Move backend source to server/ subpackage

To clarify the organization.
This commit is contained in:
Valentin Tolmer
2021-08-31 16:46:31 +02:00
committed by nitnelave
parent 3eb53ba5bf
commit d8df47b35d
30 changed files with 93 additions and 88 deletions

View File

@@ -0,0 +1,22 @@
use thiserror::Error;
#[allow(clippy::enum_variant_names)]
#[derive(Error, Debug)]
pub enum DomainError {
#[error("Authentication error for `{0}`")]
AuthenticationError(String),
#[error("Database error: `{0}`")]
DatabaseError(#[from] sqlx::Error),
#[error("Authentication protocol error for `{0}`")]
AuthenticationProtocolError(#[from] lldap_auth::opaque::AuthenticationError),
#[error("Unknown crypto error: `{0}`")]
UnknownCryptoError(#[from] orion::errors::UnknownCryptoError),
#[error("Binary serialization error: `{0}`")]
BinarySerializationError(#[from] bincode::Error),
#[error("Invalid base64: `{0}`")]
Base64DecodeError(#[from] base64::DecodeError),
#[error("Internal error: `{0}`")]
InternalError(String),
}
pub type Result<T> = std::result::Result<T, DomainError>;

View File

@@ -0,0 +1,103 @@
use super::error::*;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
#[derive(PartialEq, Eq, Debug, Serialize, Deserialize)]
#[cfg_attr(not(target_arch = "wasm32"), derive(sqlx::FromRow))]
pub struct User {
pub user_id: String,
pub email: String,
pub display_name: Option<String>,
pub first_name: Option<String>,
pub last_name: Option<String>,
// pub avatar: ?,
pub creation_date: chrono::DateTime<chrono::Utc>,
}
impl Default for User {
fn default() -> Self {
use chrono::TimeZone;
User {
user_id: String::new(),
email: String::new(),
display_name: None,
first_name: None,
last_name: None,
creation_date: chrono::Utc.timestamp(0, 0),
}
}
}
#[derive(PartialEq, Eq, Debug, Serialize, Deserialize)]
pub struct Group {
pub display_name: String,
pub users: Vec<String>,
}
#[derive(PartialEq, Eq, Debug, Serialize, Deserialize, Clone)]
pub struct BindRequest {
pub name: String,
pub password: String,
}
#[derive(PartialEq, Eq, Debug, Serialize, Deserialize, Clone)]
pub enum RequestFilter {
And(Vec<RequestFilter>),
Or(Vec<RequestFilter>),
Not(Box<RequestFilter>),
Equality(String, String),
}
#[derive(PartialEq, Eq, Debug, Serialize, Deserialize, Clone, Default)]
pub struct CreateUserRequest {
// Same fields as User, but no creation_date, and with password.
pub user_id: String,
pub email: String,
pub display_name: Option<String>,
pub first_name: Option<String>,
pub last_name: Option<String>,
}
#[async_trait]
pub trait LoginHandler: Clone + Send {
async fn bind(&self, request: BindRequest) -> Result<()>;
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct GroupId(pub i32);
#[async_trait]
pub trait BackendHandler: Clone + Send {
async fn list_users(&self, filters: Option<RequestFilter>) -> Result<Vec<User>>;
async fn list_groups(&self) -> Result<Vec<Group>>;
async fn get_user_details(&self, user_id: &str) -> Result<User>;
async fn create_user(&self, request: CreateUserRequest) -> Result<()>;
async fn delete_user(&self, user_id: &str) -> Result<()>;
async fn create_group(&self, group_name: &str) -> Result<GroupId>;
async fn add_user_to_group(&self, user_id: &str, group_id: GroupId) -> Result<()>;
async fn get_user_groups(&self, user: &str) -> Result<HashSet<String>>;
}
#[cfg(test)]
mockall::mock! {
pub TestBackendHandler{}
impl Clone for TestBackendHandler {
fn clone(&self) -> Self;
}
#[async_trait]
impl BackendHandler for TestBackendHandler {
async fn list_users(&self, filters: Option<RequestFilter>) -> Result<Vec<User>>;
async fn list_groups(&self) -> Result<Vec<Group>>;
async fn get_user_details(&self, user_id: &str) -> Result<User>;
async fn create_user(&self, request: CreateUserRequest) -> Result<()>;
async fn delete_user(&self, user_id: &str) -> Result<()>;
async fn create_group(&self, group_name: &str) -> Result<GroupId>;
async fn get_user_groups(&self, user: &str) -> Result<HashSet<String>>;
async fn add_user_to_group(&self, user_id: &str, group_id: GroupId) -> Result<()>;
}
#[async_trait]
impl LoginHandler for TestBackendHandler {
async fn bind(&self, request: BindRequest) -> Result<()>;
}
}

6
server/src/domain/mod.rs Normal file
View File

@@ -0,0 +1,6 @@
pub mod error;
pub mod handler;
pub mod opaque_handler;
pub mod sql_backend_handler;
pub mod sql_opaque_handler;
pub mod sql_tables;

View File

@@ -0,0 +1,36 @@
use super::error::*;
use async_trait::async_trait;
pub use lldap_auth::{login, registration};
#[async_trait]
pub trait OpaqueHandler: Clone + Send {
async fn login_start(
&self,
request: login::ClientLoginStartRequest,
) -> Result<login::ServerLoginStartResponse>;
async fn login_finish(&self, request: login::ClientLoginFinishRequest) -> Result<String>;
async fn registration_start(
&self,
request: registration::ClientRegistrationStartRequest,
) -> Result<registration::ServerRegistrationStartResponse>;
async fn registration_finish(
&self,
request: registration::ClientRegistrationFinishRequest,
) -> Result<()>;
}
#[cfg(test)]
mockall::mock! {
pub TestOpaqueHandler{}
impl Clone for TestOpaqueHandler {
fn clone(&self) -> Self;
}
#[async_trait]
impl OpaqueHandler for TestOpaqueHandler {
async fn login_start(&self, request: login::ClientLoginStartRequest) -> Result<login::ServerLoginStartResponse>;
async fn login_finish(&self, request: login::ClientLoginFinishRequest ) -> Result<String>;
async fn registration_start(&self, request: registration::ClientRegistrationStartRequest) -> Result<registration::ServerRegistrationStartResponse>;
async fn registration_finish(&self, request: registration::ClientRegistrationFinishRequest ) -> Result<()>;
}
}

View File

@@ -0,0 +1,536 @@
use super::{error::*, handler::*, sql_tables::*};
use crate::infra::configuration::Configuration;
use async_trait::async_trait;
use futures_util::StreamExt;
use futures_util::TryStreamExt;
use sea_query::{Expr, Iden, Order, Query, SimpleExpr, Value};
use sqlx::Row;
use std::collections::HashSet;
#[derive(Debug, Clone)]
pub struct SqlBackendHandler {
pub(crate) config: Configuration,
pub(crate) sql_pool: Pool,
}
impl SqlBackendHandler {
pub fn new(config: Configuration, sql_pool: Pool) -> Self {
SqlBackendHandler { config, sql_pool }
}
}
fn get_filter_expr(filter: RequestFilter) -> SimpleExpr {
use RequestFilter::*;
fn get_repeated_filter(
fs: Vec<RequestFilter>,
field: &dyn Fn(SimpleExpr, SimpleExpr) -> SimpleExpr,
) -> SimpleExpr {
let mut it = fs.into_iter();
let first_expr = match it.next() {
None => return Expr::value(true),
Some(f) => get_filter_expr(f),
};
it.fold(first_expr, |e, f| field(e, get_filter_expr(f)))
}
match filter {
And(fs) => get_repeated_filter(fs, &SimpleExpr::and),
Or(fs) => get_repeated_filter(fs, &SimpleExpr::or),
Not(f) => Expr::not(Expr::expr(get_filter_expr(*f))),
Equality(s1, s2) => Expr::expr(Expr::cust(&s1)).eq(s2),
}
}
#[async_trait]
impl BackendHandler for SqlBackendHandler {
async fn list_users(&self, filters: Option<RequestFilter>) -> Result<Vec<User>> {
let query = {
let mut query_builder = Query::select()
.column(Users::UserId)
.column(Users::Email)
.column(Users::DisplayName)
.column(Users::FirstName)
.column(Users::LastName)
.column(Users::Avatar)
.column(Users::CreationDate)
.from(Users::Table)
.order_by(Users::UserId, Order::Asc)
.to_owned();
if let Some(filter) = filters {
if filter != RequestFilter::And(Vec::new())
&& filter != RequestFilter::Or(Vec::new())
{
query_builder.and_where(get_filter_expr(filter));
}
}
query_builder.to_string(DbQueryBuilder {})
};
let results = sqlx::query_as::<_, User>(&query)
.fetch(&self.sql_pool)
.collect::<Vec<sqlx::Result<User>>>()
.await;
Ok(results.into_iter().collect::<sqlx::Result<Vec<User>>>()?)
}
async fn list_groups(&self) -> Result<Vec<Group>> {
let query: String = Query::select()
.column(Groups::DisplayName)
.column(Memberships::UserId)
.from(Groups::Table)
.left_join(
Memberships::Table,
Expr::tbl(Groups::Table, Groups::GroupId)
.equals(Memberships::Table, Memberships::GroupId),
)
.order_by(Groups::DisplayName, Order::Asc)
.order_by(Memberships::UserId, Order::Asc)
.to_string(DbQueryBuilder {});
let mut results = sqlx::query(&query).fetch(&self.sql_pool);
let mut groups = Vec::new();
// The rows are ordered by group, user, so we need to group them into vectors.
{
let mut current_group = String::new();
let mut current_users = Vec::new();
while let Some(row) = results.try_next().await? {
let display_name = row.get::<String, _>(&*Groups::DisplayName.to_string());
if display_name != current_group {
if !current_group.is_empty() {
groups.push(Group {
display_name: current_group,
users: current_users,
});
current_users = Vec::new();
}
current_group = display_name.clone();
}
current_users.push(row.get::<String, _>(&*Memberships::UserId.to_string()));
}
groups.push(Group {
display_name: current_group,
users: current_users,
});
}
Ok(groups)
}
async fn get_user_details(&self, user_id: &str) -> Result<User> {
let query = Query::select()
.column(Users::UserId)
.column(Users::Email)
.column(Users::DisplayName)
.column(Users::FirstName)
.column(Users::LastName)
.column(Users::Avatar)
.column(Users::CreationDate)
.from(Users::Table)
.and_where(Expr::col(Users::UserId).eq(user_id))
.to_string(DbQueryBuilder {});
Ok(sqlx::query_as::<_, User>(&query)
.fetch_one(&self.sql_pool)
.await?)
}
async fn get_user_groups(&self, user: &str) -> Result<HashSet<String>> {
if user == self.config.ldap_user_dn {
let mut groups = HashSet::new();
groups.insert("lldap_admin".to_string());
return Ok(groups);
}
let query: String = Query::select()
.column(Groups::DisplayName)
.from(Groups::Table)
.inner_join(
Memberships::Table,
Expr::tbl(Groups::Table, Groups::GroupId)
.equals(Memberships::Table, Memberships::GroupId),
)
.and_where(Expr::col(Memberships::UserId).eq(user))
.to_string(DbQueryBuilder {});
sqlx::query(&query)
// Extract the group id from the row.
.map(|row: DbRow| row.get::<String, _>(&*Groups::DisplayName.to_string()))
.fetch(&self.sql_pool)
// Collect the vector of rows, each potentially an error.
.collect::<Vec<sqlx::Result<String>>>()
.await
.into_iter()
// Transform it into a single result (the first error if any), and group the group_ids
// into a HashSet.
.collect::<sqlx::Result<HashSet<_>>>()
// Map the sqlx::Error into a DomainError.
.map_err(DomainError::DatabaseError)
}
async fn create_user(&self, request: CreateUserRequest) -> Result<()> {
let columns = vec![
Users::UserId,
Users::Email,
Users::DisplayName,
Users::FirstName,
Users::LastName,
Users::CreationDate,
];
let values = vec![
request.user_id.clone().into(),
request.email.into(),
request.display_name.map(Into::into).unwrap_or(Value::Null),
request.first_name.map(Into::into).unwrap_or(Value::Null),
request.last_name.map(Into::into).unwrap_or(Value::Null),
chrono::Utc::now().naive_utc().into(),
];
let query = Query::insert()
.into_table(Users::Table)
.columns(columns)
.values_panic(values)
.to_string(DbQueryBuilder {});
sqlx::query(&query).execute(&self.sql_pool).await?;
Ok(())
}
async fn delete_user(&self, user_id: &str) -> Result<()> {
let delete_query = Query::delete()
.from_table(Users::Table)
.and_where(Expr::col(Users::UserId).eq(user_id))
.to_string(DbQueryBuilder {});
sqlx::query(&delete_query).execute(&self.sql_pool).await?;
Ok(())
}
async fn create_group(&self, group_name: &str) -> Result<GroupId> {
let query = Query::insert()
.into_table(Groups::Table)
.columns(vec![Groups::DisplayName])
.values_panic(vec![group_name.into()])
.to_string(DbQueryBuilder {});
sqlx::query(&query).execute(&self.sql_pool).await?;
let query = Query::select()
.column(Groups::GroupId)
.from(Groups::Table)
.and_where(Expr::col(Groups::DisplayName).eq(group_name))
.to_string(DbQueryBuilder {});
let row = sqlx::query(&query).fetch_one(&self.sql_pool).await?;
Ok(GroupId(row.get::<i32, _>(&*Groups::GroupId.to_string())))
}
async fn add_user_to_group(&self, user_id: &str, group_id: GroupId) -> Result<()> {
let query = Query::insert()
.into_table(Memberships::Table)
.columns(vec![Memberships::UserId, Memberships::GroupId])
.values_panic(vec![user_id.into(), group_id.0.into()])
.to_string(DbQueryBuilder {});
sqlx::query(&query).execute(&self.sql_pool).await?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::domain::sql_tables::init_table;
use crate::infra::configuration::ConfigurationBuilder;
use lldap_auth::{opaque, registration};
fn get_default_config() -> Configuration {
ConfigurationBuilder::default()
.verbose(true)
.build()
.unwrap()
}
async fn get_in_memory_db() -> Pool {
PoolOptions::new().connect("sqlite::memory:").await.unwrap()
}
async fn get_initialized_db() -> Pool {
let sql_pool = get_in_memory_db().await;
init_table(&sql_pool).await.unwrap();
sql_pool
}
async fn insert_user(handler: &SqlBackendHandler, name: &str, pass: &str) {
use crate::domain::opaque_handler::OpaqueHandler;
insert_user_no_password(handler, name).await;
let mut rng = rand::rngs::OsRng;
let client_registration_start =
opaque::client::registration::start_registration(pass, &mut rng).unwrap();
let response = handler
.registration_start(registration::ClientRegistrationStartRequest {
username: name.to_string(),
registration_start_request: client_registration_start.message,
})
.await
.unwrap();
let registration_upload = opaque::client::registration::finish_registration(
client_registration_start.state,
response.registration_response,
&mut rng,
)
.unwrap();
handler
.registration_finish(registration::ClientRegistrationFinishRequest {
server_data: response.server_data,
registration_upload: registration_upload.message,
})
.await
.unwrap();
}
async fn insert_user_no_password(handler: &SqlBackendHandler, name: &str) {
handler
.create_user(CreateUserRequest {
user_id: name.to_string(),
email: "bob@bob.bob".to_string(),
..Default::default()
})
.await
.unwrap();
}
async fn insert_group(handler: &SqlBackendHandler, name: &str) -> GroupId {
handler.create_group(name).await.unwrap()
}
async fn insert_membership(handler: &SqlBackendHandler, group_id: GroupId, user_id: &str) {
handler.add_user_to_group(user_id, group_id).await.unwrap();
}
#[tokio::test]
async fn test_bind_admin() {
let sql_pool = get_in_memory_db().await;
let config = ConfigurationBuilder::default()
.ldap_user_dn("admin".to_string())
.ldap_user_pass("test".to_string())
.build()
.unwrap();
let handler = SqlBackendHandler::new(config, sql_pool);
handler
.bind(BindRequest {
name: "admin".to_string(),
password: "test".to_string(),
})
.await
.unwrap();
}
#[tokio::test]
async fn test_bind_user() {
let sql_pool = get_initialized_db().await;
let config = get_default_config();
let handler = SqlBackendHandler::new(config, sql_pool.clone());
insert_user(&handler, "bob", "bob00").await;
handler
.bind(BindRequest {
name: "bob".to_string(),
password: "bob00".to_string(),
})
.await
.unwrap();
handler
.bind(BindRequest {
name: "andrew".to_string(),
password: "bob00".to_string(),
})
.await
.unwrap_err();
handler
.bind(BindRequest {
name: "bob".to_string(),
password: "wrong_password".to_string(),
})
.await
.unwrap_err();
}
#[tokio::test]
async fn test_user_no_password() {
let sql_pool = get_initialized_db().await;
let config = get_default_config();
let handler = SqlBackendHandler::new(config, sql_pool.clone());
insert_user_no_password(&handler, "bob").await;
handler
.bind(BindRequest {
name: "bob".to_string(),
password: "bob00".to_string(),
})
.await
.unwrap_err();
}
#[tokio::test]
async fn test_list_users() {
let sql_pool = get_initialized_db().await;
let config = get_default_config();
let handler = SqlBackendHandler::new(config, sql_pool);
insert_user(&handler, "bob", "bob00").await;
insert_user(&handler, "patrick", "pass").await;
insert_user(&handler, "John", "Pa33w0rd!").await;
{
let users = handler
.list_users(None)
.await
.unwrap()
.into_iter()
.map(|u| u.user_id)
.collect::<Vec<_>>();
assert_eq!(users, vec!["John", "bob", "patrick"]);
}
{
let users = handler
.list_users(Some(RequestFilter::Equality(
"user_id".to_string(),
"bob".to_string(),
)))
.await
.unwrap()
.into_iter()
.map(|u| u.user_id)
.collect::<Vec<_>>();
assert_eq!(users, vec!["bob"]);
}
{
let users = handler
.list_users(Some(RequestFilter::Or(vec![
RequestFilter::Equality("user_id".to_string(), "bob".to_string()),
RequestFilter::Equality("user_id".to_string(), "John".to_string()),
])))
.await
.unwrap()
.into_iter()
.map(|u| u.user_id)
.collect::<Vec<_>>();
assert_eq!(users, vec!["John", "bob"]);
}
{
let users = handler
.list_users(Some(RequestFilter::Not(Box::new(RequestFilter::Equality(
"user_id".to_string(),
"bob".to_string(),
)))))
.await
.unwrap()
.into_iter()
.map(|u| u.user_id)
.collect::<Vec<_>>();
assert_eq!(users, vec!["John", "patrick"]);
}
}
#[tokio::test]
async fn test_list_groups() {
let sql_pool = get_initialized_db().await;
let config = get_default_config();
let handler = SqlBackendHandler::new(config, sql_pool.clone());
insert_user(&handler, "bob", "bob00").await;
insert_user(&handler, "patrick", "pass").await;
insert_user(&handler, "John", "Pa33w0rd!").await;
let group_1 = insert_group(&handler, "Best Group").await;
let group_2 = insert_group(&handler, "Worst Group").await;
insert_membership(&handler, group_1, "bob").await;
insert_membership(&handler, group_1, "patrick").await;
insert_membership(&handler, group_2, "patrick").await;
insert_membership(&handler, group_2, "John").await;
assert_eq!(
handler.list_groups().await.unwrap(),
vec![
Group {
display_name: "Best Group".to_string(),
users: vec!["bob".to_string(), "patrick".to_string()]
},
Group {
display_name: "Worst Group".to_string(),
users: vec!["John".to_string(), "patrick".to_string()]
}
]
);
}
#[tokio::test]
async fn test_get_user_details() {
let sql_pool = get_initialized_db().await;
let config = get_default_config();
let handler = SqlBackendHandler::new(config, sql_pool);
insert_user(&handler, "bob", "bob00").await;
{
let user = handler.get_user_details("bob").await.unwrap();
assert_eq!(user.user_id, "bob".to_string());
}
{
handler.get_user_details("John").await.unwrap_err();
}
}
#[tokio::test]
async fn test_get_user_groups() {
let sql_pool = get_initialized_db().await;
let config = get_default_config();
let handler = SqlBackendHandler::new(config, sql_pool.clone());
insert_user(&handler, "bob", "bob00").await;
insert_user(&handler, "patrick", "pass").await;
insert_user(&handler, "John", "Pa33w0rd!").await;
let group_1 = insert_group(&handler, "Group1").await;
let group_2 = insert_group(&handler, "Group2").await;
insert_membership(&handler, group_1, "bob").await;
insert_membership(&handler, group_1, "patrick").await;
insert_membership(&handler, group_2, "patrick").await;
let mut bob_groups = HashSet::new();
bob_groups.insert("Group1".to_string());
let mut patrick_groups = HashSet::new();
patrick_groups.insert("Group1".to_string());
patrick_groups.insert("Group2".to_string());
assert_eq!(handler.get_user_groups("bob").await.unwrap(), bob_groups);
assert_eq!(
handler.get_user_groups("patrick").await.unwrap(),
patrick_groups
);
assert_eq!(
handler.get_user_groups("John").await.unwrap(),
HashSet::new()
);
}
#[tokio::test]
async fn test_delete_user() {
let sql_pool = get_initialized_db().await;
let config = get_default_config();
let handler = SqlBackendHandler::new(config, sql_pool.clone());
insert_user(&handler, "val", "s3np4i").await;
insert_user(&handler, "Hector", "Be$t").await;
insert_user(&handler, "Jennz", "boupBoup").await;
// Remove a user
let _request_result = handler.delete_user("Jennz").await.unwrap();
let users = handler
.list_users(None)
.await
.unwrap()
.into_iter()
.map(|u| u.user_id)
.collect::<Vec<_>>();
assert_eq!(users, vec!["Hector", "val"]);
// Insert new user and remove two
insert_user(&handler, "NewBoi", "Joni").await;
let _request_result = handler.delete_user("Hector").await.unwrap();
let _request_result = handler.delete_user("NewBoi").await.unwrap();
let users = handler
.list_users(None)
.await
.unwrap()
.into_iter()
.map(|u| u.user_id)
.collect::<Vec<_>>();
assert_eq!(users, vec!["val"]);
}
}

View File

@@ -0,0 +1,331 @@
use super::{
error::*,
handler::{BindRequest, LoginHandler},
opaque_handler::*,
sql_backend_handler::SqlBackendHandler,
sql_tables::*,
};
use async_trait::async_trait;
use lldap_auth::opaque;
use log::*;
use sea_query::{Expr, Iden, Query};
use sqlx::Row;
type SqlOpaqueHandler = SqlBackendHandler;
fn passwords_match(
password_file_bytes: &[u8],
clear_password: &str,
server_setup: &opaque::server::ServerSetup,
username: &str,
) -> Result<()> {
use opaque::{client, server};
let mut rng = rand::rngs::OsRng;
let client_login_start_result = client::login::start_login(clear_password, &mut rng)?;
let password_file = server::ServerRegistration::deserialize(password_file_bytes)
.map_err(opaque::AuthenticationError::ProtocolError)?;
let server_login_start_result = server::login::start_login(
&mut rng,
server_setup,
Some(password_file),
client_login_start_result.message,
username,
)?;
client::login::finish_login(
client_login_start_result.state,
server_login_start_result.message,
)?;
Ok(())
}
impl SqlBackendHandler {
fn get_orion_secret_key(&self) -> Result<orion::aead::SecretKey> {
Ok(orion::aead::SecretKey::from_slice(
self.config.get_server_keys().private(),
)?)
}
async fn get_password_file_for_user(
&self,
username: &str,
) -> Result<Option<opaque::server::ServerRegistration>> {
// Fetch the previously registered password file from the DB.
let password_file_bytes = {
let query = Query::select()
.column(Users::PasswordHash)
.from(Users::Table)
.and_where(Expr::col(Users::UserId).eq(username))
.to_string(DbQueryBuilder {});
if let Some(row) = sqlx::query(&query).fetch_optional(&self.sql_pool).await? {
if let Some(bytes) =
row.get::<Option<Vec<u8>>, _>(&*Users::PasswordHash.to_string())
{
bytes
} else {
// No password set.
return Ok(None);
}
} else {
// No such user.
return Ok(None);
}
};
opaque::server::ServerRegistration::deserialize(&password_file_bytes)
.map(Option::Some)
.map_err(|_| {
DomainError::InternalError(format!("Corrupted password file for {}", username))
})
}
}
#[async_trait]
impl LoginHandler for SqlBackendHandler {
async fn bind(&self, request: BindRequest) -> Result<()> {
if request.name == self.config.ldap_user_dn {
if request.password == self.config.ldap_user_pass {
return Ok(());
} else {
debug!(r#"Invalid password for LDAP bind user"#);
return Err(DomainError::AuthenticationError(request.name));
}
}
let query = Query::select()
.column(Users::PasswordHash)
.from(Users::Table)
.and_where(Expr::col(Users::UserId).eq(request.name.as_str()))
.to_string(DbQueryBuilder {});
if let Ok(row) = sqlx::query(&query).fetch_one(&self.sql_pool).await {
if let Some(password_hash) =
row.get::<Option<Vec<u8>>, _>(&*Users::PasswordHash.to_string())
{
if let Err(e) = passwords_match(
&password_hash,
&request.password,
self.config.get_server_setup(),
&request.name,
) {
debug!(r#"Invalid password for "{}": {}"#, request.name, e);
} else {
return Ok(());
}
} else {
debug!(r#"User "{}" has no password"#, request.name);
}
} else {
debug!(r#"No user found for "{}""#, request.name);
}
Err(DomainError::AuthenticationError(request.name))
}
}
#[async_trait]
impl OpaqueHandler for SqlOpaqueHandler {
async fn login_start(
&self,
request: login::ClientLoginStartRequest,
) -> Result<login::ServerLoginStartResponse> {
let maybe_password_file = self.get_password_file_for_user(&request.username).await?;
let mut rng = rand::rngs::OsRng;
// Get the CredentialResponse for the user, or a dummy one if no user/no password.
let start_response = opaque::server::login::start_login(
&mut rng,
self.config.get_server_setup(),
maybe_password_file,
request.login_start_request,
&request.username,
)?;
let secret_key = self.get_orion_secret_key()?;
let server_data = login::ServerData {
username: request.username,
server_login: start_response.state,
};
let encrypted_state = orion::aead::seal(&secret_key, &bincode::serialize(&server_data)?)?;
Ok(login::ServerLoginStartResponse {
server_data: base64::encode(&encrypted_state),
credential_response: start_response.message,
})
}
async fn login_finish(&self, request: login::ClientLoginFinishRequest) -> Result<String> {
let secret_key = self.get_orion_secret_key()?;
let login::ServerData {
username,
server_login,
} = bincode::deserialize(&orion::aead::open(
&secret_key,
&base64::decode(&request.server_data)?,
)?)?;
// Finish the login: this makes sure the client data is correct, and gives a session key we
// don't need.
let _session_key =
opaque::server::login::finish_login(server_login, request.credential_finalization)?
.session_key;
Ok(username)
}
async fn registration_start(
&self,
request: registration::ClientRegistrationStartRequest,
) -> Result<registration::ServerRegistrationStartResponse> {
// Generate the server-side key and derive the data to send back.
let start_response = opaque::server::registration::start_registration(
self.config.get_server_setup(),
request.registration_start_request,
&request.username,
)?;
let secret_key = self.get_orion_secret_key()?;
let server_data = registration::ServerData {
username: request.username,
};
let encrypted_state = orion::aead::seal(&secret_key, &bincode::serialize(&server_data)?)?;
Ok(registration::ServerRegistrationStartResponse {
server_data: base64::encode(encrypted_state),
registration_response: start_response.message,
})
}
async fn registration_finish(
&self,
request: registration::ClientRegistrationFinishRequest,
) -> Result<()> {
let secret_key = self.get_orion_secret_key()?;
let registration::ServerData { username } = bincode::deserialize(&orion::aead::open(
&secret_key,
&base64::decode(&request.server_data)?,
)?)?;
let password_file =
opaque::server::registration::get_password_file(request.registration_upload);
{
// Set the user password to the new password.
let update_query = Query::update()
.table(Users::Table)
.values(vec![(
Users::PasswordHash,
password_file.serialize().into(),
)])
.and_where(Expr::col(Users::UserId).eq(username))
.to_string(DbQueryBuilder {});
sqlx::query(&update_query).execute(&self.sql_pool).await?;
}
Ok(())
}
}
/// Convenience function to set a user's password.
pub(crate) async fn register_password(
opaque_handler: &SqlOpaqueHandler,
username: &str,
password: &str,
) -> Result<()> {
let mut rng = rand::rngs::OsRng;
use registration::*;
let registration_start = opaque::client::registration::start_registration(password, &mut rng)?;
let start_response = opaque_handler
.registration_start(ClientRegistrationStartRequest {
username: username.to_string(),
registration_start_request: registration_start.message,
})
.await?;
let registration_finish = opaque::client::registration::finish_registration(
registration_start.state,
start_response.registration_response,
&mut rng,
)?;
opaque_handler
.registration_finish(ClientRegistrationFinishRequest {
server_data: start_response.server_data,
registration_upload: registration_finish.message,
})
.await
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
domain::{
handler::{BackendHandler, CreateUserRequest},
sql_backend_handler::SqlBackendHandler,
sql_tables::init_table,
},
infra::configuration::{Configuration, ConfigurationBuilder},
};
fn get_default_config() -> Configuration {
ConfigurationBuilder::default()
.verbose(true)
.build()
.unwrap()
}
async fn get_in_memory_db() -> Pool {
PoolOptions::new().connect("sqlite::memory:").await.unwrap()
}
async fn get_initialized_db() -> Pool {
let sql_pool = get_in_memory_db().await;
init_table(&sql_pool).await.unwrap();
sql_pool
}
async fn insert_user_no_password(handler: &SqlBackendHandler, name: &str) {
handler
.create_user(CreateUserRequest {
user_id: name.to_string(),
email: "bob@bob.bob".to_string(),
..Default::default()
})
.await
.unwrap();
}
async fn attempt_login(
opaque_handler: &SqlOpaqueHandler,
username: &str,
password: &str,
) -> Result<()> {
let mut rng = rand::rngs::OsRng;
use login::*;
let login_start = opaque::client::login::start_login(password, &mut rng)?;
let start_response = opaque_handler
.login_start(ClientLoginStartRequest {
username: username.to_string(),
login_start_request: login_start.message,
})
.await?;
let login_finish = opaque::client::login::finish_login(
login_start.state,
start_response.credential_response,
)?;
opaque_handler
.login_finish(ClientLoginFinishRequest {
server_data: start_response.server_data,
credential_finalization: login_finish.message,
})
.await?;
Ok(())
}
#[tokio::test]
async fn test_flow() -> Result<()> {
let sql_pool = get_initialized_db().await;
let config = get_default_config();
let backend_handler = SqlBackendHandler::new(config.clone(), sql_pool.clone());
let opaque_handler = SqlOpaqueHandler::new(config, sql_pool);
insert_user_no_password(&backend_handler, "bob").await;
attempt_login(&opaque_handler, "bob", "bob00")
.await
.unwrap_err();
register_password(&opaque_handler, "bob", "bob00").await?;
attempt_login(&opaque_handler, "bob", "wrong_password")
.await
.unwrap_err();
attempt_login(&opaque_handler, "bob", "bob00").await?;
Ok(())
}
}

View File

@@ -0,0 +1,152 @@
use sea_query::*;
pub type Pool = sqlx::sqlite::SqlitePool;
pub type PoolOptions = sqlx::sqlite::SqlitePoolOptions;
pub type DbRow = sqlx::sqlite::SqliteRow;
pub type DbQueryBuilder = SqliteQueryBuilder;
#[derive(Iden)]
pub enum Users {
Table,
UserId,
Email,
DisplayName,
FirstName,
LastName,
Avatar,
CreationDate,
PasswordHash,
TotpSecret,
MfaType,
}
#[derive(Iden)]
pub enum Groups {
Table,
GroupId,
DisplayName,
}
#[derive(Iden)]
pub enum Memberships {
Table,
UserId,
GroupId,
}
pub async fn init_table(pool: &Pool) -> sqlx::Result<()> {
// SQLite needs this pragma to be turned on. Other DB might not understand this, so ignore the
// error.
let _ = sqlx::query("PRAGMA foreign_keys = ON").execute(pool).await;
sqlx::query(
&Table::create()
.table(Users::Table)
.if_not_exists()
.col(
ColumnDef::new(Users::UserId)
.string_len(255)
.not_null()
.primary_key(),
)
.col(ColumnDef::new(Users::Email).string_len(255).not_null())
.col(ColumnDef::new(Users::DisplayName).string_len(255))
.col(ColumnDef::new(Users::FirstName).string_len(255))
.col(ColumnDef::new(Users::LastName).string_len(255))
.col(ColumnDef::new(Users::Avatar).binary())
.col(ColumnDef::new(Users::CreationDate).date_time().not_null())
.col(ColumnDef::new(Users::PasswordHash).binary())
.col(ColumnDef::new(Users::TotpSecret).string_len(64))
.col(ColumnDef::new(Users::MfaType).string_len(64))
.to_string(DbQueryBuilder {}),
)
.execute(pool)
.await?;
sqlx::query(
&Table::create()
.table(Groups::Table)
.if_not_exists()
.col(
ColumnDef::new(Groups::GroupId)
.integer()
.not_null()
.primary_key(),
)
.col(
ColumnDef::new(Groups::DisplayName)
.string_len(255)
.unique_key()
.not_null(),
)
.to_string(DbQueryBuilder {}),
)
.execute(pool)
.await?;
sqlx::query(
&Table::create()
.table(Memberships::Table)
.if_not_exists()
.col(
ColumnDef::new(Memberships::UserId)
.string_len(255)
.not_null(),
)
.col(ColumnDef::new(Memberships::GroupId).integer().not_null())
.foreign_key(
ForeignKey::create()
.name("MembershipUserForeignKey")
.table(Memberships::Table, Users::Table)
.col(Memberships::UserId, Users::UserId)
.on_delete(ForeignKeyAction::Cascade)
.on_update(ForeignKeyAction::Cascade),
)
.foreign_key(
ForeignKey::create()
.name("MembershipGroupForeignKey")
.table(Memberships::Table, Groups::Table)
.col(Memberships::GroupId, Groups::GroupId)
.on_delete(ForeignKeyAction::Cascade)
.on_update(ForeignKeyAction::Cascade),
)
.to_string(DbQueryBuilder {}),
)
.execute(pool)
.await?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::prelude::*;
use sqlx::{Column, Row};
#[actix_rt::test]
async fn test_init_table() {
let sql_pool = PoolOptions::new().connect("sqlite::memory:").await.unwrap();
init_table(&sql_pool).await.unwrap();
sqlx::query(r#"INSERT INTO users
(user_id, email, display_name, first_name, last_name, creation_date, password_hash)
VALUES ("bôb", "böb@bob.bob", "Bob Bobbersön", "Bob", "Bobberson", "1970-01-01 00:00:00", "bob00")"#).execute(&sql_pool).await.unwrap();
let row =
sqlx::query(r#"SELECT display_name, creation_date FROM users WHERE user_id = "bôb""#)
.fetch_one(&sql_pool)
.await
.unwrap();
assert_eq!(row.column(0).name(), "display_name");
assert_eq!(row.get::<String, _>("display_name"), "Bob Bobbersön");
assert_eq!(
row.get::<DateTime<Utc>, _>("creation_date"),
Utc.timestamp(0, 0),
);
}
#[actix_rt::test]
async fn test_already_init_table() {
let sql_pool = PoolOptions::new().connect("sqlite::memory:").await.unwrap();
init_table(&sql_pool).await.unwrap();
init_table(&sql_pool).await.unwrap();
}
}