gitdataai/libs/room/src/connection.rs
ZhenYi 8b47f677bb
Some checks are pending
CI / Rust Lint & Check (push) Waiting to run
CI / Rust Tests (push) Waiting to run
CI / Frontend Lint & Type Check (push) Waiting to run
CI / Frontend Build (push) Blocked by required conditions
fix(avatar): add upload API routes and fix URL path prefix
- Add /api/users/me/avatar and /api/projects/{name}/avatar multipart upload endpoints
- Fix avatar URL path: missing /avatar prefix (static.gitdata.ai/avatar/{file})
- Fix project avatar: Utc::now() → .timestamp(), missing extension, wrong return type
- Replace broken SkipNoisyPaths middleware with self-contained RequestLogger
  (actix-web 4.13 body type incompatibility with newer actix-http)
- Exclude /assets/* requests from main app logger
- Exclude /avatar/*, /blob/*, /media/*, /static/* from static server logger
- Fix TypingEvent missing sender_type field in ws_universal.rs and connection.rs
- Wire real fetch-based upload in user profile settings
- Add project avatar upload UI to project settings page
2026-04-25 23:19:22 +08:00

1133 lines
41 KiB
Rust

use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::thread;
use std::time::{Duration, Instant};
use tokio::sync::{RwLock, broadcast};
use uuid::Uuid;
use db::cache::AppCache;
use db::database::AppDatabase;
use models::rooms::{MessageContentType, MessageSenderType, room_message};
use queue::types::TypingEvent;
use queue::{AgentTaskEvent, ProjectRoomEvent, RoomMessageEnvelope, RoomMessageEvent, RoomMessageStreamChunkEvent};
use sea_orm::{ColumnTrait, ConnectionTrait, EntityTrait, QueryFilter, Set};
use crate::error::RoomError;
use crate::metrics::RoomMetrics;
use crate::types::NotificationEvent;
const BROADCAST_CAPACITY: usize = 100_000;
const SHUTDOWN_CHANNEL_CAPACITY: usize = 16;
const CONNECTION_COOLDOWN: Duration = Duration::from_secs(30);
const MAX_CONNECTIONS_PER_ROOM: usize = 50000;
const MAX_CONNECTIONS_PER_PROJECT: usize = 50000;
const MAX_CONNECTIONS_PER_USER: usize = 50000;
const BATCH_SIZE: usize = 100;
const ROOM_IDLE_TIMEOUT: Duration = Duration::from_secs(30 * 60);
pub struct RoomConnectionManager {
room_inner: RwLock<HashMap<Uuid, broadcast::Sender<Arc<RoomMessageEvent>>>>,
project_inner: RwLock<HashMap<Uuid, broadcast::Sender<Arc<ProjectRoomEvent>>>>,
user_inner: RwLock<HashMap<Uuid, broadcast::Sender<Arc<ProjectRoomEvent>>>>,
user_notification_inner: RwLock<HashMap<Uuid, broadcast::Sender<Arc<NotificationEvent>>>>,
/// Broadcast channel for agent task events per project.
task_inner: RwLock<HashMap<Uuid, broadcast::Sender<Arc<AgentTaskEvent>>>>,
pub metrics: Arc<RoomMetrics>,
cache: AppCache,
connection_rate: RwLock<HashMap<(Uuid, Uuid), Instant>>,
shutdown_tx: broadcast::Sender<()>,
room_shutdown_txs: RwLock<HashMap<Uuid, broadcast::Sender<()>>>,
project_shutdown_txs: RwLock<HashMap<Uuid, broadcast::Sender<()>>>,
user_shutdown_txs: RwLock<HashMap<Uuid, broadcast::Sender<()>>>,
stream_inner: RwLock<HashMap<Uuid, broadcast::Sender<Arc<RoomMessageStreamChunkEvent>>>>,
room_stream_inner: RwLock<HashMap<Uuid, broadcast::Sender<Arc<RoomMessageStreamChunkEvent>>>>,
typing_inner: RwLock<HashMap<Uuid, broadcast::Sender<Arc<TypingEvent>>>>,
room_last_activity: RwLock<HashMap<Uuid, Instant>>,
room_subscriber_count: RwLock<HashMap<Uuid, usize>>,
project_subscriber_count: RwLock<HashMap<Uuid, usize>>,
user_subscriber_count: RwLock<HashMap<Uuid, usize>>,
}
impl RoomConnectionManager {
pub fn new(metrics: Arc<RoomMetrics>, cache: AppCache) -> Self {
let (shutdown_tx, _) = broadcast::channel(SHUTDOWN_CHANNEL_CAPACITY);
Self {
#[allow(clippy::default_constructed_unit_structs)]
room_inner: RwLock::new(HashMap::new()),
#[allow(clippy::default_constructed_unit_structs)]
project_inner: RwLock::new(HashMap::new()),
#[allow(clippy::default_constructed_unit_structs)]
user_inner: RwLock::new(HashMap::new()),
#[allow(clippy::default_constructed_unit_structs)]
user_notification_inner: RwLock::new(HashMap::new()),
#[allow(clippy::default_constructed_unit_structs)]
task_inner: RwLock::new(HashMap::new()),
metrics,
cache,
#[allow(clippy::default_constructed_unit_structs)]
connection_rate: RwLock::new(HashMap::new()),
shutdown_tx,
#[allow(clippy::default_constructed_unit_structs)]
room_shutdown_txs: RwLock::new(HashMap::new()),
#[allow(clippy::default_constructed_unit_structs)]
project_shutdown_txs: RwLock::new(HashMap::new()),
#[allow(clippy::default_constructed_unit_structs)]
user_shutdown_txs: RwLock::new(HashMap::new()),
#[allow(clippy::default_constructed_unit_structs)]
stream_inner: RwLock::new(HashMap::new()),
#[allow(clippy::default_constructed_unit_structs)]
room_stream_inner: RwLock::new(HashMap::new()),
#[allow(clippy::default_constructed_unit_structs)]
typing_inner: RwLock::new(HashMap::new()),
#[allow(clippy::default_constructed_unit_structs)]
room_last_activity: RwLock::new(HashMap::new()),
#[allow(clippy::default_constructed_unit_structs)]
room_subscriber_count: RwLock::new(HashMap::new()),
#[allow(clippy::default_constructed_unit_structs)]
project_subscriber_count: RwLock::new(HashMap::new()),
#[allow(clippy::default_constructed_unit_structs)]
user_subscriber_count: RwLock::new(HashMap::new()),
}
}
pub async fn check_room_connection_rate(
&self,
room_id: Uuid,
user_id: Uuid,
) -> Result<(), RoomError> {
let mut map = self.connection_rate.write().await;
let key = (room_id, user_id);
if let Some(last) = map.get(&key) {
if last.elapsed() < CONNECTION_COOLDOWN {
return Err(RoomError::RateLimited(format!(
"Connection cooldown active, retry in {}s",
CONNECTION_COOLDOWN.saturating_sub(last.elapsed()).as_secs()
)));
}
}
map.insert(key, Instant::now());
Ok(())
}
pub async fn check_project_connection_rate(
&self,
project_id: Uuid,
user_id: Uuid,
) -> Result<(), RoomError> {
let mut map = self.connection_rate.write().await;
let key = (project_id, user_id);
if let Some(last) = map.get(&key) {
if last.elapsed() < CONNECTION_COOLDOWN {
return Err(RoomError::RateLimited(format!(
"Connection cooldown active, retry in {}s",
CONNECTION_COOLDOWN.saturating_sub(last.elapsed()).as_secs()
)));
}
}
map.insert(key, Instant::now());
Ok(())
}
pub async fn check_user_connection_rate(&self, user_id: Uuid) -> Result<(), RoomError> {
let mut map = self.connection_rate.write().await;
let key = (Uuid::nil(), user_id);
if let Some(last) = map.get(&key) {
if last.elapsed() < CONNECTION_COOLDOWN {
return Err(RoomError::RateLimited(format!(
"Connection cooldown active, retry in {}s",
CONNECTION_COOLDOWN.saturating_sub(last.elapsed()).as_secs()
)));
}
}
map.insert(key, Instant::now());
Ok(())
}
pub async fn cleanup_rate_limit(&self) {
let mut map = self.connection_rate.write().await;
map.retain(|_, instant| instant.elapsed() < CONNECTION_COOLDOWN * 2);
const MAX_RATE_ENTRIES: usize = MAX_CONNECTIONS_PER_ROOM * 10;
if map.len() > MAX_RATE_ENTRIES {
let mut entries: Vec<_> = map.iter().collect();
entries.sort_by(|a, b| a.1.cmp(b.1));
let keep_count = entries.len() / 2;
let to_remove: Vec<_> = entries
.into_iter()
.take(keep_count)
.map(|(k, _)| *k)
.collect();
for key in to_remove {
map.remove(&key);
}
}
drop(map);
self.cleanup_idle_rooms().await;
}
pub async fn cleanup_idle_rooms(&self) {
let now = Instant::now();
let activity = self.room_last_activity.read().await;
let idle_room_ids: Vec<Uuid> = activity
.iter()
.filter(|(_, last_time)| now.duration_since(**last_time) > ROOM_IDLE_TIMEOUT)
.map(|(room_id, _)| *room_id)
.collect();
drop(activity);
if idle_room_ids.is_empty() {
return;
}
{
let mut counts = self.room_subscriber_count.write().await;
let mut rooms = self.room_inner.write().await;
for room_id in &idle_room_ids {
if let Some(sender) = rooms.remove(room_id) {
let count = counts.remove(&room_id).unwrap_or(1);
self.metrics.users_online.decrement(count as f64);
drop(sender);
}
}
}
{
let mut stream_map = self.room_stream_inner.write().await;
for room_id in &idle_room_ids {
stream_map.remove(room_id);
}
}
{
let mut txs = self.room_shutdown_txs.write().await;
for room_id in &idle_room_ids {
txs.remove(room_id);
}
}
{
let mut activity = self.room_last_activity.write().await;
for room_id in &idle_room_ids {
activity.remove(room_id);
}
}
}
pub fn subscribe_shutdown(&self) -> broadcast::Receiver<()> {
self.shutdown_tx.subscribe()
}
pub fn trigger_shutdown(&self) {
let _ = self.shutdown_tx.send(());
}
pub async fn register_room(&self, room_id: Uuid) -> broadcast::Receiver<()> {
let mut txs = self.room_shutdown_txs.write().await;
if let Some(tx) = txs.get(&room_id) {
return tx.subscribe();
}
let (tx, rx) = broadcast::channel(SHUTDOWN_CHANNEL_CAPACITY);
txs.insert(room_id, tx);
rx
}
pub async fn shutdown_room(&self, room_id: Uuid) {
{
let txs = self.room_shutdown_txs.read().await;
if let Some(tx) = txs.get(&room_id) {
let _ = tx.send(());
}
}
{
let mut counts = self.room_subscriber_count.write().await;
let count = counts.remove(&room_id).unwrap_or(0) as f64;
if count > 0.0 {
self.metrics.users_online.decrement(count);
}
}
{
let mut map = self.room_inner.write().await;
map.remove(&room_id);
}
{
let mut stream_map = self.room_stream_inner.write().await;
stream_map.remove(&room_id);
}
{
let mut txs = self.room_shutdown_txs.write().await;
txs.remove(&room_id);
}
}
pub async fn prune_stale_rooms(&self, active_room_ids: &[Uuid]) {
let mut txs = self.room_shutdown_txs.write().await;
txs.retain(|room_id, _| active_room_ids.contains(room_id));
drop(txs);
let mut counts = self.room_subscriber_count.write().await;
counts.retain(|room_id, _| active_room_ids.contains(room_id));
}
pub async fn register_project(&self, project_id: Uuid) -> broadcast::Receiver<()> {
let mut txs = self.project_shutdown_txs.write().await;
if let Some(tx) = txs.get(&project_id) {
return tx.subscribe();
}
let (tx, rx) = broadcast::channel(SHUTDOWN_CHANNEL_CAPACITY);
txs.insert(project_id, tx);
rx
}
pub async fn shutdown_project(&self, project_id: Uuid) {
{
let txs = self.project_shutdown_txs.read().await;
if let Some(tx) = txs.get(&project_id) {
let _ = tx.send(());
}
}
{
let mut map = self.project_inner.write().await;
map.remove(&project_id);
}
{
let mut txs = self.project_shutdown_txs.write().await;
txs.remove(&project_id);
}
}
pub async fn prune_stale_projects(&self, active_project_ids: &[Uuid]) {
let mut txs = self.project_shutdown_txs.write().await;
txs.retain(|project_id, _| active_project_ids.contains(project_id));
}
pub async fn register_user(&self, user_id: Uuid) -> broadcast::Receiver<()> {
let mut txs = self.user_shutdown_txs.write().await;
if let Some(tx) = txs.get(&user_id) {
return tx.subscribe();
}
let (tx, rx) = broadcast::channel(SHUTDOWN_CHANNEL_CAPACITY);
txs.insert(user_id, tx);
rx
}
pub async fn shutdown_user(&self, user_id: Uuid) {
{
let txs = self.user_shutdown_txs.read().await;
if let Some(tx) = txs.get(&user_id) {
let _ = tx.send(());
}
}
{
let mut map = self.user_inner.write().await;
map.remove(&user_id);
}
{
let mut txs = self.user_shutdown_txs.write().await;
txs.remove(&user_id);
}
}
pub async fn subscribe_user_notification(
&self,
user_id: Uuid,
) -> broadcast::Receiver<Arc<NotificationEvent>> {
let mut map = self.user_notification_inner.write().await;
if let Some(sender) = map.get(&user_id) {
return sender.subscribe();
}
let (tx, rx) = broadcast::channel(BROADCAST_CAPACITY);
map.insert(user_id, tx);
rx
}
pub async fn unsubscribe_user_notification(&self, user_id: Uuid) {
let mut map = self.user_notification_inner.write().await;
map.remove(&user_id);
}
pub async fn push_user_notification(&self, user_id: Uuid, event: Arc<NotificationEvent>) {
let map = self.user_notification_inner.read().await;
if let Some(sender) = map.get(&user_id) {
let _ = sender.send(event);
}
}
pub async fn subscribe(
&self,
room_id: Uuid,
_user_id: Uuid,
) -> Result<broadcast::Receiver<Arc<RoomMessageEvent>>, RoomError> {
let mut map = self.room_inner.write().await;
if let Some(_sender) = map.get(&room_id) {
drop(map);
let mut counts = self.room_subscriber_count.write().await;
*counts.entry(room_id).or_insert(0) += 1;
let map = self.room_inner.read().await;
if let Some(sender) = map.get(&room_id) {
return Ok(sender.subscribe());
}
return Err(RoomError::Internal(
"room disappeared during subscribe".into(),
));
}
if map.len() >= MAX_CONNECTIONS_PER_ROOM {
return Err(RoomError::RateLimited(format!(
"Room connection limit reached ({})",
MAX_CONNECTIONS_PER_ROOM
)));
}
let (tx, rx) = broadcast::channel(BROADCAST_CAPACITY);
map.insert(room_id, tx);
drop(map);
let mut counts = self.room_subscriber_count.write().await;
counts.insert(room_id, 1);
self.metrics.users_online.increment(1.0);
Ok(rx)
}
pub async fn unsubscribe(&self, room_id: Uuid, _user_id: Uuid) {
let mut counts = self.room_subscriber_count.write().await;
let count = counts.entry(room_id).or_insert(0);
if *count > 0 {
*count -= 1;
self.metrics.users_online.decrement(1.0);
}
if *count == 0 {
counts.remove(&room_id);
drop(counts);
let mut map = self.room_inner.write().await;
map.remove(&room_id);
}
}
pub async fn broadcast(&self, room_id: Uuid, event: RoomMessageEvent) {
{
let mut activity = self.room_last_activity.write().await;
activity.insert(room_id, Instant::now());
}
let map = self.room_inner.read().await;
if let Some(sender) = map.get(&room_id) {
let event = Arc::new(event);
if sender.send(event).is_err() {
self.metrics.broadcasts_dropped.increment(1);
}
}
}
pub async fn subscribe_project(
&self,
project_id: Uuid,
_user_id: Uuid,
) -> Result<broadcast::Receiver<Arc<ProjectRoomEvent>>, RoomError> {
let mut map = self.project_inner.write().await;
if map.get(&project_id).is_some() {
drop(map);
let mut counts = self.project_subscriber_count.write().await;
*counts.entry(project_id).or_insert(0) += 1;
let map = self.project_inner.read().await;
if let Some(sender) = map.get(&project_id) {
return Ok(sender.subscribe());
}
return Err(RoomError::Internal("project channel disappeared".into()));
}
if map.len() >= MAX_CONNECTIONS_PER_PROJECT {
return Err(RoomError::RateLimited(format!(
"Project connection limit reached ({})",
MAX_CONNECTIONS_PER_PROJECT
)));
}
let (tx, rx) = broadcast::channel(BROADCAST_CAPACITY);
map.insert(project_id, tx);
drop(map);
let mut counts = self.project_subscriber_count.write().await;
counts.insert(project_id, 1);
Ok(rx)
}
pub async fn unsubscribe_project(&self, project_id: Uuid, _user_id: Uuid) {
let mut counts = self.project_subscriber_count.write().await;
let count = counts.entry(project_id).or_insert(0);
if *count > 0 {
*count -= 1;
}
if *count == 0 {
counts.remove(&project_id);
drop(counts);
let mut map = self.project_inner.write().await;
map.remove(&project_id);
}
}
pub async fn broadcast_project(&self, project_id: Uuid, event: ProjectRoomEvent) {
let map = self.project_inner.read().await;
if let Some(sender) = map.get(&project_id) {
let event = Arc::new(event);
if sender.send(event).is_err() {
self.metrics.broadcasts_dropped.increment(1);
}
}
}
/// Broadcast an agent task event to all WS clients subscribed to this project.
pub async fn broadcast_agent_task(&self, project_id: Uuid, event: AgentTaskEvent) {
let map = self.task_inner.read().await;
if let Some(sender) = map.get(&project_id) {
let event = Arc::new(event);
if sender.send(event).is_err() {
self.metrics.broadcasts_dropped.increment(1);
}
}
}
/// Subscribe to agent task events for a project.
/// Returns a broadcast receiver that yields task events as they occur.
pub async fn subscribe_task_events(
&self,
project_id: Uuid,
) -> Result<broadcast::Receiver<Arc<AgentTaskEvent>>, RoomError> {
let mut map = self.task_inner.write().await;
if let Some(sender) = map.get(&project_id).cloned() {
drop(map);
return Ok(sender.subscribe());
}
let (tx, rx) = broadcast::channel(BROADCAST_CAPACITY);
map.insert(project_id, tx);
Ok(rx)
}
pub async fn subscribe_user(
&self,
user_id: Uuid,
) -> Result<broadcast::Receiver<Arc<ProjectRoomEvent>>, RoomError> {
let mut map = self.user_inner.write().await;
if let Some(_sender) = map.get(&user_id) {
drop(map);
let mut counts = self.user_subscriber_count.write().await;
*counts.entry(user_id).or_insert(0) += 1;
let map = self.user_inner.read().await;
if let Some(sender) = map.get(&user_id) {
return Ok(sender.subscribe());
}
return Err(RoomError::Internal("user channel disappeared".into()));
}
if map.len() >= MAX_CONNECTIONS_PER_USER {
return Err(RoomError::RateLimited(format!(
"User connection limit reached ({})",
MAX_CONNECTIONS_PER_USER
)));
}
let (tx, rx) = broadcast::channel(BROADCAST_CAPACITY);
map.insert(user_id, tx);
drop(map);
let mut counts = self.user_subscriber_count.write().await;
counts.insert(user_id, 1);
self.metrics.users_online.increment(1.0);
Ok(rx)
}
pub async fn unsubscribe_user(&self, user_id: Uuid) {
let mut counts = self.user_subscriber_count.write().await;
let count = counts.entry(user_id).or_insert(0);
if *count > 0 {
*count -= 1;
self.metrics.users_online.decrement(1.0);
}
if *count == 0 {
counts.remove(&user_id);
drop(counts);
let mut map = self.user_inner.write().await;
map.remove(&user_id);
}
}
pub async fn broadcast_to_user(&self, user_id: Uuid, event: ProjectRoomEvent) {
let map = self.user_inner.read().await;
if let Some(sender) = map.get(&user_id) {
let event = Arc::new(event);
if sender.send(event).is_err() {
self.metrics.broadcasts_dropped.increment(1);
}
}
}
pub async fn register_stream_channel(
&self,
message_id: Uuid,
) -> broadcast::Receiver<Arc<RoomMessageStreamChunkEvent>> {
let mut map = self.stream_inner.write().await;
if let Some(tx) = map.get(&message_id) {
return tx.subscribe();
}
let (tx, rx) = broadcast::channel(BROADCAST_CAPACITY);
map.insert(message_id, tx);
rx
}
pub async fn subscribe_stream(
&self,
message_id: Uuid,
) -> Option<broadcast::Receiver<Arc<RoomMessageStreamChunkEvent>>> {
let map = self.stream_inner.read().await;
map.get(&message_id).map(|tx| tx.subscribe())
}
pub async fn subscribe_room_stream(
&self,
room_id: Uuid,
) -> broadcast::Receiver<Arc<RoomMessageStreamChunkEvent>> {
let mut map = self.room_stream_inner.write().await;
if let Some(tx) = map.get(&room_id) {
return tx.subscribe();
}
let (tx, rx) = broadcast::channel(BROADCAST_CAPACITY);
map.insert(room_id, tx);
rx
}
pub async fn broadcast_stream_chunk(&self, event: RoomMessageStreamChunkEvent) {
// Update activity tracker to prevent idle cleanup during active streaming
{
let mut activity = self.room_last_activity.write().await;
activity.insert(event.room_id, Instant::now());
}
let event = Arc::new(event);
let is_final_chunk = event.done;
let map = self.stream_inner.read().await;
if let Some(tx) = map.get(&event.message_id) {
let _ = tx.send(Arc::clone(&event));
}
drop(map);
let map = self.room_stream_inner.read().await;
if let Some(tx) = map.get(&event.room_id) {
let _ = tx.send(Arc::clone(&event));
}
if is_final_chunk {
drop(map);
let mut map = self.room_stream_inner.write().await;
map.remove(&event.room_id);
}
}
pub async fn close_stream_channel(&self, message_id: Uuid) {
let mut map = self.stream_inner.write().await;
map.remove(&message_id);
}
pub async fn subscribe_typing(
&self,
room_id: Uuid,
) -> broadcast::Receiver<Arc<TypingEvent>> {
let mut map: tokio::sync::RwLockWriteGuard<'_, std::collections::HashMap<Uuid, broadcast::Sender<Arc<TypingEvent>>>> = self.typing_inner.write().await;
let tx = map.entry(room_id).or_insert_with(|| {
let (tx, _) = broadcast::channel(BROADCAST_CAPACITY);
tx
});
// Replay active typing state from Redis to the new subscriber.
// This ensures newly connected WS clients see who is currently typing.
let active_events = self.get_active_typing_events(room_id).await;
for event in active_events {
let _ = tx.send(Arc::new(event));
}
tx.subscribe()
}
/// Broadcast a typing event and persist it to Redis with 60s TTL.
/// - "start": writes key with 60s expiry, broadcasts start event
/// - "stop": deletes key, broadcasts stop event
pub async fn broadcast_typing(&self, room_id: Uuid, event: TypingEvent) {
let user_key = format!("typing:{}:{}", room_id, event.user_id);
let action = event.action.clone();
let username = event.username.clone();
let avatar_url = event.avatar_url.clone();
let sender_type = event.sender_type.clone().unwrap_or_else(|| "user".to_string());
// Write/delete Redis key for 60s expiry (non-blocking)
if let Ok(mut conn) = self.cache.conn().await {
let key = user_key;
tokio::spawn(async move {
if action == "start" {
let value = serde_json::json!({
"username": username,
"avatar_url": avatar_url,
"sender_type": sender_type,
})
.to_string();
let _: Result<(), _> = redis::cmd("SETEX")
.arg(&key)
.arg(60i64)
.arg(&value)
.query_async(&mut conn)
.await;
} else {
let _: Result<(), _> = redis::cmd("DEL").arg(&key).query_async(&mut conn).await;
}
});
}
let map: tokio::sync::RwLockReadGuard<'_, std::collections::HashMap<Uuid, broadcast::Sender<Arc<TypingEvent>>>> = self.typing_inner.read().await;
if let Some(tx) = map.get(&room_id) {
let event = Arc::new(event);
let _ = tx.send(event);
}
}
/// Load all active typing entries for a room from Redis and return as TypingEvents.
/// Used to replay current typing state to newly connected WS clients.
pub async fn get_active_typing_events(&self, room_id: Uuid) -> Vec<TypingEvent> {
let pattern = format!("typing:{}:*", room_id);
if let Ok(mut conn) = self.cache.conn().await {
let keys: Vec<String> = match redis::cmd("KEYS").arg(&pattern).query_async(&mut conn).await {
Ok(k) => k,
Err(_) => return vec![],
};
if keys.is_empty() {
return vec![];
}
let mut results = Vec::new();
for key in keys {
let parts: Vec<&str> = key.split(':').collect();
let user_id = parts.get(2).and_then(|s| Uuid::parse_str(s).ok());
if let (Some(value), Some(user_uuid)) = (
redis::cmd("GET").arg(&key).query_async::<String>(&mut conn).await.ok(),
user_id,
) {
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&value) {
results.push(TypingEvent {
room_id,
user_id: user_uuid,
username: parsed.get("username").and_then(|v| v.as_str()).unwrap_or("").to_string(),
avatar_url: parsed.get("avatar_url").and_then(|v| v.as_str()).map(String::from),
action: "start".to_string(),
sender_type: parsed.get("sender_type").and_then(|v| v.as_str()).map(String::from),
});
}
}
}
return results;
}
vec![]
}
}
fn parse_sender_type(s: &str) -> MessageSenderType {
match s {
"member" => MessageSenderType::Member,
"admin" => MessageSenderType::Admin,
"owner" => MessageSenderType::Owner,
"ai" => MessageSenderType::Ai,
"system" => MessageSenderType::System,
"tool" => MessageSenderType::Tool,
"guest" => MessageSenderType::Guest,
_ => MessageSenderType::Member,
}
}
fn parse_content_type(s: &str) -> MessageContentType {
match s {
"text" => MessageContentType::Text,
"image" => MessageContentType::Image,
"audio" => MessageContentType::Audio,
"video" => MessageContentType::Video,
"file" => MessageContentType::File,
_ => MessageContentType::Text,
}
}
pub type PersistFn = Arc<
dyn Fn(Vec<RoomMessageEnvelope>) -> Pin<Box<dyn Future<Output = anyhow::Result<()>> + Send>>
+ Send
+ Sync,
>;
use dashmap::DashMap;
pub type DedupCache = Arc<DashMap<uuid::Uuid, Instant>>;
const DEDUP_CACHE_TTL: Duration = Duration::from_secs(300);
pub fn cleanup_dedup_cache(cache: &DedupCache) {
let cutoff = Instant::now() - DEDUP_CACHE_TTL;
cache.retain(|_, inserted_at| *inserted_at > cutoff);
}
pub fn make_persist_fn(
db: AppDatabase,
metrics: Arc<RoomMetrics>,
dedup_cache: DedupCache,
) -> PersistFn {
Arc::new(move |envelopes: Vec<RoomMessageEnvelope>| {
let db = db.clone();
let metrics = metrics.clone();
let cache = dedup_cache.clone();
Box::pin(async move {
for chunk in envelopes.chunks(BATCH_SIZE) {
let mut models_to_insert = Vec::new();
let mut ids_to_dedup: Vec<uuid::Uuid> = Vec::new();
for env in chunk {
if cache.contains_key(&env.id) {
metrics.incr_duplicates_skipped();
continue;
}
ids_to_dedup.push(env.id);
}
let existing_ids: std::collections::HashSet<uuid::Uuid> =
if !ids_to_dedup.is_empty() {
room_message::Entity::find()
.filter(room_message::Column::Id.is_in(ids_to_dedup))
.into_model::<room_message::Model>()
.all(&db)
.await?
.into_iter()
.map(|m| m.id)
.collect()
} else {
std::collections::HashSet::new()
};
for env in chunk {
if cache.contains_key(&env.id) {
continue;
}
cache.insert(env.id, Instant::now());
if existing_ids.contains(&env.id) {
metrics.incr_duplicates_skipped();
continue;
}
let sender_type = parse_sender_type(&env.sender_type);
let content_type = parse_content_type(&env.content_type);
models_to_insert.push(room_message::ActiveModel {
id: Set(env.id),
seq: Set(env.seq),
room: Set(env.room_id),
sender_type: Set(sender_type),
sender_id: Set(env.sender_id),
model_id: Set(env.model_id),
thread: Set(env.thread_id),
content: Set(env.content.clone()),
content_type: Set(content_type),
edited_at: Set(None),
send_at: Set(env.send_at.clone()),
revoked: Set(None),
revoked_by: Set(None),
in_reply_to: Set(env.in_reply_to),
});
}
if !models_to_insert.is_empty() {
let count = models_to_insert.len() as u64;
room_message::Entity::insert_many(models_to_insert)
.exec(&db)
.await?;
// Batch update content_tsv using a single UPDATE with subquery
// instead of N individual UPDATE statements (N=chunk size, up to 100)
let ids: Vec<String> = chunk
.iter()
.filter(|e| !existing_ids.contains(&e.id))
.map(|e| format!("'{}'", e.id))
.collect();
if !ids.is_empty() {
let batch_sql = format!(
"UPDATE room_message AS t \
SET content_tsv = to_tsvector('simple', content) \
WHERE t.id IN ({})",
ids.join(",")
);
let stmt = sea_orm::Statement::from_sql_and_values(
sea_orm::DbBackend::Postgres,
&batch_sql,
vec![],
);
let _ = db.execute_raw(stmt).await;
}
metrics.messages_persisted.increment(count);
}
}
Ok(())
})
})
}
pub type RedisFuture =
Pin<Box<dyn Future<Output = anyhow::Result<deadpool_redis::cluster::Connection>> + Send>>;
pub fn extract_get_redis(
producer: queue::MessageProducer,
) -> Arc<dyn Fn() -> RedisFuture + Send + Sync> {
Arc::new(move || {
let get_redis_fn = producer.get_redis.clone();
Box::pin(async move {
let handle = get_redis_fn();
match handle.await {
Ok(conn) => conn,
Err(_) => anyhow::bail!("redis pool task panicked"),
}
}) as RedisFuture
})
}
fn start_pubsub_thread<F, Fut>(
redis_url: String,
channel: String,
relay_tx: tokio::sync::mpsc::Sender<Vec<u8>>,
mut shutdown_rx: broadcast::Receiver<()>,
_on_msg: F,
) where
F: Fn(Vec<u8>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send,
{
thread::Builder::new()
.name(format!("redis-pubsub-{}", &channel[..channel.len().min(32)]))
.spawn(move || {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("pubsub thread runtime");
rt.block_on(async {
let redis_url = redis_url.clone();
loop {
if shutdown_rx.try_recv().is_ok() {
tracing::info!(channel = %channel, "pubsub thread shutting down before connect");
break;
}
let client = match redis::Client::open(redis_url.as_str()) {
Ok(c) => c,
Err(e) => {
tracing::error!(channel = %channel, error = %e, "pubsub redis client open failed");
thread::sleep(Duration::from_secs(1));
continue;
}
};
let mut pubsub = match client.get_async_pubsub().await {
Ok(p) => p,
Err(e) => {
tracing::error!(channel = %channel, error = %e, "pubsub connection failed");
thread::sleep(Duration::from_secs(1));
continue;
}
};
match pubsub.subscribe(&channel).await {
Ok(_) => tracing::info!(channel = %channel, "pubsub subscribed"),
Err(e) => {
tracing::error!(channel = %channel, error = %e, "pubsub subscribe failed");
thread::sleep(Duration::from_secs(1));
continue;
}
}
let mut stream = pubsub.on_message();
loop {
if shutdown_rx.try_recv().is_ok() {
tracing::info!(channel = %channel, "pubsub thread shutting down");
return;
}
let msg = tokio::time::timeout(
Duration::from_millis(500),
futures::StreamExt::next(&mut stream),
)
.await;
match msg {
Ok(Some(msg)) => {
let payload = msg.get_payload_bytes();
tracing::debug!(channel = %channel, len = payload.len(), "pubsub received");
if relay_tx.send(payload.to_vec()).await.is_err() {
tracing::warn!(channel = %channel, "pubsub relay channel closed");
return;
}
}
Ok(None) => {
tracing::warn!(channel = %channel, "pubsub stream ended, will reconnect");
break;
}
Err(_) => {}
}
}
tracing::warn!(channel = %channel, "pubsub connection lost, reconnecting");
}
});
})
.expect("pubsub thread spawn");
}
pub async fn subscribe_room_events(
redis_url: String,
manager: Arc<RoomConnectionManager>,
room_id: Uuid,
mut shutdown_rx: broadcast::Receiver<()>,
) {
let channel = format!("room:pub:{}", room_id);
let (tx, mut rx) = tokio::sync::mpsc::channel::<Vec<u8>>(1024);
tracing::info!(room_id = %room_id, channel = %channel, "starting room pubsub subscriber");
let thread_channel = channel.clone();
let thread_shutdown = shutdown_rx.resubscribe();
start_pubsub_thread(
redis_url,
thread_channel,
tx,
thread_shutdown,
|_| async {},
);
loop {
tokio::select! {
_ = shutdown_rx.recv() => {
tracing::info!(room_id = %room_id, "room subscriber shutting down");
break;
}
payload = rx.recv() => {
match payload {
Some(data) => {
match serde_json::from_slice::<RoomMessageEvent>(&data) {
Ok(event) => {
manager.broadcast(room_id, event).await;
}
Err(e) => {
tracing::warn!(error = %e, "malformed RoomMessageEvent");
}
}
}
None => {
tracing::warn!(room_id = %room_id, "pubsub relay channel closed");
break;
}
}
}
}
}
tracing::info!(room_id = %room_id, "room subscriber stopped");
}
pub async fn subscribe_project_room_events(
redis_url: String,
manager: Arc<RoomConnectionManager>,
project_id: Uuid,
mut shutdown_rx: broadcast::Receiver<()>,
) {
let channel = format!("project:pub:{}", project_id);
let (tx, mut rx) = tokio::sync::mpsc::channel::<Vec<u8>>(1024);
tracing::info!(project_id = %project_id, channel = %channel, "starting project pubsub subscriber");
let thread_channel = channel.clone();
let thread_shutdown = shutdown_rx.resubscribe();
start_pubsub_thread(
redis_url,
thread_channel,
tx,
thread_shutdown,
|_| async {},
);
loop {
tokio::select! {
_ = shutdown_rx.recv() => {
tracing::info!(project_id = %project_id, "project subscriber shutting down");
break;
}
payload = rx.recv() => {
match payload {
Some(data) => {
match serde_json::from_slice::<ProjectRoomEvent>(&data) {
Ok(event) => {
manager.broadcast_project(project_id, event).await;
}
Err(e) => {
tracing::warn!(error = %e, "malformed ProjectRoomEvent");
}
}
}
None => {
tracing::warn!(project_id = %project_id, "project pubsub relay channel closed");
break;
}
}
}
}
}
tracing::info!(project_id = %project_id, "project subscriber stopped");
}
/// Subscribe to Redis Pub/Sub `task:pub:{project_id}` and relay events to
/// `RoomConnectionManager::broadcast_agent_task()` so all WS clients get notified.
pub async fn subscribe_task_events_fn(
redis_url: String,
manager: Arc<RoomConnectionManager>,
project_id: Uuid,
mut shutdown_rx: broadcast::Receiver<()>,
) {
let channel = format!("task:pub:{}", project_id);
let (tx, mut rx) = tokio::sync::mpsc::channel::<Vec<u8>>(1024);
tracing::info!(project_id = %project_id, channel = %channel, "starting task pubsub subscriber");
let thread_channel = channel.clone();
let thread_shutdown = shutdown_rx.resubscribe();
start_pubsub_thread(
redis_url,
thread_channel,
tx,
thread_shutdown,
|_| async {},
);
loop {
tokio::select! {
_ = shutdown_rx.recv() => {
tracing::info!(project_id = %project_id, "task subscriber shutting down");
break;
}
payload = rx.recv() => {
match payload {
Some(data) => {
match serde_json::from_slice::<AgentTaskEvent>(&data) {
Ok(event) => {
manager.broadcast_agent_task(project_id, event).await;
}
Err(e) => {
tracing::warn!(error = %e, "malformed AgentTaskEvent");
}
}
}
None => {
tracing::warn!(project_id = %project_id, "task pubsub relay channel closed");
break;
}
}
}
}
}
tracing::info!(project_id = %project_id, "task subscriber stopped");
}