domain: introduce UserId to make uid case insensitive
Note that if there was a non-lowercase user already in the DB, it cannot be found again. To fix this, run in the DB: sqlite> UPDATE users SET user_id = LOWER(user_id);
This commit is contained in:
committed by
nitnelave
parent
26cedcb621
commit
ca19e61f50
@@ -3,7 +3,7 @@ use thiserror::Error;
|
||||
#[allow(clippy::enum_variant_names)]
|
||||
#[derive(Error, Debug)]
|
||||
pub enum DomainError {
|
||||
#[error("Authentication error for `{0}`")]
|
||||
#[error("Authentication error: `{0}`")]
|
||||
AuthenticationError(String),
|
||||
#[error("Database error: `{0}`")]
|
||||
DatabaseError(#[from] sqlx::Error),
|
||||
|
||||
@@ -3,10 +3,41 @@ use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashSet;
|
||||
|
||||
#[derive(PartialEq, Eq, Clone, Debug, Default, Serialize, Deserialize)]
|
||||
#[cfg_attr(not(target_arch = "wasm32"), derive(sqlx::FromRow))]
|
||||
#[serde(from = "String")]
|
||||
pub struct UserId(String);
|
||||
|
||||
impl UserId {
|
||||
pub fn new(user_id: &str) -> Self {
|
||||
Self(user_id.to_lowercase())
|
||||
}
|
||||
|
||||
pub fn as_str(&self) -> &str {
|
||||
self.0.as_str()
|
||||
}
|
||||
|
||||
pub fn into_string(self) -> String {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for UserId {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for UserId {
|
||||
fn from(s: String) -> Self {
|
||||
Self::new(&s)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Debug, Serialize, Deserialize)]
|
||||
#[cfg_attr(not(target_arch = "wasm32"), derive(sqlx::FromRow))]
|
||||
pub struct User {
|
||||
pub user_id: String,
|
||||
pub user_id: UserId,
|
||||
pub email: String,
|
||||
pub display_name: String,
|
||||
pub first_name: String,
|
||||
@@ -19,7 +50,7 @@ impl Default for User {
|
||||
fn default() -> Self {
|
||||
use chrono::TimeZone;
|
||||
User {
|
||||
user_id: String::new(),
|
||||
user_id: UserId::default(),
|
||||
email: String::new(),
|
||||
display_name: String::new(),
|
||||
first_name: String::new(),
|
||||
@@ -33,12 +64,12 @@ impl Default for User {
|
||||
pub struct Group {
|
||||
pub id: GroupId,
|
||||
pub display_name: String,
|
||||
pub users: Vec<String>,
|
||||
pub users: Vec<UserId>,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct BindRequest {
|
||||
pub name: String,
|
||||
pub name: UserId,
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
@@ -47,6 +78,7 @@ pub enum UserRequestFilter {
|
||||
And(Vec<UserRequestFilter>),
|
||||
Or(Vec<UserRequestFilter>),
|
||||
Not(Box<UserRequestFilter>),
|
||||
UserId(UserId),
|
||||
Equality(String, String),
|
||||
// Check if a user belongs to a group identified by name.
|
||||
MemberOf(String),
|
||||
@@ -62,13 +94,13 @@ pub enum GroupRequestFilter {
|
||||
DisplayName(String),
|
||||
GroupId(GroupId),
|
||||
// Check if the group contains a user identified by uid.
|
||||
Member(String),
|
||||
Member(UserId),
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Debug, Serialize, Deserialize, Clone, Default)]
|
||||
pub struct CreateUserRequest {
|
||||
// Same fields as User, but no creation_date, and with password.
|
||||
pub user_id: String,
|
||||
pub user_id: UserId,
|
||||
pub email: String,
|
||||
pub display_name: Option<String>,
|
||||
pub first_name: Option<String>,
|
||||
@@ -78,7 +110,7 @@ pub struct CreateUserRequest {
|
||||
#[derive(PartialEq, Eq, Debug, Serialize, Deserialize, Clone, Default)]
|
||||
pub struct UpdateUserRequest {
|
||||
// Same fields as CreateUserRequest, but no with an extra layer of Option.
|
||||
pub user_id: String,
|
||||
pub user_id: UserId,
|
||||
pub email: Option<String>,
|
||||
pub display_name: Option<String>,
|
||||
pub first_name: Option<String>,
|
||||
@@ -106,17 +138,17 @@ pub struct GroupIdAndName(pub GroupId, pub String);
|
||||
pub trait BackendHandler: Clone + Send {
|
||||
async fn list_users(&self, filters: Option<UserRequestFilter>) -> Result<Vec<User>>;
|
||||
async fn list_groups(&self, filters: Option<GroupRequestFilter>) -> Result<Vec<Group>>;
|
||||
async fn get_user_details(&self, user_id: &str) -> Result<User>;
|
||||
async fn get_user_details(&self, user_id: &UserId) -> Result<User>;
|
||||
async fn get_group_details(&self, group_id: GroupId) -> Result<GroupIdAndName>;
|
||||
async fn create_user(&self, request: CreateUserRequest) -> Result<()>;
|
||||
async fn update_user(&self, request: UpdateUserRequest) -> Result<()>;
|
||||
async fn update_group(&self, request: UpdateGroupRequest) -> Result<()>;
|
||||
async fn delete_user(&self, user_id: &str) -> Result<()>;
|
||||
async fn delete_user(&self, user_id: &UserId) -> Result<()>;
|
||||
async fn create_group(&self, group_name: &str) -> Result<GroupId>;
|
||||
async fn delete_group(&self, group_id: GroupId) -> Result<()>;
|
||||
async fn add_user_to_group(&self, user_id: &str, group_id: GroupId) -> Result<()>;
|
||||
async fn remove_user_from_group(&self, user_id: &str, group_id: GroupId) -> Result<()>;
|
||||
async fn get_user_groups(&self, user: &str) -> Result<HashSet<GroupIdAndName>>;
|
||||
async fn add_user_to_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()>;
|
||||
async fn remove_user_from_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()>;
|
||||
async fn get_user_groups(&self, user_id: &UserId) -> Result<HashSet<GroupIdAndName>>;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -129,17 +161,17 @@ mockall::mock! {
|
||||
impl BackendHandler for TestBackendHandler {
|
||||
async fn list_users(&self, filters: Option<UserRequestFilter>) -> Result<Vec<User>>;
|
||||
async fn list_groups(&self, filters: Option<GroupRequestFilter>) -> Result<Vec<Group>>;
|
||||
async fn get_user_details(&self, user_id: &str) -> Result<User>;
|
||||
async fn get_user_details(&self, user_id: &UserId) -> Result<User>;
|
||||
async fn get_group_details(&self, group_id: GroupId) -> Result<GroupIdAndName>;
|
||||
async fn create_user(&self, request: CreateUserRequest) -> Result<()>;
|
||||
async fn update_user(&self, request: UpdateUserRequest) -> Result<()>;
|
||||
async fn update_group(&self, request: UpdateGroupRequest) -> Result<()>;
|
||||
async fn delete_user(&self, user_id: &str) -> Result<()>;
|
||||
async fn delete_user(&self, user_id: &UserId) -> Result<()>;
|
||||
async fn create_group(&self, group_name: &str) -> Result<GroupId>;
|
||||
async fn delete_group(&self, group_id: GroupId) -> Result<()>;
|
||||
async fn get_user_groups(&self, user: &str) -> Result<HashSet<GroupIdAndName>>;
|
||||
async fn add_user_to_group(&self, user_id: &str, group_id: GroupId) -> Result<()>;
|
||||
async fn remove_user_from_group(&self, user_id: &str, group_id: GroupId) -> Result<()>;
|
||||
async fn get_user_groups(&self, user_id: &UserId) -> Result<HashSet<GroupIdAndName>>;
|
||||
async fn add_user_to_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()>;
|
||||
async fn remove_user_from_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()>;
|
||||
}
|
||||
#[async_trait]
|
||||
impl LoginHandler for TestBackendHandler {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use super::error::*;
|
||||
use crate::domain::{error::*, handler::UserId};
|
||||
use async_trait::async_trait;
|
||||
|
||||
pub use lldap_auth::{login, registration};
|
||||
@@ -9,7 +9,7 @@ pub trait OpaqueHandler: Clone + Send {
|
||||
&self,
|
||||
request: login::ClientLoginStartRequest,
|
||||
) -> Result<login::ServerLoginStartResponse>;
|
||||
async fn login_finish(&self, request: login::ClientLoginFinishRequest) -> Result<String>;
|
||||
async fn login_finish(&self, request: login::ClientLoginFinishRequest) -> Result<UserId>;
|
||||
async fn registration_start(
|
||||
&self,
|
||||
request: registration::ClientRegistrationStartRequest,
|
||||
@@ -32,7 +32,7 @@ mockall::mock! {
|
||||
&self,
|
||||
request: login::ClientLoginStartRequest
|
||||
) -> Result<login::ServerLoginStartResponse>;
|
||||
async fn login_finish(&self, request: login::ClientLoginFinishRequest ) -> Result<String>;
|
||||
async fn login_finish(&self, request: login::ClientLoginFinishRequest ) -> Result<UserId>;
|
||||
async fn registration_start(
|
||||
&self,
|
||||
request: registration::ClientRegistrationStartRequest
|
||||
|
||||
@@ -51,12 +51,16 @@ fn get_user_filter_expr(filter: UserRequestFilter) -> (RequiresGroup, SimpleExpr
|
||||
let (requires_group, filters) = get_user_filter_expr(*f);
|
||||
(requires_group, Expr::not(Expr::expr(filters)))
|
||||
}
|
||||
UserId(user_id) => (
|
||||
RequiresGroup(false),
|
||||
Expr::col((Users::Table, Users::UserId)).eq(user_id),
|
||||
),
|
||||
Equality(s1, s2) => (
|
||||
RequiresGroup(false),
|
||||
if s1 == Users::DisplayName.to_string() {
|
||||
Expr::col((Users::Table, Users::DisplayName)).eq(s2)
|
||||
} else if s1 == Users::UserId.to_string() {
|
||||
Expr::col((Users::Table, Users::UserId)).eq(s2)
|
||||
panic!("User id should be wrapped")
|
||||
} else {
|
||||
Expr::expr(Expr::cust(&s1)).eq(s2)
|
||||
},
|
||||
@@ -205,17 +209,17 @@ impl BackendHandler for SqlBackendHandler {
|
||||
id: group_id,
|
||||
display_name,
|
||||
users: rows
|
||||
.map(|row| row.get::<String, _>(&*Memberships::UserId.to_string()))
|
||||
.map(|row| row.get::<UserId, _>(&*Memberships::UserId.to_string()))
|
||||
// If a group has no users, an empty string is returned because of the left
|
||||
// join.
|
||||
.filter(|s| !s.is_empty())
|
||||
.filter(|s| !s.as_str().is_empty())
|
||||
.collect(),
|
||||
});
|
||||
}
|
||||
Ok(groups)
|
||||
}
|
||||
|
||||
async fn get_user_details(&self, user_id: &str) -> Result<User> {
|
||||
async fn get_user_details(&self, user_id: &UserId) -> Result<User> {
|
||||
let query = Query::select()
|
||||
.column(Users::UserId)
|
||||
.column(Users::Email)
|
||||
@@ -246,8 +250,8 @@ impl BackendHandler for SqlBackendHandler {
|
||||
.await?)
|
||||
}
|
||||
|
||||
async fn get_user_groups(&self, user: &str) -> Result<HashSet<GroupIdAndName>> {
|
||||
if user == self.config.ldap_user_dn {
|
||||
async fn get_user_groups(&self, user_id: &UserId) -> Result<HashSet<GroupIdAndName>> {
|
||||
if *user_id == self.config.ldap_user_dn {
|
||||
let mut groups = HashSet::new();
|
||||
groups.insert(GroupIdAndName(GroupId(1), "lldap_admin".to_string()));
|
||||
return Ok(groups);
|
||||
@@ -261,7 +265,7 @@ impl BackendHandler for SqlBackendHandler {
|
||||
Expr::tbl(Groups::Table, Groups::GroupId)
|
||||
.equals(Memberships::Table, Memberships::GroupId),
|
||||
)
|
||||
.and_where(Expr::col(Memberships::UserId).eq(user))
|
||||
.and_where(Expr::col(Memberships::UserId).eq(user_id))
|
||||
.to_string(DbQueryBuilder {});
|
||||
|
||||
sqlx::query(&query)
|
||||
@@ -294,7 +298,7 @@ impl BackendHandler for SqlBackendHandler {
|
||||
Users::CreationDate,
|
||||
];
|
||||
let values = vec![
|
||||
request.user_id.clone().into(),
|
||||
request.user_id.into(),
|
||||
request.email.into(),
|
||||
request.display_name.unwrap_or_default().into(),
|
||||
request.first_name.unwrap_or_default().into(),
|
||||
@@ -353,7 +357,7 @@ impl BackendHandler for SqlBackendHandler {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn delete_user(&self, user_id: &str) -> Result<()> {
|
||||
async fn delete_user(&self, user_id: &UserId) -> Result<()> {
|
||||
let delete_query = Query::delete()
|
||||
.from_table(Users::Table)
|
||||
.and_where(Expr::col(Users::UserId).eq(user_id))
|
||||
@@ -387,7 +391,7 @@ impl BackendHandler for SqlBackendHandler {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn add_user_to_group(&self, user_id: &str, group_id: GroupId) -> Result<()> {
|
||||
async fn add_user_to_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()> {
|
||||
let query = Query::insert()
|
||||
.into_table(Memberships::Table)
|
||||
.columns(vec![Memberships::UserId, Memberships::GroupId])
|
||||
@@ -397,7 +401,7 @@ impl BackendHandler for SqlBackendHandler {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn remove_user_from_group(&self, user_id: &str, group_id: GroupId) -> Result<()> {
|
||||
async fn remove_user_from_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()> {
|
||||
let query = Query::delete()
|
||||
.from_table(Memberships::Table)
|
||||
.and_where(Expr::col(Memberships::GroupId).eq(group_id))
|
||||
@@ -463,7 +467,7 @@ mod tests {
|
||||
async fn insert_user_no_password(handler: &SqlBackendHandler, name: &str) {
|
||||
handler
|
||||
.create_user(CreateUserRequest {
|
||||
user_id: name.to_string(),
|
||||
user_id: UserId::new(name),
|
||||
email: "bob@bob.bob".to_string(),
|
||||
..Default::default()
|
||||
})
|
||||
@@ -476,21 +480,24 @@ mod tests {
|
||||
}
|
||||
|
||||
async fn insert_membership(handler: &SqlBackendHandler, group_id: GroupId, user_id: &str) {
|
||||
handler.add_user_to_group(user_id, group_id).await.unwrap();
|
||||
handler
|
||||
.add_user_to_group(&UserId::new(user_id), group_id)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_bind_admin() {
|
||||
let sql_pool = get_in_memory_db().await;
|
||||
let config = ConfigurationBuilder::default()
|
||||
.ldap_user_dn("admin".to_string())
|
||||
.ldap_user_dn(UserId::new("admin"))
|
||||
.ldap_user_pass(secstr::SecUtf8::from("test"))
|
||||
.build()
|
||||
.unwrap();
|
||||
let handler = SqlBackendHandler::new(config, sql_pool);
|
||||
handler
|
||||
.bind(BindRequest {
|
||||
name: "admin".to_string(),
|
||||
name: UserId::new("admin"),
|
||||
password: "test".to_string(),
|
||||
})
|
||||
.await
|
||||
@@ -506,21 +513,21 @@ mod tests {
|
||||
|
||||
handler
|
||||
.bind(BindRequest {
|
||||
name: "bob".to_string(),
|
||||
name: UserId::new("bob"),
|
||||
password: "bob00".to_string(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
handler
|
||||
.bind(BindRequest {
|
||||
name: "andrew".to_string(),
|
||||
name: UserId::new("andrew"),
|
||||
password: "bob00".to_string(),
|
||||
})
|
||||
.await
|
||||
.unwrap_err();
|
||||
handler
|
||||
.bind(BindRequest {
|
||||
name: "bob".to_string(),
|
||||
name: UserId::new("bob"),
|
||||
password: "wrong_password".to_string(),
|
||||
})
|
||||
.await
|
||||
@@ -536,7 +543,7 @@ mod tests {
|
||||
|
||||
handler
|
||||
.bind(BindRequest {
|
||||
name: "bob".to_string(),
|
||||
name: UserId::new("bob"),
|
||||
password: "bob00".to_string(),
|
||||
})
|
||||
.await
|
||||
@@ -557,47 +564,44 @@ mod tests {
|
||||
.await
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(|u| u.user_id)
|
||||
.map(|u| u.user_id.to_string())
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(users, vec!["John", "bob", "patrick"]);
|
||||
assert_eq!(users, vec!["bob", "john", "patrick"]);
|
||||
}
|
||||
{
|
||||
let users = handler
|
||||
.list_users(Some(UserRequestFilter::Equality(
|
||||
"user_id".to_string(),
|
||||
"bob".to_string(),
|
||||
)))
|
||||
.list_users(Some(UserRequestFilter::UserId(UserId::new("bob"))))
|
||||
.await
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(|u| u.user_id)
|
||||
.map(|u| u.user_id.to_string())
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(users, vec!["bob"]);
|
||||
}
|
||||
{
|
||||
let users = handler
|
||||
.list_users(Some(UserRequestFilter::Or(vec![
|
||||
UserRequestFilter::Equality("user_id".to_string(), "bob".to_string()),
|
||||
UserRequestFilter::Equality("user_id".to_string(), "John".to_string()),
|
||||
UserRequestFilter::UserId(UserId::new("bob")),
|
||||
UserRequestFilter::UserId(UserId::new("John")),
|
||||
])))
|
||||
.await
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(|u| u.user_id)
|
||||
.map(|u| u.user_id.to_string())
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(users, vec!["John", "bob"]);
|
||||
assert_eq!(users, vec!["bob", "john"]);
|
||||
}
|
||||
{
|
||||
let users = handler
|
||||
.list_users(Some(UserRequestFilter::Not(Box::new(
|
||||
UserRequestFilter::Equality("user_id".to_string(), "bob".to_string()),
|
||||
UserRequestFilter::UserId(UserId::new("bob")),
|
||||
))))
|
||||
.await
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(|u| u.user_id)
|
||||
.map(|u| u.user_id.to_string())
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(users, vec!["John", "patrick"]);
|
||||
assert_eq!(users, vec!["john", "patrick"]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -622,7 +626,7 @@ mod tests {
|
||||
Group {
|
||||
id: group_1,
|
||||
display_name: "Best Group".to_string(),
|
||||
users: vec!["bob".to_string(), "patrick".to_string()]
|
||||
users: vec![UserId::new("bob"), UserId::new("patrick")]
|
||||
},
|
||||
Group {
|
||||
id: group_3,
|
||||
@@ -632,7 +636,7 @@ mod tests {
|
||||
Group {
|
||||
id: group_2,
|
||||
display_name: "Worst Group".to_string(),
|
||||
users: vec!["John".to_string(), "patrick".to_string()]
|
||||
users: vec![UserId::new("john"), UserId::new("patrick")]
|
||||
},
|
||||
]
|
||||
);
|
||||
@@ -640,7 +644,7 @@ mod tests {
|
||||
handler
|
||||
.list_groups(Some(GroupRequestFilter::Or(vec![
|
||||
GroupRequestFilter::DisplayName("Empty Group".to_string()),
|
||||
GroupRequestFilter::Member("bob".to_string()),
|
||||
GroupRequestFilter::Member(UserId::new("bob")),
|
||||
])))
|
||||
.await
|
||||
.unwrap(),
|
||||
@@ -648,7 +652,7 @@ mod tests {
|
||||
Group {
|
||||
id: group_1,
|
||||
display_name: "Best Group".to_string(),
|
||||
users: vec!["bob".to_string(), "patrick".to_string()]
|
||||
users: vec![UserId::new("bob"), UserId::new("patrick")]
|
||||
},
|
||||
Group {
|
||||
id: group_3,
|
||||
@@ -670,7 +674,7 @@ mod tests {
|
||||
vec![Group {
|
||||
id: group_1,
|
||||
display_name: "Best Group".to_string(),
|
||||
users: vec!["bob".to_string(), "patrick".to_string()]
|
||||
users: vec![UserId::new("bob"), UserId::new("patrick")]
|
||||
}]
|
||||
);
|
||||
}
|
||||
@@ -682,13 +686,35 @@ mod tests {
|
||||
let handler = SqlBackendHandler::new(config, sql_pool);
|
||||
insert_user(&handler, "bob", "bob00").await;
|
||||
{
|
||||
let user = handler.get_user_details("bob").await.unwrap();
|
||||
assert_eq!(user.user_id, "bob".to_string());
|
||||
let user = handler.get_user_details(&UserId::new("bob")).await.unwrap();
|
||||
assert_eq!(user.user_id.as_str(), "bob");
|
||||
}
|
||||
{
|
||||
handler.get_user_details("John").await.unwrap_err();
|
||||
handler
|
||||
.get_user_details(&UserId::new("John"))
|
||||
.await
|
||||
.unwrap_err();
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_user_lowercase() {
|
||||
let sql_pool = get_initialized_db().await;
|
||||
let config = get_default_config();
|
||||
let handler = SqlBackendHandler::new(config, sql_pool);
|
||||
insert_user(&handler, "Bob", "bob00").await;
|
||||
{
|
||||
let user = handler.get_user_details(&UserId::new("bOb")).await.unwrap();
|
||||
assert_eq!(user.user_id.as_str(), "bob");
|
||||
}
|
||||
{
|
||||
handler
|
||||
.get_user_details(&UserId::new("John"))
|
||||
.await
|
||||
.unwrap_err();
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_user_groups() {
|
||||
let sql_pool = get_initialized_db().await;
|
||||
@@ -707,13 +733,19 @@ mod tests {
|
||||
let mut patrick_groups = HashSet::new();
|
||||
patrick_groups.insert(GroupIdAndName(group_1, "Group1".to_string()));
|
||||
patrick_groups.insert(GroupIdAndName(group_2, "Group2".to_string()));
|
||||
assert_eq!(handler.get_user_groups("bob").await.unwrap(), bob_groups);
|
||||
assert_eq!(
|
||||
handler.get_user_groups("patrick").await.unwrap(),
|
||||
handler.get_user_groups(&UserId::new("bob")).await.unwrap(),
|
||||
bob_groups
|
||||
);
|
||||
assert_eq!(
|
||||
handler
|
||||
.get_user_groups(&UserId::new("patrick"))
|
||||
.await
|
||||
.unwrap(),
|
||||
patrick_groups
|
||||
);
|
||||
assert_eq!(
|
||||
handler.get_user_groups("John").await.unwrap(),
|
||||
handler.get_user_groups(&UserId::new("John")).await.unwrap(),
|
||||
HashSet::new()
|
||||
);
|
||||
}
|
||||
@@ -729,29 +761,29 @@ mod tests {
|
||||
insert_user(&handler, "Jennz", "boupBoup").await;
|
||||
|
||||
// Remove a user
|
||||
let _request_result = handler.delete_user("Jennz").await.unwrap();
|
||||
let _request_result = handler.delete_user(&UserId::new("Jennz")).await.unwrap();
|
||||
|
||||
let users = handler
|
||||
.list_users(None)
|
||||
.await
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(|u| u.user_id)
|
||||
.map(|u| u.user_id.to_string())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(users, vec!["Hector", "val"]);
|
||||
assert_eq!(users, vec!["hector", "val"]);
|
||||
|
||||
// Insert new user and remove two
|
||||
insert_user(&handler, "NewBoi", "Joni").await;
|
||||
let _request_result = handler.delete_user("Hector").await.unwrap();
|
||||
let _request_result = handler.delete_user("NewBoi").await.unwrap();
|
||||
let _request_result = handler.delete_user(&UserId::new("Hector")).await.unwrap();
|
||||
let _request_result = handler.delete_user(&UserId::new("NewBoi")).await.unwrap();
|
||||
|
||||
let users = handler
|
||||
.list_users(None)
|
||||
.await
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(|u| u.user_id)
|
||||
.map(|u| u.user_id.to_string())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(users, vec!["val"]);
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use super::{
|
||||
error::*,
|
||||
handler::{BindRequest, LoginHandler},
|
||||
handler::{BindRequest, LoginHandler, UserId},
|
||||
opaque_handler::*,
|
||||
sql_backend_handler::SqlBackendHandler,
|
||||
sql_tables::*,
|
||||
@@ -18,7 +18,7 @@ fn passwords_match(
|
||||
password_file_bytes: &[u8],
|
||||
clear_password: &str,
|
||||
server_setup: &opaque::server::ServerSetup,
|
||||
username: &str,
|
||||
username: &UserId,
|
||||
) -> Result<()> {
|
||||
use opaque::{client, server};
|
||||
let mut rng = rand::rngs::OsRng;
|
||||
@@ -31,7 +31,7 @@ fn passwords_match(
|
||||
server_setup,
|
||||
Some(password_file),
|
||||
client_login_start_result.message,
|
||||
username,
|
||||
username.as_str(),
|
||||
)?;
|
||||
client::login::finish_login(
|
||||
client_login_start_result.state,
|
||||
@@ -88,13 +88,16 @@ impl LoginHandler for SqlBackendHandler {
|
||||
return Ok(());
|
||||
} else {
|
||||
debug!(r#"Invalid password for LDAP bind user"#);
|
||||
return Err(DomainError::AuthenticationError(request.name));
|
||||
return Err(DomainError::AuthenticationError(format!(
|
||||
" for user '{}'",
|
||||
request.name
|
||||
)));
|
||||
}
|
||||
}
|
||||
let query = Query::select()
|
||||
.column(Users::PasswordHash)
|
||||
.from(Users::Table)
|
||||
.and_where(Expr::col(Users::UserId).eq(request.name.as_str()))
|
||||
.and_where(Expr::col(Users::UserId).eq(&request.name))
|
||||
.to_string(DbQueryBuilder {});
|
||||
if let Ok(row) = sqlx::query(&query).fetch_one(&self.sql_pool).await {
|
||||
if let Some(password_hash) =
|
||||
@@ -106,17 +109,20 @@ impl LoginHandler for SqlBackendHandler {
|
||||
self.config.get_server_setup(),
|
||||
&request.name,
|
||||
) {
|
||||
debug!(r#"Invalid password for "{}": {}"#, request.name, e);
|
||||
debug!(r#"Invalid password for "{}": {}"#, &request.name, e);
|
||||
} else {
|
||||
return Ok(());
|
||||
}
|
||||
} else {
|
||||
debug!(r#"User "{}" has no password"#, request.name);
|
||||
debug!(r#"User "{}" has no password"#, &request.name);
|
||||
}
|
||||
} else {
|
||||
debug!(r#"No user found for "{}""#, request.name);
|
||||
debug!(r#"No user found for "{}""#, &request.name);
|
||||
}
|
||||
Err(DomainError::AuthenticationError(request.name))
|
||||
Err(DomainError::AuthenticationError(format!(
|
||||
" for user '{}'",
|
||||
request.name
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -150,7 +156,7 @@ impl OpaqueHandler for SqlOpaqueHandler {
|
||||
})
|
||||
}
|
||||
|
||||
async fn login_finish(&self, request: login::ClientLoginFinishRequest) -> Result<String> {
|
||||
async fn login_finish(&self, request: login::ClientLoginFinishRequest) -> Result<UserId> {
|
||||
let secret_key = self.get_orion_secret_key()?;
|
||||
let login::ServerData {
|
||||
username,
|
||||
@@ -165,7 +171,7 @@ impl OpaqueHandler for SqlOpaqueHandler {
|
||||
opaque::server::login::finish_login(server_login, request.credential_finalization)?
|
||||
.session_key;
|
||||
|
||||
Ok(username)
|
||||
Ok(UserId::new(&username))
|
||||
}
|
||||
|
||||
async fn registration_start(
|
||||
@@ -220,7 +226,7 @@ impl OpaqueHandler for SqlOpaqueHandler {
|
||||
/// Convenience function to set a user's password.
|
||||
pub(crate) async fn register_password(
|
||||
opaque_handler: &SqlOpaqueHandler,
|
||||
username: &str,
|
||||
username: &UserId,
|
||||
password: &SecUtf8,
|
||||
) -> Result<()> {
|
||||
let mut rng = rand::rngs::OsRng;
|
||||
@@ -278,7 +284,7 @@ mod tests {
|
||||
async fn insert_user_no_password(handler: &SqlBackendHandler, name: &str) {
|
||||
handler
|
||||
.create_user(CreateUserRequest {
|
||||
user_id: name.to_string(),
|
||||
user_id: UserId::new(name),
|
||||
email: "bob@bob.bob".to_string(),
|
||||
..Default::default()
|
||||
})
|
||||
@@ -323,7 +329,12 @@ mod tests {
|
||||
attempt_login(&opaque_handler, "bob", "bob00")
|
||||
.await
|
||||
.unwrap_err();
|
||||
register_password(&opaque_handler, "bob", &secstr::SecUtf8::from("bob00")).await?;
|
||||
register_password(
|
||||
&opaque_handler,
|
||||
&UserId::new("bob"),
|
||||
&secstr::SecUtf8::from("bob00"),
|
||||
)
|
||||
.await?;
|
||||
attempt_login(&opaque_handler, "bob", "wrong_password")
|
||||
.await
|
||||
.unwrap_err();
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use super::handler::GroupId;
|
||||
use super::handler::{GroupId, UserId};
|
||||
use sea_query::*;
|
||||
|
||||
pub type Pool = sqlx::sqlite::SqlitePool;
|
||||
@@ -37,6 +37,43 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<DB> sqlx::Type<DB> for UserId
|
||||
where
|
||||
DB: sqlx::Database,
|
||||
String: sqlx::Type<DB>,
|
||||
{
|
||||
fn type_info() -> <DB as sqlx::Database>::TypeInfo {
|
||||
<String as sqlx::Type<DB>>::type_info()
|
||||
}
|
||||
fn compatible(ty: &<DB as sqlx::Database>::TypeInfo) -> bool {
|
||||
<String as sqlx::Type<DB>>::compatible(ty)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'r, DB> sqlx::Decode<'r, DB> for UserId
|
||||
where
|
||||
DB: sqlx::Database,
|
||||
String: sqlx::Decode<'r, DB>,
|
||||
{
|
||||
fn decode(
|
||||
value: <DB as sqlx::database::HasValueRef<'r>>::ValueRef,
|
||||
) -> Result<Self, Box<dyn std::error::Error + Sync + Send + 'static>> {
|
||||
<String as sqlx::Decode<'r, DB>>::decode(value).map(|s| UserId::new(&s))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<UserId> for sea_query::Value {
|
||||
fn from(user_id: UserId) -> Self {
|
||||
user_id.into_string().into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&UserId> for sea_query::Value {
|
||||
fn from(user_id: &UserId) -> Self {
|
||||
user_id.as_str().into()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Iden)]
|
||||
pub enum Users {
|
||||
Table,
|
||||
|
||||
Reference in New Issue
Block a user