graphql: Add a method to list groups
This commit is contained in:
committed by
nitnelave
parent
e4d6b122c5
commit
480f48f820
@@ -31,6 +31,7 @@ impl Default for User {
|
||||
|
||||
#[derive(PartialEq, Eq, Debug, Serialize, Deserialize)]
|
||||
pub struct Group {
|
||||
pub id: GroupId,
|
||||
pub display_name: String,
|
||||
pub users: Vec<String>,
|
||||
}
|
||||
@@ -74,9 +75,12 @@ pub trait LoginHandler: Clone + Send {
|
||||
async fn bind(&self, request: BindRequest) -> Result<()>;
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub struct GroupId(pub i32);
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub struct GroupIdAndName(pub GroupId, pub String);
|
||||
|
||||
#[async_trait]
|
||||
pub trait BackendHandler: Clone + Send {
|
||||
async fn list_users(&self, filters: Option<RequestFilter>) -> Result<Vec<User>>;
|
||||
@@ -88,7 +92,7 @@ pub trait BackendHandler: Clone + Send {
|
||||
async fn create_group(&self, group_name: &str) -> Result<GroupId>;
|
||||
async fn add_user_to_group(&self, user_id: &str, group_id: GroupId) -> Result<()>;
|
||||
async fn remove_user_from_group(&self, user_id: &str, group_id: GroupId) -> Result<()>;
|
||||
async fn get_user_groups(&self, user: &str) -> Result<HashSet<String>>;
|
||||
async fn get_user_groups(&self, user: &str) -> Result<HashSet<GroupIdAndName>>;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -106,7 +110,7 @@ mockall::mock! {
|
||||
async fn update_user(&self, request: UpdateUserRequest) -> Result<()>;
|
||||
async fn delete_user(&self, user_id: &str) -> Result<()>;
|
||||
async fn create_group(&self, group_name: &str) -> Result<GroupId>;
|
||||
async fn get_user_groups(&self, user: &str) -> Result<HashSet<String>>;
|
||||
async fn get_user_groups(&self, user: &str) -> Result<HashSet<GroupIdAndName>>;
|
||||
async fn add_user_to_group(&self, user_id: &str, group_id: GroupId) -> Result<()>;
|
||||
async fn remove_user_from_group(&self, user_id: &str, group_id: GroupId) -> Result<()>;
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ use super::{error::*, handler::*, sql_tables::*};
|
||||
use crate::infra::configuration::Configuration;
|
||||
use async_trait::async_trait;
|
||||
use futures_util::StreamExt;
|
||||
use futures_util::TryStreamExt;
|
||||
use sea_query::{Expr, Iden, Order, Query, SimpleExpr};
|
||||
use sqlx::Row;
|
||||
use std::collections::HashSet;
|
||||
@@ -76,6 +75,7 @@ impl BackendHandler for SqlBackendHandler {
|
||||
|
||||
async fn list_groups(&self) -> Result<Vec<Group>> {
|
||||
let query: String = Query::select()
|
||||
.column((Groups::Table, Groups::GroupId))
|
||||
.column(Groups::DisplayName)
|
||||
.column(Memberships::UserId)
|
||||
.from(Groups::Table)
|
||||
@@ -88,32 +88,33 @@ impl BackendHandler for SqlBackendHandler {
|
||||
.order_by(Memberships::UserId, Order::Asc)
|
||||
.to_string(DbQueryBuilder {});
|
||||
|
||||
let mut results = sqlx::query(&query).fetch(&self.sql_pool);
|
||||
// For group_by.
|
||||
use itertools::Itertools;
|
||||
let mut groups = Vec::new();
|
||||
// The rows are ordered by group, user, so we need to group them into vectors.
|
||||
// The rows are returned sorted by display_name, equivalent to group_id. We group them by
|
||||
// this key which gives us one element (`rows`) per group.
|
||||
for ((group_id, display_name), rows) in &sqlx::query(&query)
|
||||
.fetch_all(&self.sql_pool)
|
||||
.await?
|
||||
.into_iter()
|
||||
.group_by(|row| {
|
||||
(
|
||||
GroupId(row.get::<i32, _>(&*Groups::GroupId.to_string())),
|
||||
row.get::<String, _>(&*Groups::DisplayName.to_string()),
|
||||
)
|
||||
})
|
||||
{
|
||||
let mut current_group = String::new();
|
||||
let mut current_users = Vec::new();
|
||||
while let Some(row) = results.try_next().await? {
|
||||
let display_name = row.get::<String, _>(&*Groups::DisplayName.to_string());
|
||||
if display_name != current_group {
|
||||
if !current_group.is_empty() {
|
||||
groups.push(Group {
|
||||
display_name: current_group,
|
||||
users: current_users,
|
||||
});
|
||||
current_users = Vec::new();
|
||||
}
|
||||
current_group = display_name.clone();
|
||||
}
|
||||
current_users.push(row.get::<String, _>(&*Memberships::UserId.to_string()));
|
||||
}
|
||||
groups.push(Group {
|
||||
display_name: current_group,
|
||||
users: current_users,
|
||||
id: group_id,
|
||||
display_name,
|
||||
users: rows
|
||||
.map(|row| row.get::<String, _>(&*Memberships::UserId.to_string()))
|
||||
// If a group has no users, an empty string is returned because of the left
|
||||
// join.
|
||||
.filter(|s| !s.is_empty())
|
||||
.collect(),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(groups)
|
||||
}
|
||||
|
||||
@@ -135,13 +136,14 @@ impl BackendHandler for SqlBackendHandler {
|
||||
.await?)
|
||||
}
|
||||
|
||||
async fn get_user_groups(&self, user: &str) -> Result<HashSet<String>> {
|
||||
async fn get_user_groups(&self, user: &str) -> Result<HashSet<GroupIdAndName>> {
|
||||
if user == self.config.ldap_user_dn {
|
||||
let mut groups = HashSet::new();
|
||||
groups.insert("lldap_admin".to_string());
|
||||
groups.insert(GroupIdAndName(GroupId(1), "lldap_admin".to_string()));
|
||||
return Ok(groups);
|
||||
}
|
||||
let query: String = Query::select()
|
||||
.column((Groups::Table, Groups::GroupId))
|
||||
.column(Groups::DisplayName)
|
||||
.from(Groups::Table)
|
||||
.inner_join(
|
||||
@@ -154,10 +156,15 @@ impl BackendHandler for SqlBackendHandler {
|
||||
|
||||
sqlx::query(&query)
|
||||
// Extract the group id from the row.
|
||||
.map(|row: DbRow| row.get::<String, _>(&*Groups::DisplayName.to_string()))
|
||||
.map(|row: DbRow| {
|
||||
GroupIdAndName(
|
||||
row.get::<GroupId, _>(&*Groups::GroupId.to_string()),
|
||||
row.get::<String, _>(&*Groups::DisplayName.to_string()),
|
||||
)
|
||||
})
|
||||
.fetch(&self.sql_pool)
|
||||
// Collect the vector of rows, each potentially an error.
|
||||
.collect::<Vec<sqlx::Result<String>>>()
|
||||
.collect::<Vec<sqlx::Result<GroupIdAndName>>>()
|
||||
.await
|
||||
.into_iter()
|
||||
// Transform it into a single result (the first error if any), and group the group_ids
|
||||
@@ -468,6 +475,7 @@ mod tests {
|
||||
insert_user(&handler, "John", "Pa33w0rd!").await;
|
||||
let group_1 = insert_group(&handler, "Best Group").await;
|
||||
let group_2 = insert_group(&handler, "Worst Group").await;
|
||||
let group_3 = insert_group(&handler, "Empty Group").await;
|
||||
insert_membership(&handler, group_1, "bob").await;
|
||||
insert_membership(&handler, group_1, "patrick").await;
|
||||
insert_membership(&handler, group_2, "patrick").await;
|
||||
@@ -476,13 +484,20 @@ mod tests {
|
||||
handler.list_groups().await.unwrap(),
|
||||
vec![
|
||||
Group {
|
||||
id: group_1,
|
||||
display_name: "Best Group".to_string(),
|
||||
users: vec!["bob".to_string(), "patrick".to_string()]
|
||||
},
|
||||
Group {
|
||||
id: group_3,
|
||||
display_name: "Empty Group".to_string(),
|
||||
users: vec![]
|
||||
},
|
||||
Group {
|
||||
id: group_2,
|
||||
display_name: "Worst Group".to_string(),
|
||||
users: vec!["John".to_string(), "patrick".to_string()]
|
||||
}
|
||||
},
|
||||
]
|
||||
);
|
||||
}
|
||||
@@ -515,10 +530,10 @@ mod tests {
|
||||
insert_membership(&handler, group_1, "patrick").await;
|
||||
insert_membership(&handler, group_2, "patrick").await;
|
||||
let mut bob_groups = HashSet::new();
|
||||
bob_groups.insert("Group1".to_string());
|
||||
bob_groups.insert(GroupIdAndName(group_1, "Group1".to_string()));
|
||||
let mut patrick_groups = HashSet::new();
|
||||
patrick_groups.insert("Group1".to_string());
|
||||
patrick_groups.insert("Group2".to_string());
|
||||
patrick_groups.insert(GroupIdAndName(group_1, "Group1".to_string()));
|
||||
patrick_groups.insert(GroupIdAndName(group_2, "Group2".to_string()));
|
||||
assert_eq!(handler.get_user_groups("bob").await.unwrap(), bob_groups);
|
||||
assert_eq!(
|
||||
handler.get_user_groups("patrick").await.unwrap(),
|
||||
|
||||
@@ -12,6 +12,31 @@ impl From<GroupId> for Value {
|
||||
}
|
||||
}
|
||||
|
||||
impl<DB> sqlx::Type<DB> for GroupId
|
||||
where
|
||||
DB: sqlx::Database,
|
||||
i32: sqlx::Type<DB>,
|
||||
{
|
||||
fn type_info() -> <DB as sqlx::Database>::TypeInfo {
|
||||
<i32 as sqlx::Type<DB>>::type_info()
|
||||
}
|
||||
fn compatible(ty: &<DB as sqlx::Database>::TypeInfo) -> bool {
|
||||
<i32 as sqlx::Type<DB>>::compatible(ty)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'r, DB> sqlx::Decode<'r, DB> for GroupId
|
||||
where
|
||||
DB: sqlx::Database,
|
||||
i32: sqlx::Decode<'r, DB>,
|
||||
{
|
||||
fn decode(
|
||||
value: <DB as sqlx::database::HasValueRef<'r>>::ValueRef,
|
||||
) -> Result<Self, Box<dyn std::error::Error + Sync + Send + 'static>> {
|
||||
<i32 as sqlx::Decode<'r, DB>>::decode(value).map(GroupId)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Iden)]
|
||||
pub enum Users {
|
||||
Table,
|
||||
|
||||
Reference in New Issue
Block a user