Implement refresh tokens

This commit is contained in:
Valentin Tolmer
2021-05-20 17:40:30 +02:00
parent 312d9b7a6f
commit d5cb53ae8a
8 changed files with 301 additions and 29 deletions

View File

@@ -1,12 +1,13 @@
use super::sql_tables::*;
use crate::domain::{error::*, sql_tables::Pool};
use crate::infra::configuration::Configuration;
use crate::infra::jwt_sql_tables::*;
use async_trait::async_trait;
use futures_util::StreamExt;
use futures_util::TryStreamExt;
use log::*;
use sea_query::Iden;
use sea_query::{Expr, Order, Query, SimpleExpr, SqliteQueryBuilder};
use sea_query::{Expr, Order, Query, SimpleExpr};
use sqlx::Row;
use std::collections::HashSet;
@@ -72,7 +73,7 @@ impl BackendHandler for SqlBackendHandler {
.column(Users::Password)
.from(Users::Table)
.and_where(Expr::col(Users::UserId).eq(request.name.as_str()))
.to_string(SqliteQueryBuilder);
.to_string(DbQueryBuilder {});
if let Ok(row) = sqlx::query(&query).fetch_one(&self.sql_pool).await {
if passwords_match(
&request.password,
@@ -109,7 +110,7 @@ impl BackendHandler for SqlBackendHandler {
}
}
query_builder.to_string(SqliteQueryBuilder)
query_builder.to_string(DbQueryBuilder {})
};
let results = sqlx::query_as::<_, User>(&query)
@@ -132,7 +133,7 @@ impl BackendHandler for SqlBackendHandler {
)
.order_by(Groups::DisplayName, Order::Asc)
.order_by(Memberships::UserId, Order::Asc)
.to_string(SqliteQueryBuilder);
.to_string(DbQueryBuilder {});
let mut results = sqlx::query(&query).fetch(&self.sql_pool);
let mut groups = Vec::new();
@@ -178,7 +179,7 @@ impl BackendHandler for SqlBackendHandler {
.equals(Memberships::Table, Memberships::GroupId),
)
.and_where(Expr::col(Memberships::UserId).eq(user))
.to_string(SqliteQueryBuilder);
.to_string(DbQueryBuilder {});
sqlx::query(&query)
// Extract the group id from the row.
@@ -196,6 +197,80 @@ impl BackendHandler for SqlBackendHandler {
}
}
#[async_trait]
impl crate::infra::tcp_server::TcpBackendHandler for SqlBackendHandler {
async fn get_jwt_blacklist(&self) -> anyhow::Result<HashSet<u64>> {
use sqlx::Result;
let query = Query::select()
.column(JwtBlacklist::JwtHash)
.from(JwtBlacklist::Table)
.to_string(DbQueryBuilder {});
sqlx::query(&query)
.map(|row: DbRow| row.get::<i64, _>(&*JwtBlacklist::JwtHash.to_string()) as u64)
.fetch(&self.sql_pool)
.collect::<Vec<sqlx::Result<u64>>>()
.await
.into_iter()
.collect::<Result<HashSet<u64>>>()
.map_err(|e| anyhow::anyhow!(e))
}
async fn create_refresh_token(&self, user: &str) -> Result<(String, chrono::Duration)> {
use rand::{distributions::Alphanumeric, rngs::SmallRng, Rng, SeedableRng};
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
// TODO: Initialize the rng only once. Maybe Arc<Cell>?
let mut rng = SmallRng::from_entropy();
let refresh_token: String = std::iter::repeat(())
.map(|()| rng.sample(Alphanumeric))
.map(char::from)
.take(100)
.collect();
let refresh_token_hash = {
let mut s = DefaultHasher::new();
refresh_token.hash(&mut s);
s.finish()
};
let duration = chrono::Duration::days(30);
let query = Query::insert()
.into_table(JwtRefreshStorage::Table)
.columns(vec![
JwtRefreshStorage::RefreshTokenHash,
JwtRefreshStorage::UserId,
JwtRefreshStorage::ExpiryDate,
])
.values_panic(vec![
(refresh_token_hash as i64).into(),
user.into(),
(chrono::Utc::now() + duration).naive_utc().into(),
])
.to_string(DbQueryBuilder {});
sqlx::query(&query).execute(&self.sql_pool).await?;
Ok((refresh_token, duration))
}
async fn check_token(&self, token: &str, user: &str) -> Result<bool> {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let refresh_token_hash = {
let mut s = DefaultHasher::new();
token.hash(&mut s);
s.finish()
};
let query = Query::select()
.expr(SimpleExpr::Value(1.into()))
.from(JwtRefreshStorage::Table)
.and_where(Expr::col(JwtRefreshStorage::RefreshTokenHash).eq(refresh_token_hash as i64))
.and_where(Expr::col(JwtRefreshStorage::UserId).eq(user))
.to_string(DbQueryBuilder {});
Ok(sqlx::query(&query)
.fetch_optional(&self.sql_pool)
.await?
.is_some())
}
}
#[cfg(test)]
mockall::mock! {
pub TestBackendHandler{}
@@ -247,7 +322,7 @@ mod tests {
chrono::NaiveDateTime::from_timestamp(0, 0).into(),
pass.into(),
])
.to_string(SqliteQueryBuilder);
.to_string(DbQueryBuilder {});
sqlx::query(&query).execute(sql_pool).await.unwrap();
}
@@ -256,7 +331,7 @@ mod tests {
.into_table(Groups::Table)
.columns(vec![Groups::GroupId, Groups::DisplayName])
.values_panic(vec![id.into(), name.into()])
.to_string(SqliteQueryBuilder);
.to_string(DbQueryBuilder {});
sqlx::query(&query).execute(sql_pool).await.unwrap();
}
@@ -265,7 +340,7 @@ mod tests {
.into_table(Memberships::Table)
.columns(vec![Memberships::UserId, Memberships::GroupId])
.values_panic(vec![user_id.into(), group_id.into()])
.to_string(SqliteQueryBuilder);
.to_string(DbQueryBuilder {});
sqlx::query(&query).execute(sql_pool).await.unwrap();
}

View File

@@ -3,6 +3,7 @@ use sea_query::*;
pub type Pool = sqlx::sqlite::SqlitePool;
pub type PoolOptions = sqlx::sqlite::SqlitePoolOptions;
pub type DbRow = sqlx::sqlite::SqliteRow;
pub type DbQueryBuilder = SqliteQueryBuilder;
#[derive(Iden)]
pub enum Users {
@@ -60,7 +61,7 @@ pub async fn init_table(pool: &Pool) -> sqlx::Result<()> {
.col(ColumnDef::new(Users::Password).string_len(255).not_null())
.col(ColumnDef::new(Users::TotpSecret).string_len(64))
.col(ColumnDef::new(Users::MfaType).string_len(64))
.to_string(SqliteQueryBuilder),
.to_string(DbQueryBuilder {}),
)
.execute(pool)
.await?;
@@ -79,7 +80,7 @@ pub async fn init_table(pool: &Pool) -> sqlx::Result<()> {
.string_len(255)
.not_null(),
)
.to_string(SqliteQueryBuilder),
.to_string(DbQueryBuilder {}),
)
.execute(pool)
.await?;
@@ -109,7 +110,7 @@ pub async fn init_table(pool: &Pool) -> sqlx::Result<()> {
.on_delete(ForeignKeyAction::Cascade)
.on_update(ForeignKeyAction::Cascade),
)
.to_string(SqliteQueryBuilder),
.to_string(DbQueryBuilder {}),
)
.execute(pool)
.await?;