Move backend source to server/ subpackage

To clarify the organization.
This commit is contained in:
Valentin Tolmer
2021-08-31 16:46:31 +02:00
committed by nitnelave
parent 3eb53ba5bf
commit d8df47b35d
30 changed files with 93 additions and 88 deletions

View File

@@ -0,0 +1,415 @@
use crate::{
domain::{
error::DomainError,
handler::{BackendHandler, BindRequest, LoginHandler},
opaque_handler::OpaqueHandler,
},
infra::{
tcp_backend_handler::*,
tcp_server::{error_to_http_response, AppState},
},
};
use actix_web::{
cookie::{Cookie, SameSite},
dev::{Service, ServiceRequest, ServiceResponse, Transform},
error::{ErrorBadRequest, ErrorUnauthorized},
web, HttpRequest, HttpResponse,
};
use anyhow::Result;
use chrono::prelude::*;
use futures::future::{ok, Ready};
use futures_util::{FutureExt, TryFutureExt};
use hmac::Hmac;
use jwt::{SignWithKey, VerifyWithKey};
use lldap_auth::{login, registration, JWTClaims};
use sha2::Sha512;
use std::collections::{hash_map::DefaultHasher, HashSet};
use std::hash::{Hash, Hasher};
use std::pin::Pin;
use std::task::{Context, Poll};
use time::ext::NumericalDuration;
type Token<S> = jwt::Token<jwt::Header, JWTClaims, S>;
type SignedToken = Token<jwt::token::Signed>;
fn create_jwt(key: &Hmac<Sha512>, user: String, groups: HashSet<String>) -> SignedToken {
let claims = JWTClaims {
exp: Utc::now() + chrono::Duration::days(1),
iat: Utc::now(),
user,
groups,
};
let header = jwt::Header {
algorithm: jwt::AlgorithmType::Hs512,
..Default::default()
};
jwt::Token::new(header, claims).sign_with_key(key).unwrap()
}
fn get_refresh_token_from_cookie(
request: HttpRequest,
) -> std::result::Result<(u64, String), HttpResponse> {
match request.cookie("refresh_token") {
None => Err(HttpResponse::Unauthorized().body("Missing refresh token")),
Some(t) => match t.value().split_once("+") {
None => Err(HttpResponse::Unauthorized().body("Invalid refresh token")),
Some((token, u)) => {
let refresh_token_hash = {
let mut s = DefaultHasher::new();
token.hash(&mut s);
s.finish()
};
Ok((refresh_token_hash, u.to_string()))
}
},
}
}
async fn get_refresh<Backend>(
data: web::Data<AppState<Backend>>,
request: HttpRequest,
) -> HttpResponse
where
Backend: TcpBackendHandler + BackendHandler + 'static,
{
let backend_handler = &data.backend_handler;
let jwt_key = &data.jwt_key;
let (refresh_token_hash, user) = match get_refresh_token_from_cookie(request) {
Ok(t) => t,
Err(http_response) => return http_response,
};
let res_found = data
.backend_handler
.check_token(refresh_token_hash, &user)
.await;
// Async closures are not supported yet.
match res_found {
Ok(found) => {
if found {
backend_handler.get_user_groups(&user).await
} else {
Err(DomainError::AuthenticationError(
"Invalid refresh token".to_string(),
))
}
}
Err(e) => Err(e),
}
.map(|groups| create_jwt(jwt_key, user.to_string(), groups))
.map(|token| {
HttpResponse::Ok()
.cookie(
Cookie::build("token", token.as_str())
.max_age(1.days())
.path("/api")
.http_only(true)
.same_site(SameSite::Strict)
.finish(),
)
.body(token.as_str().to_owned())
})
.unwrap_or_else(error_to_http_response)
}
async fn get_logout<Backend>(
data: web::Data<AppState<Backend>>,
request: HttpRequest,
) -> HttpResponse
where
Backend: TcpBackendHandler + BackendHandler + 'static,
{
let (refresh_token_hash, user) = match get_refresh_token_from_cookie(request) {
Ok(t) => t,
Err(http_response) => return http_response,
};
if let Err(response) = data
.backend_handler
.delete_refresh_token(refresh_token_hash)
.map_err(error_to_http_response)
.await
{
return response;
};
match data
.backend_handler
.blacklist_jwts(&user)
.map_err(error_to_http_response)
.await
{
Ok(new_blacklisted_jwts) => {
let mut jwt_blacklist = data.jwt_blacklist.write().unwrap();
for jwt in new_blacklisted_jwts {
jwt_blacklist.insert(jwt);
}
}
Err(response) => return response,
};
HttpResponse::Ok()
.cookie(
Cookie::build("token", "")
.max_age(0.days())
.path("/api")
.http_only(true)
.same_site(SameSite::Strict)
.finish(),
)
.cookie(
Cookie::build("refresh_token", "")
.max_age(0.days())
.path("/auth")
.http_only(true)
.same_site(SameSite::Strict)
.finish(),
)
.finish()
}
pub(crate) fn error_to_api_response<T>(error: DomainError) -> ApiResult<T> {
ApiResult::Right(error_to_http_response(error))
}
pub type ApiResult<M> = actix_web::Either<web::Json<M>, HttpResponse>;
async fn opaque_login_start<Backend>(
data: web::Data<AppState<Backend>>,
request: web::Json<login::ClientLoginStartRequest>,
) -> ApiResult<login::ServerLoginStartResponse>
where
Backend: OpaqueHandler + 'static,
{
data.backend_handler
.login_start(request.into_inner())
.await
.map(|res| ApiResult::Left(web::Json(res)))
.unwrap_or_else(error_to_api_response)
}
async fn get_login_successful_response<Backend>(
data: &web::Data<AppState<Backend>>,
name: &str,
) -> HttpResponse
where
Backend: TcpBackendHandler + BackendHandler,
{
// The authentication was successful, we need to fetch the groups to create the JWT
// token.
data.backend_handler
.get_user_groups(name)
.and_then(|g| async { Ok((g, data.backend_handler.create_refresh_token(name).await?)) })
.await
.map(|(groups, (refresh_token, max_age))| {
let token = create_jwt(&data.jwt_key, name.to_string(), groups);
HttpResponse::Ok()
.cookie(
Cookie::build("token", token.as_str())
.max_age(1.days())
.path("/api")
.http_only(true)
.same_site(SameSite::Strict)
.finish(),
)
.cookie(
Cookie::build("refresh_token", refresh_token + "+" + name)
.max_age(max_age.num_days().days())
.path("/auth")
.http_only(true)
.same_site(SameSite::Strict)
.finish(),
)
.body(token.as_str().to_owned())
})
.unwrap_or_else(error_to_http_response)
}
async fn opaque_login_finish<Backend>(
data: web::Data<AppState<Backend>>,
request: web::Json<login::ClientLoginFinishRequest>,
) -> HttpResponse
where
Backend: TcpBackendHandler + BackendHandler + OpaqueHandler + 'static,
{
let name = match data
.backend_handler
.login_finish(request.into_inner())
.await
{
Ok(n) => n,
Err(e) => return error_to_http_response(e),
};
get_login_successful_response(&data, &name).await
}
async fn post_authorize<Backend>(
data: web::Data<AppState<Backend>>,
request: web::Json<BindRequest>,
) -> HttpResponse
where
Backend: TcpBackendHandler + BackendHandler + LoginHandler + 'static,
{
let name = request.name.clone();
if let Err(e) = data.backend_handler.bind(request.into_inner()).await {
return error_to_http_response(e);
}
get_login_successful_response(&data, &name).await
}
async fn opaque_register_start<Backend>(
data: web::Data<AppState<Backend>>,
request: web::Json<registration::ClientRegistrationStartRequest>,
) -> ApiResult<registration::ServerRegistrationStartResponse>
where
Backend: OpaqueHandler + 'static,
{
data.backend_handler
.registration_start(request.into_inner())
.await
.map(|res| ApiResult::Left(web::Json(res)))
.unwrap_or_else(error_to_api_response)
}
async fn opaque_register_finish<Backend>(
data: web::Data<AppState<Backend>>,
request: web::Json<registration::ClientRegistrationFinishRequest>,
) -> HttpResponse
where
Backend: TcpBackendHandler + BackendHandler + OpaqueHandler + 'static,
{
if let Err(e) = data
.backend_handler
.registration_finish(request.into_inner())
.await
{
return error_to_http_response(e);
}
HttpResponse::Ok().finish()
}
pub struct CookieToHeaderTranslatorFactory;
impl<S> Transform<S, ServiceRequest> for CookieToHeaderTranslatorFactory
where
S: Service<ServiceRequest, Response = ServiceResponse, Error = actix_web::Error>,
S::Future: 'static,
{
type Response = ServiceResponse;
type Error = actix_web::Error;
type InitError = ();
type Transform = CookieToHeaderTranslator<S>;
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
ok(CookieToHeaderTranslator { service })
}
}
pub struct CookieToHeaderTranslator<S> {
service: S,
}
impl<S> Service<ServiceRequest> for CookieToHeaderTranslator<S>
where
S: Service<ServiceRequest, Response = ServiceResponse, Error = actix_web::Error>,
S::Future: 'static,
{
type Response = ServiceResponse;
type Error = actix_web::Error;
#[allow(clippy::type_complexity)]
type Future = Pin<Box<dyn core::future::Future<Output = Result<Self::Response, Self::Error>>>>;
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx)
}
fn call(&self, mut req: ServiceRequest) -> Self::Future {
if let Some(token_cookie) = req.cookie("token") {
if let Ok(header_value) = actix_http::header::HeaderValue::from_str(&format!(
"Bearer {}",
token_cookie.value()
)) {
req.headers_mut()
.insert(actix_http::header::AUTHORIZATION, header_value);
} else {
return async move {
Ok(req.error_response(ErrorBadRequest("Invalid token cookie")))
}
.boxed_local();
}
};
Box::pin(self.service.call(req))
}
}
pub struct ValidationResults {
pub user: String,
pub is_admin: bool,
}
impl ValidationResults {
#[cfg(test)]
pub fn admin() -> Self {
Self {
user: "admin".to_string(),
is_admin: true,
}
}
pub fn can_access(&self, user: &str) -> bool {
self.is_admin || self.user == user
}
}
pub(crate) fn check_if_token_is_valid<Backend>(
state: &AppState<Backend>,
token_str: &str,
) -> Result<ValidationResults, actix_web::Error> {
let token: Token<_> = VerifyWithKey::verify_with_key(token_str, &state.jwt_key)
.map_err(|_| ErrorUnauthorized("Invalid JWT"))?;
if token.claims().exp.lt(&Utc::now()) {
return Err(ErrorUnauthorized("Expired JWT"));
}
if token.header().algorithm != jwt::AlgorithmType::Hs512 {
return Err(ErrorUnauthorized(format!(
"Unsupported JWT algorithm: '{:?}'. Supported ones are: ['HS512']",
token.header().algorithm
)));
}
let jwt_hash = {
let mut s = DefaultHasher::new();
token_str.hash(&mut s);
s.finish()
};
if state.jwt_blacklist.read().unwrap().contains(&jwt_hash) {
return Err(ErrorUnauthorized("JWT was logged out"));
}
let is_admin = token.claims().groups.contains("lldap_admin");
Ok(ValidationResults {
user: token.claims().user.clone(),
is_admin,
})
}
pub fn configure_server<Backend>(cfg: &mut web::ServiceConfig)
where
Backend: TcpBackendHandler + LoginHandler + OpaqueHandler + BackendHandler + 'static,
{
cfg.service(web::resource("").route(web::post().to(post_authorize::<Backend>)))
.service(
web::resource("/opaque/login/start")
.route(web::post().to(opaque_login_start::<Backend>)),
)
.service(
web::resource("/opaque/login/finish")
.route(web::post().to(opaque_login_finish::<Backend>)),
)
.service(
web::resource("/opaque/register/start")
.route(web::post().to(opaque_register_start::<Backend>)),
)
.service(
web::resource("/opaque/register/finish")
.route(web::post().to(opaque_register_finish::<Backend>)),
)
.service(web::resource("/refresh").route(web::get().to(get_refresh::<Backend>)))
.service(web::resource("/logout").route(web::get().to(get_logout::<Backend>)));
}

