use std::sync::Arc; use uuid::Uuid; use tokio::sync::{broadcast, RwLock}; use super::{RoomConnectionManager, RoomMessageStreamChunkEvent, BROADCAST_CAPACITY, REPLAY_BUFFER_SIZE}; impl RoomConnectionManager { pub async fn register_stream_channel(&self, message_id: Uuid, room_id: Uuid, display_name: Option) -> broadcast::Receiver> { let mut map = self.stream_inner.write().await; if let Some(tx) = map.get(&message_id) { return tx.subscribe(); } let (tx, rx) = broadcast::channel(BROADCAST_CAPACITY); map.insert(message_id, tx.clone()); // Also register in active_streams for late-joiner catchup let meta = super::ActiveStreamMeta { message_id, room_id, display_name: display_name.clone(), chunks: Arc::new(RwLock::new(Vec::new())), }; drop(map); let mut active = self.active_streams.write().await; active.insert(message_id, meta); rx } pub async fn subscribe_stream(&self, message_id: Uuid) -> Option>> { let map = self.stream_inner.read().await; map.get(&message_id).map(|tx| tx.subscribe()) } pub async fn subscribe_room_stream(&self, room_id: Uuid) -> broadcast::Receiver> { // New subscriber: replay active streams in this room so they catch up, // then subscribe to the room's channel. let (_existing_tx, new_rx) = { let mut map = self.room_stream_inner.write().await; match map.get_mut(&room_id) { Some((existing_tx, count)) => { *count += 1; let tx_clone = existing_tx.clone(); let rx_clone = existing_tx.subscribe(); drop(map); // Replay buffered chunks to existing channel so all subscribers receive them. let active = self.active_streams.read().await; for (&msg_id, meta) in active.iter() { if meta.room_id != room_id { continue; } let start_event = Arc::new(RoomMessageStreamChunkEvent { message_id: msg_id, room_id, seq: 0, content: String::new(), done: false, error: None, display_name: meta.display_name.clone(), chunk_type: None, }); let _ = tx_clone.send(Arc::clone(&start_event)); let chunks = meta.chunks.read().await; for chunk in chunks.iter() { let _ = tx_clone.send(Arc::new(chunk.clone())); } } (tx_clone, rx_clone) } None => { let (tx, rx) = broadcast::channel(BROADCAST_CAPACITY); map.insert(room_id, (tx.clone(), 1)); drop(map); // Replay buffered chunks to new channel. let active = self.active_streams.read().await; for (&msg_id, meta) in active.iter() { if meta.room_id != room_id { continue; } let start_event = Arc::new(RoomMessageStreamChunkEvent { message_id: msg_id, room_id, seq: 0, content: String::new(), done: false, error: None, display_name: meta.display_name.clone(), chunk_type: None, }); let _ = tx.send(Arc::clone(&start_event)); let chunks = meta.chunks.read().await; for chunk in chunks.iter() { let _ = tx.send(Arc::new(chunk.clone())); } } (tx, rx) } } }; new_rx } pub async fn broadcast_stream_chunk(&self, event: RoomMessageStreamChunkEvent) { { let mut activity = self.room_last_activity.write().await; activity.insert(event.room_id, std::time::Instant::now()); } let is_start = event.seq == 0 && !event.done; let is_final_chunk = event.done; // Buffer chunk in active_streams for late-joiner replay. if !is_final_chunk || is_start { let mut active = self.active_streams.write().await; if let Some(meta) = active.get_mut(&event.message_id) { let mut chunks = meta.chunks.write().await; chunks.push(event.clone()); // Evict oldest if buffer exceeds REPLAY_BUFFER_SIZE. if chunks.len() > REPLAY_BUFFER_SIZE { chunks.remove(0); } } drop(active); // Also update room_to_streams reverse index. if is_start { let mut r2s = self.room_to_streams.write().await; r2s.entry(event.room_id).or_default().insert(event.message_id); } } let event = Arc::new(event); let map = self.stream_inner.read().await; if let Some(tx) = map.get(&event.message_id) { let _ = tx.send(Arc::clone(&event)); } drop(map); let map = self.room_stream_inner.read().await; if let Some((tx, _)) = map.get(&event.room_id) { let _ = tx.send(Arc::clone(&event)); } if is_final_chunk { drop(map); // Cleanup active_streams entry. let mut active = self.active_streams.write().await; if active.remove(&event.message_id).is_some() { let mut r2s = self.room_to_streams.write().await; if let Some(ids) = r2s.get_mut(&event.room_id) { ids.remove(&event.message_id); if ids.is_empty() { r2s.remove(&event.room_id); } } } drop(active); // Cleanup room_stream_inner subscriber count. let mut map = self.room_stream_inner.write().await; if let Some((_, count)) = map.get_mut(&event.room_id) { if *count > 0 { *count -= 1; } if *count == 0 { map.remove(&event.room_id); } } } } pub async fn close_stream_channel(&self, message_id: Uuid) { let mut map = self.stream_inner.write().await; map.remove(&message_id); drop(map); // Remove from active_streams (cleanup on stream end). let mut active = self.active_streams.write().await; if let Some(meta) = active.remove(&message_id) { let mut r2s = self.room_to_streams.write().await; if let Some(ids) = r2s.get_mut(&meta.room_id) { ids.remove(&message_id); if ids.is_empty() { r2s.remove(&meta.room_id); } } } } pub async fn register_stream_cancel(&self, room_id: Uuid) -> Arc { let cancel = Arc::new(std::sync::atomic::AtomicBool::new(false)); let mut map = self.stream_cancel_tokens.write().await; map.insert(room_id, cancel.clone()); cancel } pub async fn cancel_ai_stream(&self, room_id: Uuid) -> bool { let map = self.stream_cancel_tokens.read().await; if let Some(cancel) = map.get(&room_id) { cancel.store(true, std::sync::atomic::Ordering::Release); true } else { false } } pub async fn unregister_stream_cancel(&self, room_id: Uuid) { let mut map = self.stream_cancel_tokens.write().await; map.remove(&room_id); } }