Implement refresh tokens
This commit is contained in:
@@ -1,12 +1,13 @@
|
||||
use super::sql_tables::*;
|
||||
use crate::domain::{error::*, sql_tables::Pool};
|
||||
use crate::infra::configuration::Configuration;
|
||||
use crate::infra::jwt_sql_tables::*;
|
||||
use async_trait::async_trait;
|
||||
use futures_util::StreamExt;
|
||||
use futures_util::TryStreamExt;
|
||||
use log::*;
|
||||
use sea_query::Iden;
|
||||
use sea_query::{Expr, Order, Query, SimpleExpr, SqliteQueryBuilder};
|
||||
use sea_query::{Expr, Order, Query, SimpleExpr};
|
||||
use sqlx::Row;
|
||||
use std::collections::HashSet;
|
||||
|
||||
@@ -72,7 +73,7 @@ impl BackendHandler for SqlBackendHandler {
|
||||
.column(Users::Password)
|
||||
.from(Users::Table)
|
||||
.and_where(Expr::col(Users::UserId).eq(request.name.as_str()))
|
||||
.to_string(SqliteQueryBuilder);
|
||||
.to_string(DbQueryBuilder {});
|
||||
if let Ok(row) = sqlx::query(&query).fetch_one(&self.sql_pool).await {
|
||||
if passwords_match(
|
||||
&request.password,
|
||||
@@ -109,7 +110,7 @@ impl BackendHandler for SqlBackendHandler {
|
||||
}
|
||||
}
|
||||
|
||||
query_builder.to_string(SqliteQueryBuilder)
|
||||
query_builder.to_string(DbQueryBuilder {})
|
||||
};
|
||||
|
||||
let results = sqlx::query_as::<_, User>(&query)
|
||||
@@ -132,7 +133,7 @@ impl BackendHandler for SqlBackendHandler {
|
||||
)
|
||||
.order_by(Groups::DisplayName, Order::Asc)
|
||||
.order_by(Memberships::UserId, Order::Asc)
|
||||
.to_string(SqliteQueryBuilder);
|
||||
.to_string(DbQueryBuilder {});
|
||||
|
||||
let mut results = sqlx::query(&query).fetch(&self.sql_pool);
|
||||
let mut groups = Vec::new();
|
||||
@@ -178,7 +179,7 @@ impl BackendHandler for SqlBackendHandler {
|
||||
.equals(Memberships::Table, Memberships::GroupId),
|
||||
)
|
||||
.and_where(Expr::col(Memberships::UserId).eq(user))
|
||||
.to_string(SqliteQueryBuilder);
|
||||
.to_string(DbQueryBuilder {});
|
||||
|
||||
sqlx::query(&query)
|
||||
// Extract the group id from the row.
|
||||
@@ -196,6 +197,80 @@ impl BackendHandler for SqlBackendHandler {
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl crate::infra::tcp_server::TcpBackendHandler for SqlBackendHandler {
|
||||
async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>> {
|
||||
use sqlx::Result;
|
||||
let query = Query::select()
|
||||
.column(JwtBlacklist::JwtHash)
|
||||
.from(JwtBlacklist::Table)
|
||||
.to_string(DbQueryBuilder {});
|
||||
|
||||
sqlx::query(&query)
|
||||
.map(|row: DbRow| row.get::<i64, _>(&*JwtBlacklist::JwtHash.to_string()) as u64)
|
||||
.fetch(&self.sql_pool)
|
||||
.collect::<Vec<sqlx::Result<u64>>>()
|
||||
.await
|
||||
.into_iter()
|
||||
.collect::<Result<HashSet<u64>>>()
|
||||
.map_err(|e| anyhow::anyhow!(e))
|
||||
}
|
||||
|
||||
async fn create_refresh_token(&self, user: &str) -> Result<(String, chrono::Duration)> {
|
||||
use rand::{distributions::Alphanumeric, rngs::SmallRng, Rng, SeedableRng};
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{Hash, Hasher};
|
||||
// TODO: Initialize the rng only once. Maybe Arc<Cell>?
|
||||
let mut rng = SmallRng::from_entropy();
|
||||
let refresh_token: String = std::iter::repeat(())
|
||||
.map(|()| rng.sample(Alphanumeric))
|
||||
.map(char::from)
|
||||
.take(100)
|
||||
.collect();
|
||||
let refresh_token_hash = {
|
||||
let mut s = DefaultHasher::new();
|
||||
refresh_token.hash(&mut s);
|
||||
s.finish()
|
||||
};
|
||||
let duration = chrono::Duration::days(30);
|
||||
let query = Query::insert()
|
||||
.into_table(JwtRefreshStorage::Table)
|
||||
.columns(vec![
|
||||
JwtRefreshStorage::RefreshTokenHash,
|
||||
JwtRefreshStorage::UserId,
|
||||
JwtRefreshStorage::ExpiryDate,
|
||||
])
|
||||
.values_panic(vec![
|
||||
(refresh_token_hash as i64).into(),
|
||||
user.into(),
|
||||
(chrono::Utc::now() + duration).naive_utc().into(),
|
||||
])
|
||||
.to_string(DbQueryBuilder {});
|
||||
sqlx::query(&query).execute(&self.sql_pool).await?;
|
||||
Ok((refresh_token, duration))
|
||||
}
|
||||
|
||||
async fn check_token(&self, token: &str, user: &str) -> Result<bool> {
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{Hash, Hasher};
|
||||
let refresh_token_hash = {
|
||||
let mut s = DefaultHasher::new();
|
||||
token.hash(&mut s);
|
||||
s.finish()
|
||||
};
|
||||
let query = Query::select()
|
||||
.expr(SimpleExpr::Value(1.into()))
|
||||
.from(JwtRefreshStorage::Table)
|
||||
.and_where(Expr::col(JwtRefreshStorage::RefreshTokenHash).eq(refresh_token_hash as i64))
|
||||
.and_where(Expr::col(JwtRefreshStorage::UserId).eq(user))
|
||||
.to_string(DbQueryBuilder {});
|
||||
Ok(sqlx::query(&query)
|
||||
.fetch_optional(&self.sql_pool)
|
||||
.await?
|
||||
.is_some())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mockall::mock! {
|
||||
pub TestBackendHandler{}
|
||||
@@ -247,7 +322,7 @@ mod tests {
|
||||
chrono::NaiveDateTime::from_timestamp(0, 0).into(),
|
||||
pass.into(),
|
||||
])
|
||||
.to_string(SqliteQueryBuilder);
|
||||
.to_string(DbQueryBuilder {});
|
||||
sqlx::query(&query).execute(sql_pool).await.unwrap();
|
||||
}
|
||||
|
||||
@@ -256,7 +331,7 @@ mod tests {
|
||||
.into_table(Groups::Table)
|
||||
.columns(vec![Groups::GroupId, Groups::DisplayName])
|
||||
.values_panic(vec![id.into(), name.into()])
|
||||
.to_string(SqliteQueryBuilder);
|
||||
.to_string(DbQueryBuilder {});
|
||||
sqlx::query(&query).execute(sql_pool).await.unwrap();
|
||||
}
|
||||
|
||||
@@ -265,7 +340,7 @@ mod tests {
|
||||
.into_table(Memberships::Table)
|
||||
.columns(vec![Memberships::UserId, Memberships::GroupId])
|
||||
.values_panic(vec![user_id.into(), group_id.into()])
|
||||
.to_string(SqliteQueryBuilder);
|
||||
.to_string(DbQueryBuilder {});
|
||||
sqlx::query(&query).execute(sql_pool).await.unwrap();
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ use sea_query::*;
|
||||
pub type Pool = sqlx::sqlite::SqlitePool;
|
||||
pub type PoolOptions = sqlx::sqlite::SqlitePoolOptions;
|
||||
pub type DbRow = sqlx::sqlite::SqliteRow;
|
||||
pub type DbQueryBuilder = SqliteQueryBuilder;
|
||||
|
||||
#[derive(Iden)]
|
||||
pub enum Users {
|
||||
@@ -60,7 +61,7 @@ pub async fn init_table(pool: &Pool) -> sqlx::Result<()> {
|
||||
.col(ColumnDef::new(Users::Password).string_len(255).not_null())
|
||||
.col(ColumnDef::new(Users::TotpSecret).string_len(64))
|
||||
.col(ColumnDef::new(Users::MfaType).string_len(64))
|
||||
.to_string(SqliteQueryBuilder),
|
||||
.to_string(DbQueryBuilder {}),
|
||||
)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
@@ -79,7 +80,7 @@ pub async fn init_table(pool: &Pool) -> sqlx::Result<()> {
|
||||
.string_len(255)
|
||||
.not_null(),
|
||||
)
|
||||
.to_string(SqliteQueryBuilder),
|
||||
.to_string(DbQueryBuilder {}),
|
||||
)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
@@ -109,7 +110,7 @@ pub async fn init_table(pool: &Pool) -> sqlx::Result<()> {
|
||||
.on_delete(ForeignKeyAction::Cascade)
|
||||
.on_update(ForeignKeyAction::Cascade),
|
||||
)
|
||||
.to_string(SqliteQueryBuilder),
|
||||
.to_string(DbQueryBuilder {}),
|
||||
)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
Reference in New Issue
Block a user