gitdataai/libs/api/room/ws.rs
2026-04-15 09:08:09 +08:00

706 lines
26 KiB
Rust

use std::sync::{Arc, LazyLock};
use std::time::{Duration, Instant};
use actix_web::{HttpMessage, HttpRequest, HttpResponse, web};
use actix_ws::Message as WsMessage;
use serde::Serialize;
use uuid::Uuid;
use queue::{ProjectRoomEvent, RoomMessageEvent, RoomMessageStreamChunkEvent};
use service::AppService;
use session::Session;
const MAX_TEXT_MESSAGE_LEN: usize = 64 * 1024;
const MAX_MESSAGES_PER_SECOND: u32 = 10;
const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(30);
const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(60);
const MAX_IDLE_TIMEOUT: Duration = Duration::from_secs(300);
const RATE_LIMIT_WINDOW: Duration = Duration::from_secs(1);
/// Authenticate WebSocket request: try query parameter token first, then fall back to session.
async fn authenticate_ws_request(
service: &AppService,
req: &HttpRequest,
) -> Result<Uuid, actix_web::Error> {
// Try query parameter token first (one-time use via Redis)
if let Some(token) = req.uri().query().and_then(|q| {
q.split('&')
.find(|p| p.starts_with("token="))
.and_then(|p| p.split('=').nth(1))
}) {
match service.ws_token.validate_token(token).await {
Ok(uid) => {
slog::debug!(service.logs, "WS: token auth successful for uid={}", uid);
return Ok(uid);
}
Err(_) => {
slog::warn!(service.logs, "WS: token auth failed");
service
.room
.room_manager
.metrics
.ws_auth_failures
.increment(1);
return Err(crate::error::ApiError(service::error::AppError::Unauthorized).into());
}
}
}
// Fall back to session-based auth
let session = Session::get_session(&mut req.extensions_mut());
match session.user() {
Some(uid) => Ok(uid),
None => {
service
.room
.room_manager
.metrics
.ws_auth_failures
.increment(1);
Err(crate::error::ApiError(service::error::AppError::Unauthorized).into())
}
}
}
async fn check_ws_rate_limit(
log: &slog::Logger,
manager: &Arc<room::connection::RoomConnectionManager>,
message_count: &mut u32,
rate_window_start: &mut Instant,
) -> bool {
if rate_window_start.elapsed() > RATE_LIMIT_WINDOW {
*message_count = 0;
*rate_window_start = Instant::now();
}
*message_count += 1;
if *message_count > MAX_MESSAGES_PER_SECOND {
slog::warn!(log, "WS rate limit exceeded");
manager.metrics.ws_rate_limit_hits.increment(1);
true
} else {
false
}
}
#[derive(Clone, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum WsEventPayload {
RoomMessage(RoomMessagePayload),
ProjectEvent(ProjectEventPayload),
AiStreamChunk(AiStreamChunkPayload),
}
#[derive(Clone, Serialize)]
pub struct AiStreamChunkPayload {
pub message_id: Uuid,
pub room_id: Uuid,
pub content: String,
pub done: bool,
pub error: Option<String>,
}
impl From<RoomMessageStreamChunkEvent> for AiStreamChunkPayload {
fn from(e: RoomMessageStreamChunkEvent) -> Self {
Self {
message_id: e.message_id,
room_id: e.room_id,
content: e.content,
done: e.done,
error: e.error,
}
}
}
impl From<Arc<RoomMessageStreamChunkEvent>> for AiStreamChunkPayload {
fn from(e: Arc<RoomMessageStreamChunkEvent>) -> Self {
AiStreamChunkPayload::from((&*e).clone())
}
}
#[derive(Clone, Serialize)]
pub struct RoomMessagePayload {
pub id: Uuid,
pub room_id: Uuid,
pub sender_type: String,
pub sender_id: Option<Uuid>,
pub thread_id: Option<Uuid>,
pub content: String,
pub content_type: String,
pub send_at: chrono::DateTime<chrono::Utc>,
pub seq: i64,
pub display_name: Option<String>,
}
impl From<RoomMessageEvent> for RoomMessagePayload {
fn from(e: RoomMessageEvent) -> Self {
Self {
id: e.id,
room_id: e.room_id,
sender_type: e.sender_type,
sender_id: e.sender_id,
thread_id: e.thread_id,
content: e.content,
content_type: e.content_type,
send_at: e.send_at,
seq: e.seq,
display_name: e.display_name,
}
}
}
impl From<Arc<RoomMessageEvent>> for RoomMessagePayload {
fn from(e: Arc<RoomMessageEvent>) -> Self {
RoomMessagePayload::from((&*e).clone())
}
}
impl From<&RoomMessageEvent> for RoomMessagePayload {
fn from(e: &RoomMessageEvent) -> Self {
Self {
id: e.id,
room_id: e.room_id,
sender_type: e.sender_type.clone(),
sender_id: e.sender_id,
thread_id: e.thread_id,
content: e.content.clone(),
content_type: e.content_type.clone(),
send_at: e.send_at,
seq: e.seq,
display_name: e.display_name.clone(),
}
}
}
#[derive(Clone, Serialize)]
pub struct ProjectEventPayload {
pub event_type: String,
pub project_id: Uuid,
pub room_id: Option<Uuid>,
pub category_id: Option<Uuid>,
pub message_id: Option<Uuid>,
pub seq: Option<i64>,
pub timestamp: chrono::DateTime<chrono::Utc>,
}
impl From<ProjectRoomEvent> for ProjectEventPayload {
fn from(e: ProjectRoomEvent) -> Self {
Self {
event_type: e.event_type,
project_id: e.project_id,
room_id: e.room_id,
category_id: e.category_id,
message_id: e.message_id,
seq: e.seq,
timestamp: e.timestamp,
}
}
}
impl From<Arc<ProjectRoomEvent>> for ProjectEventPayload {
fn from(e: Arc<ProjectRoomEvent>) -> Self {
ProjectEventPayload::from((&*e).clone())
}
}
impl From<&ProjectRoomEvent> for ProjectEventPayload {
fn from(e: &ProjectRoomEvent) -> Self {
Self {
event_type: e.event_type.clone(),
project_id: e.project_id,
room_id: e.room_id,
category_id: e.category_id,
message_id: e.message_id,
seq: e.seq,
timestamp: e.timestamp,
}
}
}
#[derive(Clone, Serialize)]
pub struct WsOutEvent {
#[serde(skip_serializing_if = "Option::is_none")]
pub room_id: Option<Uuid>,
#[serde(skip_serializing_if = "Option::is_none")]
pub project_id: Option<Uuid>,
#[serde(skip_serializing_if = "Option::is_none")]
pub event: Option<WsEventPayload>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
}
pub(crate) fn validate_origin(req: &HttpRequest) -> bool {
static ALLOWED_ORIGINS: LazyLock<Vec<String>> = LazyLock::new(|| {
std::env::var("WS_ALLOWED_ORIGINS")
.map(|v| v.split(',').map(|s| s.trim().to_string()).collect())
.unwrap_or_else(|_| {
vec![
"http://localhost".to_string(),
"https://localhost".to_string(),
"http://127.0.0.1".to_string(),
"https://127.0.0.1".to_string(),
"ws://localhost".to_string(),
"wss://localhost".to_string(),
"ws://127.0.0.1".to_string(),
"wss://127.0.0.1".to_string(),
]
})
});
let Some(origin) = req.headers().get("origin") else {
return true;
};
let Ok(origin_str) = origin.to_str() else {
return false;
};
// Exact match (with port)
if ALLOWED_ORIGINS.iter().any(|allowed| origin_str == *allowed) {
return true;
}
// Strip port: http://localhost:5173 -> http://localhost, http://[::1]:5173 -> http://[::1]
let origin_without_port = if let Some((scheme_host, port)) = origin_str.rsplit_once(':') {
if port.chars().all(|c| c.is_ascii_digit()) {
scheme_host.to_string()
} else {
origin_str.to_string()
}
} else {
origin_str.to_string()
};
if ALLOWED_ORIGINS
.iter()
.any(|allowed| origin_without_port == *allowed)
{
return true;
}
// Also check if the full origin starts with any allowed prefix
ALLOWED_ORIGINS
.iter()
.any(|allowed| origin_str.starts_with(allowed))
}
pub async fn ws_room(
room_id: web::Path<Uuid>,
service: web::Data<AppService>,
req: HttpRequest,
stream: web::Payload,
) -> Result<HttpResponse, actix_web::Error> {
let room_id = room_id.into_inner();
// Authenticate: try query parameter token first, then session
let user_id = authenticate_ws_request(&service, &req).await?;
let origin_val = req
.headers()
.get("origin")
.and_then(|v| v.to_str().ok())
.unwrap_or("(none)");
slog::debug!(
service.logs,
"WS room connection attempt user_id={} room_id={} origin={}",
user_id,
room_id,
origin_val
);
if !validate_origin(&req) {
slog::warn!(
service.logs,
"WS room: origin rejected user_id={} room_id={} origin={}",
user_id,
room_id,
origin_val
);
service
.room
.room_manager
.metrics
.ws_auth_failures
.increment(1);
return Err(crate::error::ApiError(service::error::AppError::BadRequest(
"Invalid origin".into(),
))
.into());
}
if let Err(e) = service.room.check_room_access(room_id, user_id).await {
slog::warn!(
service.logs,
"WS room: access denied for user_id={} room_id={} error={}",
user_id,
room_id,
e
);
return Err(crate::error::ApiError::from(e).into());
}
let manager = service.room.room_manager.clone();
manager.metrics.ws_connections_active.increment(1.0);
manager.metrics.ws_connections_total.increment(1);
manager.metrics.incr_room_connections(room_id).await;
let (response, mut session, mut msg_stream) = actix_ws::handle(&req, stream)?;
actix::spawn(async move {
let mut receiver = match manager.subscribe(room_id, user_id).await {
Ok(r) => r,
Err(e) => {
slog::error!(service.logs, "Failed to subscribe to room: {}", e);
return;
}
};
let mut stream_rx = manager.subscribe_room_stream(room_id).await;
let mut shutdown_rx = manager.subscribe_shutdown();
let mut last_heartbeat = Instant::now();
let mut last_activity = Instant::now();
let mut heartbeat_interval = tokio::time::interval(HEARTBEAT_INTERVAL);
heartbeat_interval.tick().await;
let mut message_count: u32 = 0;
let mut rate_window_start = Instant::now();
loop {
tokio::select! {
_ = heartbeat_interval.tick() => {
if last_heartbeat.elapsed() > HEARTBEAT_TIMEOUT {
slog::warn!(service.logs, "WS room {} heartbeat timeout for user {}", room_id, user_id);
manager.metrics.ws_heartbeat_timeout_total.increment(1);
let _ = session.close(Some(actix_ws::CloseCode::Policy.into())).await;
break;
}
if last_activity.elapsed() > MAX_IDLE_TIMEOUT {
slog::info!(service.logs, "WS room {} idle timeout for user {}", room_id, user_id);
manager.metrics.ws_idle_timeout_total.increment(1);
let _ = session.close(Some(actix_ws::CloseCode::Normal.into())).await;
break;
}
if session.ping(b"").await.is_err() {
break;
}
manager.metrics.ws_heartbeat_sent_total.increment(1);
}
_ = shutdown_rx.recv() => {
slog::info!(service.logs, "WS room {} shutdown", room_id);
let _ = session.close(Some(actix_ws::CloseCode::Normal.into())).await;
break;
}
msg = msg_stream.recv() => {
match msg {
Some(Ok(WsMessage::Ping(bytes))) => {
if session.pong(&bytes).await.is_err() {
break;
}
last_heartbeat = Instant::now();
}
Some(Ok(WsMessage::Pong(_))) => {
last_heartbeat = Instant::now();
}
#[allow(unused_assignments)]
Some(Ok(WsMessage::Text(text))) => {
if last_activity.elapsed() > MAX_IDLE_TIMEOUT {
slog::info!(service.logs, "WS room {} idle timeout for user {}", room_id, user_id);
manager.metrics.ws_idle_timeout_total.increment(1);
let _ = session.close(Some(actix_ws::CloseCode::Normal.into())).await;
break;
}
last_activity = Instant::now();
if check_ws_rate_limit(&service.logs, &manager, &mut message_count, &mut rate_window_start).await {
let _ = session.text(serde_json::json!({
"type": "error",
"error": "rate_limit_exceeded",
"max_per_second": MAX_MESSAGES_PER_SECOND
}).to_string()).await;
break;
}
if text.len() > MAX_TEXT_MESSAGE_LEN {
slog::warn!(service.logs, "WS room {} message too long from user {}: {} bytes", room_id, user_id, text.len());
let _ = session.text(serde_json::json!({
"type": "error",
"error": "message_too_long",
"max_bytes": MAX_TEXT_MESSAGE_LEN
}).to_string()).await;
break;
}
slog::warn!(service.logs, "WS room {} unexpected text message from user {} ({} bytes) — WS is push-only, use REST to send messages", room_id, user_id, text.len());
let _ = session.text(serde_json::json!({
"type": "error",
"error": "ws_push_only",
"message": "WebSocket is for receiving messages only. Use the REST API to send messages."
}).to_string()).await;
break;
}
Some(Ok(WsMessage::Binary(_))) => {
if check_ws_rate_limit(&service.logs, &manager, &mut message_count, &mut rate_window_start).await {
break;
}
slog::warn!(service.logs, "WS room {} unexpected binary from user {}", room_id, user_id);
break;
}
Some(Ok(WsMessage::Close(reason))) => {
let _ = session.close(reason).await;
break;
}
Some(Ok(_)) => {}
Some(Err(e)) => {
slog::warn!(service.logs, "WS room error: {}", e);
break;
}
None => break,
}
}
event = receiver.recv() => {
match event {
Ok(event) => {
let payload = WsOutEvent {
room_id: Some(room_id),
project_id: None,
event: Some(WsEventPayload::RoomMessage(event.into())),
error: None,
};
match serde_json::to_string(&payload) {
Ok(json) => {
if session.text(json).await.is_err() {
break;
}
}
Err(e) => {
slog::error!(service.logs, "WS serialize error: {}", e);
break;
}
}
}
Err(_) => break,
}
}
chunk_event = stream_rx.recv() => {
match chunk_event {
Ok(chunk) => {
let payload = WsOutEvent {
room_id: Some(room_id),
project_id: None,
event: Some(WsEventPayload::AiStreamChunk(chunk.into())),
error: None,
};
match serde_json::to_string(&payload) {
Ok(json) => {
if session.text(json).await.is_err() {
break;
}
}
Err(e) => {
slog::error!(service.logs, "WS streaming serialize error: {}", e);
}
}
}
Err(_) => {}
}
}
}
}
manager.unsubscribe(room_id, user_id).await;
manager.metrics.ws_connections_active.decrement(1.0);
manager.metrics.ws_disconnections_total.increment(1);
manager.metrics.dec_room_connections(room_id).await;
});
Ok(response)
}
pub async fn ws_project(
project_id: web::Path<Uuid>,
service: web::Data<AppService>,
req: HttpRequest,
stream: web::Payload,
) -> Result<HttpResponse, actix_web::Error> {
let project_id = project_id.into_inner();
// Authenticate: try query parameter token first, then session
let user_id = authenticate_ws_request(&service, &req).await?;
if !validate_origin(&req) {
service
.room
.room_manager
.metrics
.ws_auth_failures
.increment(1);
return Err(crate::error::ApiError(service::error::AppError::BadRequest(
"Invalid origin".into(),
))
.into());
}
if let Err(e) = service.room.check_project_member(project_id, user_id).await {
service
.room
.room_manager
.metrics
.ws_auth_failures
.increment(1);
return Err(crate::error::ApiError::from(e).into());
}
if let Err(e) = service
.room
.room_manager
.check_project_connection_rate(project_id, user_id)
.await
{
service
.room
.room_manager
.metrics
.ws_rate_limit_hits
.increment(1);
return Err(crate::error::ApiError::from(e).into());
}
let manager = service.room.room_manager.clone();
manager.metrics.ws_connections_active.increment(1.0);
manager.metrics.ws_connections_total.increment(1);
let (response, mut session, mut msg_stream) = actix_ws::handle(&req, stream)?;
actix::spawn(async move {
let mut receiver = match manager.subscribe_project(project_id, user_id).await {
Ok(r) => r,
Err(e) => {
slog::error!(service.logs, "Failed to subscribe to project: {}", e);
return;
}
};
let mut shutdown_rx = manager.subscribe_shutdown();
let mut last_heartbeat = Instant::now();
let mut last_activity = Instant::now();
let mut heartbeat_interval = tokio::time::interval(HEARTBEAT_INTERVAL);
heartbeat_interval.tick().await;
let mut message_count: u32 = 0;
let mut rate_window_start = Instant::now();
loop {
tokio::select! {
_ = heartbeat_interval.tick() => {
if last_heartbeat.elapsed() > HEARTBEAT_TIMEOUT {
slog::warn!(service.logs, "WS project {} heartbeat timeout for user {}", project_id, user_id);
manager.metrics.ws_heartbeat_timeout_total.increment(1);
let _ = session.close(Some(actix_ws::CloseCode::Policy.into())).await;
break;
}
if last_activity.elapsed() > MAX_IDLE_TIMEOUT {
slog::info!(service.logs, "WS project {} idle timeout for user {}", project_id, user_id);
manager.metrics.ws_idle_timeout_total.increment(1);
let _ = session.close(Some(actix_ws::CloseCode::Normal.into())).await;
break;
}
if session.ping(b"").await.is_err() {
break;
}
manager.metrics.ws_heartbeat_sent_total.increment(1);
}
_ = shutdown_rx.recv() => {
slog::info!(service.logs, "WS project {} shutdown", project_id);
let _ = session.close(Some(actix_ws::CloseCode::Normal.into())).await;
break;
}
msg = msg_stream.recv() => {
match msg {
Some(Ok(WsMessage::Ping(bytes))) => {
if session.pong(&bytes).await.is_err() {
break;
}
last_heartbeat = Instant::now();
}
Some(Ok(WsMessage::Pong(_))) => {
last_heartbeat = Instant::now();
}
#[allow(unused_assignments)]
Some(Ok(WsMessage::Text(text))) => {
if last_activity.elapsed() > MAX_IDLE_TIMEOUT {
slog::info!(service.logs, "WS project {} idle timeout for user {}", project_id, user_id);
manager.metrics.ws_idle_timeout_total.increment(1);
let _ = session.close(Some(actix_ws::CloseCode::Normal.into())).await;
break;
}
last_activity = Instant::now();
slog::warn!(service.logs, "WS project {} unexpected text from user {} ({} bytes) — WS is push-only", project_id, user_id, text.len());
let _ = session.text(serde_json::json!({
"type": "error",
"error": "ws_push_only",
"message": "WebSocket is for receiving events only."
}).to_string()).await;
break;
}
Some(Ok(WsMessage::Binary(_))) => {
if check_ws_rate_limit(&service.logs, &manager, &mut message_count, &mut rate_window_start).await {
slog::warn!(service.logs, "WS project {} rate limit exceeded for user {}", project_id, user_id);
let _ = session.text(serde_json::json!({
"type": "error",
"error": "rate_limit_exceeded",
"max_per_second": MAX_MESSAGES_PER_SECOND
}).to_string()).await;
break;
}
slog::warn!(service.logs, "WS project {} unexpected binary from user {}", project_id, user_id);
break;
}
Some(Ok(WsMessage::Close(reason))) => {
let _ = session.close(reason).await;
break;
}
Some(Ok(_)) => {}
Some(Err(e)) => {
slog::warn!(service.logs, "WS project error: {}", e);
break;
}
None => break,
}
}
event = receiver.recv() => {
match event {
Ok(event) => {
let payload = WsOutEvent {
room_id: event.room_id,
project_id: Some(project_id),
event: Some(WsEventPayload::ProjectEvent(event.into())),
error: None,
};
match serde_json::to_string(&payload) {
Ok(json) => {
if session.text(json).await.is_err() {
break;
}
}
Err(e) => {
slog::error!(service.logs, "WS serialize error: {}", e);
break;
}
}
}
Err(_) => break,
}
}
}
}
manager.unsubscribe_project(project_id, user_id).await;
manager.metrics.ws_connections_active.decrement(1.0);
manager.metrics.ws_disconnections_total.increment(1);
});
Ok(response)
}