fix(room): add access check to ws_universal subscribe and fix stream handling
Some checks are pending
CI / Rust Lint & Check (push) Waiting to run
CI / Rust Tests (push) Waiting to run
CI / Frontend Lint & Type Check (push) Waiting to run
CI / Frontend Build (push) Blocked by required conditions

- Add `check_room_access` before `manager.subscribe()` in ws_universal
  to prevent unauthorized room subscription (security fix)
- Fix busy-wait in `poll_push_streams`: sleep 50ms when streams are
  empty, yield only when there are active streams
- Re-subscribe dead rooms after stream errors so events are not
  permanently lost until manual reconnect
- Fix streaming message placeholder using fake content as room_id:
  use chunk.room_id from backend instead
- Show toast error on history load failures instead of silent fallback
This commit is contained in:
ZhenYi 2026-04-17 17:15:34 +08:00
parent afb1bbeb71
commit 5256e72be7
3 changed files with 64 additions and 32 deletions

View File

@ -15,6 +15,10 @@ impl WsRequestHandler {
Self { service, user_id } Self { service, user_id }
} }
pub fn service(&self) -> &Arc<AppService> {
&self.service
}
pub async fn handle(&self, request: WsRequest) -> WsResponse { pub async fn handle(&self, request: WsRequest) -> WsResponse {
let request_id = request.request_id; let request_id = request.request_id;
let action_str = request.action.to_string(); let action_str = request.action.to_string();

View File

@ -10,6 +10,7 @@ use uuid::Uuid;
use crate::error::ApiError; use crate::error::ApiError;
use queue::{ReactionGroup, RoomMessageEvent, RoomMessageStreamChunkEvent}; use queue::{ReactionGroup, RoomMessageEvent, RoomMessageStreamChunkEvent};
use room::connection::RoomConnectionManager;
use service::AppService; use service::AppService;
use super::ws::validate_origin; use super::ws::validate_origin;
@ -195,7 +196,7 @@ pub async fn ws_universal(
let _ = session.close(Some(actix_ws::CloseCode::Normal.into())).await; let _ = session.close(Some(actix_ws::CloseCode::Normal.into())).await;
break; break;
} }
push_event = poll_push_streams(&mut push_streams) => { push_event = poll_push_streams(&mut push_streams, &manager, user_id) => {
match push_event { match push_event {
Some(WsPushEvent::RoomMessage { room_id, event }) => { Some(WsPushEvent::RoomMessage { room_id, event }) => {
let payload = serde_json::json!({ let payload = serde_json::json!({
@ -294,22 +295,29 @@ pub async fn ws_universal(
match request.action { match request.action {
WsAction::SubscribeRoom => { WsAction::SubscribeRoom => {
if let Some(room_id) = request.params().room_id { if let Some(room_id) = request.params().room_id {
match manager.subscribe(room_id, user_id).await { // Verify user has access to this room before subscribing
Ok(rx) => { if let Err(e) = handler.service().room.check_room_access(room_id, user_id).await {
let stream_rx = manager.subscribe_room_stream(room_id).await; let _ = session.text(serde_json::to_string(&WsResponse::error_response(
push_streams.insert(room_id, ( request.request_id, &action_str, 403, "access_denied", &format!("{}", e)
BroadcastStream::new(rx), )).unwrap_or_default()).await;
BroadcastStream::new(stream_rx), } else {
)); match manager.subscribe(room_id, user_id).await {
let _ = session.text(serde_json::to_string(&WsResponse::success( Ok(rx) => {
request.request_id, &action_str, let stream_rx = manager.subscribe_room_stream(room_id).await;
WsResponseData::subscribed(Some(room_id), None) push_streams.insert(room_id, (
)).unwrap_or_default()).await; BroadcastStream::new(rx),
} BroadcastStream::new(stream_rx),
Err(e) => { ));
let _ = session.text(serde_json::to_string(&WsResponse::error_response( let _ = session.text(serde_json::to_string(&WsResponse::success(
request.request_id, &action_str, 403, "subscribe_failed", &format!("{}", e) request.request_id, &action_str,
)).unwrap_or_default()).await; WsResponseData::subscribed(Some(room_id), None)
)).unwrap_or_default()).await;
}
Err(e) => {
let _ = session.text(serde_json::to_string(&WsResponse::error_response(
request.request_id, &action_str, 500, "subscribe_failed", &format!("{}", e)
)).unwrap_or_default()).await;
}
} }
} }
} else { } else {
@ -361,9 +369,15 @@ pub async fn ws_universal(
Ok(response) Ok(response)
} }
async fn poll_push_streams(streams: &mut PushStreams) -> Option<WsPushEvent> { async fn poll_push_streams(
streams: &mut PushStreams,
manager: &Arc<RoomConnectionManager>,
user_id: Uuid,
) -> Option<WsPushEvent> {
loop { loop {
let room_ids: Vec<Uuid> = streams.keys().copied().collect(); let room_ids: Vec<Uuid> = streams.keys().copied().collect();
let mut dead_rooms: Vec<Uuid> = Vec::new();
for room_id in room_ids { for room_id in room_ids {
if let Some((msg_stream, chunk_stream)) = streams.get_mut(&room_id) { if let Some((msg_stream, chunk_stream)) = streams.get_mut(&room_id) {
tokio::select! { tokio::select! {
@ -380,12 +394,7 @@ async fn poll_push_streams(streams: &mut PushStreams) -> Option<WsPushEvent> {
return Some(WsPushEvent::RoomMessage { room_id, event }); return Some(WsPushEvent::RoomMessage { room_id, event });
} }
Some(Err(_)) | None => { Some(Err(_)) | None => {
// Stream closed/error — remove and re-subscribe to avoid dead_rooms.push(room_id);
// spinning on a closed stream. The manager keeps the
// broadcast sender alive so re-subscribing gets the latest
// receiver. Multiple rapid errors are handled by the
// manager's existing retry/cleanup logic.
streams.remove(&room_id);
} }
} }
} }
@ -395,18 +404,33 @@ async fn poll_push_streams(streams: &mut PushStreams) -> Option<WsPushEvent> {
return Some(WsPushEvent::AiStreamChunk { room_id, chunk }); return Some(WsPushEvent::AiStreamChunk { room_id, chunk });
} }
Some(Err(_)) | None => { Some(Err(_)) | None => {
streams.remove(&room_id); dead_rooms.push(room_id);
} }
} }
} }
} }
} }
} }
if streams.is_empty() {
tokio::time::sleep(std::time::Duration::from_millis(50)).await; // Re-subscribe dead rooms so we don't permanently lose events
return None; for room_id in dead_rooms {
if streams.remove(&room_id).is_some() {
if let Ok(rx) = manager.subscribe(room_id, user_id).await {
let stream_rx = manager.subscribe_room_stream(room_id).await;
streams.insert(room_id, (
BroadcastStream::new(rx),
BroadcastStream::new(stream_rx),
));
}
}
}
if streams.is_empty() {
// Yield so the caller can drop us before the next iteration
tokio::time::sleep(Duration::from_millis(50)).await;
} else {
tokio::task::yield_now().await;
} }
tokio::task::yield_now().await;
} }
} }

View File

@ -1,4 +1,5 @@
import { useCallback, useEffect, useMemo, useRef, useState } from 'react'; import { useCallback, useEffect, useMemo, useRef, useState } from 'react';
import { toast } from 'sonner';
import { import {
type AiStreamChunkPayload, type AiStreamChunkPayload,
type RoomMessagePayload, type RoomMessagePayload,
@ -176,7 +177,7 @@ export function useRoomWs({
membersRef.current = members; membersRef.current = members;
/** Ref for AI streaming RAF batch */ /** Ref for AI streaming RAF batch */
const streamingBatchRef = useRef<Map<string, { content: string; done: boolean }>>(new Map()); const streamingBatchRef = useRef<Map<string, { content: string; done: boolean; room_id: string }>>(new Map());
const streamingRafRef = useRef<number | null>(null); const streamingRafRef = useRef<number | null>(null);
/** Flush streaming batch to state */ /** Flush streaming batch to state */
@ -193,7 +194,7 @@ export function useRoomWs({
if (idx === -1) { if (idx === -1) {
const placeholder: UiMessage = { const placeholder: UiMessage = {
id: messageId, id: messageId,
room_id: chunk.content.split('\n')[0]?.slice(0, 100) || '', room_id: chunk.room_id ?? next.find(() => true)?.room_id ?? '',
sender_type: 'ai', sender_type: 'ai',
content: chunk.done ? chunk.content : '', content: chunk.done ? chunk.content : '',
content_type: 'text', content_type: 'text',
@ -322,6 +323,7 @@ export function useRoomWs({
streamingBatchRef.current.set(chunk.message_id, { streamingBatchRef.current.set(chunk.message_id, {
content: chunk.content, content: chunk.content,
done: chunk.done, done: chunk.done,
room_id: chunk.room_id,
}); });
if (streamingRafRef.current == null) { if (streamingRafRef.current == null) {
@ -442,6 +444,7 @@ export function useRoomWs({
}) })
.catch(() => { .catch(() => {
if (activeRoomIdRef.current !== roomId) return; if (activeRoomIdRef.current !== roomId) return;
toast.error('Failed to load message history');
setIsHistoryLoaded(true); setIsHistoryLoaded(true);
}); });
} }
@ -525,7 +528,8 @@ export function useRoomWs({
return [...newOnes, ...prev]; return [...newOnes, ...prev];
}); });
} catch { } catch {
// Non-critical — just stop trying // Non-critical — show toast so user knows the load failed
toast.error('Failed to load more messages');
setIsHistoryLoaded(true); setIsHistoryLoaded(true);
} finally { } finally {
setIsLoadingMore(false); setIsLoadingMore(false);