use std::{ collections::{HashMap, HashSet}, sync::Arc, }; use cache::AppCache; use dashmap::DashMap; use db::AppDatabase; use model::room::RoomMessageModel; use serde::Deserialize; use serde::Serialize; use socketio::{Socket, SocketIo}; use tokio::sync::{Mutex, RwLock}; use tracing::warn; use uuid::Uuid; use crate::{ ChannelBusConfig, ChannelError, ChannelResult, circuit_breaker::CircuitBreaker, dedup::DeduplicationManager, event::ChannelEvent, metrics::ChannelMetrics, reconnect::ReconnectManager, rooms::{ active_workspace_users, catchup_messages, refresh_user_rooms_cache, room_socket_name, room_workspace, user_rooms, }, security::{CsrfProtection, RateLimiter}, seq::SeqAllocator, }; const ROOM_MESSAGE_EVENT: &str = "room.message"; #[derive(Clone)] pub struct ChannelBus { pub(crate) inner: Arc, } pub(crate) struct Inner { pub(crate) db: AppDatabase, pub(crate) cache: AppCache, pub(crate) io: SocketIo, pub(crate) config: ChannelBusConfig, pub(crate) online: RwLock>>, pub(crate) user_sync_locks: DashMap>>, pub(crate) typing_states: DashMap<(Uuid, Uuid), (crate::event::UserInfo, crate::event::RoomInfo, tokio_util::sync::CancellationToken)>, pub(crate) seq: SeqAllocator, pub(crate) dedup: DeduplicationManager, pub(crate) metrics: ChannelMetrics, pub(crate) reconnect: ReconnectManager, pub(crate) rate_limiter: RateLimiter, pub(crate) csrf: CsrfProtection, pub(crate) circuit_breaker: CircuitBreaker, } #[derive(Debug, Deserialize)] struct ConnectAuth { #[serde(default)] last_seq: HashMap, } impl ChannelBus { pub fn io(&self) -> &SocketIo { &self.inner.io } pub async fn first_workspace_id( &self, user: Uuid, ) -> ChannelResult> { let row = db::sqlx::query_as::<_, (Uuid,)>( "SELECT wk FROM wk_member WHERE \"user\" = $1 AND leave_at IS NULL LIMIT 1", ) .bind(user) .fetch_optional(self.inner.db.reader()) .await?; Ok(row.map(|r| r.0)) } pub async fn lookup_room( &self, room: Uuid, ) -> ChannelResult { let row = db::sqlx::query_as::<_, (String,)>( "SELECT name FROM room WHERE id = $1", ) .bind(room) .fetch_optional(self.inner.db.reader()) .await? .map(|(name,)| name) .unwrap_or_default(); Ok(crate::event::RoomInfo { id: room, name: row, }) } pub async fn list_workspace_members( &self, workspace: Uuid, ) -> ChannelResult> { let rows = db::sqlx::query_as::<_, (Uuid, String, String, String)>( r#"SELECT u.id, u.username, u.display_name, u.avatar_url FROM wk_member wm JOIN "user" u ON u.id = wm."user" WHERE wm.wk = $1 AND wm.leave_at IS NULL ORDER BY u.username"#, ) .bind(workspace) .fetch_all(self.inner.db.reader()) .await?; Ok(rows) } pub async fn lookup_workspace( &self, wk: Uuid, ) -> ChannelResult { use db::sqlx::Row; let row = db::sqlx::query( "SELECT name, avatar_url FROM workspace WHERE id = $1", ) .bind(wk) .fetch_optional(self.inner.db.reader()) .await?; let (name, avatar_url) = match row { Some(r) => (r.get::(0), r.get::(1)), None => (String::new(), String::new()), }; Ok(crate::event::WorkspaceInfo { id: wk, name, avatar_url, }) } pub async fn lookup_users( &self, users: &[Uuid], ) -> ChannelResult> { if users.is_empty() { return Ok(std::collections::HashMap::new()); } let rows = db::sqlx::query_as::<_, model::users::UserModel>( "SELECT id, username, display_name, avatar_url, website_url, \ allow_use, can_search, last_sign_in_at, created_at, updated_at \ FROM \"user\" WHERE id = ANY($1)", ) .bind(users) .fetch_all(self.inner.db.reader()) .await?; Ok(rows .into_iter() .map(|m| (m.id, crate::event::UserInfo::from_model(&m))) .collect()) } pub async fn lookup_user( &self, user: Uuid, ) -> ChannelResult { let row = db::sqlx::query_as::<_, model::users::UserModel>( "SELECT id, username, display_name, avatar_url, website_url, \ allow_use, can_search, last_sign_in_at, created_at, updated_at \ FROM \"user\" WHERE id = $1", ) .bind(user) .fetch_optional(self.inner.db.reader()) .await? .map(|m| crate::event::UserInfo::from_model(&m)) .unwrap_or_else(|| crate::event::UserInfo::unknown(user)); Ok(row) } pub async fn list_user_rooms( &self, user: Uuid, ) -> ChannelResult> { crate::rooms::user_rooms_for_api( &self.inner.db, &self.inner.cache, &self.inner.config, user, ) .await } pub async fn list_user_categories( &self, user: Uuid, ) -> ChannelResult> { crate::rooms::user_categories_for_api( &self.inner.db, &self.inner.cache, &self.inner.config, user, ) .await } pub fn new( db: AppDatabase, cache: AppCache, io: SocketIo, config: ChannelBusConfig, ) -> Self { let seq = match config.seq_segment_size { Some(size) => { SeqAllocator::with_segment_size(cache.clone(), db.clone(), size) } None => SeqAllocator::new(cache.clone(), db.clone()), }; let dedup = DeduplicationManager::with_config( cache.clone(), std::time::Duration::from_secs( config.dedup_window_secs.unwrap_or(300), ), ); let reconnect = ReconnectManager::new(cache.clone(), db.clone()); let rate_limiter = match ( config.rate_limit_max_requests, config.rate_limit_window_secs, ) { (Some(max), Some(secs)) => RateLimiter::with_config( cache.clone(), max, std::time::Duration::from_secs(secs), ), _ => RateLimiter::new(cache.clone()), }; let csrf = CsrfProtection::new(cache.clone()); let circuit_breaker = match ( config.circuit_breaker_failure_threshold, config.circuit_breaker_success_threshold, config.circuit_breaker_timeout_secs, config.circuit_breaker_half_open_max_calls, ) { (Some(failure), Some(success), Some(secs), Some(half_open)) => { CircuitBreaker::with_config( failure, success, std::time::Duration::from_secs(secs), half_open, ) } _ => CircuitBreaker::new(), }; Self { inner: Arc::new(Inner { db, cache, io, config, online: RwLock::new(HashMap::new()), user_sync_locks: DashMap::new(), typing_states: DashMap::new(), seq, dedup, metrics: ChannelMetrics::new(), reconnect, rate_limiter, csrf, circuit_breaker, }), } } pub async fn attach(&self) -> ChannelResult<()> { let namespace = self.inner.io.namespace(&self.inner.config.namespace).await; let auth_bus = self.clone(); namespace .use_middleware(move |socket, auth| { let bus = auth_bus.clone(); async move { if socket.session_user().is_some() { return Ok(()); } let token = auth .as_ref() .and_then(|v| v.get("access_token")) .and_then(|v| v.as_str()); if let Some(token) = token { let ctx = bus .check_access_token(token.to_owned()) .await .map_err(|_| { socketio::SocketIoError::Adapter( "token invalid or expired".to_owned(), ) })?; socket.set_user(ctx.user_id); return Ok(()); } Err(socketio::SocketIoError::Adapter( "unauthorized".to_owned(), )) } }) .await; let on_connect_bus = self.clone(); namespace .on_connect(move |socket| { let bus = on_connect_bus.clone(); async move { bus.inner.metrics.increment_connections(); if let Err(error) = bus.handle_connect(socket.clone()).await { warn!(%error, "channel socket connect failed, disconnecting"); let _ = socket.disconnect().await; } } }) .await; let on_disconnect_bus = self.clone(); namespace .on_disconnect(move |socket, _reason| { let bus = on_disconnect_bus.clone(); async move { bus.inner.metrics.decrement_connections(); bus.handle_disconnect(&socket).await; } }) .await; crate::http::ws::register_message_handler(self).await?; Ok(()) } pub async fn publish_room_message( &self, message: RoomMessageModel, sender: Option, ) -> ChannelResult<()> { let is_new = self .inner .dedup .check_and_mark(message.id, message.room) .await?; if !is_new { return Ok(()); } let event = match sender { Some(s) => ChannelEvent::message_created_with_sender(message, s), None => ChannelEvent::message_created(message), }; self.publish_event(event).await } pub async fn publish_room_event( &self, room: Uuid, event_type: impl Into, payload: T, ) -> ChannelResult<()> where T: Serialize, { let payload = serde_json::to_value(payload)?; self.publish_event(ChannelEvent::custom(room, event_type, payload)) .await } pub async fn emit_to_user( &self, user: Uuid, event: &str, data: &T, ) -> ChannelResult<()> { let sockets = self .inner .online .read() .await .get(&user) .map(|sockets| sockets.values().cloned().collect::>()) .unwrap_or_default(); for socket in sockets { socket.emit(event, data).await?; } Ok(()) } pub async fn refresh_user(&self, user: Uuid) -> ChannelResult<()> { let rooms = refresh_user_rooms_cache( &self.inner.db, &self.inner.cache, &self.inner.config, user, ) .await?; self.sync_online_user_rooms(user, &rooms).await } pub async fn workspace_changed(&self, wk: Uuid) -> ChannelResult<()> { let users = active_workspace_users(&self.inner.db, wk).await?; let bus = self.clone(); let results = futures::future::join_all(users.into_iter().map(|user| { let bus = bus.clone(); async move { bus.refresh_user(user).await } })) .await; let mut first_error = None; for result in results { if let Err(e) = result { tracing::warn!(error = %e, "workspace refresh failed for user"); if first_error.is_none() { first_error = Some(e); } } } if let Some(e) = first_error { Err(e) } else { Ok(()) } } pub async fn room_changed(&self, room: Uuid) -> ChannelResult<()> { if let Some(wk) = room_workspace(&self.inner.db, room).await? { self.workspace_changed(wk).await?; } Ok(()) } async fn publish_event(&self, event: ChannelEvent) -> ChannelResult<()> { self.inner.metrics.increment_sent(); // Best-effort broadcast — individual socket failures are expected // (sockets disconnect) and should not block all broadcasts. let result = self .inner .io .namespace(&self.inner.config.namespace) .await .to(room_socket_name(event.room)) .emit(ROOM_MESSAGE_EVENT, event) .await; match result { Ok(()) => { self.inner.metrics.increment_received(); Ok(()) } Err(e) => { tracing::warn!(error = %e, "WS broadcast failed"); self.inner.metrics.increment_failed(); Ok(()) // best-effort: don't propagate broadcast errors } } } async fn handle_connect(&self, socket: Socket) -> ChannelResult<()> { let user = socket.session_user().ok_or(ChannelError::Unauthorized)?; if !self .inner .rate_limiter .check_rate_limit(user, "connect") .await? { return Err(ChannelError::RateLimitExceeded); } let auth = socket .auth() .await .and_then(|value| serde_json::from_value::(value).ok()) .unwrap_or_else(|| ConnectAuth { last_seq: HashMap::new(), }); let rooms = user_rooms( &self.inner.db, &self.inner.cache, &self.inner.config, user, ) .await?; for room in &rooms { socket.join(room_socket_name(*room)).await?; } self.register_socket(user, socket.clone()).await; self.catchup(&socket, &rooms, &auth.last_seq).await?; Ok(()) } async fn handle_disconnect(&self, socket: &Socket) { let Some(user) = socket.session_user() else { return; }; let rooms = socket.rooms().await; for room_name in &rooms { if let Some(room_str) = room_name.strip_prefix("room:") { if let Ok(room_id) = Uuid::parse_str(room_str) { let _ = self.publish_room_event( room_id, "voice.channel_left", serde_json::json!({"user_id": user, "disconnected": true}), ) .await; } } } let mut online = self.inner.online.write().await; if let Some(sockets) = online.get_mut(&user) { sockets.remove(socket.id()); if sockets.is_empty() { online.remove(&user); self.inner.user_sync_locks.remove(&user); } } } async fn register_socket(&self, user: Uuid, socket: Socket) { self.inner .online .write() .await .entry(user) .or_default() .insert(socket.id().to_owned(), socket); } async fn sync_online_user_rooms( &self, user: Uuid, desired_rooms: &[Uuid], ) -> ChannelResult<()> { let lock = self .inner .user_sync_locks .entry(user) .or_insert_with(|| Arc::new(Mutex::new(()))) .clone(); let _guard = lock.lock().await; let sockets = self .inner .online .read() .await .get(&user) .map(|sockets| sockets.values().cloned().collect::>()) .unwrap_or_default(); let desired = desired_rooms .iter() .map(|room| room_socket_name(*room)) .collect::>(); for socket in sockets { let current = socket .rooms() .await .into_iter() .filter(|room| room.starts_with("room:")) .collect::>(); for room in desired.difference(¤t) { socket.join(room.clone()).await?; } for room in current.difference(&desired) { socket.leave(room).await?; } } Ok(()) } async fn catchup( &self, socket: &Socket, rooms: &[Uuid], last_seq: &HashMap, ) -> ChannelResult<()> { for room in rooms { let Some(seq) = last_seq.get(room) else { continue; }; let messages = catchup_messages( &self.inner.db, &self.inner.config, *room, *seq, ) .await?; for message in messages { let sender = match self.lookup_user(message.author).await { Ok(s) => Some(s), Err(_) => None, }; let event = match sender { Some(s) => ChannelEvent::message_created_with_sender(message, s), None => ChannelEvent::message_created(message), }; socket.emit(ROOM_MESSAGE_EVENT, event).await?; } } Ok(()) } }