50
server/src/infra/cli.rs Normal file
View File

@@ -0,0 +1,50 @@
use clap::Clap;
/// lldap is a lightweight LDAP server
#[derive(Debug, Clap, Clone)]
#[clap(version = "0.1", author = "The LLDAP team")]
pub struct CLIOpts {
/// Export
#[clap(subcommand)]
pub command: Command,
}
#[derive(Debug, Clap, Clone)]
pub enum Command {
/// Export the GraphQL schema to *.graphql.
#[clap(name = "export_graphql_schema")]
ExportGraphQLSchema(ExportGraphQLSchemaOpts),
/// Run the LDAP and GraphQL server.
#[clap(name = "run")]
Run(RunOpts),
}
#[derive(Debug, Clap, Clone)]
pub struct RunOpts {
/// Change config file name
#[clap(short, long, default_value = "lldap_config.toml")]
pub config_file: String,
/// Change ldap port. Default: 389
#[clap(long)]
pub ldap_port: Option<u16>,
/// Change ldap ssl port. Default: 636
#[clap(long)]
pub ldaps_port: Option<u16>,
/// Set verbose logging
#[clap(short, long)]
pub verbose: bool,
}
#[derive(Debug, Clap, Clone)]
pub struct ExportGraphQLSchemaOpts {
/// Output to a file. If not specified, the config is printed to the standard output.
#[clap(short, long)]
pub output_file: Option<String>,
}
pub fn init() -> CLIOpts {
CLIOpts::parse()
}

View File

@@ -0,0 +1,121 @@
use anyhow::{Context, Result};
use figment::{
providers::{Env, Format, Serialized, Toml},
Figment,
};
use lldap_auth::opaque::{server::ServerSetup, KeyPair};
use serde::{Deserialize, Serialize};
use crate::infra::cli::RunOpts;
#[derive(Clone, Debug, Deserialize, Serialize, derive_builder::Builder)]
#[builder(
pattern = "owned",
default = "Configuration::default()",
build_fn(name = "private_build", validate = "Self::validate")
)]
pub struct Configuration {
pub ldap_port: u16,
pub ldaps_port: u16,
pub http_port: u16,
pub jwt_secret: String,
pub ldap_base_dn: String,
pub ldap_user_dn: String,
pub ldap_user_pass: String,
pub database_url: String,
pub verbose: bool,
pub key_file: String,
#[serde(skip)]
#[builder(field(private), setter(strip_option))]
server_setup: Option<ServerSetup>,
}
impl ConfigurationBuilder {
#[cfg(test)]
pub fn build(self) -> Result<Configuration> {
let server_setup = get_server_setup(self.key_file.as_deref().unwrap_or("server_key"))?;
Ok(self.server_setup(server_setup).private_build()?)
}
fn validate(&self) -> Result<(), String> {
if self.server_setup.is_none() {
Err("Don't use `private_build`, use `build` instead".to_string())
} else {
Ok(())
}
}
}
impl Configuration {
pub fn get_server_setup(&self) -> &ServerSetup {
self.server_setup.as_ref().unwrap()
}
pub fn get_server_keys(&self) -> &KeyPair {
self.get_server_setup().keypair()
}
fn merge_with_cli(mut self: Configuration, cli_opts: RunOpts) -> Configuration {
if cli_opts.verbose {
self.verbose = true;
}
if let Some(port) = cli_opts.ldap_port {
self.ldap_port = port;
}
if let Some(port) = cli_opts.ldaps_port {
self.ldaps_port = port;
}
self
}
pub(super) fn default() -> Self {
Configuration {
ldap_port: 3890,
ldaps_port: 6360,
http_port: 17170,
jwt_secret: String::from("secretjwtsecret"),
ldap_base_dn: String::from("dc=example,dc=com"),
// cn=admin,dc=example,dc=com
ldap_user_dn: String::from("admin"),
ldap_user_pass: String::from("password"),
database_url: String::from("sqlite://users.db?mode=rwc"),
verbose: false,
key_file: String::from("server_key"),
server_setup: None,
}
}
}
fn get_server_setup(file_path: &str) -> Result<ServerSetup> {
use std::path::Path;
let path = Path::new(file_path);
if path.exists() {
let bytes =
std::fs::read(file_path).context(format!("Could not read key file `{}`", file_path))?;
Ok(ServerSetup::deserialize(&bytes)?)
} else {
let mut rng = rand::rngs::OsRng;
let server_setup = ServerSetup::new(&mut rng);
std::fs::write(path, server_setup.serialize()).context(format!(
"Could not write the generated server setup to file `{}`",
file_path,
))?;
Ok(server_setup)
}
}
pub fn init(cli_opts: RunOpts) -> Result<Configuration> {
let config_file = cli_opts.config_file.clone();
let config: Configuration = Figment::from(Serialized::defaults(Configuration::default()))
.merge(Toml::file(config_file))
.merge(Env::prefixed("LLDAP_"))
.extract()?;
let mut config = config.merge_with_cli(cli_opts);
config.server_setup = Some(get_server_setup(&config.key_file)?);
Ok(config)
}

View File

