- make_persist_fn now accepts embed_service, collects persisted text messages - Filters non-text, non-empty, non-system/tool messages - Groups by room→project_name, batch-embeds via embed_memories_batch - Removes old per-message synchronous embed_memory call - Workers thread embed_service through to persist_fn
1260 lines
46 KiB
Rust
1260 lines
46 KiB
Rust
use std::collections::HashMap;
|
|
use std::future::Future;
|
|
use std::pin::Pin;
|
|
use std::sync::Arc;
|
|
use std::thread;
|
|
use std::time::{Duration, Instant};
|
|
use tokio::sync::{RwLock, broadcast};
|
|
use uuid::Uuid;
|
|
|
|
use db::cache::AppCache;
|
|
use db::database::AppDatabase;
|
|
use models::rooms::{MessageContentType, MessageSenderType, room_message};
|
|
use queue::types::TypingEvent;
|
|
use queue::{AgentTaskEvent, ProjectRoomEvent, RoomMessageEnvelope, RoomMessageEvent, RoomMessageStreamChunkEvent};
|
|
use sea_orm::{ColumnTrait, ConnectionTrait, EntityTrait, QueryFilter, Set};
|
|
|
|
use crate::error::RoomError;
|
|
use crate::metrics::RoomMetrics;
|
|
use crate::types::NotificationEvent;
|
|
|
|
const BROADCAST_CAPACITY: usize = 1000;
|
|
const SHUTDOWN_CHANNEL_CAPACITY: usize = 16;
|
|
const CONNECTION_COOLDOWN: Duration = Duration::from_secs(30);
|
|
const MAX_CONNECTIONS_PER_ROOM: usize = 50000;
|
|
const MAX_CONNECTIONS_PER_PROJECT: usize = 50000;
|
|
const MAX_CONNECTIONS_PER_USER: usize = 50000;
|
|
const BATCH_SIZE: usize = 100;
|
|
const ROOM_IDLE_TIMEOUT: Duration = Duration::from_secs(30 * 60);
|
|
|
|
pub struct RoomConnectionManager {
|
|
room_inner: RwLock<HashMap<Uuid, broadcast::Sender<Arc<RoomMessageEvent>>>>,
|
|
project_inner: RwLock<HashMap<Uuid, broadcast::Sender<Arc<ProjectRoomEvent>>>>,
|
|
user_inner: RwLock<HashMap<Uuid, broadcast::Sender<Arc<ProjectRoomEvent>>>>,
|
|
user_notification_inner: RwLock<HashMap<Uuid, broadcast::Sender<Arc<NotificationEvent>>>>,
|
|
/// Broadcast channel for agent task events per project.
|
|
task_inner: RwLock<HashMap<Uuid, broadcast::Sender<Arc<AgentTaskEvent>>>>,
|
|
pub metrics: Arc<RoomMetrics>,
|
|
cache: AppCache,
|
|
connection_rate: RwLock<HashMap<(Uuid, Uuid), Instant>>,
|
|
shutdown_tx: broadcast::Sender<()>,
|
|
room_shutdown_txs: RwLock<HashMap<Uuid, broadcast::Sender<()>>>,
|
|
project_shutdown_txs: RwLock<HashMap<Uuid, broadcast::Sender<()>>>,
|
|
user_shutdown_txs: RwLock<HashMap<Uuid, broadcast::Sender<()>>>,
|
|
stream_inner: RwLock<HashMap<Uuid, broadcast::Sender<Arc<RoomMessageStreamChunkEvent>>>>,
|
|
room_stream_inner: RwLock<HashMap<Uuid, broadcast::Sender<Arc<RoomMessageStreamChunkEvent>>>>,
|
|
typing_inner: RwLock<HashMap<Uuid, broadcast::Sender<Arc<TypingEvent>>>>,
|
|
room_last_activity: RwLock<HashMap<Uuid, Instant>>,
|
|
room_subscriber_count: RwLock<HashMap<Uuid, usize>>,
|
|
project_subscriber_count: RwLock<HashMap<Uuid, usize>>,
|
|
user_subscriber_count: RwLock<HashMap<Uuid, usize>>,
|
|
}
|
|
|
|
impl RoomConnectionManager {
|
|
pub fn new(metrics: Arc<RoomMetrics>, cache: AppCache) -> Self {
|
|
let (shutdown_tx, _) = broadcast::channel(SHUTDOWN_CHANNEL_CAPACITY);
|
|
Self {
|
|
#[allow(clippy::default_constructed_unit_structs)]
|
|
room_inner: RwLock::new(HashMap::new()),
|
|
#[allow(clippy::default_constructed_unit_structs)]
|
|
project_inner: RwLock::new(HashMap::new()),
|
|
#[allow(clippy::default_constructed_unit_structs)]
|
|
user_inner: RwLock::new(HashMap::new()),
|
|
#[allow(clippy::default_constructed_unit_structs)]
|
|
user_notification_inner: RwLock::new(HashMap::new()),
|
|
#[allow(clippy::default_constructed_unit_structs)]
|
|
task_inner: RwLock::new(HashMap::new()),
|
|
metrics,
|
|
cache,
|
|
#[allow(clippy::default_constructed_unit_structs)]
|
|
connection_rate: RwLock::new(HashMap::new()),
|
|
shutdown_tx,
|
|
#[allow(clippy::default_constructed_unit_structs)]
|
|
room_shutdown_txs: RwLock::new(HashMap::new()),
|
|
#[allow(clippy::default_constructed_unit_structs)]
|
|
project_shutdown_txs: RwLock::new(HashMap::new()),
|
|
#[allow(clippy::default_constructed_unit_structs)]
|
|
user_shutdown_txs: RwLock::new(HashMap::new()),
|
|
#[allow(clippy::default_constructed_unit_structs)]
|
|
stream_inner: RwLock::new(HashMap::new()),
|
|
#[allow(clippy::default_constructed_unit_structs)]
|
|
room_stream_inner: RwLock::new(HashMap::new()),
|
|
#[allow(clippy::default_constructed_unit_structs)]
|
|
typing_inner: RwLock::new(HashMap::new()),
|
|
#[allow(clippy::default_constructed_unit_structs)]
|
|
room_last_activity: RwLock::new(HashMap::new()),
|
|
#[allow(clippy::default_constructed_unit_structs)]
|
|
room_subscriber_count: RwLock::new(HashMap::new()),
|
|
#[allow(clippy::default_constructed_unit_structs)]
|
|
project_subscriber_count: RwLock::new(HashMap::new()),
|
|
#[allow(clippy::default_constructed_unit_structs)]
|
|
user_subscriber_count: RwLock::new(HashMap::new()),
|
|
}
|
|
}
|
|
|
|
pub async fn check_room_connection_rate(
|
|
&self,
|
|
room_id: Uuid,
|
|
user_id: Uuid,
|
|
) -> Result<(), RoomError> {
|
|
let mut map = self.connection_rate.write().await;
|
|
let key = (room_id, user_id);
|
|
if let Some(last) = map.get(&key) {
|
|
if last.elapsed() < CONNECTION_COOLDOWN {
|
|
return Err(RoomError::RateLimited(format!(
|
|
"Connection cooldown active, retry in {}s",
|
|
CONNECTION_COOLDOWN.saturating_sub(last.elapsed()).as_secs()
|
|
)));
|
|
}
|
|
}
|
|
map.insert(key, Instant::now());
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn check_project_connection_rate(
|
|
&self,
|
|
project_id: Uuid,
|
|
user_id: Uuid,
|
|
) -> Result<(), RoomError> {
|
|
let mut map = self.connection_rate.write().await;
|
|
let key = (project_id, user_id);
|
|
if let Some(last) = map.get(&key) {
|
|
if last.elapsed() < CONNECTION_COOLDOWN {
|
|
return Err(RoomError::RateLimited(format!(
|
|
"Connection cooldown active, retry in {}s",
|
|
CONNECTION_COOLDOWN.saturating_sub(last.elapsed()).as_secs()
|
|
)));
|
|
}
|
|
}
|
|
map.insert(key, Instant::now());
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn check_user_connection_rate(&self, user_id: Uuid) -> Result<(), RoomError> {
|
|
let mut map = self.connection_rate.write().await;
|
|
let key = (Uuid::nil(), user_id);
|
|
if let Some(last) = map.get(&key) {
|
|
if last.elapsed() < CONNECTION_COOLDOWN {
|
|
return Err(RoomError::RateLimited(format!(
|
|
"Connection cooldown active, retry in {}s",
|
|
CONNECTION_COOLDOWN.saturating_sub(last.elapsed()).as_secs()
|
|
)));
|
|
}
|
|
}
|
|
map.insert(key, Instant::now());
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn cleanup_rate_limit(&self) {
|
|
let mut map = self.connection_rate.write().await;
|
|
map.retain(|_, instant| instant.elapsed() < CONNECTION_COOLDOWN * 2);
|
|
const MAX_RATE_ENTRIES: usize = MAX_CONNECTIONS_PER_ROOM * 10;
|
|
if map.len() > MAX_RATE_ENTRIES {
|
|
let mut entries: Vec<_> = map.iter().collect();
|
|
entries.sort_by(|a, b| a.1.cmp(b.1));
|
|
let keep_count = entries.len() / 2;
|
|
let to_remove: Vec<_> = entries
|
|
.into_iter()
|
|
.take(keep_count)
|
|
.map(|(k, _)| *k)
|
|
.collect();
|
|
for key in to_remove {
|
|
map.remove(&key);
|
|
}
|
|
}
|
|
drop(map);
|
|
self.cleanup_idle_rooms().await;
|
|
}
|
|
|
|
pub async fn cleanup_idle_rooms(&self) {
|
|
let now = Instant::now();
|
|
let activity = self.room_last_activity.read().await;
|
|
let idle_room_ids: Vec<Uuid> = activity
|
|
.iter()
|
|
.filter(|(_, last_time)| now.duration_since(**last_time) > ROOM_IDLE_TIMEOUT)
|
|
.map(|(room_id, _)| *room_id)
|
|
.collect();
|
|
drop(activity);
|
|
|
|
if idle_room_ids.is_empty() {
|
|
return;
|
|
}
|
|
|
|
{
|
|
let mut counts = self.room_subscriber_count.write().await;
|
|
let mut rooms = self.room_inner.write().await;
|
|
for room_id in &idle_room_ids {
|
|
if let Some(sender) = rooms.remove(room_id) {
|
|
let count = counts.remove(&room_id).unwrap_or(1);
|
|
self.metrics.users_online.decrement(count as f64);
|
|
drop(sender);
|
|
}
|
|
}
|
|
}
|
|
|
|
{
|
|
let mut stream_map = self.room_stream_inner.write().await;
|
|
for room_id in &idle_room_ids {
|
|
stream_map.remove(room_id);
|
|
}
|
|
}
|
|
|
|
{
|
|
let mut txs = self.room_shutdown_txs.write().await;
|
|
for room_id in &idle_room_ids {
|
|
txs.remove(room_id);
|
|
}
|
|
}
|
|
|
|
{
|
|
let mut activity = self.room_last_activity.write().await;
|
|
for room_id in &idle_room_ids {
|
|
activity.remove(room_id);
|
|
}
|
|
}
|
|
}
|
|
|
|
pub fn subscribe_shutdown(&self) -> broadcast::Receiver<()> {
|
|
self.shutdown_tx.subscribe()
|
|
}
|
|
|
|
pub fn trigger_shutdown(&self) {
|
|
let _ = self.shutdown_tx.send(());
|
|
}
|
|
|
|
pub async fn register_room(&self, room_id: Uuid) -> broadcast::Receiver<()> {
|
|
let mut txs = self.room_shutdown_txs.write().await;
|
|
if let Some(tx) = txs.get(&room_id) {
|
|
return tx.subscribe();
|
|
}
|
|
let (tx, rx) = broadcast::channel(SHUTDOWN_CHANNEL_CAPACITY);
|
|
txs.insert(room_id, tx);
|
|
rx
|
|
}
|
|
|
|
pub async fn shutdown_room(&self, room_id: Uuid) {
|
|
{
|
|
let txs = self.room_shutdown_txs.read().await;
|
|
if let Some(tx) = txs.get(&room_id) {
|
|
let _ = tx.send(());
|
|
}
|
|
}
|
|
{
|
|
let mut counts = self.room_subscriber_count.write().await;
|
|
let count = counts.remove(&room_id).unwrap_or(0) as f64;
|
|
if count > 0.0 {
|
|
self.metrics.users_online.decrement(count);
|
|
}
|
|
}
|
|
{
|
|
let mut map = self.room_inner.write().await;
|
|
map.remove(&room_id);
|
|
}
|
|
{
|
|
let mut stream_map = self.room_stream_inner.write().await;
|
|
stream_map.remove(&room_id);
|
|
}
|
|
{
|
|
let mut txs = self.room_shutdown_txs.write().await;
|
|
txs.remove(&room_id);
|
|
}
|
|
}
|
|
|
|
pub async fn prune_stale_rooms(&self, active_room_ids: &[Uuid]) {
|
|
let mut txs = self.room_shutdown_txs.write().await;
|
|
txs.retain(|room_id, _| active_room_ids.contains(room_id));
|
|
drop(txs);
|
|
let mut counts = self.room_subscriber_count.write().await;
|
|
counts.retain(|room_id, _| active_room_ids.contains(room_id));
|
|
}
|
|
|
|
pub async fn register_project(&self, project_id: Uuid) -> broadcast::Receiver<()> {
|
|
let mut txs = self.project_shutdown_txs.write().await;
|
|
if let Some(tx) = txs.get(&project_id) {
|
|
return tx.subscribe();
|
|
}
|
|
let (tx, rx) = broadcast::channel(SHUTDOWN_CHANNEL_CAPACITY);
|
|
txs.insert(project_id, tx);
|
|
rx
|
|
}
|
|
|
|
pub async fn shutdown_project(&self, project_id: Uuid) {
|
|
{
|
|
let txs = self.project_shutdown_txs.read().await;
|
|
if let Some(tx) = txs.get(&project_id) {
|
|
let _ = tx.send(());
|
|
}
|
|
}
|
|
{
|
|
let mut map = self.project_inner.write().await;
|
|
map.remove(&project_id);
|
|
}
|
|
{
|
|
let mut txs = self.project_shutdown_txs.write().await;
|
|
txs.remove(&project_id);
|
|
}
|
|
}
|
|
|
|
pub async fn prune_stale_projects(&self, active_project_ids: &[Uuid]) {
|
|
let mut txs = self.project_shutdown_txs.write().await;
|
|
txs.retain(|project_id, _| active_project_ids.contains(project_id));
|
|
}
|
|
|
|
pub async fn register_user(&self, user_id: Uuid) -> broadcast::Receiver<()> {
|
|
let mut txs = self.user_shutdown_txs.write().await;
|
|
if let Some(tx) = txs.get(&user_id) {
|
|
return tx.subscribe();
|
|
}
|
|
let (tx, rx) = broadcast::channel(SHUTDOWN_CHANNEL_CAPACITY);
|
|
txs.insert(user_id, tx);
|
|
rx
|
|
}
|
|
|
|
pub async fn shutdown_user(&self, user_id: Uuid) {
|
|
{
|
|
let txs = self.user_shutdown_txs.read().await;
|
|
if let Some(tx) = txs.get(&user_id) {
|
|
let _ = tx.send(());
|
|
}
|
|
}
|
|
{
|
|
let mut map = self.user_inner.write().await;
|
|
map.remove(&user_id);
|
|
}
|
|
{
|
|
let mut txs = self.user_shutdown_txs.write().await;
|
|
txs.remove(&user_id);
|
|
}
|
|
}
|
|
|
|
pub async fn subscribe_user_notification(
|
|
&self,
|
|
user_id: Uuid,
|
|
) -> broadcast::Receiver<Arc<NotificationEvent>> {
|
|
let mut map = self.user_notification_inner.write().await;
|
|
|
|
if let Some(sender) = map.get(&user_id) {
|
|
return sender.subscribe();
|
|
}
|
|
|
|
let (tx, rx) = broadcast::channel(BROADCAST_CAPACITY);
|
|
map.insert(user_id, tx);
|
|
rx
|
|
}
|
|
|
|
pub async fn unsubscribe_user_notification(&self, user_id: Uuid) {
|
|
let mut map = self.user_notification_inner.write().await;
|
|
map.remove(&user_id);
|
|
}
|
|
|
|
pub async fn push_user_notification(&self, user_id: Uuid, event: Arc<NotificationEvent>) {
|
|
let map = self.user_notification_inner.read().await;
|
|
if let Some(sender) = map.get(&user_id) {
|
|
let _ = sender.send(event);
|
|
}
|
|
}
|
|
|
|
pub async fn subscribe(
|
|
&self,
|
|
room_id: Uuid,
|
|
_user_id: Uuid,
|
|
) -> Result<broadcast::Receiver<Arc<RoomMessageEvent>>, RoomError> {
|
|
let mut map = self.room_inner.write().await;
|
|
if let Some(_sender) = map.get(&room_id) {
|
|
drop(map);
|
|
let mut counts = self.room_subscriber_count.write().await;
|
|
*counts.entry(room_id).or_insert(0) += 1;
|
|
let map = self.room_inner.read().await;
|
|
if let Some(sender) = map.get(&room_id) {
|
|
return Ok(sender.subscribe());
|
|
}
|
|
return Err(RoomError::Internal(
|
|
"room disappeared during subscribe".into(),
|
|
));
|
|
}
|
|
|
|
if map.len() >= MAX_CONNECTIONS_PER_ROOM {
|
|
return Err(RoomError::RateLimited(format!(
|
|
"Room connection limit reached ({})",
|
|
MAX_CONNECTIONS_PER_ROOM
|
|
)));
|
|
}
|
|
|
|
let (tx, rx) = broadcast::channel(BROADCAST_CAPACITY);
|
|
map.insert(room_id, tx);
|
|
drop(map);
|
|
let mut counts = self.room_subscriber_count.write().await;
|
|
counts.insert(room_id, 1);
|
|
self.metrics.users_online.increment(1.0);
|
|
Ok(rx)
|
|
}
|
|
|
|
pub async fn unsubscribe(&self, room_id: Uuid, _user_id: Uuid) {
|
|
let mut counts = self.room_subscriber_count.write().await;
|
|
let count = counts.entry(room_id).or_insert(0);
|
|
if *count > 0 {
|
|
*count -= 1;
|
|
self.metrics.users_online.decrement(1.0);
|
|
}
|
|
if *count == 0 {
|
|
counts.remove(&room_id);
|
|
drop(counts);
|
|
let mut map = self.room_inner.write().await;
|
|
map.remove(&room_id);
|
|
}
|
|
}
|
|
|
|
pub async fn broadcast(&self, room_id: Uuid, event: RoomMessageEvent) {
|
|
{
|
|
let mut activity = self.room_last_activity.write().await;
|
|
activity.insert(room_id, Instant::now());
|
|
}
|
|
|
|
let map = self.room_inner.read().await;
|
|
if let Some(sender) = map.get(&room_id) {
|
|
let event = Arc::new(event);
|
|
if sender.send(event).is_err() {
|
|
self.metrics.broadcasts_dropped.increment(1);
|
|
}
|
|
}
|
|
}
|
|
|
|
pub async fn subscribe_project(
|
|
&self,
|
|
project_id: Uuid,
|
|
_user_id: Uuid,
|
|
) -> Result<broadcast::Receiver<Arc<ProjectRoomEvent>>, RoomError> {
|
|
let mut map = self.project_inner.write().await;
|
|
if map.get(&project_id).is_some() {
|
|
drop(map);
|
|
let mut counts = self.project_subscriber_count.write().await;
|
|
*counts.entry(project_id).or_insert(0) += 1;
|
|
let map = self.project_inner.read().await;
|
|
if let Some(sender) = map.get(&project_id) {
|
|
return Ok(sender.subscribe());
|
|
}
|
|
return Err(RoomError::Internal("project channel disappeared".into()));
|
|
}
|
|
|
|
if map.len() >= MAX_CONNECTIONS_PER_PROJECT {
|
|
return Err(RoomError::RateLimited(format!(
|
|
"Project connection limit reached ({})",
|
|
MAX_CONNECTIONS_PER_PROJECT
|
|
)));
|
|
}
|
|
|
|
let (tx, rx) = broadcast::channel(BROADCAST_CAPACITY);
|
|
map.insert(project_id, tx);
|
|
drop(map);
|
|
let mut counts = self.project_subscriber_count.write().await;
|
|
counts.insert(project_id, 1);
|
|
Ok(rx)
|
|
}
|
|
|
|
pub async fn unsubscribe_project(&self, project_id: Uuid, _user_id: Uuid) {
|
|
let mut counts = self.project_subscriber_count.write().await;
|
|
let count = counts.entry(project_id).or_insert(0);
|
|
if *count > 0 {
|
|
*count -= 1;
|
|
}
|
|
if *count == 0 {
|
|
counts.remove(&project_id);
|
|
drop(counts);
|
|
let mut map = self.project_inner.write().await;
|
|
map.remove(&project_id);
|
|
}
|
|
}
|
|
|
|
pub async fn broadcast_project(&self, project_id: Uuid, event: ProjectRoomEvent) {
|
|
let map = self.project_inner.read().await;
|
|
if let Some(sender) = map.get(&project_id) {
|
|
let event = Arc::new(event);
|
|
if sender.send(event).is_err() {
|
|
self.metrics.broadcasts_dropped.increment(1);
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Broadcast an agent task event to all WS clients subscribed to this project.
|
|
pub async fn broadcast_agent_task(&self, project_id: Uuid, event: AgentTaskEvent) {
|
|
let map = self.task_inner.read().await;
|
|
if let Some(sender) = map.get(&project_id) {
|
|
let event = Arc::new(event);
|
|
if sender.send(event).is_err() {
|
|
self.metrics.broadcasts_dropped.increment(1);
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Subscribe to agent task events for a project.
|
|
/// Returns a broadcast receiver that yields task events as they occur.
|
|
pub async fn subscribe_task_events(
|
|
&self,
|
|
project_id: Uuid,
|
|
) -> Result<broadcast::Receiver<Arc<AgentTaskEvent>>, RoomError> {
|
|
let mut map = self.task_inner.write().await;
|
|
|
|
if let Some(sender) = map.get(&project_id).cloned() {
|
|
drop(map);
|
|
return Ok(sender.subscribe());
|
|
}
|
|
|
|
let (tx, rx) = broadcast::channel(BROADCAST_CAPACITY);
|
|
map.insert(project_id, tx);
|
|
Ok(rx)
|
|
}
|
|
|
|
pub async fn subscribe_user(
|
|
&self,
|
|
user_id: Uuid,
|
|
) -> Result<broadcast::Receiver<Arc<ProjectRoomEvent>>, RoomError> {
|
|
let mut map = self.user_inner.write().await;
|
|
|
|
if let Some(_sender) = map.get(&user_id) {
|
|
drop(map);
|
|
let mut counts = self.user_subscriber_count.write().await;
|
|
*counts.entry(user_id).or_insert(0) += 1;
|
|
let map = self.user_inner.read().await;
|
|
if let Some(sender) = map.get(&user_id) {
|
|
return Ok(sender.subscribe());
|
|
}
|
|
return Err(RoomError::Internal("user channel disappeared".into()));
|
|
}
|
|
|
|
if map.len() >= MAX_CONNECTIONS_PER_USER {
|
|
return Err(RoomError::RateLimited(format!(
|
|
"User connection limit reached ({})",
|
|
MAX_CONNECTIONS_PER_USER
|
|
)));
|
|
}
|
|
|
|
let (tx, rx) = broadcast::channel(BROADCAST_CAPACITY);
|
|
map.insert(user_id, tx);
|
|
drop(map);
|
|
let mut counts = self.user_subscriber_count.write().await;
|
|
counts.insert(user_id, 1);
|
|
self.metrics.users_online.increment(1.0);
|
|
Ok(rx)
|
|
}
|
|
|
|
pub async fn unsubscribe_user(&self, user_id: Uuid) {
|
|
let mut counts = self.user_subscriber_count.write().await;
|
|
let count = counts.entry(user_id).or_insert(0);
|
|
if *count > 0 {
|
|
*count -= 1;
|
|
}
|
|
if *count == 0 {
|
|
self.metrics.users_online.decrement(1.0);
|
|
counts.remove(&user_id);
|
|
drop(counts);
|
|
let mut map = self.user_inner.write().await;
|
|
map.remove(&user_id);
|
|
}
|
|
}
|
|
|
|
pub async fn broadcast_to_user(&self, user_id: Uuid, event: ProjectRoomEvent) {
|
|
let map = self.user_inner.read().await;
|
|
if let Some(sender) = map.get(&user_id) {
|
|
let event = Arc::new(event);
|
|
if sender.send(event).is_err() {
|
|
self.metrics.broadcasts_dropped.increment(1);
|
|
}
|
|
}
|
|
}
|
|
|
|
pub async fn register_stream_channel(
|
|
&self,
|
|
message_id: Uuid,
|
|
) -> broadcast::Receiver<Arc<RoomMessageStreamChunkEvent>> {
|
|
let mut map = self.stream_inner.write().await;
|
|
if let Some(tx) = map.get(&message_id) {
|
|
return tx.subscribe();
|
|
}
|
|
let (tx, rx) = broadcast::channel(BROADCAST_CAPACITY);
|
|
map.insert(message_id, tx);
|
|
rx
|
|
}
|
|
|
|
pub async fn subscribe_stream(
|
|
&self,
|
|
message_id: Uuid,
|
|
) -> Option<broadcast::Receiver<Arc<RoomMessageStreamChunkEvent>>> {
|
|
let map = self.stream_inner.read().await;
|
|
map.get(&message_id).map(|tx| tx.subscribe())
|
|
}
|
|
|
|
pub async fn subscribe_room_stream(
|
|
&self,
|
|
room_id: Uuid,
|
|
) -> broadcast::Receiver<Arc<RoomMessageStreamChunkEvent>> {
|
|
let mut map = self.room_stream_inner.write().await;
|
|
if let Some(tx) = map.get(&room_id) {
|
|
return tx.subscribe();
|
|
}
|
|
let (tx, rx) = broadcast::channel(BROADCAST_CAPACITY);
|
|
map.insert(room_id, tx);
|
|
rx
|
|
}
|
|
|
|
pub async fn broadcast_stream_chunk(&self, event: RoomMessageStreamChunkEvent) {
|
|
// Update activity tracker to prevent idle cleanup during active streaming
|
|
{
|
|
let mut activity = self.room_last_activity.write().await;
|
|
activity.insert(event.room_id, Instant::now());
|
|
}
|
|
|
|
let event = Arc::new(event);
|
|
let is_final_chunk = event.done;
|
|
|
|
let map = self.stream_inner.read().await;
|
|
if let Some(tx) = map.get(&event.message_id) {
|
|
let _ = tx.send(Arc::clone(&event));
|
|
}
|
|
|
|
drop(map);
|
|
let map = self.room_stream_inner.read().await;
|
|
if let Some(tx) = map.get(&event.room_id) {
|
|
let _ = tx.send(Arc::clone(&event));
|
|
}
|
|
|
|
if is_final_chunk {
|
|
drop(map);
|
|
let mut map = self.room_stream_inner.write().await;
|
|
map.remove(&event.room_id);
|
|
}
|
|
}
|
|
|
|
pub async fn close_stream_channel(&self, message_id: Uuid) {
|
|
let mut map = self.stream_inner.write().await;
|
|
map.remove(&message_id);
|
|
}
|
|
|
|
pub async fn subscribe_typing(
|
|
&self,
|
|
room_id: Uuid,
|
|
) -> broadcast::Receiver<Arc<TypingEvent>> {
|
|
let mut map: tokio::sync::RwLockWriteGuard<'_, std::collections::HashMap<Uuid, broadcast::Sender<Arc<TypingEvent>>>> = self.typing_inner.write().await;
|
|
let tx = map.entry(room_id).or_insert_with(|| {
|
|
let (tx, _) = broadcast::channel(BROADCAST_CAPACITY);
|
|
tx
|
|
});
|
|
// Replay active typing state from Redis to the new subscriber.
|
|
// This ensures newly connected WS clients see who is currently typing.
|
|
let active_events = self.get_active_typing_events(room_id).await;
|
|
for event in active_events {
|
|
let _ = tx.send(Arc::new(event));
|
|
}
|
|
tx.subscribe()
|
|
}
|
|
|
|
/// Broadcast a typing event and persist it to Redis with 60s TTL.
|
|
/// - "start": writes key with 60s expiry, broadcasts start event
|
|
/// - "stop": deletes key, broadcasts stop event
|
|
pub async fn broadcast_typing(&self, room_id: Uuid, event: TypingEvent) {
|
|
let user_key = format!("typing:{}:{}", room_id, event.user_id);
|
|
let action = event.action.clone();
|
|
let username = event.username.clone();
|
|
let avatar_url = event.avatar_url.clone();
|
|
let sender_type = event.sender_type.clone().unwrap_or_else(|| "user".to_string());
|
|
|
|
// Write/delete Redis key for 60s expiry (non-blocking)
|
|
if let Ok(mut conn) = self.cache.conn().await {
|
|
let key = user_key;
|
|
tokio::spawn(async move {
|
|
if action == "start" {
|
|
let value = serde_json::json!({
|
|
"username": username,
|
|
"avatar_url": avatar_url,
|
|
"sender_type": sender_type,
|
|
})
|
|
.to_string();
|
|
let _: Result<(), _> = redis::cmd("SETEX")
|
|
.arg(&key)
|
|
.arg(60i64)
|
|
.arg(&value)
|
|
.query_async(&mut conn)
|
|
.await;
|
|
} else {
|
|
let _: Result<(), _> = redis::cmd("DEL").arg(&key).query_async(&mut conn).await;
|
|
}
|
|
});
|
|
}
|
|
|
|
let map: tokio::sync::RwLockReadGuard<'_, std::collections::HashMap<Uuid, broadcast::Sender<Arc<TypingEvent>>>> = self.typing_inner.read().await;
|
|
if let Some(tx) = map.get(&room_id) {
|
|
let event = Arc::new(event);
|
|
let _ = tx.send(event);
|
|
}
|
|
}
|
|
|
|
/// Load all active typing entries for a room from Redis and return as TypingEvents.
|
|
/// Used to replay current typing state to newly connected WS clients.
|
|
/// Uses SCAN instead of KEYS to avoid blocking the Redis server.
|
|
pub async fn get_active_typing_events(&self, room_id: Uuid) -> Vec<TypingEvent> {
|
|
let pattern = format!("typing:{}:*", room_id);
|
|
if let Ok(mut conn) = self.cache.conn().await {
|
|
let mut cursor: u64 = 0;
|
|
let mut all_keys: Vec<String> = Vec::new();
|
|
loop {
|
|
let (next_cursor, keys): (u64, Vec<String>) = match redis::cmd("SCAN")
|
|
.arg(cursor)
|
|
.arg("MATCH")
|
|
.arg(&pattern)
|
|
.arg("COUNT")
|
|
.arg(100)
|
|
.query_async(&mut conn)
|
|
.await
|
|
{
|
|
Ok(r) => r,
|
|
Err(_) => return vec![],
|
|
};
|
|
all_keys.extend(keys);
|
|
if next_cursor == 0 {
|
|
break;
|
|
}
|
|
cursor = next_cursor;
|
|
}
|
|
if all_keys.is_empty() {
|
|
return vec![];
|
|
}
|
|
let mut results = Vec::new();
|
|
for key in all_keys {
|
|
let parts: Vec<&str> = key.split(':').collect();
|
|
let user_id = parts.get(2).and_then(|s| Uuid::parse_str(s).ok());
|
|
if let (Some(value), Some(user_uuid)) = (
|
|
redis::cmd("GET").arg(&key).query_async::<String>(&mut conn).await.ok(),
|
|
user_id,
|
|
) {
|
|
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&value) {
|
|
results.push(TypingEvent {
|
|
room_id,
|
|
user_id: user_uuid,
|
|
username: parsed.get("username").and_then(|v| v.as_str()).unwrap_or("").to_string(),
|
|
avatar_url: parsed.get("avatar_url").and_then(|v| v.as_str()).map(String::from),
|
|
action: "start".to_string(),
|
|
sender_type: parsed.get("sender_type").and_then(|v| v.as_str()).map(String::from),
|
|
});
|
|
}
|
|
}
|
|
}
|
|
return results;
|
|
}
|
|
vec![]
|
|
}
|
|
}
|
|
|
|
fn parse_sender_type(s: &str) -> MessageSenderType {
|
|
match s {
|
|
"member" => MessageSenderType::Member,
|
|
"admin" => MessageSenderType::Admin,
|
|
"owner" => MessageSenderType::Owner,
|
|
"ai" => MessageSenderType::Ai,
|
|
"system" => MessageSenderType::System,
|
|
"tool" => MessageSenderType::Tool,
|
|
"guest" => MessageSenderType::Guest,
|
|
_ => MessageSenderType::Member,
|
|
}
|
|
}
|
|
|
|
fn parse_content_type(s: &str) -> MessageContentType {
|
|
match s {
|
|
"text" => MessageContentType::Text,
|
|
"image" => MessageContentType::Image,
|
|
"audio" => MessageContentType::Audio,
|
|
"video" => MessageContentType::Video,
|
|
"file" => MessageContentType::File,
|
|
_ => MessageContentType::Text,
|
|
}
|
|
}
|
|
|
|
pub type PersistFn = Arc<
|
|
dyn Fn(Vec<RoomMessageEnvelope>) -> Pin<Box<dyn Future<Output = anyhow::Result<()>> + Send>>
|
|
+ Send
|
|
+ Sync,
|
|
>;
|
|
|
|
use dashmap::DashMap;
|
|
|
|
pub type DedupCache = Arc<DashMap<uuid::Uuid, Instant>>;
|
|
|
|
const DEDUP_CACHE_TTL: Duration = Duration::from_secs(300);
|
|
|
|
pub fn cleanup_dedup_cache(cache: &DedupCache) {
|
|
let cutoff = Instant::now() - DEDUP_CACHE_TTL;
|
|
cache.retain(|_, inserted_at| *inserted_at > cutoff);
|
|
}
|
|
|
|
pub fn make_persist_fn(
|
|
db: AppDatabase,
|
|
metrics: Arc<RoomMetrics>,
|
|
dedup_cache: DedupCache,
|
|
embed_service: Option<Arc<agent::embed::EmbedService>>,
|
|
) -> PersistFn {
|
|
Arc::new(move |envelopes: Vec<RoomMessageEnvelope>| {
|
|
let db = db.clone();
|
|
let metrics = metrics.clone();
|
|
let cache = dedup_cache.clone();
|
|
let embed = embed_service.clone();
|
|
Box::pin(async move {
|
|
let mut persisted: Vec<RoomMessageEnvelope> = Vec::new();
|
|
|
|
for chunk in envelopes.chunks(BATCH_SIZE) {
|
|
let mut models_to_insert = Vec::new();
|
|
let mut ids_to_dedup: Vec<uuid::Uuid> = Vec::new();
|
|
|
|
for env in chunk {
|
|
if cache.contains_key(&env.id) {
|
|
metrics.incr_duplicates_skipped();
|
|
continue;
|
|
}
|
|
ids_to_dedup.push(env.id);
|
|
}
|
|
|
|
let existing_ids: std::collections::HashSet<uuid::Uuid> =
|
|
if !ids_to_dedup.is_empty() {
|
|
room_message::Entity::find()
|
|
.filter(room_message::Column::Id.is_in(ids_to_dedup))
|
|
.into_model::<room_message::Model>()
|
|
.all(&db)
|
|
.await?
|
|
.into_iter()
|
|
.map(|m| m.id)
|
|
.collect()
|
|
} else {
|
|
std::collections::HashSet::new()
|
|
};
|
|
|
|
for env in chunk {
|
|
if cache.contains_key(&env.id) {
|
|
continue;
|
|
}
|
|
cache.insert(env.id, Instant::now());
|
|
|
|
if existing_ids.contains(&env.id) {
|
|
metrics.incr_duplicates_skipped();
|
|
continue;
|
|
}
|
|
|
|
let sender_type = parse_sender_type(&env.sender_type);
|
|
let content_type = parse_content_type(&env.content_type);
|
|
|
|
models_to_insert.push(room_message::ActiveModel {
|
|
id: Set(env.id),
|
|
seq: Set(env.seq),
|
|
room: Set(env.room_id),
|
|
sender_type: Set(sender_type),
|
|
sender_id: Set(env.sender_id),
|
|
model_id: Set(env.model_id),
|
|
thread: Set(env.thread_id),
|
|
content: Set(env.content.clone()),
|
|
content_type: Set(content_type),
|
|
thinking_content: Set(env.thinking_content.clone()),
|
|
edited_at: Set(None),
|
|
send_at: Set(env.send_at.clone()),
|
|
revoked: Set(None),
|
|
revoked_by: Set(None),
|
|
in_reply_to: Set(env.in_reply_to),
|
|
});
|
|
}
|
|
|
|
if !models_to_insert.is_empty() {
|
|
let count = models_to_insert.len() as u64;
|
|
room_message::Entity::insert_many(models_to_insert)
|
|
.exec(&db)
|
|
.await?;
|
|
|
|
// Batch update content_tsv using a single UPDATE with subquery
|
|
// instead of N individual UPDATE statements (N=chunk size, up to 100)
|
|
let ids: Vec<String> = chunk
|
|
.iter()
|
|
.filter(|e| !existing_ids.contains(&e.id))
|
|
.map(|e| format!("'{}'", e.id))
|
|
.collect();
|
|
|
|
if !ids.is_empty() {
|
|
let batch_sql = format!(
|
|
"UPDATE room_message AS t \
|
|
SET content_tsv = to_tsvector('simple', content) \
|
|
WHERE t.id IN ({})",
|
|
ids.join(",")
|
|
);
|
|
let stmt = sea_orm::Statement::from_sql_and_values(
|
|
sea_orm::DbBackend::Postgres,
|
|
&batch_sql,
|
|
vec![],
|
|
);
|
|
if let Err(e) = db.execute_raw(stmt).await {
|
|
tracing::warn!(error = %e, "full text index update failed");
|
|
}
|
|
}
|
|
|
|
metrics.messages_persisted.increment(count);
|
|
|
|
// Collect persisted messages for Qdrant embedding
|
|
for env in chunk {
|
|
if existing_ids.contains(&env.id) {
|
|
continue;
|
|
}
|
|
persisted.push(env.clone());
|
|
}
|
|
}
|
|
}
|
|
|
|
// Batch-embed text messages into Qdrant (non-blocking, fire-and-forget)
|
|
if let Some(embed) = embed {
|
|
if !persisted.is_empty() {
|
|
let embed_db = db.clone();
|
|
tokio::spawn(async move {
|
|
embed_persisted_messages(embed, embed_db, persisted).await;
|
|
});
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
})
|
|
})
|
|
}
|
|
|
|
/// Filter and batch-embed persisted messages into Qdrant per-room collections.
|
|
/// Only embeds text-type, non-empty messages (filters system/tool/non-text).
|
|
async fn embed_persisted_messages(
|
|
embed: Arc<agent::embed::EmbedService>,
|
|
db: AppDatabase,
|
|
messages: Vec<RoomMessageEnvelope>,
|
|
) {
|
|
// Filter: only text content, non-empty, skip system messages
|
|
let to_embed: Vec<&RoomMessageEnvelope> = messages
|
|
.iter()
|
|
.filter(|m| {
|
|
m.content_type == "text"
|
|
&& !m.content.trim().is_empty()
|
|
&& m.sender_type != "system"
|
|
&& m.sender_type != "tool"
|
|
})
|
|
.collect();
|
|
|
|
if to_embed.is_empty() {
|
|
return;
|
|
}
|
|
|
|
// Batch-lookup room → project_id → project_name
|
|
let room_ids: Vec<Uuid> = to_embed.iter().map(|m| m.room_id).collect();
|
|
let rooms = match models::rooms::room::Entity::find()
|
|
.filter(models::rooms::room::Column::Id.is_in(room_ids.clone()))
|
|
.all(&db)
|
|
.await
|
|
{
|
|
Ok(r) => r,
|
|
Err(e) => {
|
|
tracing::warn!(error = %e, "embed: failed to lookup rooms");
|
|
return;
|
|
}
|
|
};
|
|
|
|
let project_ids: Vec<Uuid> = rooms.iter().map(|r| r.project).collect();
|
|
let projects = match models::projects::project::Entity::find()
|
|
.filter(models::projects::project::Column::Id.is_in(project_ids))
|
|
.all(&db)
|
|
.await
|
|
{
|
|
Ok(p) => p,
|
|
Err(e) => {
|
|
tracing::warn!(error = %e, "embed: failed to lookup projects");
|
|
return;
|
|
}
|
|
};
|
|
|
|
// Build room_id → project_name map
|
|
use std::collections::HashMap;
|
|
let mut room_project: HashMap<Uuid, String> = HashMap::new();
|
|
for room in &rooms {
|
|
if let Some(proj) = projects.iter().find(|p| p.id == room.project) {
|
|
room_project.insert(room.id, proj.display_name.clone());
|
|
}
|
|
}
|
|
|
|
// Build EmbedMemoryInput list
|
|
let inputs: Vec<agent::embed::EmbedMemoryInput> = to_embed
|
|
.into_iter()
|
|
.filter_map(|m| {
|
|
let project_name = room_project.get(&m.room_id)?;
|
|
Some(agent::embed::EmbedMemoryInput {
|
|
message_id: m.id.to_string(),
|
|
content: m.content.clone(),
|
|
project_name: project_name.clone(),
|
|
room_id: m.room_id.to_string(),
|
|
user_id: m.sender_id.map(|id| id.to_string()),
|
|
sender_type: m.sender_type.clone(),
|
|
})
|
|
})
|
|
.collect();
|
|
|
|
if inputs.is_empty() {
|
|
return;
|
|
}
|
|
|
|
if let Err(e) = embed.embed_memories_batch(inputs).await {
|
|
tracing::warn!(error = %e, "batch memory embed failed");
|
|
}
|
|
}
|
|
|
|
pub type RedisFuture =
|
|
Pin<Box<dyn Future<Output = anyhow::Result<deadpool_redis::cluster::Connection>> + Send>>;
|
|
|
|
pub fn extract_get_redis(
|
|
producer: queue::MessageProducer,
|
|
) -> Arc<dyn Fn() -> RedisFuture + Send + Sync> {
|
|
Arc::new(move || {
|
|
let get_redis_fn = producer.get_redis.clone();
|
|
Box::pin(async move {
|
|
let handle = get_redis_fn();
|
|
match handle.await {
|
|
Ok(conn) => conn,
|
|
Err(_) => anyhow::bail!("redis pool task panicked"),
|
|
}
|
|
}) as RedisFuture
|
|
})
|
|
}
|
|
|
|
fn start_pubsub_thread<F, Fut>(
|
|
redis_url: String,
|
|
channel: String,
|
|
relay_tx: tokio::sync::mpsc::Sender<Vec<u8>>,
|
|
mut shutdown_rx: broadcast::Receiver<()>,
|
|
_on_msg: F,
|
|
) where
|
|
F: Fn(Vec<u8>) -> Fut + Send + Sync + 'static,
|
|
Fut: Future<Output = ()> + Send,
|
|
{
|
|
thread::Builder::new()
|
|
.name(format!("redis-pubsub-{}", &channel[..channel.len().min(32)]))
|
|
.spawn(move || {
|
|
let rt = tokio::runtime::Builder::new_current_thread()
|
|
.enable_all()
|
|
.build()
|
|
.expect("pubsub thread runtime");
|
|
rt.block_on(async {
|
|
let redis_url = redis_url.clone();
|
|
loop {
|
|
if shutdown_rx.try_recv().is_ok() {
|
|
tracing::info!(channel = %channel, "pubsub thread shutting down before connect");
|
|
break;
|
|
}
|
|
|
|
let client = match redis::Client::open(redis_url.as_str()) {
|
|
Ok(c) => c,
|
|
Err(e) => {
|
|
tracing::error!(channel = %channel, error = %e, "pubsub redis client open failed");
|
|
thread::sleep(Duration::from_secs(1));
|
|
continue;
|
|
}
|
|
};
|
|
|
|
let mut pubsub = match client.get_async_pubsub().await {
|
|
Ok(p) => p,
|
|
Err(e) => {
|
|
tracing::error!(channel = %channel, error = %e, "pubsub connection failed");
|
|
thread::sleep(Duration::from_secs(1));
|
|
continue;
|
|
}
|
|
};
|
|
|
|
match pubsub.subscribe(&channel).await {
|
|
Ok(_) => tracing::info!(channel = %channel, "pubsub subscribed"),
|
|
Err(e) => {
|
|
tracing::error!(channel = %channel, error = %e, "pubsub subscribe failed");
|
|
thread::sleep(Duration::from_secs(1));
|
|
continue;
|
|
}
|
|
}
|
|
|
|
let mut stream = pubsub.on_message();
|
|
|
|
loop {
|
|
if shutdown_rx.try_recv().is_ok() {
|
|
tracing::info!(channel = %channel, "pubsub thread shutting down");
|
|
return;
|
|
}
|
|
|
|
let msg = tokio::time::timeout(
|
|
Duration::from_millis(500),
|
|
futures::StreamExt::next(&mut stream),
|
|
)
|
|
.await;
|
|
|
|
match msg {
|
|
Ok(Some(msg)) => {
|
|
let payload = msg.get_payload_bytes();
|
|
tracing::debug!(channel = %channel, len = payload.len(), "pubsub received");
|
|
if relay_tx.send(payload.to_vec()).await.is_err() {
|
|
tracing::warn!(channel = %channel, "pubsub relay channel closed");
|
|
return;
|
|
}
|
|
}
|
|
Ok(None) => {
|
|
tracing::warn!(channel = %channel, "pubsub stream ended, will reconnect");
|
|
break;
|
|
}
|
|
Err(_) => {}
|
|
}
|
|
}
|
|
|
|
tracing::warn!(channel = %channel, "pubsub connection lost, reconnecting");
|
|
}
|
|
});
|
|
})
|
|
.expect("pubsub thread spawn");
|
|
}
|
|
|
|
pub async fn subscribe_room_events(
|
|
redis_url: String,
|
|
manager: Arc<RoomConnectionManager>,
|
|
room_id: Uuid,
|
|
mut shutdown_rx: broadcast::Receiver<()>,
|
|
) {
|
|
let channel = format!("room:pub:{}", room_id);
|
|
let (tx, mut rx) = tokio::sync::mpsc::channel::<Vec<u8>>(1024);
|
|
|
|
tracing::info!(room_id = %room_id, channel = %channel, "starting room pubsub subscriber");
|
|
|
|
let thread_channel = channel.clone();
|
|
let thread_shutdown = shutdown_rx.resubscribe();
|
|
start_pubsub_thread(
|
|
redis_url,
|
|
thread_channel,
|
|
tx,
|
|
thread_shutdown,
|
|
|_| async {},
|
|
);
|
|
|
|
loop {
|
|
tokio::select! {
|
|
_ = shutdown_rx.recv() => {
|
|
tracing::info!(room_id = %room_id, "room subscriber shutting down");
|
|
break;
|
|
}
|
|
payload = rx.recv() => {
|
|
match payload {
|
|
Some(data) => {
|
|
match serde_json::from_slice::<RoomMessageEvent>(&data) {
|
|
Ok(event) => {
|
|
manager.broadcast(room_id, event).await;
|
|
}
|
|
Err(e) => {
|
|
tracing::warn!(error = %e, "malformed RoomMessageEvent");
|
|
}
|
|
}
|
|
}
|
|
None => {
|
|
tracing::warn!(room_id = %room_id, "pubsub relay channel closed");
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
tracing::info!(room_id = %room_id, "room subscriber stopped");
|
|
}
|
|
|
|
pub async fn subscribe_project_room_events(
|
|
redis_url: String,
|
|
manager: Arc<RoomConnectionManager>,
|
|
project_id: Uuid,
|
|
mut shutdown_rx: broadcast::Receiver<()>,
|
|
) {
|
|
let channel = format!("project:pub:{}", project_id);
|
|
let (tx, mut rx) = tokio::sync::mpsc::channel::<Vec<u8>>(1024);
|
|
|
|
tracing::info!(project_id = %project_id, channel = %channel, "starting project pubsub subscriber");
|
|
|
|
let thread_channel = channel.clone();
|
|
let thread_shutdown = shutdown_rx.resubscribe();
|
|
start_pubsub_thread(
|
|
redis_url,
|
|
thread_channel,
|
|
tx,
|
|
thread_shutdown,
|
|
|_| async {},
|
|
);
|
|
|
|
loop {
|
|
tokio::select! {
|
|
_ = shutdown_rx.recv() => {
|
|
tracing::info!(project_id = %project_id, "project subscriber shutting down");
|
|
break;
|
|
}
|
|
payload = rx.recv() => {
|
|
match payload {
|
|
Some(data) => {
|
|
match serde_json::from_slice::<ProjectRoomEvent>(&data) {
|
|
Ok(event) => {
|
|
manager.broadcast_project(project_id, event).await;
|
|
}
|
|
Err(e) => {
|
|
tracing::warn!(error = %e, "malformed ProjectRoomEvent");
|
|
}
|
|
}
|
|
}
|
|
None => {
|
|
tracing::warn!(project_id = %project_id, "project pubsub relay channel closed");
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
tracing::info!(project_id = %project_id, "project subscriber stopped");
|
|
}
|
|
|
|
/// Subscribe to Redis Pub/Sub `task:pub:{project_id}` and relay events to
|
|
/// `RoomConnectionManager::broadcast_agent_task()` so all WS clients get notified.
|
|
pub async fn subscribe_task_events_fn(
|
|
redis_url: String,
|
|
manager: Arc<RoomConnectionManager>,
|
|
project_id: Uuid,
|
|
mut shutdown_rx: broadcast::Receiver<()>,
|
|
) {
|
|
let channel = format!("task:pub:{}", project_id);
|
|
let (tx, mut rx) = tokio::sync::mpsc::channel::<Vec<u8>>(1024);
|
|
|
|
tracing::info!(project_id = %project_id, channel = %channel, "starting task pubsub subscriber");
|
|
|
|
let thread_channel = channel.clone();
|
|
let thread_shutdown = shutdown_rx.resubscribe();
|
|
start_pubsub_thread(
|
|
redis_url,
|
|
thread_channel,
|
|
tx,
|
|
thread_shutdown,
|
|
|_| async {},
|
|
);
|
|
|
|
loop {
|
|
tokio::select! {
|
|
_ = shutdown_rx.recv() => {
|
|
tracing::info!(project_id = %project_id, "task subscriber shutting down");
|
|
break;
|
|
}
|
|
payload = rx.recv() => {
|
|
match payload {
|
|
Some(data) => {
|
|
match serde_json::from_slice::<AgentTaskEvent>(&data) {
|
|
Ok(event) => {
|
|
manager.broadcast_agent_task(project_id, event).await;
|
|
}
|
|
Err(e) => {
|
|
tracing::warn!(error = %e, "malformed AgentTaskEvent");
|
|
}
|
|
}
|
|
}
|
|
None => {
|
|
tracing::warn!(project_id = %project_id, "task pubsub relay channel closed");
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
tracing::info!(project_id = %project_id, "task subscriber stopped");
|
|
}
|