server,app: migrate to sea-orm

This commit is contained in:
Valentin Tolmer
2022-11-21 09:13:25 +01:00
committed by nitnelave
parent a3a27f0049
commit e89b1538af
40 changed files with 2125 additions and 1390 deletions

View File

@@ -26,7 +26,7 @@ use crate::domain::handler::UserRequestFilter;
use crate::{
domain::{
error::DomainError,
handler::{BackendHandler, BindRequest, GroupDetails, LoginHandler, UserId},
handler::{BackendHandler, BindRequest, GroupDetails, LoginHandler, UserColumn, UserId},
opaque_handler::OpaqueHandler,
},
infra::{
@@ -149,10 +149,7 @@ where
.list_users(
Some(UserRequestFilter::Or(vec![
UserRequestFilter::UserId(UserId::new(user_string)),
UserRequestFilter::Equality(
crate::domain::sql_tables::UserColumn::Email,
user_string.to_owned(),
),
UserRequestFilter::Equality(UserColumn::Email, user_string.to_owned()),
])),
false,
)
@@ -174,7 +171,9 @@ where
Some(token) => token,
};
if let Err(e) = super::mail::send_password_reset_email(
&user.display_name,
user.display_name
.as_deref()
.unwrap_or_else(|| user.user_id.as_str()),
&user.email,
&token,
&data.server_url,

View File

@@ -1,18 +1,17 @@
use crate::{
domain::sql_tables::{DbQueryBuilder, Pool},
infra::jwt_sql_tables::{JwtRefreshStorage, JwtStorage},
use crate::domain::{
model::{self, JwtRefreshStorageColumn, JwtStorageColumn, PasswordResetTokensColumn},
sql_tables::DbConnection,
};
use actix::prelude::*;
use chrono::Local;
use actix::prelude::{Actor, AsyncContext, Context};
use cron::Schedule;
use sea_query::{Expr, Query};
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter};
use std::{str::FromStr, time::Duration};
use tracing::{debug, error, info, instrument};
use tracing::{error, info, instrument};
// Define actor
pub struct Scheduler {
schedule: Schedule,
sql_pool: Pool,
sql_pool: DbConnection,
}
// Provide Actor implementation for our actor
@@ -33,7 +32,7 @@ impl Actor for Scheduler {
}
impl Scheduler {
pub fn new(cron_expression: &str, sql_pool: Pool) -> Self {
pub fn new(cron_expression: &str, sql_pool: DbConnection) -> Self {
let schedule = Schedule::from_str(cron_expression).unwrap();
Self { schedule, sql_pool }
}
@@ -48,33 +47,35 @@ impl Scheduler {
}
#[instrument(skip_all)]
async fn cleanup_db(sql_pool: Pool) {
async fn cleanup_db(sql_pool: DbConnection) {
info!("Cleaning DB");
let query = Query::delete()
.from_table(JwtRefreshStorage::Table)
.and_where(Expr::col(JwtRefreshStorage::ExpiryDate).lt(Local::now().naive_utc()))
.to_string(DbQueryBuilder {});
debug!(%query);
if let Err(e) = sqlx::query(&query).execute(&sql_pool).await {
if let Err(e) = model::JwtRefreshStorage::delete_many()
.filter(JwtRefreshStorageColumn::ExpiryDate.lt(chrono::Utc::now().naive_utc()))
.exec(&sql_pool)
.await
{
error!("DB error while cleaning up JWT refresh tokens: {}", e);
};
if let Err(e) = sqlx::query(
&Query::delete()
.from_table(JwtStorage::Table)
.and_where(Expr::col(JwtStorage::ExpiryDate).lt(Local::now().naive_utc()))
.to_string(DbQueryBuilder {}),
)
.execute(&sql_pool)
.await
}
if let Err(e) = model::JwtStorage::delete_many()
.filter(JwtStorageColumn::ExpiryDate.lt(chrono::Utc::now().naive_utc()))
.exec(&sql_pool)
.await
{
error!("DB error while cleaning up JWT storage: {}", e);
};
if let Err(e) = model::PasswordResetTokens::delete_many()
.filter(PasswordResetTokensColumn::ExpiryDate.lt(chrono::Utc::now().naive_utc()))
.exec(&sql_pool)
.await
{
error!("DB error while cleaning up password reset tokens: {}", e);
};
info!("DB cleaned!");
}
fn duration_until_next(&self) -> Duration {
let now = Local::now();
let next = self.schedule.upcoming(Local).next().unwrap();
let now = chrono::Utc::now();
let next = self.schedule.upcoming(chrono::Utc).next().unwrap();
let duration_until = next.signed_duration_since(now);
duration_until.to_std().unwrap()
}

View File

@@ -1,7 +1,6 @@
use crate::domain::{
handler::{BackendHandler, GroupDetails, GroupId, UserId},
handler::{BackendHandler, GroupDetails, GroupId, UserColumn, UserId},
ldap::utils::map_user_field,
sql_tables::UserColumn,
};
use juniper::{graphql_object, FieldResult, GraphQLInputObject};
use serde::{Deserialize, Serialize};
@@ -214,19 +213,19 @@ impl<Handler: BackendHandler + Sync> User<Handler> {
}
fn display_name(&self) -> &str {
&self.user.display_name
self.user.display_name.as_deref().unwrap_or("")
}
fn first_name(&self) -> &str {
&self.user.first_name
self.user.first_name.as_deref().unwrap_or("")
}
fn last_name(&self) -> &str {
&self.user.last_name
self.user.last_name.as_deref().unwrap_or("")
}
fn avatar(&self) -> String {
(&self.user.avatar).into()
fn avatar(&self) -> Option<String> {
self.user.avatar.as_ref().map(String::from)
}
fn creation_date(&self) -> chrono::DateTime<chrono::Utc> {
@@ -392,7 +391,7 @@ mod tests {
Ok(DomainUser {
user_id: UserId::new("bob"),
email: "bob@bobbers.on".to_string(),
creation_date: chrono::Utc.timestamp_millis(42),
creation_date: chrono::Utc.timestamp_millis_opt(42).unwrap(),
uuid: crate::uuid!("b1a2a3a4b1b2c1c2d1d2d3d4d5d6d7d8"),
..Default::default()
})

View File

@@ -1,6 +1,7 @@
use sea_query::*;
use sea_orm::ConnectionTrait;
use sea_query::{ColumnDef, ForeignKey, ForeignKeyAction, Iden, Table};
pub use crate::domain::sql_tables::*;
pub use crate::domain::{sql_migrations::Users, sql_tables::DbConnection};
/// Contains the refresh tokens for a given user.
#[derive(Iden)]
@@ -31,110 +32,112 @@ pub enum PasswordResetTokens {
}
/// This needs to be initialized after the domain tables are.
pub async fn init_table(pool: &Pool) -> sqlx::Result<()> {
sqlx::query(
&Table::create()
.table(JwtRefreshStorage::Table)
.if_not_exists()
.col(
ColumnDef::new(JwtRefreshStorage::RefreshTokenHash)
.big_integer()
.not_null()
.primary_key(),
)
.col(
ColumnDef::new(JwtRefreshStorage::UserId)
.string_len(255)
.not_null(),
)
.col(
ColumnDef::new(JwtRefreshStorage::ExpiryDate)
.date_time()
.not_null(),
)
.foreign_key(
ForeignKey::create()
.name("JwtRefreshStorageUserForeignKey")
.from(JwtRefreshStorage::Table, JwtRefreshStorage::UserId)
.to(Users::Table, Users::UserId)
.on_delete(ForeignKeyAction::Cascade)
.on_update(ForeignKeyAction::Cascade),
)
.to_string(DbQueryBuilder {}),
pub async fn init_table(pool: &DbConnection) -> std::result::Result<(), sea_orm::DbErr> {
let builder = pool.get_database_backend();
pool.execute(
builder.build(
Table::create()
.table(JwtRefreshStorage::Table)
.if_not_exists()
.col(
ColumnDef::new(JwtRefreshStorage::RefreshTokenHash)
.big_integer()
.not_null()
.primary_key(),
)
.col(
ColumnDef::new(JwtRefreshStorage::UserId)
.string_len(255)
.not_null(),
)
.col(
ColumnDef::new(JwtRefreshStorage::ExpiryDate)
.date_time()
.not_null(),
)
.foreign_key(
ForeignKey::create()
.name("JwtRefreshStorageUserForeignKey")
.from(JwtRefreshStorage::Table, JwtRefreshStorage::UserId)
.to(Users::Table, Users::UserId)
.on_delete(ForeignKeyAction::Cascade)
.on_update(ForeignKeyAction::Cascade),
),
),
)
.execute(pool)
.await?;
sqlx::query(
&Table::create()
.table(JwtStorage::Table)
.if_not_exists()
.col(
ColumnDef::new(JwtStorage::JwtHash)
.big_integer()
.not_null()
.primary_key(),
)
.col(
ColumnDef::new(JwtStorage::UserId)
.string_len(255)
.not_null(),
)
.col(
ColumnDef::new(JwtStorage::ExpiryDate)
.date_time()
.not_null(),
)
.col(
ColumnDef::new(JwtStorage::Blacklisted)
.boolean()
.default(false)
.not_null(),
)
.foreign_key(
ForeignKey::create()
.name("JwtStorageUserForeignKey")
.from(JwtStorage::Table, JwtStorage::UserId)
.to(Users::Table, Users::UserId)
.on_delete(ForeignKeyAction::Cascade)
.on_update(ForeignKeyAction::Cascade),
)
.to_string(DbQueryBuilder {}),
pool.execute(
builder.build(
Table::create()
.table(JwtStorage::Table)
.if_not_exists()
.col(
ColumnDef::new(JwtStorage::JwtHash)
.big_integer()
.not_null()
.primary_key(),
)
.col(
ColumnDef::new(JwtStorage::UserId)
.string_len(255)
.not_null(),
)
.col(
ColumnDef::new(JwtStorage::ExpiryDate)
.date_time()
.not_null(),
)
.col(
ColumnDef::new(JwtStorage::Blacklisted)
.boolean()
.default(false)
.not_null(),
)
.foreign_key(
ForeignKey::create()
.name("JwtStorageUserForeignKey")
.from(JwtStorage::Table, JwtStorage::UserId)
.to(Users::Table, Users::UserId)
.on_delete(ForeignKeyAction::Cascade)
.on_update(ForeignKeyAction::Cascade),
),
),
)
.execute(pool)
.await?;
sqlx::query(
&Table::create()
.table(PasswordResetTokens::Table)
.if_not_exists()
.col(
ColumnDef::new(PasswordResetTokens::Token)
.string_len(255)
.not_null()
.primary_key(),
)
.col(
ColumnDef::new(PasswordResetTokens::UserId)
.string_len(255)
.not_null(),
)
.col(
ColumnDef::new(PasswordResetTokens::ExpiryDate)
.date_time()
.not_null(),
)
.foreign_key(
ForeignKey::create()
.name("PasswordResetTokensUserForeignKey")
.from(PasswordResetTokens::Table, PasswordResetTokens::UserId)
.to(Users::Table, Users::UserId)
.on_delete(ForeignKeyAction::Cascade)
.on_update(ForeignKeyAction::Cascade),
)
.to_string(DbQueryBuilder {}),
pool.execute(
builder.build(
Table::create()
.table(PasswordResetTokens::Table)
.if_not_exists()
.col(
ColumnDef::new(PasswordResetTokens::Token)
.string_len(255)
.not_null()
.primary_key(),
)
.col(
ColumnDef::new(PasswordResetTokens::UserId)
.string_len(255)
.not_null(),
)
.col(
ColumnDef::new(PasswordResetTokens::ExpiryDate)
.date_time()
.not_null(),
)
.foreign_key(
ForeignKey::create()
.name("PasswordResetTokensUserForeignKey")
.from(PasswordResetTokens::Table, PasswordResetTokens::UserId)
.to(Users::Table, Users::UserId)
.on_delete(ForeignKeyAction::Cascade)
.on_update(ForeignKeyAction::Cascade),
),
),
)
.execute(pool)
.await?;
Ok(())

View File

@@ -569,7 +569,7 @@ impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend
mod tests {
use super::*;
use crate::{
domain::{error::Result, handler::*, opaque_handler::*, sql_tables::UserColumn},
domain::{error::Result, handler::*, opaque_handler::*},
uuid,
};
use async_trait::async_trait;
@@ -669,7 +669,7 @@ mod tests {
set.insert(GroupDetails {
group_id: GroupId(42),
display_name: group,
creation_date: chrono::Utc.timestamp(42, 42),
creation_date: chrono::Utc.timestamp_opt(42, 42).unwrap(),
uuid: uuid!("a1a2a3a4b1b2c1c2d1d2d3d4d5d6d7d8"),
});
Ok(set)
@@ -756,7 +756,7 @@ mod tests {
set.insert(GroupDetails {
group_id: GroupId(42),
display_name: "lldap_admin".to_string(),
creation_date: chrono::Utc.timestamp(42, 42),
creation_date: chrono::Utc.timestamp_opt(42, 42).unwrap(),
uuid: uuid!("a1a2a3a4b1b2c1c2d1d2d3d4d5d6d7d8"),
});
Ok(set)
@@ -843,7 +843,7 @@ mod tests {
groups: Some(vec![GroupDetails {
group_id: GroupId(42),
display_name: "rockstars".to_string(),
creation_date: chrono::Utc.timestamp(42, 42),
creation_date: chrono::Utc.timestamp_opt(42, 42).unwrap(),
uuid: uuid!("a1a2a3a4b1b2c1c2d1d2d3d4d5d6d7d8"),
}]),
}])
@@ -991,9 +991,9 @@ mod tests {
user: User {
user_id: UserId::new("bob_1"),
email: "bob@bobmail.bob".to_string(),
display_name: "Bôb Böbberson".to_string(),
first_name: "Bôb".to_string(),
last_name: "Böbberson".to_string(),
display_name: Some("Bôb Böbberson".to_string()),
first_name: Some("Bôb".to_string()),
last_name: Some("Böbberson".to_string()),
uuid: uuid!("698e1d5f-7a40-3151-8745-b9b8a37839da"),
..Default::default()
},
@@ -1003,12 +1003,12 @@ mod tests {
user: User {
user_id: UserId::new("jim"),
email: "jim@cricket.jim".to_string(),
display_name: "Jimminy Cricket".to_string(),
first_name: "Jim".to_string(),
last_name: "Cricket".to_string(),
avatar: JpegPhoto::for_tests(),
display_name: Some("Jimminy Cricket".to_string()),
first_name: Some("Jim".to_string()),
last_name: Some("Cricket".to_string()),
avatar: Some(JpegPhoto::for_tests()),
uuid: uuid!("04ac75e0-2900-3e21-926c-2f732c26b3fc"),
creation_date: Utc.ymd(2014, 7, 8).and_hms(9, 10, 11),
creation_date: Utc.with_ymd_and_hms(2014, 7, 8, 9, 10, 11).unwrap(),
},
groups: None,
},
@@ -1137,14 +1137,14 @@ mod tests {
Group {
id: GroupId(1),
display_name: "group_1".to_string(),
creation_date: chrono::Utc.timestamp(42, 42),
creation_date: chrono::Utc.timestamp_opt(42, 42).unwrap(),
users: vec![UserId::new("bob"), UserId::new("john")],
uuid: uuid!("04ac75e0-2900-3e21-926c-2f732c26b3fc"),
},
Group {
id: GroupId(3),
display_name: "BestGroup".to_string(),
creation_date: chrono::Utc.timestamp(42, 42),
creation_date: chrono::Utc.timestamp_opt(42, 42).unwrap(),
users: vec![UserId::new("john")],
uuid: uuid!("04ac75e0-2900-3e21-926c-2f732c26b3fc"),
},
@@ -1230,7 +1230,7 @@ mod tests {
Ok(vec![Group {
display_name: "group_1".to_string(),
id: GroupId(1),
creation_date: chrono::Utc.timestamp(42, 42),
creation_date: chrono::Utc.timestamp_opt(42, 42).unwrap(),
users: vec![],
uuid: uuid!("04ac75e0-2900-3e21-926c-2f732c26b3fc"),
}])
@@ -1281,7 +1281,7 @@ mod tests {
Ok(vec![Group {
display_name: "group_1".to_string(),
id: GroupId(1),
creation_date: chrono::Utc.timestamp(42, 42),
creation_date: chrono::Utc.timestamp_opt(42, 42).unwrap(),
users: vec![],
uuid: uuid!("04ac75e0-2900-3e21-926c-2f732c26b3fc"),
}])
@@ -1542,9 +1542,9 @@ mod tests {
user: User {
user_id: UserId::new("bob_1"),
email: "bob@bobmail.bob".to_string(),
display_name: "Bôb Böbberson".to_string(),
first_name: "Bôb".to_string(),
last_name: "Böbberson".to_string(),
display_name: Some("Bôb Böbberson".to_string()),
first_name: Some("Bôb".to_string()),
last_name: Some("Böbberson".to_string()),
..Default::default()
},
groups: None,
@@ -1557,7 +1557,7 @@ mod tests {
Ok(vec![Group {
id: GroupId(1),
display_name: "group_1".to_string(),
creation_date: chrono::Utc.timestamp(42, 42),
creation_date: chrono::Utc.timestamp_opt(42, 42).unwrap(),
users: vec![UserId::new("bob"), UserId::new("john")],
uuid: uuid!("04ac75e0-2900-3e21-926c-2f732c26b3fc"),
}])
@@ -1616,9 +1616,9 @@ mod tests {
user: User {
user_id: UserId::new("bob_1"),
email: "bob@bobmail.bob".to_string(),
display_name: "Bôb Böbberson".to_string(),
last_name: "Böbberson".to_string(),
avatar: JpegPhoto::for_tests(),
display_name: Some("Bôb Böbberson".to_string()),
last_name: Some("Böbberson".to_string()),
avatar: Some(JpegPhoto::for_tests()),
uuid: uuid!("b4ac75e0-2900-3e21-926c-2f732c26b3fc"),
..Default::default()
},
@@ -1631,7 +1631,7 @@ mod tests {
Ok(vec![Group {
id: GroupId(1),
display_name: "group_1".to_string(),
creation_date: chrono::Utc.timestamp(42, 42),
creation_date: chrono::Utc.timestamp_opt(42, 42).unwrap(),
users: vec![UserId::new("bob"), UserId::new("john")],
uuid: uuid!("04ac75e0-2900-3e21-926c-2f732c26b3fc"),
}])
@@ -1680,7 +1680,11 @@ mod tests {
},
LdapPartialAttribute {
atype: "createtimestamp".to_string(),
vals: vec![chrono::Utc.timestamp(0, 0).to_rfc3339().into_bytes()],
vals: vec![chrono::Utc
.timestamp_opt(0, 0)
.unwrap()
.to_rfc3339()
.into_bytes()],
},
LdapPartialAttribute {
atype: "entryuuid".to_string(),
@@ -1960,7 +1964,7 @@ mod tests {
groups.insert(GroupDetails {
group_id: GroupId(0),
display_name: "lldap_admin".to_string(),
creation_date: chrono::Utc.timestamp(42, 42),
creation_date: chrono::Utc.timestamp_opt(42, 42).unwrap(),
uuid: uuid!("a1a2a3a4b1b2c1c2d1d2d3d4d5d6d7d8"),
});
mock.expect_get_user_groups()

View File

@@ -48,3 +48,14 @@ pub fn init(config: &Configuration) -> anyhow::Result<()> {
.init();
Ok(())
}
#[cfg(test)]
pub fn init_for_tests() {
if let Err(e) = tracing_subscriber::FmtSubscriber::builder()
.with_max_level(tracing::Level::DEBUG)
.with_test_writer()
.try_init()
{
log::warn!("Could not set up test logging: {:#}", e);
}
}

View File

@@ -1,10 +1,16 @@
use super::{jwt_sql_tables::*, tcp_backend_handler::*};
use crate::domain::{error::*, handler::UserId, sql_backend_handler::SqlBackendHandler};
use super::tcp_backend_handler::TcpBackendHandler;
use crate::domain::{
error::*,
handler::UserId,
model::{self, JwtRefreshStorageColumn, JwtStorageColumn, PasswordResetTokensColumn},
sql_backend_handler::SqlBackendHandler,
};
use async_trait::async_trait;
use futures_util::StreamExt;
use sea_query::{Expr, Iden, Query, SimpleExpr};
use sea_query_binder::SqlxBinder;
use sqlx::{query_as_with, query_with, Row};
use sea_orm::{
sea_query::Cond, ActiveModelTrait, ColumnTrait, EntityTrait, FromQueryResult, IntoActiveModel,
QueryFilter, QuerySelect,
};
use sea_query::Expr;
use std::collections::HashSet;
use tracing::{debug, instrument};
@@ -18,126 +24,102 @@ fn gen_random_string(len: usize) -> String {
.collect()
}
#[derive(FromQueryResult)]
struct OnlyJwtHash {
jwt_hash: i64,
}
#[async_trait]
impl TcpBackendHandler for SqlBackendHandler {
#[instrument(skip_all, level = "debug")]
async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>> {
let (query, values) = Query::select()
.column(JwtStorage::JwtHash)
.from(JwtStorage::Table)
.build_sqlx(DbQueryBuilder {});
debug!(%query);
query_with(&query, values)
.map(|row: DbRow| row.get::<i64, _>(&*JwtStorage::JwtHash.to_string()) as u64)
.fetch(&self.sql_pool)
.collect::<Vec<sqlx::Result<u64>>>()
.await
Ok(model::JwtStorage::find()
.select_only()
.column(JwtStorageColumn::JwtHash)
.filter(JwtStorageColumn::Blacklisted.eq(true))
.into_model::<OnlyJwtHash>()
.all(&self.sql_pool)
.await?
.into_iter()
.collect::<sqlx::Result<HashSet<u64>>>()
.map_err(|e| anyhow::anyhow!(e))
.map(|m| m.jwt_hash as u64)
.collect::<HashSet<u64>>())
}
#[instrument(skip_all, level = "debug")]
async fn create_refresh_token(&self, user: &UserId) -> Result<(String, chrono::Duration)> {
debug!(?user);
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
// TODO: Initialize the rng only once. Maybe Arc<Cell>?
let refresh_token = gen_random_string(100);
let refresh_token_hash = {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut s = DefaultHasher::new();
refresh_token.hash(&mut s);
s.finish()
};
let duration = chrono::Duration::days(30);
let (query, values) = Query::insert()
.into_table(JwtRefreshStorage::Table)
.columns(vec![
JwtRefreshStorage::RefreshTokenHash,
JwtRefreshStorage::UserId,
JwtRefreshStorage::ExpiryDate,
])
.values_panic(vec![
(refresh_token_hash as i64).into(),
user.into(),
(chrono::Utc::now() + duration).naive_utc().into(),
])
.build_sqlx(DbQueryBuilder {});
debug!(%query);
query_with(&query, values).execute(&self.sql_pool).await?;
let new_token = model::jwt_refresh_storage::Model {
refresh_token_hash: refresh_token_hash as i64,
user_id: user.clone(),
expiry_date: chrono::Utc::now() + duration,
}
.into_active_model();
new_token.insert(&self.sql_pool).await?;
Ok((refresh_token, duration))
}
#[instrument(skip_all, level = "debug")]
async fn check_token(&self, refresh_token_hash: u64, user: &UserId) -> Result<bool> {
debug!(?user);
let (query, values) = Query::select()
.expr(SimpleExpr::Value(1.into()))
.from(JwtRefreshStorage::Table)
.and_where(Expr::col(JwtRefreshStorage::RefreshTokenHash).eq(refresh_token_hash as i64))
.and_where(Expr::col(JwtRefreshStorage::UserId).eq(user))
.build_sqlx(DbQueryBuilder {});
debug!(%query);
Ok(query_with(&query, values)
.fetch_optional(&self.sql_pool)
.await?
.is_some())
Ok(
model::JwtRefreshStorage::find_by_id(refresh_token_hash as i64)
.filter(JwtRefreshStorageColumn::UserId.eq(user))
.one(&self.sql_pool)
.await?
.is_some(),
)
}
#[instrument(skip_all, level = "debug")]
async fn blacklist_jwts(&self, user: &UserId) -> Result<HashSet<u64>> {
debug!(?user);
use sqlx::Result;
let (query, values) = Query::select()
.column(JwtStorage::JwtHash)
.from(JwtStorage::Table)
.and_where(Expr::col(JwtStorage::UserId).eq(user))
.and_where(Expr::col(JwtStorage::Blacklisted).eq(true))
.build_sqlx(DbQueryBuilder {});
let result = query_with(&query, values)
.map(|row: DbRow| row.get::<i64, _>(&*JwtStorage::JwtHash.to_string()) as u64)
.fetch(&self.sql_pool)
.collect::<Vec<sqlx::Result<u64>>>()
.await
let valid_tokens = model::JwtStorage::find()
.select_only()
.column(JwtStorageColumn::JwtHash)
.filter(
Cond::all()
.add(JwtStorageColumn::UserId.eq(user))
.add(JwtStorageColumn::Blacklisted.eq(false)),
)
.into_model::<OnlyJwtHash>()
.all(&self.sql_pool)
.await?
.into_iter()
.collect::<Result<HashSet<u64>>>();
let (query, values) = Query::update()
.table(JwtStorage::Table)
.values(vec![(JwtStorage::Blacklisted, true.into())])
.and_where(Expr::col(JwtStorage::UserId).eq(user))
.build_sqlx(DbQueryBuilder {});
debug!(%query);
query_with(&query, values).execute(&self.sql_pool).await?;
Ok(result?)
.map(|t| t.jwt_hash as u64)
.collect::<HashSet<u64>>();
model::JwtStorage::update_many()
.col_expr(JwtStorageColumn::Blacklisted, Expr::value(true))
.filter(JwtStorageColumn::UserId.eq(user))
.exec(&self.sql_pool)
.await?;
Ok(valid_tokens)
}
#[instrument(skip_all, level = "debug")]
async fn delete_refresh_token(&self, refresh_token_hash: u64) -> Result<()> {
let (query, values) = Query::delete()
.from_table(JwtRefreshStorage::Table)
.and_where(Expr::col(JwtRefreshStorage::RefreshTokenHash).eq(refresh_token_hash as i64))
.build_sqlx(DbQueryBuilder {});
debug!(%query);
query_with(&query, values).execute(&self.sql_pool).await?;
model::JwtRefreshStorage::delete_by_id(refresh_token_hash as i64)
.exec(&self.sql_pool)
.await?;
Ok(())
}
#[instrument(skip_all, level = "debug")]
async fn start_password_reset(&self, user: &UserId) -> Result<Option<String>> {
debug!(?user);
let (query, values) = Query::select()
.column(Users::UserId)
.from(Users::Table)
.and_where(Expr::col(Users::UserId).eq(user))
.build_sqlx(DbQueryBuilder {});
debug!(%query);
// Check that the user exists.
if query_with(&query, values)
.fetch_one(&self.sql_pool)
.await
.is_err()
if model::User::find_by_id(user.clone())
.one(&self.sql_pool)
.await?
.is_none()
{
debug!("User not found");
return Ok(None);
@@ -146,50 +128,37 @@ impl TcpBackendHandler for SqlBackendHandler {
let token = gen_random_string(100);
let duration = chrono::Duration::minutes(10);
let (query, values) = Query::insert()
.into_table(PasswordResetTokens::Table)
.columns(vec![
PasswordResetTokens::Token,
PasswordResetTokens::UserId,
PasswordResetTokens::ExpiryDate,
])
.values_panic(vec![
token.clone().into(),
user.into(),
(chrono::Utc::now() + duration).naive_utc().into(),
])
.build_sqlx(DbQueryBuilder {});
debug!(%query);
query_with(&query, values).execute(&self.sql_pool).await?;
let new_token = model::password_reset_tokens::Model {
token: token.clone(),
user_id: user.clone(),
expiry_date: chrono::Utc::now() + duration,
}
.into_active_model();
new_token.insert(&self.sql_pool).await?;
Ok(Some(token))
}
#[instrument(skip_all, level = "debug", ret)]
async fn get_user_id_for_password_reset_token(&self, token: &str) -> Result<UserId> {
let (query, values) = Query::select()
.column(PasswordResetTokens::UserId)
.from(PasswordResetTokens::Table)
.and_where(Expr::col(PasswordResetTokens::Token).eq(token))
.and_where(
Expr::col(PasswordResetTokens::ExpiryDate).gt(chrono::Utc::now().naive_utc()),
)
.build_sqlx(DbQueryBuilder {});
debug!(%query);
let (user_id,) = query_as_with(&query, values)
.fetch_one(&self.sql_pool)
.await?;
Ok(user_id)
Ok(model::PasswordResetTokens::find_by_id(token.to_owned())
.filter(PasswordResetTokensColumn::ExpiryDate.gt(chrono::Utc::now().naive_utc()))
.one(&self.sql_pool)
.await?
.ok_or_else(|| DomainError::EntityNotFound("Invalid reset token".to_owned()))?
.user_id)
}
#[instrument(skip_all, level = "debug")]
async fn delete_password_reset_token(&self, token: &str) -> Result<()> {
let (query, values) = Query::delete()
.from_table(PasswordResetTokens::Table)
.and_where(Expr::col(PasswordResetTokens::Token).eq(token))
.build_sqlx(DbQueryBuilder {});
debug!(%query);
query_with(&query, values).execute(&self.sql_pool).await?;
let result = model::PasswordResetTokens::delete_by_id(token.to_owned())
.exec(&self.sql_pool)
.await?;
if result.rows_affected == 0 {
return Err(DomainError::EntityNotFound(format!(
"No such password reset token: '{}'",
token
)));
}
Ok(())
}
}

View File

@@ -52,9 +52,9 @@ pub(crate) fn error_to_http_response(error: TcpError) -> HttpResponse {
DomainError::DatabaseError(_)
| DomainError::InternalError(_)
| DomainError::UnknownCryptoError(_) => HttpResponse::InternalServerError(),
DomainError::Base64DecodeError(_) | DomainError::BinarySerializationError(_) => {
HttpResponse::BadRequest()
}
DomainError::Base64DecodeError(_)
| DomainError::BinarySerializationError(_)
| DomainError::EntityNotFound(_) => HttpResponse::BadRequest(),
},
TcpError::BadRequest(_) => HttpResponse::BadRequest(),
TcpError::InternalServerError(_) => HttpResponse::InternalServerError(),