sqlx: update dependency and protect against injections
This commit is contained in:
committed by
nitnelave
parent
bafb1dc5cc
commit
5e2eea0d97
@@ -2,8 +2,9 @@ use super::{error::*, handler::*, sql_tables::*};
|
||||
use crate::infra::configuration::Configuration;
|
||||
use async_trait::async_trait;
|
||||
use futures_util::StreamExt;
|
||||
use sea_query::{Expr, Iden, Order, Query, SimpleExpr};
|
||||
use sqlx::{FromRow, Row};
|
||||
use sea_query::{Cond, Expr, Iden, Order, Query, SimpleExpr};
|
||||
use sea_query_binder::SqlxBinder;
|
||||
use sqlx::{query_as_with, query_with, FromRow, Row};
|
||||
use std::collections::HashSet;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -101,7 +102,7 @@ fn get_group_filter_expr(filter: GroupRequestFilter) -> SimpleExpr {
|
||||
Query::select()
|
||||
.column(Memberships::GroupId)
|
||||
.from(Memberships::Table)
|
||||
.and_where(Expr::col(Memberships::UserId).eq(user))
|
||||
.cond_where(Expr::col(Memberships::UserId).eq(user))
|
||||
.take(),
|
||||
),
|
||||
}
|
||||
@@ -114,7 +115,7 @@ impl BackendHandler for SqlBackendHandler {
|
||||
filters: Option<UserRequestFilter>,
|
||||
get_groups: bool,
|
||||
) -> Result<Vec<UserAndGroups>> {
|
||||
let query = {
|
||||
let (query, values) = {
|
||||
let mut query_builder = Query::select()
|
||||
.column((Users::Table, Users::UserId))
|
||||
.column(Users::Email)
|
||||
@@ -154,14 +155,24 @@ impl BackendHandler for SqlBackendHandler {
|
||||
&& filter != UserRequestFilter::Or(Vec::new())
|
||||
{
|
||||
let (RequiresGroup(requires_group), condition) = get_user_filter_expr(filter);
|
||||
query_builder.and_where(condition);
|
||||
if requires_group && !get_groups {
|
||||
add_join_group_tables(&mut query_builder);
|
||||
query_builder.cond_where(condition);
|
||||
if requires_group {
|
||||
query_builder
|
||||
.left_join(
|
||||
Memberships::Table,
|
||||
Expr::tbl(Users::Table, Users::UserId)
|
||||
.equals(Memberships::Table, Memberships::UserId),
|
||||
)
|
||||
.left_join(
|
||||
Groups::Table,
|
||||
Expr::tbl(Memberships::Table, Memberships::GroupId)
|
||||
.equals(Groups::Table, Groups::GroupId),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
query_builder.to_string(DbQueryBuilder {})
|
||||
query_builder.build_sqlx(DbQueryBuilder {})
|
||||
};
|
||||
log::error!("query: {}", &query);
|
||||
|
||||
@@ -170,7 +181,7 @@ impl BackendHandler for SqlBackendHandler {
|
||||
let mut users = Vec::new();
|
||||
// The rows are returned sorted by user_id. We group them by
|
||||
// this key which gives us one element (`rows`) per group.
|
||||
for (_, rows) in &sqlx::query(&query)
|
||||
for (_, rows) in &query_with(&query, values)
|
||||
.fetch_all(&self.sql_pool)
|
||||
.await?
|
||||
.into_iter()
|
||||
@@ -200,7 +211,7 @@ impl BackendHandler for SqlBackendHandler {
|
||||
}
|
||||
|
||||
async fn list_groups(&self, filters: Option<GroupRequestFilter>) -> Result<Vec<Group>> {
|
||||
let query: String = {
|
||||
let (query, values) = {
|
||||
let mut query_builder = Query::select()
|
||||
.column((Groups::Table, Groups::GroupId))
|
||||
.column(Groups::DisplayName)
|
||||
@@ -223,11 +234,11 @@ impl BackendHandler for SqlBackendHandler {
|
||||
if filter != GroupRequestFilter::And(Vec::new())
|
||||
&& filter != GroupRequestFilter::Or(Vec::new())
|
||||
{
|
||||
query_builder.and_where(get_group_filter_expr(filter));
|
||||
query_builder.cond_where(get_group_filter_expr(filter));
|
||||
}
|
||||
}
|
||||
|
||||
query_builder.to_string(DbQueryBuilder {})
|
||||
query_builder.build_sqlx(DbQueryBuilder {})
|
||||
};
|
||||
|
||||
// For group_by.
|
||||
@@ -235,7 +246,7 @@ impl BackendHandler for SqlBackendHandler {
|
||||
let mut groups = Vec::new();
|
||||
// The rows are returned sorted by display_name, equivalent to group_id. We group them by
|
||||
// this key which gives us one element (`rows`) per group.
|
||||
for ((group_id, display_name), rows) in &sqlx::query(&query)
|
||||
for ((group_id, display_name), rows) in &query_with(query.as_str(), values)
|
||||
.fetch_all(&self.sql_pool)
|
||||
.await?
|
||||
.into_iter()
|
||||
@@ -261,7 +272,7 @@ impl BackendHandler for SqlBackendHandler {
|
||||
}
|
||||
|
||||
async fn get_user_details(&self, user_id: &UserId) -> Result<User> {
|
||||
let query = Query::select()
|
||||
let (query, values) = Query::select()
|
||||
.column(Users::UserId)
|
||||
.column(Users::Email)
|
||||
.column(Users::DisplayName)
|
||||
@@ -270,25 +281,27 @@ impl BackendHandler for SqlBackendHandler {
|
||||
.column(Users::Avatar)
|
||||
.column(Users::CreationDate)
|
||||
.from(Users::Table)
|
||||
.and_where(Expr::col(Users::UserId).eq(user_id))
|
||||
.to_string(DbQueryBuilder {});
|
||||
.cond_where(Expr::col(Users::UserId).eq(user_id))
|
||||
.build_sqlx(DbQueryBuilder {});
|
||||
|
||||
Ok(sqlx::query_as::<_, User>(&query)
|
||||
Ok(query_as_with::<_, User, _>(query.as_str(), values)
|
||||
.fetch_one(&self.sql_pool)
|
||||
.await?)
|
||||
}
|
||||
|
||||
async fn get_group_details(&self, group_id: GroupId) -> Result<GroupIdAndName> {
|
||||
let query = Query::select()
|
||||
let (query, values) = Query::select()
|
||||
.column(Groups::GroupId)
|
||||
.column(Groups::DisplayName)
|
||||
.from(Groups::Table)
|
||||
.and_where(Expr::col(Groups::GroupId).eq(group_id))
|
||||
.to_string(DbQueryBuilder {});
|
||||
.cond_where(Expr::col(Groups::GroupId).eq(group_id))
|
||||
.build_sqlx(DbQueryBuilder {});
|
||||
|
||||
Ok(sqlx::query_as::<_, GroupIdAndName>(&query)
|
||||
.fetch_one(&self.sql_pool)
|
||||
.await?)
|
||||
Ok(
|
||||
query_as_with::<_, GroupIdAndName, _>(query.as_str(), values)
|
||||
.fetch_one(&self.sql_pool)
|
||||
.await?,
|
||||
)
|
||||
}
|
||||
|
||||
async fn get_user_groups(&self, user_id: &UserId) -> Result<HashSet<GroupIdAndName>> {
|
||||
@@ -297,7 +310,7 @@ impl BackendHandler for SqlBackendHandler {
|
||||
groups.insert(GroupIdAndName(GroupId(1), "lldap_admin".to_string()));
|
||||
return Ok(groups);
|
||||
}
|
||||
let query: String = Query::select()
|
||||
let (query, values) = Query::select()
|
||||
.column((Groups::Table, Groups::GroupId))
|
||||
.column(Groups::DisplayName)
|
||||
.from(Groups::Table)
|
||||
@@ -306,10 +319,10 @@ impl BackendHandler for SqlBackendHandler {
|
||||
Expr::tbl(Groups::Table, Groups::GroupId)
|
||||
.equals(Memberships::Table, Memberships::GroupId),
|
||||
)
|
||||
.and_where(Expr::col(Memberships::UserId).eq(user_id))
|
||||
.to_string(DbQueryBuilder {});
|
||||
.cond_where(Expr::col(Memberships::UserId).eq(user_id))
|
||||
.build_sqlx(DbQueryBuilder {});
|
||||
|
||||
sqlx::query(&query)
|
||||
query_with(query.as_str(), values)
|
||||
// Extract the group id from the row.
|
||||
.map(|row: DbRow| {
|
||||
GroupIdAndName(
|
||||
@@ -338,20 +351,21 @@ impl BackendHandler for SqlBackendHandler {
|
||||
Users::LastName,
|
||||
Users::CreationDate,
|
||||
];
|
||||
let values = vec![
|
||||
request.user_id.into(),
|
||||
request.email.into(),
|
||||
request.display_name.unwrap_or_default().into(),
|
||||
request.first_name.unwrap_or_default().into(),
|
||||
request.last_name.unwrap_or_default().into(),
|
||||
chrono::Utc::now().naive_utc().into(),
|
||||
];
|
||||
let query = Query::insert()
|
||||
let (query, values) = Query::insert()
|
||||
.into_table(Users::Table)
|
||||
.columns(columns)
|
||||
.values_panic(values)
|
||||
.to_string(DbQueryBuilder {});
|
||||
sqlx::query(&query).execute(&self.sql_pool).await?;
|
||||
.values_panic(vec![
|
||||
request.user_id.into(),
|
||||
request.email.into(),
|
||||
request.display_name.unwrap_or_default().into(),
|
||||
request.first_name.unwrap_or_default().into(),
|
||||
request.last_name.unwrap_or_default().into(),
|
||||
chrono::Utc::now().naive_utc().into(),
|
||||
])
|
||||
.build_sqlx(DbQueryBuilder {});
|
||||
query_with(query.as_str(), values)
|
||||
.execute(&self.sql_pool)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -372,12 +386,14 @@ impl BackendHandler for SqlBackendHandler {
|
||||
if values.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
let query = Query::update()
|
||||
let (query, values) = Query::update()
|
||||
.table(Users::Table)
|
||||
.values(values)
|
||||
.and_where(Expr::col(Users::UserId).eq(request.user_id))
|
||||
.to_string(DbQueryBuilder {});
|
||||
sqlx::query(&query).execute(&self.sql_pool).await?;
|
||||
.cond_where(Expr::col(Users::UserId).eq(request.user_id))
|
||||
.build_sqlx(DbQueryBuilder {});
|
||||
query_with(query.as_str(), values)
|
||||
.execute(&self.sql_pool)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -389,66 +405,83 @@ impl BackendHandler for SqlBackendHandler {
|
||||
if values.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
let query = Query::update()
|
||||
let (query, values) = Query::update()
|
||||
.table(Groups::Table)
|
||||
.values(values)
|
||||
.and_where(Expr::col(Groups::GroupId).eq(request.group_id))
|
||||
.to_string(DbQueryBuilder {});
|
||||
sqlx::query(&query).execute(&self.sql_pool).await?;
|
||||
.cond_where(Expr::col(Groups::GroupId).eq(request.group_id))
|
||||
.build_sqlx(DbQueryBuilder {});
|
||||
query_with(query.as_str(), values)
|
||||
.execute(&self.sql_pool)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn delete_user(&self, user_id: &UserId) -> Result<()> {
|
||||
let delete_query = Query::delete()
|
||||
let (delete_query, values) = Query::delete()
|
||||
.from_table(Users::Table)
|
||||
.and_where(Expr::col(Users::UserId).eq(user_id))
|
||||
.to_string(DbQueryBuilder {});
|
||||
sqlx::query(&delete_query).execute(&self.sql_pool).await?;
|
||||
.cond_where(Expr::col(Users::UserId).eq(user_id))
|
||||
.build_sqlx(DbQueryBuilder {});
|
||||
query_with(delete_query.as_str(), values)
|
||||
.execute(&self.sql_pool)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_group(&self, group_name: &str) -> Result<GroupId> {
|
||||
let query = Query::insert()
|
||||
let (query, values) = Query::insert()
|
||||
.into_table(Groups::Table)
|
||||
.columns(vec![Groups::DisplayName])
|
||||
.values_panic(vec![group_name.into()])
|
||||
.to_string(DbQueryBuilder {});
|
||||
sqlx::query(&query).execute(&self.sql_pool).await?;
|
||||
let query = Query::select()
|
||||
.build_sqlx(DbQueryBuilder {});
|
||||
query_with(query.as_str(), values)
|
||||
.execute(&self.sql_pool)
|
||||
.await?;
|
||||
let (query, values) = Query::select()
|
||||
.column(Groups::GroupId)
|
||||
.from(Groups::Table)
|
||||
.and_where(Expr::col(Groups::DisplayName).eq(group_name))
|
||||
.to_string(DbQueryBuilder {});
|
||||
let row = sqlx::query(&query).fetch_one(&self.sql_pool).await?;
|
||||
.cond_where(Expr::col(Groups::DisplayName).eq(group_name))
|
||||
.build_sqlx(DbQueryBuilder {});
|
||||
let row = query_with(query.as_str(), values)
|
||||
.fetch_one(&self.sql_pool)
|
||||
.await?;
|
||||
Ok(GroupId(row.get::<i32, _>(&*Groups::GroupId.to_string())))
|
||||
}
|
||||
|
||||
async fn delete_group(&self, group_id: GroupId) -> Result<()> {
|
||||
let delete_query = Query::delete()
|
||||
let (delete_query, values) = Query::delete()
|
||||
.from_table(Groups::Table)
|
||||
.and_where(Expr::col(Groups::GroupId).eq(group_id))
|
||||
.to_string(DbQueryBuilder {});
|
||||
sqlx::query(&delete_query).execute(&self.sql_pool).await?;
|
||||
.cond_where(Expr::col(Groups::GroupId).eq(group_id))
|
||||
.build_sqlx(DbQueryBuilder {});
|
||||
query_with(delete_query.as_str(), values)
|
||||
.execute(&self.sql_pool)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn add_user_to_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()> {
|
||||
let query = Query::insert()
|
||||
let (query, values) = Query::insert()
|
||||
.into_table(Memberships::Table)
|
||||
.columns(vec![Memberships::UserId, Memberships::GroupId])
|
||||
.values_panic(vec![user_id.into(), group_id.into()])
|
||||
.to_string(DbQueryBuilder {});
|
||||
sqlx::query(&query).execute(&self.sql_pool).await?;
|
||||
.build_sqlx(DbQueryBuilder {});
|
||||
query_with(query.as_str(), values)
|
||||
.execute(&self.sql_pool)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn remove_user_from_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()> {
|
||||
let query = Query::delete()
|
||||
let (query, values) = Query::delete()
|
||||
.from_table(Memberships::Table)
|
||||
.and_where(Expr::col(Memberships::GroupId).eq(group_id))
|
||||
.and_where(Expr::col(Memberships::UserId).eq(user_id))
|
||||
.to_string(DbQueryBuilder {});
|
||||
sqlx::query(&query).execute(&self.sql_pool).await?;
|
||||
.cond_where(
|
||||
Cond::all()
|
||||
.add(Expr::col(Memberships::GroupId).eq(group_id))
|
||||
.add(Expr::col(Memberships::UserId).eq(user_id)),
|
||||
)
|
||||
.build_sqlx(DbQueryBuilder {});
|
||||
query_with(query.as_str(), values)
|
||||
.execute(&self.sql_pool)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -846,4 +879,26 @@ mod tests {
|
||||
|
||||
assert_eq!(get_user_names(&handler, None).await, vec!["val"]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sql_injection() {
|
||||
let sql_pool = get_initialized_db().await;
|
||||
let config = get_default_config();
|
||||
let handler = SqlBackendHandler::new(config, sql_pool);
|
||||
let user_name = UserId::new(r#"bob"e"i'o;aü"#);
|
||||
insert_user(&handler, user_name.as_str(), "bob00").await;
|
||||
{
|
||||
let users = handler
|
||||
.list_users(None, false)
|
||||
.await
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(|u| u.user.user_id)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(users, vec![user_name.clone()]);
|
||||
let user = handler.get_user_details(&user_name).await.unwrap();
|
||||
assert_eq!(user.user_id, user_name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ use async_trait::async_trait;
|
||||
use lldap_auth::opaque;
|
||||
use log::*;
|
||||
use sea_query::{Expr, Iden, Query};
|
||||
use sea_query_binder::SqlxBinder;
|
||||
use secstr::SecUtf8;
|
||||
use sqlx::Row;
|
||||
|
||||
@@ -53,12 +54,15 @@ impl SqlBackendHandler {
|
||||
) -> Result<Option<opaque::server::ServerRegistration>> {
|
||||
// Fetch the previously registered password file from the DB.
|
||||
let password_file_bytes = {
|
||||
let query = Query::select()
|
||||
let (query, values) = Query::select()
|
||||
.column(Users::PasswordHash)
|
||||
.from(Users::Table)
|
||||
.and_where(Expr::col(Users::UserId).eq(username))
|
||||
.to_string(DbQueryBuilder {});
|
||||
if let Some(row) = sqlx::query(&query).fetch_optional(&self.sql_pool).await? {
|
||||
.cond_where(Expr::col(Users::UserId).eq(username))
|
||||
.build_sqlx(DbQueryBuilder {});
|
||||
if let Some(row) = sqlx::query_with(query.as_str(), values)
|
||||
.fetch_optional(&self.sql_pool)
|
||||
.await?
|
||||
{
|
||||
if let Some(bytes) =
|
||||
row.get::<Option<Vec<u8>>, _>(&*Users::PasswordHash.to_string())
|
||||
{
|
||||
@@ -94,12 +98,15 @@ impl LoginHandler for SqlBackendHandler {
|
||||
)));
|
||||
}
|
||||
}
|
||||
let query = Query::select()
|
||||
let (query, values) = Query::select()
|
||||
.column(Users::PasswordHash)
|
||||
.from(Users::Table)
|
||||
.and_where(Expr::col(Users::UserId).eq(&request.name))
|
||||
.to_string(DbQueryBuilder {});
|
||||
if let Ok(row) = sqlx::query(&query).fetch_one(&self.sql_pool).await {
|
||||
.cond_where(Expr::col(Users::UserId).eq(&request.name))
|
||||
.build_sqlx(DbQueryBuilder {});
|
||||
if let Ok(row) = sqlx::query_with(&query, values)
|
||||
.fetch_one(&self.sql_pool)
|
||||
.await
|
||||
{
|
||||
if let Some(password_hash) =
|
||||
row.get::<Option<Vec<u8>>, _>(&*Users::PasswordHash.to_string())
|
||||
{
|
||||
@@ -209,15 +216,14 @@ impl OpaqueHandler for SqlOpaqueHandler {
|
||||
opaque::server::registration::get_password_file(request.registration_upload);
|
||||
{
|
||||
// Set the user password to the new password.
|
||||
let update_query = Query::update()
|
||||
let (update_query, values) = Query::update()
|
||||
.table(Users::Table)
|
||||
.values(vec![(
|
||||
Users::PasswordHash,
|
||||
password_file.serialize().into(),
|
||||
)])
|
||||
.and_where(Expr::col(Users::UserId).eq(username))
|
||||
.to_string(DbQueryBuilder {});
|
||||
sqlx::query(&update_query).execute(&self.sql_pool).await?;
|
||||
.value(Users::PasswordHash, password_file.serialize().into())
|
||||
.cond_where(Expr::col(Users::UserId).eq(username))
|
||||
.build_sqlx(DbQueryBuilder {});
|
||||
sqlx::query_with(update_query.as_str(), values)
|
||||
.execute(&self.sql_pool)
|
||||
.await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -169,16 +169,16 @@ pub async fn init_table(pool: &Pool) -> sqlx::Result<()> {
|
||||
.foreign_key(
|
||||
ForeignKey::create()
|
||||
.name("MembershipUserForeignKey")
|
||||
.table(Memberships::Table, Users::Table)
|
||||
.col(Memberships::UserId, Users::UserId)
|
||||
.from(Memberships::Table, Memberships::UserId)
|
||||
.to(Users::Table, Users::UserId)
|
||||
.on_delete(ForeignKeyAction::Cascade)
|
||||
.on_update(ForeignKeyAction::Cascade),
|
||||
)
|
||||
.foreign_key(
|
||||
ForeignKey::create()
|
||||
.name("MembershipGroupForeignKey")
|
||||
.table(Memberships::Table, Groups::Table)
|
||||
.col(Memberships::GroupId, Groups::GroupId)
|
||||
.from(Memberships::Table, Memberships::GroupId)
|
||||
.to(Groups::Table, Groups::GroupId)
|
||||
.on_delete(ForeignKeyAction::Cascade)
|
||||
.on_update(ForeignKeyAction::Cascade),
|
||||
)
|
||||
|
||||
@@ -55,8 +55,8 @@ pub async fn init_table(pool: &Pool) -> sqlx::Result<()> {
|
||||
.foreign_key(
|
||||
ForeignKey::create()
|
||||
.name("JwtRefreshStorageUserForeignKey")
|
||||
.table(JwtRefreshStorage::Table, Users::Table)
|
||||
.col(JwtRefreshStorage::UserId, Users::UserId)
|
||||
.from(JwtRefreshStorage::Table, JwtRefreshStorage::UserId)
|
||||
.to(Users::Table, Users::UserId)
|
||||
.on_delete(ForeignKeyAction::Cascade)
|
||||
.on_update(ForeignKeyAction::Cascade),
|
||||
)
|
||||
@@ -94,8 +94,8 @@ pub async fn init_table(pool: &Pool) -> sqlx::Result<()> {
|
||||
.foreign_key(
|
||||
ForeignKey::create()
|
||||
.name("JwtStorageUserForeignKey")
|
||||
.table(JwtStorage::Table, Users::Table)
|
||||
.col(JwtStorage::UserId, Users::UserId)
|
||||
.from(JwtStorage::Table, JwtStorage::UserId)
|
||||
.to(Users::Table, Users::UserId)
|
||||
.on_delete(ForeignKeyAction::Cascade)
|
||||
.on_update(ForeignKeyAction::Cascade),
|
||||
)
|
||||
@@ -127,8 +127,8 @@ pub async fn init_table(pool: &Pool) -> sqlx::Result<()> {
|
||||
.foreign_key(
|
||||
ForeignKey::create()
|
||||
.name("PasswordResetTokensUserForeignKey")
|
||||
.table(PasswordResetTokens::Table, Users::Table)
|
||||
.col(PasswordResetTokens::UserId, Users::UserId)
|
||||
.from(PasswordResetTokens::Table, PasswordResetTokens::UserId)
|
||||
.to(Users::Table, Users::UserId)
|
||||
.on_delete(ForeignKeyAction::Cascade)
|
||||
.on_update(ForeignKeyAction::Cascade),
|
||||
)
|
||||
|
||||
@@ -3,7 +3,8 @@ use crate::domain::{error::*, handler::UserId, sql_backend_handler::SqlBackendHa
|
||||
use async_trait::async_trait;
|
||||
use futures_util::StreamExt;
|
||||
use sea_query::{Expr, Iden, Query, SimpleExpr};
|
||||
use sqlx::Row;
|
||||
use sea_query_binder::SqlxBinder;
|
||||
use sqlx::{query_as_with, query_with, Row};
|
||||
use std::collections::HashSet;
|
||||
|
||||
fn gen_random_string(len: usize) -> String {
|
||||
@@ -19,12 +20,12 @@ fn gen_random_string(len: usize) -> String {
|
||||
#[async_trait]
|
||||
impl TcpBackendHandler for SqlBackendHandler {
|
||||
async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>> {
|
||||
let query = Query::select()
|
||||
let (query, values) = Query::select()
|
||||
.column(JwtStorage::JwtHash)
|
||||
.from(JwtStorage::Table)
|
||||
.to_string(DbQueryBuilder {});
|
||||
.build_sqlx(DbQueryBuilder {});
|
||||
|
||||
sqlx::query(&query)
|
||||
query_with(&query, values)
|
||||
.map(|row: DbRow| row.get::<i64, _>(&*JwtStorage::JwtHash.to_string()) as u64)
|
||||
.fetch(&self.sql_pool)
|
||||
.collect::<Vec<sqlx::Result<u64>>>()
|
||||
@@ -45,7 +46,7 @@ impl TcpBackendHandler for SqlBackendHandler {
|
||||
s.finish()
|
||||
};
|
||||
let duration = chrono::Duration::days(30);
|
||||
let query = Query::insert()
|
||||
let (query, values) = Query::insert()
|
||||
.into_table(JwtRefreshStorage::Table)
|
||||
.columns(vec![
|
||||
JwtRefreshStorage::RefreshTokenHash,
|
||||
@@ -57,71 +58,75 @@ impl TcpBackendHandler for SqlBackendHandler {
|
||||
user.into(),
|
||||
(chrono::Utc::now() + duration).naive_utc().into(),
|
||||
])
|
||||
.to_string(DbQueryBuilder {});
|
||||
sqlx::query(&query).execute(&self.sql_pool).await?;
|
||||
.build_sqlx(DbQueryBuilder {});
|
||||
query_with(&query, values).execute(&self.sql_pool).await?;
|
||||
Ok((refresh_token, duration))
|
||||
}
|
||||
|
||||
async fn check_token(&self, refresh_token_hash: u64, user: &UserId) -> Result<bool> {
|
||||
let query = Query::select()
|
||||
let (query, values) = Query::select()
|
||||
.expr(SimpleExpr::Value(1.into()))
|
||||
.from(JwtRefreshStorage::Table)
|
||||
.and_where(Expr::col(JwtRefreshStorage::RefreshTokenHash).eq(refresh_token_hash as i64))
|
||||
.and_where(Expr::col(JwtRefreshStorage::UserId).eq(user))
|
||||
.to_string(DbQueryBuilder {});
|
||||
Ok(sqlx::query(&query)
|
||||
.build_sqlx(DbQueryBuilder {});
|
||||
Ok(query_with(&query, values)
|
||||
.fetch_optional(&self.sql_pool)
|
||||
.await?
|
||||
.is_some())
|
||||
}
|
||||
async fn blacklist_jwts(&self, user: &UserId) -> Result<HashSet<u64>> {
|
||||
use sqlx::Result;
|
||||
let query = Query::select()
|
||||
let (query, values) = Query::select()
|
||||
.column(JwtStorage::JwtHash)
|
||||
.from(JwtStorage::Table)
|
||||
.and_where(Expr::col(JwtStorage::UserId).eq(user))
|
||||
.and_where(Expr::col(JwtStorage::Blacklisted).eq(true))
|
||||
.to_string(DbQueryBuilder {});
|
||||
let result = sqlx::query(&query)
|
||||
.build_sqlx(DbQueryBuilder {});
|
||||
let result = query_with(&query, values)
|
||||
.map(|row: DbRow| row.get::<i64, _>(&*JwtStorage::JwtHash.to_string()) as u64)
|
||||
.fetch(&self.sql_pool)
|
||||
.collect::<Vec<sqlx::Result<u64>>>()
|
||||
.await
|
||||
.into_iter()
|
||||
.collect::<Result<HashSet<u64>>>();
|
||||
let query = Query::update()
|
||||
let (query, values) = Query::update()
|
||||
.table(JwtStorage::Table)
|
||||
.values(vec![(JwtStorage::Blacklisted, true.into())])
|
||||
.and_where(Expr::col(JwtStorage::UserId).eq(user))
|
||||
.to_string(DbQueryBuilder {});
|
||||
sqlx::query(&query).execute(&self.sql_pool).await?;
|
||||
.build_sqlx(DbQueryBuilder {});
|
||||
query_with(&query, values).execute(&self.sql_pool).await?;
|
||||
Ok(result?)
|
||||
}
|
||||
async fn delete_refresh_token(&self, refresh_token_hash: u64) -> Result<()> {
|
||||
let query = Query::delete()
|
||||
let (query, values) = Query::delete()
|
||||
.from_table(JwtRefreshStorage::Table)
|
||||
.and_where(Expr::col(JwtRefreshStorage::RefreshTokenHash).eq(refresh_token_hash))
|
||||
.to_string(DbQueryBuilder {});
|
||||
sqlx::query(&query).execute(&self.sql_pool).await?;
|
||||
.build_sqlx(DbQueryBuilder {});
|
||||
query_with(&query, values).execute(&self.sql_pool).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn start_password_reset(&self, user: &UserId) -> Result<Option<String>> {
|
||||
let query = Query::select()
|
||||
let (query, values) = Query::select()
|
||||
.column(Users::UserId)
|
||||
.from(Users::Table)
|
||||
.and_where(Expr::col(Users::UserId).eq(user))
|
||||
.to_string(DbQueryBuilder {});
|
||||
.build_sqlx(DbQueryBuilder {});
|
||||
|
||||
// Check that the user exists.
|
||||
if sqlx::query(&query).fetch_one(&self.sql_pool).await.is_err() {
|
||||
if query_with(&query, values)
|
||||
.fetch_one(&self.sql_pool)
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let token = gen_random_string(100);
|
||||
let duration = chrono::Duration::minutes(10);
|
||||
|
||||
let query = Query::insert()
|
||||
let (query, values) = Query::insert()
|
||||
.into_table(PasswordResetTokens::Table)
|
||||
.columns(vec![
|
||||
PasswordResetTokens::Token,
|
||||
@@ -133,31 +138,33 @@ impl TcpBackendHandler for SqlBackendHandler {
|
||||
user.into(),
|
||||
(chrono::Utc::now() + duration).naive_utc().into(),
|
||||
])
|
||||
.to_string(DbQueryBuilder {});
|
||||
sqlx::query(&query).execute(&self.sql_pool).await?;
|
||||
.build_sqlx(DbQueryBuilder {});
|
||||
query_with(&query, values).execute(&self.sql_pool).await?;
|
||||
Ok(Some(token))
|
||||
}
|
||||
|
||||
async fn get_user_id_for_password_reset_token(&self, token: &str) -> Result<UserId> {
|
||||
let query = Query::select()
|
||||
let (query, values) = Query::select()
|
||||
.column(PasswordResetTokens::UserId)
|
||||
.from(PasswordResetTokens::Table)
|
||||
.and_where(Expr::col(PasswordResetTokens::Token).eq(token))
|
||||
.and_where(
|
||||
Expr::col(PasswordResetTokens::ExpiryDate).gt(chrono::Utc::now().naive_utc()),
|
||||
)
|
||||
.to_string(DbQueryBuilder {});
|
||||
.build_sqlx(DbQueryBuilder {});
|
||||
|
||||
let (user_id,) = sqlx::query_as(&query).fetch_one(&self.sql_pool).await?;
|
||||
let (user_id,) = query_as_with(&query, values)
|
||||
.fetch_one(&self.sql_pool)
|
||||
.await?;
|
||||
Ok(user_id)
|
||||
}
|
||||
|
||||
async fn delete_password_reset_token(&self, token: &str) -> Result<()> {
|
||||
let query = Query::delete()
|
||||
let (query, values) = Query::delete()
|
||||
.from_table(PasswordResetTokens::Table)
|
||||
.and_where(Expr::col(PasswordResetTokens::Token).eq(token))
|
||||
.to_string(DbQueryBuilder {});
|
||||
sqlx::query(&query).execute(&self.sql_pool).await?;
|
||||
.build_sqlx(DbQueryBuilder {});
|
||||
query_with(&query, values).execute(&self.sql_pool).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user