domain: introduce UserId to make uid case insensitive
Note that if there was a non-lowercase user already in the DB, it cannot be found again. To fix this, run in the DB: sqlite> UPDATE users SET user_id = LOWER(user_id);
This commit is contained in:
committed by
nitnelave
parent
26cedcb621
commit
ca19e61f50
@@ -25,7 +25,7 @@ use lldap_auth::{login, opaque, password_reset, registration, JWTClaims};
|
||||
use crate::{
|
||||
domain::{
|
||||
error::DomainError,
|
||||
handler::{BackendHandler, BindRequest, GroupIdAndName, LoginHandler},
|
||||
handler::{BackendHandler, BindRequest, GroupIdAndName, LoginHandler, UserId},
|
||||
opaque_handler::OpaqueHandler,
|
||||
},
|
||||
infra::{
|
||||
@@ -51,7 +51,7 @@ fn create_jwt(key: &Hmac<Sha512>, user: String, groups: HashSet<GroupIdAndName>)
|
||||
jwt::Token::new(header, claims).sign_with_key(key).unwrap()
|
||||
}
|
||||
|
||||
fn parse_refresh_token(token: &str) -> std::result::Result<(u64, String), HttpResponse> {
|
||||
fn parse_refresh_token(token: &str) -> std::result::Result<(u64, UserId), HttpResponse> {
|
||||
match token.split_once('+') {
|
||||
None => Err(HttpResponse::Unauthorized().body("Invalid refresh token")),
|
||||
Some((token, u)) => {
|
||||
@@ -60,12 +60,12 @@ fn parse_refresh_token(token: &str) -> std::result::Result<(u64, String), HttpRe
|
||||
token.hash(&mut s);
|
||||
s.finish()
|
||||
};
|
||||
Ok((refresh_token_hash, u.to_string()))
|
||||
Ok((refresh_token_hash, UserId::new(u)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_refresh_token(request: HttpRequest) -> std::result::Result<(u64, String), HttpResponse> {
|
||||
fn get_refresh_token(request: HttpRequest) -> std::result::Result<(u64, UserId), HttpResponse> {
|
||||
match (
|
||||
request.cookie("refresh_token"),
|
||||
request.headers().get("refresh-token"),
|
||||
@@ -134,14 +134,14 @@ where
|
||||
{
|
||||
let user_id = match request.match_info().get("user_id") {
|
||||
None => return HttpResponse::BadRequest().body("Missing user ID"),
|
||||
Some(id) => id,
|
||||
Some(id) => UserId::new(id),
|
||||
};
|
||||
let token = match data.backend_handler.start_password_reset(user_id).await {
|
||||
let token = match data.backend_handler.start_password_reset(&user_id).await {
|
||||
Err(e) => return HttpResponse::InternalServerError().body(e.to_string()),
|
||||
Ok(None) => return HttpResponse::Ok().finish(),
|
||||
Ok(Some(token)) => token,
|
||||
};
|
||||
let user = match data.backend_handler.get_user_details(user_id).await {
|
||||
let user = match data.backend_handler.get_user_details(&user_id).await {
|
||||
Err(e) => {
|
||||
warn!("Error getting used details: {:#?}", e);
|
||||
return HttpResponse::Ok().finish();
|
||||
@@ -196,7 +196,7 @@ where
|
||||
.finish(),
|
||||
)
|
||||
.json(&password_reset::ServerPasswordResetResponse {
|
||||
user_id,
|
||||
user_id: user_id.to_string(),
|
||||
token: token.as_str().to_owned(),
|
||||
})
|
||||
}
|
||||
@@ -276,7 +276,7 @@ where
|
||||
|
||||
async fn get_login_successful_response<Backend>(
|
||||
data: &web::Data<AppState<Backend>>,
|
||||
name: &str,
|
||||
name: &UserId,
|
||||
) -> HttpResponse
|
||||
where
|
||||
Backend: TcpBackendHandler + BackendHandler,
|
||||
@@ -289,7 +289,7 @@ where
|
||||
.await
|
||||
.map(|(groups, (refresh_token, max_age))| {
|
||||
let token = create_jwt(&data.jwt_key, name.to_string(), groups);
|
||||
let refresh_token_plus_name = refresh_token + "+" + name;
|
||||
let refresh_token_plus_name = refresh_token + "+" + name.as_str();
|
||||
|
||||
HttpResponse::Ok()
|
||||
.cookie(
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
use crate::infra::cli::{GeneralConfigOpts, RunOpts, SmtpOpts, TestEmailOpts};
|
||||
use crate::{
|
||||
domain::handler::UserId,
|
||||
infra::cli::{GeneralConfigOpts, RunOpts, SmtpOpts, TestEmailOpts},
|
||||
};
|
||||
use anyhow::{Context, Result};
|
||||
use figment::{
|
||||
providers::{Env, Format, Serialized, Toml},
|
||||
@@ -49,8 +52,8 @@ pub struct Configuration {
|
||||
pub jwt_secret: SecUtf8,
|
||||
#[builder(default = r#"String::from("dc=example,dc=com")"#)]
|
||||
pub ldap_base_dn: String,
|
||||
#[builder(default = r#"String::from("admin")"#)]
|
||||
pub ldap_user_dn: String,
|
||||
#[builder(default = r#"UserId::new("admin")"#)]
|
||||
pub ldap_user_dn: UserId,
|
||||
#[builder(default = r#"SecUtf8::from("password")"#)]
|
||||
pub ldap_user_pass: SecUtf8,
|
||||
#[builder(default = r#"String::from("sqlite://users.db?mode=rwc")"#)]
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::domain::handler::{
|
||||
BackendHandler, CreateUserRequest, GroupId, UpdateGroupRequest, UpdateUserRequest,
|
||||
BackendHandler, CreateUserRequest, GroupId, UpdateGroupRequest, UpdateUserRequest, UserId,
|
||||
};
|
||||
use juniper::{graphql_object, FieldResult, GraphQLInputObject, GraphQLObject};
|
||||
|
||||
@@ -66,10 +66,11 @@ impl<Handler: BackendHandler + Sync> Mutation<Handler> {
|
||||
if !context.validation_result.is_admin {
|
||||
return Err("Unauthorized user creation".into());
|
||||
}
|
||||
let user_id = UserId::new(&user.id);
|
||||
context
|
||||
.handler
|
||||
.create_user(CreateUserRequest {
|
||||
user_id: user.id.clone(),
|
||||
user_id: user_id.clone(),
|
||||
email: user.email,
|
||||
display_name: user.display_name,
|
||||
first_name: user.first_name,
|
||||
@@ -78,7 +79,7 @@ impl<Handler: BackendHandler + Sync> Mutation<Handler> {
|
||||
.await?;
|
||||
Ok(context
|
||||
.handler
|
||||
.get_user_details(&user.id)
|
||||
.get_user_details(&user_id)
|
||||
.await
|
||||
.map(Into::into)?)
|
||||
}
|
||||
@@ -108,7 +109,7 @@ impl<Handler: BackendHandler + Sync> Mutation<Handler> {
|
||||
context
|
||||
.handler
|
||||
.update_user(UpdateUserRequest {
|
||||
user_id: user.id,
|
||||
user_id: UserId::new(&user.id),
|
||||
email: user.email,
|
||||
display_name: user.display_name,
|
||||
first_name: user.first_name,
|
||||
@@ -148,7 +149,7 @@ impl<Handler: BackendHandler + Sync> Mutation<Handler> {
|
||||
}
|
||||
context
|
||||
.handler
|
||||
.add_user_to_group(&user_id, GroupId(group_id))
|
||||
.add_user_to_group(&UserId::new(&user_id), GroupId(group_id))
|
||||
.await?;
|
||||
Ok(Success::new())
|
||||
}
|
||||
@@ -166,7 +167,7 @@ impl<Handler: BackendHandler + Sync> Mutation<Handler> {
|
||||
}
|
||||
context
|
||||
.handler
|
||||
.remove_user_from_group(&user_id, GroupId(group_id))
|
||||
.remove_user_from_group(&UserId::new(&user_id), GroupId(group_id))
|
||||
.await?;
|
||||
Ok(Success::new())
|
||||
}
|
||||
@@ -178,7 +179,7 @@ impl<Handler: BackendHandler + Sync> Mutation<Handler> {
|
||||
if context.validation_result.user == user_id {
|
||||
return Err("Cannot delete current user".into());
|
||||
}
|
||||
context.handler.delete_user(&user_id).await?;
|
||||
context.handler.delete_user(&UserId::new(&user_id)).await?;
|
||||
Ok(Success::new())
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::domain::handler::{BackendHandler, GroupId, GroupIdAndName};
|
||||
use crate::domain::handler::{BackendHandler, GroupId, GroupIdAndName, UserId};
|
||||
use juniper::{graphql_object, FieldResult, GraphQLInputObject};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
@@ -48,6 +48,9 @@ impl TryInto<DomainRequestFilter> for RequestFilter {
|
||||
return Err("Multiple fields specified in request filter".to_string());
|
||||
}
|
||||
if let Some(e) = self.eq {
|
||||
if e.field.to_lowercase() == "uid" {
|
||||
return Ok(DomainRequestFilter::UserId(UserId::new(&e.value)));
|
||||
}
|
||||
return Ok(DomainRequestFilter::Equality(e.field, e.value));
|
||||
}
|
||||
if let Some(c) = self.any {
|
||||
@@ -109,7 +112,7 @@ impl<Handler: BackendHandler + Sync> Query<Handler> {
|
||||
}
|
||||
Ok(context
|
||||
.handler
|
||||
.get_user_details(&user_id)
|
||||
.get_user_details(&UserId::new(&user_id))
|
||||
.await
|
||||
.map(Into::into)?)
|
||||
}
|
||||
@@ -170,7 +173,7 @@ impl<Handler: BackendHandler> Default for User<Handler> {
|
||||
#[graphql_object(context = Context<Handler>)]
|
||||
impl<Handler: BackendHandler + Sync> User<Handler> {
|
||||
fn id(&self) -> &str {
|
||||
&self.user.user_id
|
||||
self.user.user_id.as_str()
|
||||
}
|
||||
|
||||
fn email(&self) -> &str {
|
||||
@@ -260,7 +263,7 @@ impl<Handler: BackendHandler> From<DomainGroup> for Group<Handler> {
|
||||
Self {
|
||||
group_id: group.id.0,
|
||||
display_name: group.display_name,
|
||||
members: Some(group.users.into_iter().map(Into::into).collect()),
|
||||
members: Some(group.users.into_iter().map(UserId::into_string).collect()),
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
@@ -305,10 +308,10 @@ mod tests {
|
||||
|
||||
let mut mock = MockTestBackendHandler::new();
|
||||
mock.expect_get_user_details()
|
||||
.with(eq("bob"))
|
||||
.with(eq(UserId::new("bob")))
|
||||
.return_once(|_| {
|
||||
Ok(DomainUser {
|
||||
user_id: "bob".to_string(),
|
||||
user_id: UserId::new("bob"),
|
||||
email: "bob@bobbers.on".to_string(),
|
||||
..Default::default()
|
||||
})
|
||||
@@ -316,7 +319,7 @@ mod tests {
|
||||
let mut groups = HashSet::new();
|
||||
groups.insert(GroupIdAndName(GroupId(3), "Bobbersons".to_string()));
|
||||
mock.expect_get_user_groups()
|
||||
.with(eq("bob"))
|
||||
.with(eq(UserId::new("bob")))
|
||||
.return_once(|_| Ok(groups));
|
||||
|
||||
let context = Context::<MockTestBackendHandler> {
|
||||
@@ -369,12 +372,12 @@ mod tests {
|
||||
.return_once(|_| {
|
||||
Ok(vec![
|
||||
DomainUser {
|
||||
user_id: "bob".to_string(),
|
||||
user_id: UserId::new("bob"),
|
||||
email: "bob@bobbers.on".to_string(),
|
||||
..Default::default()
|
||||
},
|
||||
DomainUser {
|
||||
user_id: "robert".to_string(),
|
||||
user_id: UserId::new("robert"),
|
||||
email: "robert@bobbers.on".to_string(),
|
||||
..Default::default()
|
||||
},
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use crate::domain::{
|
||||
handler::{
|
||||
BackendHandler, BindRequest, Group, GroupRequestFilter, LoginHandler, User,
|
||||
BackendHandler, BindRequest, Group, GroupRequestFilter, LoginHandler, User, UserId,
|
||||
UserRequestFilter,
|
||||
},
|
||||
opaque_handler::OpaqueHandler,
|
||||
@@ -71,7 +71,7 @@ fn get_user_id_from_distinguished_name(
|
||||
dn: &str,
|
||||
base_tree: &[(String, String)],
|
||||
base_dn_str: &str,
|
||||
) -> Result<String> {
|
||||
) -> Result<UserId> {
|
||||
let parts = parse_distinguished_name(dn).context("while parsing a user ID")?;
|
||||
if !is_subtree(&parts, base_tree) {
|
||||
bail!("Not a subtree of the base tree");
|
||||
@@ -84,7 +84,7 @@ fn get_user_id_from_distinguished_name(
|
||||
base_dn_str
|
||||
);
|
||||
}
|
||||
Ok(parts[0].1.to_string())
|
||||
Ok(UserId::new(&parts[0].1))
|
||||
} else {
|
||||
bail!(
|
||||
r#"Unexpected user DN format. Got "{}", expected: "cn=username,ou=people,{}""#,
|
||||
@@ -103,7 +103,7 @@ fn get_user_attribute(user: &User, attribute: &str, dn: &str) -> Result<Vec<Stri
|
||||
"person".to_string(),
|
||||
]),
|
||||
"dn" => Ok(vec![dn.to_string()]),
|
||||
"uid" => Ok(vec![user.user_id.clone()]),
|
||||
"uid" => Ok(vec![user.user_id.to_string()]),
|
||||
"mail" => Ok(vec![user.email.clone()]),
|
||||
"givenname" => Ok(vec![user.first_name.clone()]),
|
||||
"sn" => Ok(vec![user.last_name.clone()]),
|
||||
@@ -118,7 +118,7 @@ fn make_ldap_search_user_result_entry(
|
||||
base_dn_str: &str,
|
||||
attributes: &[String],
|
||||
) -> Result<LdapSearchResultEntry> {
|
||||
let dn = format!("cn={},ou=people,{}", user.user_id, base_dn_str);
|
||||
let dn = format!("cn={},ou=people,{}", user.user_id.as_str(), base_dn_str);
|
||||
Ok(LdapSearchResultEntry {
|
||||
dn: dn.clone(),
|
||||
attributes: attributes
|
||||
@@ -264,17 +264,17 @@ fn root_dse_response(base_dn: &str) -> LdapOp {
|
||||
}
|
||||
|
||||
pub struct LdapHandler<Backend: BackendHandler + LoginHandler + OpaqueHandler> {
|
||||
dn: String,
|
||||
dn: UserId,
|
||||
backend_handler: Backend,
|
||||
pub base_dn: Vec<(String, String)>,
|
||||
base_dn_str: String,
|
||||
ldap_user_dn: String,
|
||||
ldap_user_dn: UserId,
|
||||
}
|
||||
|
||||
impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend> {
|
||||
pub fn new(backend_handler: Backend, ldap_base_dn: String, ldap_user_dn: String) -> Self {
|
||||
pub fn new(backend_handler: Backend, ldap_base_dn: String, ldap_user_dn: UserId) -> Self {
|
||||
Self {
|
||||
dn: "Unauthenticated".to_string(),
|
||||
dn: UserId::new("unauthenticated"),
|
||||
backend_handler,
|
||||
base_dn: parse_distinguished_name(&ldap_base_dn).unwrap_or_else(|_| {
|
||||
panic!(
|
||||
@@ -282,7 +282,7 @@ impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend
|
||||
ldap_base_dn
|
||||
)
|
||||
}),
|
||||
ldap_user_dn: format!("cn={},ou=people,{}", ldap_user_dn, &ldap_base_dn),
|
||||
ldap_user_dn: UserId::new(&format!("cn={},ou=people,{}", ldap_user_dn, &ldap_base_dn)),
|
||||
base_dn_str: ldap_base_dn,
|
||||
}
|
||||
}
|
||||
@@ -307,14 +307,14 @@ impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend
|
||||
.await
|
||||
{
|
||||
Ok(()) => {
|
||||
self.dn = request.dn.clone();
|
||||
self.dn = UserId::new(&request.dn);
|
||||
(LdapResultCode::Success, "".to_string())
|
||||
}
|
||||
Err(_) => (LdapResultCode::InvalidCredentials, "".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
async fn change_password(&mut self, user: &str, password: &str) -> Result<()> {
|
||||
async fn change_password(&mut self, user: &UserId, password: &str) -> Result<()> {
|
||||
use lldap_auth::*;
|
||||
let mut rng = rand::rngs::OsRng;
|
||||
let registration_start_request =
|
||||
@@ -527,7 +527,7 @@ impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend
|
||||
}
|
||||
LdapOp::SearchRequest(request) => self.do_search(&request).await,
|
||||
LdapOp::UnbindRequest => {
|
||||
self.dn = "Unauthenticated".to_string();
|
||||
self.dn = UserId::new("unauthenticated");
|
||||
// No need to notify on unbind (per rfc4511)
|
||||
return None;
|
||||
}
|
||||
@@ -617,10 +617,12 @@ impl<Backend: BackendHandler + LoginHandler + OpaqueHandler> LdapHandler<Backend
|
||||
))))
|
||||
}
|
||||
} else {
|
||||
Ok(UserRequestFilter::Equality(
|
||||
map_field(field)?,
|
||||
value.clone(),
|
||||
))
|
||||
let field = map_field(field)?;
|
||||
if field == "user_id" {
|
||||
Ok(UserRequestFilter::UserId(UserId::new(value)))
|
||||
} else {
|
||||
Ok(UserRequestFilter::Equality(field, value.clone()))
|
||||
}
|
||||
}
|
||||
}
|
||||
LdapFilter::Present(field) => {
|
||||
@@ -661,17 +663,17 @@ mod tests {
|
||||
impl BackendHandler for TestBackendHandler {
|
||||
async fn list_users(&self, filters: Option<UserRequestFilter>) -> Result<Vec<User>>;
|
||||
async fn list_groups(&self, filters: Option<GroupRequestFilter>) -> Result<Vec<Group>>;
|
||||
async fn get_user_details(&self, user_id: &str) -> Result<User>;
|
||||
async fn get_user_details(&self, user_id: &UserId) -> Result<User>;
|
||||
async fn get_group_details(&self, group_id: GroupId) -> Result<GroupIdAndName>;
|
||||
async fn get_user_groups(&self, user: &str) -> Result<HashSet<GroupIdAndName>>;
|
||||
async fn get_user_groups(&self, user: &UserId) -> Result<HashSet<GroupIdAndName>>;
|
||||
async fn create_user(&self, request: CreateUserRequest) -> Result<()>;
|
||||
async fn update_user(&self, request: UpdateUserRequest) -> Result<()>;
|
||||
async fn update_group(&self, request: UpdateGroupRequest) -> Result<()>;
|
||||
async fn delete_user(&self, user_id: &str) -> Result<()>;
|
||||
async fn delete_user(&self, user_id: &UserId) -> Result<()>;
|
||||
async fn create_group(&self, group_name: &str) -> Result<GroupId>;
|
||||
async fn delete_group(&self, group_id: GroupId) -> Result<()>;
|
||||
async fn add_user_to_group(&self, user_id: &str, group_id: GroupId) -> Result<()>;
|
||||
async fn remove_user_from_group(&self, user_id: &str, group_id: GroupId) -> Result<()>;
|
||||
async fn add_user_to_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()>;
|
||||
async fn remove_user_from_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()>;
|
||||
}
|
||||
#[async_trait]
|
||||
impl OpaqueHandler for TestBackendHandler {
|
||||
@@ -679,7 +681,7 @@ mod tests {
|
||||
&self,
|
||||
request: login::ClientLoginStartRequest
|
||||
) -> Result<login::ServerLoginStartResponse>;
|
||||
async fn login_finish(&self, request: login::ClientLoginFinishRequest) -> Result<String>;
|
||||
async fn login_finish(&self, request: login::ClientLoginFinishRequest) -> Result<UserId>;
|
||||
async fn registration_start(
|
||||
&self,
|
||||
request: registration::ClientRegistrationStartRequest
|
||||
@@ -720,12 +722,12 @@ mod tests {
|
||||
) -> LdapHandler<MockTestBackendHandler> {
|
||||
mock.expect_bind()
|
||||
.with(eq(BindRequest {
|
||||
name: "test".to_string(),
|
||||
name: UserId::new("test"),
|
||||
password: "pass".to_string(),
|
||||
}))
|
||||
.return_once(|_| Ok(()));
|
||||
let mut ldap_handler =
|
||||
LdapHandler::new(mock, "dc=example,dc=com".to_string(), "test".to_string());
|
||||
LdapHandler::new(mock, "dc=example,dc=com".to_string(), UserId::new("test"));
|
||||
let request = LdapBindRequest {
|
||||
dn: "cn=test,ou=people,dc=example,dc=com".to_string(),
|
||||
cred: LdapBindCred::Simple("pass".to_string()),
|
||||
@@ -742,13 +744,13 @@ mod tests {
|
||||
let mut mock = MockTestBackendHandler::new();
|
||||
mock.expect_bind()
|
||||
.with(eq(crate::domain::handler::BindRequest {
|
||||
name: "bob".to_string(),
|
||||
name: UserId::new("bob"),
|
||||
password: "pass".to_string(),
|
||||
}))
|
||||
.times(1)
|
||||
.return_once(|_| Ok(()));
|
||||
let mut ldap_handler =
|
||||
LdapHandler::new(mock, "dc=example,dc=com".to_string(), "test".to_string());
|
||||
LdapHandler::new(mock, "dc=example,dc=com".to_string(), UserId::new("test"));
|
||||
|
||||
let request = LdapOp::BindRequest(LdapBindRequest {
|
||||
dn: "cn=bob,ou=people,dc=example,dc=com".to_string(),
|
||||
@@ -773,13 +775,13 @@ mod tests {
|
||||
let mut mock = MockTestBackendHandler::new();
|
||||
mock.expect_bind()
|
||||
.with(eq(crate::domain::handler::BindRequest {
|
||||
name: "test".to_string(),
|
||||
name: UserId::new("test"),
|
||||
password: "pass".to_string(),
|
||||
}))
|
||||
.times(1)
|
||||
.return_once(|_| Ok(()));
|
||||
let mut ldap_handler =
|
||||
LdapHandler::new(mock, "dc=example,dc=com".to_string(), "test".to_string());
|
||||
LdapHandler::new(mock, "dc=example,dc=com".to_string(), UserId::new("test"));
|
||||
|
||||
let request = LdapBindRequest {
|
||||
dn: "cn=test,ou=people,dc=example,dc=com".to_string(),
|
||||
@@ -796,13 +798,13 @@ mod tests {
|
||||
let mut mock = MockTestBackendHandler::new();
|
||||
mock.expect_bind()
|
||||
.with(eq(crate::domain::handler::BindRequest {
|
||||
name: "test".to_string(),
|
||||
name: UserId::new("test"),
|
||||
password: "pass".to_string(),
|
||||
}))
|
||||
.times(1)
|
||||
.return_once(|_| Ok(()));
|
||||
let mut ldap_handler =
|
||||
LdapHandler::new(mock, "dc=example,dc=com".to_string(), "admin".to_string());
|
||||
LdapHandler::new(mock, "dc=example,dc=com".to_string(), UserId::new("admin"));
|
||||
|
||||
let request = LdapBindRequest {
|
||||
dn: "cn=test,ou=people,dc=example,dc=com".to_string(),
|
||||
@@ -827,7 +829,7 @@ mod tests {
|
||||
async fn test_bind_invalid_dn() {
|
||||
let mock = MockTestBackendHandler::new();
|
||||
let mut ldap_handler =
|
||||
LdapHandler::new(mock, "dc=example,dc=com".to_string(), "admin".to_string());
|
||||
LdapHandler::new(mock, "dc=example,dc=com".to_string(), UserId::new("admin"));
|
||||
|
||||
let request = LdapBindRequest {
|
||||
dn: "cn=bob,dc=example,dc=com".to_string(),
|
||||
@@ -903,7 +905,7 @@ mod tests {
|
||||
mock.expect_list_users().times(1).return_once(|_| {
|
||||
Ok(vec![
|
||||
User {
|
||||
user_id: "bob_1".to_string(),
|
||||
user_id: UserId::new("bob_1"),
|
||||
email: "bob@bobmail.bob".to_string(),
|
||||
display_name: "Bôb Böbberson".to_string(),
|
||||
first_name: "Bôb".to_string(),
|
||||
@@ -911,7 +913,7 @@ mod tests {
|
||||
..Default::default()
|
||||
},
|
||||
User {
|
||||
user_id: "jim".to_string(),
|
||||
user_id: UserId::new("jim"),
|
||||
email: "jim@cricket.jim".to_string(),
|
||||
display_name: "Jimminy Cricket".to_string(),
|
||||
first_name: "Jim".to_string(),
|
||||
@@ -1037,12 +1039,12 @@ mod tests {
|
||||
Group {
|
||||
id: GroupId(1),
|
||||
display_name: "group_1".to_string(),
|
||||
users: vec!["bob".to_string(), "john".to_string()],
|
||||
users: vec![UserId::new("bob"), UserId::new("john")],
|
||||
},
|
||||
Group {
|
||||
id: GroupId(3),
|
||||
display_name: "bestgroup".to_string(),
|
||||
users: vec!["john".to_string()],
|
||||
users: vec![UserId::new("john")],
|
||||
},
|
||||
])
|
||||
});
|
||||
@@ -1111,7 +1113,7 @@ mod tests {
|
||||
mock.expect_list_groups()
|
||||
.with(eq(Some(GroupRequestFilter::And(vec![
|
||||
GroupRequestFilter::DisplayName("group_1".to_string()),
|
||||
GroupRequestFilter::Member("bob".to_string()),
|
||||
GroupRequestFilter::Member(UserId::new("bob")),
|
||||
GroupRequestFilter::And(vec![]),
|
||||
]))))
|
||||
.times(1)
|
||||
@@ -1250,10 +1252,7 @@ mod tests {
|
||||
mock.expect_list_users()
|
||||
.with(eq(Some(UserRequestFilter::And(vec![
|
||||
UserRequestFilter::Or(vec![
|
||||
UserRequestFilter::Not(Box::new(UserRequestFilter::Equality(
|
||||
"user_id".to_string(),
|
||||
"bob".to_string(),
|
||||
))),
|
||||
UserRequestFilter::Not(Box::new(UserRequestFilter::UserId(UserId::new("bob")))),
|
||||
UserRequestFilter::And(vec![]),
|
||||
UserRequestFilter::Not(Box::new(UserRequestFilter::And(vec![]))),
|
||||
UserRequestFilter::And(vec![]),
|
||||
@@ -1342,7 +1341,7 @@ mod tests {
|
||||
.times(1)
|
||||
.return_once(|_| {
|
||||
Ok(vec![User {
|
||||
user_id: "bob_1".to_string(),
|
||||
user_id: UserId::new("bob_1"),
|
||||
..Default::default()
|
||||
}])
|
||||
});
|
||||
@@ -1378,7 +1377,7 @@ mod tests {
|
||||
let mut mock = MockTestBackendHandler::new();
|
||||
mock.expect_list_users().times(1).return_once(|_| {
|
||||
Ok(vec![User {
|
||||
user_id: "bob_1".to_string(),
|
||||
user_id: UserId::new("bob_1"),
|
||||
email: "bob@bobmail.bob".to_string(),
|
||||
display_name: "Bôb Böbberson".to_string(),
|
||||
first_name: "Bôb".to_string(),
|
||||
@@ -1393,7 +1392,7 @@ mod tests {
|
||||
Ok(vec![Group {
|
||||
id: GroupId(1),
|
||||
display_name: "group_1".to_string(),
|
||||
users: vec!["bob".to_string(), "john".to_string()],
|
||||
users: vec![UserId::new("bob"), UserId::new("john")],
|
||||
}])
|
||||
});
|
||||
let mut ldap_handler = setup_bound_handler(mock).await;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use super::{jwt_sql_tables::*, tcp_backend_handler::*};
|
||||
use crate::domain::{error::*, sql_backend_handler::SqlBackendHandler};
|
||||
use crate::domain::{error::*, handler::UserId, sql_backend_handler::SqlBackendHandler};
|
||||
use async_trait::async_trait;
|
||||
use futures_util::StreamExt;
|
||||
use sea_query::{Expr, Iden, Query, SimpleExpr};
|
||||
@@ -34,7 +34,7 @@ impl TcpBackendHandler for SqlBackendHandler {
|
||||
.map_err(|e| anyhow::anyhow!(e))
|
||||
}
|
||||
|
||||
async fn create_refresh_token(&self, user: &str) -> Result<(String, chrono::Duration)> {
|
||||
async fn create_refresh_token(&self, user: &UserId) -> Result<(String, chrono::Duration)> {
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{Hash, Hasher};
|
||||
// TODO: Initialize the rng only once. Maybe Arc<Cell>?
|
||||
@@ -62,7 +62,7 @@ impl TcpBackendHandler for SqlBackendHandler {
|
||||
Ok((refresh_token, duration))
|
||||
}
|
||||
|
||||
async fn check_token(&self, refresh_token_hash: u64, user: &str) -> Result<bool> {
|
||||
async fn check_token(&self, refresh_token_hash: u64, user: &UserId) -> Result<bool> {
|
||||
let query = Query::select()
|
||||
.expr(SimpleExpr::Value(1.into()))
|
||||
.from(JwtRefreshStorage::Table)
|
||||
@@ -74,7 +74,7 @@ impl TcpBackendHandler for SqlBackendHandler {
|
||||
.await?
|
||||
.is_some())
|
||||
}
|
||||
async fn blacklist_jwts(&self, user: &str) -> Result<HashSet<u64>> {
|
||||
async fn blacklist_jwts(&self, user: &UserId) -> Result<HashSet<u64>> {
|
||||
use sqlx::Result;
|
||||
let query = Query::select()
|
||||
.column(JwtStorage::JwtHash)
|
||||
@@ -106,7 +106,7 @@ impl TcpBackendHandler for SqlBackendHandler {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn start_password_reset(&self, user: &str) -> Result<Option<String>> {
|
||||
async fn start_password_reset(&self, user: &UserId) -> Result<Option<String>> {
|
||||
let query = Query::select()
|
||||
.column(Users::UserId)
|
||||
.from(Users::Table)
|
||||
@@ -138,7 +138,7 @@ impl TcpBackendHandler for SqlBackendHandler {
|
||||
Ok(Some(token))
|
||||
}
|
||||
|
||||
async fn get_user_id_for_password_reset_token(&self, token: &str) -> Result<String> {
|
||||
async fn get_user_id_for_password_reset_token(&self, token: &str) -> Result<UserId> {
|
||||
let query = Query::select()
|
||||
.column(PasswordResetTokens::UserId)
|
||||
.from(PasswordResetTokens::Table)
|
||||
|
||||
@@ -1,22 +1,22 @@
|
||||
use async_trait::async_trait;
|
||||
use std::collections::HashSet;
|
||||
|
||||
use crate::domain::error::Result;
|
||||
use crate::domain::{error::Result, handler::UserId};
|
||||
|
||||
#[async_trait]
|
||||
pub trait TcpBackendHandler {
|
||||
async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>>;
|
||||
async fn create_refresh_token(&self, user: &str) -> Result<(String, chrono::Duration)>;
|
||||
async fn check_token(&self, refresh_token_hash: u64, user: &str) -> Result<bool>;
|
||||
async fn blacklist_jwts(&self, user: &str) -> Result<HashSet<u64>>;
|
||||
async fn create_refresh_token(&self, user: &UserId) -> Result<(String, chrono::Duration)>;
|
||||
async fn check_token(&self, refresh_token_hash: u64, user: &UserId) -> Result<bool>;
|
||||
async fn blacklist_jwts(&self, user: &UserId) -> Result<HashSet<u64>>;
|
||||
async fn delete_refresh_token(&self, refresh_token_hash: u64) -> Result<()>;
|
||||
|
||||
/// Request a token to reset a user's password.
|
||||
/// If the user doesn't exist, returns `Ok(None)`, otherwise `Ok(Some(token))`.
|
||||
async fn start_password_reset(&self, user: &str) -> Result<Option<String>>;
|
||||
async fn start_password_reset(&self, user: &UserId) -> Result<Option<String>>;
|
||||
|
||||
/// Get the user ID associated with a password reset token.
|
||||
async fn get_user_id_for_password_reset_token(&self, token: &str) -> Result<String>;
|
||||
async fn get_user_id_for_password_reset_token(&self, token: &str) -> Result<UserId>;
|
||||
|
||||
async fn delete_password_reset_token(&self, token: &str) -> Result<()>;
|
||||
}
|
||||
@@ -37,27 +37,27 @@ mockall::mock! {
|
||||
impl BackendHandler for TestTcpBackendHandler {
|
||||
async fn list_users(&self, filters: Option<UserRequestFilter>) -> Result<Vec<User>>;
|
||||
async fn list_groups(&self, filters: Option<GroupRequestFilter>) -> Result<Vec<Group>>;
|
||||
async fn get_user_details(&self, user_id: &str) -> Result<User>;
|
||||
async fn get_user_details(&self, user_id: &UserId) -> Result<User>;
|
||||
async fn get_group_details(&self, group_id: GroupId) -> Result<GroupIdAndName>;
|
||||
async fn get_user_groups(&self, user: &str) -> Result<HashSet<GroupIdAndName>>;
|
||||
async fn get_user_groups(&self, user: &UserId) -> Result<HashSet<GroupIdAndName>>;
|
||||
async fn create_user(&self, request: CreateUserRequest) -> Result<()>;
|
||||
async fn update_user(&self, request: UpdateUserRequest) -> Result<()>;
|
||||
async fn update_group(&self, request: UpdateGroupRequest) -> Result<()>;
|
||||
async fn delete_user(&self, user_id: &str) -> Result<()>;
|
||||
async fn delete_user(&self, user_id: &UserId) -> Result<()>;
|
||||
async fn create_group(&self, group_name: &str) -> Result<GroupId>;
|
||||
async fn delete_group(&self, group_id: GroupId) -> Result<()>;
|
||||
async fn add_user_to_group(&self, user_id: &str, group_id: GroupId) -> Result<()>;
|
||||
async fn remove_user_from_group(&self, user_id: &str, group_id: GroupId) -> Result<()>;
|
||||
async fn add_user_to_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()>;
|
||||
async fn remove_user_from_group(&self, user_id: &UserId, group_id: GroupId) -> Result<()>;
|
||||
}
|
||||
#[async_trait]
|
||||
impl TcpBackendHandler for TestTcpBackendHandler {
|
||||
async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>>;
|
||||
async fn create_refresh_token(&self, user: &str) -> Result<(String, chrono::Duration)>;
|
||||
async fn check_token(&self, refresh_token_hash: u64, user: &str) -> Result<bool>;
|
||||
async fn blacklist_jwts(&self, user: &str) -> Result<HashSet<u64>>;
|
||||
async fn create_refresh_token(&self, user: &UserId) -> Result<(String, chrono::Duration)>;
|
||||
async fn check_token(&self, refresh_token_hash: u64, user: &UserId) -> Result<bool>;
|
||||
async fn blacklist_jwts(&self, user: &UserId) -> Result<HashSet<u64>>;
|
||||
async fn delete_refresh_token(&self, refresh_token_hash: u64) -> Result<()>;
|
||||
async fn start_password_reset(&self, user: &str) -> Result<Option<String>>;
|
||||
async fn get_user_id_for_password_reset_token(&self, token: &str) -> Result<String>;
|
||||
async fn start_password_reset(&self, user: &UserId) -> Result<Option<String>>;
|
||||
async fn get_user_id_for_password_reset_token(&self, token: &str) -> Result<UserId>;
|
||||
async fn delete_password_reset_token(&self, token: &str) -> Result<()>;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user