fix(room): add access check to ws_universal subscribe and fix stream handling
Some checks are pending
CI / Rust Lint & Check (push) Waiting to run
CI / Rust Tests (push) Waiting to run
CI / Frontend Lint & Type Check (push) Waiting to run
CI / Frontend Build (push) Blocked by required conditions

- Add `check_room_access` before `manager.subscribe()` in ws_universal
  to prevent unauthorized room subscription (security fix)
- Fix busy-wait in `poll_push_streams`: sleep 50ms when streams are
  empty, yield only when there are active streams
- Re-subscribe dead rooms after stream errors so events are not
  permanently lost until manual reconnect
- Fix streaming message placeholder using fake content as room_id:
  use chunk.room_id from backend instead
- Show toast error on history load failures instead of silent fallback
This commit is contained in:
ZhenYi 2026-04-17 17:15:34 +08:00
parent afb1bbeb71
commit 5256e72be7
3 changed files with 64 additions and 32 deletions

View File

@ -15,6 +15,10 @@ impl WsRequestHandler {
Self { service, user_id } Self { service, user_id }
} }
pub fn service(&self) -> &Arc<AppService> {
&self.service
}
pub async fn handle(&self, request: WsRequest) -> WsResponse { pub async fn handle(&self, request: WsRequest) -> WsResponse {
let request_id = request.request_id; let request_id = request.request_id;
let action_str = request.action.to_string(); let action_str = request.action.to_string();

View File

@ -10,6 +10,7 @@ use uuid::Uuid;
use crate::error::ApiError; use crate::error::ApiError;
use queue::{ReactionGroup, RoomMessageEvent, RoomMessageStreamChunkEvent}; use queue::{ReactionGroup, RoomMessageEvent, RoomMessageStreamChunkEvent};
use room::connection::RoomConnectionManager;
use service::AppService; use service::AppService;
use super::ws::validate_origin; use super::ws::validate_origin;
@ -195,7 +196,7 @@ pub async fn ws_universal(
let _ = session.close(Some(actix_ws::CloseCode::Normal.into())).await; let _ = session.close(Some(actix_ws::CloseCode::Normal.into())).await;
break; break;
} }
push_event = poll_push_streams(&mut push_streams) => { push_event = poll_push_streams(&mut push_streams, &manager, user_id) => {
match push_event { match push_event {
Some(WsPushEvent::RoomMessage { room_id, event }) => { Some(WsPushEvent::RoomMessage { room_id, event }) => {
let payload = serde_json::json!({ let payload = serde_json::json!({
@ -294,6 +295,12 @@ pub async fn ws_universal(
match request.action { match request.action {
WsAction::SubscribeRoom => { WsAction::SubscribeRoom => {
if let Some(room_id) = request.params().room_id { if let Some(room_id) = request.params().room_id {
// Verify user has access to this room before subscribing
if let Err(e) = handler.service().room.check_room_access(room_id, user_id).await {
let _ = session.text(serde_json::to_string(&WsResponse::error_response(
request.request_id, &action_str, 403, "access_denied", &format!("{}", e)
)).unwrap_or_default()).await;
} else {
match manager.subscribe(room_id, user_id).await { match manager.subscribe(room_id, user_id).await {
Ok(rx) => { Ok(rx) => {
let stream_rx = manager.subscribe_room_stream(room_id).await; let stream_rx = manager.subscribe_room_stream(room_id).await;
@ -308,10 +315,11 @@ pub async fn ws_universal(
} }
Err(e) => { Err(e) => {
let _ = session.text(serde_json::to_string(&WsResponse::error_response( let _ = session.text(serde_json::to_string(&WsResponse::error_response(
request.request_id, &action_str, 403, "subscribe_failed", &format!("{}", e) request.request_id, &action_str, 500, "subscribe_failed", &format!("{}", e)
)).unwrap_or_default()).await; )).unwrap_or_default()).await;
} }
} }
}
} else { } else {
let _ = session.text(serde_json::to_string(&WsResponse::error_response( let _ = session.text(serde_json::to_string(&WsResponse::error_response(
request.request_id, &action_str, 400, "bad_request", "room_id required" request.request_id, &action_str, 400, "bad_request", "room_id required"
@ -361,9 +369,15 @@ pub async fn ws_universal(
Ok(response) Ok(response)
} }
async fn poll_push_streams(streams: &mut PushStreams) -> Option<WsPushEvent> { async fn poll_push_streams(
streams: &mut PushStreams,
manager: &Arc<RoomConnectionManager>,
user_id: Uuid,
) -> Option<WsPushEvent> {
loop { loop {
let room_ids: Vec<Uuid> = streams.keys().copied().collect(); let room_ids: Vec<Uuid> = streams.keys().copied().collect();
let mut dead_rooms: Vec<Uuid> = Vec::new();
for room_id in room_ids { for room_id in room_ids {
if let Some((msg_stream, chunk_stream)) = streams.get_mut(&room_id) { if let Some((msg_stream, chunk_stream)) = streams.get_mut(&room_id) {
tokio::select! { tokio::select! {
@ -380,12 +394,7 @@ async fn poll_push_streams(streams: &mut PushStreams) -> Option<WsPushEvent> {
return Some(WsPushEvent::RoomMessage { room_id, event }); return Some(WsPushEvent::RoomMessage { room_id, event });
} }
Some(Err(_)) | None => { Some(Err(_)) | None => {
// Stream closed/error — remove and re-subscribe to avoid dead_rooms.push(room_id);
// spinning on a closed stream. The manager keeps the
// broadcast sender alive so re-subscribing gets the latest
// receiver. Multiple rapid errors are handled by the
// manager's existing retry/cleanup logic.
streams.remove(&room_id);
} }
} }
} }
@ -395,19 +404,34 @@ async fn poll_push_streams(streams: &mut PushStreams) -> Option<WsPushEvent> {
return Some(WsPushEvent::AiStreamChunk { room_id, chunk }); return Some(WsPushEvent::AiStreamChunk { room_id, chunk });
} }
Some(Err(_)) | None => { Some(Err(_)) | None => {
streams.remove(&room_id); dead_rooms.push(room_id);
} }
} }
} }
} }
} }
} }
// Re-subscribe dead rooms so we don't permanently lose events
for room_id in dead_rooms {
if streams.remove(&room_id).is_some() {
if let Ok(rx) = manager.subscribe(room_id, user_id).await {
let stream_rx = manager.subscribe_room_stream(room_id).await;
streams.insert(room_id, (
BroadcastStream::new(rx),
BroadcastStream::new(stream_rx),
));
}
}
}
if streams.is_empty() { if streams.is_empty() {
tokio::time::sleep(std::time::Duration::from_millis(50)).await; // Yield so the caller can drop us before the next iteration
return None; tokio::time::sleep(Duration::from_millis(50)).await;
} } else {
tokio::task::yield_now().await; tokio::task::yield_now().await;
} }
}
} }
fn extract_user_id_from_token(token: &str) -> Option<Uuid> { fn extract_user_id_from_token(token: &str) -> Option<Uuid> {

View File

@ -1,4 +1,5 @@
import { useCallback, useEffect, useMemo, useRef, useState } from 'react'; import { useCallback, useEffect, useMemo, useRef, useState } from 'react';
import { toast } from 'sonner';
import { import {
type AiStreamChunkPayload, type AiStreamChunkPayload,
type RoomMessagePayload, type RoomMessagePayload,
@ -176,7 +177,7 @@ export function useRoomWs({
membersRef.current = members; membersRef.current = members;
/** Ref for AI streaming RAF batch */ /** Ref for AI streaming RAF batch */
const streamingBatchRef = useRef<Map<string, { content: string; done: boolean }>>(new Map()); const streamingBatchRef = useRef<Map<string, { content: string; done: boolean; room_id: string }>>(new Map());
const streamingRafRef = useRef<number | null>(null); const streamingRafRef = useRef<number | null>(null);
/** Flush streaming batch to state */ /** Flush streaming batch to state */
@ -193,7 +194,7 @@ export function useRoomWs({
if (idx === -1) { if (idx === -1) {
const placeholder: UiMessage = { const placeholder: UiMessage = {
id: messageId, id: messageId,
room_id: chunk.content.split('\n')[0]?.slice(0, 100) || '', room_id: chunk.room_id ?? next.find(() => true)?.room_id ?? '',
sender_type: 'ai', sender_type: 'ai',
content: chunk.done ? chunk.content : '', content: chunk.done ? chunk.content : '',
content_type: 'text', content_type: 'text',
@ -322,6 +323,7 @@ export function useRoomWs({
streamingBatchRef.current.set(chunk.message_id, { streamingBatchRef.current.set(chunk.message_id, {
content: chunk.content, content: chunk.content,
done: chunk.done, done: chunk.done,
room_id: chunk.room_id,
}); });
if (streamingRafRef.current == null) { if (streamingRafRef.current == null) {
@ -442,6 +444,7 @@ export function useRoomWs({
}) })
.catch(() => { .catch(() => {
if (activeRoomIdRef.current !== roomId) return; if (activeRoomIdRef.current !== roomId) return;
toast.error('Failed to load message history');
setIsHistoryLoaded(true); setIsHistoryLoaded(true);
}); });
} }
@ -525,7 +528,8 @@ export function useRoomWs({
return [...newOnes, ...prev]; return [...newOnes, ...prev];
}); });
} catch { } catch {
// Non-critical — just stop trying // Non-critical — show toast so user knows the load failed
toast.error('Failed to load more messages');
setIsHistoryLoaded(true); setIsHistoryLoaded(true);
} finally { } finally {
setIsLoadingMore(false); setIsLoadingMore(false);