@@ -0,0 +1,82 @@
use crate::{
domain::sql_tables::{DbQueryBuilder, Pool},
infra::jwt_sql_tables::{JwtRefreshStorage, JwtStorage},
};
use actix::prelude::*;
use chrono::Local;
use cron::Schedule;
use sea_query::{Expr, Query};
use std::{str::FromStr, time::Duration};
// Define actor
pub struct Scheduler {
schedule: Schedule,
sql_pool: Pool,
}
// Provide Actor implementation for our actor
impl Actor for Scheduler {
type Context = Context<Self>;
fn started(&mut self, context: &mut Context<Self>) {
log::info!("DB Cleanup Cron started");
context.run_later(self.duration_until_next(), move |this, ctx| {
this.schedule_task(ctx)
});
}
fn stopped(&mut self, _ctx: &mut Context<Self>) {
log::info!("DB Cleanup stopped");
}
}
impl Scheduler {
pub fn new(cron_expression: &str, sql_pool: Pool) -> Self {
let schedule = Schedule::from_str(cron_expression).unwrap();
Self { schedule, sql_pool }
}
fn schedule_task(&self, ctx: &mut Context<Self>) {
log::info!("Cleaning DB");
let future = actix::fut::wrap_future::<_, Self>(Self::cleanup_db(self.sql_pool.clone()));
ctx.spawn(future);
ctx.run_later(self.duration_until_next(), move |this, ctx| {
this.schedule_task(ctx)
});
}
async fn cleanup_db(sql_pool: Pool) {
if let Err(e) = sqlx::query(
&Query::delete()
.from_table(JwtRefreshStorage::Table)
.and_where(Expr::col(JwtRefreshStorage::ExpiryDate).lt(Local::now().naive_utc()))
.to_string(DbQueryBuilder {}),
)
.execute(&sql_pool)
.await
{
log::error!("DB error while cleaning up JWT refresh tokens: {}", e);
};
if let Err(e) = sqlx::query(
&Query::delete()
.from_table(JwtStorage::Table)
.and_where(Expr::col(JwtStorage::ExpiryDate).lt(Local::now().naive_utc()))
.to_string(DbQueryBuilder {}),
)
.execute(&sql_pool)
.await
{
log::error!("DB error while cleaning up JWT storage: {}", e);
};
log::info!("DB cleaned!");
}
fn duration_until_next(&self) -> Duration {
let now = Local::now();
let next = self.schedule.upcoming(Local).next().unwrap();
let duration_until = next.signed_duration_since(now);
duration_until.to_std().unwrap()
}
}

View File

@@ -0,0 +1,100 @@
use crate::{
domain::handler::BackendHandler,
infra::{
auth_service::{check_if_token_is_valid, ValidationResults},
cli::ExportGraphQLSchemaOpts,
tcp_server::AppState,
},
};
use actix_web::{web, Error, HttpResponse};
use actix_web_httpauth::extractors::bearer::BearerAuth;
use juniper::{EmptySubscription, RootNode};
use juniper_actix::{graphiql_handler, graphql_handler, playground_handler};
use super::{mutation::Mutation, query::Query};
pub struct Context<Handler: BackendHandler> {
pub handler: Box<Handler>,
pub validation_result: ValidationResults,
}
impl<Handler: BackendHandler> juniper::Context for Context<Handler> {}
type Schema<Handler> =
RootNode<'static, Query<Handler>, Mutation<Handler>, EmptySubscription<Context<Handler>>>;
fn schema<Handler: BackendHandler + Sync>() -> Schema<Handler> {
Schema::new(
Query::<Handler>::new(),
Mutation::<Handler>::new(),
EmptySubscription::<Context<Handler>>::new(),
)
}
pub fn export_schema(opts: ExportGraphQLSchemaOpts) -> anyhow::Result<()> {
use crate::domain::sql_backend_handler::SqlBackendHandler;
use anyhow::Context;
let output = schema::<SqlBackendHandler>().as_schema_language();
match opts.output_file {
None => println!("{}", output),
Some(path) => {
use std::fs::File;
use std::io::prelude::*;
use std::path::Path;
let path = Path::new(&path);
let mut file =
File::create(&path).context(format!("unable to open '{}'", path.display()))?;
file.write_all(output.as_bytes())
.context(format!("unable to write in '{}'", path.display()))?;
}
}
Ok(())
}
async fn graphiql_route() -> Result<HttpResponse, Error> {
graphiql_handler("/api/graphql", None).await
}
async fn playground_route() -> Result<HttpResponse, Error> {
playground_handler("/api/graphql", None).await
}
async fn graphql_route<Handler: BackendHandler + Sync>(
req: actix_web::HttpRequest,
mut payload: actix_web::web::Payload,
data: web::Data<AppState<Handler>>,
) -> Result<HttpResponse, Error> {
use actix_web::FromRequest;
let bearer = BearerAuth::from_request(&req, &mut payload.0).await?;
let validation_result = check_if_token_is_valid(&data, bearer.token())?;
let context = Context::<Handler> {
handler: Box::new(data.backend_handler.clone()),
validation_result,
};
graphql_handler(&schema(), &context, req, payload).await
}
pub fn configure_endpoint<Backend>(cfg: &mut web::ServiceConfig)
where
Backend: BackendHandler + Sync + 'static,
{
let json_config = web::JsonConfig::default()
.limit(4096)
.error_handler(|err, _req| {
// create custom error response
log::error!("API error: {}", err);
let msg = err.to_string();
actix_web::error::InternalError::from_response(
err,
HttpResponse::BadRequest().body(msg),
)
.into()
});
cfg.app_data(json_config);
cfg.service(
web::resource("/graphql")
.route(web::post().to(graphql_route::<Backend>))
.route(web::get().to(graphql_route::<Backend>)),
);
cfg.service(web::resource("/graphql/playground").route(web::get().to(playground_route)));
cfg.service(web::resource("/graphql/graphiql").route(web::get().to(graphiql_route)));
}

View File

@@ -0,0 +1,3 @@
pub mod api;
pub mod mutation;
pub mod query;

View File

@@ -0,0 +1,55 @@
use crate::domain::handler::{BackendHandler, CreateUserRequest};
use juniper::{graphql_object, FieldResult, GraphQLInputObject};
use super::api::Context;
#[derive(PartialEq, Eq, Debug)]
/// The top-level GraphQL mutation type.
pub struct Mutation<Handler: BackendHandler> {
_phantom: std::marker::PhantomData<Box<Handler>>,
}
impl<Handler: BackendHandler> Mutation<Handler> {
pub fn new() -> Self {
Self {
_phantom: std::marker::PhantomData,
}
}
}
#[derive(PartialEq, Eq, Debug, GraphQLInputObject)]
/// The details required to create a user.
pub struct UserInput {
id: String,
email: String,
display_name: Option<String>,
first_name: Option<String>,
last_name: Option<String>,
}
#[graphql_object(context = Context<Handler>)]
impl<Handler: BackendHandler + Sync> Mutation<Handler> {
async fn create_user(
context: &Context<Handler>,
user: UserInput,
) -> FieldResult<super::query::User<Handler>> {
if !context.validation_result.is_admin {
return Err("Unauthorized user creation".into());
}
context
.handler
.create_user(CreateUserRequest {
user_id: user.id.clone(),
email: user.email,
display_name: user.display_name,
first_name: user.first_name,
last_name: user.last_name,
})
.await?;
Ok(context
.handler
.get_user_details(&user.id)
.await
.map(Into::into)?)
}
}

View File

