use std::sync::{Arc, LazyLock}; use std::time::{Duration, Instant}; use actix_web::{HttpMessage, HttpRequest, HttpResponse, web}; use actix_ws::Message as WsMessage; use serde::Serialize; use uuid::Uuid; use queue::{ProjectRoomEvent, RoomMessageEvent, RoomMessageStreamChunkEvent}; use service::AppService; use session::Session; const MAX_TEXT_MESSAGE_LEN: usize = 64 * 1024; const MAX_MESSAGES_PER_SECOND: u32 = 10; const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(30); const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(60); const MAX_IDLE_TIMEOUT: Duration = Duration::from_secs(300); const RATE_LIMIT_WINDOW: Duration = Duration::from_secs(1); /// Authenticate WebSocket request: try query parameter token first, then fall back to session. async fn authenticate_ws_request( service: &AppService, req: &HttpRequest, ) -> Result { // Try query parameter token first (one-time use via Redis) if let Some(token) = req.uri().query().and_then(|q| { q.split('&') .find(|p| p.starts_with("token=")) .and_then(|p| p.split('=').nth(1)) }) { match service.ws_token.validate_token(token).await { Ok(uid) => { slog::debug!(service.logs, "WS: token auth successful for uid={}", uid); return Ok(uid); } Err(_) => { slog::warn!(service.logs, "WS: token auth failed"); service .room .room_manager .metrics .ws_auth_failures .increment(1); return Err(crate::error::ApiError(service::error::AppError::Unauthorized).into()); } } } // Fall back to session-based auth let session = Session::get_session(&mut req.extensions_mut()); match session.user() { Some(uid) => Ok(uid), None => { service .room .room_manager .metrics .ws_auth_failures .increment(1); Err(crate::error::ApiError(service::error::AppError::Unauthorized).into()) } } } async fn check_ws_rate_limit( log: &slog::Logger, manager: &Arc, message_count: &mut u32, rate_window_start: &mut Instant, ) -> bool { if rate_window_start.elapsed() > RATE_LIMIT_WINDOW { *message_count = 0; *rate_window_start = Instant::now(); } *message_count += 1; if *message_count > MAX_MESSAGES_PER_SECOND { slog::warn!(log, "WS rate limit exceeded"); manager.metrics.ws_rate_limit_hits.increment(1); true } else { false } } #[derive(Clone, Serialize)] #[serde(tag = "type", rename_all = "snake_case")] pub enum WsEventPayload { RoomMessage(RoomMessagePayload), ProjectEvent(ProjectEventPayload), AiStreamChunk(AiStreamChunkPayload), } #[derive(Clone, Serialize)] pub struct AiStreamChunkPayload { pub message_id: Uuid, pub room_id: Uuid, pub content: String, pub done: bool, pub error: Option, } impl From for AiStreamChunkPayload { fn from(e: RoomMessageStreamChunkEvent) -> Self { Self { message_id: e.message_id, room_id: e.room_id, content: e.content, done: e.done, error: e.error, } } } impl From> for AiStreamChunkPayload { fn from(e: Arc) -> Self { AiStreamChunkPayload::from((&*e).clone()) } } #[derive(Clone, Serialize)] pub struct RoomMessagePayload { pub id: Uuid, pub room_id: Uuid, pub sender_type: String, pub sender_id: Option, pub thread_id: Option, pub content: String, pub content_type: String, pub send_at: chrono::DateTime, pub seq: i64, pub display_name: Option, } impl From for RoomMessagePayload { fn from(e: RoomMessageEvent) -> Self { Self { id: e.id, room_id: e.room_id, sender_type: e.sender_type, sender_id: e.sender_id, thread_id: e.thread_id, content: e.content, content_type: e.content_type, send_at: e.send_at, seq: e.seq, display_name: e.display_name, } } } impl From> for RoomMessagePayload { fn from(e: Arc) -> Self { RoomMessagePayload::from((&*e).clone()) } } impl From<&RoomMessageEvent> for RoomMessagePayload { fn from(e: &RoomMessageEvent) -> Self { Self { id: e.id, room_id: e.room_id, sender_type: e.sender_type.clone(), sender_id: e.sender_id, thread_id: e.thread_id, content: e.content.clone(), content_type: e.content_type.clone(), send_at: e.send_at, seq: e.seq, display_name: e.display_name.clone(), } } } #[derive(Clone, Serialize)] pub struct ProjectEventPayload { pub event_type: String, pub project_id: Uuid, pub room_id: Option, pub category_id: Option, pub message_id: Option, pub seq: Option, pub timestamp: chrono::DateTime, } impl From for ProjectEventPayload { fn from(e: ProjectRoomEvent) -> Self { Self { event_type: e.event_type, project_id: e.project_id, room_id: e.room_id, category_id: e.category_id, message_id: e.message_id, seq: e.seq, timestamp: e.timestamp, } } } impl From> for ProjectEventPayload { fn from(e: Arc) -> Self { ProjectEventPayload::from((&*e).clone()) } } impl From<&ProjectRoomEvent> for ProjectEventPayload { fn from(e: &ProjectRoomEvent) -> Self { Self { event_type: e.event_type.clone(), project_id: e.project_id, room_id: e.room_id, category_id: e.category_id, message_id: e.message_id, seq: e.seq, timestamp: e.timestamp, } } } #[derive(Clone, Serialize)] pub struct WsOutEvent { #[serde(skip_serializing_if = "Option::is_none")] pub room_id: Option, #[serde(skip_serializing_if = "Option::is_none")] pub project_id: Option, #[serde(skip_serializing_if = "Option::is_none")] pub event: Option, #[serde(skip_serializing_if = "Option::is_none")] pub error: Option, } pub(crate) fn validate_origin(req: &HttpRequest) -> bool { static ALLOWED_ORIGINS: LazyLock> = LazyLock::new(|| { std::env::var("WS_ALLOWED_ORIGINS") .map(|v| v.split(',').map(|s| s.trim().to_string()).collect()) .unwrap_or_else(|_| { vec![ "http://localhost".to_string(), "https://localhost".to_string(), "http://127.0.0.1".to_string(), "https://127.0.0.1".to_string(), "ws://localhost".to_string(), "wss://localhost".to_string(), "ws://127.0.0.1".to_string(), "wss://127.0.0.1".to_string(), ] }) }); let Some(origin) = req.headers().get("origin") else { return true; }; let Ok(origin_str) = origin.to_str() else { return false; }; // Exact match (with port) if ALLOWED_ORIGINS.iter().any(|allowed| origin_str == *allowed) { return true; } // Strip port: http://localhost:5173 -> http://localhost, http://[::1]:5173 -> http://[::1] let origin_without_port = if let Some((scheme_host, port)) = origin_str.rsplit_once(':') { if port.chars().all(|c| c.is_ascii_digit()) { scheme_host.to_string() } else { origin_str.to_string() } } else { origin_str.to_string() }; if ALLOWED_ORIGINS .iter() .any(|allowed| origin_without_port == *allowed) { return true; } // Also check if the full origin starts with any allowed prefix ALLOWED_ORIGINS .iter() .any(|allowed| origin_str.starts_with(allowed)) } pub async fn ws_room( room_id: web::Path, service: web::Data, req: HttpRequest, stream: web::Payload, ) -> Result { let room_id = room_id.into_inner(); // Authenticate: try query parameter token first, then session let user_id = authenticate_ws_request(&service, &req).await?; let origin_val = req .headers() .get("origin") .and_then(|v| v.to_str().ok()) .unwrap_or("(none)"); slog::debug!( service.logs, "WS room connection attempt user_id={} room_id={} origin={}", user_id, room_id, origin_val ); if !validate_origin(&req) { slog::warn!( service.logs, "WS room: origin rejected user_id={} room_id={} origin={}", user_id, room_id, origin_val ); service .room .room_manager .metrics .ws_auth_failures .increment(1); return Err(crate::error::ApiError(service::error::AppError::BadRequest( "Invalid origin".into(), )) .into()); } if let Err(e) = service.room.check_room_access(room_id, user_id).await { slog::warn!( service.logs, "WS room: access denied for user_id={} room_id={} error={}", user_id, room_id, e ); return Err(crate::error::ApiError::from(e).into()); } let manager = service.room.room_manager.clone(); manager.metrics.ws_connections_active.increment(1.0); manager.metrics.ws_connections_total.increment(1); manager.metrics.incr_room_connections(room_id).await; let (response, mut session, mut msg_stream) = actix_ws::handle(&req, stream)?; actix::spawn(async move { let mut receiver = match manager.subscribe(room_id, user_id).await { Ok(r) => r, Err(e) => { slog::error!(service.logs, "Failed to subscribe to room: {}", e); return; } }; let mut stream_rx = manager.subscribe_room_stream(room_id).await; let mut shutdown_rx = manager.subscribe_shutdown(); let mut last_heartbeat = Instant::now(); let mut last_activity = Instant::now(); let mut heartbeat_interval = tokio::time::interval(HEARTBEAT_INTERVAL); heartbeat_interval.tick().await; let mut message_count: u32 = 0; let mut rate_window_start = Instant::now(); loop { tokio::select! { _ = heartbeat_interval.tick() => { if last_heartbeat.elapsed() > HEARTBEAT_TIMEOUT { slog::warn!(service.logs, "WS room {} heartbeat timeout for user {}", room_id, user_id); manager.metrics.ws_heartbeat_timeout_total.increment(1); let _ = session.close(Some(actix_ws::CloseCode::Policy.into())).await; break; } if last_activity.elapsed() > MAX_IDLE_TIMEOUT { slog::info!(service.logs, "WS room {} idle timeout for user {}", room_id, user_id); manager.metrics.ws_idle_timeout_total.increment(1); let _ = session.close(Some(actix_ws::CloseCode::Normal.into())).await; break; } if session.ping(b"").await.is_err() { break; } manager.metrics.ws_heartbeat_sent_total.increment(1); } _ = shutdown_rx.recv() => { slog::info!(service.logs, "WS room {} shutdown", room_id); let _ = session.close(Some(actix_ws::CloseCode::Normal.into())).await; break; } msg = msg_stream.recv() => { match msg { Some(Ok(WsMessage::Ping(bytes))) => { if session.pong(&bytes).await.is_err() { break; } last_heartbeat = Instant::now(); } Some(Ok(WsMessage::Pong(_))) => { last_heartbeat = Instant::now(); } #[allow(unused_assignments)] Some(Ok(WsMessage::Text(text))) => { if last_activity.elapsed() > MAX_IDLE_TIMEOUT { slog::info!(service.logs, "WS room {} idle timeout for user {}", room_id, user_id); manager.metrics.ws_idle_timeout_total.increment(1); let _ = session.close(Some(actix_ws::CloseCode::Normal.into())).await; break; } last_activity = Instant::now(); if check_ws_rate_limit(&service.logs, &manager, &mut message_count, &mut rate_window_start).await { let _ = session.text(serde_json::json!({ "type": "error", "error": "rate_limit_exceeded", "max_per_second": MAX_MESSAGES_PER_SECOND }).to_string()).await; break; } if text.len() > MAX_TEXT_MESSAGE_LEN { slog::warn!(service.logs, "WS room {} message too long from user {}: {} bytes", room_id, user_id, text.len()); let _ = session.text(serde_json::json!({ "type": "error", "error": "message_too_long", "max_bytes": MAX_TEXT_MESSAGE_LEN }).to_string()).await; break; } slog::warn!(service.logs, "WS room {} unexpected text message from user {} ({} bytes) — WS is push-only, use REST to send messages", room_id, user_id, text.len()); let _ = session.text(serde_json::json!({ "type": "error", "error": "ws_push_only", "message": "WebSocket is for receiving messages only. Use the REST API to send messages." }).to_string()).await; break; } Some(Ok(WsMessage::Binary(_))) => { if check_ws_rate_limit(&service.logs, &manager, &mut message_count, &mut rate_window_start).await { break; } slog::warn!(service.logs, "WS room {} unexpected binary from user {}", room_id, user_id); break; } Some(Ok(WsMessage::Close(reason))) => { let _ = session.close(reason).await; break; } Some(Ok(_)) => {} Some(Err(e)) => { slog::warn!(service.logs, "WS room error: {}", e); break; } None => break, } } event = receiver.recv() => { match event { Ok(event) => { let payload = WsOutEvent { room_id: Some(room_id), project_id: None, event: Some(WsEventPayload::RoomMessage(event.into())), error: None, }; match serde_json::to_string(&payload) { Ok(json) => { if session.text(json).await.is_err() { break; } } Err(e) => { slog::error!(service.logs, "WS serialize error: {}", e); break; } } } Err(_) => break, } } chunk_event = stream_rx.recv() => { match chunk_event { Ok(chunk) => { let payload = WsOutEvent { room_id: Some(room_id), project_id: None, event: Some(WsEventPayload::AiStreamChunk(chunk.into())), error: None, }; match serde_json::to_string(&payload) { Ok(json) => { if session.text(json).await.is_err() { break; } } Err(e) => { slog::error!(service.logs, "WS streaming serialize error: {}", e); } } } Err(_) => {} } } } } manager.unsubscribe(room_id, user_id).await; manager.metrics.ws_connections_active.decrement(1.0); manager.metrics.ws_disconnections_total.increment(1); manager.metrics.dec_room_connections(room_id).await; }); Ok(response) } pub async fn ws_project( project_id: web::Path, service: web::Data, req: HttpRequest, stream: web::Payload, ) -> Result { let project_id = project_id.into_inner(); // Authenticate: try query parameter token first, then session let user_id = authenticate_ws_request(&service, &req).await?; if !validate_origin(&req) { service .room .room_manager .metrics .ws_auth_failures .increment(1); return Err(crate::error::ApiError(service::error::AppError::BadRequest( "Invalid origin".into(), )) .into()); } if let Err(e) = service.room.check_project_member(project_id, user_id).await { service .room .room_manager .metrics .ws_auth_failures .increment(1); return Err(crate::error::ApiError::from(e).into()); } if let Err(e) = service .room .room_manager .check_project_connection_rate(project_id, user_id) .await { service .room .room_manager .metrics .ws_rate_limit_hits .increment(1); return Err(crate::error::ApiError::from(e).into()); } let manager = service.room.room_manager.clone(); manager.metrics.ws_connections_active.increment(1.0); manager.metrics.ws_connections_total.increment(1); let (response, mut session, mut msg_stream) = actix_ws::handle(&req, stream)?; actix::spawn(async move { let mut receiver = match manager.subscribe_project(project_id, user_id).await { Ok(r) => r, Err(e) => { slog::error!(service.logs, "Failed to subscribe to project: {}", e); return; } }; let mut shutdown_rx = manager.subscribe_shutdown(); let mut last_heartbeat = Instant::now(); let mut last_activity = Instant::now(); let mut heartbeat_interval = tokio::time::interval(HEARTBEAT_INTERVAL); heartbeat_interval.tick().await; let mut message_count: u32 = 0; let mut rate_window_start = Instant::now(); loop { tokio::select! { _ = heartbeat_interval.tick() => { if last_heartbeat.elapsed() > HEARTBEAT_TIMEOUT { slog::warn!(service.logs, "WS project {} heartbeat timeout for user {}", project_id, user_id); manager.metrics.ws_heartbeat_timeout_total.increment(1); let _ = session.close(Some(actix_ws::CloseCode::Policy.into())).await; break; } if last_activity.elapsed() > MAX_IDLE_TIMEOUT { slog::info!(service.logs, "WS project {} idle timeout for user {}", project_id, user_id); manager.metrics.ws_idle_timeout_total.increment(1); let _ = session.close(Some(actix_ws::CloseCode::Normal.into())).await; break; } if session.ping(b"").await.is_err() { break; } manager.metrics.ws_heartbeat_sent_total.increment(1); } _ = shutdown_rx.recv() => { slog::info!(service.logs, "WS project {} shutdown", project_id); let _ = session.close(Some(actix_ws::CloseCode::Normal.into())).await; break; } msg = msg_stream.recv() => { match msg { Some(Ok(WsMessage::Ping(bytes))) => { if session.pong(&bytes).await.is_err() { break; } last_heartbeat = Instant::now(); } Some(Ok(WsMessage::Pong(_))) => { last_heartbeat = Instant::now(); } #[allow(unused_assignments)] Some(Ok(WsMessage::Text(text))) => { if last_activity.elapsed() > MAX_IDLE_TIMEOUT { slog::info!(service.logs, "WS project {} idle timeout for user {}", project_id, user_id); manager.metrics.ws_idle_timeout_total.increment(1); let _ = session.close(Some(actix_ws::CloseCode::Normal.into())).await; break; } last_activity = Instant::now(); slog::warn!(service.logs, "WS project {} unexpected text from user {} ({} bytes) — WS is push-only", project_id, user_id, text.len()); let _ = session.text(serde_json::json!({ "type": "error", "error": "ws_push_only", "message": "WebSocket is for receiving events only." }).to_string()).await; break; } Some(Ok(WsMessage::Binary(_))) => { if check_ws_rate_limit(&service.logs, &manager, &mut message_count, &mut rate_window_start).await { slog::warn!(service.logs, "WS project {} rate limit exceeded for user {}", project_id, user_id); let _ = session.text(serde_json::json!({ "type": "error", "error": "rate_limit_exceeded", "max_per_second": MAX_MESSAGES_PER_SECOND }).to_string()).await; break; } slog::warn!(service.logs, "WS project {} unexpected binary from user {}", project_id, user_id); break; } Some(Ok(WsMessage::Close(reason))) => { let _ = session.close(reason).await; break; } Some(Ok(_)) => {} Some(Err(e)) => { slog::warn!(service.logs, "WS project error: {}", e); break; } None => break, } } event = receiver.recv() => { match event { Ok(event) => { let payload = WsOutEvent { room_id: event.room_id, project_id: Some(project_id), event: Some(WsEventPayload::ProjectEvent(event.into())), error: None, }; match serde_json::to_string(&payload) { Ok(json) => { if session.text(json).await.is_err() { break; } } Err(e) => { slog::error!(service.logs, "WS serialize error: {}", e); break; } } } Err(_) => break, } } } } manager.unsubscribe_project(project_id, user_id).await; manager.metrics.ws_connections_active.decrement(1.0); manager.metrics.ws_disconnections_total.increment(1); }); Ok(response) }