443 lines
19 KiB
Rust
443 lines
19 KiB
Rust
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
use std::time::{Duration, Instant};
|
|
|
|
use actix_web::{HttpRequest, HttpResponse, web};
|
|
use actix_ws::Message as WsMessage;
|
|
use tokio_stream::StreamExt;
|
|
use tokio_stream::wrappers::BroadcastStream;
|
|
use uuid::Uuid;
|
|
|
|
use crate::error::ApiError;
|
|
use queue::{ReactionGroup, RoomMessageEvent, RoomMessageStreamChunkEvent};
|
|
use service::AppService;
|
|
|
|
use super::ws::validate_origin;
|
|
use super::ws_handler::WsRequestHandler;
|
|
use super::ws_types::{WsAction, WsRequest, WsResponse, WsResponseData};
|
|
|
|
const MAX_TEXT_MESSAGE_LEN: usize = 64 * 1024;
|
|
const MAX_MESSAGES_PER_SECOND: u32 = 10;
|
|
const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(30);
|
|
const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(60);
|
|
const MAX_IDLE_TIMEOUT: Duration = Duration::from_secs(300);
|
|
const RATE_LIMIT_WINDOW: Duration = Duration::from_secs(1);
|
|
|
|
/// Unified push event from any subscribed room.
|
|
#[derive(Debug, Clone)]
|
|
pub enum WsPushEvent {
|
|
RoomMessage {
|
|
room_id: Uuid,
|
|
event: Arc<RoomMessageEvent>,
|
|
},
|
|
ReactionUpdated {
|
|
room_id: Uuid,
|
|
message_id: Uuid,
|
|
reactions: Vec<ReactionGroup>,
|
|
},
|
|
AiStreamChunk {
|
|
room_id: Uuid,
|
|
chunk: Arc<RoomMessageStreamChunkEvent>,
|
|
},
|
|
}
|
|
|
|
/// Maps room_id -> (room_message_broadcast_stream, stream_chunk_broadcast_stream)
|
|
type PushStreams = HashMap<
|
|
Uuid,
|
|
(
|
|
BroadcastStream<Arc<RoomMessageEvent>>,
|
|
BroadcastStream<Arc<RoomMessageStreamChunkEvent>>,
|
|
),
|
|
>;
|
|
|
|
pub async fn ws_universal(
|
|
service: web::Data<AppService>,
|
|
req: HttpRequest,
|
|
stream: web::Payload,
|
|
) -> Result<HttpResponse, actix_web::Error> {
|
|
let origin_val = req
|
|
.headers()
|
|
.get("origin")
|
|
.and_then(|v| v.to_str().ok())
|
|
.unwrap_or("(none)");
|
|
if !validate_origin(&req) {
|
|
slog::warn!(
|
|
service.logs,
|
|
"WS universal: origin rejected origin={}",
|
|
origin_val
|
|
);
|
|
return Err(ApiError(service::error::AppError::BadRequest(
|
|
"Invalid origin".into(),
|
|
))
|
|
.into());
|
|
}
|
|
|
|
// Validate token BEFORE actix_ws::handle() so we can return a proper HTTP
|
|
// error if validation fails. Returning an HTTP error after handle() has been
|
|
// called (even if the handler returns an error) sends a 200 OK on what the
|
|
// browser expects to be a 101 Switching Protocols response — causing
|
|
// immediate close with readyState=3.
|
|
let user_id = if let Some(token) = req.uri().query().and_then(|q| {
|
|
q.split('&')
|
|
.find(|p| p.starts_with("token="))
|
|
.and_then(|p| p.split('=').nth(1))
|
|
}) {
|
|
slog::info!(
|
|
service.logs,
|
|
"WS universal: validating token token={} origin={}",
|
|
token,
|
|
origin_val
|
|
);
|
|
match service.ws_token.validate_token(token).await {
|
|
Ok(uid) => {
|
|
slog::info!(
|
|
service.logs,
|
|
"WS universal: token auth successful uid={} origin={}",
|
|
uid,
|
|
origin_val
|
|
);
|
|
uid
|
|
}
|
|
Err(e) => {
|
|
slog::warn!(
|
|
service.logs,
|
|
"WS universal: token auth failed: {:?} token={}",
|
|
e,
|
|
token
|
|
);
|
|
service
|
|
.room
|
|
.room_manager
|
|
.metrics
|
|
.ws_auth_failures
|
|
.increment(1);
|
|
return Err(ApiError(service::error::AppError::Unauthorized).into());
|
|
}
|
|
}
|
|
} else {
|
|
let auth_header = req
|
|
.headers()
|
|
.get("Authorization")
|
|
.and_then(|v| v.to_str().ok());
|
|
let token = match auth_header {
|
|
Some(h) if h.starts_with("Bearer ") => &h[7..],
|
|
_ => {
|
|
service
|
|
.room
|
|
.room_manager
|
|
.metrics
|
|
.ws_auth_failures
|
|
.increment(1);
|
|
return Err(ApiError(service::error::AppError::Unauthorized).into());
|
|
}
|
|
};
|
|
|
|
match extract_user_id_from_token(token) {
|
|
Some(id) => id,
|
|
None => {
|
|
service
|
|
.room
|
|
.room_manager
|
|
.metrics
|
|
.ws_auth_failures
|
|
.increment(1);
|
|
return Err(ApiError(service::error::AppError::Unauthorized).into());
|
|
}
|
|
}
|
|
};
|
|
|
|
slog::debug!(
|
|
service.logs,
|
|
"WS universal connection established user_id={} origin={}",
|
|
user_id,
|
|
origin_val
|
|
);
|
|
|
|
let service = service.get_ref().clone();
|
|
let manager = service.room.room_manager.clone();
|
|
manager.metrics.ws_connections_active.increment(1.0);
|
|
manager.metrics.ws_connections_total.increment(1);
|
|
|
|
let logs = service.logs.clone();
|
|
let (response, mut session, mut msg_stream) = actix_ws::handle(&req, stream)?;
|
|
actix::spawn(async move {
|
|
let handler = WsRequestHandler::new(Arc::new(service), user_id);
|
|
let mut push_streams: PushStreams = HashMap::new();
|
|
let mut shutdown_rx = manager.subscribe_shutdown();
|
|
let mut last_heartbeat = Instant::now();
|
|
let mut last_activity = Instant::now();
|
|
let mut heartbeat_interval = tokio::time::interval(HEARTBEAT_INTERVAL);
|
|
heartbeat_interval.tick().await;
|
|
let mut message_count: u32 = 0;
|
|
let mut rate_window_start = Instant::now();
|
|
loop {
|
|
tokio::select! {
|
|
_ = heartbeat_interval.tick() => {
|
|
if last_heartbeat.elapsed() > HEARTBEAT_TIMEOUT {
|
|
slog::warn!(logs, "WS universal heartbeat timeout for user {}", user_id);
|
|
manager.metrics.ws_heartbeat_timeout_total.increment(1);
|
|
let _ = session.close(Some(actix_ws::CloseCode::Policy.into())).await;
|
|
break;
|
|
}
|
|
if last_activity.elapsed() > MAX_IDLE_TIMEOUT {
|
|
slog::info!(logs, "WS universal idle timeout for user {}", user_id);
|
|
manager.metrics.ws_idle_timeout_total.increment(1);
|
|
let _ = session.close(Some(actix_ws::CloseCode::Normal.into())).await;
|
|
break;
|
|
}
|
|
if session.ping(b"").await.is_err() {
|
|
break;
|
|
}
|
|
manager.metrics.ws_heartbeat_sent_total.increment(1);
|
|
}
|
|
_ = shutdown_rx.recv() => {
|
|
slog::info!(logs, "WS universal shutdown");
|
|
let _ = session.close(Some(actix_ws::CloseCode::Normal.into())).await;
|
|
break;
|
|
}
|
|
push_event = poll_push_streams(&mut push_streams) => {
|
|
match push_event {
|
|
Some(WsPushEvent::RoomMessage { room_id, event }) => {
|
|
let payload = serde_json::json!({
|
|
"type": "event",
|
|
"event": "room.message",
|
|
"room_id": room_id,
|
|
"data": {
|
|
"id": event.id,
|
|
"room_id": event.room_id,
|
|
"sender_type": event.sender_type,
|
|
"sender_id": event.sender_id,
|
|
"thread_id": event.thread_id,
|
|
"content": event.content,
|
|
"content_type": event.content_type,
|
|
"send_at": event.send_at,
|
|
"seq": event.seq,
|
|
"display_name": event.display_name,
|
|
},
|
|
});
|
|
if session.text(payload.to_string()).await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
Some(WsPushEvent::ReactionUpdated { room_id, message_id, reactions }) => {
|
|
let payload = serde_json::json!({
|
|
"type": "event",
|
|
"event": "room.reaction_updated",
|
|
"room_id": room_id,
|
|
"data": {
|
|
"message_id": message_id,
|
|
"reactions": reactions,
|
|
},
|
|
});
|
|
if session.text(payload.to_string()).await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
Some(WsPushEvent::AiStreamChunk { room_id, chunk }) => {
|
|
let payload = serde_json::json!({
|
|
"type": "event",
|
|
"event": "ai.stream_chunk",
|
|
"room_id": room_id,
|
|
"data": {
|
|
"message_id": chunk.message_id,
|
|
"room_id": chunk.room_id,
|
|
"content": chunk.content,
|
|
"done": chunk.done,
|
|
"error": chunk.error,
|
|
},
|
|
});
|
|
if session.text(payload.to_string()).await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
None => {
|
|
}
|
|
}
|
|
}
|
|
msg = msg_stream.recv() => {
|
|
match msg {
|
|
Some(Ok(WsMessage::Ping(bytes))) => {
|
|
if session.pong(&bytes).await.is_err() { break; }
|
|
last_heartbeat = Instant::now();
|
|
}
|
|
Some(Ok(WsMessage::Pong(_))) => { last_heartbeat = Instant::now(); }
|
|
Some(Ok(WsMessage::Text(text))) => {
|
|
if last_activity.elapsed() > MAX_IDLE_TIMEOUT {
|
|
slog::info!(logs, "WS universal idle timeout for user {}", user_id);
|
|
manager.metrics.ws_idle_timeout_total.increment(1);
|
|
let _ = session.close(Some(actix_ws::CloseCode::Normal.into())).await;
|
|
break;
|
|
}
|
|
last_activity = Instant::now();
|
|
|
|
if rate_window_start.elapsed() > RATE_LIMIT_WINDOW {
|
|
message_count = 0;
|
|
rate_window_start = Instant::now();
|
|
}
|
|
message_count += 1;
|
|
if message_count > MAX_MESSAGES_PER_SECOND {
|
|
slog::warn!(logs, "WS universal rate limit exceeded for user {}", user_id);
|
|
manager.metrics.ws_rate_limit_hits.increment(1);
|
|
let _ = session.text(serde_json::json!({"type":"error","error":"rate_limit_exceeded"}).to_string()).await;
|
|
continue;
|
|
}
|
|
|
|
if text.len() > MAX_TEXT_MESSAGE_LEN {
|
|
slog::warn!(logs, "WS universal message too long from user {}: {} bytes", user_id, text.len());
|
|
let _ = session.text(serde_json::json!({"type":"error","error":"message_too_long"}).to_string()).await;
|
|
continue;
|
|
}
|
|
|
|
match serde_json::from_str::<WsRequest>(&text) {
|
|
Ok(request) => {
|
|
let action_str = request.action.to_string();
|
|
match request.action {
|
|
WsAction::SubscribeRoom => {
|
|
if let Some(room_id) = request.params().room_id {
|
|
match manager.subscribe(room_id, user_id).await {
|
|
Ok(rx) => {
|
|
let stream_rx = manager.subscribe_room_stream(room_id).await;
|
|
push_streams.insert(room_id, (
|
|
BroadcastStream::new(rx),
|
|
BroadcastStream::new(stream_rx),
|
|
));
|
|
let _ = session.text(serde_json::to_string(&WsResponse::success(
|
|
request.request_id, &action_str,
|
|
WsResponseData::subscribed(Some(room_id), None)
|
|
)).unwrap_or_default()).await;
|
|
}
|
|
Err(e) => {
|
|
let _ = session.text(serde_json::to_string(&WsResponse::error_response(
|
|
request.request_id, &action_str, 403, "subscribe_failed", &format!("{}", e)
|
|
)).unwrap_or_default()).await;
|
|
}
|
|
}
|
|
} else {
|
|
let _ = session.text(serde_json::to_string(&WsResponse::error_response(
|
|
request.request_id, &action_str, 400, "bad_request", "room_id required"
|
|
)).unwrap_or_default()).await;
|
|
}
|
|
}
|
|
WsAction::UnsubscribeRoom => {
|
|
if let Some(room_id) = request.params().room_id {
|
|
manager.unsubscribe(room_id, user_id).await;
|
|
push_streams.remove(&room_id);
|
|
}
|
|
let _ = session.text(serde_json::to_string(&WsResponse::success(
|
|
request.request_id, &action_str, WsResponseData::bool(true)
|
|
)).unwrap_or_default()).await;
|
|
}
|
|
_ => {
|
|
let resp = handler.handle(request).await;
|
|
let _ = session.text(serde_json::to_string(&resp).unwrap_or_default()).await;
|
|
}
|
|
}
|
|
}
|
|
Err(e) => {
|
|
slog::warn!(logs, "WS universal parse error from user {}: {}", user_id, e);
|
|
let _ = session.text(serde_json::json!({"type":"error","error":"parse_error"}).to_string()).await;
|
|
}
|
|
}
|
|
}
|
|
Some(Ok(WsMessage::Binary(_))) => { break; }
|
|
Some(Ok(WsMessage::Continuation(_))) => {}
|
|
Some(Ok(WsMessage::Nop)) => {}
|
|
Some(Ok(WsMessage::Close(reason))) => { let _ = session.close(reason).await; break; }
|
|
Some(Err(e)) => { slog::warn!(logs, "WS error: {}", e); break; }
|
|
None => break,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Clean up subscriptions on disconnect
|
|
for room_id in push_streams.keys() {
|
|
manager.unsubscribe(*room_id, user_id).await;
|
|
}
|
|
manager.metrics.ws_connections_active.decrement(1.0);
|
|
manager.metrics.ws_disconnections_total.increment(1);
|
|
});
|
|
|
|
Ok(response)
|
|
}
|
|
|
|
async fn poll_push_streams(streams: &mut PushStreams) -> Option<WsPushEvent> {
|
|
loop {
|
|
let room_ids: Vec<Uuid> = streams.keys().copied().collect();
|
|
for room_id in room_ids {
|
|
if let Some((msg_stream, chunk_stream)) = streams.get_mut(&room_id) {
|
|
tokio::select! {
|
|
result = msg_stream.next() => {
|
|
match result {
|
|
Some(Ok(event)) => {
|
|
if let Some(reactions) = event.reactions.clone() {
|
|
return Some(WsPushEvent::ReactionUpdated {
|
|
room_id: event.room_id,
|
|
message_id: event.id,
|
|
reactions,
|
|
});
|
|
}
|
|
return Some(WsPushEvent::RoomMessage { room_id, event });
|
|
}
|
|
Some(Err(_)) => {
|
|
streams.remove(&room_id);
|
|
}
|
|
None => {
|
|
streams.remove(&room_id);
|
|
}
|
|
}
|
|
}
|
|
result = chunk_stream.next() => {
|
|
match result {
|
|
Some(Ok(chunk)) => {
|
|
return Some(WsPushEvent::AiStreamChunk { room_id, chunk });
|
|
}
|
|
Some(Err(_)) | None => {
|
|
streams.remove(&room_id);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if streams.is_empty() {
|
|
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
|
|
return None;
|
|
}
|
|
tokio::task::yield_now().await;
|
|
}
|
|
}
|
|
|
|
fn extract_user_id_from_token(token: &str) -> Option<Uuid> {
|
|
if token.len() < 64 {
|
|
return None;
|
|
}
|
|
let token_data = base64_decode(token)?;
|
|
if token_data.len() < 16 {
|
|
return None;
|
|
}
|
|
let bytes: [u8; 16] = token_data[..16].try_into().ok()?;
|
|
Some(Uuid::from_bytes(bytes))
|
|
}
|
|
|
|
fn base64_decode(input: &str) -> Option<Vec<u8>> {
|
|
let table = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
|
|
let mut result = Vec::with_capacity(input.len() * 3 / 4);
|
|
let mut buffer: u32 = 0;
|
|
let mut bits = 0;
|
|
|
|
for byte in input.bytes() {
|
|
if byte == b'=' || byte == b'\n' || byte == b'\r' || byte == b' ' {
|
|
continue;
|
|
}
|
|
let idx = table.iter().position(|&x| x == byte)?;
|
|
buffer = (buffer << 6) | (idx as u32);
|
|
bits += 6;
|
|
if bits >= 8 {
|
|
bits -= 8;
|
|
result.push((buffer >> bits) as u8);
|
|
}
|
|
}
|
|
Some(result)
|
|
}
|