@@ -0,0 +1,348 @@
use crate::domain::handler::BackendHandler;
use juniper::{graphql_object, FieldResult, GraphQLInputObject};
use serde::{Deserialize, Serialize};
use std::convert::TryInto;
type DomainRequestFilter = crate::domain::handler::RequestFilter;
type DomainUser = crate::domain::handler::User;
use super::api::Context;
#[derive(PartialEq, Eq, Debug, GraphQLInputObject)]
/// A filter for requests, specifying a boolean expression based on field constraints. Only one of
/// the fields can be set at a time.
pub struct RequestFilter {
any: Option<Vec<RequestFilter>>,
all: Option<Vec<RequestFilter>>,
not: Option<Box<RequestFilter>>,
eq: Option<EqualityConstraint>,
}
impl TryInto<DomainRequestFilter> for RequestFilter {
type Error = String;
fn try_into(self) -> Result<DomainRequestFilter, Self::Error> {
let mut field_count = 0;
if self.any.is_some() {
field_count += 1;
}
if self.all.is_some() {
field_count += 1;
}
if self.not.is_some() {
field_count += 1;
}
if self.eq.is_some() {
field_count += 1;
}
if field_count == 0 {
return Err("No field specified in request filter".to_string());
}
if field_count > 1 {
return Err("Multiple fields specified in request filter".to_string());
}
if let Some(e) = self.eq {
return Ok(DomainRequestFilter::Equality(e.field, e.value));
}
if let Some(c) = self.any {
return Ok(DomainRequestFilter::Or(
c.into_iter()
.map(TryInto::try_into)
.collect::<Result<Vec<_>, String>>()?,
));
}
if let Some(c) = self.all {
return Ok(DomainRequestFilter::And(
c.into_iter()
.map(TryInto::try_into)
.collect::<Result<Vec<_>, String>>()?,
));
}
if let Some(c) = self.not {
return Ok(DomainRequestFilter::Not(Box::new((*c).try_into()?)));
}
unreachable!();
}
}
#[derive(PartialEq, Eq, Debug, GraphQLInputObject)]
pub struct EqualityConstraint {
field: String,
value: String,
}
#[derive(PartialEq, Eq, Debug)]
/// The top-level GraphQL query type.
pub struct Query<Handler: BackendHandler> {
_phantom: std::marker::PhantomData<Box<Handler>>,
}
impl<Handler: BackendHandler> Query<Handler> {
pub fn new() -> Self {
Self {
_phantom: std::marker::PhantomData,
}
}
}
#[graphql_object(context = Context<Handler>)]
impl<Handler: BackendHandler + Sync> Query<Handler> {
fn api_version() -> &'static str {
"1.0"
}
pub async fn user(context: &Context<Handler>, user_id: String) -> FieldResult<User<Handler>> {
if !context.validation_result.can_access(&user_id) {
return Err("Unauthorized access to user data".into());
}
Ok(context
.handler
.get_user_details(&user_id)
.await
.map(Into::into)?)
}
async fn users(
context: &Context<Handler>,
#[graphql(name = "where")] filters: Option<RequestFilter>,
) -> FieldResult<Vec<User<Handler>>> {
if !context.validation_result.is_admin {
return Err("Unauthorized access to user list".into());
}
Ok(context
.handler
.list_users(filters.map(TryInto::try_into).transpose()?)
.await
.map(|v| v.into_iter().map(Into::into).collect())?)
}
}
#[derive(PartialEq, Eq, Debug, Serialize, Deserialize)]
/// Represents a single user.
pub struct User<Handler: BackendHandler> {
user: DomainUser,
_phantom: std::marker::PhantomData<Box<Handler>>,
}
impl<Handler: BackendHandler> Default for User<Handler> {
fn default() -> Self {
Self {
user: DomainUser::default(),
_phantom: std::marker::PhantomData,
}
}
}
#[graphql_object(context = Context<Handler>)]
impl<Handler: BackendHandler + Sync> User<Handler> {
fn id(&self) -> &str {
&self.user.user_id
}
fn email(&self) -> &str {
&self.user.email
}
fn display_name(&self) -> Option<&String> {
self.user.display_name.as_ref()
}
fn first_name(&self) -> Option<&String> {
self.user.first_name.as_ref()
}
fn last_name(&self) -> Option<&String> {
self.user.last_name.as_ref()
}
fn creation_date(&self) -> chrono::DateTime<chrono::Utc> {
self.user.creation_date
}
/// The groups to which this user belongs.
async fn groups(&self, context: &Context<Handler>) -> FieldResult<Vec<Group<Handler>>> {
Ok(context
.handler
.get_user_groups(&self.user.user_id)
.await
.map(|set| set.into_iter().map(Into::into).collect())?)
}
}
impl<Handler: BackendHandler> From<DomainUser> for User<Handler> {
fn from(user: DomainUser) -> Self {
Self {
user,
_phantom: std::marker::PhantomData,
}
}
}
#[derive(PartialEq, Eq, Debug, Serialize, Deserialize)]
/// Represents a single group.
pub struct Group<Handler: BackendHandler> {
group_id: String,
_phantom: std::marker::PhantomData<Box<Handler>>,
}
#[graphql_object(context = Context<Handler>)]
impl<Handler: BackendHandler + Sync> Group<Handler> {
fn id(&self) -> String {
self.group_id.clone()
}
/// The groups to which this user belongs.
async fn users(&self, context: &Context<Handler>) -> FieldResult<Vec<User<Handler>>> {
if !context.validation_result.is_admin {
return Err("Unauthorized access to group data".into());
}
unimplemented!()
}
}
impl<Handler: BackendHandler> From<String> for Group<Handler> {
fn from(group_id: String) -> Self {
Self {
group_id,
_phantom: std::marker::PhantomData,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{domain::handler::MockTestBackendHandler, infra::auth_service::ValidationResults};
use juniper::{
execute, graphql_value, DefaultScalarValue, EmptyMutation, EmptySubscription, GraphQLType,
RootNode, Variables,
};
use mockall::predicate::eq;
use std::collections::HashSet;
fn schema<'q, C, Q>(query_root: Q) -> RootNode<'q, Q, EmptyMutation<C>, EmptySubscription<C>>
where
Q: GraphQLType<DefaultScalarValue, Context = C, TypeInfo = ()> + 'q,
{
RootNode::new(
query_root,
EmptyMutation::<C>::new(),
EmptySubscription::<C>::new(),
)
}
#[tokio::test]
async fn get_user_by_id() {
const QUERY: &str = r#"{
user(userId: "bob") {
id
email
groups {
id
}
}
}"#;
let mut mock = MockTestBackendHandler::new();
mock.expect_get_user_details()
.with(eq("bob"))
.return_once(|_| {
Ok(DomainUser {
user_id: "bob".to_string(),
email: "bob@bobbers.on".to_string(),
..Default::default()
})
});
let mut groups = HashSet::<String>::new();
groups.insert("Bobbersons".to_string());
mock.expect_get_user_groups()
.with(eq("bob"))
.return_once(|_| Ok(groups));
let context = Context::<MockTestBackendHandler> {
handler: Box::new(mock),
validation_result: ValidationResults::admin(),
};
let schema = schema(Query::<MockTestBackendHandler>::new());
assert_eq!(
execute(QUERY, None, &schema, &Variables::new(), &context).await,
Ok((
graphql_value!(
{
"user": {
"id": "bob",
"email": "bob@bobbers.on",
"groups": [{"id": "Bobbersons"}]
}
}),
vec![]
))
);
}
#[tokio::test]
async fn list_users() {
const QUERY: &str = r#"{
users(filters: {
any: [
{eq: {
field: "id"
value: "bob"
}},
{eq: {
field: "email"
value: "robert@bobbers.on"
}}
]}) {
id
email
}
}"#;
let mut mock = MockTestBackendHandler::new();
use crate::domain::handler::RequestFilter;
mock.expect_list_users()
.with(eq(Some(RequestFilter::Or(vec![
RequestFilter::Equality("id".to_string(), "bob".to_string()),
RequestFilter::Equality("email".to_string(), "robert@bobbers.on".to_string()),
]))))
.return_once(|_| {
Ok(vec![
DomainUser {
user_id: "bob".to_string(),
email: "bob@bobbers.on".to_string(),
..Default::default()
},
DomainUser {
user_id: "robert".to_string(),
email: "robert@bobbers.on".to_string(),
..Default::default()
},
])
});
let context = Context::<MockTestBackendHandler> {
handler: Box::new(mock),
validation_result: ValidationResults::admin(),
};
let schema = schema(Query::<MockTestBackendHandler>::new());
assert_eq!(
execute(QUERY, None, &schema, &Variables::new(), &context).await,
Ok((
graphql_value!(
{
"users": [
{
"id": "bob",
"email": "bob@bobbers.on"
},
{
"id": "robert",
"email": "robert@bobbers.on"
},
]
}),
vec![]
))
);
}
}

View File

@@ -0,0 +1,99 @@
use sea_query::*;
pub use crate::domain::sql_tables::*;
/// Contains the refresh tokens for a given user.
#[derive(Iden)]
pub enum JwtRefreshStorage {
Table,
RefreshTokenHash,
UserId,
ExpiryDate,
}
/// Contains the blacklisted JWT that haven't expired yet.
#[derive(Iden)]
pub enum JwtStorage {
Table,
JwtHash,
UserId,
ExpiryDate,
Blacklisted,
}
/// This needs to be initialized after the domain tables are.
pub async fn init_table(pool: &Pool) -> sqlx::Result<()> {
sqlx::query(
&Table::create()
.table(JwtRefreshStorage::Table)
.if_not_exists()
.col(
ColumnDef::new(JwtRefreshStorage::RefreshTokenHash)
.big_integer()
.not_null()
.primary_key(),
)
.col(
ColumnDef::new(JwtRefreshStorage::UserId)
.string_len(255)
.not_null(),
)
.col(
ColumnDef::new(JwtRefreshStorage::ExpiryDate)
.date_time()
.not_null(),
)
.foreign_key(
ForeignKey::create()
.name("JwtRefreshStorageUserForeignKey")
.table(JwtRefreshStorage::Table, Users::Table)
.col(JwtRefreshStorage::UserId, Users::UserId)
.on_delete(ForeignKeyAction::Cascade)
.on_update(ForeignKeyAction::Cascade),
)
.to_string(DbQueryBuilder {}),
)
.execute(pool)
.await?;
sqlx::query(
&Table::create()
.table(JwtStorage::Table)
.if_not_exists()
.col(
ColumnDef::new(JwtStorage::JwtHash)
.big_integer()
.not_null()
.primary_key(),
)
.col(
ColumnDef::new(JwtStorage::UserId)
.string_len(255)
.not_null(),
)
.col(
ColumnDef::new(JwtStorage::ExpiryDate)
.date_time()
.not_null(),
)
.col(
ColumnDef::new(JwtStorage::Blacklisted)
.boolean()
.default(false)
.not_null(),
)
.foreign_key(
ForeignKey::create()
.name("JwtStorageUserForeignKey")
.table(JwtStorage::Table, Users::Table)
.col(JwtStorage::UserId, Users::UserId)
.on_delete(ForeignKeyAction::Cascade)
.on_update(ForeignKeyAction::Cascade),
)
.to_string(DbQueryBuilder {}),
)
.execute(pool)
.await?;
Ok(())
}

