gitdataai/libs/transport/handler/ws.rs

336 lines
16 KiB
Rust

use std::panic::AssertUnwindSafe;
use std::sync::Arc;
use std::time::Instant;
use actix_web::{HttpRequest, HttpResponse, web};
use actix_ws::Message as WsMessage;
use futures_util::FutureExt;
use uuid::Uuid;
use service::AppService;
use super::inbound::MessageHandler;
use super::poll::{poll_notifications, poll_subscriptions};
use super::session::{
HEARTBEAT_INTERVAL, HEARTBEAT_TIMEOUT, MAX_IDLE_TIMEOUT, MAX_MESSAGES_PER_SECOND,
MAX_TEXT_MESSAGE_LEN, TransportSession, WsUserCtx,
};
use super::types::{WsInMessage, WsOutEvent};
/// Universal WebSocket endpoint: `/ws`
///
/// Protocol:
/// - Inbound: JSON `WsInMessage` (tagged enum with `type` field)
/// - Outbound: JSON `WsOutEvent` (tagged enum with `type` field)
/// - Heartbeat: client sends `{"type":"ping"}`, server replies `{"type":"pong"}`
/// - Binary frames are rejected
/// - Rate limit: 1000 messages/sec per connection
pub async fn ws_handler(
service: web::Data<AppService>,
req: HttpRequest,
stream: web::Payload,
) -> Result<HttpResponse, actix_web::Error> {
let auth_ctx = authenticate_ws(&service, &req).await?;
let user_id = auth_ctx.user_id;
// Resolve display name for this user (cached in WsUserCtx for typing events, etc.)
let display_name = {
use models::users::user as user_model;
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter};
let db = &service.db;
user_model::Entity::find()
.filter(user_model::Column::Uid.eq(user_id))
.one(db)
.await
.ok()
.flatten()
.map(|u| u.display_name.unwrap_or_else(|| u.username))
.unwrap_or_else(|| user_id.to_string())
};
tracing::info!(user_id = %user_id, display_name = %display_name, "WS transport connection established");
let service_arc = Arc::new(service.get_ref().clone());
let manager = service_arc.room.room_manager.clone();
manager.metrics.ws_connections_active.increment(1.0);
manager.metrics.ws_connections_total.increment(1);
let mut notif_rx = manager.subscribe_user_notification(user_id).await;
let mut shutdown_rx = manager.subscribe_shutdown();
let (response, mut ws_session, mut msg_stream) = actix_ws::handle(&req, stream)?;
let spawn_handle = actix::spawn(async move {
let panic_result = AssertUnwindSafe(async {
let session = TransportSession::new(
WsUserCtx {
user_id,
device_id: auth_ctx.device_id,
client_id: auth_ctx.client_id,
display_name,
},
service_arc,
);
// Split state for tokio::select! borrow safety
let mut last_heartbeat = Instant::now();
let mut last_activity = Instant::now();
let mut message_count: u32 = 0;
let mut rate_window_start = Instant::now();
let mut heartbeat_interval = tokio::time::interval(HEARTBEAT_INTERVAL);
heartbeat_interval.tick().await;
loop {
tokio::select! {
// ── Heartbeat ──
_ = heartbeat_interval.tick() => {
if last_heartbeat.elapsed() > HEARTBEAT_TIMEOUT {
tracing::warn!(user_id = %user_id, "WS transport heartbeat timeout");
manager.metrics.ws_heartbeat_timeout_total.increment(1);
let _ = ws_session.close(Some(actix_ws::CloseCode::Policy.into())).await;
break;
}
if last_activity.elapsed() > MAX_IDLE_TIMEOUT {
tracing::info!(user_id = %user_id, "WS transport idle timeout");
manager.metrics.ws_idle_timeout_total.increment(1);
let _ = ws_session.close(Some(actix_ws::CloseCode::Normal.into())).await;
break;
}
if ws_session.ping(b"").await.is_err() { break; }
manager.metrics.ws_heartbeat_sent_total.increment(1);
}
// ── Shutdown ──
_ = shutdown_rx.recv() => {
tracing::info!("WS transport shutdown");
let _ = ws_session.close(Some(actix_ws::CloseCode::Normal.into())).await;
break;
}
// ── Notification push ──
notif = poll_notifications(&mut notif_rx) => {
if let Some(event) = notif {
if send_event(&mut ws_session, &event).await.is_err() { break; }
}
}
// ── Room broadcast push ──
push = poll_subscriptions(&session) => {
if let Some(event) = push {
if send_event(&mut ws_session, &event).await.is_err() { break; }
}
}
// ── Inbound client message ──
msg = msg_stream.recv() => {
match msg {
Some(Ok(WsMessage::Ping(bytes))) => {
if ws_session.pong(&bytes).await.is_err() { break; }
last_heartbeat = Instant::now();
}
Some(Ok(WsMessage::Pong(_))) => { last_heartbeat = Instant::now(); }
Some(Ok(WsMessage::Text(text))) => {
if last_activity.elapsed() > MAX_IDLE_TIMEOUT { break; }
last_activity = Instant::now();
last_heartbeat = Instant::now();
// Rate limit
if rate_window_start.elapsed() > super::session::RATE_LIMIT_WINDOW {
message_count = 0;
rate_window_start = Instant::now();
}
message_count += 1;
if message_count > MAX_MESSAGES_PER_SECOND {
let _ = ws_session.text(serde_json::json!({
"type": "error", "error": "rate_limit_exceeded"
}).to_string()).await;
continue;
}
if text.len() > MAX_TEXT_MESSAGE_LEN {
let _ = ws_session.text(serde_json::json!({
"type": "error", "error": "message_too_long"
}).to_string()).await;
continue;
}
// Parse once — extract request_id and deserialize together.
let json_value = serde_json::from_str::<serde_json::Value>(&text);
// Application-level JSON ping (distinguish from WebSocket Ping frame)
if text.trim() == r#"{"type":"ping"}"# || text.trim() == r#"{"type":"ping","_request_id":null}"# {
if ws_session.text(r#"{"type":"pong"}"#).await.is_err() { break; }
continue;
}
// Extract _request_id from the Value, then deserialize WsInMessage
let request_id: Option<Uuid> = json_value
.ok()
.and_then(|v| v.get("_request_id")
.and_then(|r| serde_json::from_value(r.clone()).ok()));
match serde_json::from_str::<WsInMessage>(&text) {
Ok(in_msg) => {
match MessageHandler::handle(&session, in_msg).await {
Ok(Some(event)) => {
let rid = request_id.unwrap_or(Uuid::nil());
let resp = WsOutEvent::Response {
request_id: rid,
data: serde_json::to_value(&event).unwrap_or_default(),
};
if send_event(&mut ws_session, &resp).await.is_err() { break; }
}
Ok(None) => {
let rid = request_id.unwrap_or(Uuid::nil());
let ack = WsOutEvent::Response {
request_id: rid,
data: serde_json::json!({"ok": true}),
};
if send_event(&mut ws_session, &ack).await.is_err() { break; }
}
Err(e) => {
tracing::warn!(user_id = %user_id, error = %e, "WS message processing failed");
let rid = request_id.unwrap_or(Uuid::nil());
let (code, error_type) = e.ws_error_code();
let err_json = serde_json::json!({
"type": "error",
"code": code,
"error": error_type,
"message": e.to_string(),
"_request_id": rid
});
let _ = ws_session.text(err_json.to_string()).await;
}
}
}
Err(e) => {
tracing::warn!(error = %e, "WS transport parse error");
let _ = ws_session.text(serde_json::json!({
"type": "error", "error": "parse_error"
}).to_string()).await;
}
}
}
Some(Ok(WsMessage::Binary(_))) => {
let _ = ws_session.close(Some(actix_ws::CloseCode::Unsupported.into())).await;
break;
}
Some(Ok(WsMessage::Continuation(_))) => {}
Some(Ok(WsMessage::Nop)) => {}
Some(Ok(WsMessage::Close(reason))) => {
let _ = ws_session.close(reason).await;
break;
}
Some(Err(e)) => {
tracing::warn!(error = %e, "WS transport error");
let _ = ws_session.close(Some(actix_ws::CloseCode::Protocol.into())).await;
break;
}
None => break,
}
}
}
}
// Cleanup
for sub in session.subscriptions.iter() {
manager.unsubscribe(sub.room_id, user_id).await;
}
manager.unsubscribe_user_notification(user_id).await;
// Remove presence entry so disconnected users don't appear online for up to 10 minutes
let project_id = session.project_id.lock().await;
session.service.room.remove_user_presence(user_id, *project_id).await;
manager.metrics.ws_connections_active.decrement(1.0);
manager.metrics.ws_disconnections_total.increment(1);
}).catch_unwind();
if let Err(panic_err) = panic_result.await {
let panic_msg = if let Some(s) = panic_err.downcast_ref::<String>() {
s.clone()
} else if let Some(s) = panic_err.downcast_ref::<&str>() {
s.to_string()
} else {
"Unknown panic".to_string()
};
tracing::error!(user_id = %user_id, panic = %panic_msg, "WS transport task panicked");
manager.metrics.ws_connections_active.decrement(1.0);
manager.metrics.ws_disconnections_total.increment(1);
}
});
// Drop the handle intentionally — cleanup is handled inside the spawned task
drop(spawn_handle);
Ok(response)
}
async fn send_event(ws_session: &mut actix_ws::Session, event: &WsOutEvent) -> Result<(), ()> {
match serde_json::to_string(event) {
Ok(json) => ws_session.text(json).await.map_err(|_| {}),
Err(e) => {
tracing::error!(error = %e, "WS transport serialize error");
Err(())
}
}
}
async fn authenticate_ws(
service: &AppService,
req: &HttpRequest,
) -> Result<crate::token::AppTransportTokenContext, actix_web::Error> {
// Prefer Authorization header over query parameter
if let Some(auth_header) = req.headers().get("Authorization") {
if let Ok(auth_str) = auth_header.to_str() {
if let Some(token) = auth_str.strip_prefix("Bearer ") {
match service.ws_token.validate_token_ctx(token).await {
Ok(ctx) => {
return Ok(crate::token::AppTransportTokenContext {
user_id: ctx.user_id,
device_id: ctx.device_id.unwrap_or_default(),
client_id: ctx.client_id.unwrap_or_default(),
});
}
Err(_) => {
service
.room
.room_manager
.metrics
.ws_auth_failures
.increment(1);
return Err(actix_web::error::ErrorUnauthorized("token auth failed"));
}
}
}
}
}
// Fallback: token in query string (deprecated, kept for backward compatibility)
if let Some(token) = req.uri().query().and_then(|q| {
q.split('&')
.find(|p| p.starts_with("token="))
.and_then(|p| p.split('=').nth(1))
}) {
match service.ws_token.validate_token_ctx(token).await {
Ok(ctx) => {
return Ok(crate::token::AppTransportTokenContext {
user_id: ctx.user_id,
device_id: ctx.device_id.unwrap_or_default(),
client_id: ctx.client_id.unwrap_or_default(),
});
}
Err(_) => {
service
.room
.room_manager
.metrics
.ws_auth_failures
.increment(1);
return Err(actix_web::error::ErrorUnauthorized("token auth failed"));
}
}
}
service
.room
.room_manager
.metrics
.ws_auth_failures
.increment(1);
Err(actix_web::error::ErrorUnauthorized("no auth provided"))
}