diff --git a/libs/room/Cargo.toml b/libs/room/Cargo.toml index 86464f2..f418b48 100644 --- a/libs/room/Cargo.toml +++ b/libs/room/Cargo.toml @@ -43,6 +43,7 @@ redis = { workspace = true, features = ["tokio-comp", "connection-manager"] } hostname = "0.4" dashmap = "7.0.0-rc2" lru = "0.12.0" +ammonia = "4.0" [lints] workspace = true diff --git a/libs/room/src/connection.rs b/libs/room/src/connection.rs index fdb8692..c3e00a7 100644 --- a/libs/room/src/connection.rs +++ b/libs/room/src/connection.rs @@ -48,6 +48,7 @@ pub struct RoomConnectionManager { room_subscriber_count: RwLock>, project_subscriber_count: RwLock>, user_subscriber_count: RwLock>, + stream_cancel_tokens: RwLock>>, } impl RoomConnectionManager { @@ -89,6 +90,8 @@ impl RoomConnectionManager { project_subscriber_count: RwLock::new(HashMap::new()), #[allow(clippy::default_constructed_unit_structs)] user_subscriber_count: RwLock::new(HashMap::new()), + #[allow(clippy::default_constructed_unit_structs)] + stream_cancel_tokens: RwLock::new(HashMap::new()), } } @@ -629,6 +632,35 @@ impl RoomConnectionManager { map.remove(&message_id); } + /// Register a cancel flag for an active AI streaming session. + /// Returns the cancel token that the streaming task should check. + pub async fn register_stream_cancel( + &self, + room_id: Uuid, + ) -> Arc { + let cancel = Arc::new(std::sync::atomic::AtomicBool::new(false)); + let mut map = self.stream_cancel_tokens.write().await; + map.insert(room_id, cancel.clone()); + cancel + } + + /// Cancel an active AI streaming session for a room. + pub async fn cancel_ai_stream(&self, room_id: Uuid) -> bool { + let map = self.stream_cancel_tokens.read().await; + if let Some(cancel) = map.get(&room_id) { + cancel.store(true, std::sync::atomic::Ordering::Release); + true + } else { + false + } + } + + /// Clean up the cancel token for a room when streaming completes. + pub async fn unregister_stream_cancel(&self, room_id: Uuid) { + let mut map = self.stream_cancel_tokens.write().await; + map.remove(&room_id); + } + pub async fn subscribe_typing( &self, room_id: Uuid, @@ -660,24 +692,22 @@ impl RoomConnectionManager { // Write/delete Redis key for 60s expiry (non-blocking) if let Ok(mut conn) = self.cache.conn().await { let key = user_key; - tokio::spawn(async move { - if action == "start" { - let value = serde_json::json!({ - "username": username, - "avatar_url": avatar_url, - "sender_type": sender_type, - }) - .to_string(); - let _: Result<(), _> = redis::cmd("SETEX") - .arg(&key) - .arg(60i64) - .arg(&value) - .query_async(&mut conn) - .await; - } else { - let _: Result<(), _> = redis::cmd("DEL").arg(&key).query_async(&mut conn).await; - } - }); + if action == "start" { + let value = serde_json::json!({ + "username": username, + "avatar_url": avatar_url, + "sender_type": sender_type, + }) + .to_string(); + let _: Result<(), _> = redis::cmd("SETEX") + .arg(&key) + .arg(60i64) + .arg(&value) + .query_async(&mut conn) + .await; + } else { + let _: Result<(), _> = redis::cmd("DEL").arg(&key).query_async(&mut conn).await; + } } let map: tokio::sync::RwLockReadGuard<'_, std::collections::HashMap>>> = self.typing_inner.read().await; @@ -1156,6 +1186,53 @@ pub async fn subscribe_room_events( tracing::info!(room_id = %room_id, "room subscriber stopped"); } +/// Subscribe to stream chunk events for cross-node delivery. +/// When a stream chunk is published via Redis Pub/Sub on +/// `room:stream:chunk:{room_id}`, broadcast it locally. +pub async fn subscribe_room_stream_chunk_events( + redis_url: String, + manager: Arc, + room_id: Uuid, + mut shutdown_rx: broadcast::Receiver<()>, +) { + let channel = format!("room:stream:chunk:{}", room_id); + let (tx, mut rx) = tokio::sync::mpsc::channel::>(1024); + + tracing::info!(room_id = %room_id, channel = %channel, "starting room stream chunk subscriber"); + + let thread_channel = channel.clone(); + let thread_shutdown = shutdown_rx.resubscribe(); + start_pubsub_thread(redis_url, thread_channel, tx, thread_shutdown, |_| async {}); + + loop { + tokio::select! { + _ = shutdown_rx.recv() => { + tracing::info!(room_id = %room_id, "stream chunk subscriber shutting down"); + break; + } + payload = rx.recv() => { + match payload { + Some(data) => { + match serde_json::from_slice::(&data) { + Ok(event) => { + manager.broadcast_stream_chunk(event).await; + } + Err(e) => { + tracing::warn!(error = %e, "malformed RoomMessageStreamChunkEvent"); + } + } + } + None => { + tracing::warn!(room_id = %room_id, "stream chunk relay channel closed"); + break; + } + } + } + } + } + tracing::info!(room_id = %room_id, "stream chunk subscriber stopped"); +} + pub async fn subscribe_project_room_events( redis_url: String, manager: Arc, diff --git a/libs/room/src/helpers.rs b/libs/room/src/helpers.rs index 0afb6df..fd9a2ab 100644 --- a/libs/room/src/helpers.rs +++ b/libs/room/src/helpers.rs @@ -35,6 +35,7 @@ impl From for super::RoomResponse { created_at: value.created_at, last_msg_at: value.last_msg_at, unread_count: 0, + version: 0, } } } @@ -58,6 +59,7 @@ impl From for super::RoomMemberResponse { impl From for super::RoomMessageResponse { fn from(value: room_message::Model) -> Self { + let chunked = super::RoomMessageResponse::detect_chunked(&value.thinking_content); Self { id: value.id, seq: value.seq, @@ -69,6 +71,7 @@ impl From for super::RoomMessageResponse { content: value.content, content_type: value.content_type.to_string(), thinking_content: value.thinking_content, + thinking_is_chunked: chunked, edited_at: value.edited_at, send_at: value.send_at, revoked: value.revoked, @@ -270,14 +273,18 @@ impl RoomService { .filter(project::Column::Name.eq(name.clone())) .one(&self.db) .await + .inspect_err(|e| { + tracing::warn!(error = %e, project_name = %name, "utils_find_project_by_name: DB error"); + }) .ok() .flatten() { Some(project) => Ok(project), None => match project_history_name::Entity::find() - .filter(project_history_name::Column::HistoryName.eq(name)) + .filter(project_history_name::Column::HistoryName.eq(name.clone())) .one(&self.db) .await + .inspect_err(|e| tracing::warn!(error = %e, name = %name, "project_history_name lookup failed")) .ok() .flatten() { @@ -291,6 +298,7 @@ impl RoomService { project::Entity::find_by_id(uid) .one(&self.db) .await + .inspect_err(|e| tracing::warn!(error = %e, project_uid = %uid, "utils_find_project_by_uid: DB error")) .ok() .flatten() .ok_or_else(|| RoomError::NotFound("Project not found".to_string())) @@ -304,6 +312,7 @@ impl RoomService { let project = project::Entity::find_by_id(project_uid) .one(&self.db) .await + .inspect_err(|e| tracing::warn!(error = %e, project_uid = %project_uid, "check_project_access: DB error")) .ok() .flatten() .ok_or_else(|| RoomError::NotFound("Project not found".to_string()))?; @@ -352,36 +361,11 @@ impl RoomService { } pub(crate) fn sanitize_content(content: &str) -> String { - use std::sync::LazyLock; - - static SCRIPT_RE: LazyLock regex_lite::Regex> = - LazyLock::new(|| regex_lite::Regex::new(r"(?i)]*>.*?").unwrap()); - static STYLE_RE: LazyLock regex_lite::Regex> = - LazyLock::new(|| regex_lite::Regex::new(r"(?i)]*>.*?").unwrap()); - static ONERROR_RE: LazyLock regex_lite::Regex> = - LazyLock::new(|| regex_lite::Regex::new(r"(?i)\bonerror\s*=").unwrap()); - static ONLOAD_RE: LazyLock regex_lite::Regex> = - LazyLock::new(|| regex_lite::Regex::new(r"(?i)\bonload\s*=").unwrap()); - static ONCLICK_RE: LazyLock regex_lite::Regex> = - LazyLock::new(|| regex_lite::Regex::new(r"(?i)\bonclick\s*=").unwrap()); - static ONMOUSEOVER_RE: LazyLock regex_lite::Regex> = - LazyLock::new(|| regex_lite::Regex::new(r"(?i)\bonmouseover\s*=").unwrap()); - static JAVASCRIPT_RE: LazyLock regex_lite::Regex> = - LazyLock::new(|| regex_lite::Regex::new(r"(?i)javascript:").unwrap()); - static DATA_RE: LazyLock regex_lite::Regex> = - LazyLock::new(|| regex_lite::Regex::new(r"(?i)data:").unwrap()); - - let mut result = content.to_string(); - result = SCRIPT_RE.replace_all(&result, "").to_string(); - result = STYLE_RE.replace_all(&result, "").to_string(); - result = ONERROR_RE.replace_all(&result, "blocked=").to_string(); - result = ONLOAD_RE.replace_all(&result, "blocked=").to_string(); - result = ONCLICK_RE.replace_all(&result, "blocked=").to_string(); - result = ONMOUSEOVER_RE.replace_all(&result, "blocked=").to_string(); - result = JAVASCRIPT_RE.replace_all(&result, "blocked:").to_string(); - result = DATA_RE.replace_all(&result, "blocked:").to_string(); - - result + // Use ammonia for HTML sanitization (whitelist approach). + // Only allows safe tags: , , , ,
, 
,

,
, , ,

    ,
      ,
    1. + // All other tags (including "; + let result = RoomService::sanitize_content(input); + assert!(!result.contains("