207 lines
8.1 KiB
Rust
207 lines
8.1 KiB
Rust
use std::sync::Arc;
|
|
use uuid::Uuid;
|
|
use tokio::sync::{broadcast, RwLock};
|
|
|
|
use super::{RoomConnectionManager, RoomMessageStreamChunkEvent, BROADCAST_CAPACITY, REPLAY_BUFFER_SIZE};
|
|
|
|
impl RoomConnectionManager {
|
|
pub async fn register_stream_channel(&self, message_id: Uuid, room_id: Uuid, display_name: Option<String>) -> broadcast::Receiver<Arc<RoomMessageStreamChunkEvent>> {
|
|
let mut map = self.stream_inner.write().await;
|
|
if let Some(tx) = map.get(&message_id) {
|
|
return tx.subscribe();
|
|
}
|
|
let (tx, rx) = broadcast::channel(BROADCAST_CAPACITY);
|
|
map.insert(message_id, tx.clone());
|
|
|
|
// Also register in active_streams for late-joiner catchup
|
|
let meta = super::ActiveStreamMeta {
|
|
message_id,
|
|
room_id,
|
|
display_name: display_name.clone(),
|
|
chunks: Arc::new(RwLock::new(Vec::new())),
|
|
};
|
|
drop(map);
|
|
let mut active = self.active_streams.write().await;
|
|
active.insert(message_id, meta);
|
|
|
|
rx
|
|
}
|
|
|
|
pub async fn subscribe_stream(&self, message_id: Uuid) -> Option<broadcast::Receiver<Arc<RoomMessageStreamChunkEvent>>> {
|
|
let map = self.stream_inner.read().await;
|
|
map.get(&message_id).map(|tx| tx.subscribe())
|
|
}
|
|
|
|
pub async fn subscribe_room_stream(&self, room_id: Uuid) -> broadcast::Receiver<Arc<RoomMessageStreamChunkEvent>> {
|
|
// New subscriber: replay active streams in this room so they catch up,
|
|
// then subscribe to the room's channel.
|
|
|
|
let (_existing_tx, new_rx) = {
|
|
let mut map = self.room_stream_inner.write().await;
|
|
match map.get_mut(&room_id) {
|
|
Some((existing_tx, count)) => {
|
|
*count += 1;
|
|
let tx_clone = existing_tx.clone();
|
|
let rx_clone = existing_tx.subscribe();
|
|
drop(map);
|
|
// Replay buffered chunks to existing channel so all subscribers receive them.
|
|
let active = self.active_streams.read().await;
|
|
for (&msg_id, meta) in active.iter() {
|
|
if meta.room_id != room_id { continue; }
|
|
let start_event = Arc::new(RoomMessageStreamChunkEvent {
|
|
message_id: msg_id,
|
|
room_id,
|
|
seq: 0,
|
|
content: String::new(),
|
|
done: false,
|
|
error: None,
|
|
display_name: meta.display_name.clone(),
|
|
chunk_type: None,
|
|
});
|
|
let _ = tx_clone.send(Arc::clone(&start_event));
|
|
let chunks = meta.chunks.read().await;
|
|
for chunk in chunks.iter() {
|
|
let _ = tx_clone.send(Arc::new(chunk.clone()));
|
|
}
|
|
}
|
|
(tx_clone, rx_clone)
|
|
}
|
|
None => {
|
|
let (tx, rx) = broadcast::channel(BROADCAST_CAPACITY);
|
|
map.insert(room_id, (tx.clone(), 1));
|
|
drop(map);
|
|
// Replay buffered chunks to new channel.
|
|
let active = self.active_streams.read().await;
|
|
for (&msg_id, meta) in active.iter() {
|
|
if meta.room_id != room_id { continue; }
|
|
let start_event = Arc::new(RoomMessageStreamChunkEvent {
|
|
message_id: msg_id,
|
|
room_id,
|
|
seq: 0,
|
|
content: String::new(),
|
|
done: false,
|
|
error: None,
|
|
display_name: meta.display_name.clone(),
|
|
chunk_type: None,
|
|
});
|
|
let _ = tx.send(Arc::clone(&start_event));
|
|
let chunks = meta.chunks.read().await;
|
|
for chunk in chunks.iter() {
|
|
let _ = tx.send(Arc::new(chunk.clone()));
|
|
}
|
|
}
|
|
(tx, rx)
|
|
}
|
|
}
|
|
};
|
|
new_rx
|
|
}
|
|
|
|
pub async fn broadcast_stream_chunk(&self, event: RoomMessageStreamChunkEvent) {
|
|
{
|
|
let mut activity = self.room_last_activity.write().await;
|
|
activity.insert(event.room_id, std::time::Instant::now());
|
|
}
|
|
|
|
let is_start = event.seq == 0 && !event.done;
|
|
let is_final_chunk = event.done;
|
|
|
|
// Buffer chunk in active_streams for late-joiner replay.
|
|
if !is_final_chunk || is_start {
|
|
let mut active = self.active_streams.write().await;
|
|
if let Some(meta) = active.get_mut(&event.message_id) {
|
|
let mut chunks = meta.chunks.write().await;
|
|
chunks.push(event.clone());
|
|
// Evict oldest if buffer exceeds REPLAY_BUFFER_SIZE.
|
|
if chunks.len() > REPLAY_BUFFER_SIZE {
|
|
chunks.remove(0);
|
|
}
|
|
}
|
|
drop(active);
|
|
// Also update room_to_streams reverse index.
|
|
if is_start {
|
|
let mut r2s = self.room_to_streams.write().await;
|
|
r2s.entry(event.room_id).or_default().insert(event.message_id);
|
|
}
|
|
}
|
|
|
|
let event = Arc::new(event);
|
|
|
|
let map = self.stream_inner.read().await;
|
|
if let Some(tx) = map.get(&event.message_id) {
|
|
let _ = tx.send(Arc::clone(&event));
|
|
}
|
|
|
|
drop(map);
|
|
let map = self.room_stream_inner.read().await;
|
|
if let Some((tx, _)) = map.get(&event.room_id) {
|
|
let _ = tx.send(Arc::clone(&event));
|
|
}
|
|
|
|
if is_final_chunk {
|
|
drop(map);
|
|
// Cleanup active_streams entry.
|
|
let mut active = self.active_streams.write().await;
|
|
if active.remove(&event.message_id).is_some() {
|
|
let mut r2s = self.room_to_streams.write().await;
|
|
if let Some(ids) = r2s.get_mut(&event.room_id) {
|
|
ids.remove(&event.message_id);
|
|
if ids.is_empty() {
|
|
r2s.remove(&event.room_id);
|
|
}
|
|
}
|
|
}
|
|
drop(active);
|
|
// Cleanup room_stream_inner subscriber count.
|
|
let mut map = self.room_stream_inner.write().await;
|
|
if let Some((_, count)) = map.get_mut(&event.room_id) {
|
|
if *count > 0 {
|
|
*count -= 1;
|
|
}
|
|
if *count == 0 {
|
|
map.remove(&event.room_id);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
pub async fn close_stream_channel(&self, message_id: Uuid) {
|
|
let mut map = self.stream_inner.write().await;
|
|
map.remove(&message_id);
|
|
drop(map);
|
|
// Remove from active_streams (cleanup on stream end).
|
|
let mut active = self.active_streams.write().await;
|
|
if let Some(meta) = active.remove(&message_id) {
|
|
let mut r2s = self.room_to_streams.write().await;
|
|
if let Some(ids) = r2s.get_mut(&meta.room_id) {
|
|
ids.remove(&message_id);
|
|
if ids.is_empty() {
|
|
r2s.remove(&meta.room_id);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
pub async fn register_stream_cancel(&self, room_id: Uuid) -> Arc<std::sync::atomic::AtomicBool> {
|
|
let cancel = Arc::new(std::sync::atomic::AtomicBool::new(false));
|
|
let mut map = self.stream_cancel_tokens.write().await;
|
|
map.insert(room_id, cancel.clone());
|
|
cancel
|
|
}
|
|
|
|
pub async fn cancel_ai_stream(&self, room_id: Uuid) -> bool {
|
|
let map = self.stream_cancel_tokens.read().await;
|
|
if let Some(cancel) = map.get(&room_id) {
|
|
cancel.store(true, std::sync::atomic::Ordering::Release);
|
|
true
|
|
} else {
|
|
false
|
|
}
|
|
}
|
|
|
|
pub async fn unregister_stream_cancel(&self, room_id: Uuid) {
|
|
let mut map = self.stream_cancel_tokens.write().await;
|
|
map.remove(&room_id);
|
|
}
|
|
}
|