server: make attributes names, group names and emails case insensitive

In addition, group names and emails keep their casing
This commit is contained in:
Valentin Tolmer
2023-12-15 22:28:59 +01:00
committed by nitnelave
parent 71d37b9e5e
commit 272c84c574
27 changed files with 721 additions and 328 deletions

View File

@@ -1,8 +1,8 @@
use crate::domain::{
error::Result,
types::{
AttributeType, AttributeValue, Group, GroupDetails, GroupId, JpegPhoto, User,
UserAndGroups, UserColumn, UserId, Uuid,
AttributeName, AttributeType, AttributeValue, Email, Group, GroupDetails, GroupId,
GroupName, JpegPhoto, User, UserAndGroups, UserColumn, UserId, Uuid,
},
};
use async_trait::async_trait;
@@ -54,10 +54,10 @@ pub enum UserRequestFilter {
UserId(UserId),
UserIdSubString(SubStringFilter),
Equality(UserColumn, String),
AttributeEquality(String, String),
AttributeEquality(AttributeName, String),
SubString(UserColumn, SubStringFilter),
// Check if a user belongs to a group identified by name.
MemberOf(String),
MemberOf(GroupName),
// Same, by id.
MemberOfId(GroupId),
}
@@ -77,7 +77,7 @@ pub enum GroupRequestFilter {
And(Vec<GroupRequestFilter>),
Or(Vec<GroupRequestFilter>),
Not(Box<GroupRequestFilter>),
DisplayName(String),
DisplayName(GroupName),
DisplayNameSubString(SubStringFilter),
Uuid(Uuid),
GroupId(GroupId),
@@ -99,7 +99,7 @@ impl From<bool> for GroupRequestFilter {
pub struct CreateUserRequest {
// Same fields as User, but no creation_date, and with password.
pub user_id: UserId,
pub email: String,
pub email: Email,
pub display_name: Option<String>,
pub first_name: Option<String>,
pub last_name: Option<String>,
@@ -111,32 +111,32 @@ pub struct CreateUserRequest {
pub struct UpdateUserRequest {
// Same fields as CreateUserRequest, but no with an extra layer of Option.
pub user_id: UserId,
pub email: Option<String>,
pub email: Option<Email>,
pub display_name: Option<String>,
pub first_name: Option<String>,
pub last_name: Option<String>,
pub avatar: Option<JpegPhoto>,
pub delete_attributes: Vec<String>,
pub delete_attributes: Vec<AttributeName>,
pub insert_attributes: Vec<AttributeValue>,
}
#[derive(PartialEq, Eq, Debug, Serialize, Deserialize, Clone, Default)]
pub struct CreateGroupRequest {
pub display_name: String,
pub display_name: GroupName,
pub attributes: Vec<AttributeValue>,
}
#[derive(PartialEq, Eq, Debug, Serialize, Deserialize, Clone)]
pub struct UpdateGroupRequest {
pub group_id: GroupId,
pub display_name: Option<String>,
pub delete_attributes: Vec<String>,
pub display_name: Option<GroupName>,
pub delete_attributes: Vec<AttributeName>,
pub insert_attributes: Vec<AttributeValue>,
}
#[derive(PartialEq, Eq, Debug, Serialize, Deserialize, Clone)]
pub struct AttributeSchema {
pub name: String,
pub name: AttributeName,
//TODO: pub aliases: Vec<String>,
pub attribute_type: AttributeType,
pub is_list: bool,
@@ -147,7 +147,7 @@ pub struct AttributeSchema {
#[derive(PartialEq, Eq, Debug, Serialize, Deserialize, Clone)]
pub struct CreateAttributeRequest {
pub name: String,
pub name: AttributeName,
pub attribute_type: AttributeType,
pub is_list: bool,
pub is_visible: bool,
@@ -160,11 +160,11 @@ pub struct AttributeList {
}
impl AttributeList {
pub fn get_attribute_schema(&self, name: &str) -> Option<&AttributeSchema> {
self.attributes.iter().find(|a| a.name == name)
pub fn get_attribute_schema(&self, name: &AttributeName) -> Option<&AttributeSchema> {
self.attributes.iter().find(|a| a.name == *name)
}
pub fn get_attribute_type(&self, name: &str) -> Option<(AttributeType, bool)> {
pub fn get_attribute_type(&self, name: &AttributeName) -> Option<(AttributeType, bool)> {
self.get_attribute_schema(name)
.map(|a| (a.attribute_type, a.is_list))
}
@@ -224,8 +224,8 @@ pub trait SchemaBackendHandler: ReadSchemaBackendHandler {
async fn add_user_attribute(&self, request: CreateAttributeRequest) -> Result<()>;
async fn add_group_attribute(&self, request: CreateAttributeRequest) -> Result<()>;
// Note: It's up to the caller to make sure that the attribute is not hardcoded.
async fn delete_user_attribute(&self, name: &str) -> Result<()>;
async fn delete_group_attribute(&self, name: &str) -> Result<()>;
async fn delete_user_attribute(&self, name: &AttributeName) -> Result<()>;
async fn delete_group_attribute(&self, name: &AttributeName) -> Result<()>;
}
#[async_trait]

View File

@@ -7,7 +7,7 @@ use crate::domain::{
handler::{GroupListerBackendHandler, GroupRequestFilter},
ldap::error::LdapError,
schema::{PublicSchema, SchemaGroupAttributeExtractor},
types::{Group, UserId, Uuid},
types::{AttributeName, Group, UserId, Uuid},
};
use super::{
@@ -23,15 +23,15 @@ pub fn get_group_attribute(
base_dn_str: &str,
attribute: &str,
user_filter: &Option<UserId>,
ignored_group_attributes: &[String],
ignored_group_attributes: &[AttributeName],
schema: &PublicSchema,
) -> Option<Vec<Vec<u8>>> {
let attribute = attribute.to_ascii_lowercase();
let attribute = AttributeName::from(attribute);
let attribute_values = match attribute.as_str() {
"objectclass" => vec![b"groupOfUniqueNames".to_vec()],
// Always returned as part of the base response.
"dn" | "distinguishedname" => return None,
"cn" | "uid" | "id" => vec![group.display_name.clone().into_bytes()],
"cn" | "uid" | "id" => vec![group.display_name.to_string().into_bytes()],
"entryuuid" | "uuid" => vec![group.uuid.to_string().into_bytes()],
"member" | "uniquemember" => group
.users
@@ -48,11 +48,11 @@ pub fn get_group_attribute(
attribute
)
}
attr => {
_ => {
if !ignored_group_attributes.contains(&attribute) {
match get_custom_attribute::<SchemaGroupAttributeExtractor>(
&group.attributes,
attr,
&attribute,
schema,
) {
Some(v) => return Some(v),
@@ -91,7 +91,7 @@ fn make_ldap_search_group_result_entry(
base_dn_str: &str,
attributes: &[String],
user_filter: &Option<UserId>,
ignored_group_attributes: &[String],
ignored_group_attributes: &[AttributeName],
schema: &PublicSchema,
) -> LdapSearchResultEntry {
let expanded_attributes = expand_group_attribute_wildcards(attributes);
@@ -125,12 +125,12 @@ fn convert_group_filter(
let rec = |f| convert_group_filter(ldap_info, f);
match filter {
LdapFilter::Equality(field, value) => {
let field = &field.to_ascii_lowercase();
let value = &value.to_ascii_lowercase();
let field = AttributeName::from(field.as_str());
let value = value.to_ascii_lowercase();
match field.as_str() {
"member" | "uniquemember" => {
let user_name = get_user_id_from_distinguished_name(
value,
&value,
&ldap_info.base_dn,
&ldap_info.base_dn_str,
)?;
@@ -150,8 +150,8 @@ fn convert_group_filter(
warn!("Invalid dn filter on group: {}", value);
GroupRequestFilter::from(false)
})),
_ => match map_group_field(field) {
Some("display_name") => Ok(GroupRequestFilter::DisplayName(value.to_string())),
_ => match map_group_field(&field) {
Some("display_name") => Ok(GroupRequestFilter::DisplayName(value.into())),
Some("uuid") => Ok(GroupRequestFilter::Uuid(
Uuid::try_from(value.as_str()).map_err(|e| LdapError {
code: LdapResultCode::InappropriateMatching,
@@ -159,9 +159,9 @@ fn convert_group_filter(
})?,
)),
_ => {
if !ldap_info.ignored_group_attributes.contains(field) {
if !ldap_info.ignored_group_attributes.contains(&field) {
warn!(
r#"Ignoring unknown group attribute "{:?}" in filter.\n\
r#"Ignoring unknown group attribute "{}" in filter.\n\
To disable this warning, add it to "ignored_group_attributes" in the config."#,
field
);
@@ -179,24 +179,24 @@ fn convert_group_filter(
)),
LdapFilter::Not(filter) => Ok(GroupRequestFilter::Not(Box::new(rec(filter)?))),
LdapFilter::Present(field) => {
let field = &field.to_ascii_lowercase();
let field = AttributeName::from(field.as_str());
Ok(GroupRequestFilter::from(
field == "objectclass"
|| field == "dn"
|| field == "distinguishedname"
|| map_group_field(field).is_some(),
field.as_str() == "objectclass"
|| field.as_str() == "dn"
|| field.as_str() == "distinguishedname"
|| map_group_field(&field).is_some(),
))
}
LdapFilter::Substring(field, substring_filter) => {
let field = &field.to_ascii_lowercase();
match map_group_field(field.as_str()) {
let field = AttributeName::from(field.as_str());
match map_group_field(&field) {
Some("display_name") => Ok(GroupRequestFilter::DisplayNameSubString(
substring_filter.clone().into(),
)),
_ => Err(LdapError {
code: LdapResultCode::UnwillingToPerform,
message: format!(
"Unsupported group attribute for substring filter: {:?}",
"Unsupported group attribute for substring filter: \"{}\"",
field
),
}),

View File

@@ -14,7 +14,7 @@ use crate::domain::{
},
},
schema::{PublicSchema, SchemaUserAttributeExtractor},
types::{GroupDetails, User, UserAndGroups, UserColumn, UserId},
types::{AttributeName, GroupDetails, User, UserAndGroups, UserColumn, UserId},
};
pub fn get_user_attribute(
@@ -22,10 +22,10 @@ pub fn get_user_attribute(
attribute: &str,
base_dn_str: &str,
groups: Option<&[GroupDetails]>,
ignored_user_attributes: &[String],
ignored_user_attributes: &[AttributeName],
schema: &PublicSchema,
) -> Option<Vec<Vec<u8>>> {
let attribute = attribute.to_ascii_lowercase();
let attribute = AttributeName::from(attribute);
let attribute_values = match attribute.as_str() {
"objectclass" => vec![
b"inetOrgPerson".to_vec(),
@@ -37,20 +37,22 @@ pub fn get_user_attribute(
"dn" | "distinguishedname" => return None,
"uid" | "user_id" | "id" => vec![user.user_id.to_string().into_bytes()],
"entryuuid" | "uuid" => vec![user.uuid.to_string().into_bytes()],
"mail" | "email" => vec![user.email.clone().into_bytes()],
"givenname" | "first_name" | "firstname" => get_custom_attribute::<
SchemaUserAttributeExtractor,
>(
&user.attributes, "first_name", schema
)?,
"mail" | "email" => vec![user.email.to_string().into_bytes()],
"givenname" | "first_name" | "firstname" => {
get_custom_attribute::<SchemaUserAttributeExtractor>(
&user.attributes,
&"first_name".into(),
schema,
)?
}
"sn" | "last_name" | "lastname" => get_custom_attribute::<SchemaUserAttributeExtractor>(
&user.attributes,
"last_name",
&"last_name".into(),
schema,
)?,
"jpegphoto" | "avatar" => get_custom_attribute::<SchemaUserAttributeExtractor>(
&user.attributes,
"avatar",
&"avatar".into(),
schema,
)?,
"memberof" => groups
@@ -80,7 +82,7 @@ pub fn get_user_attribute(
if !ignored_user_attributes.contains(&attribute) {
match get_custom_attribute::<SchemaUserAttributeExtractor>(
&user.attributes,
attr,
&attribute,
schema,
) {
Some(v) => return Some(v),
@@ -118,7 +120,7 @@ fn make_ldap_search_user_result_entry(
base_dn_str: &str,
attributes: &[String],
groups: Option<&[GroupDetails]>,
ignored_user_attributes: &[String],
ignored_user_attributes: &[AttributeName],
schema: &PublicSchema,
) -> LdapSearchResultEntry {
let expanded_attributes = expand_user_attribute_wildcards(attributes);
@@ -156,7 +158,7 @@ fn convert_user_filter(ldap_info: &LdapInfo, filter: &LdapFilter) -> LdapResult<
)),
LdapFilter::Not(filter) => Ok(UserRequestFilter::Not(Box::new(rec(filter)?))),
LdapFilter::Equality(field, value) => {
let field = &field.to_ascii_lowercase();
let field = AttributeName::from(field.as_str());
match field.as_str() {
"memberof" => Ok(UserRequestFilter::MemberOf(
get_group_id_from_distinguished_name(
@@ -179,7 +181,7 @@ fn convert_user_filter(ldap_info: &LdapInfo, filter: &LdapFilter) -> LdapResult<
warn!("Invalid dn filter on user: {}", value);
UserRequestFilter::from(false)
})),
_ => match map_user_field(field) {
_ => match map_user_field(&field) {
UserFieldType::PrimaryField(UserColumn::UserId) => {
Ok(UserRequestFilter::UserId(UserId::new(value)))
}
@@ -187,11 +189,11 @@ fn convert_user_filter(ldap_info: &LdapInfo, filter: &LdapFilter) -> LdapResult<
Ok(UserRequestFilter::Equality(field, value.clone()))
}
UserFieldType::Attribute(field) => Ok(UserRequestFilter::AttributeEquality(
field.to_owned(),
AttributeName::from(field),
value.clone(),
)),
UserFieldType::NoMatch => {
if !ldap_info.ignored_user_attributes.contains(field) {
if !ldap_info.ignored_user_attributes.contains(&field) {
warn!(
r#"Ignoring unknown user attribute "{}" in filter.\n\
To disable this warning, add it to "ignored_user_attributes" in the config"#,
@@ -204,18 +206,18 @@ fn convert_user_filter(ldap_info: &LdapInfo, filter: &LdapFilter) -> LdapResult<
}
}
LdapFilter::Present(field) => {
let field = &field.to_ascii_lowercase();
let field = AttributeName::from(field.as_str());
// Check that it's a field we support.
Ok(UserRequestFilter::from(
field == "objectclass"
|| field == "dn"
|| field == "distinguishedname"
|| !matches!(map_user_field(field), UserFieldType::NoMatch),
field.as_str() == "objectclass"
|| field.as_str() == "dn"
|| field.as_str() == "distinguishedname"
|| !matches!(map_user_field(&field), UserFieldType::NoMatch),
))
}
LdapFilter::Substring(field, substring_filter) => {
let field = &field.to_ascii_lowercase();
match map_user_field(field.as_str()) {
let field = AttributeName::from(field.as_str());
match map_user_field(&field) {
UserFieldType::PrimaryField(UserColumn::UserId) => Ok(
UserRequestFilter::UserIdSubString(substring_filter.clone().into()),
),

View File

@@ -7,7 +7,9 @@ use crate::domain::{
handler::SubStringFilter,
ldap::error::{LdapError, LdapResult},
schema::{PublicSchema, SchemaAttributeExtractor},
types::{AttributeType, AttributeValue, JpegPhoto, UserColumn, UserId},
types::{
AttributeName, AttributeType, AttributeValue, GroupName, JpegPhoto, UserColumn, UserId,
},
};
impl From<LdapSubstringFilter> for SubStringFilter {
@@ -103,8 +105,8 @@ pub fn get_group_id_from_distinguished_name(
dn: &str,
base_tree: &[(String, String)],
base_dn_str: &str,
) -> LdapResult<String> {
get_id_from_distinguished_name(dn, base_tree, base_dn_str, true)
) -> LdapResult<GroupName> {
get_id_from_distinguished_name(dn, base_tree, base_dn_str, true).map(GroupName::from)
}
#[instrument(skip(all_attribute_keys), level = "debug")]
@@ -160,9 +162,8 @@ pub enum UserFieldType {
Attribute(&'static str),
}
pub fn map_user_field(field: &str) -> UserFieldType {
assert!(field == field.to_ascii_lowercase());
match field {
pub fn map_user_field(field: &AttributeName) -> UserFieldType {
match field.as_str() {
"uid" | "user_id" | "id" => UserFieldType::PrimaryField(UserColumn::UserId),
"mail" | "email" => UserFieldType::PrimaryField(UserColumn::Email),
"cn" | "displayname" | "display_name" => {
@@ -179,9 +180,8 @@ pub fn map_user_field(field: &str) -> UserFieldType {
}
}
pub fn map_group_field(field: &str) -> Option<&'static str> {
assert!(field == field.to_ascii_lowercase());
Some(match field {
pub fn map_group_field(field: &AttributeName) -> Option<&'static str> {
Some(match field.as_str() {
"cn" | "displayname" | "uid" | "display_name" => "display_name",
"creationdate" | "createtimestamp" | "modifytimestamp" | "creation_date" => "creation_date",
"entryuuid" | "uuid" => "uuid",
@@ -192,13 +192,13 @@ pub fn map_group_field(field: &str) -> Option<&'static str> {
pub struct LdapInfo {
pub base_dn: Vec<(String, String)>,
pub base_dn_str: String,
pub ignored_user_attributes: Vec<String>,
pub ignored_group_attributes: Vec<String>,
pub ignored_user_attributes: Vec<AttributeName>,
pub ignored_group_attributes: Vec<AttributeName>,
}
pub fn get_custom_attribute<Extractor: SchemaAttributeExtractor>(
attributes: &[AttributeValue],
attribute_name: &str,
attribute_name: &AttributeName,
schema: &PublicSchema,
) -> Option<Vec<Vec<u8>>> {
let convert_date = |date| {
@@ -212,7 +212,7 @@ pub fn get_custom_attribute<Extractor: SchemaAttributeExtractor>(
.and_then(|attribute_type| {
attributes
.iter()
.find(|a| a.name == attribute_name)
.find(|a| &a.name == attribute_name)
.map(|attribute| match attribute_type {
(AttributeType::String, false) => {
vec![attribute.value.unwrap::<String>().into_bytes()]

View File

@@ -1,7 +1,10 @@
use sea_orm::entity::prelude::*;
use serde::{Deserialize, Serialize};
use crate::domain::{handler::AttributeSchema, types::AttributeType};
use crate::domain::{
handler::AttributeSchema,
types::{AttributeName, AttributeType},
};
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)]
#[sea_orm(table_name = "group_attribute_schema")]
@@ -11,7 +14,7 @@ pub struct Model {
auto_increment = false,
column_name = "group_attribute_schema_name"
)]
pub attribute_name: String,
pub attribute_name: AttributeName,
#[sea_orm(column_name = "group_attribute_schema_type")]
pub attribute_type: AttributeType,
#[sea_orm(column_name = "group_attribute_schema_is_list")]

View File

@@ -1,7 +1,7 @@
use sea_orm::entity::prelude::*;
use serde::{Deserialize, Serialize};
use crate::domain::types::{AttributeValue, GroupId, Serialized};
use crate::domain::types::{AttributeName, AttributeValue, GroupId, Serialized};
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)]
#[sea_orm(table_name = "group_attributes")]
@@ -17,7 +17,7 @@ pub struct Model {
auto_increment = false,
column_name = "group_attribute_name"
)]
pub attribute_name: String,
pub attribute_name: AttributeName,
#[sea_orm(column_name = "group_attribute_value")]
pub value: Serialized,
}

View File

@@ -3,14 +3,15 @@
use sea_orm::entity::prelude::*;
use serde::{Deserialize, Serialize};
use crate::domain::types::{GroupId, Uuid};
use crate::domain::types::{GroupId, GroupName, Uuid};
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)]
#[sea_orm(table_name = "groups")]
pub struct Model {
#[sea_orm(primary_key, auto_increment = false)]
pub group_id: GroupId,
pub display_name: String,
pub display_name: GroupName,
pub lowercase_display_name: String,
pub creation_date: chrono::NaiveDateTime,
pub uuid: Uuid,
}

View File

@@ -1,7 +1,10 @@
use sea_orm::entity::prelude::*;
use serde::{Deserialize, Serialize};
use crate::domain::{handler::AttributeSchema, types::AttributeType};
use crate::domain::{
handler::AttributeSchema,
types::{AttributeName, AttributeType},
};
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)]
#[sea_orm(table_name = "user_attribute_schema")]
@@ -11,7 +14,7 @@ pub struct Model {
auto_increment = false,
column_name = "user_attribute_schema_name"
)]
pub attribute_name: String,
pub attribute_name: AttributeName,
#[sea_orm(column_name = "user_attribute_schema_type")]
pub attribute_type: AttributeType,
#[sea_orm(column_name = "user_attribute_schema_is_list")]

View File

@@ -1,7 +1,7 @@
use sea_orm::entity::prelude::*;
use serde::{Deserialize, Serialize};
use crate::domain::types::{AttributeValue, Serialized, UserId};
use crate::domain::types::{AttributeName, AttributeValue, Serialized, UserId};
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)]
#[sea_orm(table_name = "user_attributes")]
@@ -17,7 +17,7 @@ pub struct Model {
auto_increment = false,
column_name = "user_attribute_name"
)]
pub attribute_name: String,
pub attribute_name: AttributeName,
#[sea_orm(column_name = "user_attribute_value")]
pub value: Serialized,
}

View File

@@ -3,7 +3,7 @@
use sea_orm::{entity::prelude::*, sea_query::BlobSize};
use serde::{Deserialize, Serialize};
use crate::domain::types::{UserId, Uuid};
use crate::domain::types::{Email, UserId, Uuid};
#[derive(Copy, Clone, Default, Debug, DeriveEntity)]
pub struct Entity;
@@ -13,7 +13,8 @@ pub struct Entity;
pub struct Model {
#[sea_orm(primary_key, auto_increment = false)]
pub user_id: UserId,
pub email: String,
pub email: Email,
pub lowercase_email: String,
pub display_name: Option<String>,
pub creation_date: chrono::NaiveDateTime,
pub password_hash: Option<Vec<u8>>,
@@ -32,6 +33,7 @@ impl EntityName for Entity {
pub enum Column {
UserId,
Email,
LowercaseEmail,
DisplayName,
CreationDate,
PasswordHash,
@@ -47,6 +49,7 @@ impl ColumnTrait for Column {
match self {
Column::UserId => ColumnType::String(Some(255)),
Column::Email => ColumnType::String(Some(255)),
Column::LowercaseEmail => ColumnType::String(Some(255)),
Column::DisplayName => ColumnType::String(Some(255)),
Column::CreationDate => ColumnType::DateTime,
Column::PasswordHash => ColumnType::Binary(BlobSize::Medium),

View File

@@ -37,7 +37,7 @@ impl From<Schema> for PublicSchema {
fn from(mut schema: Schema) -> Self {
schema.user_attributes.attributes.extend_from_slice(&[
AttributeSchema {
name: "user_id".to_owned(),
name: "user_id".into(),
attribute_type: AttributeType::String,
is_list: false,
is_visible: true,
@@ -45,7 +45,7 @@ impl From<Schema> for PublicSchema {
is_hardcoded: true,
},
AttributeSchema {
name: "creation_date".to_owned(),
name: "creation_date".into(),
attribute_type: AttributeType::DateTime,
is_list: false,
is_visible: true,
@@ -53,7 +53,7 @@ impl From<Schema> for PublicSchema {
is_hardcoded: true,
},
AttributeSchema {
name: "mail".to_owned(),
name: "mail".into(),
attribute_type: AttributeType::String,
is_list: false,
is_visible: true,
@@ -61,7 +61,7 @@ impl From<Schema> for PublicSchema {
is_hardcoded: true,
},
AttributeSchema {
name: "uuid".to_owned(),
name: "uuid".into(),
attribute_type: AttributeType::String,
is_list: false,
is_visible: true,
@@ -69,7 +69,7 @@ impl From<Schema> for PublicSchema {
is_hardcoded: true,
},
AttributeSchema {
name: "display_name".to_owned(),
name: "display_name".into(),
attribute_type: AttributeType::String,
is_list: false,
is_visible: true,
@@ -83,7 +83,7 @@ impl From<Schema> for PublicSchema {
.sort_by(|a, b| a.name.cmp(&b.name));
schema.group_attributes.attributes.extend_from_slice(&[
AttributeSchema {
name: "group_id".to_owned(),
name: "group_id".into(),
attribute_type: AttributeType::Integer,
is_list: false,
is_visible: true,
@@ -91,7 +91,7 @@ impl From<Schema> for PublicSchema {
is_hardcoded: true,
},
AttributeSchema {
name: "creation_date".to_owned(),
name: "creation_date".into(),
attribute_type: AttributeType::DateTime,
is_list: false,
is_visible: true,
@@ -99,7 +99,7 @@ impl From<Schema> for PublicSchema {
is_hardcoded: true,
},
AttributeSchema {
name: "uuid".to_owned(),
name: "uuid".into(),
attribute_type: AttributeType::String,
is_list: false,
is_visible: true,
@@ -107,7 +107,7 @@ impl From<Schema> for PublicSchema {
is_hardcoded: true,
},
AttributeSchema {
name: "display_name".to_owned(),
name: "display_name".into(),
attribute_type: AttributeType::String,
is_list: false,
is_visible: true,

View File

@@ -87,7 +87,7 @@ pub mod tests {
handler
.create_user(CreateUserRequest {
user_id: UserId::new(name),
email: format!("{}@bob.bob", name),
email: format!("{}@bob.bob", name).into(),
display_name: Some("display ".to_string() + name),
first_name: Some("first ".to_string() + name),
last_name: Some("last ".to_string() + name),
@@ -100,7 +100,7 @@ pub mod tests {
pub async fn insert_group(handler: &SqlBackendHandler, name: &str) -> GroupId {
handler
.create_group(CreateGroupRequest {
display_name: name.to_owned(),
display_name: name.into(),
..Default::default()
})
.await

View File

@@ -37,7 +37,9 @@ fn get_group_filter_expr(filter: GroupRequestFilter) -> Cond {
}
}
Not(f) => get_group_filter_expr(*f).not(),
DisplayName(name) => GroupColumn::DisplayName.eq(name).into_condition(),
DisplayName(name) => GroupColumn::LowercaseDisplayName
.eq(name.as_str().to_lowercase())
.into_condition(),
GroupId(id) => GroupColumn::GroupId.eq(id.0).into_condition(),
Uuid(uuid) => GroupColumn::Uuid.eq(uuid.to_string()).into_condition(),
// WHERE (group_id in (SELECT group_id FROM memberships WHERE user_id = user))
@@ -153,9 +155,11 @@ impl GroupBackendHandler for SqlBackendHandler {
#[instrument(skip(self), level = "debug", ret, err)]
async fn create_group(&self, request: CreateGroupRequest) -> Result<GroupId> {
let now = chrono::Utc::now().naive_utc();
let uuid = Uuid::from_name_and_date(&request.display_name, &now);
let uuid = Uuid::from_name_and_date(request.display_name.as_str(), &now);
let lower_display_name = request.display_name.as_str().to_lowercase();
let new_group = model::groups::ActiveModel {
display_name: Set(request.display_name),
lowercase_display_name: Set(lower_display_name),
creation_date: Set(now),
uuid: Set(uuid),
..Default::default()
@@ -217,9 +221,14 @@ impl SqlBackendHandler {
request: UpdateGroupRequest,
transaction: &DatabaseTransaction,
) -> Result<()> {
let lower_display_name = request
.display_name
.as_ref()
.map(|s| s.as_str().to_lowercase());
let update_group = model::groups::ActiveModel {
group_id: Set(request.group_id),
display_name: request.display_name.map(Set).unwrap_or_default(),
lowercase_display_name: lower_display_name.map(Set).unwrap_or_default(),
..Default::default()
};
update_group.update(transaction).await?;
@@ -288,7 +297,7 @@ mod tests {
use crate::domain::{
handler::{CreateAttributeRequest, SchemaBackendHandler, SubStringFilter},
sql_backend_handler::tests::*,
types::{AttributeType, Serialized, UserId},
types::{AttributeType, GroupName, Serialized, UserId},
};
use pretty_assertions::assert_eq;
@@ -308,7 +317,7 @@ mod tests {
async fn get_group_names(
handler: &SqlBackendHandler,
filters: Option<GroupRequestFilter>,
) -> Vec<String> {
) -> Vec<GroupName> {
handler
.list_groups(filters)
.await
@@ -324,9 +333,9 @@ mod tests {
assert_eq!(
get_group_names(&fixture.handler, None).await,
vec![
"Best Group".to_owned(),
"Empty Group".to_owned(),
"Worst Group".to_owned()
"Best Group".into(),
"Empty Group".into(),
"Worst Group".into()
]
);
}
@@ -338,12 +347,25 @@ mod tests {
get_group_names(
&fixture.handler,
Some(GroupRequestFilter::Or(vec![
GroupRequestFilter::DisplayName("Empty Group".to_owned()),
GroupRequestFilter::DisplayName("Empty Group".into()),
GroupRequestFilter::Member(UserId::new("bob")),
]))
)
.await,
vec!["Best Group".to_owned(), "Empty Group".to_owned()]
vec!["Best Group".into(), "Empty Group".into()]
);
}
#[tokio::test]
async fn test_list_groups_case_insensitive_filter() {
let fixture = TestFixture::new().await;
assert_eq!(
get_group_names(
&fixture.handler,
Some(GroupRequestFilter::DisplayName("eMpTy gRoup".into()),)
)
.await,
vec!["Empty Group".into()]
);
}
@@ -355,7 +377,7 @@ mod tests {
&fixture.handler,
Some(GroupRequestFilter::And(vec![
GroupRequestFilter::Not(Box::new(GroupRequestFilter::DisplayName(
"value".to_owned()
"value".into()
))),
GroupRequestFilter::GroupId(fixture.groups[0]),
]))
@@ -392,7 +414,7 @@ mod tests {
.await
.unwrap();
assert_eq!(details.group_id, fixture.groups[0]);
assert_eq!(details.display_name, "Best Group");
assert_eq!(details.display_name, "Best Group".into());
assert_eq!(
get_group_ids(
&fixture.handler,
@@ -410,7 +432,7 @@ mod tests {
.handler
.update_group(UpdateGroupRequest {
group_id: fixture.groups[0],
display_name: Some("Awesomest Group".to_owned()),
display_name: Some("Awesomest Group".into()),
delete_attributes: Vec::new(),
insert_attributes: Vec::new(),
})
@@ -421,7 +443,7 @@ mod tests {
.get_group_details(fixture.groups[0])
.await
.unwrap();
assert_eq!(details.display_name, "Awesomest Group");
assert_eq!(details.display_name, "Awesomest Group".into());
}
#[tokio::test]
@@ -452,7 +474,7 @@ mod tests {
fixture
.handler
.add_group_attribute(CreateAttributeRequest {
name: "new_attribute".to_owned(),
name: "new_attribute".into(),
attribute_type: AttributeType::String,
is_list: false,
is_visible: true,
@@ -463,9 +485,9 @@ mod tests {
let new_group_id = fixture
.handler
.create_group(CreateGroupRequest {
display_name: "New Group".to_owned(),
display_name: "New Group".into(),
attributes: vec![AttributeValue {
name: "new_attribute".to_owned(),
name: "new_attribute".into(),
value: Serialized::from("value"),
}],
})
@@ -476,11 +498,11 @@ mod tests {
.get_group_details(new_group_id)
.await
.unwrap();
assert_eq!(&group_details.display_name, "New Group");
assert_eq!(group_details.display_name, "New Group".into());
assert_eq!(
group_details.attributes,
vec![AttributeValue {
name: "new_attribute".to_owned(),
name: "new_attribute".into(),
value: Serialized::from("value"),
}]
);
@@ -492,7 +514,7 @@ mod tests {
fixture
.handler
.add_group_attribute(CreateAttributeRequest {
name: "new_attribute".to_owned(),
name: "new_attribute".into(),
attribute_type: AttributeType::Integer,
is_list: false,
is_visible: true,
@@ -502,7 +524,7 @@ mod tests {
.unwrap();
let group_id = fixture.groups[0];
let attributes = vec![AttributeValue {
name: "new_attribute".to_owned(),
name: "new_attribute".into(),
value: Serialized::from(&42i64),
}];
fixture
@@ -522,7 +544,7 @@ mod tests {
.update_group(UpdateGroupRequest {
group_id,
display_name: None,
delete_attributes: vec!["new_attribute".to_owned()],
delete_attributes: vec!["new_attribute".into()],
insert_attributes: Vec::new(),
})
.await

View File

@@ -18,6 +18,7 @@ pub enum Users {
Table,
UserId,
Email,
LowercaseEmail,
DisplayName,
FirstName,
LastName,
@@ -34,6 +35,7 @@ pub enum Groups {
Table,
GroupId,
DisplayName,
LowercaseDisplayName,
CreationDate,
Uuid,
}
@@ -875,6 +877,53 @@ async fn migrate_to_v5(transaction: DatabaseTransaction) -> Result<DatabaseTrans
Ok(transaction)
}
async fn migrate_to_v6(transaction: DatabaseTransaction) -> Result<DatabaseTransaction, DbErr> {
let builder = transaction.get_database_backend();
transaction
.execute(
builder.build(
Table::alter().table(Groups::Table).add_column(
ColumnDef::new(Groups::LowercaseDisplayName)
.string_len(255)
.not_null()
.default("UNSET"),
),
),
)
.await?;
transaction
.execute(
builder.build(
Table::alter().table(Users::Table).add_column(
ColumnDef::new(Users::LowercaseEmail)
.string_len(255)
.not_null()
.default("UNSET"),
),
),
)
.await?;
transaction
.execute(builder.build(Query::update().table(Groups::Table).value(
Groups::LowercaseDisplayName,
Func::lower(Expr::col(Groups::DisplayName)),
)))
.await?;
transaction
.execute(
builder.build(
Query::update()
.table(Users::Table)
.value(Users::LowercaseEmail, Func::lower(Expr::col(Users::Email))),
),
)
.await?;
Ok(transaction)
}
// This is needed to make an array of async functions.
macro_rules! to_sync {
($l:ident) => {
@@ -900,6 +949,7 @@ pub async fn migrate_from_version(
to_sync!(migrate_to_v3),
to_sync!(migrate_to_v4),
to_sync!(migrate_to_v5),
to_sync!(migrate_to_v6),
];
assert_eq!(migrations.len(), (LAST_SCHEMA_VERSION.0 - 1) as usize);
for migration in 2..=last_version.0 {

View File

@@ -6,6 +6,7 @@ use crate::domain::{
},
model,
sql_backend_handler::SqlBackendHandler,
types::AttributeName,
};
use async_trait::async_trait;
use sea_orm::{
@@ -52,15 +53,15 @@ impl SchemaBackendHandler for SqlBackendHandler {
Ok(())
}
async fn delete_user_attribute(&self, name: &str) -> Result<()> {
model::UserAttributeSchema::delete_by_id(name)
async fn delete_user_attribute(&self, name: &AttributeName) -> Result<()> {
model::UserAttributeSchema::delete_by_id(name.clone())
.exec(&self.sql_pool)
.await?;
Ok(())
}
async fn delete_group_attribute(&self, name: &str) -> Result<()> {
model::GroupAttributeSchema::delete_by_id(name)
async fn delete_group_attribute(&self, name: &AttributeName) -> Result<()> {
model::GroupAttributeSchema::delete_by_id(name.clone())
.exec(&self.sql_pool)
.await?;
Ok(())
@@ -123,7 +124,7 @@ mod tests {
user_attributes: AttributeList {
attributes: vec![
AttributeSchema {
name: "avatar".to_owned(),
name: "avatar".into(),
attribute_type: AttributeType::JpegPhoto,
is_list: false,
is_visible: true,
@@ -131,7 +132,7 @@ mod tests {
is_hardcoded: true,
},
AttributeSchema {
name: "first_name".to_owned(),
name: "first_name".into(),
attribute_type: AttributeType::String,
is_list: false,
is_visible: true,
@@ -139,7 +140,7 @@ mod tests {
is_hardcoded: true,
},
AttributeSchema {
name: "last_name".to_owned(),
name: "last_name".into(),
attribute_type: AttributeType::String,
is_list: false,
is_visible: true,
@@ -159,7 +160,7 @@ mod tests {
async fn test_user_attribute_add_and_delete() {
let fixture = TestFixture::new().await;
let new_attribute = CreateAttributeRequest {
name: "new_attribute".to_owned(),
name: "new_attribute".into(),
attribute_type: AttributeType::Integer,
is_list: true,
is_visible: false,
@@ -171,7 +172,7 @@ mod tests {
.await
.unwrap();
let expected_value = AttributeSchema {
name: "new_attribute".to_owned(),
name: "new_attribute".into(),
attribute_type: AttributeType::Integer,
is_list: true,
is_visible: false,
@@ -188,7 +189,7 @@ mod tests {
.contains(&expected_value));
fixture
.handler
.delete_user_attribute("new_attribute")
.delete_user_attribute(&"new_attribute".into())
.await
.unwrap();
assert!(!fixture
@@ -205,7 +206,7 @@ mod tests {
async fn test_group_attribute_add_and_delete() {
let fixture = TestFixture::new().await;
let new_attribute = CreateAttributeRequest {
name: "new_attribute".to_owned(),
name: "NeW_aTTribute".into(),
attribute_type: AttributeType::JpegPhoto,
is_list: false,
is_visible: true,
@@ -217,7 +218,7 @@ mod tests {
.await
.unwrap();
let expected_value = AttributeSchema {
name: "new_attribute".to_owned(),
name: "new_attribute".into(),
attribute_type: AttributeType::JpegPhoto,
is_list: false,
is_visible: true,
@@ -234,7 +235,7 @@ mod tests {
.contains(&expected_value));
fixture
.handler
.delete_group_attribute("new_attribute")
.delete_group_attribute(&"new_attriBUte".into())
.await
.unwrap();
assert!(!fixture

View File

@@ -6,7 +6,7 @@ pub type DbConnection = sea_orm::DatabaseConnection;
#[derive(Copy, PartialEq, Eq, Debug, Clone, PartialOrd, Ord, DeriveValueType)]
pub struct SchemaVersion(pub i16);
pub const LAST_SCHEMA_VERSION: SchemaVersion = SchemaVersion(5);
pub const LAST_SCHEMA_VERSION: SchemaVersion = SchemaVersion(6);
pub async fn init_table(pool: &DbConnection) -> anyhow::Result<()> {
let version = {
@@ -51,8 +51,8 @@ mod tests {
sql_pool
.execute(raw_statement(
r#"INSERT INTO users
(user_id, email, display_name, creation_date, password_hash, uuid)
VALUES ("bôb", "böb@bob.bob", "Bob Bobbersön", "1970-01-01 00:00:00", "bob00", "abc")"#,
(user_id, email, lowercase_email, display_name, creation_date, password_hash, uuid)
VALUES ("bôb", "böb@bob.bob", "böb@bob.bob", "Bob Bobbersön", "1970-01-01 00:00:00", "bob00", "abc")"#,
))
.await
.unwrap();
@@ -373,6 +373,83 @@ mod tests {
);
}
#[tokio::test]
async fn test_migration_to_v6() {
crate::infra::logging::init_for_tests();
let sql_pool = get_in_memory_db().await;
upgrade_to_v1(&sql_pool).await.unwrap();
migrate_from_version(&sql_pool, SchemaVersion(1), SchemaVersion(5))
.await
.unwrap();
sql_pool
.execute(raw_statement(
r#"INSERT INTO users (user_id, email, display_name, creation_date, uuid)
VALUES ("bob", "BOb@bob.com", "", "1970-01-01 00:00:00", "a02eaf13-48a7-30f6-a3d4-040ff7c52b04")"#,
))
.await
.unwrap();
sql_pool
.execute(raw_statement(
r#"INSERT INTO groups (display_name, creation_date, uuid)
VALUES ("BestGroup", "1970-01-01 00:00:00", "986765a5-3f03-389e-b47b-536b2d6e1bec")"#,
))
.await
.unwrap();
migrate_from_version(&sql_pool, SchemaVersion(5), SchemaVersion(6))
.await
.unwrap();
assert_eq!(
sql_migrations::JustSchemaVersion::find_by_statement(raw_statement(
r#"SELECT version FROM metadata"#
))
.one(&sql_pool)
.await
.unwrap()
.unwrap(),
sql_migrations::JustSchemaVersion {
version: SchemaVersion(6)
}
);
#[derive(FromQueryResult, PartialEq, Eq, Debug)]
struct ShortUserDetails {
email: String,
lowercase_email: String,
}
let result = ShortUserDetails::find_by_statement(raw_statement(
r#"SELECT email, lowercase_email FROM users WHERE user_id = "bob""#,
))
.one(&sql_pool)
.await
.unwrap()
.unwrap();
assert_eq!(
result,
ShortUserDetails {
email: "BOb@bob.com".to_owned(),
lowercase_email: "bob@bob.com".to_owned(),
}
);
#[derive(FromQueryResult, PartialEq, Eq, Debug)]
struct ShortGroupDetails {
display_name: String,
lowercase_display_name: String,
}
let result = ShortGroupDetails::find_by_statement(raw_statement(
r#"SELECT display_name, lowercase_display_name FROM groups"#,
))
.one(&sql_pool)
.await
.unwrap()
.unwrap();
assert_eq!(
result,
ShortGroupDetails {
display_name: "BestGroup".to_owned(),
lowercase_display_name: "bestgroup".to_owned(),
}
);
}
#[tokio::test]
async fn test_too_high_version() {
let sql_pool = get_in_memory_db().await;

View File

@@ -6,7 +6,10 @@ use crate::domain::{
},
model::{self, GroupColumn, UserColumn},
sql_backend_handler::SqlBackendHandler,
types::{AttributeValue, GroupDetails, GroupId, Serialized, User, UserAndGroups, UserId, Uuid},
types::{
AttributeName, AttributeValue, GroupDetails, GroupId, Serialized, User, UserAndGroups,
UserId, Uuid,
},
};
use async_trait::async_trait;
use sea_orm::{
@@ -19,7 +22,7 @@ use sea_orm::{
use std::collections::HashSet;
use tracing::instrument;
fn attribute_condition(name: String, value: String) -> Cond {
fn attribute_condition(name: AttributeName, value: String) -> Cond {
Expr::in_subquery(
Expr::col(UserColumn::UserId.as_column_ref()),
model::UserAttributes::find()
@@ -53,14 +56,17 @@ fn get_user_filter_expr(filter: UserRequestFilter) -> Cond {
Or(fs) => get_repeated_filter(fs, Cond::any(), false),
Not(f) => get_user_filter_expr(*f).not(),
UserId(user_id) => ColumnTrait::eq(&UserColumn::UserId, user_id).into_condition(),
Equality(s1, s2) => {
if s1 == UserColumn::UserId {
Equality(column, value) => {
if column == UserColumn::UserId {
panic!("User id should be wrapped")
} else if column == UserColumn::Email {
ColumnTrait::eq(&UserColumn::LowercaseEmail, value.as_str().to_lowercase())
.into_condition()
} else {
ColumnTrait::eq(&s1, s2).into_condition()
ColumnTrait::eq(&column, value).into_condition()
}
}
AttributeEquality(s1, s2) => attribute_condition(s1, s2),
AttributeEquality(column, value) => attribute_condition(column, value),
MemberOf(group) => Expr::col((group_table, GroupColumn::DisplayName))
.eq(group)
.into_condition(),
@@ -159,9 +165,11 @@ impl SqlBackendHandler {
transaction: &DatabaseTransaction,
request: UpdateUserRequest,
) -> Result<()> {
let lower_email = request.email.as_ref().map(|s| s.as_str().to_lowercase());
let update_user = model::users::ActiveModel {
user_id: ActiveValue::Set(request.user_id.clone()),
email: request.email.map(ActiveValue::Set).unwrap_or_default(),
lowercase_email: lower_email.map(ActiveValue::Set).unwrap_or_default(),
display_name: to_value(&request.display_name),
..Default::default()
};
@@ -173,27 +181,27 @@ impl SqlBackendHandler {
let mut update_user_attributes = Vec::new();
let mut remove_user_attributes = Vec::new();
let mut process_serialized =
|value: ActiveValue<Serialized>, attribute_name: &str| match &value {
|value: ActiveValue<Serialized>, attribute_name: AttributeName| match &value {
ActiveValue::NotSet => {
remove_user_attributes.push(attribute_name.to_owned());
remove_user_attributes.push(attribute_name);
}
ActiveValue::Set(_) => {
update_user_attributes.push(model::user_attributes::ActiveModel {
user_id: Set(request.user_id.clone()),
attribute_name: Set(attribute_name.to_owned()),
attribute_name: Set(attribute_name),
value,
})
}
_ => unreachable!(),
};
if let Some(value) = to_serialized_value(&request.first_name) {
process_serialized(value, "first_name");
process_serialized(value, "first_name".into());
}
if let Some(value) = to_serialized_value(&request.last_name) {
process_serialized(value, "last_name");
process_serialized(value, "last_name".into());
}
if let Some(avatar) = request.avatar {
process_serialized(avatar.into_active_value(), "avatar");
process_serialized(avatar.into_active_value(), "avatar".into());
}
let schema = Self::get_schema_with_transaction(transaction).await?;
for attribute in request.insert_attributes {
@@ -202,7 +210,7 @@ impl SqlBackendHandler {
.get_attribute_type(&attribute.name)
.is_some()
{
process_serialized(ActiveValue::Set(attribute.value), &attribute.name);
process_serialized(ActiveValue::Set(attribute.value), attribute.name.clone());
} else {
return Err(DomainError::InternalError(format!(
"User attribute name {} doesn't exist in the schema, yet was attempted to be inserted in the database",
@@ -287,9 +295,11 @@ impl UserBackendHandler for SqlBackendHandler {
async fn create_user(&self, request: CreateUserRequest) -> Result<()> {
let now = chrono::Utc::now().naive_utc();
let uuid = Uuid::from_name_and_date(request.user_id.as_str(), &now);
let lower_email = request.email.as_str().to_lowercase();
let new_user = model::users::ActiveModel {
user_id: Set(request.user_id.clone()),
email: Set(request.email),
lowercase_email: Set(lower_email),
display_name: to_value(&request.display_name),
creation_date: ActiveValue::Set(now),
uuid: ActiveValue::Set(uuid),
@@ -299,21 +309,21 @@ impl UserBackendHandler for SqlBackendHandler {
if let Some(first_name) = request.first_name {
new_user_attributes.push(model::user_attributes::ActiveModel {
user_id: Set(request.user_id.clone()),
attribute_name: Set("first_name".to_owned()),
attribute_name: Set("first_name".into()),
value: Set(Serialized::from(&first_name)),
});
}
if let Some(last_name) = request.last_name {
new_user_attributes.push(model::user_attributes::ActiveModel {
user_id: Set(request.user_id.clone()),
attribute_name: Set("last_name".to_owned()),
attribute_name: Set("last_name".into()),
value: Set(Serialized::from(&last_name)),
});
}
if let Some(avatar) = request.avatar {
new_user_attributes.push(model::user_attributes::ActiveModel {
user_id: Set(request.user_id.clone()),
attribute_name: Set("avatar".to_owned()),
attribute_name: Set("avatar".into()),
value: Set(Serialized::from(&avatar)),
});
}
@@ -452,7 +462,7 @@ mod tests {
let users = get_user_names(
&fixture.handler,
Some(UserRequestFilter::AttributeEquality(
"first_name".to_string(),
AttributeName::from("first_name"),
"first bob".to_string(),
)),
)
@@ -460,6 +470,30 @@ mod tests {
assert_eq!(users, vec!["bob"]);
}
#[tokio::test]
async fn test_list_users_email_filter_uppercase_email() {
let fixture = TestFixture::new().await;
insert_user_no_password(&fixture.handler, "UppEr").await;
let users_and_emails = fixture
.handler
.list_users(
Some(UserRequestFilter::Equality(
UserColumn::Email,
"uPPer@bob.bob".to_string(),
)),
false,
)
.await
.unwrap()
.into_iter()
.map(|u| (u.user.user_id.to_string(), u.user.email.to_string()))
.collect::<Vec<_>>();
assert_eq!(
users_and_emails,
vec![("upper".to_owned(), "UppEr@bob.bob".to_owned())]
);
}
#[tokio::test]
async fn test_list_users_substring_filter() {
let fixture = TestFixture::new().await;
@@ -503,7 +537,7 @@ mod tests {
let fixture = TestFixture::new().await;
let users = get_user_names(
&fixture.handler,
Some(UserRequestFilter::MemberOf("Best Group".to_string())),
Some(UserRequestFilter::MemberOf("Best Group".into())),
)
.await;
assert_eq!(users, vec!["bob", "patrick"]);
@@ -515,7 +549,7 @@ mod tests {
let users = get_user_names(
&fixture.handler,
Some(UserRequestFilter::Or(vec![
UserRequestFilter::MemberOf("Best Group".to_string()),
UserRequestFilter::MemberOf("Best Group".into()),
UserRequestFilter::Equality(UserColumn::Uuid, "abc".to_string()),
])),
)
@@ -764,7 +798,7 @@ mod tests {
.handler
.update_user(UpdateUserRequest {
user_id: UserId::new("bob"),
email: Some("email".to_string()),
email: Some("email".into()),
display_name: Some("display_name".to_string()),
first_name: Some("first_name".to_string()),
last_name: Some("last_name".to_string()),
@@ -780,21 +814,21 @@ mod tests {
.get_user_details(&UserId::new("bob"))
.await
.unwrap();
assert_eq!(user.email, "email");
assert_eq!(user.email, "email".into());
assert_eq!(user.display_name.unwrap(), "display_name");
assert_eq!(
user.attributes,
vec![
AttributeValue {
name: "avatar".to_owned(),
name: "avatar".into(),
value: Serialized::from(&JpegPhoto::for_tests())
},
AttributeValue {
name: "first_name".to_owned(),
name: "first_name".into(),
value: Serialized::from("first_name")
},
AttributeValue {
name: "last_name".to_owned(),
name: "last_name".into(),
value: Serialized::from("last_name")
}
]
@@ -827,11 +861,11 @@ mod tests {
user.attributes,
vec![
AttributeValue {
name: "avatar".to_owned(),
name: "avatar".into(),
value: Serialized::from(&JpegPhoto::for_tests())
},
AttributeValue {
name: "first_name".to_owned(),
name: "first_name".into(),
value: Serialized::from("first bob")
}
]
@@ -850,7 +884,7 @@ mod tests {
last_name: None,
avatar: None,
insert_attributes: vec![AttributeValue {
name: "first_name".to_owned(),
name: "first_name".into(),
value: Serialized::from("new first"),
}],
..Default::default()
@@ -867,11 +901,11 @@ mod tests {
user.attributes,
vec![
AttributeValue {
name: "first_name".to_owned(),
name: "first_name".into(),
value: Serialized::from("new first")
},
AttributeValue {
name: "last_name".to_owned(),
name: "last_name".into(),
value: Serialized::from("last bob")
}
]
@@ -889,7 +923,7 @@ mod tests {
first_name: None,
last_name: None,
avatar: None,
delete_attributes: vec!["first_name".to_owned()],
delete_attributes: vec!["first_name".into()],
..Default::default()
})
.await
@@ -903,7 +937,7 @@ mod tests {
assert_eq!(
user.attributes,
vec![AttributeValue {
name: "last_name".to_owned(),
name: "last_name".into(),
value: Serialized::from("last bob")
}]
);
@@ -920,9 +954,9 @@ mod tests {
first_name: None,
last_name: None,
avatar: None,
delete_attributes: vec!["first_name".to_owned()],
delete_attributes: vec!["first_name".into()],
insert_attributes: vec![AttributeValue {
name: "first_name".to_owned(),
name: "first_name".into(),
value: Serialized::from("new first"),
}],
..Default::default()
@@ -939,11 +973,11 @@ mod tests {
user.attributes,
vec![
AttributeValue {
name: "first_name".to_owned(),
name: "first_name".into(),
value: Serialized::from("new first")
},
AttributeValue {
name: "last_name".to_owned(),
name: "last_name".into(),
value: Serialized::from("last bob")
},
]
@@ -970,7 +1004,7 @@ mod tests {
.await
.unwrap();
let avatar = AttributeValue {
name: "avatar".to_owned(),
name: "avatar".into(),
value: Serialized::from(&JpegPhoto::for_tests()),
};
assert!(user.attributes.contains(&avatar));
@@ -1000,13 +1034,13 @@ mod tests {
.handler
.create_user(CreateUserRequest {
user_id: UserId::new("james"),
email: "email".to_string(),
email: "email".into(),
display_name: Some("display_name".to_string()),
first_name: None,
last_name: Some("last_name".to_string()),
avatar: Some(JpegPhoto::for_tests()),
attributes: vec![AttributeValue {
name: "first_name".to_owned(),
name: "first_name".into(),
value: Serialized::from("First Name"),
}],
})
@@ -1018,21 +1052,21 @@ mod tests {
.get_user_details(&UserId::new("james"))
.await
.unwrap();
assert_eq!(user.email, "email");
assert_eq!(user.email, "email".into());
assert_eq!(user.display_name.unwrap(), "display_name");
assert_eq!(
user.attributes,
vec![
AttributeValue {
name: "avatar".to_owned(),
name: "avatar".into(),
value: Serialized::from(&JpegPhoto::for_tests())
},
AttributeValue {
name: "first_name".to_owned(),
name: "first_name".into(),
value: Serialized::from("First Name")
},
AttributeValue {
name: "last_name".to_owned(),
name: "last_name".into(),
value: Serialized::from("last_name")
}
]

View File

@@ -1,3 +1,5 @@
use std::cmp::Ordering;
use base64::Engine;
use chrono::{NaiveDateTime, TimeZone};
use sea_orm::{
@@ -121,12 +123,22 @@ impl Serialized {
}
#[derive(
PartialEq, Eq, PartialOrd, Ord, Clone, Debug, Default, Serialize, Deserialize, DeriveValueType,
PartialEq,
Eq,
PartialOrd,
Ord,
Clone,
Debug,
Default,
Hash,
Serialize,
Deserialize,
DeriveValueType,
)]
#[serde(from = "String")]
pub struct UserId(String);
pub struct CaseInsensitiveString(String);
impl UserId {
impl CaseInsensitiveString {
pub fn new(user_id: &str) -> Self {
Self(user_id.to_lowercase())
}
@@ -140,29 +152,185 @@ impl UserId {
}
}
impl std::fmt::Display for UserId {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl From<String> for UserId {
impl From<String> for CaseInsensitiveString {
fn from(s: String) -> Self {
Self::new(&s)
}
}
impl From<&UserId> for Value {
fn from(user_id: &UserId) -> Self {
user_id.as_str().into()
macro_rules! make_case_insensitive_string {
($c:ident) => {
#[derive(
PartialEq,
Eq,
PartialOrd,
Ord,
Clone,
Debug,
Default,
Hash,
Serialize,
Deserialize,
DeriveValueType,
)]
#[serde(from = "CaseInsensitiveString")]
pub struct $c(CaseInsensitiveString);
impl $c {
pub fn new(raw: &str) -> Self {
Self(CaseInsensitiveString::new(raw))
}
pub fn as_str(&self) -> &str {
self.0.as_str()
}
pub fn into_string(self) -> String {
self.0.into_string()
}
}
impl From<CaseInsensitiveString> for $c {
fn from(s: CaseInsensitiveString) -> Self {
Self(s)
}
}
impl From<String> for $c {
fn from(s: String) -> Self {
Self(CaseInsensitiveString::from(s))
}
}
impl From<&str> for $c {
fn from(s: &str) -> Self {
Self::new(s)
}
}
impl std::fmt::Display for $c {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{}", self.0.as_str())
}
}
impl From<&$c> for Value {
fn from(user_id: &$c) -> Self {
user_id.as_str().into()
}
}
impl TryFromU64 for $c {
fn try_from_u64(_n: u64) -> Result<Self, DbErr> {
Err(DbErr::ConvertFromU64("$c cannot be constructed from u64"))
}
}
};
}
fn compare_str_case_insensitive(s1: &str, s2: &str) -> Ordering {
let mut it_1 = s1.chars().flat_map(|c| c.to_lowercase());
let mut it_2 = s2.chars().flat_map(|c| c.to_lowercase());
loop {
match (it_1.next(), it_2.next()) {
(Some(c1), Some(c2)) => {
let o = c1.cmp(&c2);
if o != Ordering::Equal {
return o;
}
}
(None, Some(_)) => return Ordering::Less,
(Some(_), None) => return Ordering::Greater,
(None, None) => return Ordering::Equal,
}
}
}
impl TryFromU64 for UserId {
fn try_from_u64(_n: u64) -> Result<Self, DbErr> {
Err(DbErr::ConvertFromU64(
"UserId cannot be constructed from u64",
))
macro_rules! make_case_insensitive_comparable_string {
($c:ident) => {
#[derive(Clone, Debug, Default, Serialize, Deserialize, DeriveValueType)]
pub struct $c(String);
impl PartialEq for $c {
fn eq(&self, other: &Self) -> bool {
compare_str_case_insensitive(&self.0, &other.0) == Ordering::Equal
}
}
impl Eq for $c {}
impl PartialOrd for $c {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(compare_str_case_insensitive(&self.0, &other.0))
}
}
impl Ord for $c {
fn cmp(&self, other: &Self) -> Ordering {
compare_str_case_insensitive(&self.0, &other.0)
}
}
impl std::hash::Hash for $c {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.0.to_lowercase().hash(state)
}
}
impl $c {
pub fn new(raw: &str) -> Self {
Self(raw.to_owned())
}
pub fn as_str(&self) -> &str {
self.0.as_str()
}
pub fn into_string(self) -> String {
self.0
}
}
impl From<String> for $c {
fn from(s: String) -> Self {
Self(s)
}
}
impl From<&str> for $c {
fn from(s: &str) -> Self {
Self::new(s)
}
}
impl std::fmt::Display for $c {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{}", self.0.as_str())
}
}
impl From<&$c> for Value {
fn from(user_id: &$c) -> Self {
user_id.as_str().into()
}
}
impl TryFromU64 for $c {
fn try_from_u64(_n: u64) -> Result<Self, DbErr> {
Err(DbErr::ConvertFromU64("$c cannot be constructed from u64"))
}
}
};
}
make_case_insensitive_string!(UserId);
make_case_insensitive_string!(AttributeName);
make_case_insensitive_comparable_string!(Email);
make_case_insensitive_comparable_string!(GroupName);
impl AsRef<GroupName> for GroupName {
fn as_ref(&self) -> &GroupName {
self
}
}
@@ -283,14 +451,14 @@ impl IntoActiveValue<Serialized> for JpegPhoto {
#[derive(PartialEq, Eq, Debug, Clone, Serialize, Deserialize, Hash)]
pub struct AttributeValue {
pub name: String,
pub name: AttributeName,
pub value: Serialized,
}
#[derive(PartialEq, Eq, Debug, Clone, Serialize, Deserialize)]
pub struct User {
pub user_id: UserId,
pub email: String,
pub email: Email,
pub display_name: Option<String>,
pub creation_date: NaiveDateTime,
pub uuid: Uuid,
@@ -303,7 +471,7 @@ impl Default for User {
let epoch = chrono::Utc.timestamp_opt(0, 0).unwrap().naive_utc();
User {
user_id: UserId::default(),
email: String::new(),
email: Email::default(),
display_name: None,
creation_date: epoch,
uuid: Uuid::from_name_and_date("", &epoch),
@@ -397,7 +565,7 @@ impl ValueType for AttributeType {
#[derive(PartialEq, Eq, Debug, Serialize, Deserialize)]
pub struct Group {
pub id: GroupId,
pub display_name: String,
pub display_name: GroupName,
pub creation_date: NaiveDateTime,
pub uuid: Uuid,
pub users: Vec<UserId>,
@@ -407,7 +575,7 @@ pub struct Group {
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct GroupDetails {
pub group_id: GroupId,
pub display_name: String,
pub display_name: GroupName,
pub creation_date: NaiveDateTime,
pub uuid: Uuid,
pub attributes: Vec<AttributeValue>,