gitdataai/libs/api/room/ws_universal.rs
2026-04-14 19:02:01 +08:00

443 lines
19 KiB
Rust

use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use actix_web::{HttpRequest, HttpResponse, web};
use actix_ws::Message as WsMessage;
use tokio_stream::StreamExt;
use tokio_stream::wrappers::BroadcastStream;
use uuid::Uuid;
use crate::error::ApiError;
use queue::{ReactionGroup, RoomMessageEvent, RoomMessageStreamChunkEvent};
use service::AppService;
use super::ws::validate_origin;
use super::ws_handler::WsRequestHandler;
use super::ws_types::{WsAction, WsRequest, WsResponse, WsResponseData};
const MAX_TEXT_MESSAGE_LEN: usize = 64 * 1024;
const MAX_MESSAGES_PER_SECOND: u32 = 10;
const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(30);
const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(60);
const MAX_IDLE_TIMEOUT: Duration = Duration::from_secs(300);
const RATE_LIMIT_WINDOW: Duration = Duration::from_secs(1);
/// Unified push event from any subscribed room.
#[derive(Debug, Clone)]
pub enum WsPushEvent {
RoomMessage {
room_id: Uuid,
event: Arc<RoomMessageEvent>,
},
ReactionUpdated {
room_id: Uuid,
message_id: Uuid,
reactions: Vec<ReactionGroup>,
},
AiStreamChunk {
room_id: Uuid,
chunk: Arc<RoomMessageStreamChunkEvent>,
},
}
/// Maps room_id -> (room_message_broadcast_stream, stream_chunk_broadcast_stream)
type PushStreams = HashMap<
Uuid,
(
BroadcastStream<Arc<RoomMessageEvent>>,
BroadcastStream<Arc<RoomMessageStreamChunkEvent>>,
),
>;
pub async fn ws_universal(
service: web::Data<AppService>,
req: HttpRequest,
stream: web::Payload,
) -> Result<HttpResponse, actix_web::Error> {
let origin_val = req
.headers()
.get("origin")
.and_then(|v| v.to_str().ok())
.unwrap_or("(none)");
if !validate_origin(&req) {
slog::warn!(
service.logs,
"WS universal: origin rejected origin={}",
origin_val
);
return Err(ApiError(service::error::AppError::BadRequest(
"Invalid origin".into(),
))
.into());
}
// Validate token BEFORE actix_ws::handle() so we can return a proper HTTP
// error if validation fails. Returning an HTTP error after handle() has been
// called (even if the handler returns an error) sends a 200 OK on what the
// browser expects to be a 101 Switching Protocols response — causing
// immediate close with readyState=3.
let user_id = if let Some(token) = req.uri().query().and_then(|q| {
q.split('&')
.find(|p| p.starts_with("token="))
.and_then(|p| p.split('=').nth(1))
}) {
slog::info!(
service.logs,
"WS universal: validating token token={} origin={}",
token,
origin_val
);
match service.ws_token.validate_token(token).await {
Ok(uid) => {
slog::info!(
service.logs,
"WS universal: token auth successful uid={} origin={}",
uid,
origin_val
);
uid
}
Err(e) => {
slog::warn!(
service.logs,
"WS universal: token auth failed: {:?} token={}",
e,
token
);
service
.room
.room_manager
.metrics
.ws_auth_failures
.increment(1);
return Err(ApiError(service::error::AppError::Unauthorized).into());
}
}
} else {
let auth_header = req
.headers()
.get("Authorization")
.and_then(|v| v.to_str().ok());
let token = match auth_header {
Some(h) if h.starts_with("Bearer ") => &h[7..],
_ => {
service
.room
.room_manager
.metrics
.ws_auth_failures
.increment(1);
return Err(ApiError(service::error::AppError::Unauthorized).into());
}
};
match extract_user_id_from_token(token) {
Some(id) => id,
None => {
service
.room
.room_manager
.metrics
.ws_auth_failures
.increment(1);
return Err(ApiError(service::error::AppError::Unauthorized).into());
}
}
};
slog::debug!(
service.logs,
"WS universal connection established user_id={} origin={}",
user_id,
origin_val
);
let service = service.get_ref().clone();
let manager = service.room.room_manager.clone();
manager.metrics.ws_connections_active.increment(1.0);
manager.metrics.ws_connections_total.increment(1);
let logs = service.logs.clone();
let (response, mut session, mut msg_stream) = actix_ws::handle(&req, stream)?;
actix::spawn(async move {
let handler = WsRequestHandler::new(Arc::new(service), user_id);
let mut push_streams: PushStreams = HashMap::new();
let mut shutdown_rx = manager.subscribe_shutdown();
let mut last_heartbeat = Instant::now();
let mut last_activity = Instant::now();
let mut heartbeat_interval = tokio::time::interval(HEARTBEAT_INTERVAL);
heartbeat_interval.tick().await;
let mut message_count: u32 = 0;
let mut rate_window_start = Instant::now();
loop {
tokio::select! {
_ = heartbeat_interval.tick() => {
if last_heartbeat.elapsed() > HEARTBEAT_TIMEOUT {
slog::warn!(logs, "WS universal heartbeat timeout for user {}", user_id);
manager.metrics.ws_heartbeat_timeout_total.increment(1);
let _ = session.close(Some(actix_ws::CloseCode::Policy.into())).await;
break;
}
if last_activity.elapsed() > MAX_IDLE_TIMEOUT {
slog::info!(logs, "WS universal idle timeout for user {}", user_id);
manager.metrics.ws_idle_timeout_total.increment(1);
let _ = session.close(Some(actix_ws::CloseCode::Normal.into())).await;
break;
}
if session.ping(b"").await.is_err() {
break;
}
manager.metrics.ws_heartbeat_sent_total.increment(1);
}
_ = shutdown_rx.recv() => {
slog::info!(logs, "WS universal shutdown");
let _ = session.close(Some(actix_ws::CloseCode::Normal.into())).await;
break;
}
push_event = poll_push_streams(&mut push_streams) => {
match push_event {
Some(WsPushEvent::RoomMessage { room_id, event }) => {
let payload = serde_json::json!({
"type": "event",
"event": "room.message",
"room_id": room_id,
"data": {
"id": event.id,
"room_id": event.room_id,
"sender_type": event.sender_type,
"sender_id": event.sender_id,
"thread_id": event.thread_id,
"content": event.content,
"content_type": event.content_type,
"send_at": event.send_at,
"seq": event.seq,
"display_name": event.display_name,
},
});
if session.text(payload.to_string()).await.is_err() {
break;
}
}
Some(WsPushEvent::ReactionUpdated { room_id, message_id, reactions }) => {
let payload = serde_json::json!({
"type": "event",
"event": "room.reaction_updated",
"room_id": room_id,
"data": {
"message_id": message_id,
"reactions": reactions,
},
});
if session.text(payload.to_string()).await.is_err() {
break;
}
}
Some(WsPushEvent::AiStreamChunk { room_id, chunk }) => {
let payload = serde_json::json!({
"type": "event",
"event": "ai.stream_chunk",
"room_id": room_id,
"data": {
"message_id": chunk.message_id,
"room_id": chunk.room_id,
"content": chunk.content,
"done": chunk.done,
"error": chunk.error,
},
});
if session.text(payload.to_string()).await.is_err() {
break;
}
}
None => {
}
}
}
msg = msg_stream.recv() => {
match msg {
Some(Ok(WsMessage::Ping(bytes))) => {
if session.pong(&bytes).await.is_err() { break; }
last_heartbeat = Instant::now();
}
Some(Ok(WsMessage::Pong(_))) => { last_heartbeat = Instant::now(); }
Some(Ok(WsMessage::Text(text))) => {
if last_activity.elapsed() > MAX_IDLE_TIMEOUT {
slog::info!(logs, "WS universal idle timeout for user {}", user_id);
manager.metrics.ws_idle_timeout_total.increment(1);
let _ = session.close(Some(actix_ws::CloseCode::Normal.into())).await;
break;
}
last_activity = Instant::now();
if rate_window_start.elapsed() > RATE_LIMIT_WINDOW {
message_count = 0;
rate_window_start = Instant::now();
}
message_count += 1;
if message_count > MAX_MESSAGES_PER_SECOND {
slog::warn!(logs, "WS universal rate limit exceeded for user {}", user_id);
manager.metrics.ws_rate_limit_hits.increment(1);
let _ = session.text(serde_json::json!({"type":"error","error":"rate_limit_exceeded"}).to_string()).await;
continue;
}
if text.len() > MAX_TEXT_MESSAGE_LEN {
slog::warn!(logs, "WS universal message too long from user {}: {} bytes", user_id, text.len());
let _ = session.text(serde_json::json!({"type":"error","error":"message_too_long"}).to_string()).await;
continue;
}
match serde_json::from_str::<WsRequest>(&text) {
Ok(request) => {
let action_str = request.action.to_string();
match request.action {
WsAction::SubscribeRoom => {
if let Some(room_id) = request.params().room_id {
match manager.subscribe(room_id, user_id).await {
Ok(rx) => {
let stream_rx = manager.subscribe_room_stream(room_id).await;
push_streams.insert(room_id, (
BroadcastStream::new(rx),
BroadcastStream::new(stream_rx),
));
let _ = session.text(serde_json::to_string(&WsResponse::success(
request.request_id, &action_str,
WsResponseData::subscribed(Some(room_id), None)
)).unwrap_or_default()).await;
}
Err(e) => {
let _ = session.text(serde_json::to_string(&WsResponse::error_response(
request.request_id, &action_str, 403, "subscribe_failed", &format!("{}", e)
)).unwrap_or_default()).await;
}
}
} else {
let _ = session.text(serde_json::to_string(&WsResponse::error_response(
request.request_id, &action_str, 400, "bad_request", "room_id required"
)).unwrap_or_default()).await;
}
}
WsAction::UnsubscribeRoom => {
if let Some(room_id) = request.params().room_id {
manager.unsubscribe(room_id, user_id).await;
push_streams.remove(&room_id);
}
let _ = session.text(serde_json::to_string(&WsResponse::success(
request.request_id, &action_str, WsResponseData::bool(true)
)).unwrap_or_default()).await;
}
_ => {
let resp = handler.handle(request).await;
let _ = session.text(serde_json::to_string(&resp).unwrap_or_default()).await;
}
}
}
Err(e) => {
slog::warn!(logs, "WS universal parse error from user {}: {}", user_id, e);
let _ = session.text(serde_json::json!({"type":"error","error":"parse_error"}).to_string()).await;
}
}
}
Some(Ok(WsMessage::Binary(_))) => { break; }
Some(Ok(WsMessage::Continuation(_))) => {}
Some(Ok(WsMessage::Nop)) => {}
Some(Ok(WsMessage::Close(reason))) => { let _ = session.close(reason).await; break; }
Some(Err(e)) => { slog::warn!(logs, "WS error: {}", e); break; }
None => break,
}
}
}
}
// Clean up subscriptions on disconnect
for room_id in push_streams.keys() {
manager.unsubscribe(*room_id, user_id).await;
}
manager.metrics.ws_connections_active.decrement(1.0);
manager.metrics.ws_disconnections_total.increment(1);
});
Ok(response)
}
async fn poll_push_streams(streams: &mut PushStreams) -> Option<WsPushEvent> {
loop {
let room_ids: Vec<Uuid> = streams.keys().copied().collect();
for room_id in room_ids {
if let Some((msg_stream, chunk_stream)) = streams.get_mut(&room_id) {
tokio::select! {
result = msg_stream.next() => {
match result {
Some(Ok(event)) => {
if let Some(reactions) = event.reactions.clone() {
return Some(WsPushEvent::ReactionUpdated {
room_id: event.room_id,
message_id: event.id,
reactions,
});
}
return Some(WsPushEvent::RoomMessage { room_id, event });
}
Some(Err(_)) => {
streams.remove(&room_id);
}
None => {
streams.remove(&room_id);
}
}
}
result = chunk_stream.next() => {
match result {
Some(Ok(chunk)) => {
return Some(WsPushEvent::AiStreamChunk { room_id, chunk });
}
Some(Err(_)) | None => {
streams.remove(&room_id);
}
}
}
}
}
}
if streams.is_empty() {
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
return None;
}
tokio::task::yield_now().await;
}
}
fn extract_user_id_from_token(token: &str) -> Option<Uuid> {
if token.len() < 64 {
return None;
}
let token_data = base64_decode(token)?;
if token_data.len() < 16 {
return None;
}
let bytes: [u8; 16] = token_data[..16].try_into().ok()?;
Some(Uuid::from_bytes(bytes))
}
fn base64_decode(input: &str) -> Option<Vec<u8>> {
let table = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
let mut result = Vec::with_capacity(input.len() * 3 / 4);
let mut buffer: u32 = 0;
let mut bits = 0;
for byte in input.bytes() {
if byte == b'=' || byte == b'\n' || byte == b'\r' || byte == b' ' {
continue;
}
let idx = table.iter().position(|&x| x == byte)?;
buffer = (buffer << 6) | (idx as u32);
bits += 6;
if bits >= 8 {
bits -= 8;
result.push((buffer >> bits) as u8);
}
}
Some(result)
}