Move backend source to server/ subpackage
To clarify the organization.
This commit is contained in:
committed by
nitnelave
parent
3eb53ba5bf
commit
d8df47b35d
22
server/src/domain/error.rs
Normal file
22
server/src/domain/error.rs
Normal file
@@ -0,0 +1,22 @@
|
||||
use thiserror::Error;
|
||||
|
||||
#[allow(clippy::enum_variant_names)]
|
||||
#[derive(Error, Debug)]
|
||||
pub enum DomainError {
|
||||
#[error("Authentication error for `{0}`")]
|
||||
AuthenticationError(String),
|
||||
#[error("Database error: `{0}`")]
|
||||
DatabaseError(#[from] sqlx::Error),
|
||||
#[error("Authentication protocol error for `{0}`")]
|
||||
AuthenticationProtocolError(#[from] lldap_auth::opaque::AuthenticationError),
|
||||
#[error("Unknown crypto error: `{0}`")]
|
||||
UnknownCryptoError(#[from] orion::errors::UnknownCryptoError),
|
||||
#[error("Binary serialization error: `{0}`")]
|
||||
BinarySerializationError(#[from] bincode::Error),
|
||||
#[error("Invalid base64: `{0}`")]
|
||||
Base64DecodeError(#[from] base64::DecodeError),
|
||||
#[error("Internal error: `{0}`")]
|
||||
InternalError(String),
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, DomainError>;
|
||||
103
server/src/domain/handler.rs
Normal file
103
server/src/domain/handler.rs
Normal file
@@ -0,0 +1,103 @@
|
||||
use super::error::*;
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashSet;
|
||||
|
||||
#[derive(PartialEq, Eq, Debug, Serialize, Deserialize)]
|
||||
#[cfg_attr(not(target_arch = "wasm32"), derive(sqlx::FromRow))]
|
||||
pub struct User {
|
||||
pub user_id: String,
|
||||
pub email: String,
|
||||
pub display_name: Option<String>,
|
||||
pub first_name: Option<String>,
|
||||
pub last_name: Option<String>,
|
||||
// pub avatar: ?,
|
||||
pub creation_date: chrono::DateTime<chrono::Utc>,
|
||||
}
|
||||
|
||||
impl Default for User {
|
||||
fn default() -> Self {
|
||||
use chrono::TimeZone;
|
||||
User {
|
||||
user_id: String::new(),
|
||||
email: String::new(),
|
||||
display_name: None,
|
||||
first_name: None,
|
||||
last_name: None,
|
||||
creation_date: chrono::Utc.timestamp(0, 0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Debug, Serialize, Deserialize)]
|
||||
pub struct Group {
|
||||
pub display_name: String,
|
||||
pub users: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct BindRequest {
|
||||
pub name: String,
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Debug, Serialize, Deserialize, Clone)]
|
||||
pub enum RequestFilter {
|
||||
And(Vec<RequestFilter>),
|
||||
Or(Vec<RequestFilter>),
|
||||
Not(Box<RequestFilter>),
|
||||
Equality(String, String),
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Debug, Serialize, Deserialize, Clone, Default)]
|
||||
pub struct CreateUserRequest {
|
||||
// Same fields as User, but no creation_date, and with password.
|
||||
pub user_id: String,
|
||||
pub email: String,
|
||||
pub display_name: Option<String>,
|
||||
pub first_name: Option<String>,
|
||||
pub last_name: Option<String>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait LoginHandler: Clone + Send {
|
||||
async fn bind(&self, request: BindRequest) -> Result<()>;
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||
pub struct GroupId(pub i32);
|
||||
|
||||
#[async_trait]
|
||||
pub trait BackendHandler: Clone + Send {
|
||||
async fn list_users(&self, filters: Option<RequestFilter>) -> Result<Vec<User>>;
|
||||
async fn list_groups(&self) -> Result<Vec<Group>>;
|
||||
async fn get_user_details(&self, user_id: &str) -> Result<User>;
|
||||
async fn create_user(&self, request: CreateUserRequest) -> Result<()>;
|
||||
async fn delete_user(&self, user_id: &str) -> Result<()>;
|
||||
async fn create_group(&self, group_name: &str) -> Result<GroupId>;
|
||||
async fn add_user_to_group(&self, user_id: &str, group_id: GroupId) -> Result<()>;
|
||||
async fn get_user_groups(&self, user: &str) -> Result<HashSet<String>>;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mockall::mock! {
|
||||
pub TestBackendHandler{}
|
||||
impl Clone for TestBackendHandler {
|
||||
fn clone(&self) -> Self;
|
||||
}
|
||||
#[async_trait]
|
||||
impl BackendHandler for TestBackendHandler {
|
||||
async fn list_users(&self, filters: Option<RequestFilter>) -> Result<Vec<User>>;
|
||||
async fn list_groups(&self) -> Result<Vec<Group>>;
|
||||
async fn get_user_details(&self, user_id: &str) -> Result<User>;
|
||||
async fn create_user(&self, request: CreateUserRequest) -> Result<()>;
|
||||
async fn delete_user(&self, user_id: &str) -> Result<()>;
|
||||
async fn create_group(&self, group_name: &str) -> Result<GroupId>;
|
||||
async fn get_user_groups(&self, user: &str) -> Result<HashSet<String>>;
|
||||
async fn add_user_to_group(&self, user_id: &str, group_id: GroupId) -> Result<()>;
|
||||
}
|
||||
#[async_trait]
|
||||
impl LoginHandler for TestBackendHandler {
|
||||
async fn bind(&self, request: BindRequest) -> Result<()>;
|
||||
}
|
||||
}
|
||||
6
server/src/domain/mod.rs
Normal file
6
server/src/domain/mod.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
pub mod error;
|
||||
pub mod handler;
|
||||
pub mod opaque_handler;
|
||||
pub mod sql_backend_handler;
|
||||
pub mod sql_opaque_handler;
|
||||
pub mod sql_tables;
|
||||
36
server/src/domain/opaque_handler.rs
Normal file
36
server/src/domain/opaque_handler.rs
Normal file
@@ -0,0 +1,36 @@
|
||||
use super::error::*;
|
||||
use async_trait::async_trait;
|
||||
|
||||
pub use lldap_auth::{login, registration};
|
||||
|
||||
#[async_trait]
|
||||
pub trait OpaqueHandler: Clone + Send {
|
||||
async fn login_start(
|
||||
&self,
|
||||
request: login::ClientLoginStartRequest,
|
||||
) -> Result<login::ServerLoginStartResponse>;
|
||||
async fn login_finish(&self, request: login::ClientLoginFinishRequest) -> Result<String>;
|
||||
async fn registration_start(
|
||||
&self,
|
||||
request: registration::ClientRegistrationStartRequest,
|
||||
) -> Result<registration::ServerRegistrationStartResponse>;
|
||||
async fn registration_finish(
|
||||
&self,
|
||||
request: registration::ClientRegistrationFinishRequest,
|
||||
) -> Result<()>;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mockall::mock! {
|
||||
pub TestOpaqueHandler{}
|
||||
impl Clone for TestOpaqueHandler {
|
||||
fn clone(&self) -> Self;
|
||||
}
|
||||
#[async_trait]
|
||||
impl OpaqueHandler for TestOpaqueHandler {
|
||||
async fn login_start(&self, request: login::ClientLoginStartRequest) -> Result<login::ServerLoginStartResponse>;
|
||||
async fn login_finish(&self, request: login::ClientLoginFinishRequest ) -> Result<String>;
|
||||
async fn registration_start(&self, request: registration::ClientRegistrationStartRequest) -> Result<registration::ServerRegistrationStartResponse>;
|
||||
async fn registration_finish(&self, request: registration::ClientRegistrationFinishRequest ) -> Result<()>;
|
||||
}
|
||||
}
|
||||
536
server/src/domain/sql_backend_handler.rs
Normal file
536
server/src/domain/sql_backend_handler.rs
Normal file
@@ -0,0 +1,536 @@
|
||||
use super::{error::*, handler::*, sql_tables::*};
|
||||
use crate::infra::configuration::Configuration;
|
||||
use async_trait::async_trait;
|
||||
use futures_util::StreamExt;
|
||||
use futures_util::TryStreamExt;
|
||||
use sea_query::{Expr, Iden, Order, Query, SimpleExpr, Value};
|
||||
use sqlx::Row;
|
||||
use std::collections::HashSet;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SqlBackendHandler {
|
||||
pub(crate) config: Configuration,
|
||||
pub(crate) sql_pool: Pool,
|
||||
}
|
||||
|
||||
impl SqlBackendHandler {
|
||||
pub fn new(config: Configuration, sql_pool: Pool) -> Self {
|
||||
SqlBackendHandler { config, sql_pool }
|
||||
}
|
||||
}
|
||||
|
||||
fn get_filter_expr(filter: RequestFilter) -> SimpleExpr {
|
||||
use RequestFilter::*;
|
||||
fn get_repeated_filter(
|
||||
fs: Vec<RequestFilter>,
|
||||
field: &dyn Fn(SimpleExpr, SimpleExpr) -> SimpleExpr,
|
||||
) -> SimpleExpr {
|
||||
let mut it = fs.into_iter();
|
||||
let first_expr = match it.next() {
|
||||
None => return Expr::value(true),
|
||||
Some(f) => get_filter_expr(f),
|
||||
};
|
||||
it.fold(first_expr, |e, f| field(e, get_filter_expr(f)))
|
||||
}
|
||||
match filter {
|
||||
And(fs) => get_repeated_filter(fs, &SimpleExpr::and),
|
||||
Or(fs) => get_repeated_filter(fs, &SimpleExpr::or),
|
||||
Not(f) => Expr::not(Expr::expr(get_filter_expr(*f))),
|
||||
Equality(s1, s2) => Expr::expr(Expr::cust(&s1)).eq(s2),
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl BackendHandler for SqlBackendHandler {
|
||||
async fn list_users(&self, filters: Option<RequestFilter>) -> Result<Vec<User>> {
|
||||
let query = {
|
||||
let mut query_builder = Query::select()
|
||||
.column(Users::UserId)
|
||||
.column(Users::Email)
|
||||
.column(Users::DisplayName)
|
||||
.column(Users::FirstName)
|
||||
.column(Users::LastName)
|
||||
.column(Users::Avatar)
|
||||
.column(Users::CreationDate)
|
||||
.from(Users::Table)
|
||||
.order_by(Users::UserId, Order::Asc)
|
||||
.to_owned();
|
||||
if let Some(filter) = filters {
|
||||
if filter != RequestFilter::And(Vec::new())
|
||||
&& filter != RequestFilter::Or(Vec::new())
|
||||
{
|
||||
query_builder.and_where(get_filter_expr(filter));
|
||||
}
|
||||
}
|
||||
|
||||
query_builder.to_string(DbQueryBuilder {})
|
||||
};
|
||||
|
||||
let results = sqlx::query_as::<_, User>(&query)
|
||||
.fetch(&self.sql_pool)
|
||||
.collect::<Vec<sqlx::Result<User>>>()
|
||||
.await;
|
||||
|
||||
Ok(results.into_iter().collect::<sqlx::Result<Vec<User>>>()?)
|
||||
}
|
||||
|
||||
async fn list_groups(&self) -> Result<Vec<Group>> {
|
||||
let query: String = Query::select()
|
||||
.column(Groups::DisplayName)
|
||||
.column(Memberships::UserId)
|
||||
.from(Groups::Table)
|
||||
.left_join(
|
||||
Memberships::Table,
|
||||
Expr::tbl(Groups::Table, Groups::GroupId)
|
||||
.equals(Memberships::Table, Memberships::GroupId),
|
||||
)
|
||||
.order_by(Groups::DisplayName, Order::Asc)
|
||||
.order_by(Memberships::UserId, Order::Asc)
|
||||
.to_string(DbQueryBuilder {});
|
||||
|
||||
let mut results = sqlx::query(&query).fetch(&self.sql_pool);
|
||||
let mut groups = Vec::new();
|
||||
// The rows are ordered by group, user, so we need to group them into vectors.
|
||||
{
|
||||
let mut current_group = String::new();
|
||||
let mut current_users = Vec::new();
|
||||
while let Some(row) = results.try_next().await? {
|
||||
let display_name = row.get::<String, _>(&*Groups::DisplayName.to_string());
|
||||
if display_name != current_group {
|
||||
if !current_group.is_empty() {
|
||||
groups.push(Group {
|
||||
display_name: current_group,
|
||||
users: current_users,
|
||||
});
|
||||
current_users = Vec::new();
|
||||
}
|
||||
current_group = display_name.clone();
|
||||
}
|
||||
current_users.push(row.get::<String, _>(&*Memberships::UserId.to_string()));
|
||||
}
|
||||
groups.push(Group {
|
||||
display_name: current_group,
|
||||
users: current_users,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(groups)
|
||||
}
|
||||
|
||||
async fn get_user_details(&self, user_id: &str) -> Result<User> {
|
||||
let query = Query::select()
|
||||
.column(Users::UserId)
|
||||
.column(Users::Email)
|
||||
.column(Users::DisplayName)
|
||||
.column(Users::FirstName)
|
||||
.column(Users::LastName)
|
||||
.column(Users::Avatar)
|
||||
.column(Users::CreationDate)
|
||||
.from(Users::Table)
|
||||
.and_where(Expr::col(Users::UserId).eq(user_id))
|
||||
.to_string(DbQueryBuilder {});
|
||||
|
||||
Ok(sqlx::query_as::<_, User>(&query)
|
||||
.fetch_one(&self.sql_pool)
|
||||
.await?)
|
||||
}
|
||||
|
||||
async fn get_user_groups(&self, user: &str) -> Result<HashSet<String>> {
|
||||
if user == self.config.ldap_user_dn {
|
||||
let mut groups = HashSet::new();
|
||||
groups.insert("lldap_admin".to_string());
|
||||
return Ok(groups);
|
||||
}
|
||||
let query: String = Query::select()
|
||||
.column(Groups::DisplayName)
|
||||
.from(Groups::Table)
|
||||
.inner_join(
|
||||
Memberships::Table,
|
||||
Expr::tbl(Groups::Table, Groups::GroupId)
|
||||
.equals(Memberships::Table, Memberships::GroupId),
|
||||
)
|
||||
.and_where(Expr::col(Memberships::UserId).eq(user))
|
||||
.to_string(DbQueryBuilder {});
|
||||
|
||||
sqlx::query(&query)
|
||||
// Extract the group id from the row.
|
||||
.map(|row: DbRow| row.get::<String, _>(&*Groups::DisplayName.to_string()))
|
||||
.fetch(&self.sql_pool)
|
||||
// Collect the vector of rows, each potentially an error.
|
||||
.collect::<Vec<sqlx::Result<String>>>()
|
||||
.await
|
||||
.into_iter()
|
||||
// Transform it into a single result (the first error if any), and group the group_ids
|
||||
// into a HashSet.
|
||||
.collect::<sqlx::Result<HashSet<_>>>()
|
||||
// Map the sqlx::Error into a DomainError.
|
||||
.map_err(DomainError::DatabaseError)
|
||||
}
|
||||
|
||||
async fn create_user(&self, request: CreateUserRequest) -> Result<()> {
|
||||
let columns = vec![
|
||||
Users::UserId,
|
||||
Users::Email,
|
||||
Users::DisplayName,
|
||||
Users::FirstName,
|
||||
Users::LastName,
|
||||
Users::CreationDate,
|
||||
];
|
||||
let values = vec![
|
||||
request.user_id.clone().into(),
|
||||
request.email.into(),
|
||||
request.display_name.map(Into::into).unwrap_or(Value::Null),
|
||||
request.first_name.map(Into::into).unwrap_or(Value::Null),
|
||||
request.last_name.map(Into::into).unwrap_or(Value::Null),
|
||||
chrono::Utc::now().naive_utc().into(),
|
||||
];
|
||||
let query = Query::insert()
|
||||
.into_table(Users::Table)
|
||||
.columns(columns)
|
||||
.values_panic(values)
|
||||
.to_string(DbQueryBuilder {});
|
||||
sqlx::query(&query).execute(&self.sql_pool).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn delete_user(&self, user_id: &str) -> Result<()> {
|
||||
let delete_query = Query::delete()
|
||||
.from_table(Users::Table)
|
||||
.and_where(Expr::col(Users::UserId).eq(user_id))
|
||||
.to_string(DbQueryBuilder {});
|
||||
sqlx::query(&delete_query).execute(&self.sql_pool).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_group(&self, group_name: &str) -> Result<GroupId> {
|
||||
let query = Query::insert()
|
||||
.into_table(Groups::Table)
|
||||
.columns(vec![Groups::DisplayName])
|
||||
.values_panic(vec![group_name.into()])
|
||||
.to_string(DbQueryBuilder {});
|
||||
sqlx::query(&query).execute(&self.sql_pool).await?;
|
||||
let query = Query::select()
|
||||
.column(Groups::GroupId)
|
||||
.from(Groups::Table)
|
||||
.and_where(Expr::col(Groups::DisplayName).eq(group_name))
|
||||
.to_string(DbQueryBuilder {});
|
||||
let row = sqlx::query(&query).fetch_one(&self.sql_pool).await?;
|
||||
Ok(GroupId(row.get::<i32, _>(&*Groups::GroupId.to_string())))
|
||||
}
|
||||
|
||||
async fn add_user_to_group(&self, user_id: &str, group_id: GroupId) -> Result<()> {
|
||||
let query = Query::insert()
|
||||
.into_table(Memberships::Table)
|
||||
.columns(vec![Memberships::UserId, Memberships::GroupId])
|
||||
.values_panic(vec![user_id.into(), group_id.0.into()])
|
||||
.to_string(DbQueryBuilder {});
|
||||
sqlx::query(&query).execute(&self.sql_pool).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::domain::sql_tables::init_table;
|
||||
use crate::infra::configuration::ConfigurationBuilder;
|
||||
use lldap_auth::{opaque, registration};
|
||||
|
||||
fn get_default_config() -> Configuration {
|
||||
ConfigurationBuilder::default()
|
||||
.verbose(true)
|
||||
.build()
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
async fn get_in_memory_db() -> Pool {
|
||||
PoolOptions::new().connect("sqlite::memory:").await.unwrap()
|
||||
}
|
||||
|
||||
async fn get_initialized_db() -> Pool {
|
||||
let sql_pool = get_in_memory_db().await;
|
||||
init_table(&sql_pool).await.unwrap();
|
||||
sql_pool
|
||||
}
|
||||
|
||||
async fn insert_user(handler: &SqlBackendHandler, name: &str, pass: &str) {
|
||||
use crate::domain::opaque_handler::OpaqueHandler;
|
||||
insert_user_no_password(handler, name).await;
|
||||
let mut rng = rand::rngs::OsRng;
|
||||
let client_registration_start =
|
||||
opaque::client::registration::start_registration(pass, &mut rng).unwrap();
|
||||
let response = handler
|
||||
.registration_start(registration::ClientRegistrationStartRequest {
|
||||
username: name.to_string(),
|
||||
registration_start_request: client_registration_start.message,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let registration_upload = opaque::client::registration::finish_registration(
|
||||
client_registration_start.state,
|
||||
response.registration_response,
|
||||
&mut rng,
|
||||
)
|
||||
.unwrap();
|
||||
handler
|
||||
.registration_finish(registration::ClientRegistrationFinishRequest {
|
||||
server_data: response.server_data,
|
||||
registration_upload: registration_upload.message,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
async fn insert_user_no_password(handler: &SqlBackendHandler, name: &str) {
|
||||
handler
|
||||
.create_user(CreateUserRequest {
|
||||
user_id: name.to_string(),
|
||||
email: "bob@bob.bob".to_string(),
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
async fn insert_group(handler: &SqlBackendHandler, name: &str) -> GroupId {
|
||||
handler.create_group(name).await.unwrap()
|
||||
}
|
||||
|
||||
async fn insert_membership(handler: &SqlBackendHandler, group_id: GroupId, user_id: &str) {
|
||||
handler.add_user_to_group(user_id, group_id).await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_bind_admin() {
|
||||
let sql_pool = get_in_memory_db().await;
|
||||
let config = ConfigurationBuilder::default()
|
||||
.ldap_user_dn("admin".to_string())
|
||||
.ldap_user_pass("test".to_string())
|
||||
.build()
|
||||
.unwrap();
|
||||
let handler = SqlBackendHandler::new(config, sql_pool);
|
||||
handler
|
||||
.bind(BindRequest {
|
||||
name: "admin".to_string(),
|
||||
password: "test".to_string(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_bind_user() {
|
||||
let sql_pool = get_initialized_db().await;
|
||||
let config = get_default_config();
|
||||
let handler = SqlBackendHandler::new(config, sql_pool.clone());
|
||||
insert_user(&handler, "bob", "bob00").await;
|
||||
|
||||
handler
|
||||
.bind(BindRequest {
|
||||
name: "bob".to_string(),
|
||||
password: "bob00".to_string(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
handler
|
||||
.bind(BindRequest {
|
||||
name: "andrew".to_string(),
|
||||
password: "bob00".to_string(),
|
||||
})
|
||||
.await
|
||||
.unwrap_err();
|
||||
handler
|
||||
.bind(BindRequest {
|
||||
name: "bob".to_string(),
|
||||
password: "wrong_password".to_string(),
|
||||
})
|
||||
.await
|
||||
.unwrap_err();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_user_no_password() {
|
||||
let sql_pool = get_initialized_db().await;
|
||||
let config = get_default_config();
|
||||
let handler = SqlBackendHandler::new(config, sql_pool.clone());
|
||||
insert_user_no_password(&handler, "bob").await;
|
||||
|
||||
handler
|
||||
.bind(BindRequest {
|
||||
name: "bob".to_string(),
|
||||
password: "bob00".to_string(),
|
||||
})
|
||||
.await
|
||||
.unwrap_err();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_list_users() {
|
||||
let sql_pool = get_initialized_db().await;
|
||||
let config = get_default_config();
|
||||
let handler = SqlBackendHandler::new(config, sql_pool);
|
||||
insert_user(&handler, "bob", "bob00").await;
|
||||
insert_user(&handler, "patrick", "pass").await;
|
||||
insert_user(&handler, "John", "Pa33w0rd!").await;
|
||||
{
|
||||
let users = handler
|
||||
.list_users(None)
|
||||
.await
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(|u| u.user_id)
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(users, vec!["John", "bob", "patrick"]);
|
||||
}
|
||||
{
|
||||
let users = handler
|
||||
.list_users(Some(RequestFilter::Equality(
|
||||
"user_id".to_string(),
|
||||
"bob".to_string(),
|
||||
)))
|
||||
.await
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(|u| u.user_id)
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(users, vec!["bob"]);
|
||||
}
|
||||
{
|
||||
let users = handler
|
||||
.list_users(Some(RequestFilter::Or(vec![
|
||||
RequestFilter::Equality("user_id".to_string(), "bob".to_string()),
|
||||
RequestFilter::Equality("user_id".to_string(), "John".to_string()),
|
||||
])))
|
||||
.await
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(|u| u.user_id)
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(users, vec!["John", "bob"]);
|
||||
}
|
||||
{
|
||||
let users = handler
|
||||
.list_users(Some(RequestFilter::Not(Box::new(RequestFilter::Equality(
|
||||
"user_id".to_string(),
|
||||
"bob".to_string(),
|
||||
)))))
|
||||
.await
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(|u| u.user_id)
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(users, vec!["John", "patrick"]);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_list_groups() {
|
||||
let sql_pool = get_initialized_db().await;
|
||||
let config = get_default_config();
|
||||
let handler = SqlBackendHandler::new(config, sql_pool.clone());
|
||||
insert_user(&handler, "bob", "bob00").await;
|
||||
insert_user(&handler, "patrick", "pass").await;
|
||||
insert_user(&handler, "John", "Pa33w0rd!").await;
|
||||
let group_1 = insert_group(&handler, "Best Group").await;
|
||||
let group_2 = insert_group(&handler, "Worst Group").await;
|
||||
insert_membership(&handler, group_1, "bob").await;
|
||||
insert_membership(&handler, group_1, "patrick").await;
|
||||
insert_membership(&handler, group_2, "patrick").await;
|
||||
insert_membership(&handler, group_2, "John").await;
|
||||
assert_eq!(
|
||||
handler.list_groups().await.unwrap(),
|
||||
vec![
|
||||
Group {
|
||||
display_name: "Best Group".to_string(),
|
||||
users: vec!["bob".to_string(), "patrick".to_string()]
|
||||
},
|
||||
Group {
|
||||
display_name: "Worst Group".to_string(),
|
||||
users: vec!["John".to_string(), "patrick".to_string()]
|
||||
}
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_user_details() {
|
||||
let sql_pool = get_initialized_db().await;
|
||||
let config = get_default_config();
|
||||
let handler = SqlBackendHandler::new(config, sql_pool);
|
||||
insert_user(&handler, "bob", "bob00").await;
|
||||
{
|
||||
let user = handler.get_user_details("bob").await.unwrap();
|
||||
assert_eq!(user.user_id, "bob".to_string());
|
||||
}
|
||||
{
|
||||
handler.get_user_details("John").await.unwrap_err();
|
||||
}
|
||||
}
|
||||
#[tokio::test]
|
||||
async fn test_get_user_groups() {
|
||||
let sql_pool = get_initialized_db().await;
|
||||
let config = get_default_config();
|
||||
let handler = SqlBackendHandler::new(config, sql_pool.clone());
|
||||
insert_user(&handler, "bob", "bob00").await;
|
||||
insert_user(&handler, "patrick", "pass").await;
|
||||
insert_user(&handler, "John", "Pa33w0rd!").await;
|
||||
let group_1 = insert_group(&handler, "Group1").await;
|
||||
let group_2 = insert_group(&handler, "Group2").await;
|
||||
insert_membership(&handler, group_1, "bob").await;
|
||||
insert_membership(&handler, group_1, "patrick").await;
|
||||
insert_membership(&handler, group_2, "patrick").await;
|
||||
let mut bob_groups = HashSet::new();
|
||||
bob_groups.insert("Group1".to_string());
|
||||
let mut patrick_groups = HashSet::new();
|
||||
patrick_groups.insert("Group1".to_string());
|
||||
patrick_groups.insert("Group2".to_string());
|
||||
assert_eq!(handler.get_user_groups("bob").await.unwrap(), bob_groups);
|
||||
assert_eq!(
|
||||
handler.get_user_groups("patrick").await.unwrap(),
|
||||
patrick_groups
|
||||
);
|
||||
assert_eq!(
|
||||
handler.get_user_groups("John").await.unwrap(),
|
||||
HashSet::new()
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_user() {
|
||||
let sql_pool = get_initialized_db().await;
|
||||
let config = get_default_config();
|
||||
let handler = SqlBackendHandler::new(config, sql_pool.clone());
|
||||
|
||||
insert_user(&handler, "val", "s3np4i").await;
|
||||
insert_user(&handler, "Hector", "Be$t").await;
|
||||
insert_user(&handler, "Jennz", "boupBoup").await;
|
||||
|
||||
// Remove a user
|
||||
let _request_result = handler.delete_user("Jennz").await.unwrap();
|
||||
|
||||
let users = handler
|
||||
.list_users(None)
|
||||
.await
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(|u| u.user_id)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(users, vec!["Hector", "val"]);
|
||||
|
||||
// Insert new user and remove two
|
||||
insert_user(&handler, "NewBoi", "Joni").await;
|
||||
let _request_result = handler.delete_user("Hector").await.unwrap();
|
||||
let _request_result = handler.delete_user("NewBoi").await.unwrap();
|
||||
|
||||
let users = handler
|
||||
.list_users(None)
|
||||
.await
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(|u| u.user_id)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(users, vec!["val"]);
|
||||
}
|
||||
}
|
||||
331
server/src/domain/sql_opaque_handler.rs
Normal file
331
server/src/domain/sql_opaque_handler.rs
Normal file
@@ -0,0 +1,331 @@
|
||||
use super::{
|
||||
error::*,
|
||||
handler::{BindRequest, LoginHandler},
|
||||
opaque_handler::*,
|
||||
sql_backend_handler::SqlBackendHandler,
|
||||
sql_tables::*,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use lldap_auth::opaque;
|
||||
use log::*;
|
||||
use sea_query::{Expr, Iden, Query};
|
||||
use sqlx::Row;
|
||||
|
||||
type SqlOpaqueHandler = SqlBackendHandler;
|
||||
|
||||
fn passwords_match(
|
||||
password_file_bytes: &[u8],
|
||||
clear_password: &str,
|
||||
server_setup: &opaque::server::ServerSetup,
|
||||
username: &str,
|
||||
) -> Result<()> {
|
||||
use opaque::{client, server};
|
||||
let mut rng = rand::rngs::OsRng;
|
||||
let client_login_start_result = client::login::start_login(clear_password, &mut rng)?;
|
||||
|
||||
let password_file = server::ServerRegistration::deserialize(password_file_bytes)
|
||||
.map_err(opaque::AuthenticationError::ProtocolError)?;
|
||||
let server_login_start_result = server::login::start_login(
|
||||
&mut rng,
|
||||
server_setup,
|
||||
Some(password_file),
|
||||
client_login_start_result.message,
|
||||
username,
|
||||
)?;
|
||||
client::login::finish_login(
|
||||
client_login_start_result.state,
|
||||
server_login_start_result.message,
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
impl SqlBackendHandler {
|
||||
fn get_orion_secret_key(&self) -> Result<orion::aead::SecretKey> {
|
||||
Ok(orion::aead::SecretKey::from_slice(
|
||||
self.config.get_server_keys().private(),
|
||||
)?)
|
||||
}
|
||||
|
||||
async fn get_password_file_for_user(
|
||||
&self,
|
||||
username: &str,
|
||||
) -> Result<Option<opaque::server::ServerRegistration>> {
|
||||
// Fetch the previously registered password file from the DB.
|
||||
let password_file_bytes = {
|
||||
let query = Query::select()
|
||||
.column(Users::PasswordHash)
|
||||
.from(Users::Table)
|
||||
.and_where(Expr::col(Users::UserId).eq(username))
|
||||
.to_string(DbQueryBuilder {});
|
||||
if let Some(row) = sqlx::query(&query).fetch_optional(&self.sql_pool).await? {
|
||||
if let Some(bytes) =
|
||||
row.get::<Option<Vec<u8>>, _>(&*Users::PasswordHash.to_string())
|
||||
{
|
||||
bytes
|
||||
} else {
|
||||
// No password set.
|
||||
return Ok(None);
|
||||
}
|
||||
} else {
|
||||
// No such user.
|
||||
return Ok(None);
|
||||
}
|
||||
};
|
||||
opaque::server::ServerRegistration::deserialize(&password_file_bytes)
|
||||
.map(Option::Some)
|
||||
.map_err(|_| {
|
||||
DomainError::InternalError(format!("Corrupted password file for {}", username))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LoginHandler for SqlBackendHandler {
|
||||
async fn bind(&self, request: BindRequest) -> Result<()> {
|
||||
if request.name == self.config.ldap_user_dn {
|
||||
if request.password == self.config.ldap_user_pass {
|
||||
return Ok(());
|
||||
} else {
|
||||
debug!(r#"Invalid password for LDAP bind user"#);
|
||||
return Err(DomainError::AuthenticationError(request.name));
|
||||
}
|
||||
}
|
||||
let query = Query::select()
|
||||
.column(Users::PasswordHash)
|
||||
.from(Users::Table)
|
||||
.and_where(Expr::col(Users::UserId).eq(request.name.as_str()))
|
||||
.to_string(DbQueryBuilder {});
|
||||
if let Ok(row) = sqlx::query(&query).fetch_one(&self.sql_pool).await {
|
||||
if let Some(password_hash) =
|
||||
row.get::<Option<Vec<u8>>, _>(&*Users::PasswordHash.to_string())
|
||||
{
|
||||
if let Err(e) = passwords_match(
|
||||
&password_hash,
|
||||
&request.password,
|
||||
self.config.get_server_setup(),
|
||||
&request.name,
|
||||
) {
|
||||
debug!(r#"Invalid password for "{}": {}"#, request.name, e);
|
||||
} else {
|
||||
return Ok(());
|
||||
}
|
||||
} else {
|
||||
debug!(r#"User "{}" has no password"#, request.name);
|
||||
}
|
||||
} else {
|
||||
debug!(r#"No user found for "{}""#, request.name);
|
||||
}
|
||||
Err(DomainError::AuthenticationError(request.name))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl OpaqueHandler for SqlOpaqueHandler {
|
||||
async fn login_start(
|
||||
&self,
|
||||
request: login::ClientLoginStartRequest,
|
||||
) -> Result<login::ServerLoginStartResponse> {
|
||||
let maybe_password_file = self.get_password_file_for_user(&request.username).await?;
|
||||
|
||||
let mut rng = rand::rngs::OsRng;
|
||||
// Get the CredentialResponse for the user, or a dummy one if no user/no password.
|
||||
let start_response = opaque::server::login::start_login(
|
||||
&mut rng,
|
||||
self.config.get_server_setup(),
|
||||
maybe_password_file,
|
||||
request.login_start_request,
|
||||
&request.username,
|
||||
)?;
|
||||
let secret_key = self.get_orion_secret_key()?;
|
||||
let server_data = login::ServerData {
|
||||
username: request.username,
|
||||
server_login: start_response.state,
|
||||
};
|
||||
let encrypted_state = orion::aead::seal(&secret_key, &bincode::serialize(&server_data)?)?;
|
||||
|
||||
Ok(login::ServerLoginStartResponse {
|
||||
server_data: base64::encode(&encrypted_state),
|
||||
credential_response: start_response.message,
|
||||
})
|
||||
}
|
||||
|
||||
async fn login_finish(&self, request: login::ClientLoginFinishRequest) -> Result<String> {
|
||||
let secret_key = self.get_orion_secret_key()?;
|
||||
let login::ServerData {
|
||||
username,
|
||||
server_login,
|
||||
} = bincode::deserialize(&orion::aead::open(
|
||||
&secret_key,
|
||||
&base64::decode(&request.server_data)?,
|
||||
)?)?;
|
||||
// Finish the login: this makes sure the client data is correct, and gives a session key we
|
||||
// don't need.
|
||||
let _session_key =
|
||||
opaque::server::login::finish_login(server_login, request.credential_finalization)?
|
||||
.session_key;
|
||||
|
||||
Ok(username)
|
||||
}
|
||||
|
||||
async fn registration_start(
|
||||
&self,
|
||||
request: registration::ClientRegistrationStartRequest,
|
||||
) -> Result<registration::ServerRegistrationStartResponse> {
|
||||
// Generate the server-side key and derive the data to send back.
|
||||
let start_response = opaque::server::registration::start_registration(
|
||||
self.config.get_server_setup(),
|
||||
request.registration_start_request,
|
||||
&request.username,
|
||||
)?;
|
||||
let secret_key = self.get_orion_secret_key()?;
|
||||
let server_data = registration::ServerData {
|
||||
username: request.username,
|
||||
};
|
||||
let encrypted_state = orion::aead::seal(&secret_key, &bincode::serialize(&server_data)?)?;
|
||||
Ok(registration::ServerRegistrationStartResponse {
|
||||
server_data: base64::encode(encrypted_state),
|
||||
registration_response: start_response.message,
|
||||
})
|
||||
}
|
||||
|
||||
async fn registration_finish(
|
||||
&self,
|
||||
request: registration::ClientRegistrationFinishRequest,
|
||||
) -> Result<()> {
|
||||
let secret_key = self.get_orion_secret_key()?;
|
||||
let registration::ServerData { username } = bincode::deserialize(&orion::aead::open(
|
||||
&secret_key,
|
||||
&base64::decode(&request.server_data)?,
|
||||
)?)?;
|
||||
|
||||
let password_file =
|
||||
opaque::server::registration::get_password_file(request.registration_upload);
|
||||
{
|
||||
// Set the user password to the new password.
|
||||
let update_query = Query::update()
|
||||
.table(Users::Table)
|
||||
.values(vec![(
|
||||
Users::PasswordHash,
|
||||
password_file.serialize().into(),
|
||||
)])
|
||||
.and_where(Expr::col(Users::UserId).eq(username))
|
||||
.to_string(DbQueryBuilder {});
|
||||
sqlx::query(&update_query).execute(&self.sql_pool).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Convenience function to set a user's password.
|
||||
pub(crate) async fn register_password(
|
||||
opaque_handler: &SqlOpaqueHandler,
|
||||
username: &str,
|
||||
password: &str,
|
||||
) -> Result<()> {
|
||||
let mut rng = rand::rngs::OsRng;
|
||||
use registration::*;
|
||||
let registration_start = opaque::client::registration::start_registration(password, &mut rng)?;
|
||||
let start_response = opaque_handler
|
||||
.registration_start(ClientRegistrationStartRequest {
|
||||
username: username.to_string(),
|
||||
registration_start_request: registration_start.message,
|
||||
})
|
||||
.await?;
|
||||
let registration_finish = opaque::client::registration::finish_registration(
|
||||
registration_start.state,
|
||||
start_response.registration_response,
|
||||
&mut rng,
|
||||
)?;
|
||||
opaque_handler
|
||||
.registration_finish(ClientRegistrationFinishRequest {
|
||||
server_data: start_response.server_data,
|
||||
registration_upload: registration_finish.message,
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{
|
||||
domain::{
|
||||
handler::{BackendHandler, CreateUserRequest},
|
||||
sql_backend_handler::SqlBackendHandler,
|
||||
sql_tables::init_table,
|
||||
},
|
||||
infra::configuration::{Configuration, ConfigurationBuilder},
|
||||
};
|
||||
|
||||
fn get_default_config() -> Configuration {
|
||||
ConfigurationBuilder::default()
|
||||
.verbose(true)
|
||||
.build()
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
async fn get_in_memory_db() -> Pool {
|
||||
PoolOptions::new().connect("sqlite::memory:").await.unwrap()
|
||||
}
|
||||
|
||||
async fn get_initialized_db() -> Pool {
|
||||
let sql_pool = get_in_memory_db().await;
|
||||
init_table(&sql_pool).await.unwrap();
|
||||
sql_pool
|
||||
}
|
||||
|
||||
async fn insert_user_no_password(handler: &SqlBackendHandler, name: &str) {
|
||||
handler
|
||||
.create_user(CreateUserRequest {
|
||||
user_id: name.to_string(),
|
||||
email: "bob@bob.bob".to_string(),
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
async fn attempt_login(
|
||||
opaque_handler: &SqlOpaqueHandler,
|
||||
username: &str,
|
||||
password: &str,
|
||||
) -> Result<()> {
|
||||
let mut rng = rand::rngs::OsRng;
|
||||
use login::*;
|
||||
let login_start = opaque::client::login::start_login(password, &mut rng)?;
|
||||
let start_response = opaque_handler
|
||||
.login_start(ClientLoginStartRequest {
|
||||
username: username.to_string(),
|
||||
login_start_request: login_start.message,
|
||||
})
|
||||
.await?;
|
||||
let login_finish = opaque::client::login::finish_login(
|
||||
login_start.state,
|
||||
start_response.credential_response,
|
||||
)?;
|
||||
opaque_handler
|
||||
.login_finish(ClientLoginFinishRequest {
|
||||
server_data: start_response.server_data,
|
||||
credential_finalization: login_finish.message,
|
||||
})
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_flow() -> Result<()> {
|
||||
let sql_pool = get_initialized_db().await;
|
||||
let config = get_default_config();
|
||||
let backend_handler = SqlBackendHandler::new(config.clone(), sql_pool.clone());
|
||||
let opaque_handler = SqlOpaqueHandler::new(config, sql_pool);
|
||||
insert_user_no_password(&backend_handler, "bob").await;
|
||||
attempt_login(&opaque_handler, "bob", "bob00")
|
||||
.await
|
||||
.unwrap_err();
|
||||
register_password(&opaque_handler, "bob", "bob00").await?;
|
||||
attempt_login(&opaque_handler, "bob", "wrong_password")
|
||||
.await
|
||||
.unwrap_err();
|
||||
attempt_login(&opaque_handler, "bob", "bob00").await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
152
server/src/domain/sql_tables.rs
Normal file
152
server/src/domain/sql_tables.rs
Normal file
@@ -0,0 +1,152 @@
|
||||
use sea_query::*;
|
||||
|
||||
pub type Pool = sqlx::sqlite::SqlitePool;
|
||||
pub type PoolOptions = sqlx::sqlite::SqlitePoolOptions;
|
||||
pub type DbRow = sqlx::sqlite::SqliteRow;
|
||||
pub type DbQueryBuilder = SqliteQueryBuilder;
|
||||
|
||||
#[derive(Iden)]
|
||||
pub enum Users {
|
||||
Table,
|
||||
UserId,
|
||||
Email,
|
||||
DisplayName,
|
||||
FirstName,
|
||||
LastName,
|
||||
Avatar,
|
||||
CreationDate,
|
||||
PasswordHash,
|
||||
TotpSecret,
|
||||
MfaType,
|
||||
}
|
||||
|
||||
#[derive(Iden)]
|
||||
pub enum Groups {
|
||||
Table,
|
||||
GroupId,
|
||||
DisplayName,
|
||||
}
|
||||
|
||||
#[derive(Iden)]
|
||||
pub enum Memberships {
|
||||
Table,
|
||||
UserId,
|
||||
GroupId,
|
||||
}
|
||||
|
||||
pub async fn init_table(pool: &Pool) -> sqlx::Result<()> {
|
||||
// SQLite needs this pragma to be turned on. Other DB might not understand this, so ignore the
|
||||
// error.
|
||||
let _ = sqlx::query("PRAGMA foreign_keys = ON").execute(pool).await;
|
||||
sqlx::query(
|
||||
&Table::create()
|
||||
.table(Users::Table)
|
||||
.if_not_exists()
|
||||
.col(
|
||||
ColumnDef::new(Users::UserId)
|
||||
.string_len(255)
|
||||
.not_null()
|
||||
.primary_key(),
|
||||
)
|
||||
.col(ColumnDef::new(Users::Email).string_len(255).not_null())
|
||||
.col(ColumnDef::new(Users::DisplayName).string_len(255))
|
||||
.col(ColumnDef::new(Users::FirstName).string_len(255))
|
||||
.col(ColumnDef::new(Users::LastName).string_len(255))
|
||||
.col(ColumnDef::new(Users::Avatar).binary())
|
||||
.col(ColumnDef::new(Users::CreationDate).date_time().not_null())
|
||||
.col(ColumnDef::new(Users::PasswordHash).binary())
|
||||
.col(ColumnDef::new(Users::TotpSecret).string_len(64))
|
||||
.col(ColumnDef::new(Users::MfaType).string_len(64))
|
||||
.to_string(DbQueryBuilder {}),
|
||||
)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
sqlx::query(
|
||||
&Table::create()
|
||||
.table(Groups::Table)
|
||||
.if_not_exists()
|
||||
.col(
|
||||
ColumnDef::new(Groups::GroupId)
|
||||
.integer()
|
||||
.not_null()
|
||||
.primary_key(),
|
||||
)
|
||||
.col(
|
||||
ColumnDef::new(Groups::DisplayName)
|
||||
.string_len(255)
|
||||
.unique_key()
|
||||
.not_null(),
|
||||
)
|
||||
.to_string(DbQueryBuilder {}),
|
||||
)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
sqlx::query(
|
||||
&Table::create()
|
||||
.table(Memberships::Table)
|
||||
.if_not_exists()
|
||||
.col(
|
||||
ColumnDef::new(Memberships::UserId)
|
||||
.string_len(255)
|
||||
.not_null(),
|
||||
)
|
||||
.col(ColumnDef::new(Memberships::GroupId).integer().not_null())
|
||||
.foreign_key(
|
||||
ForeignKey::create()
|
||||
.name("MembershipUserForeignKey")
|
||||
.table(Memberships::Table, Users::Table)
|
||||
.col(Memberships::UserId, Users::UserId)
|
||||
.on_delete(ForeignKeyAction::Cascade)
|
||||
.on_update(ForeignKeyAction::Cascade),
|
||||
)
|
||||
.foreign_key(
|
||||
ForeignKey::create()
|
||||
.name("MembershipGroupForeignKey")
|
||||
.table(Memberships::Table, Groups::Table)
|
||||
.col(Memberships::GroupId, Groups::GroupId)
|
||||
.on_delete(ForeignKeyAction::Cascade)
|
||||
.on_update(ForeignKeyAction::Cascade),
|
||||
)
|
||||
.to_string(DbQueryBuilder {}),
|
||||
)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use chrono::prelude::*;
|
||||
use sqlx::{Column, Row};
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn test_init_table() {
|
||||
let sql_pool = PoolOptions::new().connect("sqlite::memory:").await.unwrap();
|
||||
init_table(&sql_pool).await.unwrap();
|
||||
sqlx::query(r#"INSERT INTO users
|
||||
(user_id, email, display_name, first_name, last_name, creation_date, password_hash)
|
||||
VALUES ("bôb", "böb@bob.bob", "Bob Bobbersön", "Bob", "Bobberson", "1970-01-01 00:00:00", "bob00")"#).execute(&sql_pool).await.unwrap();
|
||||
let row =
|
||||
sqlx::query(r#"SELECT display_name, creation_date FROM users WHERE user_id = "bôb""#)
|
||||
.fetch_one(&sql_pool)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(row.column(0).name(), "display_name");
|
||||
assert_eq!(row.get::<String, _>("display_name"), "Bob Bobbersön");
|
||||
assert_eq!(
|
||||
row.get::<DateTime<Utc>, _>("creation_date"),
|
||||
Utc.timestamp(0, 0),
|
||||
);
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn test_already_init_table() {
|
||||
let sql_pool = PoolOptions::new().connect("sqlite::memory:").await.unwrap();
|
||||
init_table(&sql_pool).await.unwrap();
|
||||
init_table(&sql_pool).await.unwrap();
|
||||
}
|
||||
}
|
||||
415
server/src/infra/auth_service.rs
Normal file
415
server/src/infra/auth_service.rs
Normal file
@@ -0,0 +1,415 @@
|
||||
use crate::{
|
||||
domain::{
|
||||
error::DomainError,
|
||||
handler::{BackendHandler, BindRequest, LoginHandler},
|
||||
opaque_handler::OpaqueHandler,
|
||||
},
|
||||
infra::{
|
||||
tcp_backend_handler::*,
|
||||
tcp_server::{error_to_http_response, AppState},
|
||||
},
|
||||
};
|
||||
use actix_web::{
|
||||
cookie::{Cookie, SameSite},
|
||||
dev::{Service, ServiceRequest, ServiceResponse, Transform},
|
||||
error::{ErrorBadRequest, ErrorUnauthorized},
|
||||
web, HttpRequest, HttpResponse,
|
||||
};
|
||||
use anyhow::Result;
|
||||
use chrono::prelude::*;
|
||||
use futures::future::{ok, Ready};
|
||||
use futures_util::{FutureExt, TryFutureExt};
|
||||
use hmac::Hmac;
|
||||
use jwt::{SignWithKey, VerifyWithKey};
|
||||
use lldap_auth::{login, registration, JWTClaims};
|
||||
use sha2::Sha512;
|
||||
use std::collections::{hash_map::DefaultHasher, HashSet};
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use time::ext::NumericalDuration;
|
||||
|
||||
type Token<S> = jwt::Token<jwt::Header, JWTClaims, S>;
|
||||
type SignedToken = Token<jwt::token::Signed>;
|
||||
|
||||
fn create_jwt(key: &Hmac<Sha512>, user: String, groups: HashSet<String>) -> SignedToken {
|
||||
let claims = JWTClaims {
|
||||
exp: Utc::now() + chrono::Duration::days(1),
|
||||
iat: Utc::now(),
|
||||
user,
|
||||
groups,
|
||||
};
|
||||
let header = jwt::Header {
|
||||
algorithm: jwt::AlgorithmType::Hs512,
|
||||
..Default::default()
|
||||
};
|
||||
jwt::Token::new(header, claims).sign_with_key(key).unwrap()
|
||||
}
|
||||
|
||||
fn get_refresh_token_from_cookie(
|
||||
request: HttpRequest,
|
||||
) -> std::result::Result<(u64, String), HttpResponse> {
|
||||
match request.cookie("refresh_token") {
|
||||
None => Err(HttpResponse::Unauthorized().body("Missing refresh token")),
|
||||
Some(t) => match t.value().split_once("+") {
|
||||
None => Err(HttpResponse::Unauthorized().body("Invalid refresh token")),
|
||||
Some((token, u)) => {
|
||||
let refresh_token_hash = {
|
||||
let mut s = DefaultHasher::new();
|
||||
token.hash(&mut s);
|
||||
s.finish()
|
||||
};
|
||||
Ok((refresh_token_hash, u.to_string()))
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_refresh<Backend>(
|
||||
data: web::Data<AppState<Backend>>,
|
||||
request: HttpRequest,
|
||||
) -> HttpResponse
|
||||
where
|
||||
Backend: TcpBackendHandler + BackendHandler + 'static,
|
||||
{
|
||||
let backend_handler = &data.backend_handler;
|
||||
let jwt_key = &data.jwt_key;
|
||||
let (refresh_token_hash, user) = match get_refresh_token_from_cookie(request) {
|
||||
Ok(t) => t,
|
||||
Err(http_response) => return http_response,
|
||||
};
|
||||
let res_found = data
|
||||
.backend_handler
|
||||
.check_token(refresh_token_hash, &user)
|
||||
.await;
|
||||
// Async closures are not supported yet.
|
||||
match res_found {
|
||||
Ok(found) => {
|
||||
if found {
|
||||
backend_handler.get_user_groups(&user).await
|
||||
} else {
|
||||
Err(DomainError::AuthenticationError(
|
||||
"Invalid refresh token".to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
.map(|groups| create_jwt(jwt_key, user.to_string(), groups))
|
||||
.map(|token| {
|
||||
HttpResponse::Ok()
|
||||
.cookie(
|
||||
Cookie::build("token", token.as_str())
|
||||
.max_age(1.days())
|
||||
.path("/api")
|
||||
.http_only(true)
|
||||
.same_site(SameSite::Strict)
|
||||
.finish(),
|
||||
)
|
||||
.body(token.as_str().to_owned())
|
||||
})
|
||||
.unwrap_or_else(error_to_http_response)
|
||||
}
|
||||
|
||||
async fn get_logout<Backend>(
|
||||
data: web::Data<AppState<Backend>>,
|
||||
request: HttpRequest,
|
||||
) -> HttpResponse
|
||||
where
|
||||
Backend: TcpBackendHandler + BackendHandler + 'static,
|
||||
{
|
||||
let (refresh_token_hash, user) = match get_refresh_token_from_cookie(request) {
|
||||
Ok(t) => t,
|
||||
Err(http_response) => return http_response,
|
||||
};
|
||||
if let Err(response) = data
|
||||
.backend_handler
|
||||
.delete_refresh_token(refresh_token_hash)
|
||||
.map_err(error_to_http_response)
|
||||
.await
|
||||
{
|
||||
return response;
|
||||
};
|
||||
match data
|
||||
.backend_handler
|
||||
.blacklist_jwts(&user)
|
||||
.map_err(error_to_http_response)
|
||||
.await
|
||||
{
|
||||
Ok(new_blacklisted_jwts) => {
|
||||
let mut jwt_blacklist = data.jwt_blacklist.write().unwrap();
|
||||
for jwt in new_blacklisted_jwts {
|
||||
jwt_blacklist.insert(jwt);
|
||||
}
|
||||
}
|
||||
Err(response) => return response,
|
||||
};
|
||||
HttpResponse::Ok()
|
||||
.cookie(
|
||||
Cookie::build("token", "")
|
||||
.max_age(0.days())
|
||||
.path("/api")
|
||||
.http_only(true)
|
||||
.same_site(SameSite::Strict)
|
||||
.finish(),
|
||||
)
|
||||
.cookie(
|
||||
Cookie::build("refresh_token", "")
|
||||
.max_age(0.days())
|
||||
.path("/auth")
|
||||
.http_only(true)
|
||||
.same_site(SameSite::Strict)
|
||||
.finish(),
|
||||
)
|
||||
.finish()
|
||||
}
|
||||
|
||||
pub(crate) fn error_to_api_response<T>(error: DomainError) -> ApiResult<T> {
|
||||
ApiResult::Right(error_to_http_response(error))
|
||||
}
|
||||
|
||||
pub type ApiResult<M> = actix_web::Either<web::Json<M>, HttpResponse>;
|
||||
|
||||
async fn opaque_login_start<Backend>(
|
||||
data: web::Data<AppState<Backend>>,
|
||||
request: web::Json<login::ClientLoginStartRequest>,
|
||||
) -> ApiResult<login::ServerLoginStartResponse>
|
||||
where
|
||||
Backend: OpaqueHandler + 'static,
|
||||
{
|
||||
data.backend_handler
|
||||
.login_start(request.into_inner())
|
||||
.await
|
||||
.map(|res| ApiResult::Left(web::Json(res)))
|
||||
.unwrap_or_else(error_to_api_response)
|
||||
}
|
||||
|
||||
async fn get_login_successful_response<Backend>(
|
||||
data: &web::Data<AppState<Backend>>,
|
||||
name: &str,
|
||||
) -> HttpResponse
|
||||
where
|
||||
Backend: TcpBackendHandler + BackendHandler,
|
||||
{
|
||||
// The authentication was successful, we need to fetch the groups to create the JWT
|
||||
// token.
|
||||
data.backend_handler
|
||||
.get_user_groups(name)
|
||||
.and_then(|g| async { Ok((g, data.backend_handler.create_refresh_token(name).await?)) })
|
||||
.await
|
||||
.map(|(groups, (refresh_token, max_age))| {
|
||||
let token = create_jwt(&data.jwt_key, name.to_string(), groups);
|
||||
HttpResponse::Ok()
|
||||
.cookie(
|
||||
Cookie::build("token", token.as_str())
|
||||
.max_age(1.days())
|
||||
.path("/api")
|
||||
.http_only(true)
|
||||
.same_site(SameSite::Strict)
|
||||
.finish(),
|
||||
)
|
||||
.cookie(
|
||||
Cookie::build("refresh_token", refresh_token + "+" + name)
|
||||
.max_age(max_age.num_days().days())
|
||||
.path("/auth")
|
||||
.http_only(true)
|
||||
.same_site(SameSite::Strict)
|
||||
.finish(),
|
||||
)
|
||||
.body(token.as_str().to_owned())
|
||||
})
|
||||
.unwrap_or_else(error_to_http_response)
|
||||
}
|
||||
|
||||
async fn opaque_login_finish<Backend>(
|
||||
data: web::Data<AppState<Backend>>,
|
||||
request: web::Json<login::ClientLoginFinishRequest>,
|
||||
) -> HttpResponse
|
||||
where
|
||||
Backend: TcpBackendHandler + BackendHandler + OpaqueHandler + 'static,
|
||||
{
|
||||
let name = match data
|
||||
.backend_handler
|
||||
.login_finish(request.into_inner())
|
||||
.await
|
||||
{
|
||||
Ok(n) => n,
|
||||
Err(e) => return error_to_http_response(e),
|
||||
};
|
||||
get_login_successful_response(&data, &name).await
|
||||
}
|
||||
|
||||
async fn post_authorize<Backend>(
|
||||
data: web::Data<AppState<Backend>>,
|
||||
request: web::Json<BindRequest>,
|
||||
) -> HttpResponse
|
||||
where
|
||||
Backend: TcpBackendHandler + BackendHandler + LoginHandler + 'static,
|
||||
{
|
||||
let name = request.name.clone();
|
||||
if let Err(e) = data.backend_handler.bind(request.into_inner()).await {
|
||||
return error_to_http_response(e);
|
||||
}
|
||||
get_login_successful_response(&data, &name).await
|
||||
}
|
||||
|
||||
async fn opaque_register_start<Backend>(
|
||||
data: web::Data<AppState<Backend>>,
|
||||
request: web::Json<registration::ClientRegistrationStartRequest>,
|
||||
) -> ApiResult<registration::ServerRegistrationStartResponse>
|
||||
where
|
||||
Backend: OpaqueHandler + 'static,
|
||||
{
|
||||
data.backend_handler
|
||||
.registration_start(request.into_inner())
|
||||
.await
|
||||
.map(|res| ApiResult::Left(web::Json(res)))
|
||||
.unwrap_or_else(error_to_api_response)
|
||||
}
|
||||
|
||||
async fn opaque_register_finish<Backend>(
|
||||
data: web::Data<AppState<Backend>>,
|
||||
request: web::Json<registration::ClientRegistrationFinishRequest>,
|
||||
) -> HttpResponse
|
||||
where
|
||||
Backend: TcpBackendHandler + BackendHandler + OpaqueHandler + 'static,
|
||||
{
|
||||
if let Err(e) = data
|
||||
.backend_handler
|
||||
.registration_finish(request.into_inner())
|
||||
.await
|
||||
{
|
||||
return error_to_http_response(e);
|
||||
}
|
||||
HttpResponse::Ok().finish()
|
||||
}
|
||||
|
||||
pub struct CookieToHeaderTranslatorFactory;
|
||||
|
||||
impl<S> Transform<S, ServiceRequest> for CookieToHeaderTranslatorFactory
|
||||
where
|
||||
S: Service<ServiceRequest, Response = ServiceResponse, Error = actix_web::Error>,
|
||||
S::Future: 'static,
|
||||
{
|
||||
type Response = ServiceResponse;
|
||||
type Error = actix_web::Error;
|
||||
type InitError = ();
|
||||
type Transform = CookieToHeaderTranslator<S>;
|
||||
type Future = Ready<Result<Self::Transform, Self::InitError>>;
|
||||
|
||||
fn new_transform(&self, service: S) -> Self::Future {
|
||||
ok(CookieToHeaderTranslator { service })
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CookieToHeaderTranslator<S> {
|
||||
service: S,
|
||||
}
|
||||
|
||||
impl<S> Service<ServiceRequest> for CookieToHeaderTranslator<S>
|
||||
where
|
||||
S: Service<ServiceRequest, Response = ServiceResponse, Error = actix_web::Error>,
|
||||
S::Future: 'static,
|
||||
{
|
||||
type Response = ServiceResponse;
|
||||
type Error = actix_web::Error;
|
||||
#[allow(clippy::type_complexity)]
|
||||
type Future = Pin<Box<dyn core::future::Future<Output = Result<Self::Response, Self::Error>>>>;
|
||||
|
||||
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.service.poll_ready(cx)
|
||||
}
|
||||
|
||||
fn call(&self, mut req: ServiceRequest) -> Self::Future {
|
||||
if let Some(token_cookie) = req.cookie("token") {
|
||||
if let Ok(header_value) = actix_http::header::HeaderValue::from_str(&format!(
|
||||
"Bearer {}",
|
||||
token_cookie.value()
|
||||
)) {
|
||||
req.headers_mut()
|
||||
.insert(actix_http::header::AUTHORIZATION, header_value);
|
||||
} else {
|
||||
return async move {
|
||||
Ok(req.error_response(ErrorBadRequest("Invalid token cookie")))
|
||||
}
|
||||
.boxed_local();
|
||||
}
|
||||
};
|
||||
|
||||
Box::pin(self.service.call(req))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ValidationResults {
|
||||
pub user: String,
|
||||
pub is_admin: bool,
|
||||
}
|
||||
|
||||
impl ValidationResults {
|
||||
#[cfg(test)]
|
||||
pub fn admin() -> Self {
|
||||
Self {
|
||||
user: "admin".to_string(),
|
||||
is_admin: true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn can_access(&self, user: &str) -> bool {
|
||||
self.is_admin || self.user == user
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn check_if_token_is_valid<Backend>(
|
||||
state: &AppState<Backend>,
|
||||
token_str: &str,
|
||||
) -> Result<ValidationResults, actix_web::Error> {
|
||||
let token: Token<_> = VerifyWithKey::verify_with_key(token_str, &state.jwt_key)
|
||||
.map_err(|_| ErrorUnauthorized("Invalid JWT"))?;
|
||||
if token.claims().exp.lt(&Utc::now()) {
|
||||
return Err(ErrorUnauthorized("Expired JWT"));
|
||||
}
|
||||
if token.header().algorithm != jwt::AlgorithmType::Hs512 {
|
||||
return Err(ErrorUnauthorized(format!(
|
||||
"Unsupported JWT algorithm: '{:?}'. Supported ones are: ['HS512']",
|
||||
token.header().algorithm
|
||||
)));
|
||||
}
|
||||
let jwt_hash = {
|
||||
let mut s = DefaultHasher::new();
|
||||
token_str.hash(&mut s);
|
||||
s.finish()
|
||||
};
|
||||
if state.jwt_blacklist.read().unwrap().contains(&jwt_hash) {
|
||||
return Err(ErrorUnauthorized("JWT was logged out"));
|
||||
}
|
||||
let is_admin = token.claims().groups.contains("lldap_admin");
|
||||
Ok(ValidationResults {
|
||||
user: token.claims().user.clone(),
|
||||
is_admin,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn configure_server<Backend>(cfg: &mut web::ServiceConfig)
|
||||
where
|
||||
Backend: TcpBackendHandler + LoginHandler + OpaqueHandler + BackendHandler + 'static,
|
||||
{
|
||||
cfg.service(web::resource("").route(web::post().to(post_authorize::<Backend>)))
|
||||
.service(
|
||||
web::resource("/opaque/login/start")
|
||||
.route(web::post().to(opaque_login_start::<Backend>)),
|
||||
)
|
||||
.service(
|
||||
web::resource("/opaque/login/finish")
|
||||
.route(web::post().to(opaque_login_finish::<Backend>)),
|
||||
)
|
||||
.service(
|
||||
web::resource("/opaque/register/start")
|
||||
.route(web::post().to(opaque_register_start::<Backend>)),
|
||||
)
|
||||
.service(
|
||||
web::resource("/opaque/register/finish")
|
||||
.route(web::post().to(opaque_register_finish::<Backend>)),
|
||||
)
|
||||
.service(web::resource("/refresh").route(web::get().to(get_refresh::<Backend>)))
|
||||
.service(web::resource("/logout").route(web::get().to(get_logout::<Backend>)));
|
||||
}
|
||||
50
server/src/infra/cli.rs
Normal file
50
server/src/infra/cli.rs
Normal file
@@ -0,0 +1,50 @@
|
||||
use clap::Clap;
|
||||
|
||||
/// lldap is a lightweight LDAP server
|
||||
#[derive(Debug, Clap, Clone)]
|
||||
#[clap(version = "0.1", author = "The LLDAP team")]
|
||||
pub struct CLIOpts {
|
||||
/// Export
|
||||
#[clap(subcommand)]
|
||||
pub command: Command,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clap, Clone)]
|
||||
pub enum Command {
|
||||
/// Export the GraphQL schema to *.graphql.
|
||||
#[clap(name = "export_graphql_schema")]
|
||||
ExportGraphQLSchema(ExportGraphQLSchemaOpts),
|
||||
/// Run the LDAP and GraphQL server.
|
||||
#[clap(name = "run")]
|
||||
Run(RunOpts),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clap, Clone)]
|
||||
pub struct RunOpts {
|
||||
/// Change config file name
|
||||
#[clap(short, long, default_value = "lldap_config.toml")]
|
||||
pub config_file: String,
|
||||
|
||||
/// Change ldap port. Default: 389
|
||||
#[clap(long)]
|
||||
pub ldap_port: Option<u16>,
|
||||
|
||||
/// Change ldap ssl port. Default: 636
|
||||
#[clap(long)]
|
||||
pub ldaps_port: Option<u16>,
|
||||
|
||||
/// Set verbose logging
|
||||
#[clap(short, long)]
|
||||
pub verbose: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clap, Clone)]
|
||||
pub struct ExportGraphQLSchemaOpts {
|
||||
/// Output to a file. If not specified, the config is printed to the standard output.
|
||||
#[clap(short, long)]
|
||||
pub output_file: Option<String>,
|
||||
}
|
||||
|
||||
pub fn init() -> CLIOpts {
|
||||
CLIOpts::parse()
|
||||
}
|
||||
121
server/src/infra/configuration.rs
Normal file
121
server/src/infra/configuration.rs
Normal file
@@ -0,0 +1,121 @@
|
||||
use anyhow::{Context, Result};
|
||||
use figment::{
|
||||
providers::{Env, Format, Serialized, Toml},
|
||||
Figment,
|
||||
};
|
||||
use lldap_auth::opaque::{server::ServerSetup, KeyPair};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::infra::cli::RunOpts;
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, derive_builder::Builder)]
|
||||
#[builder(
|
||||
pattern = "owned",
|
||||
default = "Configuration::default()",
|
||||
build_fn(name = "private_build", validate = "Self::validate")
|
||||
)]
|
||||
pub struct Configuration {
|
||||
pub ldap_port: u16,
|
||||
pub ldaps_port: u16,
|
||||
pub http_port: u16,
|
||||
pub jwt_secret: String,
|
||||
pub ldap_base_dn: String,
|
||||
pub ldap_user_dn: String,
|
||||
pub ldap_user_pass: String,
|
||||
pub database_url: String,
|
||||
pub verbose: bool,
|
||||
pub key_file: String,
|
||||
#[serde(skip)]
|
||||
#[builder(field(private), setter(strip_option))]
|
||||
server_setup: Option<ServerSetup>,
|
||||
}
|
||||
|
||||
impl ConfigurationBuilder {
|
||||
#[cfg(test)]
|
||||
pub fn build(self) -> Result<Configuration> {
|
||||
let server_setup = get_server_setup(self.key_file.as_deref().unwrap_or("server_key"))?;
|
||||
Ok(self.server_setup(server_setup).private_build()?)
|
||||
}
|
||||
|
||||
fn validate(&self) -> Result<(), String> {
|
||||
if self.server_setup.is_none() {
|
||||
Err("Don't use `private_build`, use `build` instead".to_string())
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Configuration {
|
||||
pub fn get_server_setup(&self) -> &ServerSetup {
|
||||
self.server_setup.as_ref().unwrap()
|
||||
}
|
||||
|
||||
pub fn get_server_keys(&self) -> &KeyPair {
|
||||
self.get_server_setup().keypair()
|
||||
}
|
||||
|
||||
fn merge_with_cli(mut self: Configuration, cli_opts: RunOpts) -> Configuration {
|
||||
if cli_opts.verbose {
|
||||
self.verbose = true;
|
||||
}
|
||||
|
||||
if let Some(port) = cli_opts.ldap_port {
|
||||
self.ldap_port = port;
|
||||
}
|
||||
|
||||
if let Some(port) = cli_opts.ldaps_port {
|
||||
self.ldaps_port = port;
|
||||
}
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
pub(super) fn default() -> Self {
|
||||
Configuration {
|
||||
ldap_port: 3890,
|
||||
ldaps_port: 6360,
|
||||
http_port: 17170,
|
||||
jwt_secret: String::from("secretjwtsecret"),
|
||||
ldap_base_dn: String::from("dc=example,dc=com"),
|
||||
// cn=admin,dc=example,dc=com
|
||||
ldap_user_dn: String::from("admin"),
|
||||
ldap_user_pass: String::from("password"),
|
||||
database_url: String::from("sqlite://users.db?mode=rwc"),
|
||||
verbose: false,
|
||||
key_file: String::from("server_key"),
|
||||
server_setup: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_server_setup(file_path: &str) -> Result<ServerSetup> {
|
||||
use std::path::Path;
|
||||
let path = Path::new(file_path);
|
||||
if path.exists() {
|
||||
let bytes =
|
||||
std::fs::read(file_path).context(format!("Could not read key file `{}`", file_path))?;
|
||||
Ok(ServerSetup::deserialize(&bytes)?)
|
||||
} else {
|
||||
let mut rng = rand::rngs::OsRng;
|
||||
let server_setup = ServerSetup::new(&mut rng);
|
||||
std::fs::write(path, server_setup.serialize()).context(format!(
|
||||
"Could not write the generated server setup to file `{}`",
|
||||
file_path,
|
||||
))?;
|
||||
Ok(server_setup)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn init(cli_opts: RunOpts) -> Result<Configuration> {
|
||||
let config_file = cli_opts.config_file.clone();
|
||||
|
||||
let config: Configuration = Figment::from(Serialized::defaults(Configuration::default()))
|
||||
.merge(Toml::file(config_file))
|
||||
.merge(Env::prefixed("LLDAP_"))
|
||||
.extract()?;
|
||||
|
||||
let mut config = config.merge_with_cli(cli_opts);
|
||||
config.server_setup = Some(get_server_setup(&config.key_file)?);
|
||||
Ok(config)
|
||||
}
|
||||
82
server/src/infra/db_cleaner.rs
Normal file
82
server/src/infra/db_cleaner.rs
Normal file
@@ -0,0 +1,82 @@
|
||||
use crate::{
|
||||
domain::sql_tables::{DbQueryBuilder, Pool},
|
||||
infra::jwt_sql_tables::{JwtRefreshStorage, JwtStorage},
|
||||
};
|
||||
use actix::prelude::*;
|
||||
use chrono::Local;
|
||||
use cron::Schedule;
|
||||
use sea_query::{Expr, Query};
|
||||
use std::{str::FromStr, time::Duration};
|
||||
|
||||
// Define actor
|
||||
pub struct Scheduler {
|
||||
schedule: Schedule,
|
||||
sql_pool: Pool,
|
||||
}
|
||||
|
||||
// Provide Actor implementation for our actor
|
||||
impl Actor for Scheduler {
|
||||
type Context = Context<Self>;
|
||||
|
||||
fn started(&mut self, context: &mut Context<Self>) {
|
||||
log::info!("DB Cleanup Cron started");
|
||||
|
||||
context.run_later(self.duration_until_next(), move |this, ctx| {
|
||||
this.schedule_task(ctx)
|
||||
});
|
||||
}
|
||||
|
||||
fn stopped(&mut self, _ctx: &mut Context<Self>) {
|
||||
log::info!("DB Cleanup stopped");
|
||||
}
|
||||
}
|
||||
|
||||
impl Scheduler {
|
||||
pub fn new(cron_expression: &str, sql_pool: Pool) -> Self {
|
||||
let schedule = Schedule::from_str(cron_expression).unwrap();
|
||||
Self { schedule, sql_pool }
|
||||
}
|
||||
|
||||
fn schedule_task(&self, ctx: &mut Context<Self>) {
|
||||
log::info!("Cleaning DB");
|
||||
let future = actix::fut::wrap_future::<_, Self>(Self::cleanup_db(self.sql_pool.clone()));
|
||||
ctx.spawn(future);
|
||||
|
||||
ctx.run_later(self.duration_until_next(), move |this, ctx| {
|
||||
this.schedule_task(ctx)
|
||||
});
|
||||
}
|
||||
|
||||
async fn cleanup_db(sql_pool: Pool) {
|
||||
if let Err(e) = sqlx::query(
|
||||
&Query::delete()
|
||||
.from_table(JwtRefreshStorage::Table)
|
||||
.and_where(Expr::col(JwtRefreshStorage::ExpiryDate).lt(Local::now().naive_utc()))
|
||||
.to_string(DbQueryBuilder {}),
|
||||
)
|
||||
.execute(&sql_pool)
|
||||
.await
|
||||
{
|
||||
log::error!("DB error while cleaning up JWT refresh tokens: {}", e);
|
||||
};
|
||||
if let Err(e) = sqlx::query(
|
||||
&Query::delete()
|
||||
.from_table(JwtStorage::Table)
|
||||
.and_where(Expr::col(JwtStorage::ExpiryDate).lt(Local::now().naive_utc()))
|
||||
.to_string(DbQueryBuilder {}),
|
||||
)
|
||||
.execute(&sql_pool)
|
||||
.await
|
||||
{
|
||||
log::error!("DB error while cleaning up JWT storage: {}", e);
|
||||
};
|
||||
log::info!("DB cleaned!");
|
||||
}
|
||||
|
||||
fn duration_until_next(&self) -> Duration {
|
||||
let now = Local::now();
|
||||
let next = self.schedule.upcoming(Local).next().unwrap();
|
||||
let duration_until = next.signed_duration_since(now);
|
||||
duration_until.to_std().unwrap()
|
||||
}
|
||||
}
|
||||
100
server/src/infra/graphql/api.rs
Normal file
100
server/src/infra/graphql/api.rs
Normal file
@@ -0,0 +1,100 @@
|
||||
use crate::{
|
||||
domain::handler::BackendHandler,
|
||||
infra::{
|
||||
auth_service::{check_if_token_is_valid, ValidationResults},
|
||||
cli::ExportGraphQLSchemaOpts,
|
||||
tcp_server::AppState,
|
||||
},
|
||||
};
|
||||
use actix_web::{web, Error, HttpResponse};
|
||||
use actix_web_httpauth::extractors::bearer::BearerAuth;
|
||||
use juniper::{EmptySubscription, RootNode};
|
||||
use juniper_actix::{graphiql_handler, graphql_handler, playground_handler};
|
||||
|
||||
use super::{mutation::Mutation, query::Query};
|
||||
|
||||
pub struct Context<Handler: BackendHandler> {
|
||||
pub handler: Box<Handler>,
|
||||
pub validation_result: ValidationResults,
|
||||
}
|
||||
|
||||
impl<Handler: BackendHandler> juniper::Context for Context<Handler> {}
|
||||
|
||||
type Schema<Handler> =
|
||||
RootNode<'static, Query<Handler>, Mutation<Handler>, EmptySubscription<Context<Handler>>>;
|
||||
|
||||
fn schema<Handler: BackendHandler + Sync>() -> Schema<Handler> {
|
||||
Schema::new(
|
||||
Query::<Handler>::new(),
|
||||
Mutation::<Handler>::new(),
|
||||
EmptySubscription::<Context<Handler>>::new(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn export_schema(opts: ExportGraphQLSchemaOpts) -> anyhow::Result<()> {
|
||||
use crate::domain::sql_backend_handler::SqlBackendHandler;
|
||||
use anyhow::Context;
|
||||
let output = schema::<SqlBackendHandler>().as_schema_language();
|
||||
match opts.output_file {
|
||||
None => println!("{}", output),
|
||||
Some(path) => {
|
||||
use std::fs::File;
|
||||
use std::io::prelude::*;
|
||||
use std::path::Path;
|
||||
let path = Path::new(&path);
|
||||
let mut file =
|
||||
File::create(&path).context(format!("unable to open '{}'", path.display()))?;
|
||||
file.write_all(output.as_bytes())
|
||||
.context(format!("unable to write in '{}'", path.display()))?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn graphiql_route() -> Result<HttpResponse, Error> {
|
||||
graphiql_handler("/api/graphql", None).await
|
||||
}
|
||||
async fn playground_route() -> Result<HttpResponse, Error> {
|
||||
playground_handler("/api/graphql", None).await
|
||||
}
|
||||
|
||||
async fn graphql_route<Handler: BackendHandler + Sync>(
|
||||
req: actix_web::HttpRequest,
|
||||
mut payload: actix_web::web::Payload,
|
||||
data: web::Data<AppState<Handler>>,
|
||||
) -> Result<HttpResponse, Error> {
|
||||
use actix_web::FromRequest;
|
||||
let bearer = BearerAuth::from_request(&req, &mut payload.0).await?;
|
||||
let validation_result = check_if_token_is_valid(&data, bearer.token())?;
|
||||
let context = Context::<Handler> {
|
||||
handler: Box::new(data.backend_handler.clone()),
|
||||
validation_result,
|
||||
};
|
||||
graphql_handler(&schema(), &context, req, payload).await
|
||||
}
|
||||
|
||||
pub fn configure_endpoint<Backend>(cfg: &mut web::ServiceConfig)
|
||||
where
|
||||
Backend: BackendHandler + Sync + 'static,
|
||||
{
|
||||
let json_config = web::JsonConfig::default()
|
||||
.limit(4096)
|
||||
.error_handler(|err, _req| {
|
||||
// create custom error response
|
||||
log::error!("API error: {}", err);
|
||||
let msg = err.to_string();
|
||||
actix_web::error::InternalError::from_response(
|
||||
err,
|
||||
HttpResponse::BadRequest().body(msg),
|
||||
)
|
||||
.into()
|
||||
});
|
||||
cfg.app_data(json_config);
|
||||
cfg.service(
|
||||
web::resource("/graphql")
|
||||
.route(web::post().to(graphql_route::<Backend>))
|
||||
.route(web::get().to(graphql_route::<Backend>)),
|
||||
);
|
||||
cfg.service(web::resource("/graphql/playground").route(web::get().to(playground_route)));
|
||||
cfg.service(web::resource("/graphql/graphiql").route(web::get().to(graphiql_route)));
|
||||
}
|
||||
3
server/src/infra/graphql/mod.rs
Normal file
3
server/src/infra/graphql/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
pub mod api;
|
||||
pub mod mutation;
|
||||
pub mod query;
|
||||
55
server/src/infra/graphql/mutation.rs
Normal file
55
server/src/infra/graphql/mutation.rs
Normal file
@@ -0,0 +1,55 @@
|
||||
use crate::domain::handler::{BackendHandler, CreateUserRequest};
|
||||
use juniper::{graphql_object, FieldResult, GraphQLInputObject};
|
||||
|
||||
use super::api::Context;
|
||||
|
||||
#[derive(PartialEq, Eq, Debug)]
|
||||
/// The top-level GraphQL mutation type.
|
||||
pub struct Mutation<Handler: BackendHandler> {
|
||||
_phantom: std::marker::PhantomData<Box<Handler>>,
|
||||
}
|
||||
|
||||
impl<Handler: BackendHandler> Mutation<Handler> {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Debug, GraphQLInputObject)]
|
||||
/// The details required to create a user.
|
||||
pub struct UserInput {
|
||||
id: String,
|
||||
email: String,
|
||||
display_name: Option<String>,
|
||||
first_name: Option<String>,
|
||||
last_name: Option<String>,
|
||||
}
|
||||
|
||||
#[graphql_object(context = Context<Handler>)]
|
||||
impl<Handler: BackendHandler + Sync> Mutation<Handler> {
|
||||
async fn create_user(
|
||||
context: &Context<Handler>,
|
||||
user: UserInput,
|
||||
) -> FieldResult<super::query::User<Handler>> {
|
||||
if !context.validation_result.is_admin {
|
||||
return Err("Unauthorized user creation".into());
|
||||
}
|
||||
context
|
||||
.handler
|
||||
.create_user(CreateUserRequest {
|
||||
user_id: user.id.clone(),
|
||||
email: user.email,
|
||||
display_name: user.display_name,
|
||||
first_name: user.first_name,
|
||||
last_name: user.last_name,
|
||||
})
|
||||
.await?;
|
||||
Ok(context
|
||||
.handler
|
||||
.get_user_details(&user.id)
|
||||
.await
|
||||
.map(Into::into)?)
|
||||
}
|
||||
}
|
||||
348
server/src/infra/graphql/query.rs
Normal file
348
server/src/infra/graphql/query.rs
Normal file
@@ -0,0 +1,348 @@
|
||||
use crate::domain::handler::BackendHandler;
|
||||
use juniper::{graphql_object, FieldResult, GraphQLInputObject};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::convert::TryInto;
|
||||
|
||||
type DomainRequestFilter = crate::domain::handler::RequestFilter;
|
||||
type DomainUser = crate::domain::handler::User;
|
||||
use super::api::Context;
|
||||
|
||||
#[derive(PartialEq, Eq, Debug, GraphQLInputObject)]
|
||||
/// A filter for requests, specifying a boolean expression based on field constraints. Only one of
|
||||
/// the fields can be set at a time.
|
||||
pub struct RequestFilter {
|
||||
any: Option<Vec<RequestFilter>>,
|
||||
all: Option<Vec<RequestFilter>>,
|
||||
not: Option<Box<RequestFilter>>,
|
||||
eq: Option<EqualityConstraint>,
|
||||
}
|
||||
|
||||
impl TryInto<DomainRequestFilter> for RequestFilter {
|
||||
type Error = String;
|
||||
fn try_into(self) -> Result<DomainRequestFilter, Self::Error> {
|
||||
let mut field_count = 0;
|
||||
if self.any.is_some() {
|
||||
field_count += 1;
|
||||
}
|
||||
if self.all.is_some() {
|
||||
field_count += 1;
|
||||
}
|
||||
if self.not.is_some() {
|
||||
field_count += 1;
|
||||
}
|
||||
if self.eq.is_some() {
|
||||
field_count += 1;
|
||||
}
|
||||
if field_count == 0 {
|
||||
return Err("No field specified in request filter".to_string());
|
||||
}
|
||||
if field_count > 1 {
|
||||
return Err("Multiple fields specified in request filter".to_string());
|
||||
}
|
||||
if let Some(e) = self.eq {
|
||||
return Ok(DomainRequestFilter::Equality(e.field, e.value));
|
||||
}
|
||||
if let Some(c) = self.any {
|
||||
return Ok(DomainRequestFilter::Or(
|
||||
c.into_iter()
|
||||
.map(TryInto::try_into)
|
||||
.collect::<Result<Vec<_>, String>>()?,
|
||||
));
|
||||
}
|
||||
if let Some(c) = self.all {
|
||||
return Ok(DomainRequestFilter::And(
|
||||
c.into_iter()
|
||||
.map(TryInto::try_into)
|
||||
.collect::<Result<Vec<_>, String>>()?,
|
||||
));
|
||||
}
|
||||
if let Some(c) = self.not {
|
||||
return Ok(DomainRequestFilter::Not(Box::new((*c).try_into()?)));
|
||||
}
|
||||
unreachable!();
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Debug, GraphQLInputObject)]
|
||||
pub struct EqualityConstraint {
|
||||
field: String,
|
||||
value: String,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Debug)]
|
||||
/// The top-level GraphQL query type.
|
||||
pub struct Query<Handler: BackendHandler> {
|
||||
_phantom: std::marker::PhantomData<Box<Handler>>,
|
||||
}
|
||||
|
||||
impl<Handler: BackendHandler> Query<Handler> {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[graphql_object(context = Context<Handler>)]
|
||||
impl<Handler: BackendHandler + Sync> Query<Handler> {
|
||||
fn api_version() -> &'static str {
|
||||
"1.0"
|
||||
}
|
||||
|
||||
pub async fn user(context: &Context<Handler>, user_id: String) -> FieldResult<User<Handler>> {
|
||||
if !context.validation_result.can_access(&user_id) {
|
||||
return Err("Unauthorized access to user data".into());
|
||||
}
|
||||
Ok(context
|
||||
.handler
|
||||
.get_user_details(&user_id)
|
||||
.await
|
||||
.map(Into::into)?)
|
||||
}
|
||||
|
||||
async fn users(
|
||||
context: &Context<Handler>,
|
||||
#[graphql(name = "where")] filters: Option<RequestFilter>,
|
||||
) -> FieldResult<Vec<User<Handler>>> {
|
||||
if !context.validation_result.is_admin {
|
||||
return Err("Unauthorized access to user list".into());
|
||||
}
|
||||
Ok(context
|
||||
.handler
|
||||
.list_users(filters.map(TryInto::try_into).transpose()?)
|
||||
.await
|
||||
.map(|v| v.into_iter().map(Into::into).collect())?)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Debug, Serialize, Deserialize)]
|
||||
/// Represents a single user.
|
||||
pub struct User<Handler: BackendHandler> {
|
||||
user: DomainUser,
|
||||
_phantom: std::marker::PhantomData<Box<Handler>>,
|
||||
}
|
||||
|
||||
impl<Handler: BackendHandler> Default for User<Handler> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
user: DomainUser::default(),
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[graphql_object(context = Context<Handler>)]
|
||||
impl<Handler: BackendHandler + Sync> User<Handler> {
|
||||
fn id(&self) -> &str {
|
||||
&self.user.user_id
|
||||
}
|
||||
|
||||
fn email(&self) -> &str {
|
||||
&self.user.email
|
||||
}
|
||||
|
||||
fn display_name(&self) -> Option<&String> {
|
||||
self.user.display_name.as_ref()
|
||||
}
|
||||
|
||||
fn first_name(&self) -> Option<&String> {
|
||||
self.user.first_name.as_ref()
|
||||
}
|
||||
|
||||
fn last_name(&self) -> Option<&String> {
|
||||
self.user.last_name.as_ref()
|
||||
}
|
||||
|
||||
fn creation_date(&self) -> chrono::DateTime<chrono::Utc> {
|
||||
self.user.creation_date
|
||||
}
|
||||
|
||||
/// The groups to which this user belongs.
|
||||
async fn groups(&self, context: &Context<Handler>) -> FieldResult<Vec<Group<Handler>>> {
|
||||
Ok(context
|
||||
.handler
|
||||
.get_user_groups(&self.user.user_id)
|
||||
.await
|
||||
.map(|set| set.into_iter().map(Into::into).collect())?)
|
||||
}
|
||||
}
|
||||
|
||||
impl<Handler: BackendHandler> From<DomainUser> for User<Handler> {
|
||||
fn from(user: DomainUser) -> Self {
|
||||
Self {
|
||||
user,
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Debug, Serialize, Deserialize)]
|
||||
/// Represents a single group.
|
||||
pub struct Group<Handler: BackendHandler> {
|
||||
group_id: String,
|
||||
_phantom: std::marker::PhantomData<Box<Handler>>,
|
||||
}
|
||||
|
||||
#[graphql_object(context = Context<Handler>)]
|
||||
impl<Handler: BackendHandler + Sync> Group<Handler> {
|
||||
fn id(&self) -> String {
|
||||
self.group_id.clone()
|
||||
}
|
||||
/// The groups to which this user belongs.
|
||||
async fn users(&self, context: &Context<Handler>) -> FieldResult<Vec<User<Handler>>> {
|
||||
if !context.validation_result.is_admin {
|
||||
return Err("Unauthorized access to group data".into());
|
||||
}
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
impl<Handler: BackendHandler> From<String> for Group<Handler> {
|
||||
fn from(group_id: String) -> Self {
|
||||
Self {
|
||||
group_id,
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{domain::handler::MockTestBackendHandler, infra::auth_service::ValidationResults};
|
||||
use juniper::{
|
||||
execute, graphql_value, DefaultScalarValue, EmptyMutation, EmptySubscription, GraphQLType,
|
||||
RootNode, Variables,
|
||||
};
|
||||
use mockall::predicate::eq;
|
||||
use std::collections::HashSet;
|
||||
|
||||
fn schema<'q, C, Q>(query_root: Q) -> RootNode<'q, Q, EmptyMutation<C>, EmptySubscription<C>>
|
||||
where
|
||||
Q: GraphQLType<DefaultScalarValue, Context = C, TypeInfo = ()> + 'q,
|
||||
{
|
||||
RootNode::new(
|
||||
query_root,
|
||||
EmptyMutation::<C>::new(),
|
||||
EmptySubscription::<C>::new(),
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn get_user_by_id() {
|
||||
const QUERY: &str = r#"{
|
||||
user(userId: "bob") {
|
||||
id
|
||||
email
|
||||
groups {
|
||||
id
|
||||
}
|
||||
}
|
||||
}"#;
|
||||
|
||||
let mut mock = MockTestBackendHandler::new();
|
||||
mock.expect_get_user_details()
|
||||
.with(eq("bob"))
|
||||
.return_once(|_| {
|
||||
Ok(DomainUser {
|
||||
user_id: "bob".to_string(),
|
||||
email: "bob@bobbers.on".to_string(),
|
||||
..Default::default()
|
||||
})
|
||||
});
|
||||
let mut groups = HashSet::<String>::new();
|
||||
groups.insert("Bobbersons".to_string());
|
||||
mock.expect_get_user_groups()
|
||||
.with(eq("bob"))
|
||||
.return_once(|_| Ok(groups));
|
||||
|
||||
let context = Context::<MockTestBackendHandler> {
|
||||
handler: Box::new(mock),
|
||||
validation_result: ValidationResults::admin(),
|
||||
};
|
||||
|
||||
let schema = schema(Query::<MockTestBackendHandler>::new());
|
||||
assert_eq!(
|
||||
execute(QUERY, None, &schema, &Variables::new(), &context).await,
|
||||
Ok((
|
||||
graphql_value!(
|
||||
{
|
||||
"user": {
|
||||
"id": "bob",
|
||||
"email": "bob@bobbers.on",
|
||||
"groups": [{"id": "Bobbersons"}]
|
||||
}
|
||||
}),
|
||||
vec![]
|
||||
))
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_users() {
|
||||
const QUERY: &str = r#"{
|
||||
users(filters: {
|
||||
any: [
|
||||
{eq: {
|
||||
field: "id"
|
||||
value: "bob"
|
||||
}},
|
||||
{eq: {
|
||||
field: "email"
|
||||
value: "robert@bobbers.on"
|
||||
}}
|
||||
]}) {
|
||||
id
|
||||
email
|
||||
}
|
||||
}"#;
|
||||
|
||||
let mut mock = MockTestBackendHandler::new();
|
||||
use crate::domain::handler::RequestFilter;
|
||||
mock.expect_list_users()
|
||||
.with(eq(Some(RequestFilter::Or(vec![
|
||||
RequestFilter::Equality("id".to_string(), "bob".to_string()),
|
||||
RequestFilter::Equality("email".to_string(), "robert@bobbers.on".to_string()),
|
||||
]))))
|
||||
.return_once(|_| {
|
||||
Ok(vec![
|
||||
DomainUser {
|
||||
user_id: "bob".to_string(),
|
||||
email: "bob@bobbers.on".to_string(),
|
||||
..Default::default()
|
||||
},
|
||||
DomainUser {
|
||||
user_id: "robert".to_string(),
|
||||
email: "robert@bobbers.on".to_string(),
|
||||
..Default::default()
|
||||
},
|
||||
])
|
||||
});
|
||||
|
||||
let context = Context::<MockTestBackendHandler> {
|
||||
handler: Box::new(mock),
|
||||
validation_result: ValidationResults::admin(),
|
||||
};
|
||||
|
||||
let schema = schema(Query::<MockTestBackendHandler>::new());
|
||||
assert_eq!(
|
||||
execute(QUERY, None, &schema, &Variables::new(), &context).await,
|
||||
Ok((
|
||||
graphql_value!(
|
||||
{
|
||||
"users": [
|
||||
{
|
||||
"id": "bob",
|
||||
"email": "bob@bobbers.on"
|
||||
},
|
||||
{
|
||||
"id": "robert",
|
||||
"email": "robert@bobbers.on"
|
||||
},
|
||||
]
|
||||
}),
|
||||
vec![]
|
||||
))
|
||||
);
|
||||
}
|
||||
}
|
||||
99
server/src/infra/jwt_sql_tables.rs
Normal file
99
server/src/infra/jwt_sql_tables.rs
Normal file
@@ -0,0 +1,99 @@
|
||||
use sea_query::*;
|
||||
|
||||
pub use crate::domain::sql_tables::*;
|
||||
|
||||
/// Contains the refresh tokens for a given user.
|
||||
#[derive(Iden)]
|
||||
pub enum JwtRefreshStorage {
|
||||
Table,
|
||||
RefreshTokenHash,
|
||||
UserId,
|
||||
ExpiryDate,
|
||||
}
|
||||
|
||||
/// Contains the blacklisted JWT that haven't expired yet.
|
||||
#[derive(Iden)]
|
||||
pub enum JwtStorage {
|
||||
Table,
|
||||
JwtHash,
|
||||
UserId,
|
||||
ExpiryDate,
|
||||
Blacklisted,
|
||||
}
|
||||
|
||||
/// This needs to be initialized after the domain tables are.
|
||||
pub async fn init_table(pool: &Pool) -> sqlx::Result<()> {
|
||||
sqlx::query(
|
||||
&Table::create()
|
||||
.table(JwtRefreshStorage::Table)
|
||||
.if_not_exists()
|
||||
.col(
|
||||
ColumnDef::new(JwtRefreshStorage::RefreshTokenHash)
|
||||
.big_integer()
|
||||
.not_null()
|
||||
.primary_key(),
|
||||
)
|
||||
.col(
|
||||
ColumnDef::new(JwtRefreshStorage::UserId)
|
||||
.string_len(255)
|
||||
.not_null(),
|
||||
)
|
||||
.col(
|
||||
ColumnDef::new(JwtRefreshStorage::ExpiryDate)
|
||||
.date_time()
|
||||
.not_null(),
|
||||
)
|
||||
.foreign_key(
|
||||
ForeignKey::create()
|
||||
.name("JwtRefreshStorageUserForeignKey")
|
||||
.table(JwtRefreshStorage::Table, Users::Table)
|
||||
.col(JwtRefreshStorage::UserId, Users::UserId)
|
||||
.on_delete(ForeignKeyAction::Cascade)
|
||||
.on_update(ForeignKeyAction::Cascade),
|
||||
)
|
||||
.to_string(DbQueryBuilder {}),
|
||||
)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
sqlx::query(
|
||||
&Table::create()
|
||||
.table(JwtStorage::Table)
|
||||
.if_not_exists()
|
||||
.col(
|
||||
ColumnDef::new(JwtStorage::JwtHash)
|
||||
.big_integer()
|
||||
.not_null()
|
||||
.primary_key(),
|
||||
)
|
||||
.col(
|
||||
ColumnDef::new(JwtStorage::UserId)
|
||||
.string_len(255)
|
||||
.not_null(),
|
||||
)
|
||||
.col(
|
||||
ColumnDef::new(JwtStorage::ExpiryDate)
|
||||
.date_time()
|
||||
.not_null(),
|
||||
)
|
||||
.col(
|
||||
ColumnDef::new(JwtStorage::Blacklisted)
|
||||
.boolean()
|
||||
.default(false)
|
||||
.not_null(),
|
||||
)
|
||||
.foreign_key(
|
||||
ForeignKey::create()
|
||||
.name("JwtStorageUserForeignKey")
|
||||
.table(JwtStorage::Table, Users::Table)
|
||||
.col(JwtStorage::UserId, Users::UserId)
|
||||
.on_delete(ForeignKeyAction::Cascade)
|
||||
.on_update(ForeignKeyAction::Cascade),
|
||||
)
|
||||
.to_string(DbQueryBuilder {}),
|
||||
)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
632
server/src/infra/ldap_handler.rs
Normal file
632
server/src/infra/ldap_handler.rs
Normal file
@@ -0,0 +1,632 @@
|
||||
use crate::domain::handler::{BackendHandler, LoginHandler, RequestFilter, User};
|
||||
use anyhow::{bail, Result};
|
||||
use ldap3_server::simple::*;
|
||||
|
||||
fn make_dn_pair<I>(mut iter: I) -> Result<(String, String)>
|
||||
where
|
||||
I: Iterator<Item = String>,
|
||||
{
|
||||
let pair = (
|
||||
iter.next()
|
||||
.ok_or_else(|| anyhow::Error::msg("Empty DN element"))?,
|
||||
iter.next()
|
||||
.ok_or_else(|| anyhow::Error::msg("Missing DN value"))?,
|
||||
);
|
||||
if let Some(e) = iter.next() {
|
||||
bail!(
|
||||
r#"Too many elements in distinguished name: "{:?}", "{:?}", "{:?}""#,
|
||||
pair.0,
|
||||
pair.1,
|
||||
e
|
||||
)
|
||||
}
|
||||
Ok(pair)
|
||||
}
|
||||
|
||||
fn parse_distinguished_name(dn: &str) -> Result<Vec<(String, String)>> {
|
||||
dn.split(',')
|
||||
.map(|s| make_dn_pair(s.split('=').map(String::from)))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn get_user_id_from_distinguished_name(
|
||||
dn: &str,
|
||||
base_tree: &[(String, String)],
|
||||
base_dn_str: &str,
|
||||
ldap_user_dn: &str,
|
||||
) -> Result<String> {
|
||||
let parts = parse_distinguished_name(dn)?;
|
||||
if !is_subtree(&parts, base_tree) {
|
||||
bail!("Not a subtree of the base tree");
|
||||
}
|
||||
if parts.len() == base_tree.len() + 1 {
|
||||
if dn != ldap_user_dn {
|
||||
bail!(r#"Wrong admin DN. Expected: "{}""#, ldap_user_dn);
|
||||
}
|
||||
Ok(parts[0].1.to_string())
|
||||
} else if parts.len() == base_tree.len() + 2 {
|
||||
if parts[1].0 != "ou" || parts[1].1 != "people" || parts[0].0 != "cn" {
|
||||
bail!(
|
||||
r#"Unexpected user DN format. Expected: "cn=username,ou=people,{}""#,
|
||||
base_dn_str
|
||||
);
|
||||
}
|
||||
Ok(parts[0].1.to_string())
|
||||
} else {
|
||||
bail!(
|
||||
r#"Unexpected user DN format. Expected: "cn=username,ou=people,{}""#,
|
||||
base_dn_str
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn get_attribute(user: &User, attribute: &str) -> Result<Vec<String>> {
|
||||
match attribute {
|
||||
"objectClass" => Ok(vec![
|
||||
"inetOrgPerson".to_string(),
|
||||
"posixAccount".to_string(),
|
||||
"mailAccount".to_string(),
|
||||
]),
|
||||
"uid" => Ok(vec![user.user_id.clone()]),
|
||||
"mail" => Ok(vec![user.email.clone()]),
|
||||
"givenName" => Ok(vec![user.first_name.clone().unwrap_or_default()]),
|
||||
"sn" => Ok(vec![user.last_name.clone().unwrap_or_default()]),
|
||||
"cn" => Ok(vec![user
|
||||
.display_name
|
||||
.clone()
|
||||
.unwrap_or_else(|| user.user_id.clone())]),
|
||||
_ => bail!("Unsupported attribute: {}", attribute),
|
||||
}
|
||||
}
|
||||
|
||||
fn make_ldap_search_result_entry(
|
||||
user: User,
|
||||
base_dn_str: &str,
|
||||
attributes: &[String],
|
||||
) -> Result<LdapSearchResultEntry> {
|
||||
Ok(LdapSearchResultEntry {
|
||||
dn: format!("cn={},{}", user.user_id, base_dn_str),
|
||||
attributes: attributes
|
||||
.iter()
|
||||
.map(|a| {
|
||||
Ok(LdapPartialAttribute {
|
||||
atype: a.to_string(),
|
||||
vals: get_attribute(&user, a)?,
|
||||
})
|
||||
})
|
||||
.collect::<Result<Vec<LdapPartialAttribute>>>()?,
|
||||
})
|
||||
}
|
||||
|
||||
fn is_subtree(subtree: &[(String, String)], base_tree: &[(String, String)]) -> bool {
|
||||
if subtree.len() < base_tree.len() {
|
||||
return false;
|
||||
}
|
||||
let size_diff = subtree.len() - base_tree.len();
|
||||
for i in 0..base_tree.len() {
|
||||
if subtree[size_diff + i] != base_tree[i] {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
fn map_field(field: &str) -> Result<String> {
|
||||
Ok(if field == "uid" {
|
||||
"user_id".to_string()
|
||||
} else if field == "mail" {
|
||||
"email".to_string()
|
||||
} else if field == "cn" {
|
||||
"display_name".to_string()
|
||||
} else if field == "givenName" {
|
||||
"first_name".to_string()
|
||||
} else if field == "sn" {
|
||||
"last_name".to_string()
|
||||
} else if field == "avatar" {
|
||||
"avatar".to_string()
|
||||
} else if field == "creationDate" {
|
||||
"creation_date".to_string()
|
||||
} else {
|
||||
bail!("Unknown field: {}", field);
|
||||
})
|
||||
}
|
||||
|
||||
fn convert_filter(filter: &LdapFilter) -> Result<RequestFilter> {
|
||||
match filter {
|
||||
LdapFilter::And(filters) => Ok(RequestFilter::And(
|
||||
filters.iter().map(convert_filter).collect::<Result<_>>()?,
|
||||
)),
|
||||
LdapFilter::Or(filters) => Ok(RequestFilter::Or(
|
||||
filters.iter().map(convert_filter).collect::<Result<_>>()?,
|
||||
)),
|
||||
LdapFilter::Not(filter) => Ok(RequestFilter::Not(Box::new(convert_filter(&*filter)?))),
|
||||
LdapFilter::Equality(field, value) => {
|
||||
Ok(RequestFilter::Equality(map_field(field)?, value.clone()))
|
||||
}
|
||||
_ => bail!("Unsupported filter"),
|
||||
}
|
||||
}
|
||||
|
||||
pub struct LdapHandler<Backend: BackendHandler + LoginHandler> {
|
||||
dn: String,
|
||||
backend_handler: Backend,
|
||||
pub base_dn: Vec<(String, String)>,
|
||||
base_dn_str: String,
|
||||
ldap_user_dn: String,
|
||||
}
|
||||
|
||||
impl<Backend: BackendHandler + LoginHandler> LdapHandler<Backend> {
|
||||
pub fn new(backend_handler: Backend, ldap_base_dn: String, ldap_user_dn: String) -> Self {
|
||||
Self {
|
||||
dn: "Unauthenticated".to_string(),
|
||||
backend_handler,
|
||||
base_dn: parse_distinguished_name(&ldap_base_dn).unwrap_or_else(|_| {
|
||||
panic!(
|
||||
"Invalid value for ldap_base_dn in configuration: {}",
|
||||
ldap_base_dn
|
||||
)
|
||||
}),
|
||||
ldap_user_dn: format!("cn={},{}", ldap_user_dn, &ldap_base_dn),
|
||||
base_dn_str: ldap_base_dn,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn do_bind(&mut self, sbr: &SimpleBindRequest) -> LdapMsg {
|
||||
let user_id = match get_user_id_from_distinguished_name(
|
||||
&sbr.dn,
|
||||
&self.base_dn,
|
||||
&self.base_dn_str,
|
||||
&self.ldap_user_dn,
|
||||
) {
|
||||
Ok(s) => s,
|
||||
Err(e) => return sbr.gen_error(LdapResultCode::NamingViolation, e.to_string()),
|
||||
};
|
||||
match self
|
||||
.backend_handler
|
||||
.bind(crate::domain::handler::BindRequest {
|
||||
name: user_id,
|
||||
password: sbr.pw.clone(),
|
||||
})
|
||||
.await
|
||||
{
|
||||
Ok(()) => {
|
||||
self.dn = sbr.dn.clone();
|
||||
sbr.gen_success()
|
||||
}
|
||||
Err(_) => sbr.gen_invalid_cred(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn do_search(&mut self, lsr: &SearchRequest) -> Vec<LdapMsg> {
|
||||
if self.dn != self.ldap_user_dn {
|
||||
return vec![lsr.gen_error(
|
||||
LdapResultCode::InsufficentAccessRights,
|
||||
r#"Current user is not allowed to query LDAP"#.to_string(),
|
||||
)];
|
||||
}
|
||||
let dn_parts = match parse_distinguished_name(&lsr.base) {
|
||||
Ok(dn) => dn,
|
||||
Err(_) => {
|
||||
return vec![lsr.gen_error(
|
||||
LdapResultCode::OperationsError,
|
||||
format!(r#"Could not parse base DN: "{}""#, lsr.base),
|
||||
)]
|
||||
}
|
||||
};
|
||||
if !is_subtree(&dn_parts, &self.base_dn) {
|
||||
// Search path is not in our tree, just return an empty success.
|
||||
return vec![lsr.gen_success()];
|
||||
}
|
||||
let filters = match convert_filter(&lsr.filter) {
|
||||
Ok(f) => Some(f),
|
||||
Err(_) => {
|
||||
return vec![lsr.gen_error(
|
||||
LdapResultCode::UnwillingToPerform,
|
||||
"Unsupported filter".to_string(),
|
||||
)]
|
||||
}
|
||||
};
|
||||
let users = match self.backend_handler.list_users(filters).await {
|
||||
Ok(users) => users,
|
||||
Err(e) => {
|
||||
return vec![lsr.gen_error(
|
||||
LdapResultCode::Other,
|
||||
format!(r#"Error during search for "{}": {}"#, lsr.base, e),
|
||||
)]
|
||||
}
|
||||
};
|
||||
|
||||
users
|
||||
.into_iter()
|
||||
.map(|u| make_ldap_search_result_entry(u, &self.base_dn_str, &lsr.attrs))
|
||||
.map(|entry| Ok(lsr.gen_result_entry(entry?)))
|
||||
// If the processing succeeds, add a success message at the end.
|
||||
.chain(std::iter::once(Ok(lsr.gen_success())))
|
||||
.collect::<Result<Vec<_>>>()
|
||||
.unwrap_or_else(|e| vec![lsr.gen_error(LdapResultCode::NoSuchAttribute, e.to_string())])
|
||||
}
|
||||
|
||||
pub fn do_whoami(&mut self, wr: &WhoamiRequest) -> LdapMsg {
|
||||
if self.dn == "Unauthenticated" {
|
||||
wr.gen_operror("Unauthenticated")
|
||||
} else {
|
||||
wr.gen_success(format!("dn: {}", self.dn).as_str())
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn handle_ldap_message(&mut self, server_op: ServerOps) -> Option<Vec<LdapMsg>> {
|
||||
let result = match server_op {
|
||||
ServerOps::SimpleBind(sbr) => vec![self.do_bind(&sbr).await],
|
||||
ServerOps::Search(sr) => self.do_search(&sr).await,
|
||||
ServerOps::Unbind(_) => {
|
||||
// No need to notify on unbind (per rfc4511)
|
||||
return None;
|
||||
}
|
||||
ServerOps::Whoami(wr) => vec![self.do_whoami(&wr)],
|
||||
};
|
||||
Some(result)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::domain::handler::BindRequest;
|
||||
use crate::domain::handler::MockTestBackendHandler;
|
||||
use mockall::predicate::eq;
|
||||
use tokio;
|
||||
|
||||
async fn setup_bound_handler(
|
||||
mut mock: MockTestBackendHandler,
|
||||
) -> LdapHandler<MockTestBackendHandler> {
|
||||
mock.expect_bind()
|
||||
.with(eq(BindRequest {
|
||||
name: "test".to_string(),
|
||||
password: "pass".to_string(),
|
||||
}))
|
||||
.return_once(|_| Ok(()));
|
||||
let mut ldap_handler =
|
||||
LdapHandler::new(mock, "dc=example,dc=com".to_string(), "test".to_string());
|
||||
let request = SimpleBindRequest {
|
||||
msgid: 1,
|
||||
dn: "cn=test,dc=example,dc=com".to_string(),
|
||||
pw: "pass".to_string(),
|
||||
};
|
||||
ldap_handler.do_bind(&request).await;
|
||||
ldap_handler
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_bind() {
|
||||
let mut mock = MockTestBackendHandler::new();
|
||||
mock.expect_bind()
|
||||
.with(eq(crate::domain::handler::BindRequest {
|
||||
name: "bob".to_string(),
|
||||
password: "pass".to_string(),
|
||||
}))
|
||||
.times(1)
|
||||
.return_once(|_| Ok(()));
|
||||
let mut ldap_handler =
|
||||
LdapHandler::new(mock, "dc=example,dc=com".to_string(), "test".to_string());
|
||||
|
||||
let request = WhoamiRequest { msgid: 1 };
|
||||
assert_eq!(
|
||||
ldap_handler.do_whoami(&request),
|
||||
request.gen_operror("Unauthenticated")
|
||||
);
|
||||
|
||||
let request = SimpleBindRequest {
|
||||
msgid: 2,
|
||||
dn: "cn=bob,ou=people,dc=example,dc=com".to_string(),
|
||||
pw: "pass".to_string(),
|
||||
};
|
||||
assert_eq!(ldap_handler.do_bind(&request).await, request.gen_success());
|
||||
|
||||
let request = WhoamiRequest { msgid: 3 };
|
||||
assert_eq!(
|
||||
ldap_handler.do_whoami(&request),
|
||||
request.gen_success("dn: cn=bob,ou=people,dc=example,dc=com")
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_admin_bind() {
|
||||
let mut mock = MockTestBackendHandler::new();
|
||||
mock.expect_bind()
|
||||
.with(eq(crate::domain::handler::BindRequest {
|
||||
name: "test".to_string(),
|
||||
password: "pass".to_string(),
|
||||
}))
|
||||
.times(1)
|
||||
.return_once(|_| Ok(()));
|
||||
let mut ldap_handler =
|
||||
LdapHandler::new(mock, "dc=example,dc=com".to_string(), "test".to_string());
|
||||
|
||||
let request = WhoamiRequest { msgid: 1 };
|
||||
assert_eq!(
|
||||
ldap_handler.do_whoami(&request),
|
||||
request.gen_operror("Unauthenticated")
|
||||
);
|
||||
|
||||
let request = SimpleBindRequest {
|
||||
msgid: 2,
|
||||
dn: "cn=test,dc=example,dc=com".to_string(),
|
||||
pw: "pass".to_string(),
|
||||
};
|
||||
assert_eq!(ldap_handler.do_bind(&request).await, request.gen_success());
|
||||
|
||||
let request = WhoamiRequest { msgid: 3 };
|
||||
assert_eq!(
|
||||
ldap_handler.do_whoami(&request),
|
||||
request.gen_success("dn: cn=test,dc=example,dc=com")
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_bind_invalid_credentials() {
|
||||
let mut mock = MockTestBackendHandler::new();
|
||||
mock.expect_bind()
|
||||
.with(eq(crate::domain::handler::BindRequest {
|
||||
name: "test".to_string(),
|
||||
password: "pass".to_string(),
|
||||
}))
|
||||
.times(1)
|
||||
.return_once(|_| Ok(()));
|
||||
let mut ldap_handler =
|
||||
LdapHandler::new(mock, "dc=example,dc=com".to_string(), "admin".to_string());
|
||||
|
||||
let request = WhoamiRequest { msgid: 1 };
|
||||
assert_eq!(
|
||||
ldap_handler.do_whoami(&request),
|
||||
request.gen_operror("Unauthenticated")
|
||||
);
|
||||
|
||||
let request = SimpleBindRequest {
|
||||
msgid: 2,
|
||||
dn: "cn=test,ou=people,dc=example,dc=com".to_string(),
|
||||
pw: "pass".to_string(),
|
||||
};
|
||||
assert_eq!(ldap_handler.do_bind(&request).await, request.gen_success());
|
||||
|
||||
let request = WhoamiRequest { msgid: 3 };
|
||||
assert_eq!(
|
||||
ldap_handler.do_whoami(&request),
|
||||
request.gen_success("dn: cn=test,ou=people,dc=example,dc=com")
|
||||
);
|
||||
|
||||
let request = SearchRequest {
|
||||
msgid: 2,
|
||||
base: "ou=people,dc=example,dc=com".to_string(),
|
||||
scope: LdapSearchScope::Base,
|
||||
filter: LdapFilter::And(vec![]),
|
||||
attrs: vec![],
|
||||
};
|
||||
assert_eq!(
|
||||
ldap_handler.do_search(&request).await,
|
||||
vec![request.gen_error(
|
||||
LdapResultCode::InsufficentAccessRights,
|
||||
r#"Current user is not allowed to query LDAP"#.to_string()
|
||||
)]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_bind_invalid_dn() {
|
||||
let mock = MockTestBackendHandler::new();
|
||||
let mut ldap_handler =
|
||||
LdapHandler::new(mock, "dc=example,dc=com".to_string(), "admin".to_string());
|
||||
|
||||
let request = SimpleBindRequest {
|
||||
msgid: 2,
|
||||
dn: "cn=bob,dc=example,dc=com".to_string(),
|
||||
pw: "pass".to_string(),
|
||||
};
|
||||
assert_eq!(
|
||||
ldap_handler.do_bind(&request).await,
|
||||
request.gen_error(
|
||||
LdapResultCode::NamingViolation,
|
||||
r#"Wrong admin DN. Expected: "cn=admin,dc=example,dc=com""#.to_string()
|
||||
)
|
||||
);
|
||||
let request = SimpleBindRequest {
|
||||
msgid: 2,
|
||||
dn: "cn=bob,ou=groups,dc=example,dc=com".to_string(),
|
||||
pw: "pass".to_string(),
|
||||
};
|
||||
assert_eq!(
|
||||
ldap_handler.do_bind(&request).await,
|
||||
request.gen_error(
|
||||
LdapResultCode::NamingViolation,
|
||||
r#"Unexpected user DN format. Expected: "cn=username,ou=people,dc=example,dc=com""#
|
||||
.to_string()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_subtree() {
|
||||
let subtree1 = &[
|
||||
("ou".to_string(), "people".to_string()),
|
||||
("dc".to_string(), "example".to_string()),
|
||||
("dc".to_string(), "com".to_string()),
|
||||
];
|
||||
let root = &[
|
||||
("dc".to_string(), "example".to_string()),
|
||||
("dc".to_string(), "com".to_string()),
|
||||
];
|
||||
assert!(is_subtree(subtree1, root));
|
||||
assert!(!is_subtree(&[], root));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_distinguished_name() {
|
||||
let parsed_dn = &[
|
||||
("ou".to_string(), "people".to_string()),
|
||||
("dc".to_string(), "example".to_string()),
|
||||
("dc".to_string(), "com".to_string()),
|
||||
];
|
||||
assert_eq!(
|
||||
parse_distinguished_name("ou=people,dc=example,dc=com").expect("parsing failed"),
|
||||
parsed_dn
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_search() {
|
||||
let mut mock = MockTestBackendHandler::new();
|
||||
mock.expect_list_users().times(1).return_once(|_| {
|
||||
Ok(vec![
|
||||
User {
|
||||
user_id: "bob_1".to_string(),
|
||||
email: "bob@bobmail.bob".to_string(),
|
||||
display_name: Some("Bôb Böbberson".to_string()),
|
||||
first_name: Some("Bôb".to_string()),
|
||||
last_name: Some("Böbberson".to_string()),
|
||||
..Default::default()
|
||||
},
|
||||
User {
|
||||
user_id: "jim".to_string(),
|
||||
email: "jim@cricket.jim".to_string(),
|
||||
display_name: Some("Jimminy Cricket".to_string()),
|
||||
first_name: Some("Jim".to_string()),
|
||||
last_name: Some("Cricket".to_string()),
|
||||
..Default::default()
|
||||
},
|
||||
])
|
||||
});
|
||||
let mut ldap_handler = setup_bound_handler(mock).await;
|
||||
let request = SearchRequest {
|
||||
msgid: 2,
|
||||
base: "ou=people,dc=example,dc=com".to_string(),
|
||||
scope: LdapSearchScope::Base,
|
||||
filter: LdapFilter::And(vec![]),
|
||||
attrs: vec![
|
||||
"objectClass".to_string(),
|
||||
"uid".to_string(),
|
||||
"mail".to_string(),
|
||||
"givenName".to_string(),
|
||||
"sn".to_string(),
|
||||
"cn".to_string(),
|
||||
],
|
||||
};
|
||||
assert_eq!(
|
||||
ldap_handler.do_search(&request).await,
|
||||
vec![
|
||||
request.gen_result_entry(LdapSearchResultEntry {
|
||||
dn: "cn=bob_1,dc=example,dc=com".to_string(),
|
||||
attributes: vec![
|
||||
LdapPartialAttribute {
|
||||
atype: "objectClass".to_string(),
|
||||
vals: vec![
|
||||
"inetOrgPerson".to_string(),
|
||||
"posixAccount".to_string(),
|
||||
"mailAccount".to_string()
|
||||
]
|
||||
},
|
||||
LdapPartialAttribute {
|
||||
atype: "uid".to_string(),
|
||||
vals: vec!["bob_1".to_string()]
|
||||
},
|
||||
LdapPartialAttribute {
|
||||
atype: "mail".to_string(),
|
||||
vals: vec!["bob@bobmail.bob".to_string()]
|
||||
},
|
||||
LdapPartialAttribute {
|
||||
atype: "givenName".to_string(),
|
||||
vals: vec!["Bôb".to_string()]
|
||||
},
|
||||
LdapPartialAttribute {
|
||||
atype: "sn".to_string(),
|
||||
vals: vec!["Böbberson".to_string()]
|
||||
},
|
||||
LdapPartialAttribute {
|
||||
atype: "cn".to_string(),
|
||||
vals: vec!["Bôb Böbberson".to_string()]
|
||||
}
|
||||
],
|
||||
}),
|
||||
request.gen_result_entry(LdapSearchResultEntry {
|
||||
dn: "cn=jim,dc=example,dc=com".to_string(),
|
||||
attributes: vec![
|
||||
LdapPartialAttribute {
|
||||
atype: "objectClass".to_string(),
|
||||
vals: vec![
|
||||
"inetOrgPerson".to_string(),
|
||||
"posixAccount".to_string(),
|
||||
"mailAccount".to_string()
|
||||
]
|
||||
},
|
||||
LdapPartialAttribute {
|
||||
atype: "uid".to_string(),
|
||||
vals: vec!["jim".to_string()]
|
||||
},
|
||||
LdapPartialAttribute {
|
||||
atype: "mail".to_string(),
|
||||
vals: vec!["jim@cricket.jim".to_string()]
|
||||
},
|
||||
LdapPartialAttribute {
|
||||
atype: "givenName".to_string(),
|
||||
vals: vec!["Jim".to_string()]
|
||||
},
|
||||
LdapPartialAttribute {
|
||||
atype: "sn".to_string(),
|
||||
vals: vec!["Cricket".to_string()]
|
||||
},
|
||||
LdapPartialAttribute {
|
||||
atype: "cn".to_string(),
|
||||
vals: vec!["Jimminy Cricket".to_string()]
|
||||
}
|
||||
],
|
||||
}),
|
||||
request.gen_success()
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_search_filters() {
|
||||
let mut mock = MockTestBackendHandler::new();
|
||||
mock.expect_list_users()
|
||||
.with(eq(Some(RequestFilter::And(vec![RequestFilter::Or(vec![
|
||||
RequestFilter::Not(Box::new(RequestFilter::Equality(
|
||||
"user_id".to_string(),
|
||||
"bob".to_string(),
|
||||
))),
|
||||
])]))))
|
||||
.times(1)
|
||||
.return_once(|_| Ok(vec![]));
|
||||
let mut ldap_handler = setup_bound_handler(mock).await;
|
||||
let request = SearchRequest {
|
||||
msgid: 2,
|
||||
base: "ou=people,dc=example,dc=com".to_string(),
|
||||
scope: LdapSearchScope::Base,
|
||||
filter: LdapFilter::And(vec![LdapFilter::Or(vec![LdapFilter::Not(Box::new(
|
||||
LdapFilter::Equality("uid".to_string(), "bob".to_string()),
|
||||
))])]),
|
||||
attrs: vec!["objectClass".to_string()],
|
||||
};
|
||||
assert_eq!(
|
||||
ldap_handler.do_search(&request).await,
|
||||
vec![request.gen_success()]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_search_unsupported_filters() {
|
||||
let mut ldap_handler = setup_bound_handler(MockTestBackendHandler::new()).await;
|
||||
let request = SearchRequest {
|
||||
msgid: 2,
|
||||
base: "ou=people,dc=example,dc=com".to_string(),
|
||||
scope: LdapSearchScope::Base,
|
||||
filter: LdapFilter::Present("uid".to_string()),
|
||||
attrs: vec!["objectClass".to_string()],
|
||||
};
|
||||
assert_eq!(
|
||||
ldap_handler.do_search(&request).await,
|
||||
vec![request.gen_error(
|
||||
LdapResultCode::UnwillingToPerform,
|
||||
"Unsupported filter".to_string()
|
||||
)]
|
||||
);
|
||||
}
|
||||
}
|
||||
102
server/src/infra/ldap_server.rs
Normal file
102
server/src/infra/ldap_server.rs
Normal file
@@ -0,0 +1,102 @@
|
||||
use crate::domain::handler::{BackendHandler, LoginHandler};
|
||||
use crate::infra::configuration::Configuration;
|
||||
use crate::infra::ldap_handler::LdapHandler;
|
||||
use actix_rt::net::TcpStream;
|
||||
use actix_server::ServerBuilder;
|
||||
use actix_service::{fn_service, ServiceFactoryExt};
|
||||
use anyhow::{bail, Result};
|
||||
use futures_util::future::ok;
|
||||
use ldap3_server::simple::*;
|
||||
use ldap3_server::LdapCodec;
|
||||
use log::*;
|
||||
use tokio::net::tcp::WriteHalf;
|
||||
use tokio_util::codec::{FramedRead, FramedWrite};
|
||||
|
||||
async fn handle_incoming_message<Backend>(
|
||||
msg: Result<LdapMsg, std::io::Error>,
|
||||
resp: &mut FramedWrite<WriteHalf<'_>, LdapCodec>,
|
||||
session: &mut LdapHandler<Backend>,
|
||||
) -> Result<bool>
|
||||
where
|
||||
Backend: BackendHandler + LoginHandler,
|
||||
{
|
||||
use futures_util::SinkExt;
|
||||
use std::convert::TryFrom;
|
||||
let server_op = match msg.map_err(|_e| ()).and_then(ServerOps::try_from) {
|
||||
Ok(a_value) => a_value,
|
||||
Err(an_error) => {
|
||||
let _err = resp
|
||||
.send(DisconnectionNotice::gen(
|
||||
LdapResultCode::Other,
|
||||
"Internal Server Error",
|
||||
))
|
||||
.await;
|
||||
let _err = resp.flush().await;
|
||||
bail!("Internal server error: {:?}", an_error);
|
||||
}
|
||||
};
|
||||
|
||||
match session.handle_ldap_message(server_op).await {
|
||||
None => return Ok(false),
|
||||
Some(result) => {
|
||||
for rmsg in result.into_iter() {
|
||||
if let Err(e) = resp.send(rmsg).await {
|
||||
bail!("Error while sending a response: {:?}", e);
|
||||
}
|
||||
}
|
||||
|
||||
if let Err(e) = resp.flush().await {
|
||||
bail!("Error while flushing responses: {:?}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
pub fn build_ldap_server<Backend>(
|
||||
config: &Configuration,
|
||||
backend_handler: Backend,
|
||||
server_builder: ServerBuilder,
|
||||
) -> Result<ServerBuilder>
|
||||
where
|
||||
Backend: BackendHandler + LoginHandler + 'static,
|
||||
{
|
||||
use futures_util::StreamExt;
|
||||
|
||||
let ldap_base_dn = config.ldap_base_dn.clone();
|
||||
let ldap_user_dn = config.ldap_user_dn.clone();
|
||||
Ok(
|
||||
server_builder.bind("ldap", ("0.0.0.0", config.ldap_port), move || {
|
||||
let backend_handler = backend_handler.clone();
|
||||
let ldap_base_dn = ldap_base_dn.clone();
|
||||
let ldap_user_dn = ldap_user_dn.clone();
|
||||
fn_service(move |mut stream: TcpStream| {
|
||||
let backend_handler = backend_handler.clone();
|
||||
let ldap_base_dn = ldap_base_dn.clone();
|
||||
let ldap_user_dn = ldap_user_dn.clone();
|
||||
async move {
|
||||
// Configure the codec etc.
|
||||
let (r, w) = stream.split();
|
||||
let mut requests = FramedRead::new(r, LdapCodec);
|
||||
let mut resp = FramedWrite::new(w, LdapCodec);
|
||||
|
||||
let mut session = LdapHandler::new(backend_handler, ldap_base_dn, ldap_user_dn);
|
||||
|
||||
while let Some(msg) = requests.next().await {
|
||||
if !handle_incoming_message(msg, &mut resp, &mut session).await? {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(stream)
|
||||
}
|
||||
})
|
||||
.map_err(|err: anyhow::Error| error!("Service Error: {:?}", err))
|
||||
// catch
|
||||
.and_then(move |_| {
|
||||
// finally
|
||||
ok(())
|
||||
})
|
||||
})?,
|
||||
)
|
||||
}
|
||||
25
server/src/infra/logging.rs
Normal file
25
server/src/infra/logging.rs
Normal file
@@ -0,0 +1,25 @@
|
||||
use crate::infra::configuration::Configuration;
|
||||
use anyhow::Context;
|
||||
use tracing::subscriber::set_global_default;
|
||||
use tracing_log::LogTracer;
|
||||
|
||||
pub fn init(config: Configuration) -> anyhow::Result<()> {
|
||||
let max_log_level = log_level_from_config(config);
|
||||
let subscriber = tracing_subscriber::fmt()
|
||||
.with_timer(tracing_subscriber::fmt::time::time())
|
||||
.with_target(false)
|
||||
.with_level(true)
|
||||
.with_max_level(max_log_level)
|
||||
.finish();
|
||||
LogTracer::init().context("Failed to set logger")?;
|
||||
set_global_default(subscriber).context("Failed to set subscriber")?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn log_level_from_config(config: Configuration) -> tracing::Level {
|
||||
if config.verbose {
|
||||
tracing::Level::DEBUG
|
||||
} else {
|
||||
tracing::Level::INFO
|
||||
}
|
||||
}
|
||||
12
server/src/infra/mod.rs
Normal file
12
server/src/infra/mod.rs
Normal file
@@ -0,0 +1,12 @@
|
||||
pub mod auth_service;
|
||||
pub mod cli;
|
||||
pub mod configuration;
|
||||
pub mod db_cleaner;
|
||||
pub mod graphql;
|
||||
pub mod jwt_sql_tables;
|
||||
pub mod ldap_handler;
|
||||
pub mod ldap_server;
|
||||
pub mod logging;
|
||||
pub mod sql_backend_handler;
|
||||
pub mod tcp_backend_handler;
|
||||
pub mod tcp_server;
|
||||
105
server/src/infra/sql_backend_handler.rs
Normal file
105
server/src/infra/sql_backend_handler.rs
Normal file
@@ -0,0 +1,105 @@
|
||||
use super::{jwt_sql_tables::*, tcp_backend_handler::*};
|
||||
use crate::domain::{error::*, sql_backend_handler::SqlBackendHandler};
|
||||
use async_trait::async_trait;
|
||||
use futures_util::StreamExt;
|
||||
use sea_query::{Expr, Iden, Query, SimpleExpr};
|
||||
use sqlx::Row;
|
||||
use std::collections::HashSet;
|
||||
|
||||
#[async_trait]
|
||||
impl TcpBackendHandler for SqlBackendHandler {
|
||||
async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>> {
|
||||
use sqlx::Result;
|
||||
let query = Query::select()
|
||||
.column(JwtStorage::JwtHash)
|
||||
.from(JwtStorage::Table)
|
||||
.to_string(DbQueryBuilder {});
|
||||
|
||||
sqlx::query(&query)
|
||||
.map(|row: DbRow| row.get::<i64, _>(&*JwtStorage::JwtHash.to_string()) as u64)
|
||||
.fetch(&self.sql_pool)
|
||||
.collect::<Vec<sqlx::Result<u64>>>()
|
||||
.await
|
||||
.into_iter()
|
||||
.collect::<Result<HashSet<u64>>>()
|
||||
.map_err(|e| anyhow::anyhow!(e))
|
||||
}
|
||||
|
||||
async fn create_refresh_token(&self, user: &str) -> Result<(String, chrono::Duration)> {
|
||||
use rand::{distributions::Alphanumeric, rngs::SmallRng, Rng, SeedableRng};
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{Hash, Hasher};
|
||||
// TODO: Initialize the rng only once. Maybe Arc<Cell>?
|
||||
let mut rng = SmallRng::from_entropy();
|
||||
let refresh_token: String = std::iter::repeat(())
|
||||
.map(|()| rng.sample(Alphanumeric))
|
||||
.map(char::from)
|
||||
.take(100)
|
||||
.collect();
|
||||
let refresh_token_hash = {
|
||||
let mut s = DefaultHasher::new();
|
||||
refresh_token.hash(&mut s);
|
||||
s.finish()
|
||||
};
|
||||
let duration = chrono::Duration::days(30);
|
||||
let query = Query::insert()
|
||||
.into_table(JwtRefreshStorage::Table)
|
||||
.columns(vec![
|
||||
JwtRefreshStorage::RefreshTokenHash,
|
||||
JwtRefreshStorage::UserId,
|
||||
JwtRefreshStorage::ExpiryDate,
|
||||
])
|
||||
.values_panic(vec![
|
||||
(refresh_token_hash as i64).into(),
|
||||
user.into(),
|
||||
(chrono::Utc::now() + duration).naive_utc().into(),
|
||||
])
|
||||
.to_string(DbQueryBuilder {});
|
||||
sqlx::query(&query).execute(&self.sql_pool).await?;
|
||||
Ok((refresh_token, duration))
|
||||
}
|
||||
|
||||
async fn check_token(&self, refresh_token_hash: u64, user: &str) -> Result<bool> {
|
||||
let query = Query::select()
|
||||
.expr(SimpleExpr::Value(1.into()))
|
||||
.from(JwtRefreshStorage::Table)
|
||||
.and_where(Expr::col(JwtRefreshStorage::RefreshTokenHash).eq(refresh_token_hash as i64))
|
||||
.and_where(Expr::col(JwtRefreshStorage::UserId).eq(user))
|
||||
.to_string(DbQueryBuilder {});
|
||||
Ok(sqlx::query(&query)
|
||||
.fetch_optional(&self.sql_pool)
|
||||
.await?
|
||||
.is_some())
|
||||
}
|
||||
async fn blacklist_jwts(&self, user: &str) -> DomainResult<HashSet<u64>> {
|
||||
use sqlx::Result;
|
||||
let query = Query::select()
|
||||
.column(JwtStorage::JwtHash)
|
||||
.from(JwtStorage::Table)
|
||||
.and_where(Expr::col(JwtStorage::UserId).eq(user))
|
||||
.and_where(Expr::col(JwtStorage::Blacklisted).eq(true))
|
||||
.to_string(DbQueryBuilder {});
|
||||
let result = sqlx::query(&query)
|
||||
.map(|row: DbRow| row.get::<i64, _>(&*JwtStorage::JwtHash.to_string()) as u64)
|
||||
.fetch(&self.sql_pool)
|
||||
.collect::<Vec<sqlx::Result<u64>>>()
|
||||
.await
|
||||
.into_iter()
|
||||
.collect::<Result<HashSet<u64>>>();
|
||||
let query = Query::update()
|
||||
.table(JwtStorage::Table)
|
||||
.values(vec![(JwtStorage::Blacklisted, true.into())])
|
||||
.and_where(Expr::col(JwtStorage::UserId).eq(user))
|
||||
.to_string(DbQueryBuilder {});
|
||||
sqlx::query(&query).execute(&self.sql_pool).await?;
|
||||
Ok(result?)
|
||||
}
|
||||
async fn delete_refresh_token(&self, refresh_token_hash: u64) -> DomainResult<()> {
|
||||
let query = Query::delete()
|
||||
.from_table(JwtRefreshStorage::Table)
|
||||
.and_where(Expr::col(JwtRefreshStorage::RefreshTokenHash).eq(refresh_token_hash))
|
||||
.to_string(DbQueryBuilder {});
|
||||
sqlx::query(&query).execute(&self.sql_pool).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
46
server/src/infra/tcp_backend_handler.rs
Normal file
46
server/src/infra/tcp_backend_handler.rs
Normal file
@@ -0,0 +1,46 @@
|
||||
use async_trait::async_trait;
|
||||
use std::collections::HashSet;
|
||||
|
||||
pub type DomainResult<T> = crate::domain::error::Result<T>;
|
||||
|
||||
#[async_trait]
|
||||
pub trait TcpBackendHandler {
|
||||
async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>>;
|
||||
async fn create_refresh_token(&self, user: &str) -> DomainResult<(String, chrono::Duration)>;
|
||||
async fn check_token(&self, refresh_token_hash: u64, user: &str) -> DomainResult<bool>;
|
||||
async fn blacklist_jwts(&self, user: &str) -> DomainResult<HashSet<u64>>;
|
||||
async fn delete_refresh_token(&self, refresh_token_hash: u64) -> DomainResult<()>;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
use crate::domain::handler::*;
|
||||
#[cfg(test)]
|
||||
mockall::mock! {
|
||||
pub TestTcpBackendHandler{}
|
||||
impl Clone for TestTcpBackendHandler {
|
||||
fn clone(&self) -> Self;
|
||||
}
|
||||
#[async_trait]
|
||||
impl LoginHandler for TestTcpBackendHandler {
|
||||
async fn bind(&self, request: BindRequest) -> DomainResult<()>;
|
||||
}
|
||||
#[async_trait]
|
||||
impl BackendHandler for TestTcpBackendHandler {
|
||||
async fn list_users(&self, filters: Option<RequestFilter>) -> DomainResult<Vec<User>>;
|
||||
async fn list_groups(&self) -> DomainResult<Vec<Group>>;
|
||||
async fn get_user_details(&self, user_id: &str) -> DomainResult<User>;
|
||||
async fn get_user_groups(&self, user: &str) -> DomainResult<HashSet<String>>;
|
||||
async fn create_user(&self, request: CreateUserRequest) -> DomainResult<()>;
|
||||
async fn delete_user(&self, user_id: &str) -> DomainResult<()>;
|
||||
async fn create_group(&self, group_name: &str) -> DomainResult<GroupId>;
|
||||
async fn add_user_to_group(&self, user_id: &str, group_id: GroupId) -> DomainResult<()>;
|
||||
}
|
||||
#[async_trait]
|
||||
impl TcpBackendHandler for TestTcpBackendHandler {
|
||||
async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>>;
|
||||
async fn create_refresh_token(&self, user: &str) -> DomainResult<(String, chrono::Duration)>;
|
||||
async fn check_token(&self, refresh_token_hash: u64, user: &str) -> DomainResult<bool>;
|
||||
async fn blacklist_jwts(&self, user: &str) -> DomainResult<HashSet<u64>>;
|
||||
async fn delete_refresh_token(&self, refresh_token_hash: u64) -> DomainResult<()>;
|
||||
}
|
||||
}
|
||||
134
server/src/infra/tcp_server.rs
Normal file
134
server/src/infra/tcp_server.rs
Normal file
@@ -0,0 +1,134 @@
|
||||
use crate::{
|
||||
domain::{
|
||||
error::DomainError,
|
||||
handler::{BackendHandler, LoginHandler},
|
||||
opaque_handler::OpaqueHandler,
|
||||
},
|
||||
infra::{auth_service, configuration::Configuration, tcp_backend_handler::*},
|
||||
};
|
||||
use actix_files::{Files, NamedFile};
|
||||
use actix_http::HttpServiceBuilder;
|
||||
use actix_server::ServerBuilder;
|
||||
use actix_service::map_config;
|
||||
use actix_web::{dev::AppConfig, web, App, HttpRequest, HttpResponse};
|
||||
use anyhow::{Context, Result};
|
||||
use hmac::{Hmac, NewMac};
|
||||
use sha2::Sha512;
|
||||
use std::collections::HashSet;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::RwLock;
|
||||
|
||||
async fn index(req: HttpRequest) -> actix_web::Result<NamedFile> {
|
||||
let mut path = PathBuf::new();
|
||||
path.push("../app");
|
||||
let file = req.match_info().query("filename");
|
||||
path.push(if file.is_empty() { "index.html" } else { file });
|
||||
Ok(NamedFile::open(path)?)
|
||||
}
|
||||
|
||||
pub(crate) fn error_to_http_response(error: DomainError) -> HttpResponse {
|
||||
match error {
|
||||
DomainError::AuthenticationError(_) | DomainError::AuthenticationProtocolError(_) => {
|
||||
HttpResponse::Unauthorized()
|
||||
}
|
||||
DomainError::DatabaseError(_)
|
||||
| DomainError::InternalError(_)
|
||||
| DomainError::UnknownCryptoError(_) => HttpResponse::InternalServerError(),
|
||||
DomainError::Base64DecodeError(_) | DomainError::BinarySerializationError(_) => {
|
||||
HttpResponse::BadRequest()
|
||||
}
|
||||
}
|
||||
.body(error.to_string())
|
||||
}
|
||||
|
||||
fn http_config<Backend>(
|
||||
cfg: &mut web::ServiceConfig,
|
||||
backend_handler: Backend,
|
||||
jwt_secret: String,
|
||||
jwt_blacklist: HashSet<u64>,
|
||||
) where
|
||||
Backend: TcpBackendHandler + BackendHandler + LoginHandler + OpaqueHandler + Sync + 'static,
|
||||
{
|
||||
cfg.app_data(web::Data::new(AppState::<Backend> {
|
||||
backend_handler,
|
||||
jwt_key: Hmac::new_varkey(jwt_secret.as_bytes()).unwrap(),
|
||||
jwt_blacklist: RwLock::new(jwt_blacklist),
|
||||
}))
|
||||
// Serve index.html and main.js, and default to index.html.
|
||||
.route(
|
||||
"/{filename:(index\\.html|main\\.js)?}",
|
||||
web::get().to(index),
|
||||
)
|
||||
.service(web::scope("/auth").configure(auth_service::configure_server::<Backend>))
|
||||
// API endpoint.
|
||||
.service(
|
||||
web::scope("/api")
|
||||
.wrap(auth_service::CookieToHeaderTranslatorFactory)
|
||||
.configure(super::graphql::api::configure_endpoint::<Backend>),
|
||||
)
|
||||
// Serve the /pkg path with the compiled WASM app.
|
||||
.service(Files::new("/pkg", "./app/pkg"))
|
||||
// Default to serve index.html for unknown routes, to support routing.
|
||||
.service(web::scope("/").route("/.*", web::get().to(index)));
|
||||
}
|
||||
|
||||
pub(crate) struct AppState<Backend> {
|
||||
pub backend_handler: Backend,
|
||||
pub jwt_key: Hmac<Sha512>,
|
||||
pub jwt_blacklist: RwLock<HashSet<u64>>,
|
||||
}
|
||||
|
||||
pub async fn build_tcp_server<Backend>(
|
||||
config: &Configuration,
|
||||
backend_handler: Backend,
|
||||
server_builder: ServerBuilder,
|
||||
) -> Result<ServerBuilder>
|
||||
where
|
||||
Backend: TcpBackendHandler + BackendHandler + LoginHandler + OpaqueHandler + Sync + 'static,
|
||||
{
|
||||
let jwt_secret = config.jwt_secret.clone();
|
||||
let jwt_blacklist = backend_handler.get_jwt_blacklist().await?;
|
||||
server_builder
|
||||
.bind("http", ("0.0.0.0", config.http_port), move || {
|
||||
let backend_handler = backend_handler.clone();
|
||||
let jwt_secret = jwt_secret.clone();
|
||||
let jwt_blacklist = jwt_blacklist.clone();
|
||||
HttpServiceBuilder::new()
|
||||
.finish(map_config(
|
||||
App::new().configure(move |cfg| {
|
||||
http_config(cfg, backend_handler, jwt_secret, jwt_blacklist)
|
||||
}),
|
||||
|_| AppConfig::default(),
|
||||
))
|
||||
.tcp()
|
||||
})
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"While bringing up the TCP server with port {}",
|
||||
config.http_port
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use actix_web::test::TestRequest;
|
||||
use std::path::Path;
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn test_index_ok() {
|
||||
let req = TestRequest::default().to_http_request();
|
||||
let resp = index(req).await.unwrap();
|
||||
assert_eq!(resp.path(), Path::new("../app/index.html"));
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn test_index_main_js() {
|
||||
let req = TestRequest::default()
|
||||
.param("filename", "main.js")
|
||||
.to_http_request();
|
||||
let resp = index(req).await.unwrap();
|
||||
assert_eq!(resp.path(), Path::new("../app/main.js"));
|
||||
}
|
||||
}
|
||||
88
server/src/main.rs
Normal file
88
server/src/main.rs
Normal file
@@ -0,0 +1,88 @@
|
||||
#![forbid(unsafe_code)]
|
||||
#![allow(clippy::nonstandard_macro_braces)]
|
||||
|
||||
use crate::{
|
||||
domain::{
|
||||
handler::{BackendHandler, CreateUserRequest},
|
||||
sql_backend_handler::SqlBackendHandler,
|
||||
sql_opaque_handler::register_password,
|
||||
sql_tables::PoolOptions,
|
||||
},
|
||||
infra::{cli::*, configuration::Configuration, db_cleaner::Scheduler},
|
||||
};
|
||||
use actix::Actor;
|
||||
use anyhow::{Context, Result};
|
||||
use futures_util::TryFutureExt;
|
||||
use log::*;
|
||||
|
||||
mod domain;
|
||||
mod infra;
|
||||
|
||||
async fn create_admin_user(handler: &SqlBackendHandler, config: &Configuration) -> Result<()> {
|
||||
handler
|
||||
.create_user(CreateUserRequest {
|
||||
user_id: config.ldap_user_dn.clone(),
|
||||
..Default::default()
|
||||
})
|
||||
.and_then(|_| register_password(handler, &config.ldap_user_dn, &config.ldap_user_pass))
|
||||
.await
|
||||
.context("Error creating admin user")?;
|
||||
let admin_group_id = handler
|
||||
.create_group("lldap_admin")
|
||||
.await
|
||||
.context("Error creating admin group")?;
|
||||
handler
|
||||
.add_user_to_group(&config.ldap_user_dn, admin_group_id)
|
||||
.await
|
||||
.context("Error adding admin user to group")
|
||||
}
|
||||
|
||||
async fn run_server(config: Configuration) -> Result<()> {
|
||||
let sql_pool = PoolOptions::new()
|
||||
.max_connections(5)
|
||||
.connect(&config.database_url)
|
||||
.await?;
|
||||
domain::sql_tables::init_table(&sql_pool).await?;
|
||||
let backend_handler = SqlBackendHandler::new(config.clone(), sql_pool.clone());
|
||||
create_admin_user(&backend_handler, &config)
|
||||
.await
|
||||
.unwrap_or_else(|e| warn!("Error setting up admin login/account: {}", e));
|
||||
let server_builder = infra::ldap_server::build_ldap_server(
|
||||
&config,
|
||||
backend_handler.clone(),
|
||||
actix_server::Server::build(),
|
||||
)?;
|
||||
infra::jwt_sql_tables::init_table(&sql_pool).await?;
|
||||
let server_builder =
|
||||
infra::tcp_server::build_tcp_server(&config, backend_handler, server_builder).await?;
|
||||
// Run every hour.
|
||||
let scheduler = Scheduler::new("0 0 * * * * *", sql_pool);
|
||||
scheduler.start();
|
||||
server_builder.workers(1).run().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_server_command(opts: RunOpts) -> Result<()> {
|
||||
let config = infra::configuration::init(opts.clone())?;
|
||||
infra::logging::init(config.clone())?;
|
||||
|
||||
info!("Starting LLDAP....");
|
||||
|
||||
debug!("CLI: {:#?}", opts);
|
||||
debug!("Configuration: {:#?}", config);
|
||||
|
||||
actix::run(
|
||||
run_server(config).unwrap_or_else(|e| error!("Could not bring up the servers: {:?}", e)),
|
||||
)?;
|
||||
|
||||
info!("End.");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let cli_opts = infra::cli::init();
|
||||
match cli_opts.command {
|
||||
Command::ExportGraphQLSchema(opts) => infra::graphql::api::export_schema(opts),
|
||||
Command::Run(opts) => run_server_command(opts),
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user