use std::panic::AssertUnwindSafe; use std::sync::Arc; use std::time::Instant; use actix_web::{HttpRequest, HttpResponse, web}; use actix_ws::Message as WsMessage; use futures_util::FutureExt; use uuid::Uuid; use service::AppService; use super::inbound::MessageHandler; use super::poll::{poll_notifications, poll_subscriptions}; use super::session::{ HEARTBEAT_INTERVAL, HEARTBEAT_TIMEOUT, MAX_IDLE_TIMEOUT, MAX_MESSAGES_PER_SECOND, MAX_TEXT_MESSAGE_LEN, TransportSession, WsUserCtx, }; use super::types::{WsInMessage, WsOutEvent}; /// Universal WebSocket endpoint: `/ws` /// /// Protocol: /// - Inbound: JSON `WsInMessage` (tagged enum with `type` field) /// - Outbound: JSON `WsOutEvent` (tagged enum with `type` field) /// - Heartbeat: client sends `{"type":"ping"}`, server replies `{"type":"pong"}` /// - Binary frames are rejected /// - Rate limit: 1000 messages/sec per connection pub async fn ws_handler( service: web::Data, req: HttpRequest, stream: web::Payload, ) -> Result { let auth_ctx = authenticate_ws(&service, &req).await?; let user_id = auth_ctx.user_id; // Resolve display name for this user (cached in WsUserCtx for typing events, etc.) let display_name = { use models::users::user as user_model; use sea_orm::{ColumnTrait, EntityTrait, QueryFilter}; let db = &service.db; user_model::Entity::find() .filter(user_model::Column::Uid.eq(user_id)) .one(db) .await .ok() .flatten() .map(|u| u.display_name.unwrap_or_else(|| u.username)) .unwrap_or_else(|| user_id.to_string()) }; tracing::info!(user_id = %user_id, display_name = %display_name, "WS transport connection established"); let service_arc = Arc::new(service.get_ref().clone()); let manager = service_arc.room.room_manager.clone(); manager.metrics.ws_connections_active.increment(1.0); manager.metrics.ws_connections_total.increment(1); let mut notif_rx = manager.subscribe_user_notification(user_id).await; let mut shutdown_rx = manager.subscribe_shutdown(); let (response, mut ws_session, mut msg_stream) = actix_ws::handle(&req, stream)?; let spawn_handle = actix::spawn(async move { let panic_result = AssertUnwindSafe(async { let session = TransportSession::new( WsUserCtx { user_id, device_id: auth_ctx.device_id, client_id: auth_ctx.client_id, display_name, }, service_arc, ); // Split state for tokio::select! borrow safety let mut last_heartbeat = Instant::now(); let mut last_activity = Instant::now(); let mut message_count: u32 = 0; let mut rate_window_start = Instant::now(); let mut heartbeat_interval = tokio::time::interval(HEARTBEAT_INTERVAL); heartbeat_interval.tick().await; loop { tokio::select! { // ── Heartbeat ── _ = heartbeat_interval.tick() => { if last_heartbeat.elapsed() > HEARTBEAT_TIMEOUT { tracing::warn!(user_id = %user_id, "WS transport heartbeat timeout"); manager.metrics.ws_heartbeat_timeout_total.increment(1); let _ = ws_session.close(Some(actix_ws::CloseCode::Policy.into())).await; break; } if last_activity.elapsed() > MAX_IDLE_TIMEOUT { tracing::info!(user_id = %user_id, "WS transport idle timeout"); manager.metrics.ws_idle_timeout_total.increment(1); let _ = ws_session.close(Some(actix_ws::CloseCode::Normal.into())).await; break; } if ws_session.ping(b"").await.is_err() { break; } manager.metrics.ws_heartbeat_sent_total.increment(1); } // ── Shutdown ── _ = shutdown_rx.recv() => { tracing::info!("WS transport shutdown"); let _ = ws_session.close(Some(actix_ws::CloseCode::Normal.into())).await; break; } // ── Notification push ── notif = poll_notifications(&mut notif_rx) => { if let Some(event) = notif { if send_event(&mut ws_session, &event).await.is_err() { break; } } } // ── Room broadcast push ── push = poll_subscriptions(&session) => { if let Some(event) = push { if send_event(&mut ws_session, &event).await.is_err() { break; } } } // ── Inbound client message ── msg = msg_stream.recv() => { match msg { Some(Ok(WsMessage::Ping(bytes))) => { if ws_session.pong(&bytes).await.is_err() { break; } last_heartbeat = Instant::now(); } Some(Ok(WsMessage::Pong(_))) => { last_heartbeat = Instant::now(); } Some(Ok(WsMessage::Text(text))) => { if last_activity.elapsed() > MAX_IDLE_TIMEOUT { break; } last_activity = Instant::now(); last_heartbeat = Instant::now(); // Rate limit if rate_window_start.elapsed() > super::session::RATE_LIMIT_WINDOW { message_count = 0; rate_window_start = Instant::now(); } message_count += 1; if message_count > MAX_MESSAGES_PER_SECOND { let _ = ws_session.text(serde_json::json!({ "type": "error", "error": "rate_limit_exceeded" }).to_string()).await; continue; } if text.len() > MAX_TEXT_MESSAGE_LEN { let _ = ws_session.text(serde_json::json!({ "type": "error", "error": "message_too_long" }).to_string()).await; continue; } // Parse once — extract request_id and deserialize together. let json_value = serde_json::from_str::(&text); // Application-level JSON ping (distinguish from WebSocket Ping frame) if text.trim() == r#"{"type":"ping"}"# || text.trim() == r#"{"type":"ping","_request_id":null}"# { if ws_session.text(r#"{"type":"pong"}"#).await.is_err() { break; } continue; } // Extract _request_id from the Value, then deserialize WsInMessage let request_id: Option = json_value .ok() .and_then(|v| v.get("_request_id") .and_then(|r| serde_json::from_value(r.clone()).ok())); match serde_json::from_str::(&text) { Ok(in_msg) => { match MessageHandler::handle(&session, in_msg).await { Ok(Some(event)) => { let rid = request_id.unwrap_or(Uuid::nil()); let resp = WsOutEvent::Response { request_id: rid, data: serde_json::to_value(&event).unwrap_or_default(), }; if send_event(&mut ws_session, &resp).await.is_err() { break; } } Ok(None) => { let rid = request_id.unwrap_or(Uuid::nil()); let ack = WsOutEvent::Response { request_id: rid, data: serde_json::json!({"ok": true}), }; if send_event(&mut ws_session, &ack).await.is_err() { break; } } Err(e) => { tracing::warn!(user_id = %user_id, error = %e, "WS message processing failed"); let rid = request_id.unwrap_or(Uuid::nil()); let (code, error_type) = e.ws_error_code(); let err_json = serde_json::json!({ "type": "error", "code": code, "error": error_type, "message": e.to_string(), "_request_id": rid }); let _ = ws_session.text(err_json.to_string()).await; } } } Err(e) => { tracing::warn!(error = %e, "WS transport parse error"); let _ = ws_session.text(serde_json::json!({ "type": "error", "error": "parse_error" }).to_string()).await; } } } Some(Ok(WsMessage::Binary(_))) => { let _ = ws_session.close(Some(actix_ws::CloseCode::Unsupported.into())).await; break; } Some(Ok(WsMessage::Continuation(_))) => {} Some(Ok(WsMessage::Nop)) => {} Some(Ok(WsMessage::Close(reason))) => { let _ = ws_session.close(reason).await; break; } Some(Err(e)) => { tracing::warn!(error = %e, "WS transport error"); let _ = ws_session.close(Some(actix_ws::CloseCode::Protocol.into())).await; break; } None => break, } } } } // Cleanup for sub in session.subscriptions.iter() { manager.unsubscribe(sub.room_id, user_id).await; } manager.unsubscribe_user_notification(user_id).await; // Remove presence entry so disconnected users don't appear online for up to 10 minutes let project_id = session.project_id.lock().await; session.service.room.remove_user_presence(user_id, *project_id).await; manager.metrics.ws_connections_active.decrement(1.0); manager.metrics.ws_disconnections_total.increment(1); }).catch_unwind(); if let Err(panic_err) = panic_result.await { let panic_msg = if let Some(s) = panic_err.downcast_ref::() { s.clone() } else if let Some(s) = panic_err.downcast_ref::<&str>() { s.to_string() } else { "Unknown panic".to_string() }; tracing::error!(user_id = %user_id, panic = %panic_msg, "WS transport task panicked"); manager.metrics.ws_connections_active.decrement(1.0); manager.metrics.ws_disconnections_total.increment(1); } }); // Drop the handle intentionally — cleanup is handled inside the spawned task drop(spawn_handle); Ok(response) } async fn send_event(ws_session: &mut actix_ws::Session, event: &WsOutEvent) -> Result<(), ()> { match serde_json::to_string(event) { Ok(json) => ws_session.text(json).await.map_err(|_| {}), Err(e) => { tracing::error!(error = %e, "WS transport serialize error"); Err(()) } } } async fn authenticate_ws( service: &AppService, req: &HttpRequest, ) -> Result { // Prefer Authorization header over query parameter if let Some(auth_header) = req.headers().get("Authorization") { if let Ok(auth_str) = auth_header.to_str() { if let Some(token) = auth_str.strip_prefix("Bearer ") { match service.ws_token.validate_token_ctx(token).await { Ok(ctx) => { return Ok(crate::token::AppTransportTokenContext { user_id: ctx.user_id, device_id: ctx.device_id.unwrap_or_default(), client_id: ctx.client_id.unwrap_or_default(), }); } Err(_) => { service .room .room_manager .metrics .ws_auth_failures .increment(1); return Err(actix_web::error::ErrorUnauthorized("token auth failed")); } } } } } // Fallback: token in query string (deprecated, kept for backward compatibility) if let Some(token) = req.uri().query().and_then(|q| { q.split('&') .find(|p| p.starts_with("token=")) .and_then(|p| p.split('=').nth(1)) }) { match service.ws_token.validate_token_ctx(token).await { Ok(ctx) => { return Ok(crate::token::AppTransportTokenContext { user_id: ctx.user_id, device_id: ctx.device_id.unwrap_or_default(), client_id: ctx.client_id.unwrap_or_default(), }); } Err(_) => { service .room .room_manager .metrics .ws_auth_failures .increment(1); return Err(actix_web::error::ErrorUnauthorized("token auth failed")); } } } service .room .room_manager .metrics .ws_auth_failures .increment(1); Err(actix_web::error::ErrorUnauthorized("no auth provided")) }