1. WS disconnect now unsubscribes from user_notification_inner. Previously, every WebSocket connection created a broadcast channel for user notifications that was never removed on disconnect, causing unbounded growth proportional to unique connected users over time. 2. Room worker tasks now use the manager's room_shutdown_txs channel instead of a local broadcast channel. shutdown_room() sends on this channel, so when a room is deleted the worker task receives the signal and terminates, releasing its DashMap (capacity 10,000) and all captured closures. Previously the worker ran forever.
571 lines
27 KiB
Rust
571 lines
27 KiB
Rust
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
use std::time::{Duration, Instant};
|
|
|
|
use actix_web::{HttpRequest, HttpResponse, web};
|
|
use actix_ws::Message as WsMessage;
|
|
use tokio_stream::StreamExt;
|
|
use tokio_stream::wrappers::BroadcastStream;
|
|
use uuid::Uuid;
|
|
|
|
use crate::error::ApiError;
|
|
use queue::{ReactionGroup, RoomMessageEvent, RoomMessageStreamChunkEvent, TypingEvent};
|
|
use room::types::NotificationEvent;
|
|
use room::connection::RoomConnectionManager;
|
|
use service::AppService;
|
|
|
|
use super::ws::validate_origin;
|
|
use super::ws_handler::WsRequestHandler;
|
|
use super::ws_types::{WsAction, WsRequest, WsResponse, WsResponseData};
|
|
|
|
const MAX_TEXT_MESSAGE_LEN: usize = 64 * 1024;
|
|
const MAX_MESSAGES_PER_SECOND: u32 = 10;
|
|
const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(30);
|
|
const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(60);
|
|
const MAX_IDLE_TIMEOUT: Duration = Duration::from_secs(300);
|
|
const RATE_LIMIT_WINDOW: Duration = Duration::from_secs(1);
|
|
|
|
/// Unified push event from any subscribed room or user notification channel.
|
|
#[derive(Debug, Clone)]
|
|
pub enum WsPushEvent {
|
|
RoomMessage {
|
|
room_id: Uuid,
|
|
event: Arc<RoomMessageEvent>,
|
|
},
|
|
ReactionUpdated {
|
|
room_id: Uuid,
|
|
message_id: Uuid,
|
|
reactions: Vec<ReactionGroup>,
|
|
},
|
|
AiStreamChunk {
|
|
room_id: Uuid,
|
|
chunk: Arc<RoomMessageStreamChunkEvent>,
|
|
},
|
|
TypingIndicator {
|
|
room_id: Uuid,
|
|
event: Arc<TypingEvent>,
|
|
},
|
|
Notification {
|
|
event: Arc<NotificationEvent>,
|
|
},
|
|
}
|
|
|
|
/// Maps room_id -> (room_message_broadcast_stream, stream_chunk_broadcast_stream)
|
|
type PushStreams = HashMap<
|
|
Uuid,
|
|
(
|
|
BroadcastStream<Arc<RoomMessageEvent>>,
|
|
BroadcastStream<Arc<RoomMessageStreamChunkEvent>>,
|
|
BroadcastStream<Arc<TypingEvent>>,
|
|
),
|
|
>;
|
|
|
|
pub async fn ws_universal(
|
|
service: web::Data<AppService>,
|
|
req: HttpRequest,
|
|
stream: web::Payload,
|
|
) -> Result<HttpResponse, actix_web::Error> {
|
|
let origin_val = req
|
|
.headers()
|
|
.get("origin")
|
|
.and_then(|v| v.to_str().ok())
|
|
.unwrap_or("(none)");
|
|
if !validate_origin(&req) {
|
|
tracing::warn!(
|
|
origin = %origin_val,
|
|
"WS universal: origin rejected"
|
|
);
|
|
return Err(ApiError(service::error::AppError::BadRequest(
|
|
"Invalid origin".into(),
|
|
))
|
|
.into());
|
|
}
|
|
|
|
// Validate token BEFORE actix_ws::handle() so we can return a proper HTTP
|
|
// error if validation fails. Returning an HTTP error after handle() has been
|
|
// called (even if the handler returns an error) sends a 200 OK on what the
|
|
// browser expects to be a 101 Switching Protocols response — causing
|
|
// immediate close with readyState=3.
|
|
let user_id = if let Some(token) = req.uri().query().and_then(|q| {
|
|
q.split('&')
|
|
.find(|p| p.starts_with("token="))
|
|
.and_then(|p| p.split('=').nth(1))
|
|
}) {
|
|
tracing::info!(
|
|
token = %token,
|
|
origin = %origin_val,
|
|
"WS universal: validating token"
|
|
);
|
|
match service.ws_token.validate_token(token).await {
|
|
Ok(uid) => {
|
|
tracing::info!(
|
|
uid = %uid,
|
|
origin = %origin_val,
|
|
"WS universal: token auth successful"
|
|
);
|
|
uid
|
|
}
|
|
Err(e) => {
|
|
tracing::warn!(
|
|
error = ?e,
|
|
token = %token,
|
|
"WS universal: token auth failed"
|
|
);
|
|
service
|
|
.room
|
|
.room_manager
|
|
.metrics
|
|
.ws_auth_failures
|
|
.increment(1);
|
|
return Err(ApiError(service::error::AppError::Unauthorized).into());
|
|
}
|
|
}
|
|
} else {
|
|
let auth_header = req
|
|
.headers()
|
|
.get("Authorization")
|
|
.and_then(|v| v.to_str().ok());
|
|
let token = match auth_header {
|
|
Some(h) if h.starts_with("Bearer ") => &h[7..],
|
|
_ => {
|
|
service
|
|
.room
|
|
.room_manager
|
|
.metrics
|
|
.ws_auth_failures
|
|
.increment(1);
|
|
return Err(ApiError(service::error::AppError::Unauthorized).into());
|
|
}
|
|
};
|
|
|
|
match extract_user_id_from_token(token) {
|
|
Some(id) => id,
|
|
None => {
|
|
service
|
|
.room
|
|
.room_manager
|
|
.metrics
|
|
.ws_auth_failures
|
|
.increment(1);
|
|
return Err(ApiError(service::error::AppError::Unauthorized).into());
|
|
}
|
|
}
|
|
};
|
|
|
|
tracing::debug!(
|
|
user_id = %user_id,
|
|
origin = %origin_val,
|
|
"WS universal connection established"
|
|
);
|
|
|
|
let service = service.get_ref().clone();
|
|
let manager = service.room.room_manager.clone();
|
|
manager.metrics.ws_connections_active.increment(1.0);
|
|
manager.metrics.ws_connections_total.increment(1);
|
|
|
|
// Subscribe to user-level notification stream immediately on connect
|
|
let notif_rx = manager.subscribe_user_notification(user_id).await;
|
|
let mut notif_stream = BroadcastStream::new(notif_rx);
|
|
|
|
let (response, mut session, mut msg_stream) = actix_ws::handle(&req, stream)?;
|
|
actix::spawn(async move {
|
|
let handler = WsRequestHandler::new(Arc::new(service), user_id);
|
|
let mut push_streams: PushStreams = HashMap::new();
|
|
let mut shutdown_rx = manager.subscribe_shutdown();
|
|
let mut last_heartbeat = Instant::now();
|
|
let mut last_activity = Instant::now();
|
|
let mut heartbeat_interval = tokio::time::interval(HEARTBEAT_INTERVAL);
|
|
heartbeat_interval.tick().await;
|
|
let mut message_count: u32 = 0;
|
|
let mut rate_window_start = Instant::now();
|
|
loop {
|
|
tokio::select! {
|
|
_ = heartbeat_interval.tick() => {
|
|
if last_heartbeat.elapsed() > HEARTBEAT_TIMEOUT {
|
|
tracing::warn!(user_id = %user_id, "WS universal heartbeat timeout");
|
|
manager.metrics.ws_heartbeat_timeout_total.increment(1);
|
|
let _ = session.close(Some(actix_ws::CloseCode::Policy.into())).await;
|
|
break;
|
|
}
|
|
if last_activity.elapsed() > MAX_IDLE_TIMEOUT {
|
|
tracing::info!(user_id = %user_id, "WS universal idle timeout");
|
|
manager.metrics.ws_idle_timeout_total.increment(1);
|
|
let _ = session.close(Some(actix_ws::CloseCode::Normal.into())).await;
|
|
break;
|
|
}
|
|
if session.ping(b"").await.is_err() {
|
|
break;
|
|
}
|
|
manager.metrics.ws_heartbeat_sent_total.increment(1);
|
|
}
|
|
_ = shutdown_rx.recv() => {
|
|
tracing::info!("WS universal shutdown");
|
|
let _ = session.close(Some(actix_ws::CloseCode::Normal.into())).await;
|
|
break;
|
|
}
|
|
notif_result = notif_stream.next() => {
|
|
match notif_result {
|
|
Some(Ok(event)) => {
|
|
let payload = serde_json::json!({
|
|
"type": "event",
|
|
"event": "notification_created",
|
|
"data": {
|
|
"event_type": event.event_type,
|
|
"notification": event.notification,
|
|
"deep_link_url": event.deep_link_url,
|
|
"timestamp": event.timestamp,
|
|
},
|
|
});
|
|
if session.text(payload.to_string()).await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
Some(Err(_)) | None => {
|
|
// Notification channel lagged or closed — re-subscribe
|
|
let rx = manager.subscribe_user_notification(user_id).await;
|
|
notif_stream = BroadcastStream::new(rx);
|
|
}
|
|
}
|
|
}
|
|
push_event = poll_push_streams(&mut push_streams, &manager, &handler.service(), user_id) => {
|
|
match push_event {
|
|
Some(WsPushEvent::RoomMessage { room_id, event }) => {
|
|
let payload = serde_json::json!({
|
|
"type": "event",
|
|
"event": "room.message",
|
|
"room_id": room_id,
|
|
"data": {
|
|
"id": event.id,
|
|
"room_id": event.room_id,
|
|
"sender_type": event.sender_type,
|
|
"sender_id": event.sender_id,
|
|
"thread_id": event.thread_id,
|
|
"content": event.content,
|
|
"content_type": event.content_type,
|
|
"send_at": event.send_at,
|
|
"seq": event.seq,
|
|
"display_name": event.display_name,
|
|
},
|
|
});
|
|
if session.text(payload.to_string()).await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
Some(WsPushEvent::ReactionUpdated { room_id, message_id, reactions }) => {
|
|
let payload = serde_json::json!({
|
|
"type": "event",
|
|
"event": "room.reaction_updated",
|
|
"room_id": room_id,
|
|
"data": {
|
|
"message_id": message_id,
|
|
"reactions": reactions,
|
|
},
|
|
});
|
|
if session.text(payload.to_string()).await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
Some(WsPushEvent::AiStreamChunk { room_id, chunk }) => {
|
|
let payload = serde_json::json!({
|
|
"type": "event",
|
|
"event": "ai.stream_chunk",
|
|
"room_id": room_id,
|
|
"data": {
|
|
"message_id": chunk.message_id,
|
|
"room_id": chunk.room_id,
|
|
"content": chunk.content,
|
|
"done": chunk.done,
|
|
"error": chunk.error,
|
|
"display_name": chunk.display_name,
|
|
"chunk_type": chunk.chunk_type,
|
|
},
|
|
});
|
|
if session.text(payload.to_string()).await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
Some(WsPushEvent::TypingIndicator { room_id, event }) => {
|
|
let payload = serde_json::json!({
|
|
"type": "event",
|
|
"event": "room.typing",
|
|
"room_id": room_id,
|
|
"data": {
|
|
"user_id": event.user_id,
|
|
"username": event.username,
|
|
"avatar_url": event.avatar_url,
|
|
"action": event.action,
|
|
"sender_type": event.sender_type.as_deref().unwrap_or("user"),
|
|
},
|
|
});
|
|
if session.text(payload.to_string()).await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
None => {
|
|
}
|
|
Some(WsPushEvent::Notification { .. }) => {
|
|
// Notification events are handled via the notif_stream branch above
|
|
}
|
|
}
|
|
}
|
|
msg = msg_stream.recv() => {
|
|
match msg {
|
|
Some(Ok(WsMessage::Ping(bytes))) => {
|
|
if session.pong(&bytes).await.is_err() { break; }
|
|
last_heartbeat = Instant::now();
|
|
}
|
|
Some(Ok(WsMessage::Pong(_))) => { last_heartbeat = Instant::now(); }
|
|
Some(Ok(WsMessage::Text(text))) => {
|
|
if last_activity.elapsed() > MAX_IDLE_TIMEOUT {
|
|
tracing::info!(user_id = %user_id, "WS universal idle timeout");
|
|
manager.metrics.ws_idle_timeout_total.increment(1);
|
|
let _ = session.close(Some(actix_ws::CloseCode::Normal.into())).await;
|
|
break;
|
|
}
|
|
last_activity = Instant::now();
|
|
|
|
if rate_window_start.elapsed() > RATE_LIMIT_WINDOW {
|
|
message_count = 0;
|
|
rate_window_start = Instant::now();
|
|
}
|
|
message_count += 1;
|
|
if message_count > MAX_MESSAGES_PER_SECOND {
|
|
tracing::warn!(user_id = %user_id, "WS universal rate limit exceeded");
|
|
manager.metrics.ws_rate_limit_hits.increment(1);
|
|
let _ = session.text(serde_json::json!({"type":"error","error":"rate_limit_exceeded"}).to_string()).await;
|
|
continue;
|
|
}
|
|
|
|
if text.len() > MAX_TEXT_MESSAGE_LEN {
|
|
tracing::warn!(user_id = %user_id, bytes = text.len(), "WS universal message too long");
|
|
let _ = session.text(serde_json::json!({"type":"error","error":"message_too_long"}).to_string()).await;
|
|
continue;
|
|
}
|
|
|
|
// Handle JSON-level ping (application heartbeat).
|
|
// Client sends {"type":"ping"} and we reply with {"type":"pong"}.
|
|
if text.trim() == r#"{"type":"ping"}"# {
|
|
if session.text(r#"{"type":"pong"}"#).await.is_err() { break; }
|
|
last_activity = Instant::now();
|
|
last_heartbeat = Instant::now();
|
|
continue;
|
|
}
|
|
|
|
match serde_json::from_str::<WsRequest>(&text) {
|
|
Ok(request) => {
|
|
let action_str = request.action.to_string();
|
|
match request.action {
|
|
WsAction::SubscribeRoom => {
|
|
if let Some(room_id) = request.params().room_id {
|
|
// Verify user has access to this room before subscribing
|
|
if let Err(e) = handler.service().room.check_room_access(room_id, user_id).await {
|
|
let _ = session.text(serde_json::to_string(&WsResponse::error_response(
|
|
request.request_id, &action_str, 403, "access_denied", &format!("{}", e)
|
|
)).unwrap_or_default()).await;
|
|
} else {
|
|
match manager.subscribe(room_id, user_id).await {
|
|
Ok(rx) => {
|
|
let stream_rx = manager.subscribe_room_stream(room_id).await;
|
|
let typing_rx = manager.subscribe_typing(room_id).await;
|
|
push_streams.insert(room_id, (
|
|
BroadcastStream::new(rx),
|
|
BroadcastStream::new(stream_rx),
|
|
BroadcastStream::new(typing_rx),
|
|
));
|
|
let _ = session.text(serde_json::to_string(&WsResponse::success(
|
|
request.request_id, &action_str,
|
|
WsResponseData::subscribed(Some(room_id), None)
|
|
)).unwrap_or_default()).await;
|
|
}
|
|
Err(e) => {
|
|
let _ = session.text(serde_json::to_string(&WsResponse::error_response(
|
|
request.request_id, &action_str, 500, "subscribe_failed", &format!("{}", e)
|
|
)).unwrap_or_default()).await;
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
let _ = session.text(serde_json::to_string(&WsResponse::error_response(
|
|
request.request_id, &action_str, 400, "bad_request", "room_id required"
|
|
)).unwrap_or_default()).await;
|
|
}
|
|
}
|
|
WsAction::UnsubscribeRoom => {
|
|
if let Some(room_id) = request.params().room_id {
|
|
manager.unsubscribe(room_id, user_id).await;
|
|
push_streams.remove(&room_id);
|
|
}
|
|
let _ = session.text(serde_json::to_string(&WsResponse::success(
|
|
request.request_id, &action_str, WsResponseData::bool(true)
|
|
)).unwrap_or_default()).await;
|
|
}
|
|
WsAction::TypingStart | WsAction::TypingStop => {
|
|
if let (Some(room_id), Some(action)) =
|
|
(request.params().room_id, request.params().typing.as_deref())
|
|
{
|
|
let names = handler.service().room.get_user_names(&[user_id]).await;
|
|
let typing_event = TypingEvent {
|
|
room_id,
|
|
user_id,
|
|
username: names.into_values().next().unwrap_or_else(|| "unknown".to_string()),
|
|
avatar_url: None,
|
|
action: action.to_string(),
|
|
sender_type: None,
|
|
};
|
|
manager.broadcast_typing(room_id, typing_event).await;
|
|
}
|
|
let _ = session.text(serde_json::to_string(&WsResponse::success(
|
|
request.request_id, &action_str, WsResponseData::bool(true)
|
|
)).unwrap_or_default()).await;
|
|
}
|
|
_ => {
|
|
let resp = handler.handle(request).await;
|
|
let _ = session.text(serde_json::to_string(&resp).unwrap_or_default()).await;
|
|
}
|
|
}
|
|
}
|
|
Err(e) => {
|
|
tracing::warn!(user_id = %user_id, error = %e, "WS universal parse error");
|
|
let _ = session.text(serde_json::json!({"type":"error","error":"parse_error"}).to_string()).await;
|
|
}
|
|
}
|
|
}
|
|
Some(Ok(WsMessage::Binary(_))) => { break; }
|
|
Some(Ok(WsMessage::Continuation(_))) => {}
|
|
Some(Ok(WsMessage::Nop)) => {}
|
|
Some(Ok(WsMessage::Close(reason))) => { let _ = session.close(reason).await; break; }
|
|
Some(Err(e)) => { tracing::warn!(error = %e, "WS error"); break; }
|
|
None => break,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Clean up subscriptions on disconnect
|
|
for room_id in push_streams.keys() {
|
|
manager.unsubscribe(*room_id, user_id).await;
|
|
}
|
|
manager.unsubscribe_user_notification(user_id).await;
|
|
manager.metrics.ws_connections_active.decrement(1.0);
|
|
manager.metrics.ws_disconnections_total.increment(1);
|
|
});
|
|
|
|
Ok(response)
|
|
}
|
|
|
|
async fn poll_push_streams(
|
|
streams: &mut PushStreams,
|
|
manager: &Arc<RoomConnectionManager>,
|
|
service: &Arc<AppService>,
|
|
user_id: Uuid,
|
|
) -> Option<WsPushEvent> {
|
|
loop {
|
|
let room_ids: Vec<Uuid> = streams.keys().copied().collect();
|
|
let mut dead_rooms: Vec<Uuid> = Vec::new();
|
|
|
|
for room_id in room_ids {
|
|
if let Some((msg_stream, chunk_stream, typing_stream)) = streams.get_mut(&room_id) {
|
|
tokio::select! {
|
|
result = msg_stream.next() => {
|
|
match result {
|
|
Some(Ok(event)) => {
|
|
if let Some(reactions) = event.reactions.clone() {
|
|
return Some(WsPushEvent::ReactionUpdated {
|
|
room_id: event.room_id,
|
|
message_id: event.message_id.unwrap_or(event.id),
|
|
reactions,
|
|
});
|
|
}
|
|
return Some(WsPushEvent::RoomMessage { room_id, event });
|
|
}
|
|
Some(Err(_)) | None => {
|
|
dead_rooms.push(room_id);
|
|
}
|
|
}
|
|
}
|
|
result = chunk_stream.next() => {
|
|
match result {
|
|
Some(Ok(chunk)) => {
|
|
return Some(WsPushEvent::AiStreamChunk { room_id, chunk });
|
|
}
|
|
Some(Err(_)) | None => {
|
|
dead_rooms.push(room_id);
|
|
}
|
|
}
|
|
}
|
|
result = typing_stream.next() => {
|
|
match result {
|
|
Some(Ok(event)) => {
|
|
return Some(WsPushEvent::TypingIndicator { room_id, event });
|
|
}
|
|
Some(Err(_)) | None => {
|
|
// Typing channel going dead is non-fatal — typing is ephemeral
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Re-subscribe dead rooms so we don't permanently lose events.
|
|
// Re-check access in case the user's permissions were revoked while the
|
|
// stream was dead.
|
|
for room_id in dead_rooms {
|
|
if streams.remove(&room_id).is_some() {
|
|
if service.room.check_room_access(room_id, user_id).await.is_ok() {
|
|
if let Ok(rx) = manager.subscribe(room_id, user_id).await {
|
|
let stream_rx = manager.subscribe_room_stream(room_id).await;
|
|
let typing_rx = manager.subscribe_typing(room_id).await;
|
|
streams.insert(room_id, (
|
|
BroadcastStream::new(rx),
|
|
BroadcastStream::new(stream_rx),
|
|
BroadcastStream::new(typing_rx),
|
|
));
|
|
}
|
|
}
|
|
// If access check fails, silently skip re-subscribe (user was removed)
|
|
}
|
|
}
|
|
|
|
if streams.is_empty() {
|
|
// Yield so the caller can drop us before the next iteration
|
|
tokio::time::sleep(Duration::from_millis(50)).await;
|
|
} else {
|
|
tokio::task::yield_now().await;
|
|
}
|
|
}
|
|
}
|
|
|
|
fn extract_user_id_from_token(token: &str) -> Option<Uuid> {
|
|
if token.len() < 64 {
|
|
return None;
|
|
}
|
|
let token_data = base64_decode(token)?;
|
|
if token_data.len() < 16 {
|
|
return None;
|
|
}
|
|
let bytes: [u8; 16] = token_data[..16].try_into().ok()?;
|
|
Some(Uuid::from_bytes(bytes))
|
|
}
|
|
|
|
fn base64_decode(input: &str) -> Option<Vec<u8>> {
|
|
let table = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
|
|
let mut result = Vec::with_capacity(input.len() * 3 / 4);
|
|
let mut buffer: u32 = 0;
|
|
let mut bits = 0;
|
|
|
|
for byte in input.bytes() {
|
|
if byte == b'=' || byte == b'\n' || byte == b'\r' || byte == b' ' {
|
|
continue;
|
|
}
|
|
let idx = table.iter().position(|&x| x == byte)?;
|
|
buffer = (buffer << 6) | (idx as u32);
|
|
bits += 6;
|
|
if bits >= 8 {
|
|
bits -= 8;
|
|
result.push((buffer >> bits) as u8);
|
|
}
|
|
}
|
|
Some(result)
|
|
}
|