diff --git a/libs/api/room/ws_handler.rs b/libs/api/room/ws_handler.rs index 8a1ae1b..dcf9c6e 100644 --- a/libs/api/room/ws_handler.rs +++ b/libs/api/room/ws_handler.rs @@ -15,6 +15,10 @@ impl WsRequestHandler { Self { service, user_id } } + pub fn service(&self) -> &Arc { + &self.service + } + pub async fn handle(&self, request: WsRequest) -> WsResponse { let request_id = request.request_id; let action_str = request.action.to_string(); diff --git a/libs/api/room/ws_universal.rs b/libs/api/room/ws_universal.rs index 17787de..478893f 100644 --- a/libs/api/room/ws_universal.rs +++ b/libs/api/room/ws_universal.rs @@ -10,6 +10,7 @@ use uuid::Uuid; use crate::error::ApiError; use queue::{ReactionGroup, RoomMessageEvent, RoomMessageStreamChunkEvent}; +use room::connection::RoomConnectionManager; use service::AppService; use super::ws::validate_origin; @@ -195,7 +196,7 @@ pub async fn ws_universal( let _ = session.close(Some(actix_ws::CloseCode::Normal.into())).await; break; } - push_event = poll_push_streams(&mut push_streams) => { + push_event = poll_push_streams(&mut push_streams, &manager, user_id) => { match push_event { Some(WsPushEvent::RoomMessage { room_id, event }) => { let payload = serde_json::json!({ @@ -294,22 +295,29 @@ pub async fn ws_universal( match request.action { WsAction::SubscribeRoom => { if let Some(room_id) = request.params().room_id { - match manager.subscribe(room_id, user_id).await { - Ok(rx) => { - let stream_rx = manager.subscribe_room_stream(room_id).await; - push_streams.insert(room_id, ( - BroadcastStream::new(rx), - BroadcastStream::new(stream_rx), - )); - let _ = session.text(serde_json::to_string(&WsResponse::success( - request.request_id, &action_str, - WsResponseData::subscribed(Some(room_id), None) - )).unwrap_or_default()).await; - } - Err(e) => { - let _ = session.text(serde_json::to_string(&WsResponse::error_response( - request.request_id, &action_str, 403, "subscribe_failed", &format!("{}", e) - )).unwrap_or_default()).await; + // Verify user has access to this room before subscribing + if let Err(e) = handler.service().room.check_room_access(room_id, user_id).await { + let _ = session.text(serde_json::to_string(&WsResponse::error_response( + request.request_id, &action_str, 403, "access_denied", &format!("{}", e) + )).unwrap_or_default()).await; + } else { + match manager.subscribe(room_id, user_id).await { + Ok(rx) => { + let stream_rx = manager.subscribe_room_stream(room_id).await; + push_streams.insert(room_id, ( + BroadcastStream::new(rx), + BroadcastStream::new(stream_rx), + )); + let _ = session.text(serde_json::to_string(&WsResponse::success( + request.request_id, &action_str, + WsResponseData::subscribed(Some(room_id), None) + )).unwrap_or_default()).await; + } + Err(e) => { + let _ = session.text(serde_json::to_string(&WsResponse::error_response( + request.request_id, &action_str, 500, "subscribe_failed", &format!("{}", e) + )).unwrap_or_default()).await; + } } } } else { @@ -361,9 +369,15 @@ pub async fn ws_universal( Ok(response) } -async fn poll_push_streams(streams: &mut PushStreams) -> Option { +async fn poll_push_streams( + streams: &mut PushStreams, + manager: &Arc, + user_id: Uuid, +) -> Option { loop { let room_ids: Vec = streams.keys().copied().collect(); + let mut dead_rooms: Vec = Vec::new(); + for room_id in room_ids { if let Some((msg_stream, chunk_stream)) = streams.get_mut(&room_id) { tokio::select! { @@ -380,12 +394,7 @@ async fn poll_push_streams(streams: &mut PushStreams) -> Option { return Some(WsPushEvent::RoomMessage { room_id, event }); } Some(Err(_)) | None => { - // Stream closed/error — remove and re-subscribe to avoid - // spinning on a closed stream. The manager keeps the - // broadcast sender alive so re-subscribing gets the latest - // receiver. Multiple rapid errors are handled by the - // manager's existing retry/cleanup logic. - streams.remove(&room_id); + dead_rooms.push(room_id); } } } @@ -395,18 +404,33 @@ async fn poll_push_streams(streams: &mut PushStreams) -> Option { return Some(WsPushEvent::AiStreamChunk { room_id, chunk }); } Some(Err(_)) | None => { - streams.remove(&room_id); + dead_rooms.push(room_id); } } } } } } - if streams.is_empty() { - tokio::time::sleep(std::time::Duration::from_millis(50)).await; - return None; + + // Re-subscribe dead rooms so we don't permanently lose events + for room_id in dead_rooms { + if streams.remove(&room_id).is_some() { + if let Ok(rx) = manager.subscribe(room_id, user_id).await { + let stream_rx = manager.subscribe_room_stream(room_id).await; + streams.insert(room_id, ( + BroadcastStream::new(rx), + BroadcastStream::new(stream_rx), + )); + } + } + } + + if streams.is_empty() { + // Yield so the caller can drop us before the next iteration + tokio::time::sleep(Duration::from_millis(50)).await; + } else { + tokio::task::yield_now().await; } - tokio::task::yield_now().await; } } diff --git a/src/hooks/useRoomWs.ts b/src/hooks/useRoomWs.ts index 727ee83..1ad1ce0 100644 --- a/src/hooks/useRoomWs.ts +++ b/src/hooks/useRoomWs.ts @@ -1,4 +1,5 @@ import { useCallback, useEffect, useMemo, useRef, useState } from 'react'; +import { toast } from 'sonner'; import { type AiStreamChunkPayload, type RoomMessagePayload, @@ -176,7 +177,7 @@ export function useRoomWs({ membersRef.current = members; /** Ref for AI streaming RAF batch */ - const streamingBatchRef = useRef>(new Map()); + const streamingBatchRef = useRef>(new Map()); const streamingRafRef = useRef(null); /** Flush streaming batch to state */ @@ -193,7 +194,7 @@ export function useRoomWs({ if (idx === -1) { const placeholder: UiMessage = { id: messageId, - room_id: chunk.content.split('\n')[0]?.slice(0, 100) || '', + room_id: chunk.room_id ?? next.find(() => true)?.room_id ?? '', sender_type: 'ai', content: chunk.done ? chunk.content : '', content_type: 'text', @@ -322,6 +323,7 @@ export function useRoomWs({ streamingBatchRef.current.set(chunk.message_id, { content: chunk.content, done: chunk.done, + room_id: chunk.room_id, }); if (streamingRafRef.current == null) { @@ -442,6 +444,7 @@ export function useRoomWs({ }) .catch(() => { if (activeRoomIdRef.current !== roomId) return; + toast.error('Failed to load message history'); setIsHistoryLoaded(true); }); } @@ -525,7 +528,8 @@ export function useRoomWs({ return [...newOnes, ...prev]; }); } catch { - // Non-critical — just stop trying + // Non-critical — show toast so user knows the load failed + toast.error('Failed to load more messages'); setIsHistoryLoaded(true); } finally { setIsLoadingMore(false);