sqlx: update dependency and protect against injections

This commit is contained in:
Valentin Tolmer
2022-03-22 20:45:59 +01:00
committed by nitnelave
parent bafb1dc5cc
commit 5e2eea0d97
10 changed files with 496 additions and 309 deletions

View File

@@ -2,8 +2,9 @@ use super::{error::*, handler::*, sql_tables::*};
use crate::infra::configuration::Configuration;
use async_trait::async_trait;
use futures_util::StreamExt;
use sea_query::{Expr, Iden, Order, Query, SimpleExpr};
use sqlx::{FromRow, Row};
use sea_query::{Cond, Expr, Iden, Order, Query, SimpleExpr};
use sea_query_binder::SqlxBinder;
use sqlx::{query_as_with, query_with, FromRow, Row};
use std::collections::HashSet;
#[derive(Debug, Clone)]
@@ -101,7 +102,7 @@ fn get_group_filter_expr(filter: GroupRequestFilter) -> SimpleExpr {
Query::select()
.column(Memberships::GroupId)
.from(Memberships::Table)
.and_where(Expr::col(Memberships::UserId).eq(user))
.cond_where(Expr::col(Memberships::UserId).eq(user))
.take(),
),
}
@@ -114,7 +115,7 @@ impl BackendHandler for SqlBackendHandler {
filters: Option<UserRequestFilter>,
get_groups: bool,
) -> Result<Vec<UserAndGroups>> {
let query = {
let (query, values) = {
let mut query_builder = Query::select()
.column((Users::Table, Users::UserId))
.column(Users::Email)
@@ -154,14 +155,24 @@ impl BackendHandler for SqlBackendHandler {
&& filter != UserRequestFilter::Or(Vec::new())
{
let (RequiresGroup(requires_group), condition) = get_user_filter_expr(filter);
query_builder.and_where(condition);
if requires_group && !get_groups {
add_join_group_tables(&mut query_builder);
query_builder.cond_where(condition);
if requires_group {
query_builder
.left_join(
Memberships::Table,
Expr::tbl(Users::Table, Users::UserId)
.equals(Memberships::Table, Memberships::UserId),
)
.left_join(
Groups::Table,
Expr::tbl(Memberships::Table, Memberships::GroupId)
.equals(Groups::Table, Groups::GroupId),
);
}
}
}
query_builder.to_string(DbQueryBuilder {})
query_builder.build_sqlx(DbQueryBuilder {})
};
log::error!("query: {}", &query);
@@ -170,7 +181,7 @@ impl BackendHandler for SqlBackendHandler {
let mut users = Vec::new();
// The rows are returned sorted by user_id. We group them by
// this key which gives us one element (`rows`) per group.
for (_, rows) in &sqlx::query(&query)
for (_, rows) in &query_with(&query, values)
.fetch_all(&self.sql_pool)
.await?
.into_iter()
@@ -200,7 +211,7 @@ impl BackendHandler for SqlBackendHandler {
}
async fn list_groups(&self, filters: Option<GroupRequestFilter>) -> Result<Vec<Group>> {
let query: String = {
let (query, values) = {
let mut query_builder = Query::select()
.column((Groups::Table, Groups::GroupId))
.column(Groups::DisplayName)
@@ -223,11 +234,11 @@ impl BackendHandler for SqlBackendHandler {
if filter != GroupRequestFilter::And(Vec::new())
&& filter != GroupRequestFilter::Or(Vec::new())
{
query_builder.and_where(get_group_filter_expr(filter));
query_builder.cond_where(get_group_filter_expr(filter));
}
}
query_builder.to_string(DbQueryBuilder {})
query_builder.build_sqlx(DbQueryBuilder {})
};
// For group_by.
@@ -235,7 +246,7 @@ impl BackendHandler for SqlBackendHandler {
let mut groups = Vec::new();
// The rows are returned sorted by display_name, equivalent to group_id. We group them by
// this key which gives us one element (`rows`) per group.
for ((group_id, display_name), rows) in &sqlx::query(&query)
for ((group_id, display_name), rows) in &query_with(query.as_str(), values)
.fetch_all(&self.sql_pool)
.await?
.into_iter()
@@ -261,7 +272,7 @@ impl BackendHandler for SqlBackendHandler {
}
async fn get_user_details(&self, user_id: &UserId) -> Result<User> {
let query = Query::select()
let (query, values) = Query::select()
.column(Users::UserId)
.column(Users::Email)
.column(Users::DisplayName)
@@ -270,25 +281,27 @@ impl BackendHandler for SqlBackendHandler {
.column(Users::Avatar)
.column(Users::CreationDate)
.from(Users::Table)
.and_where(Expr::col(Users::UserId).eq(user_id))
.to_string(DbQueryBuilder {});
.cond_where(Expr::col(Users::UserId).eq(user_id))
.build_sqlx(DbQueryBuilder {});
Ok(sqlx::query_as::<_, User>(&query)
Ok(query_as_with::<_, User, _>(query.as_str(), values)
.fetch_one(&self.sql_pool)
.await?)
}
async fn get_group_details(&self, group_id: GroupId) -> Result<GroupIdAndName> {
let query = Query::select()
let (query, values) = Query::select()
.column(Groups::GroupId)
.column(Groups::DisplayName)
.from(Groups::Table)
.and_where(Expr::col(Groups::GroupId).eq(group_id))
.to_string(DbQueryBuilder {});
.cond_where(Expr::col(Groups::GroupId).eq(group_id))
.build_sqlx(DbQueryBuilder {});
Ok(sqlx::query_as::<_, GroupIdAndName>(&query)
.fetch_one(&self.sql_pool)
.await?)
Ok(
query_as_with::<_, GroupIdAndName, _>(query.as_str(), values)
.fetch_one(&self.sql_pool)
.await?,
)
}
async fn get_user_groups(&self, user_id: &UserId) -> Result<HashSet<GroupIdAndName>> {
@@ -297,7 +310,7 @@ impl BackendHandler for SqlBackendHandler {
groups.insert(GroupIdAndName(GroupId(1), "lldap_admin".to_string()));
return Ok(groups);
}
let query: String = Query::select()
let (query, values) = Query::select()
.column((Groups::Table, Groups::GroupId))
.column(Groups::DisplayName)
.from(Groups::Table)
@@ -306,10 +319,10 @@ impl BackendHandler for SqlBackendHandler {
Expr::tbl(Groups::Table, Groups::GroupId)
.equals(Memberships::Table, Memberships::GroupId),
)
.and_where(Expr::col(Memberships::UserId).eq(user_id))
.to_string(DbQueryBuilder {});
.cond_where(Expr::col(Memberships::UserId).eq(user_id))
.build_sqlx(DbQueryBuilder {});
sqlx::query(&query)
query_with(query.as_str(), values)
// Extract the group id from the row.
.map(|row: DbRow| {
GroupIdAndName(
@@ -338,20 +351,21 @@ impl BackendHandler for SqlBackendHandler {
Users::LastName,
Users::CreationDate,
];
let values = vec![
request.user_id.into(),
request.email.into(),
request.display_name.unwrap_or_default().into(),
request.first_name.unwrap_or_default().into(),
request.last_name.unwrap_or_default().into(),
chrono::Utc::now().naive_utc().into(),
];
let query = Query::insert()
let (query, values) = Query::insert()
.into_table(Users::Table)
.columns(columns)
.values_panic(values)
.to_string(DbQueryBuilder {});
sqlx::query(&query).execute(&self.sql_pool).await?;
.values_panic(vec![
request.user_id.into(),
request.email.into(),
request.display_name.unwrap_or_default().into(),
request.first_name.unwrap_or_default().into(),
request.last_name.unwrap_or_default().into(),
chrono::Utc::now().naive_utc().into(),
])
.build_sqlx(DbQueryBuilder {});
query_with(query.as_str(), values)
.execute(&self.sql_pool)
.await?;
Ok(())
}
@@ -372,12 +386,14 @@ impl BackendHandler for SqlBackendHandler {
if values.is_empty() {
return Ok(());
}
let query = Query::update()
let (query, values) = Query::update()
.table(Users::Table)
.values(values)
.and_where(Expr::col(Users::UserId).eq(request.user_id))
.to_string(DbQueryBuilder {});
sqlx::query(&query).execute(&self.sql_pool).await?;
.cond_where(Expr::col(Users::UserId).eq(request.user_id))
.build_sqlx(DbQueryBuilder {});
query_with(query.as_str(), values)
.execute(&self.sql_pool)
.await?;
Ok(())
}
@@ -389,66 +405,83 @@ impl BackendHandler for SqlBackendHandler {
if values.is_empty() {
return Ok(());
}
let query = Query::update()
let (query, values) = Query::update()
.table(Groups::Table)
.values(values)
.and_where(Expr::col(Groups::GroupId).eq(request.group_id))
.to_string(DbQueryBuilder {});
sqlx::query(&query).execute(&self.sql_pool).await?;
.cond_where(Expr::col(Groups::GroupId).eq(request.group_id))
.build_sqlx(DbQueryBuilder {});
query_with(query.as_str(), values)
.execute(&self.sql_pool)
.await?;
Ok(())
}
async fn delete_user(&self, user_id: &UserId) -> Result<()> {
let delete_query = Query::delete()
let (delete_query, values) = Query::delete()
.from_table(Users::Table)
.and_where(Expr::col(Users::UserId).eq(user_id))
.to_string(DbQueryBuilder {});
sqlx::query(&delete_query).execute(&self.sql_pool).await?;
.cond_where(Expr::col(Users::UserId).eq(user_id))
.build_sqlx(DbQueryBuilder {});
query_with(delete_query.as_str(), values)
.execute(&self.sql_pool)
.await?;
Ok(())
}
async fn create_group(&self, group_name: &str) -> Result<GroupId> {
let query = Query::insert()
let (query, values) = Query::insert()
.into_table(Groups::Table)
.columns(vec![Groups::DisplayName])
.values_panic(vec![group_name.into()])
.to_string(DbQueryBuilder {});
sqlx::query(&query).execute(&self.sql_pool).await?;
let query = Query::select()
.build_sqlx(DbQueryBuilder {});
query_with(query.as_str(), values)
.execute(&self.sql_pool)
.await?;
let (query, values) = Query::select()
.column(Groups::GroupId)
.from(Groups::Table)
.and_where(Expr::col(Groups::DisplayName).eq(group_name))
.to_string(DbQueryBuilder {});
let row = sqlx::query(&query).fetch_one(&self.sql_pool).await?;
.cond_where(Expr::col(Groups::DisplayName).eq(group_name))
.build_sqlx(DbQueryBuilder {});
let row = query_with(query.as_str(), values)
.fetch_one(&self.sql_pool)
.await?;
Ok(GroupId(row.get::<i32, _>(&*Groups::GroupId.to_string())))
}
async fn delete_group(&self, group_id: GroupId) -> Result<()> {
let delete_query = Query::delete()
let (delete_query, values) = Query::delete()
.from_table(Groups::Table)
.and_where(Expr::col(Groups::GroupId).eq(group_id))
.to_string(DbQueryBuilder {});
sqlx::query(&delete_query).execute(&self.sql_pool).await?;
.cond_where(Expr::col(Groups::GroupId).eq(group_id))
.build_sqlx(DbQueryBuilder {});
query_with(delete_query.as_str(), values)
.execute(&self.sql_pool)
.await?;
Ok(())
}
async fn add_user_to_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()> {
let query = Query::insert()
let (query, values) = Query::insert()
.into_table(Memberships::Table)
.columns(vec![Memberships::UserId, Memberships::GroupId])
.values_panic(vec![user_id.into(), group_id.into()])
.to_string(DbQueryBuilder {});
sqlx::query(&query).execute(&self.sql_pool).await?;
.build_sqlx(DbQueryBuilder {});
query_with(query.as_str(), values)
.execute(&self.sql_pool)
.await?;
Ok(())
}
async fn remove_user_from_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()> {
let query = Query::delete()
let (query, values) = Query::delete()
.from_table(Memberships::Table)
.and_where(Expr::col(Memberships::GroupId).eq(group_id))
.and_where(Expr::col(Memberships::UserId).eq(user_id))
.to_string(DbQueryBuilder {});
sqlx::query(&query).execute(&self.sql_pool).await?;
.cond_where(
Cond::all()
.add(Expr::col(Memberships::GroupId).eq(group_id))
.add(Expr::col(Memberships::UserId).eq(user_id)),
)
.build_sqlx(DbQueryBuilder {});
query_with(query.as_str(), values)
.execute(&self.sql_pool)
.await?;
Ok(())
}
}
@@ -846,4 +879,26 @@ mod tests {
assert_eq!(get_user_names(&handler, None).await, vec!["val"]);
}
#[tokio::test]
async fn test_sql_injection() {
let sql_pool = get_initialized_db().await;
let config = get_default_config();
let handler = SqlBackendHandler::new(config, sql_pool);
let user_name = UserId::new(r#"bob"e"i'o;aü"#);
insert_user(&handler, user_name.as_str(), "bob00").await;
{
let users = handler
.list_users(None, false)
.await
.unwrap()
.into_iter()
.map(|u| u.user.user_id)
.collect::<Vec<_>>();
assert_eq!(users, vec![user_name.clone()]);
let user = handler.get_user_details(&user_name).await.unwrap();
assert_eq!(user.user_id, user_name);
}
}
}