gitdataai/libs/room/src/helpers.rs
ZhenYi abcfc5b3bb refactor(room): simplify room core modules and connection handling
Extract connection pool management and helper utilities.
Remove redundant metrics indirection, expose counters directly.
Trim room.rs boilerplate and move AI queue logic to room_ai_queue.
2026-04-30 19:16:33 +08:00

706 lines
24 KiB
Rust

use crate::error::RoomError;
use crate::service::RoomService;
use models::agents::model as ai_model;
use models::projects::{MemberRole, project, project_history_name, project_members};
use models::rooms::{
MessageContentType, RoomMemberRole, room, room_ai, room_category, room_member, room_message,
room_notifications, room_pin, room_thread,
};
use models::users::user as user_model;
use sea_orm::*;
use uuid::Uuid;
impl From<room_category::Model> for super::RoomCategoryResponse {
fn from(value: room_category::Model) -> Self {
Self {
id: value.id,
project: value.project,
name: value.name,
position: value.position,
created_by: value.created_by,
created_at: value.created_at,
}
}
}
impl From<room::Model> for super::RoomResponse {
fn from(value: room::Model) -> Self {
Self {
id: value.id,
project: value.project,
room_name: value.room_name,
public: value.public,
category: value.category,
created_by: value.created_by,
created_at: value.created_at,
last_msg_at: value.last_msg_at,
unread_count: 0,
version: 0,
}
}
}
impl From<room_member::Model> for super::RoomMemberResponse {
fn from(value: room_member::Model) -> Self {
Self {
room: value.room,
user: value.user,
user_info: None,
role: value.role.to_string(),
first_msg_in: value.first_msg_in,
joined_at: value.joined_at,
last_read_seq: value.last_read_seq,
do_not_disturb: value.do_not_disturb,
dnd_start_hour: value.dnd_start_hour,
dnd_end_hour: value.dnd_end_hour,
}
}
}
impl From<room_message::Model> for super::RoomMessageResponse {
fn from(value: room_message::Model) -> Self {
let chunked = super::RoomMessageResponse::detect_chunked(&value.thinking_content);
Self {
id: value.id,
seq: value.seq,
room: value.room,
sender_type: value.sender_type.to_string(),
sender_id: value.sender_id,
display_name: None,
thread: value.thread,
content: value.content,
content_type: value.content_type.to_string(),
thinking_content: value.thinking_content,
thinking_is_chunked: chunked,
edited_at: value.edited_at,
send_at: value.send_at,
revoked: value.revoked,
revoked_by: value.revoked_by,
in_reply_to: value.in_reply_to,
highlighted_content: None,
attachment_ids: Vec::new(),
}
}
}
impl From<room_thread::Model> for super::RoomThreadResponse {
fn from(value: room_thread::Model) -> Self {
Self {
id: value.id,
room: value.room,
parent: value.parent,
created_by: value.created_by,
participants: value.participants,
last_message_at: value.last_message_at,
last_message_preview: value.last_message_preview,
created_at: value.created_at,
updated_at: value.updated_at,
}
}
}
impl From<room_pin::Model> for super::RoomPinResponse {
fn from(value: room_pin::Model) -> Self {
Self {
room: value.room,
message: value.message,
pinned_by: value.pinned_by,
pinned_at: value.pinned_at,
}
}
}
impl From<room_ai::Model> for super::RoomAiResponse {
fn from(value: room_ai::Model) -> Self {
Self {
room: value.room,
model: value.model,
model_name: None,
version: value.version,
call_count: value.call_count,
last_call_at: value.last_call_at,
history_limit: value.history_limit,
system_prompt: value.system_prompt,
temperature: value.temperature,
max_tokens: value.max_tokens,
use_exact: value.use_exact,
think: value.think,
stream: value.stream,
min_score: value.min_score,
agent_type: value.agent_type,
created_at: value.created_at,
updated_at: value.updated_at,
}
}
}
impl From<room_notifications::Model> for super::NotificationResponse {
fn from(value: room_notifications::Model) -> Self {
Self {
id: value.id,
room: value.room,
project: value.project,
user_id: value.user_id,
user_info: None,
notification_type: value.notification_type.to_string(),
title: value.title,
content: value.content,
related_message_id: value.related_message_id,
related_user_id: value.related_user_id,
related_room_id: value.related_room_id,
metadata: value.metadata.unwrap_or(serde_json::json!({})),
is_read: value.is_read,
is_archived: value.is_archived,
created_at: value.created_at,
read_at: value.read_at,
expires_at: value.expires_at,
}
}
}
impl RoomService {
pub(crate) fn parse_room_member_role(role: &str) -> Result<RoomMemberRole, RoomError> {
match role {
"owner" => Ok(RoomMemberRole::Owner),
"admin" => Ok(RoomMemberRole::Admin),
"member" => Ok(RoomMemberRole::Member),
"guest" => Ok(RoomMemberRole::Guest),
_ => Err(RoomError::BadRequest("invalid room role".to_string())),
}
}
pub(crate) fn parse_message_content_type(
content_type: Option<String>,
) -> Result<MessageContentType, RoomError> {
match content_type
.unwrap_or_else(|| "text".to_string())
.to_lowercase()
.as_str()
{
"text" => Ok(MessageContentType::Text),
"image" => Ok(MessageContentType::Image),
"audio" => Ok(MessageContentType::Audio),
"video" => Ok(MessageContentType::Video),
"file" => Ok(MessageContentType::File),
_ => Err(RoomError::BadRequest(
"invalid message content_type".to_string(),
)),
}
}
pub(crate) async fn find_room_member(
&self,
room_id: Uuid,
user_id: Uuid,
) -> Result<Option<room_member::Model>, RoomError> {
room_member::Entity::find_by_id((room_id, user_id))
.one(&self.db)
.await
.map_err(RoomError::from)
}
pub(crate) async fn require_room_member_model(
&self,
room_id: Uuid,
user_id: Uuid,
) -> Result<room_member::Model, RoomError> {
self.find_room_member(room_id, user_id)
.await?
.ok_or(RoomError::NoPower)
}
pub(crate) fn is_room_admin(role: &RoomMemberRole) -> bool {
matches!(role, RoomMemberRole::Owner | RoomMemberRole::Admin)
}
pub(crate) async fn require_room_admin(
&self,
room_id: Uuid,
user_id: Uuid,
) -> Result<room_member::Model, RoomError> {
let member = self.require_room_member_model(room_id, user_id).await?;
if Self::is_room_admin(&member.role) {
Ok(member)
} else {
Err(RoomError::NoPower)
}
}
pub(crate) async fn require_project_admin(
&self,
project_id: Uuid,
user_id: Uuid,
) -> Result<project_members::Model, RoomError> {
let member = project_members::Entity::find()
.filter(project_members::Column::Project.eq(project_id))
.filter(project_members::Column::User.eq(user_id))
.one(&self.db)
.await?
.ok_or(RoomError::NoPower)?;
let role = member.scope_role().map_err(|_| RoomError::RoleParseError)?;
if matches!(role, MemberRole::Owner | MemberRole::Admin) {
Ok(member)
} else {
Err(RoomError::NoPower)
}
}
pub(crate) async fn ensure_room_visible_for_user(
&self,
room: &room::Model,
user_id: Uuid,
) -> Result<(), RoomError> {
if self.find_room_member(room.id, user_id).await?.is_some() {
return Ok(());
}
let project_member = project_members::Entity::find()
.filter(project_members::Column::Project.eq(room.project))
.filter(project_members::Column::User.eq(user_id))
.one(&self.db)
.await?;
if room.public && project_member.is_some() {
Ok(())
} else {
Err(RoomError::NoPower)
}
}
pub async fn utils_find_project_by_name(
&self,
name: String,
) -> Result<project::Model, RoomError> {
match project::Entity::find()
.filter(project::Column::Name.eq(name.clone()))
.one(&self.db)
.await
.inspect_err(|e| {
tracing::warn!(error = %e, project_name = %name, "utils_find_project_by_name: DB error");
})
.ok()
.flatten()
{
Some(project) => Ok(project),
None => match project_history_name::Entity::find()
.filter(project_history_name::Column::HistoryName.eq(name.clone()))
.one(&self.db)
.await
.inspect_err(|e| tracing::warn!(error = %e, name = %name, "project_history_name lookup failed"))
.ok()
.flatten()
{
Some(project) => self.utils_find_project_by_uid(project.project_uid).await,
None => Err(RoomError::NotFound("Project not found".to_string())),
},
}
}
pub async fn utils_find_project_by_uid(&self, uid: Uuid) -> Result<project::Model, RoomError> {
project::Entity::find_by_id(uid)
.one(&self.db)
.await
.inspect_err(|e| tracing::warn!(error = %e, project_uid = %uid, "utils_find_project_by_uid: DB error"))
.ok()
.flatten()
.ok_or_else(|| RoomError::NotFound("Project not found".to_string()))
}
pub async fn check_project_access(
&self,
project_uid: Uuid,
user_uid: Uuid,
) -> Result<(), RoomError> {
let project = project::Entity::find_by_id(project_uid)
.one(&self.db)
.await
.inspect_err(|e| tracing::warn!(error = %e, project_uid = %project_uid, "check_project_access: DB error"))
.ok()
.flatten()
.ok_or_else(|| RoomError::NotFound("Project not found".to_string()))?;
if project.is_public {
return Ok(());
}
let member = project_members::Entity::find()
.filter(project_members::Column::Project.eq(project_uid))
.filter(project_members::Column::User.eq(user_uid))
.one(&self.db)
.await?;
if member.is_some() {
Ok(())
} else {
Err(RoomError::NoPower)
}
}
pub(crate) fn validate_name(name: &str, max_len: usize) -> Result<(), RoomError> {
if name.trim().is_empty() {
return Err(RoomError::BadRequest("name cannot be empty".to_string()));
}
if name.len() > max_len {
return Err(RoomError::BadRequest(format!(
"name exceeds maximum length of {} characters",
max_len
)));
}
Ok(())
}
pub(crate) fn validate_content(content: &str, max_len: usize) -> Result<(), RoomError> {
if content.trim().is_empty() {
return Err(RoomError::BadRequest("content cannot be empty".to_string()));
}
if content.len() > max_len {
return Err(RoomError::BadRequest(format!(
"content exceeds maximum length of {} characters",
max_len
)));
}
Ok(())
}
pub(crate) fn sanitize_content(content: &str) -> String {
// Use ammonia for HTML sanitization (whitelist approach).
// Only allows safe tags: <a>, <b>, <i>, <code>, <pre>, <blockquote>, <p>, <br>, <strong>, <em>, <ul>, <ol>, <li>
// All other tags (including <script>, <iframe>, <style>) are stripped.
// Event handlers (onerror, onclick, etc.) are automatically removed.
ammonia::clean(content)
}
pub async fn resolve_display_name(
&self,
msg: room_message::Model,
_room_id: Uuid,
) -> super::RoomMessageResponse {
let sender_type = msg.sender_type.to_string();
let display_name = match sender_type.as_str() {
"ai" => {
if let Some(mid) = msg.model_id {
ai_model::Entity::find_by_id(mid)
.one(&self.db)
.await
.inspect_err(|e| tracing::warn!(error = %e, model_id = %mid, "resolve_display_name: AI model lookup failed"))
.ok()
.flatten()
.map(|m| m.name)
.or_else(|| Some(format!("AI({})", &mid.to_string()[..8])))
} else {
None
}
}
_ => {
if let Some(sender_id) = msg.sender_id {
let user = user_model::Entity::find()
.filter(user_model::Column::Uid.eq(sender_id))
.one(&self.db)
.await
.inspect_err(|e| tracing::warn!(error = %e, user_id = %sender_id, "resolve_display_name: user lookup failed"))
.ok()
.flatten();
user.map(|u| u.display_name.unwrap_or_else(|| u.username))
} else {
None
}
}
};
let chunked = super::RoomMessageResponse::detect_chunked(&msg.thinking_content);
super::RoomMessageResponse {
id: msg.id,
seq: msg.seq,
room: msg.room,
sender_type,
sender_id: msg.sender_id,
display_name,
thread: msg.thread,
content: msg.content,
content_type: msg.content_type.to_string(),
thinking_content: msg.thinking_content,
thinking_is_chunked: chunked,
edited_at: msg.edited_at,
send_at: msg.send_at,
revoked: msg.revoked,
revoked_by: msg.revoked_by,
in_reply_to: msg.in_reply_to,
highlighted_content: None,
attachment_ids: Vec::new(),
}
}
/// Get the current version of a room using Redis.
/// Returns 0 if no version has been set (new rooms start at 1).
pub(crate) async fn get_room_version(&self, room_id: Uuid) -> Result<i64, RoomError> {
let version_key = format!("room:version:{}", room_id);
let mut conn = self.cache.conn().await.map_err(|e| {
RoomError::Internal(format!("failed to get redis for version: {}", e))
})?;
let version: Option<i64> = redis::cmd("GET")
.arg(&version_key)
.query_async(&mut conn)
.await
.map_err(|e| RoomError::Internal(format!("version GET: {}", e)))?;
Ok(version.unwrap_or(0))
}
/// Atomically increment the room version and return the new value.
/// Called on every room mutation (rename, move, delete).
pub(crate) async fn increment_room_version(&self, room_id: Uuid) -> Result<i64, RoomError> {
Self::raw_increment_room_version(&self.cache, room_id).await
}
/// Static helper so it can be called from `room_create` without `&self`.
pub(crate) async fn raw_increment_room_version(
cache: &db::cache::AppCache,
room_id: Uuid,
) -> Result<i64, RoomError> {
let version_key = format!("room:version:{}", room_id);
let mut conn = cache.conn().await.map_err(|e| {
RoomError::Internal(format!("failed to get redis for version: {}", e))
})?;
let version: i64 = redis::cmd("INCR")
.arg(&version_key)
.query_async(&mut conn)
.await
.map_err(|e| RoomError::Internal(format!("version INCR: {}", e)))?;
Ok(version)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_room_member_role_valid() {
assert!(matches!(
RoomService::parse_room_member_role("owner").unwrap(),
RoomMemberRole::Owner
));
assert!(matches!(
RoomService::parse_room_member_role("admin").unwrap(),
RoomMemberRole::Admin
));
assert!(matches!(
RoomService::parse_room_member_role("member").unwrap(),
RoomMemberRole::Member
));
assert!(matches!(
RoomService::parse_room_member_role("guest").unwrap(),
RoomMemberRole::Guest
));
}
#[test]
fn test_parse_room_member_role_invalid() {
assert!(RoomService::parse_room_member_role("superadmin").is_err());
assert!(RoomService::parse_room_member_role("").is_err());
}
#[test]
fn test_parse_message_content_type_valid() {
assert!(matches!(
RoomService::parse_message_content_type(Some("text".into())).unwrap(),
MessageContentType::Text
));
assert!(matches!(
RoomService::parse_message_content_type(Some("image".into())).unwrap(),
MessageContentType::Image
));
assert!(matches!(
RoomService::parse_message_content_type(Some("audio".into())).unwrap(),
MessageContentType::Audio
));
assert!(matches!(
RoomService::parse_message_content_type(Some("video".into())).unwrap(),
MessageContentType::Video
));
assert!(matches!(
RoomService::parse_message_content_type(Some("file".into())).unwrap(),
MessageContentType::File
));
}
#[test]
fn test_parse_message_content_type_case_insensitive() {
assert!(matches!(
RoomService::parse_message_content_type(Some("TEXT".into())).unwrap(),
MessageContentType::Text
));
assert!(matches!(
RoomService::parse_message_content_type(Some("Image".into())).unwrap(),
MessageContentType::Image
));
}
#[test]
fn test_parse_message_content_type_none_defaults_to_text() {
assert!(matches!(
RoomService::parse_message_content_type(None).unwrap(),
MessageContentType::Text
));
}
#[test]
fn test_parse_message_content_type_invalid() {
assert!(RoomService::parse_message_content_type(Some("pdf".into())).is_err());
}
#[test]
fn test_validate_name_valid() {
assert!(RoomService::validate_name("test-room", 100).is_ok());
assert!(RoomService::validate_name("a", 100).is_ok());
}
#[test]
fn test_validate_name_empty() {
assert!(RoomService::validate_name("", 100).is_err());
assert!(RoomService::validate_name(" ", 100).is_err());
}
#[test]
fn test_validate_name_too_long() {
let long = "x".repeat(101);
assert!(RoomService::validate_name(&long, 100).is_err());
}
#[test]
fn test_validate_content_valid() {
assert!(RoomService::validate_content("hello", 10000).is_ok());
}
#[test]
fn test_validate_content_empty() {
assert!(RoomService::validate_content("", 10000).is_err());
assert!(RoomService::validate_content(" ", 10000).is_err());
}
#[test]
fn test_validate_content_too_long() {
let long = "x".repeat(10001);
assert!(RoomService::validate_content(&long, 10000).is_err());
}
#[test]
fn test_sanitize_content_removes_script_tag() {
let input = "<script>alert('xss')</script>";
let result = RoomService::sanitize_content(input);
assert!(!result.contains("<script>"));
}
#[test]
fn test_sanitize_content_blocks_javascript_uri() {
let input = "javascript:alert(1)";
let result = RoomService::sanitize_content(input);
// ammonia strips javascript: from href but preserves plain text
assert_eq!(result, "javascript:alert(1)"); // safe in plain text
}
#[test]
fn test_sanitize_content_blocks_onerror() {
let input = r#"<img src=x onerror="alert(1)">"#;
let result = RoomService::sanitize_content(input);
// ammonia removes event handler attributes from allowed tags
assert!(!result.contains("onerror"));
// ammonia keeps the img tag but with onerror removed
assert!(result.contains("<img"));
assert!(!result.contains("alert"));
}
#[test]
fn test_sanitize_content_preserves_safe_content() {
let input = "Hello <strong>world</strong>";
let result = RoomService::sanitize_content(input);
assert!(result.contains("Hello"));
assert!(result.contains("<strong>"));
}
#[test]
fn test_is_room_admin() {
assert!(RoomService::is_room_admin(&RoomMemberRole::Owner));
assert!(RoomService::is_room_admin(&RoomMemberRole::Admin));
assert!(!RoomService::is_room_admin(&RoomMemberRole::Member));
assert!(!RoomService::is_room_admin(&RoomMemberRole::Guest));
}
#[test]
fn test_room_event_type_from_str_roundtrip() {
for variant in [
crate::RoomEventType::RoomCreated,
crate::RoomEventType::RoomDeleted,
crate::RoomEventType::NewMessage,
crate::RoomEventType::MessageEdited,
crate::RoomEventType::MessageRevoked,
crate::RoomEventType::MemberJoined,
] {
let s = variant.as_str();
let parsed = crate::RoomEventType::from_str(s);
assert_eq!(parsed, Some(variant));
}
}
#[test]
fn test_room_event_type_from_str_unknown() {
assert_eq!(crate::RoomEventType::from_str("unknown_event"), None);
}
#[test]
fn test_mention_bracket_re_matches_ai_model() {
let re = crate::service::mention_bracket_re();
let caps: Vec<_> = re.captures_iter("@[ai:550e8400-0000-0000-0000-000000000001:GPT-4]").collect();
assert_eq!(caps.len(), 1);
assert_eq!(&caps[0][1], "ai");
assert_eq!(&caps[0][2], "550e8400-0000-0000-0000-000000000001");
}
#[test]
fn test_mention_bracket_re_matches_user() {
let re = crate::service::mention_bracket_re();
let caps: Vec<_> = re.captures_iter("@[user:850e8400-0000-0000-0000-000000000002:John]").collect();
assert_eq!(caps.len(), 1);
assert_eq!(&caps[0][1], "user");
}
#[test]
fn test_mention_bracket_re_matches_repo() {
let re = crate::service::mention_bracket_re();
let caps: Vec<_> = re.captures_iter("@[repo:my-repo:My Repository]").collect();
assert_eq!(caps.len(), 1);
assert_eq!(&caps[0][1], "repo");
}
#[test]
fn test_mention_bracket_re_no_match_plain_text() {
let re = crate::service::mention_bracket_re();
let caps: Vec<_> = re.captures_iter("Hello world").collect();
assert_eq!(caps.len(), 0);
}
#[test]
fn test_mention_multiple_in_same_message() {
let re = crate::service::mention_bracket_re();
let content = "@[ai:uuid1:Model1] and @[user:uuid2:User2]";
let caps: Vec<_> = re.captures_iter(content).collect();
assert_eq!(caps.len(), 2);
}
#[test]
fn test_mention_tag_re_legacy_format() {
let re = crate::service::mention_tag_re();
let content = r#"<mention type="ai" id="model-uuid">GPT-4</mention>"#;
let caps: Vec<_> = re.captures_iter(content).collect();
assert_eq!(caps.len(), 1);
assert_eq!(&caps[0][1], "ai");
assert_eq!(&caps[0][2], "model-uuid");
}
#[test]
fn test_mention_combined_brackets_and_tags() {
let bracket_re = crate::service::mention_bracket_re();
let tag_re = crate::service::mention_tag_re();
let content = r#"@[ai:uuid1:A] <mention type="ai" id="uuid2">B</mention>"#;
assert_eq!(bracket_re.captures_iter(content).count(), 1);
assert_eq!(tag_re.captures_iter(content).count(), 1);
}
}