use std::collections::HashMap; use std::sync::Arc; use std::time::{Duration, Instant}; use actix_web::{HttpRequest, HttpResponse, web}; use actix_ws::Message as WsMessage; use tokio_stream::StreamExt; use tokio_stream::wrappers::BroadcastStream; use uuid::Uuid; use crate::error::ApiError; use queue::{ReactionGroup, RoomMessageEvent, RoomMessageStreamChunkEvent}; use service::AppService; use super::ws::validate_origin; use super::ws_handler::WsRequestHandler; use super::ws_types::{WsAction, WsRequest, WsResponse, WsResponseData}; const MAX_TEXT_MESSAGE_LEN: usize = 64 * 1024; const MAX_MESSAGES_PER_SECOND: u32 = 10; const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(30); const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(60); const MAX_IDLE_TIMEOUT: Duration = Duration::from_secs(300); const RATE_LIMIT_WINDOW: Duration = Duration::from_secs(1); /// Unified push event from any subscribed room. #[derive(Debug, Clone)] pub enum WsPushEvent { RoomMessage { room_id: Uuid, event: Arc, }, ReactionUpdated { room_id: Uuid, message_id: Uuid, reactions: Vec, }, AiStreamChunk { room_id: Uuid, chunk: Arc, }, } /// Maps room_id -> (room_message_broadcast_stream, stream_chunk_broadcast_stream) type PushStreams = HashMap< Uuid, ( BroadcastStream>, BroadcastStream>, ), >; pub async fn ws_universal( service: web::Data, req: HttpRequest, stream: web::Payload, ) -> Result { let origin_val = req .headers() .get("origin") .and_then(|v| v.to_str().ok()) .unwrap_or("(none)"); if !validate_origin(&req) { slog::warn!( service.logs, "WS universal: origin rejected origin={}", origin_val ); return Err(ApiError(service::error::AppError::BadRequest( "Invalid origin".into(), )) .into()); } // Validate token BEFORE actix_ws::handle() so we can return a proper HTTP // error if validation fails. Returning an HTTP error after handle() has been // called (even if the handler returns an error) sends a 200 OK on what the // browser expects to be a 101 Switching Protocols response — causing // immediate close with readyState=3. let user_id = if let Some(token) = req.uri().query().and_then(|q| { q.split('&') .find(|p| p.starts_with("token=")) .and_then(|p| p.split('=').nth(1)) }) { slog::info!( service.logs, "WS universal: validating token token={} origin={}", token, origin_val ); match service.ws_token.validate_token(token).await { Ok(uid) => { slog::info!( service.logs, "WS universal: token auth successful uid={} origin={}", uid, origin_val ); uid } Err(e) => { slog::warn!( service.logs, "WS universal: token auth failed: {:?} token={}", e, token ); service .room .room_manager .metrics .ws_auth_failures .increment(1); return Err(ApiError(service::error::AppError::Unauthorized).into()); } } } else { let auth_header = req .headers() .get("Authorization") .and_then(|v| v.to_str().ok()); let token = match auth_header { Some(h) if h.starts_with("Bearer ") => &h[7..], _ => { service .room .room_manager .metrics .ws_auth_failures .increment(1); return Err(ApiError(service::error::AppError::Unauthorized).into()); } }; match extract_user_id_from_token(token) { Some(id) => id, None => { service .room .room_manager .metrics .ws_auth_failures .increment(1); return Err(ApiError(service::error::AppError::Unauthorized).into()); } } }; slog::debug!( service.logs, "WS universal connection established user_id={} origin={}", user_id, origin_val ); let service = service.get_ref().clone(); let manager = service.room.room_manager.clone(); manager.metrics.ws_connections_active.increment(1.0); manager.metrics.ws_connections_total.increment(1); let logs = service.logs.clone(); let (response, mut session, mut msg_stream) = actix_ws::handle(&req, stream)?; actix::spawn(async move { let handler = WsRequestHandler::new(Arc::new(service), user_id); let mut push_streams: PushStreams = HashMap::new(); let mut shutdown_rx = manager.subscribe_shutdown(); let mut last_heartbeat = Instant::now(); let mut last_activity = Instant::now(); let mut heartbeat_interval = tokio::time::interval(HEARTBEAT_INTERVAL); heartbeat_interval.tick().await; let mut message_count: u32 = 0; let mut rate_window_start = Instant::now(); loop { tokio::select! { _ = heartbeat_interval.tick() => { if last_heartbeat.elapsed() > HEARTBEAT_TIMEOUT { slog::warn!(logs, "WS universal heartbeat timeout for user {}", user_id); manager.metrics.ws_heartbeat_timeout_total.increment(1); let _ = session.close(Some(actix_ws::CloseCode::Policy.into())).await; break; } if last_activity.elapsed() > MAX_IDLE_TIMEOUT { slog::info!(logs, "WS universal idle timeout for user {}", user_id); manager.metrics.ws_idle_timeout_total.increment(1); let _ = session.close(Some(actix_ws::CloseCode::Normal.into())).await; break; } if session.ping(b"").await.is_err() { break; } manager.metrics.ws_heartbeat_sent_total.increment(1); } _ = shutdown_rx.recv() => { slog::info!(logs, "WS universal shutdown"); let _ = session.close(Some(actix_ws::CloseCode::Normal.into())).await; break; } push_event = poll_push_streams(&mut push_streams) => { match push_event { Some(WsPushEvent::RoomMessage { room_id, event }) => { let payload = serde_json::json!({ "type": "event", "event": "room.message", "room_id": room_id, "data": { "id": event.id, "room_id": event.room_id, "sender_type": event.sender_type, "sender_id": event.sender_id, "thread_id": event.thread_id, "content": event.content, "content_type": event.content_type, "send_at": event.send_at, "seq": event.seq, "display_name": event.display_name, }, }); if session.text(payload.to_string()).await.is_err() { break; } } Some(WsPushEvent::ReactionUpdated { room_id, message_id, reactions }) => { let payload = serde_json::json!({ "type": "event", "event": "room.reaction_updated", "room_id": room_id, "data": { "message_id": message_id, "reactions": reactions, }, }); if session.text(payload.to_string()).await.is_err() { break; } } Some(WsPushEvent::AiStreamChunk { room_id, chunk }) => { let payload = serde_json::json!({ "type": "event", "event": "ai.stream_chunk", "room_id": room_id, "data": { "message_id": chunk.message_id, "room_id": chunk.room_id, "content": chunk.content, "done": chunk.done, "error": chunk.error, }, }); if session.text(payload.to_string()).await.is_err() { break; } } None => { } } } msg = msg_stream.recv() => { match msg { Some(Ok(WsMessage::Ping(bytes))) => { if session.pong(&bytes).await.is_err() { break; } last_heartbeat = Instant::now(); } Some(Ok(WsMessage::Pong(_))) => { last_heartbeat = Instant::now(); } Some(Ok(WsMessage::Text(text))) => { if last_activity.elapsed() > MAX_IDLE_TIMEOUT { slog::info!(logs, "WS universal idle timeout for user {}", user_id); manager.metrics.ws_idle_timeout_total.increment(1); let _ = session.close(Some(actix_ws::CloseCode::Normal.into())).await; break; } last_activity = Instant::now(); if rate_window_start.elapsed() > RATE_LIMIT_WINDOW { message_count = 0; rate_window_start = Instant::now(); } message_count += 1; if message_count > MAX_MESSAGES_PER_SECOND { slog::warn!(logs, "WS universal rate limit exceeded for user {}", user_id); manager.metrics.ws_rate_limit_hits.increment(1); let _ = session.text(serde_json::json!({"type":"error","error":"rate_limit_exceeded"}).to_string()).await; continue; } if text.len() > MAX_TEXT_MESSAGE_LEN { slog::warn!(logs, "WS universal message too long from user {}: {} bytes", user_id, text.len()); let _ = session.text(serde_json::json!({"type":"error","error":"message_too_long"}).to_string()).await; continue; } match serde_json::from_str::(&text) { Ok(request) => { let action_str = request.action.to_string(); match request.action { WsAction::SubscribeRoom => { if let Some(room_id) = request.params().room_id { match manager.subscribe(room_id, user_id).await { Ok(rx) => { let stream_rx = manager.subscribe_room_stream(room_id).await; push_streams.insert(room_id, ( BroadcastStream::new(rx), BroadcastStream::new(stream_rx), )); let _ = session.text(serde_json::to_string(&WsResponse::success( request.request_id, &action_str, WsResponseData::subscribed(Some(room_id), None) )).unwrap_or_default()).await; } Err(e) => { let _ = session.text(serde_json::to_string(&WsResponse::error_response( request.request_id, &action_str, 403, "subscribe_failed", &format!("{}", e) )).unwrap_or_default()).await; } } } else { let _ = session.text(serde_json::to_string(&WsResponse::error_response( request.request_id, &action_str, 400, "bad_request", "room_id required" )).unwrap_or_default()).await; } } WsAction::UnsubscribeRoom => { if let Some(room_id) = request.params().room_id { manager.unsubscribe(room_id, user_id).await; push_streams.remove(&room_id); } let _ = session.text(serde_json::to_string(&WsResponse::success( request.request_id, &action_str, WsResponseData::bool(true) )).unwrap_or_default()).await; } _ => { let resp = handler.handle(request).await; let _ = session.text(serde_json::to_string(&resp).unwrap_or_default()).await; } } } Err(e) => { slog::warn!(logs, "WS universal parse error from user {}: {}", user_id, e); let _ = session.text(serde_json::json!({"type":"error","error":"parse_error"}).to_string()).await; } } } Some(Ok(WsMessage::Binary(_))) => { break; } Some(Ok(WsMessage::Continuation(_))) => {} Some(Ok(WsMessage::Nop)) => {} Some(Ok(WsMessage::Close(reason))) => { let _ = session.close(reason).await; break; } Some(Err(e)) => { slog::warn!(logs, "WS error: {}", e); break; } None => break, } } } } // Clean up subscriptions on disconnect for room_id in push_streams.keys() { manager.unsubscribe(*room_id, user_id).await; } manager.metrics.ws_connections_active.decrement(1.0); manager.metrics.ws_disconnections_total.increment(1); }); Ok(response) } async fn poll_push_streams(streams: &mut PushStreams) -> Option { loop { let room_ids: Vec = streams.keys().copied().collect(); for room_id in room_ids { if let Some((msg_stream, chunk_stream)) = streams.get_mut(&room_id) { tokio::select! { result = msg_stream.next() => { match result { Some(Ok(event)) => { if let Some(reactions) = event.reactions.clone() { return Some(WsPushEvent::ReactionUpdated { room_id: event.room_id, message_id: event.id, reactions, }); } return Some(WsPushEvent::RoomMessage { room_id, event }); } Some(Err(_)) => { streams.remove(&room_id); } None => { streams.remove(&room_id); } } } result = chunk_stream.next() => { match result { Some(Ok(chunk)) => { return Some(WsPushEvent::AiStreamChunk { room_id, chunk }); } Some(Err(_)) | None => { streams.remove(&room_id); } } } } } } if streams.is_empty() { tokio::time::sleep(std::time::Duration::from_millis(50)).await; return None; } tokio::task::yield_now().await; } } fn extract_user_id_from_token(token: &str) -> Option { if token.len() < 64 { return None; } let token_data = base64_decode(token)?; if token_data.len() < 16 { return None; } let bytes: [u8; 16] = token_data[..16].try_into().ok()?; Some(Uuid::from_bytes(bytes)) } fn base64_decode(input: &str) -> Option> { let table = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; let mut result = Vec::with_capacity(input.len() * 3 / 4); let mut buffer: u32 = 0; let mut bits = 0; for byte in input.bytes() { if byte == b'=' || byte == b'\n' || byte == b'\r' || byte == b' ' { continue; } let idx = table.iter().position(|&x| x == byte)?; buffer = (buffer << 6) | (idx as u32); bits += 6; if bits >= 8 { bits -= 8; result.push((buffer >> bits) as u8); } } Some(result) }