server: Move the definition of UserId down to lldap_auth

This commit is contained in:
Valentin Tolmer
2024-01-15 23:37:42 +01:00
committed by nitnelave
parent 10609b25e9
commit 2ea17c04ba
18 changed files with 212 additions and 162 deletions

View File

@@ -63,7 +63,7 @@ pub mod tests {
opaque::client::registration::start_registration(pass.as_bytes(), &mut rng).unwrap();
let response = handler
.registration_start(registration::ClientRegistrationStartRequest {
username: name.to_string(),
username: name.into(),
registration_start_request: client_registration_start.message,
})
.await

View File

@@ -33,7 +33,7 @@ fn passwords_match(
server_setup,
Some(password_file),
client_login_start_result.message,
username.as_str(),
username,
)?;
client::login::finish_login(
client_login_start_result.state,
@@ -100,15 +100,13 @@ impl OpaqueHandler for SqlOpaqueHandler {
&self,
request: login::ClientLoginStartRequest,
) -> Result<login::ServerLoginStartResponse> {
let user_id = request.username;
let maybe_password_file = self
.get_password_file_for_user(UserId::new(&request.username))
.get_password_file_for_user(user_id.clone())
.await?
.map(|bytes| {
opaque::server::ServerRegistration::deserialize(&bytes).map_err(|_| {
DomainError::InternalError(format!(
"Corrupted password file for {}",
&request.username
))
DomainError::InternalError(format!("Corrupted password file for {}", &user_id))
})
})
.transpose()?;
@@ -120,11 +118,11 @@ impl OpaqueHandler for SqlOpaqueHandler {
self.config.get_server_setup(),
maybe_password_file,
request.login_start_request,
&request.username,
&user_id,
)?;
let secret_key = self.get_orion_secret_key()?;
let server_data = login::ServerData {
username: request.username,
username: user_id,
server_login: start_response.state,
};
let encrypted_state = orion::aead::seal(&secret_key, &bincode::serialize(&server_data)?)?;
@@ -151,7 +149,7 @@ impl OpaqueHandler for SqlOpaqueHandler {
opaque::server::login::finish_login(server_login, request.credential_finalization)?
.session_key;
Ok(UserId::new(&username))
Ok(username)
}
#[instrument(skip_all, level = "debug", err)]
@@ -191,7 +189,7 @@ impl OpaqueHandler for SqlOpaqueHandler {
opaque::server::registration::get_password_file(request.registration_upload);
// Set the user password to the new password.
let user_update = model::users::ActiveModel {
user_id: ActiveValue::Set(UserId::new(&username)),
user_id: ActiveValue::Set(username),
password_hash: ActiveValue::Set(Some(password_file.serialize())),
..Default::default()
};
@@ -204,7 +202,7 @@ impl OpaqueHandler for SqlOpaqueHandler {
#[instrument(skip_all, level = "debug", err, fields(username = %username.as_str()))]
pub(crate) async fn register_password(
opaque_handler: &SqlOpaqueHandler,
username: &UserId,
username: UserId,
password: &SecUtf8,
) -> Result<()> {
let mut rng = rand::rngs::OsRng;
@@ -213,7 +211,7 @@ pub(crate) async fn register_password(
opaque::client::registration::start_registration(password.unsecure().as_bytes(), &mut rng)?;
let start_response = opaque_handler
.registration_start(ClientRegistrationStartRequest {
username: username.to_string(),
username,
registration_start_request: registration_start.message,
})
.await?;
@@ -245,7 +243,7 @@ mod tests {
let login_start = opaque::client::login::start_login(password, &mut rng)?;
let start_response = opaque_handler
.login_start(ClientLoginStartRequest {
username: username.to_string(),
username: UserId::new(username),
login_start_request: login_start.message,
})
.await?;
@@ -276,7 +274,7 @@ mod tests {
.unwrap_err();
register_password(
&opaque_handler,
&UserId::new("bob"),
UserId::new("bob"),
&secstr::SecUtf8::from("bob00"),
)
.await?;

View File

@@ -2,6 +2,7 @@ use std::cmp::Ordering;
use base64::Engine;
use chrono::{NaiveDateTime, TimeZone};
use lldap_auth::types::CaseInsensitiveString;
use sea_orm::{
entity::IntoActiveValue,
sea_query::{value::ValueType, ArrayType, BlobSize, ColumnType, Nullable, ValueTypeErr},
@@ -11,6 +12,7 @@ use serde::{Deserialize, Serialize};
use strum::{EnumString, IntoStaticStr};
pub use super::model::UserColumn;
pub use lldap_auth::types::UserId;
#[derive(PartialEq, Hash, Eq, Clone, Debug, Default, Serialize, Deserialize, DeriveValueType)]
#[serde(try_from = "&str")]
@@ -122,112 +124,6 @@ impl Serialized {
}
}
#[derive(
PartialEq,
Eq,
PartialOrd,
Ord,
Clone,
Debug,
Default,
Hash,
Serialize,
Deserialize,
DeriveValueType,
)]
#[serde(from = "String")]
pub struct CaseInsensitiveString(String);
impl CaseInsensitiveString {
pub fn new(user_id: &str) -> Self {
Self(user_id.to_lowercase())
}
pub fn as_str(&self) -> &str {
self.0.as_str()
}
pub fn into_string(self) -> String {
self.0
}
}
impl From<String> for CaseInsensitiveString {
fn from(s: String) -> Self {
Self::new(&s)
}
}
macro_rules! make_case_insensitive_string {
($c:ident) => {
#[derive(
PartialEq,
Eq,
PartialOrd,
Ord,
Clone,
Debug,
Default,
Hash,
Serialize,
Deserialize,
DeriveValueType,
)]
#[serde(from = "CaseInsensitiveString")]
pub struct $c(CaseInsensitiveString);
impl $c {
pub fn new(raw: &str) -> Self {
Self(CaseInsensitiveString::new(raw))
}
pub fn as_str(&self) -> &str {
self.0.as_str()
}
pub fn into_string(self) -> String {
self.0.into_string()
}
}
impl From<CaseInsensitiveString> for $c {
fn from(s: CaseInsensitiveString) -> Self {
Self(s)
}
}
impl From<String> for $c {
fn from(s: String) -> Self {
Self(CaseInsensitiveString::from(s))
}
}
impl From<&str> for $c {
fn from(s: &str) -> Self {
Self::new(s)
}
}
impl std::fmt::Display for $c {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{}", self.0.as_str())
}
}
impl From<&$c> for Value {
fn from(user_id: &$c) -> Self {
user_id.as_str().into()
}
}
impl TryFromU64 for $c {
fn try_from_u64(_n: u64) -> Result<Self, DbErr> {
Err(DbErr::ConvertFromU64("$c cannot be constructed from u64"))
}
}
};
}
fn compare_str_case_insensitive(s1: &str, s2: &str) -> Ordering {
let mut it_1 = s1.chars().flat_map(|c| c.to_lowercase());
let mut it_2 = s2.chars().flat_map(|c| c.to_lowercase());
@@ -323,8 +219,58 @@ macro_rules! make_case_insensitive_comparable_string {
};
}
make_case_insensitive_string!(UserId);
make_case_insensitive_string!(AttributeName);
#[derive(
PartialEq,
Eq,
PartialOrd,
Ord,
Clone,
Debug,
Default,
Hash,
Serialize,
Deserialize,
DeriveValueType,
)]
#[serde(from = "CaseInsensitiveString")]
pub struct AttributeName(CaseInsensitiveString);
impl AttributeName {
pub fn new(s: &str) -> Self {
s.into()
}
pub fn as_str(&self) -> &str {
self.0.as_str()
}
pub fn into_string(self) -> String {
self.0.into_string()
}
}
impl<T> From<T> for AttributeName
where
T: Into<CaseInsensitiveString>,
{
fn from(s: T) -> Self {
Self(s.into())
}
}
impl std::fmt::Display for AttributeName {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{}", self.0.as_str())
}
}
impl From<&AttributeName> for Value {
fn from(attribute_name: &AttributeName) -> Self {
attribute_name.as_str().into()
}
}
impl TryFromU64 for AttributeName {
fn try_from_u64(_n: u64) -> Result<Self, DbErr> {
Err(DbErr::ConvertFromU64(
"AttributeName cannot be constructed from u64",
))
}
}
make_case_insensitive_comparable_string!(Email);
make_case_insensitive_comparable_string!(GroupName);