From 504227eb13c855fc3bfb961895622b82b8db517a Mon Sep 17 00:00:00 2001 From: Valentin Tolmer Date: Mon, 30 Oct 2023 21:15:31 +0100 Subject: [PATCH] server: Add JWTs to the DB Otherwise, logging out doesn't actually blacklist the JWT --- server/src/infra/auth_service.rs | 96 +++++++++++++------------ server/src/infra/sql_backend_handler.rs | 20 ++++++ server/src/infra/tcp_backend_handler.rs | 7 ++ 3 files changed, 77 insertions(+), 46 deletions(-) diff --git a/server/src/infra/auth_service.rs b/server/src/infra/auth_service.rs index 7dea093..b5bbac0 100644 --- a/server/src/infra/auth_service.rs +++ b/server/src/infra/auth_service.rs @@ -1,8 +1,3 @@ -use std::collections::{hash_map::DefaultHasher, HashSet}; -use std::hash::{Hash, Hasher}; -use std::pin::Pin; -use std::task::{Context, Poll}; - use actix_web::{ cookie::{Cookie, SameSite}, dev::{Service, ServiceRequest, ServiceResponse, Transform}, @@ -17,6 +12,12 @@ use futures_util::FutureExt; use hmac::Hmac; use jwt::{SignWithKey, VerifyWithKey}; use sha2::Sha512; +use std::{ + collections::HashSet, + hash::Hash, + pin::Pin, + task::{Context, Poll}, +}; use time::ext::NumericalDuration; use tracing::{debug, info, instrument, warn}; @@ -39,31 +40,43 @@ use crate::{ type Token = jwt::Token; type SignedToken = Token; -fn create_jwt(key: &Hmac, user: String, groups: HashSet) -> SignedToken { +fn default_hash(token: &T) -> u64 { + use std::collections::hash_map::DefaultHasher; + use std::hash::Hasher; + let mut s = DefaultHasher::new(); + token.hash(&mut s); + s.finish() +} + +async fn create_jwt( + handler: &Handler, + key: &Hmac, + user: &UserId, + groups: HashSet, +) -> SignedToken { let claims = JWTClaims { exp: Utc::now() + chrono::Duration::days(1), iat: Utc::now(), - user, + user: user.to_string(), groups: groups.into_iter().map(|g| g.display_name).collect(), }; + let expiry = claims.exp.naive_utc(); let header = jwt::Header { algorithm: jwt::AlgorithmType::Hs512, ..Default::default() }; - jwt::Token::new(header, claims).sign_with_key(key).unwrap() + let token = jwt::Token::new(header, claims).sign_with_key(key).unwrap(); + handler + .register_jwt(user, default_hash(token.as_str()), expiry) + .await + .unwrap(); + token } fn parse_refresh_token(token: &str) -> TcpResult<(u64, UserId)> { match token.split_once('+') { None => Err(DomainError::AuthenticationError("Invalid refresh token".to_string()).into()), - Some((token, u)) => { - let refresh_token_hash = { - let mut s = DefaultHasher::new(); - token.hash(&mut s); - s.finish() - }; - Ok((refresh_token_hash, UserId::new(u))) - } + Some((token, u)) => Ok((default_hash(token), UserId::new(u))), } } @@ -99,26 +112,21 @@ where "Invalid refresh token".to_string(), ))); } - Ok(data - .get_readonly_handler() - .get_user_groups(&user) - .await - .map(|groups| create_jwt(jwt_key, user.to_string(), groups)) - .map(|token| { - HttpResponse::Ok() - .cookie( - Cookie::build("token", token.as_str()) - .max_age(1.days()) - .path("/") - .http_only(true) - .same_site(SameSite::Strict) - .finish(), - ) - .json(&login::ServerLoginResponse { - token: token.as_str().to_owned(), - refresh_token: None, - }) - })?) + let groups = data.get_readonly_handler().get_user_groups(&user).await?; + let token = create_jwt(data.get_tcp_handler(), jwt_key, &user, groups).await; + Ok(HttpResponse::Ok() + .cookie( + Cookie::build("token", token.as_str()) + .max_age(1.days()) + .path("/") + .http_only(true) + .same_site(SameSite::Strict) + .finish(), + ) + .json(&login::ServerLoginResponse { + token: token.as_str().to_owned(), + refresh_token: None, + })) } async fn get_refresh_handler( @@ -230,7 +238,7 @@ where .delete_password_reset_token(token) .await; let groups = HashSet::new(); - let token = create_jwt(&data.jwt_key, user_id.to_string(), groups); + let token = create_jwt(data.get_tcp_handler(), &data.jwt_key, &user_id, groups).await; Ok(HttpResponse::Ok() .cookie( Cookie::build("token", token.as_str()) @@ -271,10 +279,10 @@ where data.get_tcp_handler() .delete_refresh_token(refresh_token_hash) .await?; - let new_blacklisted_jwts = data.get_tcp_handler().blacklist_jwts(&user).await?; + let new_blacklisted_jwt_hashes = data.get_tcp_handler().blacklist_jwts(&user).await?; let mut jwt_blacklist = data.jwt_blacklist.write().unwrap(); - for jwt in new_blacklisted_jwts { - jwt_blacklist.insert(jwt); + for jwt_hash in new_blacklisted_jwt_hashes { + jwt_blacklist.insert(jwt_hash); } Ok(HttpResponse::Ok() .cookie( @@ -341,7 +349,7 @@ where // token. let groups = data.get_readonly_handler().get_user_groups(name).await?; let (refresh_token, max_age) = data.get_tcp_handler().create_refresh_token(name).await?; - let token = create_jwt(&data.jwt_key, name.to_string(), groups); + let token = create_jwt(data.get_tcp_handler(), &data.jwt_key, name, groups).await; let refresh_token_plus_name = refresh_token + "+" + name.as_str(); Ok(HttpResponse::Ok() @@ -604,11 +612,7 @@ pub(crate) fn check_if_token_is_valid( token.header().algorithm ))); } - let jwt_hash = { - let mut s = DefaultHasher::new(); - token_str.hash(&mut s); - s.finish() - }; + let jwt_hash = default_hash(token_str); if state.jwt_blacklist.read().unwrap().contains(&jwt_hash) { return Err(ErrorUnauthorized("JWT was logged out")); } diff --git a/server/src/infra/sql_backend_handler.rs b/server/src/infra/sql_backend_handler.rs index 54ca857..b64d13d 100644 --- a/server/src/infra/sql_backend_handler.rs +++ b/server/src/infra/sql_backend_handler.rs @@ -6,6 +6,7 @@ use crate::domain::{ types::UserId, }; use async_trait::async_trait; +use chrono::NaiveDateTime; use sea_orm::{ sea_query::{Cond, Expr}, ActiveModelTrait, ColumnTrait, EntityTrait, IntoActiveModel, QueryFilter, QuerySelect, @@ -62,6 +63,25 @@ impl TcpBackendHandler for SqlBackendHandler { Ok((refresh_token, duration)) } + #[instrument(skip_all, level = "debug")] + async fn register_jwt( + &self, + user: &UserId, + jwt_hash: u64, + expiry_date: NaiveDateTime, + ) -> Result<()> { + debug!(?user, ?jwt_hash); + let new_token = model::jwt_storage::Model { + jwt_hash: jwt_hash as i64, + user_id: user.clone(), + blacklisted: false, + expiry_date, + } + .into_active_model(); + new_token.insert(&self.sql_pool).await?; + Ok(()) + } + #[instrument(skip_all, level = "debug")] async fn check_token(&self, refresh_token_hash: u64, user: &UserId) -> Result { debug!(?user); diff --git a/server/src/infra/tcp_backend_handler.rs b/server/src/infra/tcp_backend_handler.rs index e01e531..58299c1 100644 --- a/server/src/infra/tcp_backend_handler.rs +++ b/server/src/infra/tcp_backend_handler.rs @@ -1,4 +1,5 @@ use async_trait::async_trait; +use chrono::NaiveDateTime; use std::collections::HashSet; use crate::domain::{error::Result, types::UserId}; @@ -7,6 +8,12 @@ use crate::domain::{error::Result, types::UserId}; pub trait TcpBackendHandler: Sync { async fn get_jwt_blacklist(&self) -> anyhow::Result>; async fn create_refresh_token(&self, user: &UserId) -> Result<(String, chrono::Duration)>; + async fn register_jwt( + &self, + user: &UserId, + jwt_hash: u64, + expiry_date: NaiveDateTime, + ) -> Result<()>; async fn check_token(&self, refresh_token_hash: u64, user: &UserId) -> Result; async fn blacklist_jwts(&self, user: &UserId) -> Result>; async fn delete_refresh_token(&self, refresh_token_hash: u64) -> Result<()>;