Move backend source to server/ subpackage

To clarify the organization.
This commit is contained in:
Valentin Tolmer
2021-08-31 16:46:31 +02:00
committed by nitnelave
parent 3eb53ba5bf
commit d8df47b35d
30 changed files with 93 additions and 88 deletions

View File

@@ -0,0 +1,22 @@
use thiserror::Error;
#[allow(clippy::enum_variant_names)]
#[derive(Error, Debug)]
pub enum DomainError {
#[error("Authentication error for `{0}`")]
AuthenticationError(String),
#[error("Database error: `{0}`")]
DatabaseError(#[from] sqlx::Error),
#[error("Authentication protocol error for `{0}`")]
AuthenticationProtocolError(#[from] lldap_auth::opaque::AuthenticationError),
#[error("Unknown crypto error: `{0}`")]
UnknownCryptoError(#[from] orion::errors::UnknownCryptoError),
#[error("Binary serialization error: `{0}`")]
BinarySerializationError(#[from] bincode::Error),
#[error("Invalid base64: `{0}`")]
Base64DecodeError(#[from] base64::DecodeError),
#[error("Internal error: `{0}`")]
InternalError(String),
}
pub type Result<T> = std::result::Result<T, DomainError>;

View File

@@ -0,0 +1,103 @@
use super::error::*;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
#[derive(PartialEq, Eq, Debug, Serialize, Deserialize)]
#[cfg_attr(not(target_arch = "wasm32"), derive(sqlx::FromRow))]
pub struct User {
pub user_id: String,
pub email: String,
pub display_name: Option<String>,
pub first_name: Option<String>,
pub last_name: Option<String>,
// pub avatar: ?,
pub creation_date: chrono::DateTime<chrono::Utc>,
}
impl Default for User {
fn default() -> Self {
use chrono::TimeZone;
User {
user_id: String::new(),
email: String::new(),
display_name: None,
first_name: None,
last_name: None,
creation_date: chrono::Utc.timestamp(0, 0),
}
}
}
#[derive(PartialEq, Eq, Debug, Serialize, Deserialize)]
pub struct Group {
pub display_name: String,
pub users: Vec<String>,
}
#[derive(PartialEq, Eq, Debug, Serialize, Deserialize, Clone)]
pub struct BindRequest {
pub name: String,
pub password: String,
}
#[derive(PartialEq, Eq, Debug, Serialize, Deserialize, Clone)]
pub enum RequestFilter {
And(Vec<RequestFilter>),
Or(Vec<RequestFilter>),
Not(Box<RequestFilter>),
Equality(String, String),
}
#[derive(PartialEq, Eq, Debug, Serialize, Deserialize, Clone, Default)]
pub struct CreateUserRequest {
// Same fields as User, but no creation_date, and with password.
pub user_id: String,
pub email: String,
pub display_name: Option<String>,
pub first_name: Option<String>,
pub last_name: Option<String>,
}
#[async_trait]
pub trait LoginHandler: Clone + Send {
async fn bind(&self, request: BindRequest) -> Result<()>;
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct GroupId(pub i32);
#[async_trait]
pub trait BackendHandler: Clone + Send {
async fn list_users(&self, filters: Option<RequestFilter>) -> Result<Vec<User>>;
async fn list_groups(&self) -> Result<Vec<Group>>;
async fn get_user_details(&self, user_id: &str) -> Result<User>;
async fn create_user(&self, request: CreateUserRequest) -> Result<()>;
async fn delete_user(&self, user_id: &str) -> Result<()>;
async fn create_group(&self, group_name: &str) -> Result<GroupId>;
async fn add_user_to_group(&self, user_id: &str, group_id: GroupId) -> Result<()>;
async fn get_user_groups(&self, user: &str) -> Result<HashSet<String>>;
}
#[cfg(test)]
mockall::mock! {
pub TestBackendHandler{}
impl Clone for TestBackendHandler {
fn clone(&self) -> Self;
}
#[async_trait]
impl BackendHandler for TestBackendHandler {
async fn list_users(&self, filters: Option<RequestFilter>) -> Result<Vec<User>>;
async fn list_groups(&self) -> Result<Vec<Group>>;
async fn get_user_details(&self, user_id: &str) -> Result<User>;
async fn create_user(&self, request: CreateUserRequest) -> Result<()>;
async fn delete_user(&self, user_id: &str) -> Result<()>;
async fn create_group(&self, group_name: &str) -> Result<GroupId>;
async fn get_user_groups(&self, user: &str) -> Result<HashSet<String>>;
async fn add_user_to_group(&self, user_id: &str, group_id: GroupId) -> Result<()>;
}
#[async_trait]
impl LoginHandler for TestBackendHandler {
async fn bind(&self, request: BindRequest) -> Result<()>;
}
}

6
server/src/domain/mod.rs Normal file
View File

@@ -0,0 +1,6 @@
pub mod error;
pub mod handler;
pub mod opaque_handler;
pub mod sql_backend_handler;
pub mod sql_opaque_handler;
pub mod sql_tables;

View File

@@ -0,0 +1,36 @@
use super::error::*;
use async_trait::async_trait;
pub use lldap_auth::{login, registration};
#[async_trait]
pub trait OpaqueHandler: Clone + Send {
async fn login_start(
&self,
request: login::ClientLoginStartRequest,
) -> Result<login::ServerLoginStartResponse>;
async fn login_finish(&self, request: login::ClientLoginFinishRequest) -> Result<String>;
async fn registration_start(
&self,
request: registration::ClientRegistrationStartRequest,
) -> Result<registration::ServerRegistrationStartResponse>;
async fn registration_finish(
&self,
request: registration::ClientRegistrationFinishRequest,
) -> Result<()>;
}
#[cfg(test)]
mockall::mock! {
pub TestOpaqueHandler{}
impl Clone for TestOpaqueHandler {
fn clone(&self) -> Self;
}
#[async_trait]
impl OpaqueHandler for TestOpaqueHandler {
async fn login_start(&self, request: login::ClientLoginStartRequest) -> Result<login::ServerLoginStartResponse>;
async fn login_finish(&self, request: login::ClientLoginFinishRequest ) -> Result<String>;
async fn registration_start(&self, request: registration::ClientRegistrationStartRequest) -> Result<registration::ServerRegistrationStartResponse>;
async fn registration_finish(&self, request: registration::ClientRegistrationFinishRequest ) -> Result<()>;
}
}

View File

@@ -0,0 +1,536 @@
use super::{error::*, handler::*, sql_tables::*};
use crate::infra::configuration::Configuration;
use async_trait::async_trait;
use futures_util::StreamExt;
use futures_util::TryStreamExt;
use sea_query::{Expr, Iden, Order, Query, SimpleExpr, Value};
use sqlx::Row;
use std::collections::HashSet;
#[derive(Debug, Clone)]
pub struct SqlBackendHandler {
pub(crate) config: Configuration,
pub(crate) sql_pool: Pool,
}
impl SqlBackendHandler {
pub fn new(config: Configuration, sql_pool: Pool) -> Self {
SqlBackendHandler { config, sql_pool }
}
}
fn get_filter_expr(filter: RequestFilter) -> SimpleExpr {
use RequestFilter::*;
fn get_repeated_filter(
fs: Vec<RequestFilter>,
field: &dyn Fn(SimpleExpr, SimpleExpr) -> SimpleExpr,
) -> SimpleExpr {
let mut it = fs.into_iter();
let first_expr = match it.next() {
None => return Expr::value(true),
Some(f) => get_filter_expr(f),
};
it.fold(first_expr, |e, f| field(e, get_filter_expr(f)))
}
match filter {
And(fs) => get_repeated_filter(fs, &SimpleExpr::and),
Or(fs) => get_repeated_filter(fs, &SimpleExpr::or),
Not(f) => Expr::not(Expr::expr(get_filter_expr(*f))),
Equality(s1, s2) => Expr::expr(Expr::cust(&s1)).eq(s2),
}
}
#[async_trait]
impl BackendHandler for SqlBackendHandler {
async fn list_users(&self, filters: Option<RequestFilter>) -> Result<Vec<User>> {
let query = {
let mut query_builder = Query::select()
.column(Users::UserId)
.column(Users::Email)
.column(Users::DisplayName)
.column(Users::FirstName)
.column(Users::LastName)
.column(Users::Avatar)
.column(Users::CreationDate)
.from(Users::Table)
.order_by(Users::UserId, Order::Asc)
.to_owned();
if let Some(filter) = filters {
if filter != RequestFilter::And(Vec::new())
&& filter != RequestFilter::Or(Vec::new())
{
query_builder.and_where(get_filter_expr(filter));
}
}
query_builder.to_string(DbQueryBuilder {})
};
let results = sqlx::query_as::<_, User>(&query)
.fetch(&self.sql_pool)
.collect::<Vec<sqlx::Result<User>>>()
.await;
Ok(results.into_iter().collect::<sqlx::Result<Vec<User>>>()?)
}
async fn list_groups(&self) -> Result<Vec<Group>> {
let query: String = Query::select()
.column(Groups::DisplayName)
.column(Memberships::UserId)
.from(Groups::Table)
.left_join(
Memberships::Table,
Expr::tbl(Groups::Table, Groups::GroupId)
.equals(Memberships::Table, Memberships::GroupId),
)
.order_by(Groups::DisplayName, Order::Asc)
.order_by(Memberships::UserId, Order::Asc)
.to_string(DbQueryBuilder {});
let mut results = sqlx::query(&query).fetch(&self.sql_pool);
let mut groups = Vec::new();
// The rows are ordered by group, user, so we need to group them into vectors.
{
let mut current_group = String::new();
let mut current_users = Vec::new();
while let Some(row) = results.try_next().await? {
let display_name = row.get::<String, _>(&*Groups::DisplayName.to_string());
if display_name != current_group {
if !current_group.is_empty() {
groups.push(Group {
display_name: current_group,
users: current_users,
});
current_users = Vec::new();
}
current_group = display_name.clone();
}
current_users.push(row.get::<String, _>(&*Memberships::UserId.to_string()));
}
groups.push(Group {
display_name: current_group,
users: current_users,
});
}
Ok(groups)
}
async fn get_user_details(&self, user_id: &str) -> Result<User> {
let query = Query::select()
.column(Users::UserId)
.column(Users::Email)
.column(Users::DisplayName)
.column(Users::FirstName)
.column(Users::LastName)
.column(Users::Avatar)
.column(Users::CreationDate)
.from(Users::Table)
.and_where(Expr::col(Users::UserId).eq(user_id))
.to_string(DbQueryBuilder {});
Ok(sqlx::query_as::<_, User>(&query)
.fetch_one(&self.sql_pool)
.await?)
}
async fn get_user_groups(&self, user: &str) -> Result<HashSet<String>> {
if user == self.config.ldap_user_dn {
let mut groups = HashSet::new();
groups.insert("lldap_admin".to_string());
return Ok(groups);
}
let query: String = Query::select()
.column(Groups::DisplayName)
.from(Groups::Table)
.inner_join(
Memberships::Table,
Expr::tbl(Groups::Table, Groups::GroupId)
.equals(Memberships::Table, Memberships::GroupId),
)
.and_where(Expr::col(Memberships::UserId).eq(user))
.to_string(DbQueryBuilder {});
sqlx::query(&query)
// Extract the group id from the row.
.map(|row: DbRow| row.get::<String, _>(&*Groups::DisplayName.to_string()))
.fetch(&self.sql_pool)
// Collect the vector of rows, each potentially an error.
.collect::<Vec<sqlx::Result<String>>>()
.await
.into_iter()
// Transform it into a single result (the first error if any), and group the group_ids
// into a HashSet.
.collect::<sqlx::Result<HashSet<_>>>()
// Map the sqlx::Error into a DomainError.
.map_err(DomainError::DatabaseError)
}
async fn create_user(&self, request: CreateUserRequest) -> Result<()> {
let columns = vec![
Users::UserId,
Users::Email,
Users::DisplayName,
Users::FirstName,
Users::LastName,
Users::CreationDate,
];
let values = vec![
request.user_id.clone().into(),
request.email.into(),
request.display_name.map(Into::into).unwrap_or(Value::Null),
request.first_name.map(Into::into).unwrap_or(Value::Null),
request.last_name.map(Into::into).unwrap_or(Value::Null),
chrono::Utc::now().naive_utc().into(),
];
let query = Query::insert()
.into_table(Users::Table)
.columns(columns)
.values_panic(values)
.to_string(DbQueryBuilder {});
sqlx::query(&query).execute(&self.sql_pool).await?;
Ok(())
}
async fn delete_user(&self, user_id: &str) -> Result<()> {
let delete_query = Query::delete()
.from_table(Users::Table)
.and_where(Expr::col(Users::UserId).eq(user_id))
.to_string(DbQueryBuilder {});
sqlx::query(&delete_query).execute(&self.sql_pool).await?;
Ok(())
}
async fn create_group(&self, group_name: &str) -> Result<GroupId> {
let query = Query::insert()
.into_table(Groups::Table)
.columns(vec![Groups::DisplayName])
.values_panic(vec![group_name.into()])
.to_string(DbQueryBuilder {});
sqlx::query(&query).execute(&self.sql_pool).await?;
let query = Query::select()
.column(Groups::GroupId)
.from(Groups::Table)
.and_where(Expr::col(Groups::DisplayName).eq(group_name))
.to_string(DbQueryBuilder {});
let row = sqlx::query(&query).fetch_one(&self.sql_pool).await?;
Ok(GroupId(row.get::<i32, _>(&*Groups::GroupId.to_string())))
}
async fn add_user_to_group(&self, user_id: &str, group_id: GroupId) -> Result<()> {
let query = Query::insert()
.into_table(Memberships::Table)
.columns(vec![Memberships::UserId, Memberships::GroupId])
.values_panic(vec![user_id.into(), group_id.0.into()])
.to_string(DbQueryBuilder {});
sqlx::query(&query).execute(&self.sql_pool).await?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::domain::sql_tables::init_table;
use crate::infra::configuration::ConfigurationBuilder;
use lldap_auth::{opaque, registration};
fn get_default_config() -> Configuration {
ConfigurationBuilder::default()
.verbose(true)
.build()
.unwrap()
}
async fn get_in_memory_db() -> Pool {
PoolOptions::new().connect("sqlite::memory:").await.unwrap()
}
async fn get_initialized_db() -> Pool {
let sql_pool = get_in_memory_db().await;
init_table(&sql_pool).await.unwrap();
sql_pool
}
async fn insert_user(handler: &SqlBackendHandler, name: &str, pass: &str) {
use crate::domain::opaque_handler::OpaqueHandler;
insert_user_no_password(handler, name).await;
let mut rng = rand::rngs::OsRng;
let client_registration_start =
opaque::client::registration::start_registration(pass, &mut rng).unwrap();
let response = handler
.registration_start(registration::ClientRegistrationStartRequest {
username: name.to_string(),
registration_start_request: client_registration_start.message,
})
.await
.unwrap();
let registration_upload = opaque::client::registration::finish_registration(
client_registration_start.state,
response.registration_response,
&mut rng,
)
.unwrap();
handler
.registration_finish(registration::ClientRegistrationFinishRequest {
server_data: response.server_data,
registration_upload: registration_upload.message,
})
.await
.unwrap();
}
async fn insert_user_no_password(handler: &SqlBackendHandler, name: &str) {
handler
.create_user(CreateUserRequest {
user_id: name.to_string(),
email: "bob@bob.bob".to_string(),
..Default::default()
})
.await
.unwrap();
}
async fn insert_group(handler: &SqlBackendHandler, name: &str) -> GroupId {
handler.create_group(name).await.unwrap()
}
async fn insert_membership(handler: &SqlBackendHandler, group_id: GroupId, user_id: &str) {
handler.add_user_to_group(user_id, group_id).await.unwrap();
}
#[tokio::test]
async fn test_bind_admin() {
let sql_pool = get_in_memory_db().await;
let config = ConfigurationBuilder::default()
.ldap_user_dn("admin".to_string())
.ldap_user_pass("test".to_string())
.build()
.unwrap();
let handler = SqlBackendHandler::new(config, sql_pool);
handler
.bind(BindRequest {
name: "admin".to_string(),
password: "test".to_string(),
})
.await
.unwrap();
}
#[tokio::test]
async fn test_bind_user() {
let sql_pool = get_initialized_db().await;
let config = get_default_config();
let handler = SqlBackendHandler::new(config, sql_pool.clone());
insert_user(&handler, "bob", "bob00").await;
handler
.bind(BindRequest {
name: "bob".to_string(),
password: "bob00".to_string(),
})
.await
.unwrap();
handler
.bind(BindRequest {
name: "andrew".to_string(),
password: "bob00".to_string(),
})
.await
.unwrap_err();
handler
.bind(BindRequest {
name: "bob".to_string(),
password: "wrong_password".to_string(),
})
.await
.unwrap_err();
}
#[tokio::test]
async fn test_user_no_password() {
let sql_pool = get_initialized_db().await;
let config = get_default_config();
let handler = SqlBackendHandler::new(config, sql_pool.clone());
insert_user_no_password(&handler, "bob").await;
handler
.bind(BindRequest {
name: "bob".to_string(),
password: "bob00".to_string(),
})
.await
.unwrap_err();
}
#[tokio::test]
async fn test_list_users() {
let sql_pool = get_initialized_db().await;
let config = get_default_config();
let handler = SqlBackendHandler::new(config, sql_pool);
insert_user(&handler, "bob", "bob00").await;
insert_user(&handler, "patrick", "pass").await;
insert_user(&handler, "John", "Pa33w0rd!").await;
{
let users = handler
.list_users(None)
.await
.unwrap()
.into_iter()
.map(|u| u.user_id)
.collect::<Vec<_>>();
assert_eq!(users, vec!["John", "bob", "patrick"]);
}
{
let users = handler
.list_users(Some(RequestFilter::Equality(
"user_id".to_string(),
"bob".to_string(),
)))
.await
.unwrap()
.into_iter()
.map(|u| u.user_id)
.collect::<Vec<_>>();
assert_eq!(users, vec!["bob"]);
}
{
let users = handler
.list_users(Some(RequestFilter::Or(vec![
RequestFilter::Equality("user_id".to_string(), "bob".to_string()),
RequestFilter::Equality("user_id".to_string(), "John".to_string()),
])))
.await
.unwrap()
.into_iter()
.map(|u| u.user_id)
.collect::<Vec<_>>();
assert_eq!(users, vec!["John", "bob"]);
}
{
let users = handler
.list_users(Some(RequestFilter::Not(Box::new(RequestFilter::Equality(
"user_id".to_string(),
"bob".to_string(),
)))))
.await
.unwrap()
.into_iter()
.map(|u| u.user_id)
.collect::<Vec<_>>();
assert_eq!(users, vec!["John", "patrick"]);
}
}
#[tokio::test]
async fn test_list_groups() {
let sql_pool = get_initialized_db().await;
let config = get_default_config();
let handler = SqlBackendHandler::new(config, sql_pool.clone());
insert_user(&handler, "bob", "bob00").await;
insert_user(&handler, "patrick", "pass").await;
insert_user(&handler, "John", "Pa33w0rd!").await;
let group_1 = insert_group(&handler, "Best Group").await;
let group_2 = insert_group(&handler, "Worst Group").await;
insert_membership(&handler, group_1, "bob").await;
insert_membership(&handler, group_1, "patrick").await;
insert_membership(&handler, group_2, "patrick").await;
insert_membership(&handler, group_2, "John").await;
assert_eq!(
handler.list_groups().await.unwrap(),
vec![
Group {
display_name: "Best Group".to_string(),
users: vec!["bob".to_string(), "patrick".to_string()]
},
Group {
display_name: "Worst Group".to_string(),
users: vec!["John".to_string(), "patrick".to_string()]
}
]
);
}
#[tokio::test]
async fn test_get_user_details() {
let sql_pool = get_initialized_db().await;
let config = get_default_config();
let handler = SqlBackendHandler::new(config, sql_pool);
insert_user(&handler, "bob", "bob00").await;
{
let user = handler.get_user_details("bob").await.unwrap();
assert_eq!(user.user_id, "bob".to_string());
}
{
handler.get_user_details("John").await.unwrap_err();
}
}
#[tokio::test]
async fn test_get_user_groups() {
let sql_pool = get_initialized_db().await;
let config = get_default_config();
let handler = SqlBackendHandler::new(config, sql_pool.clone());
insert_user(&handler, "bob", "bob00").await;
insert_user(&handler, "patrick", "pass").await;
insert_user(&handler, "John", "Pa33w0rd!").await;
let group_1 = insert_group(&handler, "Group1").await;
let group_2 = insert_group(&handler, "Group2").await;
insert_membership(&handler, group_1, "bob").await;
insert_membership(&handler, group_1, "patrick").await;
insert_membership(&handler, group_2, "patrick").await;
let mut bob_groups = HashSet::new();
bob_groups.insert("Group1".to_string());
let mut patrick_groups = HashSet::new();
patrick_groups.insert("Group1".to_string());
patrick_groups.insert("Group2".to_string());
assert_eq!(handler.get_user_groups("bob").await.unwrap(), bob_groups);
assert_eq!(
handler.get_user_groups("patrick").await.unwrap(),
patrick_groups
);
assert_eq!(
handler.get_user_groups("John").await.unwrap(),
HashSet::new()
);
}
#[tokio::test]
async fn test_delete_user() {
let sql_pool = get_initialized_db().await;
let config = get_default_config();
let handler = SqlBackendHandler::new(config, sql_pool.clone());
insert_user(&handler, "val", "s3np4i").await;
insert_user(&handler, "Hector", "Be$t").await;
insert_user(&handler, "Jennz", "boupBoup").await;
// Remove a user
let _request_result = handler.delete_user("Jennz").await.unwrap();
let users = handler
.list_users(None)
.await
.unwrap()
.into_iter()
.map(|u| u.user_id)
.collect::<Vec<_>>();
assert_eq!(users, vec!["Hector", "val"]);
// Insert new user and remove two
insert_user(&handler, "NewBoi", "Joni").await;
let _request_result = handler.delete_user("Hector").await.unwrap();
let _request_result = handler.delete_user("NewBoi").await.unwrap();
let users = handler
.list_users(None)
.await
.unwrap()
.into_iter()
.map(|u| u.user_id)
.collect::<Vec<_>>();
assert_eq!(users, vec!["val"]);
}
}

View File

@@ -0,0 +1,331 @@
use super::{
error::*,
handler::{BindRequest, LoginHandler},
opaque_handler::*,
sql_backend_handler::SqlBackendHandler,
sql_tables::*,
};
use async_trait::async_trait;
use lldap_auth::opaque;
use log::*;
use sea_query::{Expr, Iden, Query};
use sqlx::Row;
type SqlOpaqueHandler = SqlBackendHandler;
fn passwords_match(
password_file_bytes: &[u8],
clear_password: &str,
server_setup: &opaque::server::ServerSetup,
username: &str,
) -> Result<()> {
use opaque::{client, server};
let mut rng = rand::rngs::OsRng;
let client_login_start_result = client::login::start_login(clear_password, &mut rng)?;
let password_file = server::ServerRegistration::deserialize(password_file_bytes)
.map_err(opaque::AuthenticationError::ProtocolError)?;
let server_login_start_result = server::login::start_login(
&mut rng,
server_setup,
Some(password_file),
client_login_start_result.message,
username,
)?;
client::login::finish_login(
client_login_start_result.state,
server_login_start_result.message,
)?;
Ok(())
}
impl SqlBackendHandler {
fn get_orion_secret_key(&self) -> Result<orion::aead::SecretKey> {
Ok(orion::aead::SecretKey::from_slice(
self.config.get_server_keys().private(),
)?)
}
async fn get_password_file_for_user(
&self,
username: &str,
) -> Result<Option<opaque::server::ServerRegistration>> {
// Fetch the previously registered password file from the DB.
let password_file_bytes = {
let query = Query::select()
.column(Users::PasswordHash)
.from(Users::Table)
.and_where(Expr::col(Users::UserId).eq(username))
.to_string(DbQueryBuilder {});
if let Some(row) = sqlx::query(&query).fetch_optional(&self.sql_pool).await? {
if let Some(bytes) =
row.get::<Option<Vec<u8>>, _>(&*Users::PasswordHash.to_string())
{
bytes
} else {
// No password set.
return Ok(None);
}
} else {
// No such user.
return Ok(None);
}
};
opaque::server::ServerRegistration::deserialize(&password_file_bytes)
.map(Option::Some)
.map_err(|_| {
DomainError::InternalError(format!("Corrupted password file for {}", username))
})
}
}
#[async_trait]
impl LoginHandler for SqlBackendHandler {
async fn bind(&self, request: BindRequest) -> Result<()> {
if request.name == self.config.ldap_user_dn {
if request.password == self.config.ldap_user_pass {
return Ok(());
} else {
debug!(r#"Invalid password for LDAP bind user"#);
return Err(DomainError::AuthenticationError(request.name));
}
}
let query = Query::select()
.column(Users::PasswordHash)
.from(Users::Table)
.and_where(Expr::col(Users::UserId).eq(request.name.as_str()))
.to_string(DbQueryBuilder {});
if let Ok(row) = sqlx::query(&query).fetch_one(&self.sql_pool).await {
if let Some(password_hash) =
row.get::<Option<Vec<u8>>, _>(&*Users::PasswordHash.to_string())
{
if let Err(e) = passwords_match(
&password_hash,
&request.password,
self.config.get_server_setup(),
&request.name,
) {
debug!(r#"Invalid password for "{}": {}"#, request.name, e);
} else {
return Ok(());
}
} else {
debug!(r#"User "{}" has no password"#, request.name);
}
} else {
debug!(r#"No user found for "{}""#, request.name);
}
Err(DomainError::AuthenticationError(request.name))
}
}
#[async_trait]
impl OpaqueHandler for SqlOpaqueHandler {
async fn login_start(
&self,
request: login::ClientLoginStartRequest,
) -> Result<login::ServerLoginStartResponse> {
let maybe_password_file = self.get_password_file_for_user(&request.username).await?;
let mut rng = rand::rngs::OsRng;
// Get the CredentialResponse for the user, or a dummy one if no user/no password.
let start_response = opaque::server::login::start_login(
&mut rng,
self.config.get_server_setup(),
maybe_password_file,
request.login_start_request,
&request.username,
)?;
let secret_key = self.get_orion_secret_key()?;
let server_data = login::ServerData {
username: request.username,
server_login: start_response.state,
};
let encrypted_state = orion::aead::seal(&secret_key, &bincode::serialize(&server_data)?)?;
Ok(login::ServerLoginStartResponse {
server_data: base64::encode(&encrypted_state),
credential_response: start_response.message,
})
}
async fn login_finish(&self, request: login::ClientLoginFinishRequest) -> Result<String> {
let secret_key = self.get_orion_secret_key()?;
let login::ServerData {
username,
server_login,
} = bincode::deserialize(&orion::aead::open(
&secret_key,
&base64::decode(&request.server_data)?,
)?)?;
// Finish the login: this makes sure the client data is correct, and gives a session key we
// don't need.
let _session_key =
opaque::server::login::finish_login(server_login, request.credential_finalization)?
.session_key;
Ok(username)
}
async fn registration_start(
&self,
request: registration::ClientRegistrationStartRequest,
) -> Result<registration::ServerRegistrationStartResponse> {
// Generate the server-side key and derive the data to send back.
let start_response = opaque::server::registration::start_registration(
self.config.get_server_setup(),
request.registration_start_request,
&request.username,
)?;
let secret_key = self.get_orion_secret_key()?;
let server_data = registration::ServerData {
username: request.username,
};
let encrypted_state = orion::aead::seal(&secret_key, &bincode::serialize(&server_data)?)?;
Ok(registration::ServerRegistrationStartResponse {
server_data: base64::encode(encrypted_state),
registration_response: start_response.message,
})
}
async fn registration_finish(
&self,
request: registration::ClientRegistrationFinishRequest,
) -> Result<()> {
let secret_key = self.get_orion_secret_key()?;
let registration::ServerData { username } = bincode::deserialize(&orion::aead::open(
&secret_key,
&base64::decode(&request.server_data)?,
)?)?;
let password_file =
opaque::server::registration::get_password_file(request.registration_upload);
{
// Set the user password to the new password.
let update_query = Query::update()
.table(Users::Table)
.values(vec![(
Users::PasswordHash,
password_file.serialize().into(),
)])
.and_where(Expr::col(Users::UserId).eq(username))
.to_string(DbQueryBuilder {});
sqlx::query(&update_query).execute(&self.sql_pool).await?;
}
Ok(())
}
}
/// Convenience function to set a user's password.
pub(crate) async fn register_password(
opaque_handler: &SqlOpaqueHandler,
username: &str,
password: &str,
) -> Result<()> {
let mut rng = rand::rngs::OsRng;
use registration::*;
let registration_start = opaque::client::registration::start_registration(password, &mut rng)?;
let start_response = opaque_handler
.registration_start(ClientRegistrationStartRequest {
username: username.to_string(),
registration_start_request: registration_start.message,
})
.await?;
let registration_finish = opaque::client::registration::finish_registration(
registration_start.state,
start_response.registration_response,
&mut rng,
)?;
opaque_handler
.registration_finish(ClientRegistrationFinishRequest {
server_data: start_response.server_data,
registration_upload: registration_finish.message,
})
.await
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
domain::{
handler::{BackendHandler, CreateUserRequest},
sql_backend_handler::SqlBackendHandler,
sql_tables::init_table,
},
infra::configuration::{Configuration, ConfigurationBuilder},
};
fn get_default_config() -> Configuration {
ConfigurationBuilder::default()
.verbose(true)
.build()
.unwrap()
}
async fn get_in_memory_db() -> Pool {
PoolOptions::new().connect("sqlite::memory:").await.unwrap()
}
async fn get_initialized_db() -> Pool {
let sql_pool = get_in_memory_db().await;
init_table(&sql_pool).await.unwrap();
sql_pool
}
async fn insert_user_no_password(handler: &SqlBackendHandler, name: &str) {
handler
.create_user(CreateUserRequest {
user_id: name.to_string(),
email: "bob@bob.bob".to_string(),
..Default::default()
})
.await
.unwrap();
}
async fn attempt_login(
opaque_handler: &SqlOpaqueHandler,
username: &str,
password: &str,
) -> Result<()> {
let mut rng = rand::rngs::OsRng;
use login::*;
let login_start = opaque::client::login::start_login(password, &mut rng)?;
let start_response = opaque_handler
.login_start(ClientLoginStartRequest {
username: username.to_string(),
login_start_request: login_start.message,
})
.await?;
let login_finish = opaque::client::login::finish_login(
login_start.state,
start_response.credential_response,
)?;
opaque_handler
.login_finish(ClientLoginFinishRequest {
server_data: start_response.server_data,
credential_finalization: login_finish.message,
})
.await?;
Ok(())
}
#[tokio::test]
async fn test_flow() -> Result<()> {
let sql_pool = get_initialized_db().await;
let config = get_default_config();
let backend_handler = SqlBackendHandler::new(config.clone(), sql_pool.clone());
let opaque_handler = SqlOpaqueHandler::new(config, sql_pool);
insert_user_no_password(&backend_handler, "bob").await;
attempt_login(&opaque_handler, "bob", "bob00")
.await
.unwrap_err();
register_password(&opaque_handler, "bob", "bob00").await?;
attempt_login(&opaque_handler, "bob", "wrong_password")
.await
.unwrap_err();
attempt_login(&opaque_handler, "bob", "bob00").await?;
Ok(())
}
}

View File

@@ -0,0 +1,152 @@
use sea_query::*;
pub type Pool = sqlx::sqlite::SqlitePool;
pub type PoolOptions = sqlx::sqlite::SqlitePoolOptions;
pub type DbRow = sqlx::sqlite::SqliteRow;
pub type DbQueryBuilder = SqliteQueryBuilder;
#[derive(Iden)]
pub enum Users {
Table,
UserId,
Email,
DisplayName,
FirstName,
LastName,
Avatar,
CreationDate,
PasswordHash,
TotpSecret,
MfaType,
}
#[derive(Iden)]
pub enum Groups {
Table,
GroupId,
DisplayName,
}
#[derive(Iden)]
pub enum Memberships {
Table,
UserId,
GroupId,
}
pub async fn init_table(pool: &Pool) -> sqlx::Result<()> {
// SQLite needs this pragma to be turned on. Other DB might not understand this, so ignore the
// error.
let _ = sqlx::query("PRAGMA foreign_keys = ON").execute(pool).await;
sqlx::query(
&Table::create()
.table(Users::Table)
.if_not_exists()
.col(
ColumnDef::new(Users::UserId)
.string_len(255)
.not_null()
.primary_key(),
)
.col(ColumnDef::new(Users::Email).string_len(255).not_null())
.col(ColumnDef::new(Users::DisplayName).string_len(255))
.col(ColumnDef::new(Users::FirstName).string_len(255))
.col(ColumnDef::new(Users::LastName).string_len(255))
.col(ColumnDef::new(Users::Avatar).binary())
.col(ColumnDef::new(Users::CreationDate).date_time().not_null())
.col(ColumnDef::new(Users::PasswordHash).binary())
.col(ColumnDef::new(Users::TotpSecret).string_len(64))
.col(ColumnDef::new(Users::MfaType).string_len(64))
.to_string(DbQueryBuilder {}),
)
.execute(pool)
.await?;
sqlx::query(
&Table::create()
.table(Groups::Table)
.if_not_exists()
.col(
ColumnDef::new(Groups::GroupId)
.integer()
.not_null()
.primary_key(),
)
.col(
ColumnDef::new(Groups::DisplayName)
.string_len(255)
.unique_key()
.not_null(),
)
.to_string(DbQueryBuilder {}),
)
.execute(pool)
.await?;
sqlx::query(
&Table::create()
.table(Memberships::Table)
.if_not_exists()
.col(
ColumnDef::new(Memberships::UserId)
.string_len(255)
.not_null(),
)
.col(ColumnDef::new(Memberships::GroupId).integer().not_null())
.foreign_key(
ForeignKey::create()
.name("MembershipUserForeignKey")
.table(Memberships::Table, Users::Table)
.col(Memberships::UserId, Users::UserId)
.on_delete(ForeignKeyAction::Cascade)
.on_update(ForeignKeyAction::Cascade),
)
.foreign_key(
ForeignKey::create()
.name("MembershipGroupForeignKey")
.table(Memberships::Table, Groups::Table)
.col(Memberships::GroupId, Groups::GroupId)
.on_delete(ForeignKeyAction::Cascade)
.on_update(ForeignKeyAction::Cascade),
)
.to_string(DbQueryBuilder {}),
)
.execute(pool)
.await?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::prelude::*;
use sqlx::{Column, Row};
#[actix_rt::test]
async fn test_init_table() {
let sql_pool = PoolOptions::new().connect("sqlite::memory:").await.unwrap();
init_table(&sql_pool).await.unwrap();
sqlx::query(r#"INSERT INTO users
(user_id, email, display_name, first_name, last_name, creation_date, password_hash)
VALUES ("bôb", "böb@bob.bob", "Bob Bobbersön", "Bob", "Bobberson", "1970-01-01 00:00:00", "bob00")"#).execute(&sql_pool).await.unwrap();
let row =
sqlx::query(r#"SELECT display_name, creation_date FROM users WHERE user_id = "bôb""#)
.fetch_one(&sql_pool)
.await
.unwrap();
assert_eq!(row.column(0).name(), "display_name");
assert_eq!(row.get::<String, _>("display_name"), "Bob Bobbersön");
assert_eq!(
row.get::<DateTime<Utc>, _>("creation_date"),
Utc.timestamp(0, 0),
);
}
#[actix_rt::test]
async fn test_already_init_table() {
let sql_pool = PoolOptions::new().connect("sqlite::memory:").await.unwrap();
init_table(&sql_pool).await.unwrap();
init_table(&sql_pool).await.unwrap();
}
}

View File

@@ -0,0 +1,415 @@
use crate::{
domain::{
error::DomainError,
handler::{BackendHandler, BindRequest, LoginHandler},
opaque_handler::OpaqueHandler,
},
infra::{
tcp_backend_handler::*,
tcp_server::{error_to_http_response, AppState},
},
};
use actix_web::{
cookie::{Cookie, SameSite},
dev::{Service, ServiceRequest, ServiceResponse, Transform},
error::{ErrorBadRequest, ErrorUnauthorized},
web, HttpRequest, HttpResponse,
};
use anyhow::Result;
use chrono::prelude::*;
use futures::future::{ok, Ready};
use futures_util::{FutureExt, TryFutureExt};
use hmac::Hmac;
use jwt::{SignWithKey, VerifyWithKey};
use lldap_auth::{login, registration, JWTClaims};
use sha2::Sha512;
use std::collections::{hash_map::DefaultHasher, HashSet};
use std::hash::{Hash, Hasher};
use std::pin::Pin;
use std::task::{Context, Poll};
use time::ext::NumericalDuration;
type Token<S> = jwt::Token<jwt::Header, JWTClaims, S>;
type SignedToken = Token<jwt::token::Signed>;
fn create_jwt(key: &Hmac<Sha512>, user: String, groups: HashSet<String>) -> SignedToken {
let claims = JWTClaims {
exp: Utc::now() + chrono::Duration::days(1),
iat: Utc::now(),
user,
groups,
};
let header = jwt::Header {
algorithm: jwt::AlgorithmType::Hs512,
..Default::default()
};
jwt::Token::new(header, claims).sign_with_key(key).unwrap()
}
fn get_refresh_token_from_cookie(
request: HttpRequest,
) -> std::result::Result<(u64, String), HttpResponse> {
match request.cookie("refresh_token") {
None => Err(HttpResponse::Unauthorized().body("Missing refresh token")),
Some(t) => match t.value().split_once("+") {
None => Err(HttpResponse::Unauthorized().body("Invalid refresh token")),
Some((token, u)) => {
let refresh_token_hash = {
let mut s = DefaultHasher::new();
token.hash(&mut s);
s.finish()
};
Ok((refresh_token_hash, u.to_string()))
}
},
}
}
async fn get_refresh<Backend>(
data: web::Data<AppState<Backend>>,
request: HttpRequest,
) -> HttpResponse
where
Backend: TcpBackendHandler + BackendHandler + 'static,
{
let backend_handler = &data.backend_handler;
let jwt_key = &data.jwt_key;
let (refresh_token_hash, user) = match get_refresh_token_from_cookie(request) {
Ok(t) => t,
Err(http_response) => return http_response,
};
let res_found = data
.backend_handler
.check_token(refresh_token_hash, &user)
.await;
// Async closures are not supported yet.
match res_found {
Ok(found) => {
if found {
backend_handler.get_user_groups(&user).await
} else {
Err(DomainError::AuthenticationError(
"Invalid refresh token".to_string(),
))
}
}
Err(e) => Err(e),
}
.map(|groups| create_jwt(jwt_key, user.to_string(), groups))
.map(|token| {
HttpResponse::Ok()
.cookie(
Cookie::build("token", token.as_str())
.max_age(1.days())
.path("/api")
.http_only(true)
.same_site(SameSite::Strict)
.finish(),
)
.body(token.as_str().to_owned())
})
.unwrap_or_else(error_to_http_response)
}
async fn get_logout<Backend>(
data: web::Data<AppState<Backend>>,
request: HttpRequest,
) -> HttpResponse
where
Backend: TcpBackendHandler + BackendHandler + 'static,
{
let (refresh_token_hash, user) = match get_refresh_token_from_cookie(request) {
Ok(t) => t,
Err(http_response) => return http_response,
};
if let Err(response) = data
.backend_handler
.delete_refresh_token(refresh_token_hash)
.map_err(error_to_http_response)
.await
{
return response;
};
match data
.backend_handler
.blacklist_jwts(&user)
.map_err(error_to_http_response)
.await
{
Ok(new_blacklisted_jwts) => {
let mut jwt_blacklist = data.jwt_blacklist.write().unwrap();
for jwt in new_blacklisted_jwts {
jwt_blacklist.insert(jwt);
}
}
Err(response) => return response,
};
HttpResponse::Ok()
.cookie(
Cookie::build("token", "")
.max_age(0.days())
.path("/api")
.http_only(true)
.same_site(SameSite::Strict)
.finish(),
)
.cookie(
Cookie::build("refresh_token", "")
.max_age(0.days())
.path("/auth")
.http_only(true)
.same_site(SameSite::Strict)
.finish(),
)
.finish()
}
pub(crate) fn error_to_api_response<T>(error: DomainError) -> ApiResult<T> {
ApiResult::Right(error_to_http_response(error))
}
pub type ApiResult<M> = actix_web::Either<web::Json<M>, HttpResponse>;
async fn opaque_login_start<Backend>(
data: web::Data<AppState<Backend>>,
request: web::Json<login::ClientLoginStartRequest>,
) -> ApiResult<login::ServerLoginStartResponse>
where
Backend: OpaqueHandler + 'static,
{
data.backend_handler
.login_start(request.into_inner())
.await
.map(|res| ApiResult::Left(web::Json(res)))
.unwrap_or_else(error_to_api_response)
}
async fn get_login_successful_response<Backend>(
data: &web::Data<AppState<Backend>>,
name: &str,
) -> HttpResponse
where
Backend: TcpBackendHandler + BackendHandler,
{
// The authentication was successful, we need to fetch the groups to create the JWT
// token.
data.backend_handler
.get_user_groups(name)
.and_then(|g| async { Ok((g, data.backend_handler.create_refresh_token(name).await?)) })
.await
.map(|(groups, (refresh_token, max_age))| {
let token = create_jwt(&data.jwt_key, name.to_string(), groups);
HttpResponse::Ok()
.cookie(
Cookie::build("token", token.as_str())
.max_age(1.days())
.path("/api")
.http_only(true)
.same_site(SameSite::Strict)
.finish(),
)
.cookie(
Cookie::build("refresh_token", refresh_token + "+" + name)
.max_age(max_age.num_days().days())
.path("/auth")
.http_only(true)
.same_site(SameSite::Strict)
.finish(),
)
.body(token.as_str().to_owned())
})
.unwrap_or_else(error_to_http_response)
}
async fn opaque_login_finish<Backend>(
data: web::Data<AppState<Backend>>,
request: web::Json<login::ClientLoginFinishRequest>,
) -> HttpResponse
where
Backend: TcpBackendHandler + BackendHandler + OpaqueHandler + 'static,
{
let name = match data
.backend_handler
.login_finish(request.into_inner())
.await
{
Ok(n) => n,
Err(e) => return error_to_http_response(e),
};
get_login_successful_response(&data, &name).await
}
async fn post_authorize<Backend>(
data: web::Data<AppState<Backend>>,
request: web::Json<BindRequest>,
) -> HttpResponse
where
Backend: TcpBackendHandler + BackendHandler + LoginHandler + 'static,
{
let name = request.name.clone();
if let Err(e) = data.backend_handler.bind(request.into_inner()).await {
return error_to_http_response(e);
}
get_login_successful_response(&data, &name).await
}
async fn opaque_register_start<Backend>(
data: web::Data<AppState<Backend>>,
request: web::Json<registration::ClientRegistrationStartRequest>,
) -> ApiResult<registration::ServerRegistrationStartResponse>
where
Backend: OpaqueHandler + 'static,
{
data.backend_handler
.registration_start(request.into_inner())
.await
.map(|res| ApiResult::Left(web::Json(res)))
.unwrap_or_else(error_to_api_response)
}
async fn opaque_register_finish<Backend>(
data: web::Data<AppState<Backend>>,
request: web::Json<registration::ClientRegistrationFinishRequest>,
) -> HttpResponse
where
Backend: TcpBackendHandler + BackendHandler + OpaqueHandler + 'static,
{
if let Err(e) = data
.backend_handler
.registration_finish(request.into_inner())
.await
{
return error_to_http_response(e);
}
HttpResponse::Ok().finish()
}
pub struct CookieToHeaderTranslatorFactory;
impl<S> Transform<S, ServiceRequest> for CookieToHeaderTranslatorFactory
where
S: Service<ServiceRequest, Response = ServiceResponse, Error = actix_web::Error>,
S::Future: 'static,
{
type Response = ServiceResponse;
type Error = actix_web::Error;
type InitError = ();
type Transform = CookieToHeaderTranslator<S>;
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
ok(CookieToHeaderTranslator { service })
}
}
pub struct CookieToHeaderTranslator<S> {
service: S,
}
impl<S> Service<ServiceRequest> for CookieToHeaderTranslator<S>
where
S: Service<ServiceRequest, Response = ServiceResponse, Error = actix_web::Error>,
S::Future: 'static,
{
type Response = ServiceResponse;
type Error = actix_web::Error;
#[allow(clippy::type_complexity)]
type Future = Pin<Box<dyn core::future::Future<Output = Result<Self::Response, Self::Error>>>>;
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx)
}
fn call(&self, mut req: ServiceRequest) -> Self::Future {
if let Some(token_cookie) = req.cookie("token") {
if let Ok(header_value) = actix_http::header::HeaderValue::from_str(&format!(
"Bearer {}",
token_cookie.value()
)) {
req.headers_mut()
.insert(actix_http::header::AUTHORIZATION, header_value);
} else {
return async move {
Ok(req.error_response(ErrorBadRequest("Invalid token cookie")))
}
.boxed_local();
}
};
Box::pin(self.service.call(req))
}
}
pub struct ValidationResults {
pub user: String,
pub is_admin: bool,
}
impl ValidationResults {
#[cfg(test)]
pub fn admin() -> Self {
Self {
user: "admin".to_string(),
is_admin: true,
}
}
pub fn can_access(&self, user: &str) -> bool {
self.is_admin || self.user == user
}
}
pub(crate) fn check_if_token_is_valid<Backend>(
state: &AppState<Backend>,
token_str: &str,
) -> Result<ValidationResults, actix_web::Error> {
let token: Token<_> = VerifyWithKey::verify_with_key(token_str, &state.jwt_key)
.map_err(|_| ErrorUnauthorized("Invalid JWT"))?;
if token.claims().exp.lt(&Utc::now()) {
return Err(ErrorUnauthorized("Expired JWT"));
}
if token.header().algorithm != jwt::AlgorithmType::Hs512 {
return Err(ErrorUnauthorized(format!(
"Unsupported JWT algorithm: '{:?}'. Supported ones are: ['HS512']",
token.header().algorithm
)));
}
let jwt_hash = {
let mut s = DefaultHasher::new();
token_str.hash(&mut s);
s.finish()
};
if state.jwt_blacklist.read().unwrap().contains(&jwt_hash) {
return Err(ErrorUnauthorized("JWT was logged out"));
}
let is_admin = token.claims().groups.contains("lldap_admin");
Ok(ValidationResults {
user: token.claims().user.clone(),
is_admin,
})
}
pub fn configure_server<Backend>(cfg: &mut web::ServiceConfig)
where
Backend: TcpBackendHandler + LoginHandler + OpaqueHandler + BackendHandler + 'static,
{
cfg.service(web::resource("").route(web::post().to(post_authorize::<Backend>)))
.service(
web::resource("/opaque/login/start")
.route(web::post().to(opaque_login_start::<Backend>)),
)
.service(
web::resource("/opaque/login/finish")
.route(web::post().to(opaque_login_finish::<Backend>)),
)
.service(
web::resource("/opaque/register/start")
.route(web::post().to(opaque_register_start::<Backend>)),
)
.service(
web::resource("/opaque/register/finish")
.route(web::post().to(opaque_register_finish::<Backend>)),
)
.service(web::resource("/refresh").route(web::get().to(get_refresh::<Backend>)))
.service(web::resource("/logout").route(web::get().to(get_logout::<Backend>)));
}

50
server/src/infra/cli.rs Normal file
View File

@@ -0,0 +1,50 @@
use clap::Clap;
/// lldap is a lightweight LDAP server
#[derive(Debug, Clap, Clone)]
#[clap(version = "0.1", author = "The LLDAP team")]
pub struct CLIOpts {
/// Export
#[clap(subcommand)]
pub command: Command,
}
#[derive(Debug, Clap, Clone)]
pub enum Command {
/// Export the GraphQL schema to *.graphql.
#[clap(name = "export_graphql_schema")]
ExportGraphQLSchema(ExportGraphQLSchemaOpts),
/// Run the LDAP and GraphQL server.
#[clap(name = "run")]
Run(RunOpts),
}
#[derive(Debug, Clap, Clone)]
pub struct RunOpts {
/// Change config file name
#[clap(short, long, default_value = "lldap_config.toml")]
pub config_file: String,
/// Change ldap port. Default: 389
#[clap(long)]
pub ldap_port: Option<u16>,
/// Change ldap ssl port. Default: 636
#[clap(long)]
pub ldaps_port: Option<u16>,
/// Set verbose logging
#[clap(short, long)]
pub verbose: bool,
}
#[derive(Debug, Clap, Clone)]
pub struct ExportGraphQLSchemaOpts {
/// Output to a file. If not specified, the config is printed to the standard output.
#[clap(short, long)]
pub output_file: Option<String>,
}
pub fn init() -> CLIOpts {
CLIOpts::parse()
}

View File

@@ -0,0 +1,121 @@
use anyhow::{Context, Result};
use figment::{
providers::{Env, Format, Serialized, Toml},
Figment,
};
use lldap_auth::opaque::{server::ServerSetup, KeyPair};
use serde::{Deserialize, Serialize};
use crate::infra::cli::RunOpts;
#[derive(Clone, Debug, Deserialize, Serialize, derive_builder::Builder)]
#[builder(
pattern = "owned",
default = "Configuration::default()",
build_fn(name = "private_build", validate = "Self::validate")
)]
pub struct Configuration {
pub ldap_port: u16,
pub ldaps_port: u16,
pub http_port: u16,
pub jwt_secret: String,
pub ldap_base_dn: String,
pub ldap_user_dn: String,
pub ldap_user_pass: String,
pub database_url: String,
pub verbose: bool,
pub key_file: String,
#[serde(skip)]
#[builder(field(private), setter(strip_option))]
server_setup: Option<ServerSetup>,
}
impl ConfigurationBuilder {
#[cfg(test)]
pub fn build(self) -> Result<Configuration> {
let server_setup = get_server_setup(self.key_file.as_deref().unwrap_or("server_key"))?;
Ok(self.server_setup(server_setup).private_build()?)
}
fn validate(&self) -> Result<(), String> {
if self.server_setup.is_none() {
Err("Don't use `private_build`, use `build` instead".to_string())
} else {
Ok(())
}
}
}
impl Configuration {
pub fn get_server_setup(&self) -> &ServerSetup {
self.server_setup.as_ref().unwrap()
}
pub fn get_server_keys(&self) -> &KeyPair {
self.get_server_setup().keypair()
}
fn merge_with_cli(mut self: Configuration, cli_opts: RunOpts) -> Configuration {
if cli_opts.verbose {
self.verbose = true;
}
if let Some(port) = cli_opts.ldap_port {
self.ldap_port = port;
}
if let Some(port) = cli_opts.ldaps_port {
self.ldaps_port = port;
}
self
}
pub(super) fn default() -> Self {
Configuration {
ldap_port: 3890,
ldaps_port: 6360,
http_port: 17170,
jwt_secret: String::from("secretjwtsecret"),
ldap_base_dn: String::from("dc=example,dc=com"),
// cn=admin,dc=example,dc=com
ldap_user_dn: String::from("admin"),
ldap_user_pass: String::from("password"),
database_url: String::from("sqlite://users.db?mode=rwc"),
verbose: false,
key_file: String::from("server_key"),
server_setup: None,
}
}
}
fn get_server_setup(file_path: &str) -> Result<ServerSetup> {
use std::path::Path;
let path = Path::new(file_path);
if path.exists() {
let bytes =
std::fs::read(file_path).context(format!("Could not read key file `{}`", file_path))?;
Ok(ServerSetup::deserialize(&bytes)?)
} else {
let mut rng = rand::rngs::OsRng;
let server_setup = ServerSetup::new(&mut rng);
std::fs::write(path, server_setup.serialize()).context(format!(
"Could not write the generated server setup to file `{}`",
file_path,
))?;
Ok(server_setup)
}
}
pub fn init(cli_opts: RunOpts) -> Result<Configuration> {
let config_file = cli_opts.config_file.clone();
let config: Configuration = Figment::from(Serialized::defaults(Configuration::default()))
.merge(Toml::file(config_file))
.merge(Env::prefixed("LLDAP_"))
.extract()?;
let mut config = config.merge_with_cli(cli_opts);
config.server_setup = Some(get_server_setup(&config.key_file)?);
Ok(config)
}

View File

@@ -0,0 +1,82 @@
use crate::{
domain::sql_tables::{DbQueryBuilder, Pool},
infra::jwt_sql_tables::{JwtRefreshStorage, JwtStorage},
};
use actix::prelude::*;
use chrono::Local;
use cron::Schedule;
use sea_query::{Expr, Query};
use std::{str::FromStr, time::Duration};
// Define actor
pub struct Scheduler {
schedule: Schedule,
sql_pool: Pool,
}
// Provide Actor implementation for our actor
impl Actor for Scheduler {
type Context = Context<Self>;
fn started(&mut self, context: &mut Context<Self>) {
log::info!("DB Cleanup Cron started");
context.run_later(self.duration_until_next(), move |this, ctx| {
this.schedule_task(ctx)
});
}
fn stopped(&mut self, _ctx: &mut Context<Self>) {
log::info!("DB Cleanup stopped");
}
}
impl Scheduler {
pub fn new(cron_expression: &str, sql_pool: Pool) -> Self {
let schedule = Schedule::from_str(cron_expression).unwrap();
Self { schedule, sql_pool }
}
fn schedule_task(&self, ctx: &mut Context<Self>) {
log::info!("Cleaning DB");
let future = actix::fut::wrap_future::<_, Self>(Self::cleanup_db(self.sql_pool.clone()));
ctx.spawn(future);
ctx.run_later(self.duration_until_next(), move |this, ctx| {
this.schedule_task(ctx)
});
}
async fn cleanup_db(sql_pool: Pool) {
if let Err(e) = sqlx::query(
&Query::delete()
.from_table(JwtRefreshStorage::Table)
.and_where(Expr::col(JwtRefreshStorage::ExpiryDate).lt(Local::now().naive_utc()))
.to_string(DbQueryBuilder {}),
)
.execute(&sql_pool)
.await
{
log::error!("DB error while cleaning up JWT refresh tokens: {}", e);
};
if let Err(e) = sqlx::query(
&Query::delete()
.from_table(JwtStorage::Table)
.and_where(Expr::col(JwtStorage::ExpiryDate).lt(Local::now().naive_utc()))
.to_string(DbQueryBuilder {}),
)
.execute(&sql_pool)
.await
{
log::error!("DB error while cleaning up JWT storage: {}", e);
};
log::info!("DB cleaned!");
}
fn duration_until_next(&self) -> Duration {
let now = Local::now();
let next = self.schedule.upcoming(Local).next().unwrap();
let duration_until = next.signed_duration_since(now);
duration_until.to_std().unwrap()
}
}

View File

@@ -0,0 +1,100 @@
use crate::{
domain::handler::BackendHandler,
infra::{
auth_service::{check_if_token_is_valid, ValidationResults},
cli::ExportGraphQLSchemaOpts,
tcp_server::AppState,
},
};
use actix_web::{web, Error, HttpResponse};
use actix_web_httpauth::extractors::bearer::BearerAuth;
use juniper::{EmptySubscription, RootNode};
use juniper_actix::{graphiql_handler, graphql_handler, playground_handler};
use super::{mutation::Mutation, query::Query};
pub struct Context<Handler: BackendHandler> {
pub handler: Box<Handler>,
pub validation_result: ValidationResults,
}
impl<Handler: BackendHandler> juniper::Context for Context<Handler> {}
type Schema<Handler> =
RootNode<'static, Query<Handler>, Mutation<Handler>, EmptySubscription<Context<Handler>>>;
fn schema<Handler: BackendHandler + Sync>() -> Schema<Handler> {
Schema::new(
Query::<Handler>::new(),
Mutation::<Handler>::new(),
EmptySubscription::<Context<Handler>>::new(),
)
}
pub fn export_schema(opts: ExportGraphQLSchemaOpts) -> anyhow::Result<()> {
use crate::domain::sql_backend_handler::SqlBackendHandler;
use anyhow::Context;
let output = schema::<SqlBackendHandler>().as_schema_language();
match opts.output_file {
None => println!("{}", output),
Some(path) => {
use std::fs::File;
use std::io::prelude::*;
use std::path::Path;
let path = Path::new(&path);
let mut file =
File::create(&path).context(format!("unable to open '{}'", path.display()))?;
file.write_all(output.as_bytes())
.context(format!("unable to write in '{}'", path.display()))?;
}
}
Ok(())
}
async fn graphiql_route() -> Result<HttpResponse, Error> {
graphiql_handler("/api/graphql", None).await
}
async fn playground_route() -> Result<HttpResponse, Error> {
playground_handler("/api/graphql", None).await
}
async fn graphql_route<Handler: BackendHandler + Sync>(
req: actix_web::HttpRequest,
mut payload: actix_web::web::Payload,
data: web::Data<AppState<Handler>>,
) -> Result<HttpResponse, Error> {
use actix_web::FromRequest;
let bearer = BearerAuth::from_request(&req, &mut payload.0).await?;
let validation_result = check_if_token_is_valid(&data, bearer.token())?;
let context = Context::<Handler> {
handler: Box::new(data.backend_handler.clone()),
validation_result,
};
graphql_handler(&schema(), &context, req, payload).await
}
pub fn configure_endpoint<Backend>(cfg: &mut web::ServiceConfig)
where
Backend: BackendHandler + Sync + 'static,
{
let json_config = web::JsonConfig::default()
.limit(4096)
.error_handler(|err, _req| {
// create custom error response
log::error!("API error: {}", err);
let msg = err.to_string();
actix_web::error::InternalError::from_response(
err,
HttpResponse::BadRequest().body(msg),
)
.into()
});
cfg.app_data(json_config);
cfg.service(
web::resource("/graphql")
.route(web::post().to(graphql_route::<Backend>))
.route(web::get().to(graphql_route::<Backend>)),
);
cfg.service(web::resource("/graphql/playground").route(web::get().to(playground_route)));
cfg.service(web::resource("/graphql/graphiql").route(web::get().to(graphiql_route)));
}

View File

@@ -0,0 +1,3 @@
pub mod api;
pub mod mutation;
pub mod query;

View File

@@ -0,0 +1,55 @@
use crate::domain::handler::{BackendHandler, CreateUserRequest};
use juniper::{graphql_object, FieldResult, GraphQLInputObject};
use super::api::Context;
#[derive(PartialEq, Eq, Debug)]
/// The top-level GraphQL mutation type.
pub struct Mutation<Handler: BackendHandler> {
_phantom: std::marker::PhantomData<Box<Handler>>,
}
impl<Handler: BackendHandler> Mutation<Handler> {
pub fn new() -> Self {
Self {
_phantom: std::marker::PhantomData,
}
}
}
#[derive(PartialEq, Eq, Debug, GraphQLInputObject)]
/// The details required to create a user.
pub struct UserInput {
id: String,
email: String,
display_name: Option<String>,
first_name: Option<String>,
last_name: Option<String>,
}
#[graphql_object(context = Context<Handler>)]
impl<Handler: BackendHandler + Sync> Mutation<Handler> {
async fn create_user(
context: &Context<Handler>,
user: UserInput,
) -> FieldResult<super::query::User<Handler>> {
if !context.validation_result.is_admin {
return Err("Unauthorized user creation".into());
}
context
.handler
.create_user(CreateUserRequest {
user_id: user.id.clone(),
email: user.email,
display_name: user.display_name,
first_name: user.first_name,
last_name: user.last_name,
})
.await?;
Ok(context
.handler
.get_user_details(&user.id)
.await
.map(Into::into)?)
}
}

View File

@@ -0,0 +1,348 @@
use crate::domain::handler::BackendHandler;
use juniper::{graphql_object, FieldResult, GraphQLInputObject};
use serde::{Deserialize, Serialize};
use std::convert::TryInto;
type DomainRequestFilter = crate::domain::handler::RequestFilter;
type DomainUser = crate::domain::handler::User;
use super::api::Context;
#[derive(PartialEq, Eq, Debug, GraphQLInputObject)]
/// A filter for requests, specifying a boolean expression based on field constraints. Only one of
/// the fields can be set at a time.
pub struct RequestFilter {
any: Option<Vec<RequestFilter>>,
all: Option<Vec<RequestFilter>>,
not: Option<Box<RequestFilter>>,
eq: Option<EqualityConstraint>,
}
impl TryInto<DomainRequestFilter> for RequestFilter {
type Error = String;
fn try_into(self) -> Result<DomainRequestFilter, Self::Error> {
let mut field_count = 0;
if self.any.is_some() {
field_count += 1;
}
if self.all.is_some() {
field_count += 1;
}
if self.not.is_some() {
field_count += 1;
}
if self.eq.is_some() {
field_count += 1;
}
if field_count == 0 {
return Err("No field specified in request filter".to_string());
}
if field_count > 1 {
return Err("Multiple fields specified in request filter".to_string());
}
if let Some(e) = self.eq {
return Ok(DomainRequestFilter::Equality(e.field, e.value));
}
if let Some(c) = self.any {
return Ok(DomainRequestFilter::Or(
c.into_iter()
.map(TryInto::try_into)
.collect::<Result<Vec<_>, String>>()?,
));
}
if let Some(c) = self.all {
return Ok(DomainRequestFilter::And(
c.into_iter()
.map(TryInto::try_into)
.collect::<Result<Vec<_>, String>>()?,
));
}
if let Some(c) = self.not {
return Ok(DomainRequestFilter::Not(Box::new((*c).try_into()?)));
}
unreachable!();
}
}
#[derive(PartialEq, Eq, Debug, GraphQLInputObject)]
pub struct EqualityConstraint {
field: String,
value: String,
}
#[derive(PartialEq, Eq, Debug)]
/// The top-level GraphQL query type.
pub struct Query<Handler: BackendHandler> {
_phantom: std::marker::PhantomData<Box<Handler>>,
}
impl<Handler: BackendHandler> Query<Handler> {
pub fn new() -> Self {
Self {
_phantom: std::marker::PhantomData,
}
}
}
#[graphql_object(context = Context<Handler>)]
impl<Handler: BackendHandler + Sync> Query<Handler> {
fn api_version() -> &'static str {
"1.0"
}
pub async fn user(context: &Context<Handler>, user_id: String) -> FieldResult<User<Handler>> {
if !context.validation_result.can_access(&user_id) {
return Err("Unauthorized access to user data".into());
}
Ok(context
.handler
.get_user_details(&user_id)
.await
.map(Into::into)?)
}
async fn users(
context: &Context<Handler>,
#[graphql(name = "where")] filters: Option<RequestFilter>,
) -> FieldResult<Vec<User<Handler>>> {
if !context.validation_result.is_admin {
return Err("Unauthorized access to user list".into());
}
Ok(context
.handler
.list_users(filters.map(TryInto::try_into).transpose()?)
.await
.map(|v| v.into_iter().map(Into::into).collect())?)
}
}
#[derive(PartialEq, Eq, Debug, Serialize, Deserialize)]
/// Represents a single user.
pub struct User<Handler: BackendHandler> {
user: DomainUser,
_phantom: std::marker::PhantomData<Box<Handler>>,
}
impl<Handler: BackendHandler> Default for User<Handler> {
fn default() -> Self {
Self {
user: DomainUser::default(),
_phantom: std::marker::PhantomData,
}
}
}
#[graphql_object(context = Context<Handler>)]
impl<Handler: BackendHandler + Sync> User<Handler> {
fn id(&self) -> &str {
&self.user.user_id
}
fn email(&self) -> &str {
&self.user.email
}
fn display_name(&self) -> Option<&String> {
self.user.display_name.as_ref()
}
fn first_name(&self) -> Option<&String> {
self.user.first_name.as_ref()
}
fn last_name(&self) -> Option<&String> {
self.user.last_name.as_ref()
}
fn creation_date(&self) -> chrono::DateTime<chrono::Utc> {
self.user.creation_date
}
/// The groups to which this user belongs.
async fn groups(&self, context: &Context<Handler>) -> FieldResult<Vec<Group<Handler>>> {
Ok(context
.handler
.get_user_groups(&self.user.user_id)
.await
.map(|set| set.into_iter().map(Into::into).collect())?)
}
}
impl<Handler: BackendHandler> From<DomainUser> for User<Handler> {
fn from(user: DomainUser) -> Self {
Self {
user,
_phantom: std::marker::PhantomData,
}
}
}
#[derive(PartialEq, Eq, Debug, Serialize, Deserialize)]
/// Represents a single group.
pub struct Group<Handler: BackendHandler> {
group_id: String,
_phantom: std::marker::PhantomData<Box<Handler>>,
}
#[graphql_object(context = Context<Handler>)]
impl<Handler: BackendHandler + Sync> Group<Handler> {
fn id(&self) -> String {
self.group_id.clone()
}
/// The groups to which this user belongs.
async fn users(&self, context: &Context<Handler>) -> FieldResult<Vec<User<Handler>>> {
if !context.validation_result.is_admin {
return Err("Unauthorized access to group data".into());
}
unimplemented!()
}
}
impl<Handler: BackendHandler> From<String> for Group<Handler> {
fn from(group_id: String) -> Self {
Self {
group_id,
_phantom: std::marker::PhantomData,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{domain::handler::MockTestBackendHandler, infra::auth_service::ValidationResults};
use juniper::{
execute, graphql_value, DefaultScalarValue, EmptyMutation, EmptySubscription, GraphQLType,
RootNode, Variables,
};
use mockall::predicate::eq;
use std::collections::HashSet;
fn schema<'q, C, Q>(query_root: Q) -> RootNode<'q, Q, EmptyMutation<C>, EmptySubscription<C>>
where
Q: GraphQLType<DefaultScalarValue, Context = C, TypeInfo = ()> + 'q,
{
RootNode::new(
query_root,
EmptyMutation::<C>::new(),
EmptySubscription::<C>::new(),
)
}
#[tokio::test]
async fn get_user_by_id() {
const QUERY: &str = r#"{
user(userId: "bob") {
id
email
groups {
id
}
}
}"#;
let mut mock = MockTestBackendHandler::new();
mock.expect_get_user_details()
.with(eq("bob"))
.return_once(|_| {
Ok(DomainUser {
user_id: "bob".to_string(),
email: "bob@bobbers.on".to_string(),
..Default::default()
})
});
let mut groups = HashSet::<String>::new();
groups.insert("Bobbersons".to_string());
mock.expect_get_user_groups()
.with(eq("bob"))
.return_once(|_| Ok(groups));
let context = Context::<MockTestBackendHandler> {
handler: Box::new(mock),
validation_result: ValidationResults::admin(),
};
let schema = schema(Query::<MockTestBackendHandler>::new());
assert_eq!(
execute(QUERY, None, &schema, &Variables::new(), &context).await,
Ok((
graphql_value!(
{
"user": {
"id": "bob",
"email": "bob@bobbers.on",
"groups": [{"id": "Bobbersons"}]
}
}),
vec![]
))
);
}
#[tokio::test]
async fn list_users() {
const QUERY: &str = r#"{
users(filters: {
any: [
{eq: {
field: "id"
value: "bob"
}},
{eq: {
field: "email"
value: "robert@bobbers.on"
}}
]}) {
id
email
}
}"#;
let mut mock = MockTestBackendHandler::new();
use crate::domain::handler::RequestFilter;
mock.expect_list_users()
.with(eq(Some(RequestFilter::Or(vec![
RequestFilter::Equality("id".to_string(), "bob".to_string()),
RequestFilter::Equality("email".to_string(), "robert@bobbers.on".to_string()),
]))))
.return_once(|_| {
Ok(vec![
DomainUser {
user_id: "bob".to_string(),
email: "bob@bobbers.on".to_string(),
..Default::default()
},
DomainUser {
user_id: "robert".to_string(),
email: "robert@bobbers.on".to_string(),
..Default::default()
},
])
});
let context = Context::<MockTestBackendHandler> {
handler: Box::new(mock),
validation_result: ValidationResults::admin(),
};
let schema = schema(Query::<MockTestBackendHandler>::new());
assert_eq!(
execute(QUERY, None, &schema, &Variables::new(), &context).await,
Ok((
graphql_value!(
{
"users": [
{
"id": "bob",
"email": "bob@bobbers.on"
},
{
"id": "robert",
"email": "robert@bobbers.on"
},
]
}),
vec![]
))
);
}
}

View File

@@ -0,0 +1,99 @@
use sea_query::*;
pub use crate::domain::sql_tables::*;
/// Contains the refresh tokens for a given user.
#[derive(Iden)]
pub enum JwtRefreshStorage {
Table,
RefreshTokenHash,
UserId,
ExpiryDate,
}
/// Contains the blacklisted JWT that haven't expired yet.
#[derive(Iden)]
pub enum JwtStorage {
Table,
JwtHash,
UserId,
ExpiryDate,
Blacklisted,
}
/// This needs to be initialized after the domain tables are.
pub async fn init_table(pool: &Pool) -> sqlx::Result<()> {
sqlx::query(
&Table::create()
.table(JwtRefreshStorage::Table)
.if_not_exists()
.col(
ColumnDef::new(JwtRefreshStorage::RefreshTokenHash)
.big_integer()
.not_null()
.primary_key(),
)
.col(
ColumnDef::new(JwtRefreshStorage::UserId)
.string_len(255)
.not_null(),
)
.col(
ColumnDef::new(JwtRefreshStorage::ExpiryDate)
.date_time()
.not_null(),
)
.foreign_key(
ForeignKey::create()
.name("JwtRefreshStorageUserForeignKey")
.table(JwtRefreshStorage::Table, Users::Table)
.col(JwtRefreshStorage::UserId, Users::UserId)
.on_delete(ForeignKeyAction::Cascade)
.on_update(ForeignKeyAction::Cascade),
)
.to_string(DbQueryBuilder {}),
)
.execute(pool)
.await?;
sqlx::query(
&Table::create()
.table(JwtStorage::Table)
.if_not_exists()
.col(
ColumnDef::new(JwtStorage::JwtHash)
.big_integer()
.not_null()
.primary_key(),
)
.col(
ColumnDef::new(JwtStorage::UserId)
.string_len(255)
.not_null(),
)
.col(
ColumnDef::new(JwtStorage::ExpiryDate)
.date_time()
.not_null(),
)
.col(
ColumnDef::new(JwtStorage::Blacklisted)
.boolean()
.default(false)
.not_null(),
)
.foreign_key(
ForeignKey::create()
.name("JwtStorageUserForeignKey")
.table(JwtStorage::Table, Users::Table)
.col(JwtStorage::UserId, Users::UserId)
.on_delete(ForeignKeyAction::Cascade)
.on_update(ForeignKeyAction::Cascade),
)
.to_string(DbQueryBuilder {}),
)
.execute(pool)
.await?;
Ok(())
}

View File

@@ -0,0 +1,632 @@
use crate::domain::handler::{BackendHandler, LoginHandler, RequestFilter, User};
use anyhow::{bail, Result};
use ldap3_server::simple::*;
fn make_dn_pair<I>(mut iter: I) -> Result<(String, String)>
where
I: Iterator<Item = String>,
{
let pair = (
iter.next()
.ok_or_else(|| anyhow::Error::msg("Empty DN element"))?,
iter.next()
.ok_or_else(|| anyhow::Error::msg("Missing DN value"))?,
);
if let Some(e) = iter.next() {
bail!(
r#"Too many elements in distinguished name: "{:?}", "{:?}", "{:?}""#,
pair.0,
pair.1,
e
)
}
Ok(pair)
}
fn parse_distinguished_name(dn: &str) -> Result<Vec<(String, String)>> {
dn.split(',')
.map(|s| make_dn_pair(s.split('=').map(String::from)))
.collect()
}
fn get_user_id_from_distinguished_name(
dn: &str,
base_tree: &[(String, String)],
base_dn_str: &str,
ldap_user_dn: &str,
) -> Result<String> {
let parts = parse_distinguished_name(dn)?;
if !is_subtree(&parts, base_tree) {
bail!("Not a subtree of the base tree");
}
if parts.len() == base_tree.len() + 1 {
if dn != ldap_user_dn {
bail!(r#"Wrong admin DN. Expected: "{}""#, ldap_user_dn);
}
Ok(parts[0].1.to_string())
} else if parts.len() == base_tree.len() + 2 {
if parts[1].0 != "ou" || parts[1].1 != "people" || parts[0].0 != "cn" {
bail!(
r#"Unexpected user DN format. Expected: "cn=username,ou=people,{}""#,
base_dn_str
);
}
Ok(parts[0].1.to_string())
} else {
bail!(
r#"Unexpected user DN format. Expected: "cn=username,ou=people,{}""#,
base_dn_str
);
}
}
fn get_attribute(user: &User, attribute: &str) -> Result<Vec<String>> {
match attribute {
"objectClass" => Ok(vec![
"inetOrgPerson".to_string(),
"posixAccount".to_string(),
"mailAccount".to_string(),
]),
"uid" => Ok(vec![user.user_id.clone()]),
"mail" => Ok(vec![user.email.clone()]),
"givenName" => Ok(vec![user.first_name.clone().unwrap_or_default()]),
"sn" => Ok(vec![user.last_name.clone().unwrap_or_default()]),
"cn" => Ok(vec![user
.display_name
.clone()
.unwrap_or_else(|| user.user_id.clone())]),
_ => bail!("Unsupported attribute: {}", attribute),
}
}
fn make_ldap_search_result_entry(
user: User,
base_dn_str: &str,
attributes: &[String],
) -> Result<LdapSearchResultEntry> {
Ok(LdapSearchResultEntry {
dn: format!("cn={},{}", user.user_id, base_dn_str),
attributes: attributes
.iter()
.map(|a| {
Ok(LdapPartialAttribute {
atype: a.to_string(),
vals: get_attribute(&user, a)?,
})
})
.collect::<Result<Vec<LdapPartialAttribute>>>()?,
})
}
fn is_subtree(subtree: &[(String, String)], base_tree: &[(String, String)]) -> bool {
if subtree.len() < base_tree.len() {
return false;
}
let size_diff = subtree.len() - base_tree.len();
for i in 0..base_tree.len() {
if subtree[size_diff + i] != base_tree[i] {
return false;
}
}
true
}
fn map_field(field: &str) -> Result<String> {
Ok(if field == "uid" {
"user_id".to_string()
} else if field == "mail" {
"email".to_string()
} else if field == "cn" {
"display_name".to_string()
} else if field == "givenName" {
"first_name".to_string()
} else if field == "sn" {
"last_name".to_string()
} else if field == "avatar" {
"avatar".to_string()
} else if field == "creationDate" {
"creation_date".to_string()
} else {
bail!("Unknown field: {}", field);
})
}
fn convert_filter(filter: &LdapFilter) -> Result<RequestFilter> {
match filter {
LdapFilter::And(filters) => Ok(RequestFilter::And(
filters.iter().map(convert_filter).collect::<Result<_>>()?,
)),
LdapFilter::Or(filters) => Ok(RequestFilter::Or(
filters.iter().map(convert_filter).collect::<Result<_>>()?,
)),
LdapFilter::Not(filter) => Ok(RequestFilter::Not(Box::new(convert_filter(&*filter)?))),
LdapFilter::Equality(field, value) => {
Ok(RequestFilter::Equality(map_field(field)?, value.clone()))
}
_ => bail!("Unsupported filter"),
}
}
pub struct LdapHandler<Backend: BackendHandler + LoginHandler> {
dn: String,
backend_handler: Backend,
pub base_dn: Vec<(String, String)>,
base_dn_str: String,
ldap_user_dn: String,
}
impl<Backend: BackendHandler + LoginHandler> LdapHandler<Backend> {
pub fn new(backend_handler: Backend, ldap_base_dn: String, ldap_user_dn: String) -> Self {
Self {
dn: "Unauthenticated".to_string(),
backend_handler,
base_dn: parse_distinguished_name(&ldap_base_dn).unwrap_or_else(|_| {
panic!(
"Invalid value for ldap_base_dn in configuration: {}",
ldap_base_dn
)
}),
ldap_user_dn: format!("cn={},{}", ldap_user_dn, &ldap_base_dn),
base_dn_str: ldap_base_dn,
}
}
pub async fn do_bind(&mut self, sbr: &SimpleBindRequest) -> LdapMsg {
let user_id = match get_user_id_from_distinguished_name(
&sbr.dn,
&self.base_dn,
&self.base_dn_str,
&self.ldap_user_dn,
) {
Ok(s) => s,
Err(e) => return sbr.gen_error(LdapResultCode::NamingViolation, e.to_string()),
};
match self
.backend_handler
.bind(crate::domain::handler::BindRequest {
name: user_id,
password: sbr.pw.clone(),
})
.await
{
Ok(()) => {
self.dn = sbr.dn.clone();
sbr.gen_success()
}
Err(_) => sbr.gen_invalid_cred(),
}
}
pub async fn do_search(&mut self, lsr: &SearchRequest) -> Vec<LdapMsg> {
if self.dn != self.ldap_user_dn {
return vec![lsr.gen_error(
LdapResultCode::InsufficentAccessRights,
r#"Current user is not allowed to query LDAP"#.to_string(),
)];
}
let dn_parts = match parse_distinguished_name(&lsr.base) {
Ok(dn) => dn,
Err(_) => {
return vec![lsr.gen_error(
LdapResultCode::OperationsError,
format!(r#"Could not parse base DN: "{}""#, lsr.base),
)]
}
};
if !is_subtree(&dn_parts, &self.base_dn) {
// Search path is not in our tree, just return an empty success.
return vec![lsr.gen_success()];
}
let filters = match convert_filter(&lsr.filter) {
Ok(f) => Some(f),
Err(_) => {
return vec![lsr.gen_error(
LdapResultCode::UnwillingToPerform,
"Unsupported filter".to_string(),
)]
}
};
let users = match self.backend_handler.list_users(filters).await {
Ok(users) => users,
Err(e) => {
return vec![lsr.gen_error(
LdapResultCode::Other,
format!(r#"Error during search for "{}": {}"#, lsr.base, e),
)]
}
};
users
.into_iter()
.map(|u| make_ldap_search_result_entry(u, &self.base_dn_str, &lsr.attrs))
.map(|entry| Ok(lsr.gen_result_entry(entry?)))
// If the processing succeeds, add a success message at the end.
.chain(std::iter::once(Ok(lsr.gen_success())))
.collect::<Result<Vec<_>>>()
.unwrap_or_else(|e| vec![lsr.gen_error(LdapResultCode::NoSuchAttribute, e.to_string())])
}
pub fn do_whoami(&mut self, wr: &WhoamiRequest) -> LdapMsg {
if self.dn == "Unauthenticated" {
wr.gen_operror("Unauthenticated")
} else {
wr.gen_success(format!("dn: {}", self.dn).as_str())
}
}
pub async fn handle_ldap_message(&mut self, server_op: ServerOps) -> Option<Vec<LdapMsg>> {
let result = match server_op {
ServerOps::SimpleBind(sbr) => vec![self.do_bind(&sbr).await],
ServerOps::Search(sr) => self.do_search(&sr).await,
ServerOps::Unbind(_) => {
// No need to notify on unbind (per rfc4511)
return None;
}
ServerOps::Whoami(wr) => vec![self.do_whoami(&wr)],
};
Some(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::domain::handler::BindRequest;
use crate::domain::handler::MockTestBackendHandler;
use mockall::predicate::eq;
use tokio;
async fn setup_bound_handler(
mut mock: MockTestBackendHandler,
) -> LdapHandler<MockTestBackendHandler> {
mock.expect_bind()
.with(eq(BindRequest {
name: "test".to_string(),
password: "pass".to_string(),
}))
.return_once(|_| Ok(()));
let mut ldap_handler =
LdapHandler::new(mock, "dc=example,dc=com".to_string(), "test".to_string());
let request = SimpleBindRequest {
msgid: 1,
dn: "cn=test,dc=example,dc=com".to_string(),
pw: "pass".to_string(),
};
ldap_handler.do_bind(&request).await;
ldap_handler
}
#[tokio::test]
async fn test_bind() {
let mut mock = MockTestBackendHandler::new();
mock.expect_bind()
.with(eq(crate::domain::handler::BindRequest {
name: "bob".to_string(),
password: "pass".to_string(),
}))
.times(1)
.return_once(|_| Ok(()));
let mut ldap_handler =
LdapHandler::new(mock, "dc=example,dc=com".to_string(), "test".to_string());
let request = WhoamiRequest { msgid: 1 };
assert_eq!(
ldap_handler.do_whoami(&request),
request.gen_operror("Unauthenticated")
);
let request = SimpleBindRequest {
msgid: 2,
dn: "cn=bob,ou=people,dc=example,dc=com".to_string(),
pw: "pass".to_string(),
};
assert_eq!(ldap_handler.do_bind(&request).await, request.gen_success());
let request = WhoamiRequest { msgid: 3 };
assert_eq!(
ldap_handler.do_whoami(&request),
request.gen_success("dn: cn=bob,ou=people,dc=example,dc=com")
);
}
#[tokio::test]
async fn test_admin_bind() {
let mut mock = MockTestBackendHandler::new();
mock.expect_bind()
.with(eq(crate::domain::handler::BindRequest {
name: "test".to_string(),
password: "pass".to_string(),
}))
.times(1)
.return_once(|_| Ok(()));
let mut ldap_handler =
LdapHandler::new(mock, "dc=example,dc=com".to_string(), "test".to_string());
let request = WhoamiRequest { msgid: 1 };
assert_eq!(
ldap_handler.do_whoami(&request),
request.gen_operror("Unauthenticated")
);
let request = SimpleBindRequest {
msgid: 2,
dn: "cn=test,dc=example,dc=com".to_string(),
pw: "pass".to_string(),
};
assert_eq!(ldap_handler.do_bind(&request).await, request.gen_success());
let request = WhoamiRequest { msgid: 3 };
assert_eq!(
ldap_handler.do_whoami(&request),
request.gen_success("dn: cn=test,dc=example,dc=com")
);
}
#[tokio::test]
async fn test_bind_invalid_credentials() {
let mut mock = MockTestBackendHandler::new();
mock.expect_bind()
.with(eq(crate::domain::handler::BindRequest {
name: "test".to_string(),
password: "pass".to_string(),
}))
.times(1)
.return_once(|_| Ok(()));
let mut ldap_handler =
LdapHandler::new(mock, "dc=example,dc=com".to_string(), "admin".to_string());
let request = WhoamiRequest { msgid: 1 };
assert_eq!(
ldap_handler.do_whoami(&request),
request.gen_operror("Unauthenticated")
);
let request = SimpleBindRequest {
msgid: 2,
dn: "cn=test,ou=people,dc=example,dc=com".to_string(),
pw: "pass".to_string(),
};
assert_eq!(ldap_handler.do_bind(&request).await, request.gen_success());
let request = WhoamiRequest { msgid: 3 };
assert_eq!(
ldap_handler.do_whoami(&request),
request.gen_success("dn: cn=test,ou=people,dc=example,dc=com")
);
let request = SearchRequest {
msgid: 2,
base: "ou=people,dc=example,dc=com".to_string(),
scope: LdapSearchScope::Base,
filter: LdapFilter::And(vec![]),
attrs: vec![],
};
assert_eq!(
ldap_handler.do_search(&request).await,
vec![request.gen_error(
LdapResultCode::InsufficentAccessRights,
r#"Current user is not allowed to query LDAP"#.to_string()
)]
);
}
#[tokio::test]
async fn test_bind_invalid_dn() {
let mock = MockTestBackendHandler::new();
let mut ldap_handler =
LdapHandler::new(mock, "dc=example,dc=com".to_string(), "admin".to_string());
let request = SimpleBindRequest {
msgid: 2,
dn: "cn=bob,dc=example,dc=com".to_string(),
pw: "pass".to_string(),
};
assert_eq!(
ldap_handler.do_bind(&request).await,
request.gen_error(
LdapResultCode::NamingViolation,
r#"Wrong admin DN. Expected: "cn=admin,dc=example,dc=com""#.to_string()
)
);
let request = SimpleBindRequest {
msgid: 2,
dn: "cn=bob,ou=groups,dc=example,dc=com".to_string(),
pw: "pass".to_string(),
};
assert_eq!(
ldap_handler.do_bind(&request).await,
request.gen_error(
LdapResultCode::NamingViolation,
r#"Unexpected user DN format. Expected: "cn=username,ou=people,dc=example,dc=com""#
.to_string()
)
);
}
#[test]
fn test_is_subtree() {
let subtree1 = &[
("ou".to_string(), "people".to_string()),
("dc".to_string(), "example".to_string()),
("dc".to_string(), "com".to_string()),
];
let root = &[
("dc".to_string(), "example".to_string()),
("dc".to_string(), "com".to_string()),
];
assert!(is_subtree(subtree1, root));
assert!(!is_subtree(&[], root));
}
#[test]
fn test_parse_distinguished_name() {
let parsed_dn = &[
("ou".to_string(), "people".to_string()),
("dc".to_string(), "example".to_string()),
("dc".to_string(), "com".to_string()),
];
assert_eq!(
parse_distinguished_name("ou=people,dc=example,dc=com").expect("parsing failed"),
parsed_dn
);
}
#[tokio::test]
async fn test_search() {
let mut mock = MockTestBackendHandler::new();
mock.expect_list_users().times(1).return_once(|_| {
Ok(vec![
User {
user_id: "bob_1".to_string(),
email: "bob@bobmail.bob".to_string(),
display_name: Some("Bôb Böbberson".to_string()),
first_name: Some("Bôb".to_string()),
last_name: Some("Böbberson".to_string()),
..Default::default()
},
User {
user_id: "jim".to_string(),
email: "jim@cricket.jim".to_string(),
display_name: Some("Jimminy Cricket".to_string()),
first_name: Some("Jim".to_string()),
last_name: Some("Cricket".to_string()),
..Default::default()
},
])
});
let mut ldap_handler = setup_bound_handler(mock).await;
let request = SearchRequest {
msgid: 2,
base: "ou=people,dc=example,dc=com".to_string(),
scope: LdapSearchScope::Base,
filter: LdapFilter::And(vec![]),
attrs: vec![
"objectClass".to_string(),
"uid".to_string(),
"mail".to_string(),
"givenName".to_string(),
"sn".to_string(),
"cn".to_string(),
],
};
assert_eq!(
ldap_handler.do_search(&request).await,
vec![
request.gen_result_entry(LdapSearchResultEntry {
dn: "cn=bob_1,dc=example,dc=com".to_string(),
attributes: vec![
LdapPartialAttribute {
atype: "objectClass".to_string(),
vals: vec![
"inetOrgPerson".to_string(),
"posixAccount".to_string(),
"mailAccount".to_string()
]
},
LdapPartialAttribute {
atype: "uid".to_string(),
vals: vec!["bob_1".to_string()]
},
LdapPartialAttribute {
atype: "mail".to_string(),
vals: vec!["bob@bobmail.bob".to_string()]
},
LdapPartialAttribute {
atype: "givenName".to_string(),
vals: vec!["Bôb".to_string()]
},
LdapPartialAttribute {
atype: "sn".to_string(),
vals: vec!["Böbberson".to_string()]
},
LdapPartialAttribute {
atype: "cn".to_string(),
vals: vec!["Bôb Böbberson".to_string()]
}
],
}),
request.gen_result_entry(LdapSearchResultEntry {
dn: "cn=jim,dc=example,dc=com".to_string(),
attributes: vec![
LdapPartialAttribute {
atype: "objectClass".to_string(),
vals: vec![
"inetOrgPerson".to_string(),
"posixAccount".to_string(),
"mailAccount".to_string()
]
},
LdapPartialAttribute {
atype: "uid".to_string(),
vals: vec!["jim".to_string()]
},
LdapPartialAttribute {
atype: "mail".to_string(),
vals: vec!["jim@cricket.jim".to_string()]
},
LdapPartialAttribute {
atype: "givenName".to_string(),
vals: vec!["Jim".to_string()]
},
LdapPartialAttribute {
atype: "sn".to_string(),
vals: vec!["Cricket".to_string()]
},
LdapPartialAttribute {
atype: "cn".to_string(),
vals: vec!["Jimminy Cricket".to_string()]
}
],
}),
request.gen_success()
]
);
}
#[tokio::test]
async fn test_search_filters() {
let mut mock = MockTestBackendHandler::new();
mock.expect_list_users()
.with(eq(Some(RequestFilter::And(vec![RequestFilter::Or(vec![
RequestFilter::Not(Box::new(RequestFilter::Equality(
"user_id".to_string(),
"bob".to_string(),
))),
])]))))
.times(1)
.return_once(|_| Ok(vec![]));
let mut ldap_handler = setup_bound_handler(mock).await;
let request = SearchRequest {
msgid: 2,
base: "ou=people,dc=example,dc=com".to_string(),
scope: LdapSearchScope::Base,
filter: LdapFilter::And(vec![LdapFilter::Or(vec![LdapFilter::Not(Box::new(
LdapFilter::Equality("uid".to_string(), "bob".to_string()),
))])]),
attrs: vec!["objectClass".to_string()],
};
assert_eq!(
ldap_handler.do_search(&request).await,
vec![request.gen_success()]
);
}
#[tokio::test]
async fn test_search_unsupported_filters() {
let mut ldap_handler = setup_bound_handler(MockTestBackendHandler::new()).await;
let request = SearchRequest {
msgid: 2,
base: "ou=people,dc=example,dc=com".to_string(),
scope: LdapSearchScope::Base,
filter: LdapFilter::Present("uid".to_string()),
attrs: vec!["objectClass".to_string()],
};
assert_eq!(
ldap_handler.do_search(&request).await,
vec![request.gen_error(
LdapResultCode::UnwillingToPerform,
"Unsupported filter".to_string()
)]
);
}
}

View File

@@ -0,0 +1,102 @@
use crate::domain::handler::{BackendHandler, LoginHandler};
use crate::infra::configuration::Configuration;
use crate::infra::ldap_handler::LdapHandler;
use actix_rt::net::TcpStream;
use actix_server::ServerBuilder;
use actix_service::{fn_service, ServiceFactoryExt};
use anyhow::{bail, Result};
use futures_util::future::ok;
use ldap3_server::simple::*;
use ldap3_server::LdapCodec;
use log::*;
use tokio::net::tcp::WriteHalf;
use tokio_util::codec::{FramedRead, FramedWrite};
async fn handle_incoming_message<Backend>(
msg: Result<LdapMsg, std::io::Error>,
resp: &mut FramedWrite<WriteHalf<'_>, LdapCodec>,
session: &mut LdapHandler<Backend>,
) -> Result<bool>
where
Backend: BackendHandler + LoginHandler,
{
use futures_util::SinkExt;
use std::convert::TryFrom;
let server_op = match msg.map_err(|_e| ()).and_then(ServerOps::try_from) {
Ok(a_value) => a_value,
Err(an_error) => {
let _err = resp
.send(DisconnectionNotice::gen(
LdapResultCode::Other,
"Internal Server Error",
))
.await;
let _err = resp.flush().await;
bail!("Internal server error: {:?}", an_error);
}
};
match session.handle_ldap_message(server_op).await {
None => return Ok(false),
Some(result) => {
for rmsg in result.into_iter() {
if let Err(e) = resp.send(rmsg).await {
bail!("Error while sending a response: {:?}", e);
}
}
if let Err(e) = resp.flush().await {
bail!("Error while flushing responses: {:?}", e);
}
}
}
Ok(true)
}
pub fn build_ldap_server<Backend>(
config: &Configuration,
backend_handler: Backend,
server_builder: ServerBuilder,
) -> Result<ServerBuilder>
where
Backend: BackendHandler + LoginHandler + 'static,
{
use futures_util::StreamExt;
let ldap_base_dn = config.ldap_base_dn.clone();
let ldap_user_dn = config.ldap_user_dn.clone();
Ok(
server_builder.bind("ldap", ("0.0.0.0", config.ldap_port), move || {
let backend_handler = backend_handler.clone();
let ldap_base_dn = ldap_base_dn.clone();
let ldap_user_dn = ldap_user_dn.clone();
fn_service(move |mut stream: TcpStream| {
let backend_handler = backend_handler.clone();
let ldap_base_dn = ldap_base_dn.clone();
let ldap_user_dn = ldap_user_dn.clone();
async move {
// Configure the codec etc.
let (r, w) = stream.split();
let mut requests = FramedRead::new(r, LdapCodec);
let mut resp = FramedWrite::new(w, LdapCodec);
let mut session = LdapHandler::new(backend_handler, ldap_base_dn, ldap_user_dn);
while let Some(msg) = requests.next().await {
if !handle_incoming_message(msg, &mut resp, &mut session).await? {
break;
}
}
Ok(stream)
}
})
.map_err(|err: anyhow::Error| error!("Service Error: {:?}", err))
// catch
.and_then(move |_| {
// finally
ok(())
})
})?,
)
}

View File

@@ -0,0 +1,25 @@
use crate::infra::configuration::Configuration;
use anyhow::Context;
use tracing::subscriber::set_global_default;
use tracing_log::LogTracer;
pub fn init(config: Configuration) -> anyhow::Result<()> {
let max_log_level = log_level_from_config(config);
let subscriber = tracing_subscriber::fmt()
.with_timer(tracing_subscriber::fmt::time::time())
.with_target(false)
.with_level(true)
.with_max_level(max_log_level)
.finish();
LogTracer::init().context("Failed to set logger")?;
set_global_default(subscriber).context("Failed to set subscriber")?;
Ok(())
}
fn log_level_from_config(config: Configuration) -> tracing::Level {
if config.verbose {
tracing::Level::DEBUG
} else {
tracing::Level::INFO
}
}

12
server/src/infra/mod.rs Normal file
View File

@@ -0,0 +1,12 @@
pub mod auth_service;
pub mod cli;
pub mod configuration;
pub mod db_cleaner;
pub mod graphql;
pub mod jwt_sql_tables;
pub mod ldap_handler;
pub mod ldap_server;
pub mod logging;
pub mod sql_backend_handler;
pub mod tcp_backend_handler;
pub mod tcp_server;

View File

@@ -0,0 +1,105 @@
use super::{jwt_sql_tables::*, tcp_backend_handler::*};
use crate::domain::{error::*, sql_backend_handler::SqlBackendHandler};
use async_trait::async_trait;
use futures_util::StreamExt;
use sea_query::{Expr, Iden, Query, SimpleExpr};
use sqlx::Row;
use std::collections::HashSet;
#[async_trait]
impl TcpBackendHandler for SqlBackendHandler {
async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>> {
use sqlx::Result;
let query = Query::select()
.column(JwtStorage::JwtHash)
.from(JwtStorage::Table)
.to_string(DbQueryBuilder {});
sqlx::query(&query)
.map(|row: DbRow| row.get::<i64, _>(&*JwtStorage::JwtHash.to_string()) as u64)
.fetch(&self.sql_pool)
.collect::<Vec<sqlx::Result<u64>>>()
.await
.into_iter()
.collect::<Result<HashSet<u64>>>()
.map_err(|e| anyhow::anyhow!(e))
}
async fn create_refresh_token(&self, user: &str) -> Result<(String, chrono::Duration)> {
use rand::{distributions::Alphanumeric, rngs::SmallRng, Rng, SeedableRng};
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
// TODO: Initialize the rng only once. Maybe Arc<Cell>?
let mut rng = SmallRng::from_entropy();
let refresh_token: String = std::iter::repeat(())
.map(|()| rng.sample(Alphanumeric))
.map(char::from)
.take(100)
.collect();
let refresh_token_hash = {
let mut s = DefaultHasher::new();
refresh_token.hash(&mut s);
s.finish()
};
let duration = chrono::Duration::days(30);
let query = Query::insert()
.into_table(JwtRefreshStorage::Table)
.columns(vec![
JwtRefreshStorage::RefreshTokenHash,
JwtRefreshStorage::UserId,
JwtRefreshStorage::ExpiryDate,
])
.values_panic(vec![
(refresh_token_hash as i64).into(),
user.into(),
(chrono::Utc::now() + duration).naive_utc().into(),
])
.to_string(DbQueryBuilder {});
sqlx::query(&query).execute(&self.sql_pool).await?;
Ok((refresh_token, duration))
}
async fn check_token(&self, refresh_token_hash: u64, user: &str) -> Result<bool> {
let query = Query::select()
.expr(SimpleExpr::Value(1.into()))
.from(JwtRefreshStorage::Table)
.and_where(Expr::col(JwtRefreshStorage::RefreshTokenHash).eq(refresh_token_hash as i64))
.and_where(Expr::col(JwtRefreshStorage::UserId).eq(user))
.to_string(DbQueryBuilder {});
Ok(sqlx::query(&query)
.fetch_optional(&self.sql_pool)
.await?
.is_some())
}
async fn blacklist_jwts(&self, user: &str) -> DomainResult<HashSet<u64>> {
use sqlx::Result;
let query = Query::select()
.column(JwtStorage::JwtHash)
.from(JwtStorage::Table)
.and_where(Expr::col(JwtStorage::UserId).eq(user))
.and_where(Expr::col(JwtStorage::Blacklisted).eq(true))
.to_string(DbQueryBuilder {});
let result = sqlx::query(&query)
.map(|row: DbRow| row.get::<i64, _>(&*JwtStorage::JwtHash.to_string()) as u64)
.fetch(&self.sql_pool)
.collect::<Vec<sqlx::Result<u64>>>()
.await
.into_iter()
.collect::<Result<HashSet<u64>>>();
let query = Query::update()
.table(JwtStorage::Table)
.values(vec![(JwtStorage::Blacklisted, true.into())])
.and_where(Expr::col(JwtStorage::UserId).eq(user))
.to_string(DbQueryBuilder {});
sqlx::query(&query).execute(&self.sql_pool).await?;
Ok(result?)
}
async fn delete_refresh_token(&self, refresh_token_hash: u64) -> DomainResult<()> {
let query = Query::delete()
.from_table(JwtRefreshStorage::Table)
.and_where(Expr::col(JwtRefreshStorage::RefreshTokenHash).eq(refresh_token_hash))
.to_string(DbQueryBuilder {});
sqlx::query(&query).execute(&self.sql_pool).await?;
Ok(())
}
}

View File

@@ -0,0 +1,46 @@
use async_trait::async_trait;
use std::collections::HashSet;
pub type DomainResult<T> = crate::domain::error::Result<T>;
#[async_trait]
pub trait TcpBackendHandler {
async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>>;
async fn create_refresh_token(&self, user: &str) -> DomainResult<(String, chrono::Duration)>;
async fn check_token(&self, refresh_token_hash: u64, user: &str) -> DomainResult<bool>;
async fn blacklist_jwts(&self, user: &str) -> DomainResult<HashSet<u64>>;
async fn delete_refresh_token(&self, refresh_token_hash: u64) -> DomainResult<()>;
}
#[cfg(test)]
use crate::domain::handler::*;
#[cfg(test)]
mockall::mock! {
pub TestTcpBackendHandler{}
impl Clone for TestTcpBackendHandler {
fn clone(&self) -> Self;
}
#[async_trait]
impl LoginHandler for TestTcpBackendHandler {
async fn bind(&self, request: BindRequest) -> DomainResult<()>;
}
#[async_trait]
impl BackendHandler for TestTcpBackendHandler {
async fn list_users(&self, filters: Option<RequestFilter>) -> DomainResult<Vec<User>>;
async fn list_groups(&self) -> DomainResult<Vec<Group>>;
async fn get_user_details(&self, user_id: &str) -> DomainResult<User>;
async fn get_user_groups(&self, user: &str) -> DomainResult<HashSet<String>>;
async fn create_user(&self, request: CreateUserRequest) -> DomainResult<()>;
async fn delete_user(&self, user_id: &str) -> DomainResult<()>;
async fn create_group(&self, group_name: &str) -> DomainResult<GroupId>;
async fn add_user_to_group(&self, user_id: &str, group_id: GroupId) -> DomainResult<()>;
}
#[async_trait]
impl TcpBackendHandler for TestTcpBackendHandler {
async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>>;
async fn create_refresh_token(&self, user: &str) -> DomainResult<(String, chrono::Duration)>;
async fn check_token(&self, refresh_token_hash: u64, user: &str) -> DomainResult<bool>;
async fn blacklist_jwts(&self, user: &str) -> DomainResult<HashSet<u64>>;
async fn delete_refresh_token(&self, refresh_token_hash: u64) -> DomainResult<()>;
}
}

View File

@@ -0,0 +1,134 @@
use crate::{
domain::{
error::DomainError,
handler::{BackendHandler, LoginHandler},
opaque_handler::OpaqueHandler,
},
infra::{auth_service, configuration::Configuration, tcp_backend_handler::*},
};
use actix_files::{Files, NamedFile};
use actix_http::HttpServiceBuilder;
use actix_server::ServerBuilder;
use actix_service::map_config;
use actix_web::{dev::AppConfig, web, App, HttpRequest, HttpResponse};
use anyhow::{Context, Result};
use hmac::{Hmac, NewMac};
use sha2::Sha512;
use std::collections::HashSet;
use std::path::PathBuf;
use std::sync::RwLock;
async fn index(req: HttpRequest) -> actix_web::Result<NamedFile> {
let mut path = PathBuf::new();
path.push("../app");
let file = req.match_info().query("filename");
path.push(if file.is_empty() { "index.html" } else { file });
Ok(NamedFile::open(path)?)
}
pub(crate) fn error_to_http_response(error: DomainError) -> HttpResponse {
match error {
DomainError::AuthenticationError(_) | DomainError::AuthenticationProtocolError(_) => {
HttpResponse::Unauthorized()
}
DomainError::DatabaseError(_)
| DomainError::InternalError(_)
| DomainError::UnknownCryptoError(_) => HttpResponse::InternalServerError(),
DomainError::Base64DecodeError(_) | DomainError::BinarySerializationError(_) => {
HttpResponse::BadRequest()
}
}
.body(error.to_string())
}
fn http_config<Backend>(
cfg: &mut web::ServiceConfig,
backend_handler: Backend,
jwt_secret: String,
jwt_blacklist: HashSet<u64>,
) where
Backend: TcpBackendHandler + BackendHandler + LoginHandler + OpaqueHandler + Sync + 'static,
{
cfg.app_data(web::Data::new(AppState::<Backend> {
backend_handler,
jwt_key: Hmac::new_varkey(jwt_secret.as_bytes()).unwrap(),
jwt_blacklist: RwLock::new(jwt_blacklist),
}))
// Serve index.html and main.js, and default to index.html.
.route(
"/{filename:(index\\.html|main\\.js)?}",
web::get().to(index),
)
.service(web::scope("/auth").configure(auth_service::configure_server::<Backend>))
// API endpoint.
.service(
web::scope("/api")
.wrap(auth_service::CookieToHeaderTranslatorFactory)
.configure(super::graphql::api::configure_endpoint::<Backend>),
)
// Serve the /pkg path with the compiled WASM app.
.service(Files::new("/pkg", "./app/pkg"))
// Default to serve index.html for unknown routes, to support routing.
.service(web::scope("/").route("/.*", web::get().to(index)));
}
pub(crate) struct AppState<Backend> {
pub backend_handler: Backend,
pub jwt_key: Hmac<Sha512>,
pub jwt_blacklist: RwLock<HashSet<u64>>,
}
pub async fn build_tcp_server<Backend>(
config: &Configuration,
backend_handler: Backend,
server_builder: ServerBuilder,
) -> Result<ServerBuilder>
where
Backend: TcpBackendHandler + BackendHandler + LoginHandler + OpaqueHandler + Sync + 'static,
{
let jwt_secret = config.jwt_secret.clone();
let jwt_blacklist = backend_handler.get_jwt_blacklist().await?;
server_builder
.bind("http", ("0.0.0.0", config.http_port), move || {
let backend_handler = backend_handler.clone();
let jwt_secret = jwt_secret.clone();
let jwt_blacklist = jwt_blacklist.clone();
HttpServiceBuilder::new()
.finish(map_config(
App::new().configure(move |cfg| {
http_config(cfg, backend_handler, jwt_secret, jwt_blacklist)
}),
|_| AppConfig::default(),
))
.tcp()
})
.with_context(|| {
format!(
"While bringing up the TCP server with port {}",
config.http_port
)
})
}
#[cfg(test)]
mod tests {
use super::*;
use actix_web::test::TestRequest;
use std::path::Path;
#[actix_rt::test]
async fn test_index_ok() {
let req = TestRequest::default().to_http_request();
let resp = index(req).await.unwrap();
assert_eq!(resp.path(), Path::new("../app/index.html"));
}
#[actix_rt::test]
async fn test_index_main_js() {
let req = TestRequest::default()
.param("filename", "main.js")
.to_http_request();
let resp = index(req).await.unwrap();
assert_eq!(resp.path(), Path::new("../app/main.js"));
}
}

88
server/src/main.rs Normal file
View File

@@ -0,0 +1,88 @@
#![forbid(unsafe_code)]
#![allow(clippy::nonstandard_macro_braces)]
use crate::{
domain::{
handler::{BackendHandler, CreateUserRequest},
sql_backend_handler::SqlBackendHandler,
sql_opaque_handler::register_password,
sql_tables::PoolOptions,
},
infra::{cli::*, configuration::Configuration, db_cleaner::Scheduler},
};
use actix::Actor;
use anyhow::{Context, Result};
use futures_util::TryFutureExt;
use log::*;
mod domain;
mod infra;
async fn create_admin_user(handler: &SqlBackendHandler, config: &Configuration) -> Result<()> {
handler
.create_user(CreateUserRequest {
user_id: config.ldap_user_dn.clone(),
..Default::default()
})
.and_then(|_| register_password(handler, &config.ldap_user_dn, &config.ldap_user_pass))
.await
.context("Error creating admin user")?;
let admin_group_id = handler
.create_group("lldap_admin")
.await
.context("Error creating admin group")?;
handler
.add_user_to_group(&config.ldap_user_dn, admin_group_id)
.await
.context("Error adding admin user to group")
}
async fn run_server(config: Configuration) -> Result<()> {
let sql_pool = PoolOptions::new()
.max_connections(5)
.connect(&config.database_url)
.await?;
domain::sql_tables::init_table(&sql_pool).await?;
let backend_handler = SqlBackendHandler::new(config.clone(), sql_pool.clone());
create_admin_user(&backend_handler, &config)
.await
.unwrap_or_else(|e| warn!("Error setting up admin login/account: {}", e));
let server_builder = infra::ldap_server::build_ldap_server(
&config,
backend_handler.clone(),
actix_server::Server::build(),
)?;
infra::jwt_sql_tables::init_table(&sql_pool).await?;
let server_builder =
infra::tcp_server::build_tcp_server(&config, backend_handler, server_builder).await?;
// Run every hour.
let scheduler = Scheduler::new("0 0 * * * * *", sql_pool);
scheduler.start();
server_builder.workers(1).run().await?;
Ok(())
}
fn run_server_command(opts: RunOpts) -> Result<()> {
let config = infra::configuration::init(opts.clone())?;
infra::logging::init(config.clone())?;
info!("Starting LLDAP....");
debug!("CLI: {:#?}", opts);
debug!("Configuration: {:#?}", config);
actix::run(
run_server(config).unwrap_or_else(|e| error!("Could not bring up the servers: {:?}", e)),
)?;
info!("End.");
Ok(())
}
fn main() -> Result<()> {
let cli_opts = infra::cli::init();
match cli_opts.command {
Command::ExportGraphQLSchema(opts) => infra::graphql::api::export_schema(opts),
Command::Run(opts) => run_server_command(opts),
}
}