gitdataai/libs/room/src/connection/stream.rs

207 lines
8.1 KiB
Rust

use std::sync::Arc;
use uuid::Uuid;
use tokio::sync::{broadcast, RwLock};
use super::{RoomConnectionManager, RoomMessageStreamChunkEvent, BROADCAST_CAPACITY, REPLAY_BUFFER_SIZE};
impl RoomConnectionManager {
pub async fn register_stream_channel(&self, message_id: Uuid, room_id: Uuid, display_name: Option<String>) -> broadcast::Receiver<Arc<RoomMessageStreamChunkEvent>> {
let mut map = self.stream_inner.write().await;
if let Some(tx) = map.get(&message_id) {
return tx.subscribe();
}
let (tx, rx) = broadcast::channel(BROADCAST_CAPACITY);
map.insert(message_id, tx.clone());
// Also register in active_streams for late-joiner catchup
let meta = super::ActiveStreamMeta {
message_id,
room_id,
display_name: display_name.clone(),
chunks: Arc::new(RwLock::new(Vec::new())),
};
drop(map);
let mut active = self.active_streams.write().await;
active.insert(message_id, meta);
rx
}
pub async fn subscribe_stream(&self, message_id: Uuid) -> Option<broadcast::Receiver<Arc<RoomMessageStreamChunkEvent>>> {
let map = self.stream_inner.read().await;
map.get(&message_id).map(|tx| tx.subscribe())
}
pub async fn subscribe_room_stream(&self, room_id: Uuid) -> broadcast::Receiver<Arc<RoomMessageStreamChunkEvent>> {
// New subscriber: replay active streams in this room so they catch up,
// then subscribe to the room's channel.
let (_existing_tx, new_rx) = {
let mut map = self.room_stream_inner.write().await;
match map.get_mut(&room_id) {
Some((existing_tx, count)) => {
*count += 1;
let tx_clone = existing_tx.clone();
let rx_clone = existing_tx.subscribe();
drop(map);
// Replay buffered chunks to existing channel so all subscribers receive them.
let active = self.active_streams.read().await;
for (&msg_id, meta) in active.iter() {
if meta.room_id != room_id { continue; }
let start_event = Arc::new(RoomMessageStreamChunkEvent {
message_id: msg_id,
room_id,
seq: 0,
content: String::new(),
done: false,
error: None,
display_name: meta.display_name.clone(),
chunk_type: None,
});
let _ = tx_clone.send(Arc::clone(&start_event));
let chunks = meta.chunks.read().await;
for chunk in chunks.iter() {
let _ = tx_clone.send(Arc::new(chunk.clone()));
}
}
(tx_clone, rx_clone)
}
None => {
let (tx, rx) = broadcast::channel(BROADCAST_CAPACITY);
map.insert(room_id, (tx.clone(), 1));
drop(map);
// Replay buffered chunks to new channel.
let active = self.active_streams.read().await;
for (&msg_id, meta) in active.iter() {
if meta.room_id != room_id { continue; }
let start_event = Arc::new(RoomMessageStreamChunkEvent {
message_id: msg_id,
room_id,
seq: 0,
content: String::new(),
done: false,
error: None,
display_name: meta.display_name.clone(),
chunk_type: None,
});
let _ = tx.send(Arc::clone(&start_event));
let chunks = meta.chunks.read().await;
for chunk in chunks.iter() {
let _ = tx.send(Arc::new(chunk.clone()));
}
}
(tx, rx)
}
}
};
new_rx
}
pub async fn broadcast_stream_chunk(&self, event: RoomMessageStreamChunkEvent) {
{
let mut activity = self.room_last_activity.write().await;
activity.insert(event.room_id, std::time::Instant::now());
}
let is_start = event.seq == 0 && !event.done;
let is_final_chunk = event.done;
// Buffer chunk in active_streams for late-joiner replay.
if !is_final_chunk || is_start {
let mut active = self.active_streams.write().await;
if let Some(meta) = active.get_mut(&event.message_id) {
let mut chunks = meta.chunks.write().await;
chunks.push(event.clone());
// Evict oldest if buffer exceeds REPLAY_BUFFER_SIZE.
if chunks.len() > REPLAY_BUFFER_SIZE {
chunks.remove(0);
}
}
drop(active);
// Also update room_to_streams reverse index.
if is_start {
let mut r2s = self.room_to_streams.write().await;
r2s.entry(event.room_id).or_default().insert(event.message_id);
}
}
let event = Arc::new(event);
let map = self.stream_inner.read().await;
if let Some(tx) = map.get(&event.message_id) {
let _ = tx.send(Arc::clone(&event));
}
drop(map);
let map = self.room_stream_inner.read().await;
if let Some((tx, _)) = map.get(&event.room_id) {
let _ = tx.send(Arc::clone(&event));
}
if is_final_chunk {
drop(map);
// Cleanup active_streams entry.
let mut active = self.active_streams.write().await;
if active.remove(&event.message_id).is_some() {
let mut r2s = self.room_to_streams.write().await;
if let Some(ids) = r2s.get_mut(&event.room_id) {
ids.remove(&event.message_id);
if ids.is_empty() {
r2s.remove(&event.room_id);
}
}
}
drop(active);
// Cleanup room_stream_inner subscriber count.
let mut map = self.room_stream_inner.write().await;
if let Some((_, count)) = map.get_mut(&event.room_id) {
if *count > 0 {
*count -= 1;
}
if *count == 0 {
map.remove(&event.room_id);
}
}
}
}
pub async fn close_stream_channel(&self, message_id: Uuid) {
let mut map = self.stream_inner.write().await;
map.remove(&message_id);
drop(map);
// Remove from active_streams (cleanup on stream end).
let mut active = self.active_streams.write().await;
if let Some(meta) = active.remove(&message_id) {
let mut r2s = self.room_to_streams.write().await;
if let Some(ids) = r2s.get_mut(&meta.room_id) {
ids.remove(&message_id);
if ids.is_empty() {
r2s.remove(&meta.room_id);
}
}
}
}
pub async fn register_stream_cancel(&self, room_id: Uuid) -> Arc<std::sync::atomic::AtomicBool> {
let cancel = Arc::new(std::sync::atomic::AtomicBool::new(false));
let mut map = self.stream_cancel_tokens.write().await;
map.insert(room_id, cancel.clone());
cancel
}
pub async fn cancel_ai_stream(&self, room_id: Uuid) -> bool {
let map = self.stream_cancel_tokens.read().await;
if let Some(cancel) = map.get(&room_id) {
cancel.store(true, std::sync::atomic::Ordering::Release);
true
} else {
false
}
}
pub async fn unregister_stream_cancel(&self, room_id: Uuid) {
let mut map = self.stream_cancel_tokens.write().await;
map.remove(&room_id);
}
}