From 70146e0b702b95c3d59238a4b0ed80c735556b0b Mon Sep 17 00:00:00 2001 From: Valentin Tolmer Date: Tue, 16 May 2023 04:29:34 +0200 Subject: [PATCH] server: prepare DB schema for user attributes First step of #67. --- .github/workflows/docker-build-static.yml | 16 +- Cargo.lock | 20 +- scripts/sqlite_dump_commands.sh | 4 +- server/Cargo.toml | 1 + server/src/domain/error.rs | 11 + server/src/domain/handler.rs | 1 + server/src/domain/ldap/group.rs | 10 +- server/src/domain/ldap/user.rs | 44 +- server/src/domain/ldap/utils.rs | 50 +- .../domain/model/group_attribute_schema.rs | 39 ++ server/src/domain/model/group_attributes.rs | 57 ++ server/src/domain/model/mod.rs | 6 + server/src/domain/model/prelude.rs | 8 + .../src/domain/model/user_attribute_schema.rs | 39 ++ server/src/domain/model/user_attributes.rs | 57 ++ server/src/domain/model/users.rs | 17 +- server/src/domain/sql_migrations.rs | 557 ++++++++++++++---- server/src/domain/sql_tables.rs | 139 ++++- server/src/domain/sql_user_backend_handler.rs | 378 +++++++++--- server/src/domain/types.rs | 125 +++- server/src/infra/graphql/query.rs | 35 +- server/src/infra/ldap_handler.rs | 47 +- server/src/infra/tcp_server.rs | 1 + server/tests/common/fixture.rs | 2 +- 24 files changed, 1362 insertions(+), 302 deletions(-) create mode 100644 server/src/domain/model/group_attribute_schema.rs create mode 100644 server/src/domain/model/group_attributes.rs create mode 100644 server/src/domain/model/user_attribute_schema.rs create mode 100644 server/src/domain/model/user_attributes.rs diff --git a/.github/workflows/docker-build-static.yml b/.github/workflows/docker-build-static.yml index 5315c51..dcb9832 100644 --- a/.github/workflows/docker-build-static.yml +++ b/.github/workflows/docker-build-static.yml @@ -350,7 +350,7 @@ jobs: curl -L https://raw.githubusercontent.com/lldap/lldap/main/scripts/sqlite_dump_commands.sh -o helper.sh chmod +x ./helper.sh ./helper.sh | sqlite3 ./users.db > ./dump.sql - sed -i -r -e "s/X'([[:xdigit:]]+'[^'])/'\\\x\\1/g" -e '1s/^/BEGIN;\n/' -e '$aCOMMIT;' ./dump.sql + sed -i -r -e "s/X'([[:xdigit:]]+'[^'])/'\\\x\\1/g" -e ":a; s/(INSERT INTO user_attribute_schema\(.*\) VALUES\(.*),1([^']*\);)$/\1,true\2/; s/(INSERT INTO user_attribute_schema\(.*\) VALUES\(.*),0([^']*\);)$/\1,false\2/; ta" -e '1s/^/BEGIN;\n/' -e '$aCOMMIT;' ./dump.sql - name: Create schema on postgres run: | @@ -358,7 +358,6 @@ jobs: - name: Copy converted db to postgress and import run: | - docker ps -a docker cp ./dump.sql postgresql:/tmp/dump.sql docker exec postgresql bash -c "psql -U lldapuser -d lldap < /tmp/dump.sql" rm ./dump.sql @@ -377,7 +376,6 @@ jobs: - name: Copy converted db to mariadb and import run: | - docker ps -a docker cp ./dump.sql mariadb:/tmp/dump.sql docker exec mariadb bash -c "mariadb -ulldapuser -plldappass -f lldap < /tmp/dump.sql" rm ./dump.sql @@ -395,7 +393,6 @@ jobs: - name: Copy converted db to mysql and import run: | - docker ps -a docker cp ./dump.sql mysql:/tmp/dump.sql docker exec mysql bash -c "mysql -ulldapuser -plldappass -f lldap < /tmp/dump.sql" rm ./dump.sql @@ -434,11 +431,12 @@ jobs: LLDAP_http_port: 17173 LLDAP_JWT_SECRET: somejwtsecret - - name: Test Dummy User - run: | - ldapsearch -H ldap://localhost:3891 -LLL -D "uid=dummyuser,ou=people,dc=example,dc=com" -w 'dummypassword' -s "One" -b "ou=people,dc=example,dc=com" - ldapsearch -H ldap://localhost:3892 -LLL -D "uid=dummyuser,ou=people,dc=example,dc=com" -w 'dummypassword' -s "One" -b "ou=people,dc=example,dc=com" - ldapsearch -H ldap://localhost:3893 -LLL -D "uid=dummyuser,ou=people,dc=example,dc=com" -w 'dummypassword' -s "One" -b "ou=people,dc=example,dc=com" + - name: Test Dummy User Postgres + run: ldapsearch -H ldap://localhost:3891 -LLL -D "uid=dummyuser,ou=people,dc=example,dc=com" -w 'dummypassword' -s "One" -b "ou=people,dc=example,dc=com" + - name: Test Dummy User MariaDB + run: ldapsearch -H ldap://localhost:3892 -LLL -D "uid=dummyuser,ou=people,dc=example,dc=com" -w 'dummypassword' -s "One" -b "ou=people,dc=example,dc=com" + - name: Test Dummy User MySQL + run: ldapsearch -H ldap://localhost:3893 -LLL -D "uid=dummyuser,ou=people,dc=example,dc=com" -w 'dummypassword' -s "One" -b "ou=people,dc=example,dc=com" build-docker-image: needs: [build-ui, build-bin] diff --git a/Cargo.lock b/Cargo.lock index 6d9d74e..5ea673d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2453,6 +2453,7 @@ dependencies = [ "serde_json", "serial_test", "sha2 0.10.6", + "strum", "thiserror", "time 0.3.19", "tokio", @@ -2466,7 +2467,7 @@ dependencies = [ "tracing-log", "tracing-subscriber", "urlencoding", - "uuid 0.8.2", + "uuid 1.3.1", "webpki-roots", ] @@ -2609,12 +2610,6 @@ dependencies = [ "digest 0.10.6", ] -[[package]] -name = "md5" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771" - [[package]] name = "memchr" version = "2.5.0" @@ -4065,6 +4060,12 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" +[[package]] +name = "strum" +version = "0.24.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "063e6045c0e62079840579a7e47a355ae92f60eb74daaf156fb1e84ba164e63f" + [[package]] name = "subtle" version = "2.4.1" @@ -4530,10 +4531,6 @@ name = "uuid" version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bc5cf98d8186244414c848017f0e2676b3fcb46807f6668a97dfe67359a3c4b7" -dependencies = [ - "getrandom 0.2.8", - "md5", -] [[package]] name = "uuid" @@ -4542,6 +4539,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5b55a3fef2a1e3b3a00ce878640918820d3c51081576ac657d23af9fc7928fdb" dependencies = [ "getrandom 0.2.8", + "md-5", ] [[package]] diff --git a/scripts/sqlite_dump_commands.sh b/scripts/sqlite_dump_commands.sh index 4b5778e..cdcf87a 100644 --- a/scripts/sqlite_dump_commands.sh +++ b/scripts/sqlite_dump_commands.sh @@ -1,9 +1,9 @@ #! /bin/bash -tables=("users" "groups" "memberships" "jwt_refresh_storage" "jwt_storage" "password_reset_tokens") +tables=("users" "groups" "memberships" "jwt_refresh_storage" "jwt_storage" "password_reset_tokens" "group_attribute_schema" "group_attributes" "user_attribute_schema" "user_attributes") echo ".header on" for table in ${tables[@]}; do echo ".mode insert $table" echo "select * from $table;" -done \ No newline at end of file +done diff --git a/server/Cargo.toml b/server/Cargo.toml index e42cf4a..fde5807 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -43,6 +43,7 @@ serde = "*" serde_bytes = "0.11" serde_json = "1" sha2 = "0.10" +strum = "0.24" thiserror = "*" time = "0.3" tokio-rustls = "0.23" diff --git a/server/src/domain/error.rs b/server/src/domain/error.rs index d5ec981..24d1aa5 100644 --- a/server/src/domain/error.rs +++ b/server/src/domain/error.rs @@ -7,6 +7,8 @@ pub enum DomainError { AuthenticationError(String), #[error("Database error: `{0}`")] DatabaseError(#[from] sea_orm::DbErr), + #[error("Database transaction error: `{0}`")] + DatabaseTransactionError(#[from] sea_orm::TransactionError), #[error("Authentication protocol error for `{0}`")] AuthenticationProtocolError(#[from] lldap_auth::opaque::AuthenticationError), #[error("Unknown crypto error: `{0}`")] @@ -21,4 +23,13 @@ pub enum DomainError { InternalError(String), } +impl From> for DomainError { + fn from(value: sea_orm::TransactionError) -> Self { + match value { + sea_orm::TransactionError::Connection(e) => e.into(), + sea_orm::TransactionError::Transaction(e) => e, + } + } +} + pub type Result = std::result::Result; diff --git a/server/src/domain/handler.rs b/server/src/domain/handler.rs index ded7f7d..591ab22 100644 --- a/server/src/domain/handler.rs +++ b/server/src/domain/handler.rs @@ -53,6 +53,7 @@ pub enum UserRequestFilter { UserId(UserId), UserIdSubString(SubStringFilter), Equality(UserColumn, String), + AttributeEquality(String, String), SubString(UserColumn, SubStringFilter), // Check if a user belongs to a group identified by name. MemberOf(String), diff --git a/server/src/domain/ldap/group.rs b/server/src/domain/ldap/group.rs index 4f29b7d..2d4294c 100644 --- a/server/src/domain/ldap/group.rs +++ b/server/src/domain/ldap/group.rs @@ -6,7 +6,7 @@ use tracing::{debug, instrument, warn}; use crate::domain::{ handler::{GroupListerBackendHandler, GroupRequestFilter}, ldap::error::LdapError, - types::{Group, GroupColumn, UserId, Uuid}, + types::{Group, UserId, Uuid}, }; use super::{ @@ -140,10 +140,8 @@ fn convert_group_filter( GroupRequestFilter::from(false) })), _ => match map_group_field(field) { - Some(GroupColumn::DisplayName) => { - Ok(GroupRequestFilter::DisplayName(value.to_string())) - } - Some(GroupColumn::Uuid) => Ok(GroupRequestFilter::Uuid( + Some("display_name") => Ok(GroupRequestFilter::DisplayName(value.to_string())), + Some("uuid") => Ok(GroupRequestFilter::Uuid( Uuid::try_from(value.as_str()).map_err(|e| LdapError { code: LdapResultCode::InappropriateMatching, message: format!("Invalid UUID: {:#}", e), @@ -181,7 +179,7 @@ fn convert_group_filter( LdapFilter::Substring(field, substring_filter) => { let field = &field.to_ascii_lowercase(); match map_group_field(field.as_str()) { - Some(GroupColumn::DisplayName) => Ok(GroupRequestFilter::DisplayNameSubString( + Some("display_name") => Ok(GroupRequestFilter::DisplayNameSubString( substring_filter.clone().into(), )), _ => Err(LdapError { diff --git a/server/src/domain/ldap/user.rs b/server/src/domain/ldap/user.rs index fec9b95..02be15c 100644 --- a/server/src/domain/ldap/user.rs +++ b/server/src/domain/ldap/user.rs @@ -7,17 +7,15 @@ use tracing::{debug, instrument, warn}; use crate::domain::{ handler::{UserListerBackendHandler, UserRequestFilter}, ldap::{ - error::LdapError, - utils::{expand_attribute_wildcards, get_user_id_from_distinguished_name}, + error::{LdapError, LdapResult}, + utils::{ + expand_attribute_wildcards, get_group_id_from_distinguished_name, + get_user_id_from_distinguished_name, map_user_field, LdapInfo, UserFieldType, + }, }, types::{GroupDetails, User, UserAndGroups, UserColumn, UserId}, }; -use super::{ - error::LdapResult, - utils::{get_group_id_from_distinguished_name, map_user_field, LdapInfo}, -}; - pub fn get_user_attribute( user: &User, attribute: &str, @@ -154,9 +152,17 @@ fn convert_user_filter(ldap_info: &LdapInfo, filter: &LdapFilter) -> LdapResult< UserRequestFilter::from(false) })), _ => match map_user_field(field) { - Some(UserColumn::UserId) => Ok(UserRequestFilter::UserId(UserId::new(value))), - Some(field) => Ok(UserRequestFilter::Equality(field, value.clone())), - None => { + UserFieldType::PrimaryField(UserColumn::UserId) => { + Ok(UserRequestFilter::UserId(UserId::new(value))) + } + UserFieldType::PrimaryField(field) => { + Ok(UserRequestFilter::Equality(field, value.clone())) + } + UserFieldType::Attribute(field) => Ok(UserRequestFilter::AttributeEquality( + field.to_owned(), + value.clone(), + )), + UserFieldType::NoMatch => { if !ldap_info.ignored_user_attributes.contains(field) { warn!( r#"Ignoring unknown user attribute "{}" in filter.\n\ @@ -176,26 +182,26 @@ fn convert_user_filter(ldap_info: &LdapInfo, filter: &LdapFilter) -> LdapResult< field == "objectclass" || field == "dn" || field == "distinguishedname" - || map_user_field(field).is_some(), + || !matches!(map_user_field(field), UserFieldType::NoMatch), )) } LdapFilter::Substring(field, substring_filter) => { let field = &field.to_ascii_lowercase(); match map_user_field(field.as_str()) { - Some(UserColumn::UserId) => Ok(UserRequestFilter::UserIdSubString( - substring_filter.clone().into(), - )), - None - | Some(UserColumn::CreationDate) - | Some(UserColumn::Avatar) - | Some(UserColumn::Uuid) => Err(LdapError { + UserFieldType::PrimaryField(UserColumn::UserId) => Ok( + UserRequestFilter::UserIdSubString(substring_filter.clone().into()), + ), + UserFieldType::NoMatch + | UserFieldType::Attribute(_) + | UserFieldType::PrimaryField(UserColumn::CreationDate) + | UserFieldType::PrimaryField(UserColumn::Uuid) => Err(LdapError { code: LdapResultCode::UnwillingToPerform, message: format!( "Unsupported user attribute for substring filter: {:?}", field ), }), - Some(field) => Ok(UserRequestFilter::SubString( + UserFieldType::PrimaryField(field) => Ok(UserRequestFilter::SubString( field, substring_filter.clone().into(), )), diff --git a/server/src/domain/ldap/utils.rs b/server/src/domain/ldap/utils.rs index 852d6e6..99f9328 100644 --- a/server/src/domain/ldap/utils.rs +++ b/server/src/domain/ldap/utils.rs @@ -5,7 +5,7 @@ use tracing::{debug, instrument, warn}; use crate::domain::{ handler::SubStringFilter, ldap::error::{LdapError, LdapResult}, - types::{GroupColumn, UserColumn, UserId}, + types::{UserColumn, UserId}, }; impl From for SubStringFilter { @@ -152,31 +152,37 @@ pub fn is_subtree(subtree: &[(String, String)], base_tree: &[(String, String)]) true } -pub fn map_user_field(field: &str) -> Option { - assert!(field == field.to_ascii_lowercase()); - Some(match field { - "uid" | "user_id" | "id" => UserColumn::UserId, - "mail" | "email" => UserColumn::Email, - "cn" | "displayname" | "display_name" => UserColumn::DisplayName, - "givenname" | "first_name" | "firstname" => UserColumn::FirstName, - "sn" | "last_name" | "lastname" => UserColumn::LastName, - "avatar" | "jpegphoto" => UserColumn::Avatar, - "creationdate" | "createtimestamp" | "modifytimestamp" | "creation_date" => { - UserColumn::CreationDate - } - "entryuuid" | "uuid" => UserColumn::Uuid, - _ => return None, - }) +pub enum UserFieldType { + NoMatch, + PrimaryField(UserColumn), + Attribute(&'static str), } -pub fn map_group_field(field: &str) -> Option { +pub fn map_user_field(field: &str) -> UserFieldType { + assert!(field == field.to_ascii_lowercase()); + match field { + "uid" | "user_id" | "id" => UserFieldType::PrimaryField(UserColumn::UserId), + "mail" | "email" => UserFieldType::PrimaryField(UserColumn::Email), + "cn" | "displayname" | "display_name" => { + UserFieldType::PrimaryField(UserColumn::DisplayName) + } + "givenname" | "first_name" | "firstname" => UserFieldType::Attribute("first_name"), + "sn" | "last_name" | "lastname" => UserFieldType::Attribute("last_name"), + "avatar" | "jpegphoto" => UserFieldType::Attribute("avatar"), + "creationdate" | "createtimestamp" | "modifytimestamp" | "creation_date" => { + UserFieldType::PrimaryField(UserColumn::CreationDate) + } + "entryuuid" | "uuid" => UserFieldType::PrimaryField(UserColumn::Uuid), + _ => UserFieldType::NoMatch, + } +} + +pub fn map_group_field(field: &str) -> Option<&'static str> { assert!(field == field.to_ascii_lowercase()); Some(match field { - "cn" | "displayname" | "uid" | "display_name" => GroupColumn::DisplayName, - "creationdate" | "createtimestamp" | "modifytimestamp" | "creation_date" => { - GroupColumn::CreationDate - } - "entryuuid" | "uuid" => GroupColumn::Uuid, + "cn" | "displayname" | "uid" | "display_name" => "display_name", + "creationdate" | "createtimestamp" | "modifytimestamp" | "creation_date" => "creation_date", + "entryuuid" | "uuid" => "uuid", _ => return None, }) } diff --git a/server/src/domain/model/group_attribute_schema.rs b/server/src/domain/model/group_attribute_schema.rs new file mode 100644 index 0000000..c6acd95 --- /dev/null +++ b/server/src/domain/model/group_attribute_schema.rs @@ -0,0 +1,39 @@ +use sea_orm::entity::prelude::*; +use serde::{Deserialize, Serialize}; + +use crate::domain::types::AttributeType; + +#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)] +#[sea_orm(table_name = "group_attribute_schema")] +pub struct Model { + #[sea_orm( + primary_key, + auto_increment = false, + column_name = "group_attribute_schema_name" + )] + pub attribute_name: String, + #[sea_orm(column_name = "group_attribute_schema_type")] + pub attribute_type: AttributeType, + #[sea_orm(column_name = "group_attribute_schema_is_list")] + pub is_list: bool, + #[sea_orm(column_name = "group_attribute_schema_is_group_visible")] + pub is_group_visible: bool, + #[sea_orm(column_name = "group_attribute_schema_is_group_editable")] + pub is_group_editable: bool, + #[sea_orm(column_name = "group_attribute_schema_is_hardcoded")] + pub is_hardcoded: bool, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm(has_many = "super::group_attributes::Entity")] + GroupAttributes, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::GroupAttributes.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/server/src/domain/model/group_attributes.rs b/server/src/domain/model/group_attributes.rs new file mode 100644 index 0000000..a6325a1 --- /dev/null +++ b/server/src/domain/model/group_attributes.rs @@ -0,0 +1,57 @@ +use sea_orm::entity::prelude::*; +use serde::{Deserialize, Serialize}; + +use crate::domain::types::{GroupId, Serialized}; + +#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)] +#[sea_orm(table_name = "group_attributes")] +pub struct Model { + #[sea_orm( + primary_key, + auto_increment = false, + column_name = "group_attribute_group_id" + )] + pub group_id: GroupId, + #[sea_orm( + primary_key, + auto_increment = false, + column_name = "group_attribute_name" + )] + pub attribute_name: String, + #[sea_orm(column_name = "group_attribute_value")] + pub value: Serialized, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::groups::Entity", + from = "Column::GroupId", + to = "super::groups::Column::GroupId", + on_update = "Cascade", + on_delete = "Cascade" + )] + Groups, + #[sea_orm( + belongs_to = "super::group_attribute_schema::Entity", + from = "Column::AttributeName", + to = "super::group_attribute_schema::Column::AttributeName", + on_update = "Cascade", + on_delete = "Cascade" + )] + GroupAttributeSchema, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Groups.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::GroupAttributeSchema.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/server/src/domain/model/mod.rs b/server/src/domain/model/mod.rs index 36b1060..622f478 100644 --- a/server/src/domain/model/mod.rs +++ b/server/src/domain/model/mod.rs @@ -9,4 +9,10 @@ pub mod memberships; pub mod password_reset_tokens; pub mod users; +pub mod user_attribute_schema; +pub mod user_attributes; + +pub mod group_attribute_schema; +pub mod group_attributes; + pub use prelude::*; diff --git a/server/src/domain/model/prelude.rs b/server/src/domain/model/prelude.rs index a25ffe6..337b85b 100644 --- a/server/src/domain/model/prelude.rs +++ b/server/src/domain/model/prelude.rs @@ -1,5 +1,9 @@ //! `SeaORM` Entity. Generated by sea-orm-codegen 0.10.3 +pub use super::group_attribute_schema::Column as GroupAttributeSchemaColumn; +pub use super::group_attribute_schema::Entity as GroupAttributeSchema; +pub use super::group_attributes::Column as GroupAttributesColumn; +pub use super::group_attributes::Entity as GroupAttributes; pub use super::groups::Column as GroupColumn; pub use super::groups::Entity as Group; pub use super::jwt_refresh_storage::Column as JwtRefreshStorageColumn; @@ -10,5 +14,9 @@ pub use super::memberships::Column as MembershipColumn; pub use super::memberships::Entity as Membership; pub use super::password_reset_tokens::Column as PasswordResetTokensColumn; pub use super::password_reset_tokens::Entity as PasswordResetTokens; +pub use super::user_attribute_schema::Column as UserAttributeSchemaColumn; +pub use super::user_attribute_schema::Entity as UserAttributeSchema; +pub use super::user_attributes::Column as UserAttributesColumn; +pub use super::user_attributes::Entity as UserAttributes; pub use super::users::Column as UserColumn; pub use super::users::Entity as User; diff --git a/server/src/domain/model/user_attribute_schema.rs b/server/src/domain/model/user_attribute_schema.rs new file mode 100644 index 0000000..e83ab06 --- /dev/null +++ b/server/src/domain/model/user_attribute_schema.rs @@ -0,0 +1,39 @@ +use sea_orm::entity::prelude::*; +use serde::{Deserialize, Serialize}; + +use crate::domain::types::AttributeType; + +#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)] +#[sea_orm(table_name = "user_attribute_schema")] +pub struct Model { + #[sea_orm( + primary_key, + auto_increment = false, + column_name = "user_attribute_schema_name" + )] + pub attribute_name: String, + #[sea_orm(column_name = "user_attribute_schema_type")] + pub attribute_type: AttributeType, + #[sea_orm(column_name = "user_attribute_schema_is_list")] + pub is_list: bool, + #[sea_orm(column_name = "user_attribute_schema_is_user_visible")] + pub is_user_visible: bool, + #[sea_orm(column_name = "user_attribute_schema_is_user_editable")] + pub is_user_editable: bool, + #[sea_orm(column_name = "user_attribute_schema_is_hardcoded")] + pub is_hardcoded: bool, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm(has_many = "super::user_attributes::Entity")] + UserAttributes, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::UserAttributes.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/server/src/domain/model/user_attributes.rs b/server/src/domain/model/user_attributes.rs new file mode 100644 index 0000000..1639c0a --- /dev/null +++ b/server/src/domain/model/user_attributes.rs @@ -0,0 +1,57 @@ +use sea_orm::entity::prelude::*; +use serde::{Deserialize, Serialize}; + +use crate::domain::types::{Serialized, UserId}; + +#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)] +#[sea_orm(table_name = "user_attributes")] +pub struct Model { + #[sea_orm( + primary_key, + auto_increment = false, + column_name = "user_attribute_user_id" + )] + pub user_id: UserId, + #[sea_orm( + primary_key, + auto_increment = false, + column_name = "user_attribute_name" + )] + pub attribute_name: String, + #[sea_orm(column_name = "user_attribute_value")] + pub value: Serialized, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::users::Entity", + from = "Column::UserId", + to = "super::users::Column::UserId", + on_update = "Cascade", + on_delete = "Cascade" + )] + Users, + #[sea_orm( + belongs_to = "super::user_attribute_schema::Entity", + from = "Column::AttributeName", + to = "super::user_attribute_schema::Column::AttributeName", + on_update = "Cascade", + on_delete = "Cascade" + )] + UserAttributeSchema, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Users.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::UserAttributeSchema.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/server/src/domain/model/users.rs b/server/src/domain/model/users.rs index 84b583a..c48e8a4 100644 --- a/server/src/domain/model/users.rs +++ b/server/src/domain/model/users.rs @@ -3,7 +3,7 @@ use sea_orm::{entity::prelude::*, sea_query::BlobSize}; use serde::{Deserialize, Serialize}; -use crate::domain::types::{JpegPhoto, UserId, Uuid}; +use crate::domain::types::{UserId, Uuid}; #[derive(Copy, Clone, Default, Debug, DeriveEntity)] pub struct Entity; @@ -15,9 +15,6 @@ pub struct Model { pub user_id: UserId, pub email: String, pub display_name: Option, - pub first_name: Option, - pub last_name: Option, - pub avatar: Option, pub creation_date: chrono::NaiveDateTime, pub password_hash: Option>, pub totp_secret: Option, @@ -36,9 +33,6 @@ pub enum Column { UserId, Email, DisplayName, - FirstName, - LastName, - Avatar, CreationDate, PasswordHash, TotpSecret, @@ -54,9 +48,6 @@ impl ColumnTrait for Column { Column::UserId => ColumnType::String(Some(255)), Column::Email => ColumnType::String(Some(255)), Column::DisplayName => ColumnType::String(Some(255)), - Column::FirstName => ColumnType::String(Some(255)), - Column::LastName => ColumnType::String(Some(255)), - Column::Avatar => ColumnType::Binary(BlobSize::Long), Column::CreationDate => ColumnType::DateTime, Column::PasswordHash => ColumnType::Binary(BlobSize::Medium), Column::TotpSecret => ColumnType::String(Some(64)), @@ -124,11 +115,11 @@ impl From for crate::domain::types::User { user_id: user.user_id, email: user.email, display_name: user.display_name, - first_name: user.first_name, - last_name: user.last_name, + first_name: None, + last_name: None, creation_date: user.creation_date, uuid: user.uuid, - avatar: user.avatar, + avatar: None, } } } diff --git a/server/src/domain/sql_migrations.rs b/server/src/domain/sql_migrations.rs index 37e27ca..8aeab71 100644 --- a/server/src/domain/sql_migrations.rs +++ b/server/src/domain/sql_migrations.rs @@ -1,17 +1,17 @@ use crate::domain::{ - sql_tables::{DbConnection, SchemaVersion}, - types::{GroupId, UserId, Uuid}, + sql_tables::{DbConnection, SchemaVersion, LAST_SCHEMA_VERSION}, + types::{AttributeType, GroupId, JpegPhoto, Serialized, UserId, Uuid}, }; -use anyhow::Context; use itertools::Itertools; use sea_orm::{ sea_query::{ self, all, ColumnDef, Expr, ForeignKey, ForeignKeyAction, Func, Index, Query, Table, Value, }, - ConnectionTrait, FromQueryResult, Iden, Order, Statement, TransactionTrait, + ConnectionTrait, DatabaseTransaction, DbErr, FromQueryResult, Iden, Order, Statement, + TransactionTrait, }; use serde::{Deserialize, Serialize}; -use tracing::{info, instrument, warn}; +use tracing::{error, info, instrument, warn}; #[derive(Iden, PartialEq, Eq, Debug, Serialize, Deserialize, Clone, Copy)] pub enum Users { @@ -45,6 +45,44 @@ pub enum Memberships { GroupId, } +#[derive(Iden, PartialEq, Eq, Debug, Serialize, Deserialize, Clone, Copy)] +pub enum UserAttributeSchema { + Table, + UserAttributeSchemaName, + UserAttributeSchemaType, + UserAttributeSchemaIsList, + UserAttributeSchemaIsUserVisible, + UserAttributeSchemaIsUserEditable, + UserAttributeSchemaIsHardcoded, +} + +#[derive(Iden, PartialEq, Eq, Debug, Serialize, Deserialize, Clone, Copy)] +pub enum UserAttributes { + Table, + UserAttributeUserId, + UserAttributeName, + UserAttributeValue, +} + +#[derive(Iden, PartialEq, Eq, Debug, Serialize, Deserialize, Clone, Copy)] +pub enum GroupAttributeSchema { + Table, + GroupAttributeSchemaName, + GroupAttributeSchemaType, + GroupAttributeSchemaIsList, + GroupAttributeSchemaIsGroupVisible, + GroupAttributeSchemaIsGroupEditable, + GroupAttributeSchemaIsHardcoded, +} + +#[derive(Iden, PartialEq, Eq, Debug, Serialize, Deserialize, Clone, Copy)] +pub enum GroupAttributes { + Table, + GroupAttributeGroupId, + GroupAttributeName, + GroupAttributeValue, +} + // Metadata about the SQL DB. #[derive(Iden)] pub enum Metadata { @@ -337,72 +375,64 @@ pub async fn upgrade_to_v1(pool: &DbConnection) -> std::result::Result<(), sea_o } async fn replace_column( - pool: &DbConnection, + transaction: DatabaseTransaction, table_name: I, column_name: I, mut new_column: ColumnDef, update_values: [Statement; N], -) -> anyhow::Result<()> { +) -> Result { // Update the definition of a column (in a compatible way). Due to Sqlite, this is more complicated: // - rename the column to a temporary name // - create the column with the new definition // - copy the data from the temp column to the new one // - update the new one if there are changes needed // - drop the old one - let builder = pool.get_database_backend(); - pool.transaction::<_, (), sea_orm::DbErr>(move |transaction| { - Box::pin(async move { - #[derive(Iden)] - enum TempTable { - TempName, - } - transaction - .execute( - builder.build( - Table::alter() - .table(table_name) - .rename_column(column_name, TempTable::TempName), - ), - ) - .await?; - transaction - .execute( - builder.build(Table::alter().table(table_name).add_column(&mut new_column)), - ) - .await?; - transaction - .execute( - builder.build( - Query::update() - .table(table_name) - .value(column_name, Expr::col((table_name, TempTable::TempName))), - ), - ) - .await?; - for statement in update_values { - transaction.execute(statement).await?; - } - transaction - .execute( - builder.build( - Table::alter() - .table(table_name) - .drop_column(TempTable::TempName), - ), - ) - .await?; - Ok(()) - }) - }) - .await?; - Ok(()) + let builder = transaction.get_database_backend(); + #[derive(Iden)] + enum TempTable { + TempName, + } + transaction + .execute( + builder.build( + Table::alter() + .table(table_name) + .rename_column(column_name, TempTable::TempName), + ), + ) + .await?; + transaction + .execute(builder.build(Table::alter().table(table_name).add_column(&mut new_column))) + .await?; + transaction + .execute( + builder.build( + Query::update() + .table(table_name) + .value(column_name, Expr::col((table_name, TempTable::TempName))), + ), + ) + .await?; + for statement in update_values { + transaction.execute(statement).await?; + } + transaction + .execute( + builder.build( + Table::alter() + .table(table_name) + .drop_column(TempTable::TempName), + ), + ) + .await?; + Ok(transaction) } -async fn migrate_to_v2(pool: &DbConnection) -> anyhow::Result<()> { - let builder = pool.get_database_backend(); +async fn migrate_to_v2(transaction: DatabaseTransaction) -> Result { + let builder = transaction.get_database_backend(); // Allow nulls in DisplayName, and change empty string to null. - replace_column( - pool, + let transaction = replace_column( + transaction, Users::Table, Users::DisplayName, ColumnDef::new(Users::DisplayName) @@ -416,14 +446,14 @@ async fn migrate_to_v2(pool: &DbConnection) -> anyhow::Result<()> { )], ) .await?; - Ok(()) + Ok(transaction) } -async fn migrate_to_v3(pool: &DbConnection) -> anyhow::Result<()> { - let builder = pool.get_database_backend(); +async fn migrate_to_v3(transaction: DatabaseTransaction) -> Result { + let builder = transaction.get_database_backend(); // Allow nulls in First and LastName. Users who created their DB in 0.4.1 have the not null constraint. - replace_column( - pool, + let transaction = replace_column( + transaction, Users::Table, Users::FirstName, ColumnDef::new(Users::FirstName).string_len(255).to_owned(), @@ -435,8 +465,8 @@ async fn migrate_to_v3(pool: &DbConnection) -> anyhow::Result<()> { )], ) .await?; - replace_column( - pool, + let transaction = replace_column( + transaction, Users::Table, Users::LastName, ColumnDef::new(Users::LastName).string_len(255).to_owned(), @@ -449,8 +479,8 @@ async fn migrate_to_v3(pool: &DbConnection) -> anyhow::Result<()> { ) .await?; // Change Avatar from binary to blob(long), because for MySQL this is 64kb. - replace_column( - pool, + let transaction = replace_column( + transaction, Users::Table, Users::Avatar, ColumnDef::new(Users::Avatar) @@ -459,13 +489,13 @@ async fn migrate_to_v3(pool: &DbConnection) -> anyhow::Result<()> { [], ) .await?; - Ok(()) + Ok(transaction) } -async fn migrate_to_v4(pool: &DbConnection) -> anyhow::Result<()> { - let builder = pool.get_database_backend(); +async fn migrate_to_v4(transaction: DatabaseTransaction) -> Result { + let builder = transaction.get_database_backend(); // Make emails and UUIDs unique. - if let Err(e) = pool + if let Err(e) = transaction .execute( builder.build( Index::create() @@ -477,16 +507,16 @@ async fn migrate_to_v4(pool: &DbConnection) -> anyhow::Result<()> { ), ) .await - .context( - r#"while enforcing unicity on emails (2 users have the same email). + { + error!( + r#"Found several users with the same email. See https://github.com/lldap/lldap/blob/main/docs/migration_guides/v0.5.md for details. +Conflicting emails: "#, - ) - { - warn!("Found several users with the same email:"); - for (email, users) in &pool + ); + for (email, users) in &transaction .query_all( builder.build( Query::select() @@ -528,39 +558,329 @@ See https://github.com/lldap/lldap/blob/main/docs/migration_guides/v0.5.md for d } return Err(e); } - pool.execute( - builder.build( - Index::create() - .if_not_exists() - .name("unique-user-uuid") - .table(Users::Table) - .col(Users::Uuid) - .unique(), - ), - ) - .await - .context("while enforcing unicity on user UUIDs (2 users have the same UUID)")?; - pool.execute( - builder.build( - Index::create() - .if_not_exists() - .name("unique-group-uuid") - .table(Groups::Table) - .col(Groups::Uuid) - .unique(), - ), - ) - .await - .context("while enforcing unicity on group UUIDs (2 groups have the same UUID)")?; - Ok(()) + transaction + .execute( + builder.build( + Index::create() + .if_not_exists() + .name("unique-user-uuid") + .table(Users::Table) + .col(Users::Uuid) + .unique(), + ), + ) + .await?; + transaction + .execute( + builder.build( + Index::create() + .if_not_exists() + .name("unique-group-uuid") + .table(Groups::Table) + .col(Groups::Uuid) + .unique(), + ), + ) + .await?; + Ok(transaction) +} + +async fn migrate_to_v5(transaction: DatabaseTransaction) -> Result { + let builder = transaction.get_database_backend(); + transaction + .execute( + builder.build( + Table::create() + .table(UserAttributeSchema::Table) + .col( + ColumnDef::new(UserAttributeSchema::UserAttributeSchemaName) + .string_len(64) + .not_null() + .primary_key(), + ) + .col( + ColumnDef::new(UserAttributeSchema::UserAttributeSchemaType) + .string_len(64) + .not_null(), + ) + .col( + ColumnDef::new(UserAttributeSchema::UserAttributeSchemaIsList) + .boolean() + .not_null(), + ) + .col( + ColumnDef::new(UserAttributeSchema::UserAttributeSchemaIsUserVisible) + .boolean() + .not_null(), + ) + .col( + ColumnDef::new(UserAttributeSchema::UserAttributeSchemaIsUserEditable) + .boolean() + .not_null(), + ) + .col( + ColumnDef::new(UserAttributeSchema::UserAttributeSchemaIsHardcoded) + .boolean() + .not_null(), + ), + ), + ) + .await?; + + transaction + .execute( + builder.build( + Table::create() + .table(GroupAttributeSchema::Table) + .col( + ColumnDef::new(GroupAttributeSchema::GroupAttributeSchemaName) + .string_len(64) + .not_null() + .primary_key(), + ) + .col( + ColumnDef::new(GroupAttributeSchema::GroupAttributeSchemaType) + .string_len(64) + .not_null(), + ) + .col( + ColumnDef::new(GroupAttributeSchema::GroupAttributeSchemaIsList) + .boolean() + .not_null(), + ) + .col( + ColumnDef::new(GroupAttributeSchema::GroupAttributeSchemaIsGroupVisible) + .boolean() + .not_null(), + ) + .col( + ColumnDef::new(GroupAttributeSchema::GroupAttributeSchemaIsGroupEditable) + .boolean() + .not_null(), + ) + .col( + ColumnDef::new(GroupAttributeSchema::GroupAttributeSchemaIsHardcoded) + .boolean() + .not_null(), + ), + ), + ) + .await?; + + transaction + .execute( + builder.build( + Table::create() + .table(UserAttributes::Table) + .col( + ColumnDef::new(UserAttributes::UserAttributeUserId) + .string_len(255) + .not_null(), + ) + .col( + ColumnDef::new(UserAttributes::UserAttributeName) + .string_len(64) + .not_null(), + ) + .col( + ColumnDef::new(UserAttributes::UserAttributeValue) + .blob(sea_query::BlobSize::Long) + .not_null(), + ) + .foreign_key( + ForeignKey::create() + .name("UserAttributeUserIdForeignKey") + .from(UserAttributes::Table, UserAttributes::UserAttributeUserId) + .to(Users::Table, Users::UserId) + .on_delete(ForeignKeyAction::Cascade) + .on_update(ForeignKeyAction::Cascade), + ) + .foreign_key( + ForeignKey::create() + .name("UserAttributeNameForeignKey") + .from(UserAttributes::Table, UserAttributes::UserAttributeName) + .to( + UserAttributeSchema::Table, + UserAttributeSchema::UserAttributeSchemaName, + ) + .on_delete(ForeignKeyAction::Cascade) + .on_update(ForeignKeyAction::Cascade), + ) + .primary_key( + Index::create() + .col(UserAttributes::UserAttributeUserId) + .col(UserAttributes::UserAttributeName), + ), + ), + ) + .await?; + + transaction + .execute( + builder.build( + Table::create() + .table(GroupAttributes::Table) + .col( + ColumnDef::new(GroupAttributes::GroupAttributeGroupId) + .integer() + .not_null(), + ) + .col( + ColumnDef::new(GroupAttributes::GroupAttributeName) + .string_len(64) + .not_null(), + ) + .col( + ColumnDef::new(GroupAttributes::GroupAttributeValue) + .blob(sea_query::BlobSize::Long) + .not_null(), + ) + .foreign_key( + ForeignKey::create() + .name("GroupAttributeGroupIdForeignKey") + .from( + GroupAttributes::Table, + GroupAttributes::GroupAttributeGroupId, + ) + .to(Groups::Table, Groups::GroupId) + .on_delete(ForeignKeyAction::Cascade) + .on_update(ForeignKeyAction::Cascade), + ) + .foreign_key( + ForeignKey::create() + .name("GroupAttributeNameForeignKey") + .from(GroupAttributes::Table, GroupAttributes::GroupAttributeName) + .to( + GroupAttributeSchema::Table, + GroupAttributeSchema::GroupAttributeSchemaName, + ) + .on_delete(ForeignKeyAction::Cascade) + .on_update(ForeignKeyAction::Cascade), + ) + .primary_key( + Index::create() + .col(GroupAttributes::GroupAttributeGroupId) + .col(GroupAttributes::GroupAttributeName), + ), + ), + ) + .await?; + + transaction + .execute( + builder.build( + Query::insert() + .into_table(UserAttributeSchema::Table) + .columns([ + UserAttributeSchema::UserAttributeSchemaName, + UserAttributeSchema::UserAttributeSchemaType, + UserAttributeSchema::UserAttributeSchemaIsList, + UserAttributeSchema::UserAttributeSchemaIsUserVisible, + UserAttributeSchema::UserAttributeSchemaIsUserEditable, + UserAttributeSchema::UserAttributeSchemaIsHardcoded, + ]) + .values_panic([ + "first_name".into(), + AttributeType::String.into(), + false.into(), + true.into(), + true.into(), + true.into(), + ]) + .values_panic([ + "last_name".into(), + AttributeType::String.into(), + false.into(), + true.into(), + true.into(), + true.into(), + ]) + .values_panic([ + "avatar".into(), + AttributeType::JpegPhoto.into(), + false.into(), + true.into(), + true.into(), + true.into(), + ]), + ), + ) + .await?; + + { + let mut user_statement = Query::insert() + .into_table(UserAttributes::Table) + .columns([ + UserAttributes::UserAttributeUserId, + UserAttributes::UserAttributeName, + UserAttributes::UserAttributeValue, + ]) + .to_owned(); + #[derive(FromQueryResult)] + struct FullUserDetails { + user_id: UserId, + first_name: Option, + last_name: Option, + avatar: Option, + } + let mut any_user = false; + for user in FullUserDetails::find_by_statement(builder.build( + Query::select().from(Users::Table).columns([ + Users::UserId, + Users::FirstName, + Users::LastName, + Users::Avatar, + ]), + )) + .all(&transaction) + .await? + { + if let Some(name) = &user.first_name { + any_user = true; + user_statement.values_panic([ + user.user_id.clone().into(), + "first_name".into(), + Serialized::from(name).into(), + ]); + } + if let Some(name) = &user.last_name { + any_user = true; + user_statement.values_panic([ + user.user_id.clone().into(), + "last_name".into(), + Serialized::from(name).into(), + ]); + } + if let Some(avatar) = &user.avatar { + any_user = true; + user_statement.values_panic([ + user.user_id.clone().into(), + "avatar".into(), + Serialized::from(avatar).into(), + ]); + } + } + + if any_user { + transaction.execute(builder.build(&user_statement)).await?; + } + } + + for column in [Users::FirstName, Users::LastName, Users::Avatar] { + transaction + .execute(builder.build(Table::alter().table(Users::Table).drop_column(column))) + .await?; + } + + Ok(transaction) } // This is needed to make an array of async functions. macro_rules! to_sync { ($l:ident) => { - |pool| -> std::pin::Pin>>> { - Box::pin($l(pool)) - } + move |transaction| -> std::pin::Pin< + Box>>, + > { Box::pin($l(transaction)) } }; } @@ -579,21 +899,26 @@ pub async fn migrate_from_version( to_sync!(migrate_to_v2), to_sync!(migrate_to_v3), to_sync!(migrate_to_v4), + to_sync!(migrate_to_v5), ]; - for migration in 2..=4 { + assert_eq!(migrations.len(), (LAST_SCHEMA_VERSION.0 - 1) as usize); + for migration in 2..=last_version.0 { if version < SchemaVersion(migration) && SchemaVersion(migration) <= last_version { info!("Upgrading DB schema to version {}", migration); - migrations[(migration - 2) as usize](pool).await?; + let transaction = pool.begin().await?; + let transaction = migrations[(migration - 2) as usize](transaction).await?; + let builder = transaction.get_database_backend(); + transaction + .execute( + builder.build( + Query::update() + .table(Metadata::Table) + .value(Metadata::Version, Value::from(migration)), + ), + ) + .await?; + transaction.commit().await?; } } - let builder = pool.get_database_backend(); - pool.execute( - builder.build( - Query::update() - .table(Metadata::Table) - .value(Metadata::Version, Value::from(last_version)), - ), - ) - .await?; Ok(()) } diff --git a/server/src/domain/sql_tables.rs b/server/src/domain/sql_tables.rs index cc0d37b..eff6057 100644 --- a/server/src/domain/sql_tables.rs +++ b/server/src/domain/sql_tables.rs @@ -21,7 +21,7 @@ impl From for Value { } } -const LAST_SCHEMA_VERSION: SchemaVersion = SchemaVersion(4); +pub const LAST_SCHEMA_VERSION: SchemaVersion = SchemaVersion(5); pub async fn init_table(pool: &DbConnection) -> anyhow::Result<()> { let version = { @@ -40,7 +40,7 @@ pub async fn init_table(pool: &DbConnection) -> anyhow::Result<()> { mod tests { use crate::domain::{ sql_migrations, - types::{GroupId, Uuid}, + types::{GroupId, JpegPhoto, Serialized, Uuid}, }; use super::*; @@ -62,10 +62,22 @@ mod tests { async fn test_init_table() { let sql_pool = get_in_memory_db().await; init_table(&sql_pool).await.unwrap(); - sql_pool.execute(raw_statement( - r#"INSERT INTO users - (user_id, email, display_name, first_name, last_name, creation_date, password_hash, uuid) - VALUES ("bôb", "böb@bob.bob", "Bob Bobbersön", "Bob", "Bobberson", "1970-01-01 00:00:00", "bob00", "abc")"#)).await.unwrap(); + sql_pool + .execute(raw_statement( + r#"INSERT INTO users + (user_id, email, display_name, creation_date, password_hash, uuid) + VALUES ("bôb", "böb@bob.bob", "Bob Bobbersön", "1970-01-01 00:00:00", "bob00", "abc")"#, + )) + .await + .unwrap(); + sql_pool + .execute(raw_statement( + r#"INSERT INTO user_attributes + (user_attribute_user_id, user_attribute_name, user_attribute_value) + VALUES ("bôb", "first_name", "Bob")"#, + )) + .await + .unwrap(); #[derive(FromQueryResult, PartialEq, Eq, Debug)] struct ShortUserDetails { display_name: String, @@ -97,11 +109,12 @@ mod tests { #[tokio::test] async fn test_migrate_tables() { + crate::infra::logging::init_for_tests(); // Test that we add the column creation_date to groups and uuid to users and groups. let sql_pool = get_in_memory_db().await; sql_pool .execute(raw_statement( - r#"CREATE TABLE users ( user_id TEXT, display_name TEXT, first_name TEXT NOT NULL, last_name TEXT, avatar BLOB, creation_date TEXT, email TEXT);"#, + r#"CREATE TABLE users ( user_id TEXT PRIMARY KEY, display_name TEXT, first_name TEXT NOT NULL, last_name TEXT, avatar BLOB, creation_date TEXT, email TEXT);"#, )) .await .unwrap(); @@ -143,12 +156,11 @@ mod tests { #[derive(FromQueryResult, PartialEq, Eq, Debug)] struct SimpleUser { display_name: Option, - first_name: Option, uuid: Uuid, } assert_eq!( SimpleUser::find_by_statement(raw_statement( - r#"SELECT display_name, first_name, uuid FROM users ORDER BY display_name"# + r#"SELECT display_name, uuid FROM users ORDER BY display_name"# )) .all(&sql_pool) .await @@ -156,17 +168,36 @@ mod tests { vec![ SimpleUser { display_name: None, - first_name: None, uuid: crate::uuid!("a02eaf13-48a7-30f6-a3d4-040ff7c52b04") }, SimpleUser { display_name: Some("John Doe".to_owned()), - first_name: Some("John".to_owned()), uuid: crate::uuid!("986765a5-3f03-389e-b47b-536b2d6e1bec") } ] ); #[derive(FromQueryResult, PartialEq, Eq, Debug)] + struct UserAttribute { + user_attribute_user_id: String, + user_attribute_name: String, + user_attribute_value: Serialized, + } + assert_eq!( + UserAttribute::find_by_statement(raw_statement( + r#"SELECT user_attribute_user_id, user_attribute_name, user_attribute_value FROM user_attributes ORDER BY user_attribute_user_id, user_attribute_value"# + )) + .all(&sql_pool) + .await + .unwrap(), + vec![ + UserAttribute { + user_attribute_user_id: "john".to_owned(), + user_attribute_name: "first_name".to_owned(), + user_attribute_value: Serialized::from("John"), + } + ] + ); + #[derive(FromQueryResult, PartialEq, Eq, Debug)] struct ShortGroupDetails { group_id: GroupId, display_name: String, @@ -270,6 +301,92 @@ mod tests { ); } + #[tokio::test] + async fn test_migration_to_v5() { + crate::infra::logging::init_for_tests(); + let sql_pool = get_in_memory_db().await; + upgrade_to_v1(&sql_pool).await.unwrap(); + migrate_from_version(&sql_pool, SchemaVersion(1), SchemaVersion(4)) + .await + .unwrap(); + sql_pool + .execute(raw_statement( + r#"INSERT INTO users (user_id, email, creation_date, uuid) + VALUES ("bob", "bob@bob.com", "1970-01-01 00:00:00", "a02eaf13-48a7-30f6-a3d4-040ff7c52b04")"#, + )) + .await + .unwrap(); + sql_pool + .execute(sea_orm::Statement::from_sql_and_values(DbBackend::Sqlite, + r#"INSERT INTO users (user_id, email, display_name, first_name, last_name, avatar, creation_date, uuid) + VALUES ("bob2", "bob2@bob.com", "display bob", "first bob", "last bob", $1, "1970-01-01 00:00:00", "986765a5-3f03-389e-b47b-536b2d6e1bec")"#, [JpegPhoto::for_tests().into()]), + ) + .await + .unwrap(); + migrate_from_version(&sql_pool, SchemaVersion(4), SchemaVersion(5)) + .await + .unwrap(); + assert_eq!( + sql_migrations::JustSchemaVersion::find_by_statement(raw_statement( + r#"SELECT version FROM metadata"# + )) + .one(&sql_pool) + .await + .unwrap() + .unwrap(), + sql_migrations::JustSchemaVersion { + version: SchemaVersion(5) + } + ); + #[derive(FromQueryResult, PartialEq, Eq, Debug)] + pub struct UserV5 { + user_id: String, + email: String, + display_name: Option, + } + assert_eq!( + UserV5::find_by_statement(raw_statement( + r#"SELECT user_id, email, display_name FROM users ORDER BY user_id ASC"# + )) + .all(&sql_pool) + .await + .unwrap(), + vec![ + UserV5 { + user_id: "bob".to_owned(), + email: "bob@bob.com".to_owned(), + display_name: None + }, + UserV5 { + user_id: "bob2".to_owned(), + email: "bob2@bob.com".to_owned(), + display_name: Some("display bob".to_owned()) + }, + ] + ); + sql_pool + .execute(raw_statement(r#"SELECT first_name FROM users"#)) + .await + .unwrap_err(); + #[derive(FromQueryResult, PartialEq, Eq, Debug)] + pub struct UserAttribute { + user_attribute_user_id: String, + user_attribute_name: String, + user_attribute_value: Serialized, + } + assert_eq!( + UserAttribute::find_by_statement(raw_statement(r#"SELECT * FROM user_attributes ORDER BY user_attribute_user_id, user_attribute_name ASC"#)) + .all(&sql_pool) + .await + .unwrap(), + vec![ + UserAttribute { user_attribute_user_id: "bob2".to_string(), user_attribute_name: "avatar".to_owned(), user_attribute_value: Serialized::from(&JpegPhoto::for_tests()) }, + UserAttribute { user_attribute_user_id: "bob2".to_string(), user_attribute_name: "first_name".to_owned(), user_attribute_value: Serialized::from("first bob") }, + UserAttribute { user_attribute_user_id: "bob2".to_string(), user_attribute_name: "last_name".to_owned(), user_attribute_value: Serialized::from("last bob") }, + ] + ); + } + #[tokio::test] async fn test_too_high_version() { let sql_pool = get_in_memory_db().await; diff --git a/server/src/domain/sql_user_backend_handler.rs b/server/src/domain/sql_user_backend_handler.rs index a53de3d..c5dfc2d 100644 --- a/server/src/domain/sql_user_backend_handler.rs +++ b/server/src/domain/sql_user_backend_handler.rs @@ -6,17 +6,31 @@ use crate::domain::{ }, model::{self, GroupColumn, UserColumn}, sql_backend_handler::SqlBackendHandler, - types::{GroupDetails, GroupId, User, UserAndGroups, UserId, Uuid}, + types::{GroupDetails, GroupId, Serialized, User, UserAndGroups, UserId, Uuid}, }; use async_trait::async_trait; use sea_orm::{ - entity::IntoActiveValue, - sea_query::{Alias, Cond, Expr, Func, IntoColumnRef, IntoCondition, SimpleExpr}, - ActiveModelTrait, ActiveValue, ColumnTrait, EntityTrait, ModelTrait, QueryFilter, QueryOrder, - QuerySelect, QueryTrait, Set, + sea_query::{ + query::OnConflict, Alias, Cond, Expr, Func, IntoColumnRef, IntoCondition, SimpleExpr, + }, + ActiveModelTrait, ActiveValue, ColumnTrait, EntityTrait, IntoActiveValue, ModelTrait, + QueryFilter, QueryOrder, QuerySelect, QueryTrait, Set, TransactionTrait, }; use std::collections::HashSet; -use tracing::{debug, instrument}; +use tracing::{debug, instrument, warn}; + +fn attribute_condition(name: String, value: String) -> Cond { + Expr::in_subquery( + Expr::col(UserColumn::UserId.as_column_ref()), + model::UserAttributes::find() + .select_only() + .column(model::UserAttributesColumn::UserId) + .filter(model::UserAttributesColumn::AttributeName.eq(name)) + .filter(model::UserAttributesColumn::Value.eq(Serialized::from(&value))) + .into_query(), + ) + .into_condition() +} fn get_user_filter_expr(filter: UserRequestFilter) -> Cond { use UserRequestFilter::*; @@ -46,6 +60,7 @@ fn get_user_filter_expr(filter: UserRequestFilter) -> Cond { ColumnTrait::eq(&s1, s2).into_condition() } } + AttributeEquality(s1, s2) => attribute_condition(s1, s2), MemberOf(group) => Expr::col((group_table, GroupColumn::DisplayName)) .eq(group) .into_condition(), @@ -55,9 +70,11 @@ fn get_user_filter_expr(filter: UserRequestFilter) -> Cond { UserIdSubString(filter) => UserColumn::UserId .like(&filter.to_sql_filter()) .into_condition(), - SubString(col, filter) => SimpleExpr::FunctionCall(Func::lower(Expr::col(col))) - .like(filter.to_sql_filter()) - .into_condition(), + SubString(col, filter) => { + SimpleExpr::FunctionCall(Func::lower(Expr::col(col.as_column_ref()))) + .like(filter.to_sql_filter()) + .into_condition() + } } } @@ -78,10 +95,11 @@ impl UserListerBackendHandler for SqlBackendHandler { async fn list_users( &self, filters: Option, - get_groups: bool, + // To simplify the query, we always fetch groups. TODO: cleanup. + _get_groups: bool, ) -> Result> { debug!(?filters); - let query = model::User::find() + let results = model::User::find() .filter( filters .map(|f| { @@ -98,45 +116,62 @@ impl UserListerBackendHandler for SqlBackendHandler { }) .unwrap_or_else(|| SimpleExpr::Value(true.into()).into_condition()), ) - .order_by_asc(UserColumn::UserId); - if !get_groups { - Ok(query - .into_model::() - .all(&self.sql_pool) - .await? - .into_iter() - .map(|u| UserAndGroups { - user: u, - groups: None, - }) - .collect()) - } else { - let results = query - //find_with_linked? - .find_also_linked(model::memberships::UserToGroup) - .order_by_asc(SimpleExpr::Column( - (Alias::new("r1"), GroupColumn::GroupId).into_column_ref(), - )) - .all(&self.sql_pool) - .await?; - use itertools::Itertools; - Ok(results - .iter() - .group_by(|(u, _)| u) - .into_iter() - .map(|(user, groups)| { - let groups: Vec<_> = groups - .into_iter() - .flat_map(|(_, g)| g) - .map(|g| GroupDetails::from(g.clone())) - .collect(); - UserAndGroups { - user: user.clone().into(), - groups: Some(groups), - } - }) - .collect()) + .order_by_asc(UserColumn::UserId) + //find_with_linked? + .find_also_linked(model::memberships::UserToGroup) + .order_by_asc(SimpleExpr::Column( + (Alias::new("r1"), GroupColumn::GroupId).into_column_ref(), + )) + .all(&self.sql_pool) + .await?; + use itertools::Itertools; + let mut users: Vec<_> = results + .iter() + .group_by(|(u, _)| u) + .into_iter() + .map(|(user, groups)| { + let groups: Vec<_> = groups + .into_iter() + .flat_map(|(_, g)| g) + .map(|g| GroupDetails::from(g.clone())) + .collect(); + UserAndGroups { + user: user.clone().into(), + groups: Some(groups), + } + }) + .collect(); + // At this point, the users don't have attributes, we need to populate it with another query. + let user_ids = users + .iter() + .map(|u| u.user.user_id.clone()) + .collect::>(); + let attributes = model::UserAttributes::find() + .filter(model::UserAttributesColumn::UserId.is_in(&user_ids)) + .order_by_asc(model::UserAttributesColumn::UserId) + .all(&self.sql_pool) + .await?; + let mut attributes_iter = attributes.iter().peekable(); + for user in users.iter_mut() { + attributes_iter + .peeking_take_while(|u| u.user_id < user.user.user_id) + .for_each(|_| ()); + + for model::user_attributes::Model { + user_id: _, + attribute_name, + value, + } in attributes_iter.take_while_ref(|u| u.user_id == user.user.user_id) + { + match attribute_name.as_str() { + "first_name" => user.user.first_name = Some(value.unwrap()), + "last_name" => user.user.last_name = Some(value.unwrap()), + "avatar" => user.user.avatar = Some(value.unwrap()), + _ => warn!("Unknown attribute name: {}", attribute_name), + } + } } + Ok(users) } } @@ -145,11 +180,30 @@ impl UserBackendHandler for SqlBackendHandler { #[instrument(skip_all, level = "debug", ret)] async fn get_user_details(&self, user_id: &UserId) -> Result { debug!(?user_id); - model::User::find_by_id(user_id.to_owned()) - .into_model::() - .one(&self.sql_pool) - .await? - .ok_or_else(|| DomainError::EntityNotFound(user_id.to_string())) + let mut user = User::from( + model::User::find_by_id(user_id.to_owned()) + .one(&self.sql_pool) + .await? + .ok_or_else(|| DomainError::EntityNotFound(user_id.to_string()))?, + ); + let attributes = model::UserAttributes::find() + .filter(model::UserAttributesColumn::UserId.eq(user_id)) + .all(&self.sql_pool) + .await?; + for model::user_attributes::Model { + user_id: _, + attribute_name, + value, + } in attributes + { + match attribute_name.as_str() { + "first_name" => user.first_name = Some(value.unwrap()), + "last_name" => user.last_name = Some(value.unwrap()), + "avatar" => user.avatar = Some(value.unwrap()), + _ => warn!("Unknown attribute name: {}", attribute_name), + } + } + Ok(user) } #[instrument(skip_all, level = "debug", ret, err)] @@ -173,17 +227,48 @@ impl UserBackendHandler for SqlBackendHandler { let now = chrono::Utc::now().naive_utc(); let uuid = Uuid::from_name_and_date(request.user_id.as_str(), &now); let new_user = model::users::ActiveModel { - user_id: Set(request.user_id), + user_id: Set(request.user_id.clone()), email: Set(request.email), display_name: to_value(&request.display_name), - first_name: to_value(&request.first_name), - last_name: to_value(&request.last_name), - avatar: request.avatar.into_active_value(), creation_date: ActiveValue::Set(now), uuid: ActiveValue::Set(uuid), ..Default::default() }; - new_user.insert(&self.sql_pool).await?; + let mut new_user_attributes = Vec::new(); + if let Some(first_name) = request.first_name { + new_user_attributes.push(model::user_attributes::ActiveModel { + user_id: Set(request.user_id.clone()), + attribute_name: Set("first_name".to_owned()), + value: Set(Serialized::from(&first_name)), + }); + } + if let Some(last_name) = request.last_name { + new_user_attributes.push(model::user_attributes::ActiveModel { + user_id: Set(request.user_id.clone()), + attribute_name: Set("last_name".to_owned()), + value: Set(Serialized::from(&last_name)), + }); + } + if let Some(avatar) = request.avatar { + new_user_attributes.push(model::user_attributes::ActiveModel { + user_id: Set(request.user_id), + attribute_name: Set("avatar".to_owned()), + value: Set(Serialized::from(&avatar)), + }); + } + self.sql_pool + .transaction::<_, (), DomainError>(|transaction| { + Box::pin(async move { + new_user.insert(transaction).await?; + if !new_user_attributes.is_empty() { + model::UserAttributes::insert_many(new_user_attributes) + .exec(transaction) + .await?; + } + Ok(()) + }) + }) + .await?; Ok(()) } @@ -191,15 +276,72 @@ impl UserBackendHandler for SqlBackendHandler { async fn update_user(&self, request: UpdateUserRequest) -> Result<()> { debug!(user_id = ?request.user_id); let update_user = model::users::ActiveModel { - user_id: ActiveValue::Set(request.user_id), + user_id: ActiveValue::Set(request.user_id.clone()), email: request.email.map(ActiveValue::Set).unwrap_or_default(), display_name: to_value(&request.display_name), - first_name: to_value(&request.first_name), - last_name: to_value(&request.last_name), - avatar: request.avatar.into_active_value(), ..Default::default() }; - update_user.update(&self.sql_pool).await?; + let mut update_user_attributes = Vec::new(); + let mut remove_user_attributes = Vec::new(); + let to_serialized_value = |s: &Option| match s.as_ref().map(|s| s.as_str()) { + None => None, + Some("") => Some(ActiveValue::NotSet), + Some(s) => Some(ActiveValue::Set(Serialized::from(s))), + }; + let mut process_serialized = + |value: ActiveValue, attribute_name: &str| match &value { + ActiveValue::NotSet => { + remove_user_attributes.push(attribute_name.to_owned()); + } + ActiveValue::Set(_) => { + update_user_attributes.push(model::user_attributes::ActiveModel { + user_id: Set(request.user_id.clone()), + attribute_name: Set(attribute_name.to_owned()), + value, + }) + } + _ => unreachable!(), + }; + if let Some(value) = to_serialized_value(&request.first_name) { + process_serialized(value, "first_name"); + } + if let Some(value) = to_serialized_value(&request.last_name) { + process_serialized(value, "last_name"); + } + if let Some(avatar) = request.avatar { + process_serialized(avatar.into_active_value(), "avatar"); + } + self.sql_pool + .transaction::<_, (), DomainError>(|transaction| { + Box::pin(async move { + update_user.update(transaction).await?; + if !update_user_attributes.is_empty() { + model::UserAttributes::insert_many(update_user_attributes) + .on_conflict( + OnConflict::columns([ + model::UserAttributesColumn::UserId, + model::UserAttributesColumn::AttributeName, + ]) + .update_column(model::UserAttributesColumn::Value) + .to_owned(), + ) + .exec(transaction) + .await?; + } + if !remove_user_attributes.is_empty() { + model::UserAttributes::delete_many() + .filter(model::UserAttributesColumn::UserId.eq(&request.user_id)) + .filter( + model::UserAttributesColumn::AttributeName + .is_in(remove_user_attributes), + ) + .exec(transaction) + .await?; + } + Ok(()) + }) + }) + .await?; Ok(()) } @@ -291,8 +433,8 @@ mod tests { let fixture = TestFixture::new().await; let users = get_user_names( &fixture.handler, - Some(UserRequestFilter::Equality( - UserColumn::FirstName, + Some(UserRequestFilter::AttributeEquality( + "first_name".to_string(), "first bob".to_string(), )), ) @@ -312,10 +454,10 @@ mod tests { final_: Some("K".to_owned()), }), UserRequestFilter::SubString( - UserColumn::FirstName, + UserColumn::DisplayName, SubStringFilter { initial: None, - any: vec!["r".to_owned(), "t".to_owned()], + any: vec!["t".to_owned(), "r".to_owned()], final_: None, }, ), @@ -633,8 +775,9 @@ mod tests { .handler .update_user(UpdateUserRequest { user_id: UserId::new("bob"), - first_name: Some("first_name".to_string()), + first_name: None, last_name: Some(String::new()), + avatar: Some(JpegPhoto::for_tests()), ..Default::default() }) .await @@ -646,11 +789,78 @@ mod tests { .await .unwrap(); assert_eq!(user.display_name.unwrap(), "display bob"); - assert_eq!(user.first_name.unwrap(), "first_name"); + assert_eq!(user.first_name.unwrap(), "first bob"); assert_eq!(user.last_name, None); + assert_eq!(user.avatar, Some(JpegPhoto::for_tests())); + } + + #[tokio::test] + async fn test_update_user_delete_avatar() { + let fixture = TestFixture::new().await; + + fixture + .handler + .update_user(UpdateUserRequest { + user_id: UserId::new("bob"), + avatar: Some(JpegPhoto::for_tests()), + ..Default::default() + }) + .await + .unwrap(); + + let user = fixture + .handler + .get_user_details(&UserId::new("bob")) + .await + .unwrap(); + assert_eq!(user.avatar, Some(JpegPhoto::for_tests())); + fixture + .handler + .update_user(UpdateUserRequest { + user_id: UserId::new("bob"), + avatar: Some(JpegPhoto::null()), + ..Default::default() + }) + .await + .unwrap(); + + let user = fixture + .handler + .get_user_details(&UserId::new("bob")) + .await + .unwrap(); assert_eq!(user.avatar, None); } + #[tokio::test] + async fn test_create_user_all_values() { + let fixture = TestFixture::new().await; + + fixture + .handler + .create_user(CreateUserRequest { + user_id: UserId::new("james"), + email: "email".to_string(), + display_name: Some("display_name".to_string()), + first_name: Some("first_name".to_string()), + last_name: Some("last_name".to_string()), + avatar: Some(JpegPhoto::for_tests()), + }) + .await + .unwrap(); + + let user = fixture + .handler + .get_user_details(&UserId::new("james")) + .await + .unwrap(); + assert_eq!(user.email, "email"); + assert_eq!(user.display_name.unwrap(), "display_name"); + assert_eq!(user.first_name.unwrap(), "first_name"); + assert_eq!(user.last_name.unwrap(), "last_name"); + assert_eq!(user.avatar, Some(JpegPhoto::for_tests())); + } + #[tokio::test] async fn test_remove_user_from_group() { let fixture = TestFixture::new().await; @@ -670,4 +880,32 @@ mod tests { vec!["patrick"] ); } + + #[tokio::test] + async fn test_delete_user_not_found() { + let fixture = TestFixture::new().await; + + fixture + .handler + .delete_user(&UserId::new("not found")) + .await + .expect_err("Should have failed"); + } + + #[tokio::test] + async fn test_remove_user_from_group_not_found() { + let fixture = TestFixture::new().await; + + fixture + .handler + .remove_user_from_group(&UserId::new("not found"), fixture.groups[0]) + .await + .expect_err("Should have failed"); + + fixture + .handler + .remove_user_from_group(&UserId::new("not found"), GroupId(16242)) + .await + .expect_err("Should have failed"); + } } diff --git a/server/src/domain/types.rs b/server/src/domain/types.rs index 0c1d1aa..15a3c8e 100644 --- a/server/src/domain/types.rs +++ b/server/src/domain/types.rs @@ -2,7 +2,8 @@ use base64::Engine; use chrono::{NaiveDateTime, TimeZone}; use sea_orm::{ entity::IntoActiveValue, - sea_query::{value::ValueType, ArrayType, ColumnType, Nullable, ValueTypeErr}, + sea_query::{value::ValueType, ArrayType, BlobSize, ColumnType, Nullable, ValueTypeErr}, + strum::{EnumString, IntoStaticStr}, DbErr, FromQueryResult, QueryResult, TryFromU64, TryGetError, TryGetable, Value, }; use serde::{Deserialize, Serialize}; @@ -103,7 +104,64 @@ macro_rules! uuid { }; } -#[derive(PartialEq, Eq, Clone, Debug, Default, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub struct Serialized(Vec); + +impl<'a, T: Serialize + ?Sized> From<&'a T> for Serialized { + fn from(t: &'a T) -> Self { + Self(bincode::serialize(&t).unwrap()) + } +} + +impl Serialized { + pub fn unwrap<'a, T: Deserialize<'a>>(&'a self) -> T { + bincode::deserialize(&self.0).unwrap() + } + + pub fn expect<'a, T: Deserialize<'a>>(&'a self, message: &str) -> T { + bincode::deserialize(&self.0).expect(message) + } +} + +impl From for Value { + fn from(ser: Serialized) -> Self { + ser.0.into() + } +} + +impl TryGetable for Serialized { + fn try_get_by(res: &QueryResult, index: I) -> Result { + Ok(Self(Vec::::try_get_by(res, index)?)) + } +} + +impl TryFromU64 for Serialized { + fn try_from_u64(_n: u64) -> Result { + Err(DbErr::ConvertFromU64( + "Serialized cannot be constructed from u64", + )) + } +} + +impl ValueType for Serialized { + fn try_from(v: Value) -> Result { + Ok(Self( as ValueType>::try_from(v)?)) + } + + fn type_name() -> String { + "Serialized".to_owned() + } + + fn array_type() -> ArrayType { + ArrayType::Bytes + } + + fn column_type() -> ColumnType { + ColumnType::Binary(BlobSize::Long) + } +} + +#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Debug, Default, Serialize, Deserialize)] #[serde(from = "String")] pub struct UserId(String); @@ -238,6 +296,10 @@ impl From<&JpegPhoto> for String { } impl JpegPhoto { + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + pub fn null() -> Self { Self(vec![]) } @@ -296,7 +358,7 @@ impl ValueType for JpegPhoto { } fn column_type() -> ColumnType { - ColumnType::Binary(sea_orm::sea_query::BlobSize::Long) + ColumnType::Binary(BlobSize::Long) } } @@ -306,13 +368,17 @@ impl Nullable for JpegPhoto { } } -impl IntoActiveValue for JpegPhoto { - fn into_active_value(self) -> sea_orm::ActiveValue { - sea_orm::ActiveValue::Set(self) +impl IntoActiveValue for JpegPhoto { + fn into_active_value(self) -> sea_orm::ActiveValue { + if self.is_empty() { + sea_orm::ActiveValue::NotSet + } else { + sea_orm::ActiveValue::Set(Serialized::from(&self)) + } } } -#[derive(PartialEq, Eq, Debug, Clone, Serialize, Deserialize, FromQueryResult)] +#[derive(PartialEq, Eq, Debug, Clone, Serialize, Deserialize)] pub struct User { pub user_id: UserId, pub email: String, @@ -380,6 +446,51 @@ impl TryFromU64 for GroupId { } } +#[derive( + Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, EnumString, IntoStaticStr, +)] +pub enum AttributeType { + String, + Integer, + JpegPhoto, + DateTime, +} + +impl From for Value { + fn from(attribute_type: AttributeType) -> Self { + Into::<&'static str>::into(attribute_type).into() + } +} + +impl TryGetable for AttributeType { + fn try_get_by(res: &QueryResult, index: I) -> Result { + use std::str::FromStr; + Ok(AttributeType::from_str(&String::try_get_by(res, index)?).expect("Invalid enum value")) + } +} + +impl ValueType for AttributeType { + fn try_from(v: Value) -> Result { + use std::str::FromStr; + Ok( + AttributeType::from_str(&::try_from(v)?) + .expect("Invalid enum value"), + ) + } + + fn type_name() -> String { + "AttributeType".to_owned() + } + + fn array_type() -> ArrayType { + ArrayType::String + } + + fn column_type() -> ColumnType { + ColumnType::String(Some(64)) + } +} + #[derive(PartialEq, Eq, Debug, Serialize, Deserialize)] pub struct Group { pub id: GroupId, diff --git a/server/src/infra/graphql/query.rs b/server/src/infra/graphql/query.rs index 1eb2f5e..5e8d453 100644 --- a/server/src/infra/graphql/query.rs +++ b/server/src/infra/graphql/query.rs @@ -1,7 +1,7 @@ use crate::{ domain::{ handler::BackendHandler, - ldap::utils::map_user_field, + ldap::utils::{map_user_field, UserFieldType}, types::{GroupDetails, GroupId, UserColumn, UserId}, }, infra::{ @@ -61,14 +61,19 @@ impl TryInto for RequestFilter { return Err("Multiple fields specified in request filter".to_string()); } if let Some(e) = self.eq { - if let Some(column) = map_user_field(&e.field) { - if column == UserColumn::UserId { - return Ok(DomainRequestFilter::UserId(UserId::new(&e.value))); + return match map_user_field(&e.field.to_ascii_lowercase()) { + UserFieldType::NoMatch => Err(format!("Unknown request filter: {}", &e.field)), + UserFieldType::PrimaryField(UserColumn::UserId) => { + Ok(DomainRequestFilter::UserId(UserId::new(&e.value))) } - return Ok(DomainRequestFilter::Equality(column, e.value)); - } else { - return Err(format!("Unknown request filter: {}", &e.field)); - } + UserFieldType::PrimaryField(column) => { + Ok(DomainRequestFilter::Equality(column, e.value)) + } + UserFieldType::Attribute(column) => Ok(DomainRequestFilter::AttributeEquality( + column.to_owned(), + e.value, + )), + }; } if let Some(c) = self.any { return Ok(DomainRequestFilter::Or( @@ -461,6 +466,10 @@ mod tests { {eq: { field: "email" value: "robert@bobbers.on" + }}, + {eq: { + field: "firstName" + value: "robert" }} ]}) { id @@ -475,7 +484,11 @@ mod tests { DomainRequestFilter::UserId(UserId::new("bob")), DomainRequestFilter::Equality( UserColumn::Email, - "robert@bobbers.on".to_string(), + "robert@bobbers.on".to_owned(), + ), + DomainRequestFilter::AttributeEquality( + "first_name".to_owned(), + "robert".to_owned(), ), ]))), eq(false), @@ -485,7 +498,7 @@ mod tests { DomainUserAndGroups { user: DomainUser { user_id: UserId::new("bob"), - email: "bob@bobbers.on".to_string(), + email: "bob@bobbers.on".to_owned(), ..Default::default() }, groups: None, @@ -493,7 +506,7 @@ mod tests { DomainUserAndGroups { user: DomainUser { user_id: UserId::new("robert"), - email: "robert@bobbers.on".to_string(), + email: "robert@bobbers.on".to_owned(), ..Default::default() }, groups: None, diff --git a/server/src/infra/ldap_handler.rs b/server/src/infra/ldap_handler.rs index 81d320c..4266d95 100644 --- a/server/src/infra/ldap_handler.rs +++ b/server/src/infra/ldap_handler.rs @@ -1307,6 +1307,7 @@ mod tests { GroupRequestFilter::Member(UserId::new("bob")), GroupRequestFilter::DisplayName("rockstars".to_string()), false.into(), + GroupRequestFilter::Uuid(uuid!("04ac75e0-2900-3e21-926c-2f732c26b3fc")), true.into(), true.into(), true.into(), @@ -1345,6 +1346,10 @@ mod tests { "dn".to_string(), "uid=rockstars,ou=people,dc=example,dc=com".to_string(), ), + LdapFilter::Equality( + "uuid".to_string(), + "04ac75e0-2900-3e21-926c-2f732c26b3fc".to_string(), + ), LdapFilter::Equality("obJEctclass".to_string(), "groupofUniqueNames".to_string()), LdapFilter::Equality("objectclass".to_string(), "groupOfNames".to_string()), LdapFilter::Present("objectclass".to_string()), @@ -1530,6 +1535,10 @@ mod tests { true.into(), true.into(), false.into(), + UserRequestFilter::AttributeEquality( + "first_name".to_owned(), + "firstname".to_owned(), + ), false.into(), UserRequestFilter::UserIdSubString(SubStringFilter { initial: Some("iNIt".to_owned()), @@ -1537,7 +1546,7 @@ mod tests { final_: Some("finAl".to_owned()), }), UserRequestFilter::SubString( - UserColumn::FirstName, + UserColumn::DisplayName, SubStringFilter { initial: Some("iNIt".to_owned()), any: vec!["1".to_owned(), "2aA".to_owned()], @@ -1570,6 +1579,7 @@ mod tests { LdapFilter::Present("objectClass".to_string()), LdapFilter::Present("uid".to_string()), LdapFilter::Present("unknown".to_string()), + LdapFilter::Equality("givenname".to_string(), "firstname".to_string()), LdapFilter::Equality("unknown_attribute".to_string(), "randomValue".to_string()), LdapFilter::Substring( "uid".to_owned(), @@ -1580,7 +1590,7 @@ mod tests { }, ), LdapFilter::Substring( - "firstName".to_owned(), + "displayName".to_owned(), LdapSubstringFilter { initial: Some("iNIt".to_owned()), any: vec!["1".to_owned(), "2aA".to_owned()], @@ -1596,6 +1606,35 @@ mod tests { ); } + #[tokio::test] + async fn test_search_unsupported_substring_filter() { + let mut ldap_handler = setup_bound_admin_handler(MockTestBackendHandler::new()).await; + let request = make_user_search_request( + LdapFilter::Substring( + "uuid".to_owned(), + LdapSubstringFilter { + initial: Some("iNIt".to_owned()), + any: vec!["1".to_owned(), "2aA".to_owned()], + final_: Some("finAl".to_owned()), + }, + ), + vec!["objectClass"], + ); + ldap_handler.do_search_or_dse(&request).await.unwrap_err(); + let request = make_user_search_request( + LdapFilter::Substring( + "givenname".to_owned(), + LdapSubstringFilter { + initial: Some("iNIt".to_owned()), + any: vec!["1".to_owned(), "2aA".to_owned()], + final_: Some("finAl".to_owned()), + }, + ), + vec!["objectClass"], + ); + ldap_handler.do_search_or_dse(&request).await.unwrap_err(); + } + #[tokio::test] async fn test_search_member_of_filter() { let mut mock = MockTestBackendHandler::new(); @@ -1652,7 +1691,7 @@ mod tests { .with( eq(Some(UserRequestFilter::And(vec![UserRequestFilter::Or( vec![UserRequestFilter::Not(Box::new( - UserRequestFilter::Equality(UserColumn::FirstName, "bob".to_string()), + UserRequestFilter::Equality(UserColumn::DisplayName, "bob".to_string()), ))], )]))), eq(false), @@ -1670,7 +1709,7 @@ mod tests { let mut ldap_handler = setup_bound_admin_handler(mock).await; let request = make_user_search_request( LdapFilter::And(vec![LdapFilter::Or(vec![LdapFilter::Not(Box::new( - LdapFilter::Equality("givenname".to_string(), "bob".to_string()), + LdapFilter::Equality("displayname".to_string(), "bob".to_string()), ))])]), vec!["objectclass"], ); diff --git a/server/src/infra/tcp_server.rs b/server/src/infra/tcp_server.rs index 43f65ea..989688d 100644 --- a/server/src/infra/tcp_server.rs +++ b/server/src/infra/tcp_server.rs @@ -53,6 +53,7 @@ pub(crate) fn error_to_http_response(error: TcpError) -> HttpResponse { HttpResponse::Unauthorized() } DomainError::DatabaseError(_) + | DomainError::DatabaseTransactionError(_) | DomainError::InternalError(_) | DomainError::UnknownCryptoError(_) => HttpResponse::InternalServerError(), DomainError::Base64DecodeError(_) diff --git a/server/tests/common/fixture.rs b/server/tests/common/fixture.rs index 77c188c..338f213 100644 --- a/server/tests/common/fixture.rs +++ b/server/tests/common/fixture.rs @@ -222,7 +222,7 @@ impl Drop for LLDAPFixture { pub fn new_id(prefix: Option<&str>) -> String { let id = Uuid::new_v4(); - let id = format!("{}-lldap-test", id.to_simple()); + let id = format!("{}-lldap-test", id.simple()); match prefix { Some(prefix) => format!("{}{}", prefix, id), None => id,