use std::sync::Arc; use std::time::Instant; use actix_web::{web, HttpRequest, HttpResponse}; use actix_ws::Message as WsMessage; use uuid::Uuid; use service::AppService; use super::inbound::MessageHandler; use super::poll::{poll_notifications, poll_subscriptions}; use super::session::{TransportSession, WsUserCtx, HEARTBEAT_INTERVAL, HEARTBEAT_TIMEOUT, MAX_IDLE_TIMEOUT, MAX_TEXT_MESSAGE_LEN, MAX_MESSAGES_PER_SECOND}; use super::types::{WsInMessage, WsOutEvent}; /// Universal WebSocket endpoint: `/ws` /// /// Protocol: /// - Inbound: JSON `WsInMessage` (tagged enum with `type` field) /// - Outbound: JSON `WsOutEvent` (tagged enum with `type` field) /// - Heartbeat: client sends `{"type":"ping"}`, server replies `{"type":"pong"}` /// - Binary frames are rejected /// - Rate limit: 1000 messages/sec per connection pub async fn ws_handler( service: web::Data, req: HttpRequest, stream: web::Payload, ) -> Result { let user_id = authenticate_ws(&service, &req).await?; tracing::info!(user_id = %user_id, "WS transport connection established"); let service_arc = Arc::new(service.get_ref().clone()); let manager = service_arc.room.room_manager.clone(); manager.metrics.ws_connections_active.increment(1.0); manager.metrics.ws_connections_total.increment(1); let mut notif_rx = manager.subscribe_user_notification(user_id).await; let mut shutdown_rx = manager.subscribe_shutdown(); let (response, mut ws_session, mut msg_stream) = actix_ws::handle(&req, stream)?; actix::spawn(async move { let session = TransportSession::new( WsUserCtx { user_id, device_id: String::new(), client_id: String::new() }, service_arc, ); // Split state for tokio::select! borrow safety let mut last_heartbeat = Instant::now(); let mut last_activity = Instant::now(); let mut message_count: u32 = 0; let mut rate_window_start = Instant::now(); let mut heartbeat_interval = tokio::time::interval(HEARTBEAT_INTERVAL); heartbeat_interval.tick().await; loop { tokio::select! { // ── Heartbeat ── _ = heartbeat_interval.tick() => { if last_heartbeat.elapsed() > HEARTBEAT_TIMEOUT { tracing::warn!(user_id = %user_id, "WS transport heartbeat timeout"); manager.metrics.ws_heartbeat_timeout_total.increment(1); let _ = ws_session.close(Some(actix_ws::CloseCode::Policy.into())).await; break; } if last_activity.elapsed() > MAX_IDLE_TIMEOUT { tracing::info!(user_id = %user_id, "WS transport idle timeout"); manager.metrics.ws_idle_timeout_total.increment(1); let _ = ws_session.close(Some(actix_ws::CloseCode::Normal.into())).await; break; } if ws_session.ping(b"").await.is_err() { break; } manager.metrics.ws_heartbeat_sent_total.increment(1); } // ── Shutdown ── _ = shutdown_rx.recv() => { tracing::info!("WS transport shutdown"); let _ = ws_session.close(Some(actix_ws::CloseCode::Normal.into())).await; break; } // ── Notification push ── notif = poll_notifications(&mut notif_rx) => { if let Some(event) = notif { if send_event(&mut ws_session, &event).await.is_err() { break; } } } // ── Room broadcast push ── push = poll_subscriptions(&session) => { if let Some(event) = push { if send_event(&mut ws_session, &event).await.is_err() { break; } } } // ── Inbound client message ── msg = msg_stream.recv() => { match msg { Some(Ok(WsMessage::Ping(bytes))) => { if ws_session.pong(&bytes).await.is_err() { break; } last_heartbeat = Instant::now(); } Some(Ok(WsMessage::Pong(_))) => { last_heartbeat = Instant::now(); } Some(Ok(WsMessage::Text(text))) => { if last_activity.elapsed() > MAX_IDLE_TIMEOUT { break; } last_activity = Instant::now(); last_heartbeat = Instant::now(); // Rate limit if rate_window_start.elapsed() > super::session::RATE_LIMIT_WINDOW { message_count = 0; rate_window_start = Instant::now(); } message_count += 1; if message_count > MAX_MESSAGES_PER_SECOND { let _ = ws_session.text(serde_json::json!({ "type": "error", "error": "rate_limit_exceeded" }).to_string()).await; continue; } if text.len() > MAX_TEXT_MESSAGE_LEN { let _ = ws_session.text(serde_json::json!({ "type": "error", "error": "message_too_long" }).to_string()).await; continue; } // Application-level ping if text.trim() == r#"{"type":"ping"}"# { if ws_session.text(r#"{"type":"pong"}"#).await.is_err() { break; } continue; } match serde_json::from_str::(&text) { Ok(in_msg) => { if let Ok(response) = MessageHandler::handle(&session, in_msg).await { if let Some(event) = response { if send_event(&mut ws_session, &event).await.is_err() { break; } } } } Err(e) => { tracing::warn!(error = %e, "WS transport parse error"); let _ = ws_session.text(serde_json::json!({ "type": "error", "error": "parse_error" }).to_string()).await; } } } Some(Ok(WsMessage::Binary(_))) => { break; } Some(Ok(WsMessage::Continuation(_))) => {} Some(Ok(WsMessage::Nop)) => {} Some(Ok(WsMessage::Close(reason))) => { let _ = ws_session.close(reason).await; break; } Some(Err(e)) => { tracing::warn!(error = %e, "WS transport error"); break; } None => break, } } } } // Cleanup for sub in session.subscriptions.iter() { manager.unsubscribe(sub.room_id, user_id).await; } manager.unsubscribe_user_notification(user_id).await; manager.metrics.ws_connections_active.decrement(1.0); manager.metrics.ws_disconnections_total.increment(1); }); Ok(response) } async fn send_event(ws_session: &mut actix_ws::Session, event: &WsOutEvent) -> Result<(), ()> { match serde_json::to_string(event) { Ok(json) => ws_session.text(json).await.map_err(|_| {}), Err(e) => { tracing::error!(error = %e, "WS transport serialize error"); Err(()) } } } async fn authenticate_ws( service: &AppService, req: &HttpRequest, ) -> Result { if let Some(token) = req.uri().query().and_then(|q| { q.split('&').find(|p| p.starts_with("token=")).and_then(|p| p.split('=').nth(1)) }) { match service.ws_token.validate_token(token).await { Ok(uid) => return Ok(uid), Err(_) => { service.room.room_manager.metrics.ws_auth_failures.increment(1); return Err(actix_web::error::ErrorUnauthorized("token auth failed")); } } } service.room.room_manager.metrics.ws_auth_failures.increment(1); Err(actix_web::error::ErrorUnauthorized("no auth provided")) }