999 lines
35 KiB
Rust
999 lines
35 KiB
Rust
use std::collections::HashMap;
|
|
use std::future::Future;
|
|
use std::pin::Pin;
|
|
use std::sync::Arc;
|
|
use std::thread;
|
|
use std::time::{Duration, Instant};
|
|
use tokio::sync::{RwLock, broadcast};
|
|
use uuid::Uuid;
|
|
|
|
use db::database::AppDatabase;
|
|
use models::rooms::{MessageContentType, MessageSenderType, room_message};
|
|
use queue::{AgentTaskEvent, ProjectRoomEvent, RoomMessageEnvelope, RoomMessageEvent, RoomMessageStreamChunkEvent};
|
|
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter, Set};
|
|
|
|
use crate::error::RoomError;
|
|
use crate::metrics::RoomMetrics;
|
|
use crate::types::NotificationEvent;
|
|
|
|
const BROADCAST_CAPACITY: usize = 10000;
|
|
const SHUTDOWN_CHANNEL_CAPACITY: usize = 16;
|
|
const CONNECTION_COOLDOWN: Duration = Duration::from_secs(30);
|
|
const MAX_CONNECTIONS_PER_ROOM: usize = 50000;
|
|
const MAX_CONNECTIONS_PER_PROJECT: usize = 50000;
|
|
const MAX_CONNECTIONS_PER_USER: usize = 50000;
|
|
const BATCH_SIZE: usize = 100;
|
|
const ROOM_IDLE_TIMEOUT: Duration = Duration::from_secs(30 * 60);
|
|
|
|
pub struct RoomConnectionManager {
|
|
room_inner: RwLock<HashMap<Uuid, broadcast::Sender<Arc<RoomMessageEvent>>>>,
|
|
project_inner: RwLock<HashMap<Uuid, broadcast::Sender<Arc<ProjectRoomEvent>>>>,
|
|
user_inner: RwLock<HashMap<Uuid, broadcast::Sender<Arc<ProjectRoomEvent>>>>,
|
|
user_notification_inner: RwLock<HashMap<Uuid, broadcast::Sender<Arc<NotificationEvent>>>>,
|
|
/// Broadcast channel for agent task events per project.
|
|
task_inner: RwLock<HashMap<Uuid, broadcast::Sender<Arc<AgentTaskEvent>>>>,
|
|
pub metrics: Arc<RoomMetrics>,
|
|
connection_rate: RwLock<HashMap<(Uuid, Uuid), Instant>>,
|
|
shutdown_tx: broadcast::Sender<()>,
|
|
room_shutdown_txs: RwLock<HashMap<Uuid, broadcast::Sender<()>>>,
|
|
project_shutdown_txs: RwLock<HashMap<Uuid, broadcast::Sender<()>>>,
|
|
user_shutdown_txs: RwLock<HashMap<Uuid, broadcast::Sender<()>>>,
|
|
stream_inner: RwLock<HashMap<Uuid, broadcast::Sender<Arc<RoomMessageStreamChunkEvent>>>>,
|
|
room_stream_inner: RwLock<HashMap<Uuid, broadcast::Sender<Arc<RoomMessageStreamChunkEvent>>>>,
|
|
room_last_activity: RwLock<HashMap<Uuid, Instant>>,
|
|
room_subscriber_count: RwLock<HashMap<Uuid, usize>>,
|
|
project_subscriber_count: RwLock<HashMap<Uuid, usize>>,
|
|
user_subscriber_count: RwLock<HashMap<Uuid, usize>>,
|
|
}
|
|
|
|
impl RoomConnectionManager {
|
|
pub fn new(metrics: Arc<RoomMetrics>) -> Self {
|
|
let (shutdown_tx, _) = broadcast::channel(SHUTDOWN_CHANNEL_CAPACITY);
|
|
Self {
|
|
#[allow(clippy::default_constructed_unit_structs)]
|
|
room_inner: RwLock::new(HashMap::new()),
|
|
#[allow(clippy::default_constructed_unit_structs)]
|
|
project_inner: RwLock::new(HashMap::new()),
|
|
#[allow(clippy::default_constructed_unit_structs)]
|
|
user_inner: RwLock::new(HashMap::new()),
|
|
#[allow(clippy::default_constructed_unit_structs)]
|
|
user_notification_inner: RwLock::new(HashMap::new()),
|
|
#[allow(clippy::default_constructed_unit_structs)]
|
|
task_inner: RwLock::new(HashMap::new()),
|
|
metrics,
|
|
#[allow(clippy::default_constructed_unit_structs)]
|
|
connection_rate: RwLock::new(HashMap::new()),
|
|
shutdown_tx,
|
|
#[allow(clippy::default_constructed_unit_structs)]
|
|
room_shutdown_txs: RwLock::new(HashMap::new()),
|
|
#[allow(clippy::default_constructed_unit_structs)]
|
|
project_shutdown_txs: RwLock::new(HashMap::new()),
|
|
#[allow(clippy::default_constructed_unit_structs)]
|
|
user_shutdown_txs: RwLock::new(HashMap::new()),
|
|
#[allow(clippy::default_constructed_unit_structs)]
|
|
stream_inner: RwLock::new(HashMap::new()),
|
|
#[allow(clippy::default_constructed_unit_structs)]
|
|
room_stream_inner: RwLock::new(HashMap::new()),
|
|
#[allow(clippy::default_constructed_unit_structs)]
|
|
room_last_activity: RwLock::new(HashMap::new()),
|
|
#[allow(clippy::default_constructed_unit_structs)]
|
|
room_subscriber_count: RwLock::new(HashMap::new()),
|
|
#[allow(clippy::default_constructed_unit_structs)]
|
|
project_subscriber_count: RwLock::new(HashMap::new()),
|
|
#[allow(clippy::default_constructed_unit_structs)]
|
|
user_subscriber_count: RwLock::new(HashMap::new()),
|
|
}
|
|
}
|
|
|
|
pub async fn check_room_connection_rate(
|
|
&self,
|
|
room_id: Uuid,
|
|
user_id: Uuid,
|
|
) -> Result<(), RoomError> {
|
|
let mut map = self.connection_rate.write().await;
|
|
let key = (room_id, user_id);
|
|
if let Some(last) = map.get(&key) {
|
|
if last.elapsed() < CONNECTION_COOLDOWN {
|
|
return Err(RoomError::RateLimited(format!(
|
|
"Connection cooldown active, retry in {}s",
|
|
CONNECTION_COOLDOWN.saturating_sub(last.elapsed()).as_secs()
|
|
)));
|
|
}
|
|
}
|
|
map.insert(key, Instant::now());
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn check_project_connection_rate(
|
|
&self,
|
|
project_id: Uuid,
|
|
user_id: Uuid,
|
|
) -> Result<(), RoomError> {
|
|
let mut map = self.connection_rate.write().await;
|
|
let key = (project_id, user_id);
|
|
if let Some(last) = map.get(&key) {
|
|
if last.elapsed() < CONNECTION_COOLDOWN {
|
|
return Err(RoomError::RateLimited(format!(
|
|
"Connection cooldown active, retry in {}s",
|
|
CONNECTION_COOLDOWN.saturating_sub(last.elapsed()).as_secs()
|
|
)));
|
|
}
|
|
}
|
|
map.insert(key, Instant::now());
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn check_user_connection_rate(&self, user_id: Uuid) -> Result<(), RoomError> {
|
|
let mut map = self.connection_rate.write().await;
|
|
let key = (Uuid::nil(), user_id);
|
|
if let Some(last) = map.get(&key) {
|
|
if last.elapsed() < CONNECTION_COOLDOWN {
|
|
return Err(RoomError::RateLimited(format!(
|
|
"Connection cooldown active, retry in {}s",
|
|
CONNECTION_COOLDOWN.saturating_sub(last.elapsed()).as_secs()
|
|
)));
|
|
}
|
|
}
|
|
map.insert(key, Instant::now());
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn cleanup_rate_limit(&self) {
|
|
let mut map = self.connection_rate.write().await;
|
|
map.retain(|_, instant| instant.elapsed() < CONNECTION_COOLDOWN * 2);
|
|
const MAX_RATE_ENTRIES: usize = MAX_CONNECTIONS_PER_ROOM * 10;
|
|
if map.len() > MAX_RATE_ENTRIES {
|
|
let mut entries: Vec<_> = map.iter().collect();
|
|
entries.sort_by(|a, b| a.1.cmp(b.1));
|
|
let keep_count = entries.len() / 2;
|
|
let to_remove: Vec<_> = entries
|
|
.into_iter()
|
|
.take(keep_count)
|
|
.map(|(k, _)| *k)
|
|
.collect();
|
|
for key in to_remove {
|
|
map.remove(&key);
|
|
}
|
|
}
|
|
drop(map);
|
|
self.cleanup_idle_rooms().await;
|
|
}
|
|
|
|
pub async fn cleanup_idle_rooms(&self) {
|
|
let now = Instant::now();
|
|
let activity = self.room_last_activity.read().await;
|
|
let idle_room_ids: Vec<Uuid> = activity
|
|
.iter()
|
|
.filter(|(_, last_time)| now.duration_since(**last_time) > ROOM_IDLE_TIMEOUT)
|
|
.map(|(room_id, _)| *room_id)
|
|
.collect();
|
|
drop(activity);
|
|
|
|
if idle_room_ids.is_empty() {
|
|
return;
|
|
}
|
|
|
|
{
|
|
let mut counts = self.room_subscriber_count.write().await;
|
|
let mut rooms = self.room_inner.write().await;
|
|
for room_id in &idle_room_ids {
|
|
if let Some(sender) = rooms.remove(room_id) {
|
|
let count = counts.remove(&room_id).unwrap_or(1);
|
|
self.metrics.users_online.decrement(count as f64);
|
|
drop(sender);
|
|
}
|
|
}
|
|
}
|
|
|
|
{
|
|
let mut txs = self.room_shutdown_txs.write().await;
|
|
for room_id in &idle_room_ids {
|
|
txs.remove(room_id);
|
|
}
|
|
}
|
|
|
|
{
|
|
let mut activity = self.room_last_activity.write().await;
|
|
for room_id in &idle_room_ids {
|
|
activity.remove(room_id);
|
|
}
|
|
}
|
|
}
|
|
|
|
pub fn subscribe_shutdown(&self) -> broadcast::Receiver<()> {
|
|
self.shutdown_tx.subscribe()
|
|
}
|
|
|
|
pub fn trigger_shutdown(&self) {
|
|
let _ = self.shutdown_tx.send(());
|
|
}
|
|
|
|
pub async fn register_room(&self, room_id: Uuid) -> broadcast::Receiver<()> {
|
|
let mut txs = self.room_shutdown_txs.write().await;
|
|
if let Some(tx) = txs.get(&room_id) {
|
|
return tx.subscribe();
|
|
}
|
|
let (tx, rx) = broadcast::channel(SHUTDOWN_CHANNEL_CAPACITY);
|
|
txs.insert(room_id, tx);
|
|
rx
|
|
}
|
|
|
|
pub async fn shutdown_room(&self, room_id: Uuid) {
|
|
{
|
|
let txs = self.room_shutdown_txs.read().await;
|
|
if let Some(tx) = txs.get(&room_id) {
|
|
let _ = tx.send(());
|
|
}
|
|
}
|
|
{
|
|
let mut counts = self.room_subscriber_count.write().await;
|
|
let count = counts.remove(&room_id).unwrap_or(0) as f64;
|
|
if count > 0.0 {
|
|
self.metrics.users_online.decrement(count);
|
|
}
|
|
}
|
|
{
|
|
let mut map = self.room_inner.write().await;
|
|
map.remove(&room_id);
|
|
}
|
|
{
|
|
let mut txs = self.room_shutdown_txs.write().await;
|
|
txs.remove(&room_id);
|
|
}
|
|
}
|
|
|
|
pub async fn prune_stale_rooms(&self, active_room_ids: &[Uuid]) {
|
|
let mut txs = self.room_shutdown_txs.write().await;
|
|
txs.retain(|room_id, _| active_room_ids.contains(room_id));
|
|
drop(txs);
|
|
let mut counts = self.room_subscriber_count.write().await;
|
|
counts.retain(|room_id, _| active_room_ids.contains(room_id));
|
|
}
|
|
|
|
pub async fn register_project(&self, project_id: Uuid) -> broadcast::Receiver<()> {
|
|
let mut txs = self.project_shutdown_txs.write().await;
|
|
if let Some(tx) = txs.get(&project_id) {
|
|
return tx.subscribe();
|
|
}
|
|
let (tx, rx) = broadcast::channel(SHUTDOWN_CHANNEL_CAPACITY);
|
|
txs.insert(project_id, tx);
|
|
rx
|
|
}
|
|
|
|
pub async fn shutdown_project(&self, project_id: Uuid) {
|
|
{
|
|
let txs = self.project_shutdown_txs.read().await;
|
|
if let Some(tx) = txs.get(&project_id) {
|
|
let _ = tx.send(());
|
|
}
|
|
}
|
|
{
|
|
let mut map = self.project_inner.write().await;
|
|
map.remove(&project_id);
|
|
}
|
|
{
|
|
let mut txs = self.project_shutdown_txs.write().await;
|
|
txs.remove(&project_id);
|
|
}
|
|
}
|
|
|
|
pub async fn prune_stale_projects(&self, active_project_ids: &[Uuid]) {
|
|
let mut txs = self.project_shutdown_txs.write().await;
|
|
txs.retain(|project_id, _| active_project_ids.contains(project_id));
|
|
}
|
|
|
|
pub async fn register_user(&self, user_id: Uuid) -> broadcast::Receiver<()> {
|
|
let mut txs = self.user_shutdown_txs.write().await;
|
|
if let Some(tx) = txs.get(&user_id) {
|
|
return tx.subscribe();
|
|
}
|
|
let (tx, rx) = broadcast::channel(SHUTDOWN_CHANNEL_CAPACITY);
|
|
txs.insert(user_id, tx);
|
|
rx
|
|
}
|
|
|
|
pub async fn shutdown_user(&self, user_id: Uuid) {
|
|
{
|
|
let txs = self.user_shutdown_txs.read().await;
|
|
if let Some(tx) = txs.get(&user_id) {
|
|
let _ = tx.send(());
|
|
}
|
|
}
|
|
{
|
|
let mut map = self.user_inner.write().await;
|
|
map.remove(&user_id);
|
|
}
|
|
{
|
|
let mut txs = self.user_shutdown_txs.write().await;
|
|
txs.remove(&user_id);
|
|
}
|
|
}
|
|
|
|
pub async fn subscribe_user_notification(
|
|
&self,
|
|
user_id: Uuid,
|
|
) -> broadcast::Receiver<Arc<NotificationEvent>> {
|
|
let mut map = self.user_notification_inner.write().await;
|
|
|
|
if let Some(sender) = map.get(&user_id) {
|
|
return sender.subscribe();
|
|
}
|
|
|
|
let (tx, rx) = broadcast::channel(BROADCAST_CAPACITY);
|
|
map.insert(user_id, tx);
|
|
rx
|
|
}
|
|
|
|
pub async fn unsubscribe_user_notification(&self, user_id: Uuid) {
|
|
let mut map = self.user_notification_inner.write().await;
|
|
map.remove(&user_id);
|
|
}
|
|
|
|
pub async fn push_user_notification(&self, user_id: Uuid, event: Arc<NotificationEvent>) {
|
|
let map = self.user_notification_inner.read().await;
|
|
if let Some(sender) = map.get(&user_id) {
|
|
let _ = sender.send(event);
|
|
}
|
|
}
|
|
|
|
pub async fn subscribe(
|
|
&self,
|
|
room_id: Uuid,
|
|
_user_id: Uuid,
|
|
) -> Result<broadcast::Receiver<Arc<RoomMessageEvent>>, RoomError> {
|
|
let mut map = self.room_inner.write().await;
|
|
if let Some(_sender) = map.get(&room_id) {
|
|
drop(map);
|
|
let mut counts = self.room_subscriber_count.write().await;
|
|
*counts.entry(room_id).or_insert(0) += 1;
|
|
let map = self.room_inner.read().await;
|
|
if let Some(sender) = map.get(&room_id) {
|
|
return Ok(sender.subscribe());
|
|
}
|
|
return Err(RoomError::Internal(
|
|
"room disappeared during subscribe".into(),
|
|
));
|
|
}
|
|
|
|
if map.len() >= MAX_CONNECTIONS_PER_ROOM {
|
|
return Err(RoomError::RateLimited(format!(
|
|
"Room connection limit reached ({})",
|
|
MAX_CONNECTIONS_PER_ROOM
|
|
)));
|
|
}
|
|
|
|
let (tx, rx) = broadcast::channel(BROADCAST_CAPACITY);
|
|
map.insert(room_id, tx);
|
|
drop(map);
|
|
let mut counts = self.room_subscriber_count.write().await;
|
|
counts.insert(room_id, 1);
|
|
self.metrics.users_online.increment(1.0);
|
|
Ok(rx)
|
|
}
|
|
|
|
pub async fn unsubscribe(&self, room_id: Uuid, _user_id: Uuid) {
|
|
let mut counts = self.room_subscriber_count.write().await;
|
|
let count = counts.entry(room_id).or_insert(0);
|
|
if *count > 0 {
|
|
*count -= 1;
|
|
self.metrics.users_online.decrement(1.0);
|
|
}
|
|
if *count == 0 {
|
|
counts.remove(&room_id);
|
|
drop(counts);
|
|
let mut map = self.room_inner.write().await;
|
|
map.remove(&room_id);
|
|
}
|
|
}
|
|
|
|
pub async fn broadcast(&self, room_id: Uuid, event: RoomMessageEvent) {
|
|
{
|
|
let mut activity = self.room_last_activity.write().await;
|
|
activity.insert(room_id, Instant::now());
|
|
}
|
|
|
|
let map = self.room_inner.read().await;
|
|
if let Some(sender) = map.get(&room_id) {
|
|
let event = Arc::new(event);
|
|
if sender.send(event).is_err() {
|
|
self.metrics.broadcasts_dropped.increment(1);
|
|
}
|
|
}
|
|
}
|
|
|
|
pub async fn subscribe_project(
|
|
&self,
|
|
project_id: Uuid,
|
|
_user_id: Uuid,
|
|
) -> Result<broadcast::Receiver<Arc<ProjectRoomEvent>>, RoomError> {
|
|
let mut map = self.project_inner.write().await;
|
|
if map.get(&project_id).is_some() {
|
|
drop(map);
|
|
let mut counts = self.project_subscriber_count.write().await;
|
|
*counts.entry(project_id).or_insert(0) += 1;
|
|
let map = self.project_inner.read().await;
|
|
if let Some(sender) = map.get(&project_id) {
|
|
return Ok(sender.subscribe());
|
|
}
|
|
return Err(RoomError::Internal("project channel disappeared".into()));
|
|
}
|
|
|
|
if map.len() >= MAX_CONNECTIONS_PER_PROJECT {
|
|
return Err(RoomError::RateLimited(format!(
|
|
"Project connection limit reached ({})",
|
|
MAX_CONNECTIONS_PER_PROJECT
|
|
)));
|
|
}
|
|
|
|
let (tx, rx) = broadcast::channel(BROADCAST_CAPACITY);
|
|
map.insert(project_id, tx);
|
|
drop(map);
|
|
let mut counts = self.project_subscriber_count.write().await;
|
|
counts.insert(project_id, 1);
|
|
Ok(rx)
|
|
}
|
|
|
|
pub async fn unsubscribe_project(&self, project_id: Uuid, _user_id: Uuid) {
|
|
let mut counts = self.project_subscriber_count.write().await;
|
|
let count = counts.entry(project_id).or_insert(0);
|
|
if *count > 0 {
|
|
*count -= 1;
|
|
}
|
|
if *count == 0 {
|
|
counts.remove(&project_id);
|
|
drop(counts);
|
|
let mut map = self.project_inner.write().await;
|
|
map.remove(&project_id);
|
|
}
|
|
}
|
|
|
|
pub async fn broadcast_project(&self, project_id: Uuid, event: ProjectRoomEvent) {
|
|
let map = self.project_inner.read().await;
|
|
if let Some(sender) = map.get(&project_id) {
|
|
let event = Arc::new(event);
|
|
if sender.send(event).is_err() {
|
|
self.metrics.broadcasts_dropped.increment(1);
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Broadcast an agent task event to all WS clients subscribed to this project.
|
|
pub async fn broadcast_agent_task(&self, project_id: Uuid, event: AgentTaskEvent) {
|
|
let map = self.task_inner.read().await;
|
|
if let Some(sender) = map.get(&project_id) {
|
|
let event = Arc::new(event);
|
|
if sender.send(event).is_err() {
|
|
self.metrics.broadcasts_dropped.increment(1);
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Subscribe to agent task events for a project.
|
|
/// Returns a broadcast receiver that yields task events as they occur.
|
|
pub async fn subscribe_task_events(
|
|
&self,
|
|
project_id: Uuid,
|
|
) -> Result<broadcast::Receiver<Arc<AgentTaskEvent>>, RoomError> {
|
|
let mut map = self.task_inner.write().await;
|
|
|
|
if let Some(sender) = map.get(&project_id).cloned() {
|
|
drop(map);
|
|
return Ok(sender.subscribe());
|
|
}
|
|
|
|
let (tx, rx) = broadcast::channel(BROADCAST_CAPACITY);
|
|
map.insert(project_id, tx);
|
|
Ok(rx)
|
|
}
|
|
|
|
pub async fn subscribe_user(
|
|
&self,
|
|
user_id: Uuid,
|
|
) -> Result<broadcast::Receiver<Arc<ProjectRoomEvent>>, RoomError> {
|
|
let mut map = self.user_inner.write().await;
|
|
|
|
if let Some(_sender) = map.get(&user_id) {
|
|
drop(map);
|
|
let mut counts = self.user_subscriber_count.write().await;
|
|
*counts.entry(user_id).or_insert(0) += 1;
|
|
let map = self.user_inner.read().await;
|
|
if let Some(sender) = map.get(&user_id) {
|
|
return Ok(sender.subscribe());
|
|
}
|
|
return Err(RoomError::Internal("user channel disappeared".into()));
|
|
}
|
|
|
|
if map.len() >= MAX_CONNECTIONS_PER_USER {
|
|
return Err(RoomError::RateLimited(format!(
|
|
"User connection limit reached ({})",
|
|
MAX_CONNECTIONS_PER_USER
|
|
)));
|
|
}
|
|
|
|
let (tx, rx) = broadcast::channel(BROADCAST_CAPACITY);
|
|
map.insert(user_id, tx);
|
|
drop(map);
|
|
let mut counts = self.user_subscriber_count.write().await;
|
|
counts.insert(user_id, 1);
|
|
self.metrics.users_online.increment(1.0);
|
|
Ok(rx)
|
|
}
|
|
|
|
pub async fn unsubscribe_user(&self, user_id: Uuid) {
|
|
let mut counts = self.user_subscriber_count.write().await;
|
|
let count = counts.entry(user_id).or_insert(0);
|
|
if *count > 0 {
|
|
*count -= 1;
|
|
self.metrics.users_online.decrement(1.0);
|
|
}
|
|
if *count == 0 {
|
|
counts.remove(&user_id);
|
|
drop(counts);
|
|
let mut map = self.user_inner.write().await;
|
|
map.remove(&user_id);
|
|
}
|
|
}
|
|
|
|
pub async fn broadcast_to_user(&self, user_id: Uuid, event: ProjectRoomEvent) {
|
|
let map = self.user_inner.read().await;
|
|
if let Some(sender) = map.get(&user_id) {
|
|
let event = Arc::new(event);
|
|
if sender.send(event).is_err() {
|
|
self.metrics.broadcasts_dropped.increment(1);
|
|
}
|
|
}
|
|
}
|
|
|
|
pub async fn register_stream_channel(
|
|
&self,
|
|
message_id: Uuid,
|
|
) -> broadcast::Receiver<Arc<RoomMessageStreamChunkEvent>> {
|
|
let mut map = self.stream_inner.write().await;
|
|
if let Some(tx) = map.get(&message_id) {
|
|
return tx.subscribe();
|
|
}
|
|
let (tx, rx) = broadcast::channel(BROADCAST_CAPACITY);
|
|
map.insert(message_id, tx);
|
|
rx
|
|
}
|
|
|
|
pub async fn subscribe_stream(
|
|
&self,
|
|
message_id: Uuid,
|
|
) -> Option<broadcast::Receiver<Arc<RoomMessageStreamChunkEvent>>> {
|
|
let map = self.stream_inner.read().await;
|
|
map.get(&message_id).map(|tx| tx.subscribe())
|
|
}
|
|
|
|
pub async fn subscribe_room_stream(
|
|
&self,
|
|
room_id: Uuid,
|
|
) -> broadcast::Receiver<Arc<RoomMessageStreamChunkEvent>> {
|
|
let mut map = self.room_stream_inner.write().await;
|
|
if let Some(tx) = map.get(&room_id) {
|
|
return tx.subscribe();
|
|
}
|
|
let (tx, rx) = broadcast::channel(BROADCAST_CAPACITY);
|
|
map.insert(room_id, tx);
|
|
rx
|
|
}
|
|
|
|
pub async fn broadcast_stream_chunk(&self, event: RoomMessageStreamChunkEvent) {
|
|
let event = Arc::new(event);
|
|
let is_final_chunk = event.done;
|
|
|
|
let map = self.stream_inner.read().await;
|
|
if let Some(tx) = map.get(&event.message_id) {
|
|
let _ = tx.send(Arc::clone(&event));
|
|
}
|
|
|
|
drop(map);
|
|
let map = self.room_stream_inner.read().await;
|
|
if let Some(tx) = map.get(&event.room_id) {
|
|
let _ = tx.send(Arc::clone(&event));
|
|
}
|
|
|
|
if is_final_chunk {
|
|
drop(map);
|
|
let mut map = self.room_stream_inner.write().await;
|
|
map.remove(&event.room_id);
|
|
}
|
|
}
|
|
|
|
pub async fn close_stream_channel(&self, message_id: Uuid) {
|
|
let mut map = self.stream_inner.write().await;
|
|
map.remove(&message_id);
|
|
}
|
|
}
|
|
|
|
fn parse_sender_type(s: &str) -> MessageSenderType {
|
|
match s {
|
|
"member" => MessageSenderType::Member,
|
|
"admin" => MessageSenderType::Admin,
|
|
"owner" => MessageSenderType::Owner,
|
|
"ai" => MessageSenderType::Ai,
|
|
"system" => MessageSenderType::System,
|
|
"tool" => MessageSenderType::Tool,
|
|
"guest" => MessageSenderType::Guest,
|
|
_ => MessageSenderType::Member,
|
|
}
|
|
}
|
|
|
|
fn parse_content_type(s: &str) -> MessageContentType {
|
|
match s {
|
|
"text" => MessageContentType::Text,
|
|
"image" => MessageContentType::Image,
|
|
"audio" => MessageContentType::Audio,
|
|
"video" => MessageContentType::Video,
|
|
"file" => MessageContentType::File,
|
|
_ => MessageContentType::Text,
|
|
}
|
|
}
|
|
|
|
pub type PersistFn = Arc<
|
|
dyn Fn(Vec<RoomMessageEnvelope>) -> Pin<Box<dyn Future<Output = anyhow::Result<()>> + Send>>
|
|
+ Send
|
|
+ Sync,
|
|
>;
|
|
|
|
use dashmap::DashMap;
|
|
|
|
pub type DedupCache = Arc<DashMap<uuid::Uuid, Instant>>;
|
|
|
|
const DEDUP_CACHE_TTL: Duration = Duration::from_secs(300);
|
|
|
|
pub fn cleanup_dedup_cache(cache: &DedupCache) {
|
|
let cutoff = Instant::now() - DEDUP_CACHE_TTL;
|
|
cache.retain(|_, inserted_at| *inserted_at > cutoff);
|
|
}
|
|
|
|
pub fn make_persist_fn(
|
|
db: AppDatabase,
|
|
metrics: Arc<RoomMetrics>,
|
|
dedup_cache: DedupCache,
|
|
) -> PersistFn {
|
|
Arc::new(move |envelopes: Vec<RoomMessageEnvelope>| {
|
|
let db = db.clone();
|
|
let metrics = metrics.clone();
|
|
let cache = dedup_cache.clone();
|
|
Box::pin(async move {
|
|
for chunk in envelopes.chunks(BATCH_SIZE) {
|
|
let mut models_to_insert = Vec::new();
|
|
let mut ids_to_dedup: Vec<uuid::Uuid> = Vec::new();
|
|
|
|
for env in chunk {
|
|
if cache.contains_key(&env.id) {
|
|
metrics.incr_duplicates_skipped();
|
|
continue;
|
|
}
|
|
ids_to_dedup.push(env.id);
|
|
}
|
|
|
|
let existing_ids: std::collections::HashSet<uuid::Uuid> =
|
|
if !ids_to_dedup.is_empty() {
|
|
room_message::Entity::find()
|
|
.filter(room_message::Column::Id.is_in(ids_to_dedup))
|
|
.into_model::<room_message::Model>()
|
|
.all(&db)
|
|
.await?
|
|
.into_iter()
|
|
.map(|m| m.id)
|
|
.collect()
|
|
} else {
|
|
std::collections::HashSet::new()
|
|
};
|
|
|
|
for env in chunk {
|
|
if cache.contains_key(&env.id) {
|
|
continue;
|
|
}
|
|
cache.insert(env.id, Instant::now());
|
|
|
|
if existing_ids.contains(&env.id) {
|
|
metrics.incr_duplicates_skipped();
|
|
continue;
|
|
}
|
|
|
|
let sender_type = parse_sender_type(&env.sender_type);
|
|
let content_type = parse_content_type(&env.content_type);
|
|
|
|
models_to_insert.push(room_message::ActiveModel {
|
|
id: Set(env.id),
|
|
seq: Set(env.seq),
|
|
room: Set(env.room_id),
|
|
sender_type: Set(sender_type),
|
|
sender_id: Set(env.sender_id),
|
|
thread: Set(env.thread_id),
|
|
content: Set(env.content.clone()),
|
|
content_type: Set(content_type),
|
|
edited_at: Set(None),
|
|
send_at: Set(env.send_at.clone()),
|
|
revoked: Set(None),
|
|
revoked_by: Set(None),
|
|
in_reply_to: Set(env.in_reply_to),
|
|
});
|
|
}
|
|
|
|
if !models_to_insert.is_empty() {
|
|
let count = models_to_insert.len() as u64;
|
|
room_message::Entity::insert_many(models_to_insert)
|
|
.exec(&db)
|
|
.await?;
|
|
metrics.messages_persisted.increment(count);
|
|
}
|
|
}
|
|
Ok(())
|
|
})
|
|
})
|
|
}
|
|
|
|
pub type RedisFuture =
|
|
Pin<Box<dyn Future<Output = anyhow::Result<deadpool_redis::cluster::Connection>> + Send>>;
|
|
|
|
pub fn extract_get_redis(
|
|
producer: queue::MessageProducer,
|
|
) -> Arc<dyn Fn() -> RedisFuture + Send + Sync> {
|
|
Arc::new(move || {
|
|
let get_redis_fn = producer.get_redis.clone();
|
|
Box::pin(async move {
|
|
let handle = get_redis_fn();
|
|
match handle.await {
|
|
Ok(conn) => conn,
|
|
Err(_) => anyhow::bail!("redis pool task panicked"),
|
|
}
|
|
}) as RedisFuture
|
|
})
|
|
}
|
|
|
|
fn start_pubsub_thread<F, Fut>(
|
|
redis_url: String,
|
|
channel: String,
|
|
relay_tx: tokio::sync::mpsc::Sender<Vec<u8>>,
|
|
mut shutdown_rx: broadcast::Receiver<()>,
|
|
log: slog::Logger,
|
|
_on_msg: F,
|
|
) where
|
|
F: Fn(Vec<u8>) -> Fut + Send + Sync + 'static,
|
|
Fut: Future<Output = ()> + Send,
|
|
{
|
|
thread::Builder::new()
|
|
.name(format!("redis-pubsub-{}", &channel[..channel.len().min(32)]))
|
|
.spawn(move || {
|
|
let rt = tokio::runtime::Builder::new_current_thread()
|
|
.enable_all()
|
|
.build()
|
|
.expect("pubsub thread runtime");
|
|
rt.block_on(async {
|
|
let redis_url = redis_url.clone();
|
|
loop {
|
|
if shutdown_rx.try_recv().is_ok() {
|
|
slog::info!(log, "pubsub thread shutting down before connect"; "channel" => %channel);
|
|
break;
|
|
}
|
|
|
|
let client = match redis::Client::open(redis_url.as_str()) {
|
|
Ok(c) => c,
|
|
Err(e) => {
|
|
slog::error!(log, "pubsub redis client open failed"; "channel" => %channel, "error" => %e);
|
|
thread::sleep(Duration::from_secs(1));
|
|
continue;
|
|
}
|
|
};
|
|
|
|
let mut pubsub = match client.get_async_pubsub().await {
|
|
Ok(p) => p,
|
|
Err(e) => {
|
|
slog::error!(log, "pubsub connection failed"; "channel" => %channel, "error" => %e);
|
|
thread::sleep(Duration::from_secs(1));
|
|
continue;
|
|
}
|
|
};
|
|
|
|
match pubsub.subscribe(&channel).await {
|
|
Ok(_) => slog::info!(log, "pubsub subscribed"; "channel" => %channel),
|
|
Err(e) => {
|
|
slog::error!(log, "pubsub subscribe failed"; "channel" => %channel, "error" => %e);
|
|
thread::sleep(Duration::from_secs(1));
|
|
continue;
|
|
}
|
|
}
|
|
|
|
let mut stream = pubsub.on_message();
|
|
|
|
loop {
|
|
if shutdown_rx.try_recv().is_ok() {
|
|
slog::info!(log, "pubsub thread shutting down"; "channel" => %channel);
|
|
return;
|
|
}
|
|
|
|
let msg = tokio::time::timeout(
|
|
Duration::from_millis(500),
|
|
futures::StreamExt::next(&mut stream),
|
|
)
|
|
.await;
|
|
|
|
match msg {
|
|
Ok(Some(msg)) => {
|
|
let payload = msg.get_payload_bytes();
|
|
slog::debug!(log, "pubsub received"; "channel" => %channel, "len" => payload.len());
|
|
if relay_tx.send(payload.to_vec()).await.is_err() {
|
|
slog::warn!(log, "pubsub relay channel closed"; "channel" => %channel);
|
|
return;
|
|
}
|
|
}
|
|
Ok(None) => {
|
|
slog::warn!(log, "pubsub stream ended, will reconnect"; "channel" => %channel);
|
|
break;
|
|
}
|
|
Err(_) => {}
|
|
}
|
|
}
|
|
|
|
slog::warn!(log, "pubsub connection lost, reconnecting"; "channel" => %channel);
|
|
}
|
|
});
|
|
})
|
|
.expect("pubsub thread spawn");
|
|
}
|
|
|
|
pub async fn subscribe_room_events(
|
|
redis_url: String,
|
|
manager: Arc<RoomConnectionManager>,
|
|
room_id: Uuid,
|
|
log: slog::Logger,
|
|
mut shutdown_rx: broadcast::Receiver<()>,
|
|
) {
|
|
let channel = format!("room:pub:{}", room_id);
|
|
let (tx, mut rx) = tokio::sync::mpsc::channel::<Vec<u8>>(1024);
|
|
|
|
slog::info!(log, "starting room pubsub subscriber"; "room_id" => %room_id, "channel" => %channel);
|
|
|
|
let thread_log = log.clone();
|
|
let thread_channel = channel.clone();
|
|
let thread_shutdown = shutdown_rx.resubscribe();
|
|
start_pubsub_thread(
|
|
redis_url,
|
|
thread_channel,
|
|
tx,
|
|
thread_shutdown,
|
|
thread_log,
|
|
|_| async {},
|
|
);
|
|
|
|
loop {
|
|
tokio::select! {
|
|
_ = shutdown_rx.recv() => {
|
|
slog::info!(log, "room subscriber shutting down"; "room_id" => %room_id);
|
|
break;
|
|
}
|
|
payload = rx.recv() => {
|
|
match payload {
|
|
Some(data) => {
|
|
match serde_json::from_slice::<RoomMessageEvent>(&data) {
|
|
Ok(event) => {
|
|
manager.broadcast(room_id, event).await;
|
|
}
|
|
Err(e) => {
|
|
slog::warn!(log, "malformed RoomMessageEvent"; "error" => %e);
|
|
}
|
|
}
|
|
}
|
|
None => {
|
|
slog::warn!(log, "pubsub relay channel closed"; "room_id" => %room_id);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
slog::info!(log, "room subscriber stopped"; "room_id" => %room_id);
|
|
}
|
|
|
|
pub async fn subscribe_project_room_events(
|
|
redis_url: String,
|
|
manager: Arc<RoomConnectionManager>,
|
|
project_id: Uuid,
|
|
log: slog::Logger,
|
|
mut shutdown_rx: broadcast::Receiver<()>,
|
|
) {
|
|
let channel = format!("project:pub:{}", project_id);
|
|
let (tx, mut rx) = tokio::sync::mpsc::channel::<Vec<u8>>(1024);
|
|
|
|
slog::info!(log, "starting project pubsub subscriber"; "project_id" => %project_id, "channel" => %channel);
|
|
|
|
let thread_log = log.clone();
|
|
let thread_channel = channel.clone();
|
|
let thread_shutdown = shutdown_rx.resubscribe();
|
|
start_pubsub_thread(
|
|
redis_url,
|
|
thread_channel,
|
|
tx,
|
|
thread_shutdown,
|
|
thread_log,
|
|
|_| async {},
|
|
);
|
|
|
|
loop {
|
|
tokio::select! {
|
|
_ = shutdown_rx.recv() => {
|
|
slog::info!(log, "project subscriber shutting down"; "project_id" => %project_id);
|
|
break;
|
|
}
|
|
payload = rx.recv() => {
|
|
match payload {
|
|
Some(data) => {
|
|
match serde_json::from_slice::<ProjectRoomEvent>(&data) {
|
|
Ok(event) => {
|
|
manager.broadcast_project(project_id, event).await;
|
|
}
|
|
Err(e) => {
|
|
slog::warn!(log, "malformed ProjectRoomEvent"; "error" => %e);
|
|
}
|
|
}
|
|
}
|
|
None => {
|
|
slog::warn!(log, "project pubsub relay channel closed"; "project_id" => %project_id);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
slog::info!(log, "project subscriber stopped"; "project_id" => %project_id);
|
|
}
|
|
|
|
/// Subscribe to Redis Pub/Sub `task:pub:{project_id}` and relay events to
|
|
/// `RoomConnectionManager::broadcast_agent_task()` so all WS clients get notified.
|
|
pub async fn subscribe_task_events_fn(
|
|
redis_url: String,
|
|
manager: Arc<RoomConnectionManager>,
|
|
project_id: Uuid,
|
|
log: slog::Logger,
|
|
mut shutdown_rx: broadcast::Receiver<()>,
|
|
) {
|
|
let channel = format!("task:pub:{}", project_id);
|
|
let (tx, mut rx) = tokio::sync::mpsc::channel::<Vec<u8>>(1024);
|
|
|
|
slog::info!(log, "starting task pubsub subscriber"; "project_id" => %project_id, "channel" => %channel);
|
|
|
|
let thread_log = log.clone();
|
|
let thread_channel = channel.clone();
|
|
let thread_shutdown = shutdown_rx.resubscribe();
|
|
start_pubsub_thread(
|
|
redis_url,
|
|
thread_channel,
|
|
tx,
|
|
thread_shutdown,
|
|
thread_log,
|
|
|_| async {},
|
|
);
|
|
|
|
loop {
|
|
tokio::select! {
|
|
_ = shutdown_rx.recv() => {
|
|
slog::info!(log, "task subscriber shutting down"; "project_id" => %project_id);
|
|
break;
|
|
}
|
|
payload = rx.recv() => {
|
|
match payload {
|
|
Some(data) => {
|
|
match serde_json::from_slice::<AgentTaskEvent>(&data) {
|
|
Ok(event) => {
|
|
manager.broadcast_agent_task(project_id, event).await;
|
|
}
|
|
Err(e) => {
|
|
slog::warn!(log, "malformed AgentTaskEvent"; "error" => %e);
|
|
}
|
|
}
|
|
}
|
|
None => {
|
|
slog::warn!(log, "task pubsub relay channel closed"; "project_id" => %project_id);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
slog::info!(log, "task subscriber stopped"; "project_id" => %project_id);
|
|
}
|