View File

@@ -0,0 +1,632 @@
use crate::domain::handler::{BackendHandler, LoginHandler, RequestFilter, User};
use anyhow::{bail, Result};
use ldap3_server::simple::*;
fn make_dn_pair<I>(mut iter: I) -> Result<(String, String)>
where
I: Iterator<Item = String>,
{
let pair = (
iter.next()
.ok_or_else(|| anyhow::Error::msg("Empty DN element"))?,
iter.next()
.ok_or_else(|| anyhow::Error::msg("Missing DN value"))?,
);
if let Some(e) = iter.next() {
bail!(
r#"Too many elements in distinguished name: "{:?}", "{:?}", "{:?}""#,
pair.0,
pair.1,
e
)
}
Ok(pair)
}
fn parse_distinguished_name(dn: &str) -> Result<Vec<(String, String)>> {
dn.split(',')
.map(|s| make_dn_pair(s.split('=').map(String::from)))
.collect()
}
fn get_user_id_from_distinguished_name(
dn: &str,
base_tree: &[(String, String)],
base_dn_str: &str,
ldap_user_dn: &str,
) -> Result<String> {
let parts = parse_distinguished_name(dn)?;
if !is_subtree(&parts, base_tree) {
bail!("Not a subtree of the base tree");
}
if parts.len() == base_tree.len() + 1 {
if dn != ldap_user_dn {
bail!(r#"Wrong admin DN. Expected: "{}""#, ldap_user_dn);
}
Ok(parts[0].1.to_string())
} else if parts.len() == base_tree.len() + 2 {
if parts[1].0 != "ou" || parts[1].1 != "people" || parts[0].0 != "cn" {
bail!(
r#"Unexpected user DN format. Expected: "cn=username,ou=people,{}""#,
base_dn_str
);
}
Ok(parts[0].1.to_string())
} else {
bail!(
r#"Unexpected user DN format. Expected: "cn=username,ou=people,{}""#,
base_dn_str
);
}
}
fn get_attribute(user: &User, attribute: &str) -> Result<Vec<String>> {
match attribute {
"objectClass" => Ok(vec![
"inetOrgPerson".to_string(),
"posixAccount".to_string(),
"mailAccount".to_string(),
]),
"uid" => Ok(vec![user.user_id.clone()]),
"mail" => Ok(vec![user.email.clone()]),
"givenName" => Ok(vec![user.first_name.clone().unwrap_or_default()]),
"sn" => Ok(vec![user.last_name.clone().unwrap_or_default()]),
"cn" => Ok(vec![user
.display_name
.clone()
.unwrap_or_else(|| user.user_id.clone())]),
_ => bail!("Unsupported attribute: {}", attribute),
}
}
fn make_ldap_search_result_entry(
user: User,
base_dn_str: &str,
attributes: &[String],
) -> Result<LdapSearchResultEntry> {
Ok(LdapSearchResultEntry {
dn: format!("cn={},{}", user.user_id, base_dn_str),
attributes: attributes
.iter()
.map(|a| {
Ok(LdapPartialAttribute {
atype: a.to_string(),
vals: get_attribute(&user, a)?,
})
})
.collect::<Result<Vec<LdapPartialAttribute>>>()?,
})
}
fn is_subtree(subtree: &[(String, String)], base_tree: &[(String, String)]) -> bool {
if subtree.len() < base_tree.len() {
return false;
}
let size_diff = subtree.len() - base_tree.len();
for i in 0..base_tree.len() {
if subtree[size_diff + i] != base_tree[i] {
return false;
}
}
true
}
fn map_field(field: &str) -> Result<String> {
Ok(if field == "uid" {
"user_id".to_string()
} else if field == "mail" {
"email".to_string()
} else if field == "cn" {
"display_name".to_string()
} else if field == "givenName" {
"first_name".to_string()
} else if field == "sn" {
"last_name".to_string()
} else if field == "avatar" {
"avatar".to_string()
} else if field == "creationDate" {
"creation_date".to_string()
} else {
bail!("Unknown field: {}", field);
})
}
fn convert_filter(filter: &LdapFilter) -> Result<RequestFilter> {
match filter {
LdapFilter::And(filters) => Ok(RequestFilter::And(
filters.iter().map(convert_filter).collect::<Result<_>>()?,
)),
LdapFilter::Or(filters) => Ok(RequestFilter::Or(
filters.iter().map(convert_filter).collect::<Result<_>>()?,
)),
LdapFilter::Not(filter) => Ok(RequestFilter::Not(Box::new(convert_filter(&*filter)?))),
LdapFilter::Equality(field, value) => {
Ok(RequestFilter::Equality(map_field(field)?, value.clone()))
}
_ => bail!("Unsupported filter"),
}
}
pub struct LdapHandler<Backend: BackendHandler + LoginHandler> {
dn: String,
backend_handler: Backend,
pub base_dn: Vec<(String, String)>,
base_dn_str: String,
ldap_user_dn: String,
}
impl<Backend: BackendHandler + LoginHandler> LdapHandler<Backend> {
pub fn new(backend_handler: Backend, ldap_base_dn: String, ldap_user_dn: String) -> Self {
Self {
dn: "Unauthenticated".to_string(),
backend_handler,
base_dn: parse_distinguished_name(&ldap_base_dn).unwrap_or_else(|_| {
panic!(
"Invalid value for ldap_base_dn in configuration: {}",
ldap_base_dn
)
}),
ldap_user_dn: format!("cn={},{}", ldap_user_dn, &ldap_base_dn),
base_dn_str: ldap_base_dn,
}
}
pub async fn do_bind(&mut self, sbr: &SimpleBindRequest) -> LdapMsg {
let user_id = match get_user_id_from_distinguished_name(
&sbr.dn,
&self.base_dn,
&self.base_dn_str,
&self.ldap_user_dn,
) {
Ok(s) => s,
Err(e) => return sbr.gen_error(LdapResultCode::NamingViolation, e.to_string()),
};
match self
.backend_handler
.bind(crate::domain::handler::BindRequest {
name: user_id,
password: sbr.pw.clone(),
})
.await
{
Ok(()) => {
self.dn = sbr.dn.clone();
sbr.gen_success()
}
Err(_) => sbr.gen_invalid_cred(),
}
}
pub async fn do_search(&mut self, lsr: &SearchRequest) -> Vec<LdapMsg> {
if self.dn != self.ldap_user_dn {
return vec![lsr.gen_error(
LdapResultCode::InsufficentAccessRights,
r#"Current user is not allowed to query LDAP"#.to_string(),
)];
}
let dn_parts = match parse_distinguished_name(&lsr.base) {
Ok(dn) => dn,
Err(_) => {
return vec![lsr.gen_error(
LdapResultCode::OperationsError,
format!(r#"Could not parse base DN: "{}""#, lsr.base),
)]
}
};
if !is_subtree(&dn_parts, &self.base_dn) {
// Search path is not in our tree, just return an empty success.
return vec![lsr.gen_success()];
}
let filters = match convert_filter(&lsr.filter) {
Ok(f) => Some(f),
Err(_) => {
return vec![lsr.gen_error(
LdapResultCode::UnwillingToPerform,
"Unsupported filter".to_string(),
)]
}
};
let users = match self.backend_handler.list_users(filters).await {
Ok(users) => users,
Err(e) => {
return vec![lsr.gen_error(
LdapResultCode::Other,
format!(r#"Error during search for "{}": {}"#, lsr.base, e),
)]
}
};
users
.into_iter()
.map(|u| make_ldap_search_result_entry(u, &self.base_dn_str, &lsr.attrs))
.map(|entry| Ok(lsr.gen_result_entry(entry?)))
// If the processing succeeds, add a success message at the end.
.chain(std::iter::once(Ok(lsr.gen_success())))
.collect::<Result<Vec<_>>>()
.unwrap_or_else(|e| vec![lsr.gen_error(LdapResultCode::NoSuchAttribute, e.to_string())])
}
pub fn do_whoami(&mut self, wr: &WhoamiRequest) -> LdapMsg {
if self.dn == "Unauthenticated" {
wr.gen_operror("Unauthenticated")
} else {
wr.gen_success(format!("dn: {}", self.dn).as_str())
}
}
pub async fn handle_ldap_message(&mut self, server_op: ServerOps) -> Option<Vec<LdapMsg>> {
let result = match server_op {
ServerOps::SimpleBind(sbr) => vec![self.do_bind(&sbr).await],
ServerOps::Search(sr) => self.do_search(&sr).await,
ServerOps::Unbind(_) => {
// No need to notify on unbind (per rfc4511)
return None;
}
ServerOps::Whoami(wr) => vec![self.do_whoami(&wr)],
};
Some(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::domain::handler::BindRequest;
use crate::domain::handler::MockTestBackendHandler;
use mockall::predicate::eq;
use tokio;
async fn setup_bound_handler(
mut mock: MockTestBackendHandler,
) -> LdapHandler<MockTestBackendHandler> {
mock.expect_bind()
.with(eq(BindRequest {
name: "test".to_string(),
password: "pass".to_string(),
}))
.return_once(|_| Ok(()));
let mut ldap_handler =
LdapHandler::new(mock, "dc=example,dc=com".to_string(), "test".to_string());
let request = SimpleBindRequest {
msgid: 1,
dn: "cn=test,dc=example,dc=com".to_string(),
pw: "pass".to_string(),
};
ldap_handler.do_bind(&request).await;
ldap_handler
}
#[tokio::test]
async fn test_bind() {
let mut mock = MockTestBackendHandler::new();
mock.expect_bind()
.with(eq(crate::domain::handler::BindRequest {
name: "bob".to_string(),
password: "pass".to_string(),
}))
.times(1)
.return_once(|_| Ok(()));
let mut ldap_handler =
LdapHandler::new(mock, "dc=example,dc=com".to_string(), "test".to_string());
let request = WhoamiRequest { msgid: 1 };
assert_eq!(
ldap_handler.do_whoami(&request),
request.gen_operror("Unauthenticated")
);
let request = SimpleBindRequest {
msgid: 2,
dn: "cn=bob,ou=people,dc=example,dc=com".to_string(),
pw: "pass".to_string(),
};
assert_eq!(ldap_handler.do_bind(&request).await, request.gen_success());
let request = WhoamiRequest { msgid: 3 };
assert_eq!(
ldap_handler.do_whoami(&request),
request.gen_success("dn: cn=bob,ou=people,dc=example,dc=com")
);
}
#[tokio::test]
async fn test_admin_bind() {
let mut mock = MockTestBackendHandler::new();
mock.expect_bind()
.with(eq(crate::domain::handler::BindRequest {
name: "test".to_string(),
password: "pass".to_string(),
}))
.times(1)
.return_once(|_| Ok(()));
let mut ldap_handler =
LdapHandler::new(mock, "dc=example,dc=com".to_string(), "test".to_string());
let request = WhoamiRequest { msgid: 1 };
assert_eq!(
ldap_handler.do_whoami(&request),
request.gen_operror("Unauthenticated")
);
let request = SimpleBindRequest {
msgid: 2,
dn: "cn=test,dc=example,dc=com".to_string(),
pw: "pass".to_string(),
};
assert_eq!(ldap_handler.do_bind(&request).await, request.gen_success());
let request = WhoamiRequest { msgid: 3 };
assert_eq!(
ldap_handler.do_whoami(&request),
request.gen_success("dn: cn=test,dc=example,dc=com")
);
}
#[tokio::test]
async fn test_bind_invalid_credentials() {
let mut mock = MockTestBackendHandler::new();
mock.expect_bind()
.with(eq(crate::domain::handler::BindRequest {
name: "test".to_string(),
password: "pass".to_string(),
}))
.times(1)
.return_once(|_| Ok(()));
let mut ldap_handler =
LdapHandler::new(mock, "dc=example,dc=com".to_string(), "admin".to_string());
let request = WhoamiRequest { msgid: 1 };
assert_eq!(
ldap_handler.do_whoami(&request),
request.gen_operror("Unauthenticated")
);
let request = SimpleBindRequest {
msgid: 2,
dn: "cn=test,ou=people,dc=example,dc=com".to_string(),
pw: "pass".to_string(),
};
assert_eq!(ldap_handler.do_bind(&request).await, request.gen_success());
let request = WhoamiRequest { msgid: 3 };
assert_eq!(
ldap_handler.do_whoami(&request),
request.gen_success("dn: cn=test,ou=people,dc=example,dc=com")
);
let request = SearchRequest {
msgid: 2,
base: "ou=people,dc=example,dc=com".to_string(),
scope: LdapSearchScope::Base,
filter: LdapFilter::And(vec![]),
attrs: vec![],
};
assert_eq!(
ldap_handler.do_search(&request).await,
vec![request.gen_error(
LdapResultCode::InsufficentAccessRights,
r#"Current user is not allowed to query LDAP"#.to_string()
)]
);
}
#[tokio::test]
async fn test_bind_invalid_dn() {
let mock = MockTestBackendHandler::new();
let mut ldap_handler =
LdapHandler::new(mock, "dc=example,dc=com".to_string(), "admin".to_string());
let request = SimpleBindRequest {
msgid: 2,
dn: "cn=bob,dc=example,dc=com".to_string(),
pw: "pass".to_string(),
};
assert_eq!(
ldap_handler.do_bind(&request).await,
request.gen_error(
LdapResultCode::NamingViolation,
r#"Wrong admin DN. Expected: "cn=admin,dc=example,dc=com""#.to_string()
)
);
let request = SimpleBindRequest {
msgid: 2,
dn: "cn=bob,ou=groups,dc=example,dc=com".to_string(),
pw: "pass".to_string(),
};
assert_eq!(
ldap_handler.do_bind(&request).await,
request.gen_error(
LdapResultCode::NamingViolation,
r#"Unexpected user DN format. Expected: "cn=username,ou=people,dc=example,dc=com""#
.to_string()
)
);
}
#[test]
fn test_is_subtree() {
let subtree1 = &[
("ou".to_string(), "people".to_string()),
("dc".to_string(), "example".to_string()),
("dc".to_string(), "com".to_string()),
];
let root = &[
("dc".to_string(), "example".to_string()),
("dc".to_string(), "com".to_string()),
];
assert!(is_subtree(subtree1, root));
assert!(!is_subtree(&[], root));
}
#[test]
fn test_parse_distinguished_name() {
let parsed_dn = &[
("ou".to_string(), "people".to_string()),
("dc".to_string(), "example".to_string()),
("dc".to_string(), "com".to_string()),
];
assert_eq!(
parse_distinguished_name("ou=people,dc=example,dc=com").expect("parsing failed"),
parsed_dn
);
}
#[tokio::test]
async fn test_search() {
let mut mock = MockTestBackendHandler::new();
mock.expect_list_users().times(1).return_once(|_| {
Ok(vec![
User {
user_id: "bob_1".to_string(),
email: "bob@bobmail.bob".to_string(),
display_name: Some("Bôb Böbberson".to_string()),
first_name: Some("Bôb".to_string()),
last_name: Some("Böbberson".to_string()),
..Default::default()
},
User {
user_id: "jim".to_string(),
email: "jim@cricket.jim".to_string(),
display_name: Some("Jimminy Cricket".to_string()),
first_name: Some("Jim".to_string()),
last_name: Some("Cricket".to_string()),
..Default::default()
},
])
});
let mut ldap_handler = setup_bound_handler(mock).await;
let request = SearchRequest {
msgid: 2,
base: "ou=people,dc=example,dc=com".to_string(),
scope: LdapSearchScope::Base,
filter: LdapFilter::And(vec![]),
attrs: vec![
"objectClass".to_string(),
"uid".to_string(),
"mail".to_string(),
"givenName".to_string(),
"sn".to_string(),
"cn".to_string(),
],
};
assert_eq!(
ldap_handler.do_search(&request).await,
vec![
request.gen_result_entry(LdapSearchResultEntry {
dn: "cn=bob_1,dc=example,dc=com".to_string(),
attributes: vec![
LdapPartialAttribute {
atype: "objectClass".to_string(),
vals: vec![
"inetOrgPerson".to_string(),
"posixAccount".to_string(),
"mailAccount".to_string()
]
},
LdapPartialAttribute {
atype: "uid".to_string(),
vals: vec!["bob_1".to_string()]
},
LdapPartialAttribute {
atype: "mail".to_string(),
vals: vec!["bob@bobmail.bob".to_string()]
},
LdapPartialAttribute {
atype: "givenName".to_string(),
vals: vec!["Bôb".to_string()]
},
LdapPartialAttribute {
atype: "sn".to_string(),
vals: vec!["Böbberson".to_string()]
},
LdapPartialAttribute {
atype: "cn".to_string(),
vals: vec!["Bôb Böbberson".to_string()]
}
],
}),
request.gen_result_entry(LdapSearchResultEntry {
dn: "cn=jim,dc=example,dc=com".to_string(),
attributes: vec![
LdapPartialAttribute {
atype: "objectClass".to_string(),
vals: vec![
"inetOrgPerson".to_string(),
"posixAccount".to_string(),
"mailAccount".to_string()
]
},
LdapPartialAttribute {
atype: "uid".to_string(),
vals: vec!["jim".to_string()]
},
LdapPartialAttribute {
atype: "mail".to_string(),
vals: vec!["jim@cricket.jim".to_string()]
},
LdapPartialAttribute {
atype: "givenName".to_string(),
vals: vec!["Jim".to_string()]
},
LdapPartialAttribute {
atype: "sn".to_string(),
vals: vec!["Cricket".to_string()]
},
LdapPartialAttribute {
atype: "cn".to_string(),
vals: vec!["Jimminy Cricket".to_string()]
}
],
}),
request.gen_success()
]
);
}
#[tokio::test]
async fn test_search_filters() {
let mut mock = MockTestBackendHandler::new();
mock.expect_list_users()
.with(eq(Some(RequestFilter::And(vec![RequestFilter::Or(vec![
RequestFilter::Not(Box::new(RequestFilter::Equality(
"user_id".to_string(),
"bob".to_string(),
))),
])]))))
.times(1)
.return_once(|_| Ok(vec![]));
let mut ldap_handler = setup_bound_handler(mock).await;
let request = SearchRequest {
msgid: 2,
base: "ou=people,dc=example,dc=com".to_string(),
scope: LdapSearchScope::Base,
filter: LdapFilter::And(vec![LdapFilter::Or(vec![LdapFilter::Not(Box::new(
LdapFilter::Equality("uid".to_string(), "bob".to_string()),
))])]),
attrs: vec!["objectClass".to_string()],
};
assert_eq!(
ldap_handler.do_search(&request).await,
vec![request.gen_success()]
);
}
#[tokio::test]
async fn test_search_unsupported_filters() {
let mut ldap_handler = setup_bound_handler(MockTestBackendHandler::new()).await;
let request = SearchRequest {
msgid: 2,
base: "ou=people,dc=example,dc=com".to_string(),
scope: LdapSearchScope::Base,
filter: LdapFilter::Present("uid".to_string()),
attrs: vec!["objectClass".to_string()],
};
assert_eq!(
ldap_handler.do_search(&request).await,
vec![request.gen_error(
LdapResultCode::UnwillingToPerform,
"Unsupported filter".to_string()
)]
);
}
}

View File

@@ -0,0 +1,102 @@
use crate::domain::handler::{BackendHandler, LoginHandler};
use crate::infra::configuration::Configuration;
use crate::infra::ldap_handler::LdapHandler;
use actix_rt::net::TcpStream;
use actix_server::ServerBuilder;
use actix_service::{fn_service, ServiceFactoryExt};
use anyhow::{bail, Result};
use futures_util::future::ok;
use ldap3_server::simple::*;
use ldap3_server::LdapCodec;
use log::*;
use tokio::net::tcp::WriteHalf;
use tokio_util::codec::{FramedRead, FramedWrite};
async fn handle_incoming_message<Backend>(
msg: Result<LdapMsg, std::io::Error>,
resp: &mut FramedWrite<WriteHalf<'_>, LdapCodec>,
session: &mut LdapHandler<Backend>,
) -> Result<bool>
where
Backend: BackendHandler + LoginHandler,
{
use futures_util::SinkExt;
use std::convert::TryFrom;
let server_op = match msg.map_err(|_e| ()).and_then(ServerOps::try_from) {
Ok(a_value) => a_value,
Err(an_error) => {
let _err = resp
.send(DisconnectionNotice::gen(
LdapResultCode::Other,
"Internal Server Error",
))
.await;
let _err = resp.flush().await;
bail!("Internal server error: {:?}", an_error);
}
};
match session.handle_ldap_message(server_op).await {
None => return Ok(false),
Some(result) => {
for rmsg in result.into_iter() {
if let Err(e) = resp.send(rmsg).await {
bail!("Error while sending a response: {:?}", e);
}
}
if let Err(e) = resp.flush().await {
bail!("Error while flushing responses: {:?}", e);
}
}
}
Ok(true)
}
pub fn build_ldap_server<Backend>(
config: &Configuration,
backend_handler: Backend,
server_builder: ServerBuilder,
) -> Result<ServerBuilder>
where
Backend: BackendHandler + LoginHandler + 'static,
{
use futures_util::StreamExt;
let ldap_base_dn = config.ldap_base_dn.clone();
let ldap_user_dn = config.ldap_user_dn.clone();
Ok(
server_builder.bind("ldap", ("0.0.0.0", config.ldap_port), move || {
let backend_handler = backend_handler.clone();
let ldap_base_dn = ldap_base_dn.clone();
let ldap_user_dn = ldap_user_dn.clone();
fn_service(move |mut stream: TcpStream| {
let backend_handler = backend_handler.clone();
let ldap_base_dn = ldap_base_dn.clone();
let ldap_user_dn = ldap_user_dn.clone();
async move {
// Configure the codec etc.
let (r, w) = stream.split();
let mut requests = FramedRead::new(r, LdapCodec);
let mut resp = FramedWrite::new(w, LdapCodec);
let mut session = LdapHandler::new(backend_handler, ldap_base_dn, ldap_user_dn);
while let Some(msg) = requests.next().await {
if !handle_incoming_message(msg, &mut resp, &mut session).await? {
break;
}
}
Ok(stream)
}
})
.map_err(|err: anyhow::Error| error!("Service Error: {:?}", err))
// catch
.and_then(move |_| {
// finally
ok(())
})
})?,
)
}

View File

@@ -0,0 +1,25 @@
use crate::infra::configuration::Configuration;
use anyhow::Context;
use tracing::subscriber::set_global_default;
use tracing_log::LogTracer;
pub fn init(config: Configuration) -> anyhow::Result<()> {
let max_log_level = log_level_from_config(config);
let subscriber = tracing_subscriber::fmt()
.with_timer(tracing_subscriber::fmt::time::time())
.with_target(false)
.with_level(true)
.with_max_level(max_log_level)
.finish();
LogTracer::init().context("Failed to set logger")?;
set_global_default(subscriber).context("Failed to set subscriber")?;
Ok(())
}
fn log_level_from_config(config: Configuration) -> tracing::Level {
if config.verbose {
tracing::Level::DEBUG
} else {
tracing::Level::INFO
}
}

12
server/src/infra/mod.rs Normal file
View File

@@ -0,0 +1,12 @@
pub mod auth_service;
pub mod cli;
pub mod configuration;
pub mod db_cleaner;
pub mod graphql;
pub mod jwt_sql_tables;
pub mod ldap_handler;
pub mod ldap_server;
pub mod logging;
pub mod sql_backend_handler;
pub mod tcp_backend_handler;
pub mod tcp_server;

View File

@@ -0,0 +1,105 @@
use super::{jwt_sql_tables::*, tcp_backend_handler::*};
use crate::domain::{error::*, sql_backend_handler::SqlBackendHandler};
use async_trait::async_trait;
use futures_util::StreamExt;
use sea_query::{Expr, Iden, Query, SimpleExpr};
use sqlx::Row;
use std::collections::HashSet;
#[async_trait]
impl TcpBackendHandler for SqlBackendHandler {
async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>> {
use sqlx::Result;
let query = Query::select()
.column(JwtStorage::JwtHash)
.from(JwtStorage::Table)
.to_string(DbQueryBuilder {});
sqlx::query(&query)
.map(|row: DbRow| row.get::<i64, _>(&*JwtStorage::JwtHash.to_string()) as u64)
.fetch(&self.sql_pool)
.collect::<Vec<sqlx::Result<u64>>>()
.await
.into_iter()
.collect::<Result<HashSet<u64>>>()
.map_err(|e| anyhow::anyhow!(e))
}
async fn create_refresh_token(&self, user: &str) -> Result<(String, chrono::Duration)> {
use rand::{distributions::Alphanumeric, rngs::SmallRng, Rng, SeedableRng};
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
// TODO: Initialize the rng only once. Maybe Arc<Cell>?
let mut rng = SmallRng::from_entropy();
let refresh_token: String = std::iter::repeat(())
.map(|()| rng.sample(Alphanumeric))
.map(char::from)
.take(100)
.collect();
let refresh_token_hash = {
let mut s = DefaultHasher::new();
refresh_token.hash(&mut s);
s.finish()
};
let duration = chrono::Duration::days(30);
let query = Query::insert()
.into_table(JwtRefreshStorage::Table)
.columns(vec![
JwtRefreshStorage::RefreshTokenHash,
JwtRefreshStorage::UserId,
JwtRefreshStorage::ExpiryDate,
])
.values_panic(vec![
(refresh_token_hash as i64).into(),
user.into(),
(chrono::Utc::now() + duration).naive_utc().into(),
])
.to_string(DbQueryBuilder {});
sqlx::query(&query).execute(&self.sql_pool).await?;
Ok((refresh_token, duration))
}
async fn check_token(&self, refresh_token_hash: u64, user: &str) -> Result<bool> {
let query = Query::select()
.expr(SimpleExpr::Value(1.into()))
.from(JwtRefreshStorage::Table)
.and_where(Expr::col(JwtRefreshStorage::RefreshTokenHash).eq(refresh_token_hash as i64))
.and_where(Expr::col(JwtRefreshStorage::UserId).eq(user))
.to_string(DbQueryBuilder {});
Ok(sqlx::query(&query)
.fetch_optional(&self.sql_pool)
.await?
.is_some())
}
async fn blacklist_jwts(&self, user: &str) -> DomainResult<HashSet<u64>> {
use sqlx::Result;
let query = Query::select()
.column(JwtStorage::JwtHash)
.from(JwtStorage::Table)
.and_where(Expr::col(JwtStorage::UserId).eq(user))
.and_where(Expr::col(JwtStorage::Blacklisted).eq(true))
.to_string(DbQueryBuilder {});
let result = sqlx::query(&query)
.map(|row: DbRow| row.get::<i64, _>(&*JwtStorage::JwtHash.to_string()) as u64)
.fetch(&self.sql_pool)
.collect::<Vec<sqlx::Result<u64>>>()
.await
.into_iter()
.collect::<Result<HashSet<u64>>>();
let query = Query::update()
.table(JwtStorage::Table)
.values(vec![(JwtStorage::Blacklisted, true.into())])
.and_where(Expr::col(JwtStorage::UserId).eq(user))
.to_string(DbQueryBuilder {});
sqlx::query(&query).execute(&self.sql_pool).await?;
Ok(result?)
}
async fn delete_refresh_token(&self, refresh_token_hash: u64) -> DomainResult<()> {
let query = Query::delete()
.from_table(JwtRefreshStorage::Table)
.and_where(Expr::col(JwtRefreshStorage::RefreshTokenHash).eq(refresh_token_hash))
.to_string(DbQueryBuilder {});
sqlx::query(&query).execute(&self.sql_pool).await?;
Ok(())
}
}

View File

@@ -0,0 +1,46 @@
use async_trait::async_trait;
use std::collections::HashSet;
pub type DomainResult<T> = crate::domain::error::Result<T>;
#[async_trait]
pub trait TcpBackendHandler {
async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>>;
async fn create_refresh_token(&self, user: &str) -> DomainResult<(String, chrono::Duration)>;
async fn check_token(&self, refresh_token_hash: u64, user: &str) -> DomainResult<bool>;
async fn blacklist_jwts(&self, user: &str) -> DomainResult<HashSet<u64>>;
async fn delete_refresh_token(&self, refresh_token_hash: u64) -> DomainResult<()>;
}
#[cfg(test)]
use crate::domain::handler::*;
#[cfg(test)]
mockall::mock! {
pub TestTcpBackendHandler{}
impl Clone for TestTcpBackendHandler {
fn clone(&self) -> Self;
}
#[async_trait]
impl LoginHandler for TestTcpBackendHandler {
async fn bind(&self, request: BindRequest) -> DomainResult<()>;
}
#[async_trait]
impl BackendHandler for TestTcpBackendHandler {
async fn list_users(&self, filters: Option<RequestFilter>) -> DomainResult<Vec<User>>;
async fn list_groups(&self) -> DomainResult<Vec<Group>>;
async fn get_user_details(&self, user_id: &str) -> DomainResult<User>;
async fn get_user_groups(&self, user: &str) -> DomainResult<HashSet<String>>;
async fn create_user(&self, request: CreateUserRequest) -> DomainResult<()>;
async fn delete_user(&self, user_id: &str) -> DomainResult<()>;
async fn create_group(&self, group_name: &str) -> DomainResult<GroupId>;
async fn add_user_to_group(&self, user_id: &str, group_id: GroupId) -> DomainResult<()>;
}
#[async_trait]
impl TcpBackendHandler for TestTcpBackendHandler {
async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>>;
async fn create_refresh_token(&self, user: &str) -> DomainResult<(String, chrono::Duration)>;
async fn check_token(&self, refresh_token_hash: u64, user: &str) -> DomainResult<bool>;
async fn blacklist_jwts(&self, user: &str) -> DomainResult<HashSet<u64>>;
async fn delete_refresh_token(&self, refresh_token_hash: u64) -> DomainResult<()>;
}
}

View File

@@ -0,0 +1,134 @@
use crate::{
domain::{
error::DomainError,
handler::{BackendHandler, LoginHandler},
opaque_handler::OpaqueHandler,
},
infra::{auth_service, configuration::Configuration, tcp_backend_handler::*},
};
use actix_files::{Files, NamedFile};
use actix_http::HttpServiceBuilder;
use actix_server::ServerBuilder;
use actix_service::map_config;
use actix_web::{dev::AppConfig, web, App, HttpRequest, HttpResponse};
use anyhow::{Context, Result};
use hmac::{Hmac, NewMac};
use sha2::Sha512;
use std::collections::HashSet;
use std::path::PathBuf;
use std::sync::RwLock;
async fn index(req: HttpRequest) -> actix_web::Result<NamedFile> {
let mut path = PathBuf::new();
path.push("../app");
let file = req.match_info().query("filename");
path.push(if file.is_empty() { "index.html" } else { file });
Ok(NamedFile::open(path)?)
}
pub(crate) fn error_to_http_response(error: DomainError) -> HttpResponse {
match error {
DomainError::AuthenticationError(_) | DomainError::AuthenticationProtocolError(_) => {
HttpResponse::Unauthorized()
}
DomainError::DatabaseError(_)
| DomainError::InternalError(_)
| DomainError::UnknownCryptoError(_) => HttpResponse::InternalServerError(),
DomainError::Base64DecodeError(_) | DomainError::BinarySerializationError(_) => {
HttpResponse::BadRequest()
}
}
.body(error.to_string())
}
fn http_config<Backend>(
cfg: &mut web::ServiceConfig,
backend_handler: Backend,
jwt_secret: String,
jwt_blacklist: HashSet<u64>,
) where
Backend: TcpBackendHandler + BackendHandler + LoginHandler + OpaqueHandler + Sync + 'static,
{
cfg.app_data(web::Data::new(AppState::<Backend> {
backend_handler,
jwt_key: Hmac::new_varkey(jwt_secret.as_bytes()).unwrap(),
jwt_blacklist: RwLock::new(jwt_blacklist),
}))
// Serve index.html and main.js, and default to index.html.
.route(
"/{filename:(index\\.html|main\\.js)?}",
web::get().to(index),
)
.service(web::scope("/auth").configure(auth_service::configure_server::<Backend>))
// API endpoint.
.service(
web::scope("/api")
.wrap(auth_service::CookieToHeaderTranslatorFactory)
.configure(super::graphql::api::configure_endpoint::<Backend>),
)
// Serve the /pkg path with the compiled WASM app.
.service(Files::new("/pkg", "./app/pkg"))
// Default to serve index.html for unknown routes, to support routing.
.service(web::scope("/").route("/.*", web::get().to(index)));
}
pub(crate) struct AppState<Backend> {
pub backend_handler: Backend,
pub jwt_key: Hmac<Sha512>,
pub jwt_blacklist: RwLock<HashSet<u64>>,
}
pub async fn build_tcp_server<Backend>(
config: &Configuration,
backend_handler: Backend,
server_builder: ServerBuilder,
) -> Result<ServerBuilder>
where
Backend: TcpBackendHandler + BackendHandler + LoginHandler + OpaqueHandler + Sync + 'static,
{
let jwt_secret = config.jwt_secret.clone();
let jwt_blacklist = backend_handler.get_jwt_blacklist().await?;
server_builder
.bind("http", ("0.0.0.0", config.http_port), move || {
let backend_handler = backend_handler.clone();
let jwt_secret = jwt_secret.clone();
let jwt_blacklist = jwt_blacklist.clone();
HttpServiceBuilder::new()
.finish(map_config(
App::new().configure(move |cfg| {
http_config(cfg, backend_handler, jwt_secret, jwt_blacklist)
}),
|_| AppConfig::default(),
))
.tcp()
})
.with_context(|| {
format!(
"While bringing up the TCP server with port {}",
config.http_port
)
})
}
#[cfg(test)]
mod tests {
use super::*;
use actix_web::test::TestRequest;
use std::path::Path;
#[actix_rt::test]
async fn test_index_ok() {
let req = TestRequest::default().to_http_request();
let resp = index(req).await.unwrap();
assert_eq!(resp.path(), Path::new("../app/index.html"));
}
#[actix_rt::test]
async fn test_index_main_js() {
let req = TestRequest::default()
.param("filename", "main.js")
.to_http_request();
let resp = index(req).await.unwrap();
assert_eq!(resp.path(), Path::new("../app/main.js"));
}
}