336 lines
16 KiB
Rust
336 lines
16 KiB
Rust
use std::panic::AssertUnwindSafe;
|
|
use std::sync::Arc;
|
|
use std::time::Instant;
|
|
|
|
use actix_web::{HttpRequest, HttpResponse, web};
|
|
use actix_ws::Message as WsMessage;
|
|
use futures_util::FutureExt;
|
|
use uuid::Uuid;
|
|
|
|
use service::AppService;
|
|
|
|
use super::inbound::MessageHandler;
|
|
use super::poll::{poll_notifications, poll_subscriptions};
|
|
use super::session::{
|
|
HEARTBEAT_INTERVAL, HEARTBEAT_TIMEOUT, MAX_IDLE_TIMEOUT, MAX_MESSAGES_PER_SECOND,
|
|
MAX_TEXT_MESSAGE_LEN, TransportSession, WsUserCtx,
|
|
};
|
|
use super::types::{WsInMessage, WsOutEvent};
|
|
|
|
/// Universal WebSocket endpoint: `/ws`
|
|
///
|
|
/// Protocol:
|
|
/// - Inbound: JSON `WsInMessage` (tagged enum with `type` field)
|
|
/// - Outbound: JSON `WsOutEvent` (tagged enum with `type` field)
|
|
/// - Heartbeat: client sends `{"type":"ping"}`, server replies `{"type":"pong"}`
|
|
/// - Binary frames are rejected
|
|
/// - Rate limit: 1000 messages/sec per connection
|
|
pub async fn ws_handler(
|
|
service: web::Data<AppService>,
|
|
req: HttpRequest,
|
|
stream: web::Payload,
|
|
) -> Result<HttpResponse, actix_web::Error> {
|
|
let auth_ctx = authenticate_ws(&service, &req).await?;
|
|
let user_id = auth_ctx.user_id;
|
|
|
|
// Resolve display name for this user (cached in WsUserCtx for typing events, etc.)
|
|
let display_name = {
|
|
use models::users::user as user_model;
|
|
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter};
|
|
let db = &service.db;
|
|
user_model::Entity::find()
|
|
.filter(user_model::Column::Uid.eq(user_id))
|
|
.one(db)
|
|
.await
|
|
.ok()
|
|
.flatten()
|
|
.map(|u| u.display_name.unwrap_or_else(|| u.username))
|
|
.unwrap_or_else(|| user_id.to_string())
|
|
};
|
|
|
|
tracing::info!(user_id = %user_id, display_name = %display_name, "WS transport connection established");
|
|
|
|
let service_arc = Arc::new(service.get_ref().clone());
|
|
let manager = service_arc.room.room_manager.clone();
|
|
manager.metrics.ws_connections_active.increment(1.0);
|
|
manager.metrics.ws_connections_total.increment(1);
|
|
|
|
let mut notif_rx = manager.subscribe_user_notification(user_id).await;
|
|
let mut shutdown_rx = manager.subscribe_shutdown();
|
|
|
|
let (response, mut ws_session, mut msg_stream) = actix_ws::handle(&req, stream)?;
|
|
|
|
let spawn_handle = actix::spawn(async move {
|
|
let panic_result = AssertUnwindSafe(async {
|
|
let session = TransportSession::new(
|
|
WsUserCtx {
|
|
user_id,
|
|
device_id: auth_ctx.device_id,
|
|
client_id: auth_ctx.client_id,
|
|
display_name,
|
|
},
|
|
service_arc,
|
|
);
|
|
|
|
// Split state for tokio::select! borrow safety
|
|
let mut last_heartbeat = Instant::now();
|
|
let mut last_activity = Instant::now();
|
|
let mut message_count: u32 = 0;
|
|
let mut rate_window_start = Instant::now();
|
|
let mut heartbeat_interval = tokio::time::interval(HEARTBEAT_INTERVAL);
|
|
heartbeat_interval.tick().await;
|
|
|
|
loop {
|
|
tokio::select! {
|
|
// ── Heartbeat ──
|
|
_ = heartbeat_interval.tick() => {
|
|
if last_heartbeat.elapsed() > HEARTBEAT_TIMEOUT {
|
|
tracing::warn!(user_id = %user_id, "WS transport heartbeat timeout");
|
|
manager.metrics.ws_heartbeat_timeout_total.increment(1);
|
|
let _ = ws_session.close(Some(actix_ws::CloseCode::Policy.into())).await;
|
|
break;
|
|
}
|
|
if last_activity.elapsed() > MAX_IDLE_TIMEOUT {
|
|
tracing::info!(user_id = %user_id, "WS transport idle timeout");
|
|
manager.metrics.ws_idle_timeout_total.increment(1);
|
|
let _ = ws_session.close(Some(actix_ws::CloseCode::Normal.into())).await;
|
|
break;
|
|
}
|
|
if ws_session.ping(b"").await.is_err() { break; }
|
|
manager.metrics.ws_heartbeat_sent_total.increment(1);
|
|
}
|
|
// ── Shutdown ──
|
|
_ = shutdown_rx.recv() => {
|
|
tracing::info!("WS transport shutdown");
|
|
let _ = ws_session.close(Some(actix_ws::CloseCode::Normal.into())).await;
|
|
break;
|
|
}
|
|
// ── Notification push ──
|
|
notif = poll_notifications(&mut notif_rx) => {
|
|
if let Some(event) = notif {
|
|
if send_event(&mut ws_session, &event).await.is_err() { break; }
|
|
}
|
|
}
|
|
// ── Room broadcast push ──
|
|
push = poll_subscriptions(&session) => {
|
|
if let Some(event) = push {
|
|
if send_event(&mut ws_session, &event).await.is_err() { break; }
|
|
}
|
|
}
|
|
// ── Inbound client message ──
|
|
msg = msg_stream.recv() => {
|
|
match msg {
|
|
Some(Ok(WsMessage::Ping(bytes))) => {
|
|
if ws_session.pong(&bytes).await.is_err() { break; }
|
|
last_heartbeat = Instant::now();
|
|
}
|
|
Some(Ok(WsMessage::Pong(_))) => { last_heartbeat = Instant::now(); }
|
|
Some(Ok(WsMessage::Text(text))) => {
|
|
if last_activity.elapsed() > MAX_IDLE_TIMEOUT { break; }
|
|
last_activity = Instant::now();
|
|
last_heartbeat = Instant::now();
|
|
|
|
// Rate limit
|
|
if rate_window_start.elapsed() > super::session::RATE_LIMIT_WINDOW {
|
|
message_count = 0;
|
|
rate_window_start = Instant::now();
|
|
}
|
|
message_count += 1;
|
|
if message_count > MAX_MESSAGES_PER_SECOND {
|
|
let _ = ws_session.text(serde_json::json!({
|
|
"type": "error", "error": "rate_limit_exceeded"
|
|
}).to_string()).await;
|
|
continue;
|
|
}
|
|
if text.len() > MAX_TEXT_MESSAGE_LEN {
|
|
let _ = ws_session.text(serde_json::json!({
|
|
"type": "error", "error": "message_too_long"
|
|
}).to_string()).await;
|
|
continue;
|
|
}
|
|
|
|
// Parse once — extract request_id and deserialize together.
|
|
let json_value = serde_json::from_str::<serde_json::Value>(&text);
|
|
|
|
// Application-level JSON ping (distinguish from WebSocket Ping frame)
|
|
if text.trim() == r#"{"type":"ping"}"# || text.trim() == r#"{"type":"ping","_request_id":null}"# {
|
|
if ws_session.text(r#"{"type":"pong"}"#).await.is_err() { break; }
|
|
continue;
|
|
}
|
|
|
|
// Extract _request_id from the Value, then deserialize WsInMessage
|
|
let request_id: Option<Uuid> = json_value
|
|
.ok()
|
|
.and_then(|v| v.get("_request_id")
|
|
.and_then(|r| serde_json::from_value(r.clone()).ok()));
|
|
|
|
match serde_json::from_str::<WsInMessage>(&text) {
|
|
Ok(in_msg) => {
|
|
match MessageHandler::handle(&session, in_msg).await {
|
|
Ok(Some(event)) => {
|
|
let rid = request_id.unwrap_or(Uuid::nil());
|
|
let resp = WsOutEvent::Response {
|
|
request_id: rid,
|
|
data: serde_json::to_value(&event).unwrap_or_default(),
|
|
};
|
|
if send_event(&mut ws_session, &resp).await.is_err() { break; }
|
|
}
|
|
Ok(None) => {
|
|
let rid = request_id.unwrap_or(Uuid::nil());
|
|
let ack = WsOutEvent::Response {
|
|
request_id: rid,
|
|
data: serde_json::json!({"ok": true}),
|
|
};
|
|
if send_event(&mut ws_session, &ack).await.is_err() { break; }
|
|
}
|
|
Err(e) => {
|
|
tracing::warn!(user_id = %user_id, error = %e, "WS message processing failed");
|
|
let rid = request_id.unwrap_or(Uuid::nil());
|
|
let (code, error_type) = e.ws_error_code();
|
|
let err_json = serde_json::json!({
|
|
"type": "error",
|
|
"code": code,
|
|
"error": error_type,
|
|
"message": e.to_string(),
|
|
"_request_id": rid
|
|
});
|
|
let _ = ws_session.text(err_json.to_string()).await;
|
|
}
|
|
}
|
|
}
|
|
Err(e) => {
|
|
tracing::warn!(error = %e, "WS transport parse error");
|
|
let _ = ws_session.text(serde_json::json!({
|
|
"type": "error", "error": "parse_error"
|
|
}).to_string()).await;
|
|
}
|
|
}
|
|
}
|
|
Some(Ok(WsMessage::Binary(_))) => {
|
|
let _ = ws_session.close(Some(actix_ws::CloseCode::Unsupported.into())).await;
|
|
break;
|
|
}
|
|
Some(Ok(WsMessage::Continuation(_))) => {}
|
|
Some(Ok(WsMessage::Nop)) => {}
|
|
Some(Ok(WsMessage::Close(reason))) => {
|
|
let _ = ws_session.close(reason).await;
|
|
break;
|
|
}
|
|
Some(Err(e)) => {
|
|
tracing::warn!(error = %e, "WS transport error");
|
|
let _ = ws_session.close(Some(actix_ws::CloseCode::Protocol.into())).await;
|
|
break;
|
|
}
|
|
None => break,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Cleanup
|
|
for sub in session.subscriptions.iter() {
|
|
manager.unsubscribe(sub.room_id, user_id).await;
|
|
}
|
|
manager.unsubscribe_user_notification(user_id).await;
|
|
// Remove presence entry so disconnected users don't appear online for up to 10 minutes
|
|
let project_id = session.project_id.lock().await;
|
|
session.service.room.remove_user_presence(user_id, *project_id).await;
|
|
manager.metrics.ws_connections_active.decrement(1.0);
|
|
manager.metrics.ws_disconnections_total.increment(1);
|
|
}).catch_unwind();
|
|
|
|
if let Err(panic_err) = panic_result.await {
|
|
let panic_msg = if let Some(s) = panic_err.downcast_ref::<String>() {
|
|
s.clone()
|
|
} else if let Some(s) = panic_err.downcast_ref::<&str>() {
|
|
s.to_string()
|
|
} else {
|
|
"Unknown panic".to_string()
|
|
};
|
|
tracing::error!(user_id = %user_id, panic = %panic_msg, "WS transport task panicked");
|
|
manager.metrics.ws_connections_active.decrement(1.0);
|
|
manager.metrics.ws_disconnections_total.increment(1);
|
|
}
|
|
});
|
|
|
|
// Drop the handle intentionally — cleanup is handled inside the spawned task
|
|
drop(spawn_handle);
|
|
|
|
Ok(response)
|
|
}
|
|
|
|
async fn send_event(ws_session: &mut actix_ws::Session, event: &WsOutEvent) -> Result<(), ()> {
|
|
match serde_json::to_string(event) {
|
|
Ok(json) => ws_session.text(json).await.map_err(|_| {}),
|
|
Err(e) => {
|
|
tracing::error!(error = %e, "WS transport serialize error");
|
|
Err(())
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn authenticate_ws(
|
|
service: &AppService,
|
|
req: &HttpRequest,
|
|
) -> Result<crate::token::AppTransportTokenContext, actix_web::Error> {
|
|
// Prefer Authorization header over query parameter
|
|
if let Some(auth_header) = req.headers().get("Authorization") {
|
|
if let Ok(auth_str) = auth_header.to_str() {
|
|
if let Some(token) = auth_str.strip_prefix("Bearer ") {
|
|
match service.ws_token.validate_token_ctx(token).await {
|
|
Ok(ctx) => {
|
|
return Ok(crate::token::AppTransportTokenContext {
|
|
user_id: ctx.user_id,
|
|
device_id: ctx.device_id.unwrap_or_default(),
|
|
client_id: ctx.client_id.unwrap_or_default(),
|
|
});
|
|
}
|
|
Err(_) => {
|
|
service
|
|
.room
|
|
.room_manager
|
|
.metrics
|
|
.ws_auth_failures
|
|
.increment(1);
|
|
return Err(actix_web::error::ErrorUnauthorized("token auth failed"));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Fallback: token in query string (deprecated, kept for backward compatibility)
|
|
if let Some(token) = req.uri().query().and_then(|q| {
|
|
q.split('&')
|
|
.find(|p| p.starts_with("token="))
|
|
.and_then(|p| p.split('=').nth(1))
|
|
}) {
|
|
match service.ws_token.validate_token_ctx(token).await {
|
|
Ok(ctx) => {
|
|
return Ok(crate::token::AppTransportTokenContext {
|
|
user_id: ctx.user_id,
|
|
device_id: ctx.device_id.unwrap_or_default(),
|
|
client_id: ctx.client_id.unwrap_or_default(),
|
|
});
|
|
}
|
|
Err(_) => {
|
|
service
|
|
.room
|
|
.room_manager
|
|
.metrics
|
|
.ws_auth_failures
|
|
.increment(1);
|
|
return Err(actix_web::error::ErrorUnauthorized("token auth failed"));
|
|
}
|
|
}
|
|
}
|
|
|
|
service
|
|
.room
|
|
.room_manager
|
|
.metrics
|
|
.ws_auth_failures
|
|
.increment(1);
|
|
Err(actix_web::error::ErrorUnauthorized("no auth provided"))
|
|
}
|