use std::collections::HashMap; use std::future::Future; use std::pin::Pin; use std::sync::Arc; use std::thread; use std::time::{Duration, Instant}; use tokio::sync::{RwLock, broadcast}; use uuid::Uuid; use db::cache::AppCache; use db::database::AppDatabase; use models::rooms::{MessageContentType, MessageSenderType, room_message}; use queue::types::TypingEvent; use queue::{AgentTaskEvent, ProjectRoomEvent, RoomMessageEnvelope, RoomMessageEvent, RoomMessageStreamChunkEvent}; use sea_orm::{ColumnTrait, ConnectionTrait, EntityTrait, QueryFilter, Set}; use crate::error::RoomError; use crate::metrics::RoomMetrics; use crate::types::NotificationEvent; const BROADCAST_CAPACITY: usize = 100_000; const SHUTDOWN_CHANNEL_CAPACITY: usize = 16; const CONNECTION_COOLDOWN: Duration = Duration::from_secs(30); const MAX_CONNECTIONS_PER_ROOM: usize = 50000; const MAX_CONNECTIONS_PER_PROJECT: usize = 50000; const MAX_CONNECTIONS_PER_USER: usize = 50000; const BATCH_SIZE: usize = 100; const ROOM_IDLE_TIMEOUT: Duration = Duration::from_secs(30 * 60); pub struct RoomConnectionManager { room_inner: RwLock>>>, project_inner: RwLock>>>, user_inner: RwLock>>>, user_notification_inner: RwLock>>>, /// Broadcast channel for agent task events per project. task_inner: RwLock>>>, pub metrics: Arc, cache: AppCache, connection_rate: RwLock>, shutdown_tx: broadcast::Sender<()>, room_shutdown_txs: RwLock>>, project_shutdown_txs: RwLock>>, user_shutdown_txs: RwLock>>, stream_inner: RwLock>>>, room_stream_inner: RwLock>>>, typing_inner: RwLock>>>, room_last_activity: RwLock>, room_subscriber_count: RwLock>, project_subscriber_count: RwLock>, user_subscriber_count: RwLock>, } impl RoomConnectionManager { pub fn new(metrics: Arc, cache: AppCache) -> Self { let (shutdown_tx, _) = broadcast::channel(SHUTDOWN_CHANNEL_CAPACITY); Self { #[allow(clippy::default_constructed_unit_structs)] room_inner: RwLock::new(HashMap::new()), #[allow(clippy::default_constructed_unit_structs)] project_inner: RwLock::new(HashMap::new()), #[allow(clippy::default_constructed_unit_structs)] user_inner: RwLock::new(HashMap::new()), #[allow(clippy::default_constructed_unit_structs)] user_notification_inner: RwLock::new(HashMap::new()), #[allow(clippy::default_constructed_unit_structs)] task_inner: RwLock::new(HashMap::new()), metrics, cache, #[allow(clippy::default_constructed_unit_structs)] connection_rate: RwLock::new(HashMap::new()), shutdown_tx, #[allow(clippy::default_constructed_unit_structs)] room_shutdown_txs: RwLock::new(HashMap::new()), #[allow(clippy::default_constructed_unit_structs)] project_shutdown_txs: RwLock::new(HashMap::new()), #[allow(clippy::default_constructed_unit_structs)] user_shutdown_txs: RwLock::new(HashMap::new()), #[allow(clippy::default_constructed_unit_structs)] stream_inner: RwLock::new(HashMap::new()), #[allow(clippy::default_constructed_unit_structs)] room_stream_inner: RwLock::new(HashMap::new()), #[allow(clippy::default_constructed_unit_structs)] typing_inner: RwLock::new(HashMap::new()), #[allow(clippy::default_constructed_unit_structs)] room_last_activity: RwLock::new(HashMap::new()), #[allow(clippy::default_constructed_unit_structs)] room_subscriber_count: RwLock::new(HashMap::new()), #[allow(clippy::default_constructed_unit_structs)] project_subscriber_count: RwLock::new(HashMap::new()), #[allow(clippy::default_constructed_unit_structs)] user_subscriber_count: RwLock::new(HashMap::new()), } } pub async fn check_room_connection_rate( &self, room_id: Uuid, user_id: Uuid, ) -> Result<(), RoomError> { let mut map = self.connection_rate.write().await; let key = (room_id, user_id); if let Some(last) = map.get(&key) { if last.elapsed() < CONNECTION_COOLDOWN { return Err(RoomError::RateLimited(format!( "Connection cooldown active, retry in {}s", CONNECTION_COOLDOWN.saturating_sub(last.elapsed()).as_secs() ))); } } map.insert(key, Instant::now()); Ok(()) } pub async fn check_project_connection_rate( &self, project_id: Uuid, user_id: Uuid, ) -> Result<(), RoomError> { let mut map = self.connection_rate.write().await; let key = (project_id, user_id); if let Some(last) = map.get(&key) { if last.elapsed() < CONNECTION_COOLDOWN { return Err(RoomError::RateLimited(format!( "Connection cooldown active, retry in {}s", CONNECTION_COOLDOWN.saturating_sub(last.elapsed()).as_secs() ))); } } map.insert(key, Instant::now()); Ok(()) } pub async fn check_user_connection_rate(&self, user_id: Uuid) -> Result<(), RoomError> { let mut map = self.connection_rate.write().await; let key = (Uuid::nil(), user_id); if let Some(last) = map.get(&key) { if last.elapsed() < CONNECTION_COOLDOWN { return Err(RoomError::RateLimited(format!( "Connection cooldown active, retry in {}s", CONNECTION_COOLDOWN.saturating_sub(last.elapsed()).as_secs() ))); } } map.insert(key, Instant::now()); Ok(()) } pub async fn cleanup_rate_limit(&self) { let mut map = self.connection_rate.write().await; map.retain(|_, instant| instant.elapsed() < CONNECTION_COOLDOWN * 2); const MAX_RATE_ENTRIES: usize = MAX_CONNECTIONS_PER_ROOM * 10; if map.len() > MAX_RATE_ENTRIES { let mut entries: Vec<_> = map.iter().collect(); entries.sort_by(|a, b| a.1.cmp(b.1)); let keep_count = entries.len() / 2; let to_remove: Vec<_> = entries .into_iter() .take(keep_count) .map(|(k, _)| *k) .collect(); for key in to_remove { map.remove(&key); } } drop(map); self.cleanup_idle_rooms().await; } pub async fn cleanup_idle_rooms(&self) { let now = Instant::now(); let activity = self.room_last_activity.read().await; let idle_room_ids: Vec = activity .iter() .filter(|(_, last_time)| now.duration_since(**last_time) > ROOM_IDLE_TIMEOUT) .map(|(room_id, _)| *room_id) .collect(); drop(activity); if idle_room_ids.is_empty() { return; } { let mut counts = self.room_subscriber_count.write().await; let mut rooms = self.room_inner.write().await; for room_id in &idle_room_ids { if let Some(sender) = rooms.remove(room_id) { let count = counts.remove(&room_id).unwrap_or(1); self.metrics.users_online.decrement(count as f64); drop(sender); } } } { let mut stream_map = self.room_stream_inner.write().await; for room_id in &idle_room_ids { stream_map.remove(room_id); } } { let mut txs = self.room_shutdown_txs.write().await; for room_id in &idle_room_ids { txs.remove(room_id); } } { let mut activity = self.room_last_activity.write().await; for room_id in &idle_room_ids { activity.remove(room_id); } } } pub fn subscribe_shutdown(&self) -> broadcast::Receiver<()> { self.shutdown_tx.subscribe() } pub fn trigger_shutdown(&self) { let _ = self.shutdown_tx.send(()); } pub async fn register_room(&self, room_id: Uuid) -> broadcast::Receiver<()> { let mut txs = self.room_shutdown_txs.write().await; if let Some(tx) = txs.get(&room_id) { return tx.subscribe(); } let (tx, rx) = broadcast::channel(SHUTDOWN_CHANNEL_CAPACITY); txs.insert(room_id, tx); rx } pub async fn shutdown_room(&self, room_id: Uuid) { { let txs = self.room_shutdown_txs.read().await; if let Some(tx) = txs.get(&room_id) { let _ = tx.send(()); } } { let mut counts = self.room_subscriber_count.write().await; let count = counts.remove(&room_id).unwrap_or(0) as f64; if count > 0.0 { self.metrics.users_online.decrement(count); } } { let mut map = self.room_inner.write().await; map.remove(&room_id); } { let mut stream_map = self.room_stream_inner.write().await; stream_map.remove(&room_id); } { let mut txs = self.room_shutdown_txs.write().await; txs.remove(&room_id); } } pub async fn prune_stale_rooms(&self, active_room_ids: &[Uuid]) { let mut txs = self.room_shutdown_txs.write().await; txs.retain(|room_id, _| active_room_ids.contains(room_id)); drop(txs); let mut counts = self.room_subscriber_count.write().await; counts.retain(|room_id, _| active_room_ids.contains(room_id)); } pub async fn register_project(&self, project_id: Uuid) -> broadcast::Receiver<()> { let mut txs = self.project_shutdown_txs.write().await; if let Some(tx) = txs.get(&project_id) { return tx.subscribe(); } let (tx, rx) = broadcast::channel(SHUTDOWN_CHANNEL_CAPACITY); txs.insert(project_id, tx); rx } pub async fn shutdown_project(&self, project_id: Uuid) { { let txs = self.project_shutdown_txs.read().await; if let Some(tx) = txs.get(&project_id) { let _ = tx.send(()); } } { let mut map = self.project_inner.write().await; map.remove(&project_id); } { let mut txs = self.project_shutdown_txs.write().await; txs.remove(&project_id); } } pub async fn prune_stale_projects(&self, active_project_ids: &[Uuid]) { let mut txs = self.project_shutdown_txs.write().await; txs.retain(|project_id, _| active_project_ids.contains(project_id)); } pub async fn register_user(&self, user_id: Uuid) -> broadcast::Receiver<()> { let mut txs = self.user_shutdown_txs.write().await; if let Some(tx) = txs.get(&user_id) { return tx.subscribe(); } let (tx, rx) = broadcast::channel(SHUTDOWN_CHANNEL_CAPACITY); txs.insert(user_id, tx); rx } pub async fn shutdown_user(&self, user_id: Uuid) { { let txs = self.user_shutdown_txs.read().await; if let Some(tx) = txs.get(&user_id) { let _ = tx.send(()); } } { let mut map = self.user_inner.write().await; map.remove(&user_id); } { let mut txs = self.user_shutdown_txs.write().await; txs.remove(&user_id); } } pub async fn subscribe_user_notification( &self, user_id: Uuid, ) -> broadcast::Receiver> { let mut map = self.user_notification_inner.write().await; if let Some(sender) = map.get(&user_id) { return sender.subscribe(); } let (tx, rx) = broadcast::channel(BROADCAST_CAPACITY); map.insert(user_id, tx); rx } pub async fn unsubscribe_user_notification(&self, user_id: Uuid) { let mut map = self.user_notification_inner.write().await; map.remove(&user_id); } pub async fn push_user_notification(&self, user_id: Uuid, event: Arc) { let map = self.user_notification_inner.read().await; if let Some(sender) = map.get(&user_id) { let _ = sender.send(event); } } pub async fn subscribe( &self, room_id: Uuid, _user_id: Uuid, ) -> Result>, RoomError> { let mut map = self.room_inner.write().await; if let Some(_sender) = map.get(&room_id) { drop(map); let mut counts = self.room_subscriber_count.write().await; *counts.entry(room_id).or_insert(0) += 1; let map = self.room_inner.read().await; if let Some(sender) = map.get(&room_id) { return Ok(sender.subscribe()); } return Err(RoomError::Internal( "room disappeared during subscribe".into(), )); } if map.len() >= MAX_CONNECTIONS_PER_ROOM { return Err(RoomError::RateLimited(format!( "Room connection limit reached ({})", MAX_CONNECTIONS_PER_ROOM ))); } let (tx, rx) = broadcast::channel(BROADCAST_CAPACITY); map.insert(room_id, tx); drop(map); let mut counts = self.room_subscriber_count.write().await; counts.insert(room_id, 1); self.metrics.users_online.increment(1.0); Ok(rx) } pub async fn unsubscribe(&self, room_id: Uuid, _user_id: Uuid) { let mut counts = self.room_subscriber_count.write().await; let count = counts.entry(room_id).or_insert(0); if *count > 0 { *count -= 1; self.metrics.users_online.decrement(1.0); } if *count == 0 { counts.remove(&room_id); drop(counts); let mut map = self.room_inner.write().await; map.remove(&room_id); } } pub async fn broadcast(&self, room_id: Uuid, event: RoomMessageEvent) { { let mut activity = self.room_last_activity.write().await; activity.insert(room_id, Instant::now()); } let map = self.room_inner.read().await; if let Some(sender) = map.get(&room_id) { let event = Arc::new(event); if sender.send(event).is_err() { self.metrics.broadcasts_dropped.increment(1); } } } pub async fn subscribe_project( &self, project_id: Uuid, _user_id: Uuid, ) -> Result>, RoomError> { let mut map = self.project_inner.write().await; if map.get(&project_id).is_some() { drop(map); let mut counts = self.project_subscriber_count.write().await; *counts.entry(project_id).or_insert(0) += 1; let map = self.project_inner.read().await; if let Some(sender) = map.get(&project_id) { return Ok(sender.subscribe()); } return Err(RoomError::Internal("project channel disappeared".into())); } if map.len() >= MAX_CONNECTIONS_PER_PROJECT { return Err(RoomError::RateLimited(format!( "Project connection limit reached ({})", MAX_CONNECTIONS_PER_PROJECT ))); } let (tx, rx) = broadcast::channel(BROADCAST_CAPACITY); map.insert(project_id, tx); drop(map); let mut counts = self.project_subscriber_count.write().await; counts.insert(project_id, 1); Ok(rx) } pub async fn unsubscribe_project(&self, project_id: Uuid, _user_id: Uuid) { let mut counts = self.project_subscriber_count.write().await; let count = counts.entry(project_id).or_insert(0); if *count > 0 { *count -= 1; } if *count == 0 { counts.remove(&project_id); drop(counts); let mut map = self.project_inner.write().await; map.remove(&project_id); } } pub async fn broadcast_project(&self, project_id: Uuid, event: ProjectRoomEvent) { let map = self.project_inner.read().await; if let Some(sender) = map.get(&project_id) { let event = Arc::new(event); if sender.send(event).is_err() { self.metrics.broadcasts_dropped.increment(1); } } } /// Broadcast an agent task event to all WS clients subscribed to this project. pub async fn broadcast_agent_task(&self, project_id: Uuid, event: AgentTaskEvent) { let map = self.task_inner.read().await; if let Some(sender) = map.get(&project_id) { let event = Arc::new(event); if sender.send(event).is_err() { self.metrics.broadcasts_dropped.increment(1); } } } /// Subscribe to agent task events for a project. /// Returns a broadcast receiver that yields task events as they occur. pub async fn subscribe_task_events( &self, project_id: Uuid, ) -> Result>, RoomError> { let mut map = self.task_inner.write().await; if let Some(sender) = map.get(&project_id).cloned() { drop(map); return Ok(sender.subscribe()); } let (tx, rx) = broadcast::channel(BROADCAST_CAPACITY); map.insert(project_id, tx); Ok(rx) } pub async fn subscribe_user( &self, user_id: Uuid, ) -> Result>, RoomError> { let mut map = self.user_inner.write().await; if let Some(_sender) = map.get(&user_id) { drop(map); let mut counts = self.user_subscriber_count.write().await; *counts.entry(user_id).or_insert(0) += 1; let map = self.user_inner.read().await; if let Some(sender) = map.get(&user_id) { return Ok(sender.subscribe()); } return Err(RoomError::Internal("user channel disappeared".into())); } if map.len() >= MAX_CONNECTIONS_PER_USER { return Err(RoomError::RateLimited(format!( "User connection limit reached ({})", MAX_CONNECTIONS_PER_USER ))); } let (tx, rx) = broadcast::channel(BROADCAST_CAPACITY); map.insert(user_id, tx); drop(map); let mut counts = self.user_subscriber_count.write().await; counts.insert(user_id, 1); self.metrics.users_online.increment(1.0); Ok(rx) } pub async fn unsubscribe_user(&self, user_id: Uuid) { let mut counts = self.user_subscriber_count.write().await; let count = counts.entry(user_id).or_insert(0); if *count > 0 { *count -= 1; self.metrics.users_online.decrement(1.0); } if *count == 0 { counts.remove(&user_id); drop(counts); let mut map = self.user_inner.write().await; map.remove(&user_id); } } pub async fn broadcast_to_user(&self, user_id: Uuid, event: ProjectRoomEvent) { let map = self.user_inner.read().await; if let Some(sender) = map.get(&user_id) { let event = Arc::new(event); if sender.send(event).is_err() { self.metrics.broadcasts_dropped.increment(1); } } } pub async fn register_stream_channel( &self, message_id: Uuid, ) -> broadcast::Receiver> { let mut map = self.stream_inner.write().await; if let Some(tx) = map.get(&message_id) { return tx.subscribe(); } let (tx, rx) = broadcast::channel(BROADCAST_CAPACITY); map.insert(message_id, tx); rx } pub async fn subscribe_stream( &self, message_id: Uuid, ) -> Option>> { let map = self.stream_inner.read().await; map.get(&message_id).map(|tx| tx.subscribe()) } pub async fn subscribe_room_stream( &self, room_id: Uuid, ) -> broadcast::Receiver> { let mut map = self.room_stream_inner.write().await; if let Some(tx) = map.get(&room_id) { return tx.subscribe(); } let (tx, rx) = broadcast::channel(BROADCAST_CAPACITY); map.insert(room_id, tx); rx } pub async fn broadcast_stream_chunk(&self, event: RoomMessageStreamChunkEvent) { // Update activity tracker to prevent idle cleanup during active streaming { let mut activity = self.room_last_activity.write().await; activity.insert(event.room_id, Instant::now()); } let event = Arc::new(event); let is_final_chunk = event.done; let map = self.stream_inner.read().await; if let Some(tx) = map.get(&event.message_id) { let _ = tx.send(Arc::clone(&event)); } drop(map); let map = self.room_stream_inner.read().await; if let Some(tx) = map.get(&event.room_id) { let _ = tx.send(Arc::clone(&event)); } if is_final_chunk { drop(map); let mut map = self.room_stream_inner.write().await; map.remove(&event.room_id); } } pub async fn close_stream_channel(&self, message_id: Uuid) { let mut map = self.stream_inner.write().await; map.remove(&message_id); } pub async fn subscribe_typing( &self, room_id: Uuid, ) -> broadcast::Receiver> { let mut map: tokio::sync::RwLockWriteGuard<'_, std::collections::HashMap>>> = self.typing_inner.write().await; let tx = map.entry(room_id).or_insert_with(|| { let (tx, _) = broadcast::channel(BROADCAST_CAPACITY); tx }); // Replay active typing state from Redis to the new subscriber. // This ensures newly connected WS clients see who is currently typing. let active_events = self.get_active_typing_events(room_id).await; for event in active_events { let _ = tx.send(Arc::new(event)); } tx.subscribe() } /// Broadcast a typing event and persist it to Redis with 60s TTL. /// - "start": writes key with 60s expiry, broadcasts start event /// - "stop": deletes key, broadcasts stop event pub async fn broadcast_typing(&self, room_id: Uuid, event: TypingEvent) { let user_key = format!("typing:{}:{}", room_id, event.user_id); let action = event.action.clone(); let username = event.username.clone(); let avatar_url = event.avatar_url.clone(); let sender_type = event.sender_type.clone().unwrap_or_else(|| "user".to_string()); // Write/delete Redis key for 60s expiry (non-blocking) if let Ok(mut conn) = self.cache.conn().await { let key = user_key; tokio::spawn(async move { if action == "start" { let value = serde_json::json!({ "username": username, "avatar_url": avatar_url, "sender_type": sender_type, }) .to_string(); let _: Result<(), _> = redis::cmd("SETEX") .arg(&key) .arg(60i64) .arg(&value) .query_async(&mut conn) .await; } else { let _: Result<(), _> = redis::cmd("DEL").arg(&key).query_async(&mut conn).await; } }); } let map: tokio::sync::RwLockReadGuard<'_, std::collections::HashMap>>> = self.typing_inner.read().await; if let Some(tx) = map.get(&room_id) { let event = Arc::new(event); let _ = tx.send(event); } } /// Load all active typing entries for a room from Redis and return as TypingEvents. /// Used to replay current typing state to newly connected WS clients. pub async fn get_active_typing_events(&self, room_id: Uuid) -> Vec { let pattern = format!("typing:{}:*", room_id); if let Ok(mut conn) = self.cache.conn().await { let keys: Vec = match redis::cmd("KEYS").arg(&pattern).query_async(&mut conn).await { Ok(k) => k, Err(_) => return vec![], }; if keys.is_empty() { return vec![]; } let mut results = Vec::new(); for key in keys { let parts: Vec<&str> = key.split(':').collect(); let user_id = parts.get(2).and_then(|s| Uuid::parse_str(s).ok()); if let (Some(value), Some(user_uuid)) = ( redis::cmd("GET").arg(&key).query_async::(&mut conn).await.ok(), user_id, ) { if let Ok(parsed) = serde_json::from_str::(&value) { results.push(TypingEvent { room_id, user_id: user_uuid, username: parsed.get("username").and_then(|v| v.as_str()).unwrap_or("").to_string(), avatar_url: parsed.get("avatar_url").and_then(|v| v.as_str()).map(String::from), action: "start".to_string(), sender_type: parsed.get("sender_type").and_then(|v| v.as_str()).map(String::from), }); } } } return results; } vec![] } } fn parse_sender_type(s: &str) -> MessageSenderType { match s { "member" => MessageSenderType::Member, "admin" => MessageSenderType::Admin, "owner" => MessageSenderType::Owner, "ai" => MessageSenderType::Ai, "system" => MessageSenderType::System, "tool" => MessageSenderType::Tool, "guest" => MessageSenderType::Guest, _ => MessageSenderType::Member, } } fn parse_content_type(s: &str) -> MessageContentType { match s { "text" => MessageContentType::Text, "image" => MessageContentType::Image, "audio" => MessageContentType::Audio, "video" => MessageContentType::Video, "file" => MessageContentType::File, _ => MessageContentType::Text, } } pub type PersistFn = Arc< dyn Fn(Vec) -> Pin> + Send>> + Send + Sync, >; use dashmap::DashMap; pub type DedupCache = Arc>; const DEDUP_CACHE_TTL: Duration = Duration::from_secs(300); pub fn cleanup_dedup_cache(cache: &DedupCache) { let cutoff = Instant::now() - DEDUP_CACHE_TTL; cache.retain(|_, inserted_at| *inserted_at > cutoff); } pub fn make_persist_fn( db: AppDatabase, metrics: Arc, dedup_cache: DedupCache, ) -> PersistFn { Arc::new(move |envelopes: Vec| { let db = db.clone(); let metrics = metrics.clone(); let cache = dedup_cache.clone(); Box::pin(async move { for chunk in envelopes.chunks(BATCH_SIZE) { let mut models_to_insert = Vec::new(); let mut ids_to_dedup: Vec = Vec::new(); for env in chunk { if cache.contains_key(&env.id) { metrics.incr_duplicates_skipped(); continue; } ids_to_dedup.push(env.id); } let existing_ids: std::collections::HashSet = if !ids_to_dedup.is_empty() { room_message::Entity::find() .filter(room_message::Column::Id.is_in(ids_to_dedup)) .into_model::() .all(&db) .await? .into_iter() .map(|m| m.id) .collect() } else { std::collections::HashSet::new() }; for env in chunk { if cache.contains_key(&env.id) { continue; } cache.insert(env.id, Instant::now()); if existing_ids.contains(&env.id) { metrics.incr_duplicates_skipped(); continue; } let sender_type = parse_sender_type(&env.sender_type); let content_type = parse_content_type(&env.content_type); models_to_insert.push(room_message::ActiveModel { id: Set(env.id), seq: Set(env.seq), room: Set(env.room_id), sender_type: Set(sender_type), sender_id: Set(env.sender_id), model_id: Set(env.model_id), thread: Set(env.thread_id), content: Set(env.content.clone()), content_type: Set(content_type), edited_at: Set(None), send_at: Set(env.send_at.clone()), revoked: Set(None), revoked_by: Set(None), in_reply_to: Set(env.in_reply_to), }); } if !models_to_insert.is_empty() { let count = models_to_insert.len() as u64; room_message::Entity::insert_many(models_to_insert) .exec(&db) .await?; // Batch update content_tsv using a single UPDATE with subquery // instead of N individual UPDATE statements (N=chunk size, up to 100) let ids: Vec = chunk .iter() .filter(|e| !existing_ids.contains(&e.id)) .map(|e| format!("'{}'", e.id)) .collect(); if !ids.is_empty() { let batch_sql = format!( "UPDATE room_message AS t \ SET content_tsv = to_tsvector('simple', content) \ WHERE t.id IN ({})", ids.join(",") ); let stmt = sea_orm::Statement::from_sql_and_values( sea_orm::DbBackend::Postgres, &batch_sql, vec![], ); let _ = db.execute_raw(stmt).await; } metrics.messages_persisted.increment(count); } } Ok(()) }) }) } pub type RedisFuture = Pin> + Send>>; pub fn extract_get_redis( producer: queue::MessageProducer, ) -> Arc RedisFuture + Send + Sync> { Arc::new(move || { let get_redis_fn = producer.get_redis.clone(); Box::pin(async move { let handle = get_redis_fn(); match handle.await { Ok(conn) => conn, Err(_) => anyhow::bail!("redis pool task panicked"), } }) as RedisFuture }) } fn start_pubsub_thread( redis_url: String, channel: String, relay_tx: tokio::sync::mpsc::Sender>, mut shutdown_rx: broadcast::Receiver<()>, _on_msg: F, ) where F: Fn(Vec) -> Fut + Send + Sync + 'static, Fut: Future + Send, { thread::Builder::new() .name(format!("redis-pubsub-{}", &channel[..channel.len().min(32)])) .spawn(move || { let rt = tokio::runtime::Builder::new_current_thread() .enable_all() .build() .expect("pubsub thread runtime"); rt.block_on(async { let redis_url = redis_url.clone(); loop { if shutdown_rx.try_recv().is_ok() { tracing::info!(channel = %channel, "pubsub thread shutting down before connect"); break; } let client = match redis::Client::open(redis_url.as_str()) { Ok(c) => c, Err(e) => { tracing::error!(channel = %channel, error = %e, "pubsub redis client open failed"); thread::sleep(Duration::from_secs(1)); continue; } }; let mut pubsub = match client.get_async_pubsub().await { Ok(p) => p, Err(e) => { tracing::error!(channel = %channel, error = %e, "pubsub connection failed"); thread::sleep(Duration::from_secs(1)); continue; } }; match pubsub.subscribe(&channel).await { Ok(_) => tracing::info!(channel = %channel, "pubsub subscribed"), Err(e) => { tracing::error!(channel = %channel, error = %e, "pubsub subscribe failed"); thread::sleep(Duration::from_secs(1)); continue; } } let mut stream = pubsub.on_message(); loop { if shutdown_rx.try_recv().is_ok() { tracing::info!(channel = %channel, "pubsub thread shutting down"); return; } let msg = tokio::time::timeout( Duration::from_millis(500), futures::StreamExt::next(&mut stream), ) .await; match msg { Ok(Some(msg)) => { let payload = msg.get_payload_bytes(); tracing::debug!(channel = %channel, len = payload.len(), "pubsub received"); if relay_tx.send(payload.to_vec()).await.is_err() { tracing::warn!(channel = %channel, "pubsub relay channel closed"); return; } } Ok(None) => { tracing::warn!(channel = %channel, "pubsub stream ended, will reconnect"); break; } Err(_) => {} } } tracing::warn!(channel = %channel, "pubsub connection lost, reconnecting"); } }); }) .expect("pubsub thread spawn"); } pub async fn subscribe_room_events( redis_url: String, manager: Arc, room_id: Uuid, mut shutdown_rx: broadcast::Receiver<()>, ) { let channel = format!("room:pub:{}", room_id); let (tx, mut rx) = tokio::sync::mpsc::channel::>(1024); tracing::info!(room_id = %room_id, channel = %channel, "starting room pubsub subscriber"); let thread_channel = channel.clone(); let thread_shutdown = shutdown_rx.resubscribe(); start_pubsub_thread( redis_url, thread_channel, tx, thread_shutdown, |_| async {}, ); loop { tokio::select! { _ = shutdown_rx.recv() => { tracing::info!(room_id = %room_id, "room subscriber shutting down"); break; } payload = rx.recv() => { match payload { Some(data) => { match serde_json::from_slice::(&data) { Ok(event) => { manager.broadcast(room_id, event).await; } Err(e) => { tracing::warn!(error = %e, "malformed RoomMessageEvent"); } } } None => { tracing::warn!(room_id = %room_id, "pubsub relay channel closed"); break; } } } } } tracing::info!(room_id = %room_id, "room subscriber stopped"); } pub async fn subscribe_project_room_events( redis_url: String, manager: Arc, project_id: Uuid, mut shutdown_rx: broadcast::Receiver<()>, ) { let channel = format!("project:pub:{}", project_id); let (tx, mut rx) = tokio::sync::mpsc::channel::>(1024); tracing::info!(project_id = %project_id, channel = %channel, "starting project pubsub subscriber"); let thread_channel = channel.clone(); let thread_shutdown = shutdown_rx.resubscribe(); start_pubsub_thread( redis_url, thread_channel, tx, thread_shutdown, |_| async {}, ); loop { tokio::select! { _ = shutdown_rx.recv() => { tracing::info!(project_id = %project_id, "project subscriber shutting down"); break; } payload = rx.recv() => { match payload { Some(data) => { match serde_json::from_slice::(&data) { Ok(event) => { manager.broadcast_project(project_id, event).await; } Err(e) => { tracing::warn!(error = %e, "malformed ProjectRoomEvent"); } } } None => { tracing::warn!(project_id = %project_id, "project pubsub relay channel closed"); break; } } } } } tracing::info!(project_id = %project_id, "project subscriber stopped"); } /// Subscribe to Redis Pub/Sub `task:pub:{project_id}` and relay events to /// `RoomConnectionManager::broadcast_agent_task()` so all WS clients get notified. pub async fn subscribe_task_events_fn( redis_url: String, manager: Arc, project_id: Uuid, mut shutdown_rx: broadcast::Receiver<()>, ) { let channel = format!("task:pub:{}", project_id); let (tx, mut rx) = tokio::sync::mpsc::channel::>(1024); tracing::info!(project_id = %project_id, channel = %channel, "starting task pubsub subscriber"); let thread_channel = channel.clone(); let thread_shutdown = shutdown_rx.resubscribe(); start_pubsub_thread( redis_url, thread_channel, tx, thread_shutdown, |_| async {}, ); loop { tokio::select! { _ = shutdown_rx.recv() => { tracing::info!(project_id = %project_id, "task subscriber shutting down"); break; } payload = rx.recv() => { match payload { Some(data) => { match serde_json::from_slice::(&data) { Ok(event) => { manager.broadcast_agent_task(project_id, event).await; } Err(e) => { tracing::warn!(error = %e, "malformed AgentTaskEvent"); } } } None => { tracing::warn!(project_id = %project_id, "task pubsub relay channel closed"); break; } } } } } tracing::info!(project_id = %project_id, "task subscriber stopped"); }