gitdataai/libs/room/src/connection.rs
2026-04-14 19:02:01 +08:00

999 lines
35 KiB
Rust

use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::thread;
use std::time::{Duration, Instant};
use tokio::sync::{RwLock, broadcast};
use uuid::Uuid;
use db::database::AppDatabase;
use models::rooms::{MessageContentType, MessageSenderType, room_message};
use queue::{AgentTaskEvent, ProjectRoomEvent, RoomMessageEnvelope, RoomMessageEvent, RoomMessageStreamChunkEvent};
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter, Set};
use crate::error::RoomError;
use crate::metrics::RoomMetrics;
use crate::types::NotificationEvent;
const BROADCAST_CAPACITY: usize = 10000;
const SHUTDOWN_CHANNEL_CAPACITY: usize = 16;
const CONNECTION_COOLDOWN: Duration = Duration::from_secs(30);
const MAX_CONNECTIONS_PER_ROOM: usize = 50000;
const MAX_CONNECTIONS_PER_PROJECT: usize = 50000;
const MAX_CONNECTIONS_PER_USER: usize = 50000;
const BATCH_SIZE: usize = 100;
const ROOM_IDLE_TIMEOUT: Duration = Duration::from_secs(30 * 60);
pub struct RoomConnectionManager {
room_inner: RwLock<HashMap<Uuid, broadcast::Sender<Arc<RoomMessageEvent>>>>,
project_inner: RwLock<HashMap<Uuid, broadcast::Sender<Arc<ProjectRoomEvent>>>>,
user_inner: RwLock<HashMap<Uuid, broadcast::Sender<Arc<ProjectRoomEvent>>>>,
user_notification_inner: RwLock<HashMap<Uuid, broadcast::Sender<Arc<NotificationEvent>>>>,
/// Broadcast channel for agent task events per project.
task_inner: RwLock<HashMap<Uuid, broadcast::Sender<Arc<AgentTaskEvent>>>>,
pub metrics: Arc<RoomMetrics>,
connection_rate: RwLock<HashMap<(Uuid, Uuid), Instant>>,
shutdown_tx: broadcast::Sender<()>,
room_shutdown_txs: RwLock<HashMap<Uuid, broadcast::Sender<()>>>,
project_shutdown_txs: RwLock<HashMap<Uuid, broadcast::Sender<()>>>,
user_shutdown_txs: RwLock<HashMap<Uuid, broadcast::Sender<()>>>,
stream_inner: RwLock<HashMap<Uuid, broadcast::Sender<Arc<RoomMessageStreamChunkEvent>>>>,
room_stream_inner: RwLock<HashMap<Uuid, broadcast::Sender<Arc<RoomMessageStreamChunkEvent>>>>,
room_last_activity: RwLock<HashMap<Uuid, Instant>>,
room_subscriber_count: RwLock<HashMap<Uuid, usize>>,
project_subscriber_count: RwLock<HashMap<Uuid, usize>>,
user_subscriber_count: RwLock<HashMap<Uuid, usize>>,
}
impl RoomConnectionManager {
pub fn new(metrics: Arc<RoomMetrics>) -> Self {
let (shutdown_tx, _) = broadcast::channel(SHUTDOWN_CHANNEL_CAPACITY);
Self {
#[allow(clippy::default_constructed_unit_structs)]
room_inner: RwLock::new(HashMap::new()),
#[allow(clippy::default_constructed_unit_structs)]
project_inner: RwLock::new(HashMap::new()),
#[allow(clippy::default_constructed_unit_structs)]
user_inner: RwLock::new(HashMap::new()),
#[allow(clippy::default_constructed_unit_structs)]
user_notification_inner: RwLock::new(HashMap::new()),
#[allow(clippy::default_constructed_unit_structs)]
task_inner: RwLock::new(HashMap::new()),
metrics,
#[allow(clippy::default_constructed_unit_structs)]
connection_rate: RwLock::new(HashMap::new()),
shutdown_tx,
#[allow(clippy::default_constructed_unit_structs)]
room_shutdown_txs: RwLock::new(HashMap::new()),
#[allow(clippy::default_constructed_unit_structs)]
project_shutdown_txs: RwLock::new(HashMap::new()),
#[allow(clippy::default_constructed_unit_structs)]
user_shutdown_txs: RwLock::new(HashMap::new()),
#[allow(clippy::default_constructed_unit_structs)]
stream_inner: RwLock::new(HashMap::new()),
#[allow(clippy::default_constructed_unit_structs)]
room_stream_inner: RwLock::new(HashMap::new()),
#[allow(clippy::default_constructed_unit_structs)]
room_last_activity: RwLock::new(HashMap::new()),
#[allow(clippy::default_constructed_unit_structs)]
room_subscriber_count: RwLock::new(HashMap::new()),
#[allow(clippy::default_constructed_unit_structs)]
project_subscriber_count: RwLock::new(HashMap::new()),
#[allow(clippy::default_constructed_unit_structs)]
user_subscriber_count: RwLock::new(HashMap::new()),
}
}
pub async fn check_room_connection_rate(
&self,
room_id: Uuid,
user_id: Uuid,
) -> Result<(), RoomError> {
let mut map = self.connection_rate.write().await;
let key = (room_id, user_id);
if let Some(last) = map.get(&key) {
if last.elapsed() < CONNECTION_COOLDOWN {
return Err(RoomError::RateLimited(format!(
"Connection cooldown active, retry in {}s",
CONNECTION_COOLDOWN.saturating_sub(last.elapsed()).as_secs()
)));
}
}
map.insert(key, Instant::now());
Ok(())
}
pub async fn check_project_connection_rate(
&self,
project_id: Uuid,
user_id: Uuid,
) -> Result<(), RoomError> {
let mut map = self.connection_rate.write().await;
let key = (project_id, user_id);
if let Some(last) = map.get(&key) {
if last.elapsed() < CONNECTION_COOLDOWN {
return Err(RoomError::RateLimited(format!(
"Connection cooldown active, retry in {}s",
CONNECTION_COOLDOWN.saturating_sub(last.elapsed()).as_secs()
)));
}
}
map.insert(key, Instant::now());
Ok(())
}
pub async fn check_user_connection_rate(&self, user_id: Uuid) -> Result<(), RoomError> {
let mut map = self.connection_rate.write().await;
let key = (Uuid::nil(), user_id);
if let Some(last) = map.get(&key) {
if last.elapsed() < CONNECTION_COOLDOWN {
return Err(RoomError::RateLimited(format!(
"Connection cooldown active, retry in {}s",
CONNECTION_COOLDOWN.saturating_sub(last.elapsed()).as_secs()
)));
}
}
map.insert(key, Instant::now());
Ok(())
}
pub async fn cleanup_rate_limit(&self) {
let mut map = self.connection_rate.write().await;
map.retain(|_, instant| instant.elapsed() < CONNECTION_COOLDOWN * 2);
const MAX_RATE_ENTRIES: usize = MAX_CONNECTIONS_PER_ROOM * 10;
if map.len() > MAX_RATE_ENTRIES {
let mut entries: Vec<_> = map.iter().collect();
entries.sort_by(|a, b| a.1.cmp(b.1));
let keep_count = entries.len() / 2;
let to_remove: Vec<_> = entries
.into_iter()
.take(keep_count)
.map(|(k, _)| *k)
.collect();
for key in to_remove {
map.remove(&key);
}
}
drop(map);
self.cleanup_idle_rooms().await;
}
pub async fn cleanup_idle_rooms(&self) {
let now = Instant::now();
let activity = self.room_last_activity.read().await;
let idle_room_ids: Vec<Uuid> = activity
.iter()
.filter(|(_, last_time)| now.duration_since(**last_time) > ROOM_IDLE_TIMEOUT)
.map(|(room_id, _)| *room_id)
.collect();
drop(activity);
if idle_room_ids.is_empty() {
return;
}
{
let mut counts = self.room_subscriber_count.write().await;
let mut rooms = self.room_inner.write().await;
for room_id in &idle_room_ids {
if let Some(sender) = rooms.remove(room_id) {
let count = counts.remove(&room_id).unwrap_or(1);
self.metrics.users_online.decrement(count as f64);
drop(sender);
}
}
}
{
let mut txs = self.room_shutdown_txs.write().await;
for room_id in &idle_room_ids {
txs.remove(room_id);
}
}
{
let mut activity = self.room_last_activity.write().await;
for room_id in &idle_room_ids {
activity.remove(room_id);
}
}
}
pub fn subscribe_shutdown(&self) -> broadcast::Receiver<()> {
self.shutdown_tx.subscribe()
}
pub fn trigger_shutdown(&self) {
let _ = self.shutdown_tx.send(());
}
pub async fn register_room(&self, room_id: Uuid) -> broadcast::Receiver<()> {
let mut txs = self.room_shutdown_txs.write().await;
if let Some(tx) = txs.get(&room_id) {
return tx.subscribe();
}
let (tx, rx) = broadcast::channel(SHUTDOWN_CHANNEL_CAPACITY);
txs.insert(room_id, tx);
rx
}
pub async fn shutdown_room(&self, room_id: Uuid) {
{
let txs = self.room_shutdown_txs.read().await;
if let Some(tx) = txs.get(&room_id) {
let _ = tx.send(());
}
}
{
let mut counts = self.room_subscriber_count.write().await;
let count = counts.remove(&room_id).unwrap_or(0) as f64;
if count > 0.0 {
self.metrics.users_online.decrement(count);
}
}
{
let mut map = self.room_inner.write().await;
map.remove(&room_id);
}
{
let mut txs = self.room_shutdown_txs.write().await;
txs.remove(&room_id);
}
}
pub async fn prune_stale_rooms(&self, active_room_ids: &[Uuid]) {
let mut txs = self.room_shutdown_txs.write().await;
txs.retain(|room_id, _| active_room_ids.contains(room_id));
drop(txs);
let mut counts = self.room_subscriber_count.write().await;
counts.retain(|room_id, _| active_room_ids.contains(room_id));
}
pub async fn register_project(&self, project_id: Uuid) -> broadcast::Receiver<()> {
let mut txs = self.project_shutdown_txs.write().await;
if let Some(tx) = txs.get(&project_id) {
return tx.subscribe();
}
let (tx, rx) = broadcast::channel(SHUTDOWN_CHANNEL_CAPACITY);
txs.insert(project_id, tx);
rx
}
pub async fn shutdown_project(&self, project_id: Uuid) {
{
let txs = self.project_shutdown_txs.read().await;
if let Some(tx) = txs.get(&project_id) {
let _ = tx.send(());
}
}
{
let mut map = self.project_inner.write().await;
map.remove(&project_id);
}
{
let mut txs = self.project_shutdown_txs.write().await;
txs.remove(&project_id);
}
}
pub async fn prune_stale_projects(&self, active_project_ids: &[Uuid]) {
let mut txs = self.project_shutdown_txs.write().await;
txs.retain(|project_id, _| active_project_ids.contains(project_id));
}
pub async fn register_user(&self, user_id: Uuid) -> broadcast::Receiver<()> {
let mut txs = self.user_shutdown_txs.write().await;
if let Some(tx) = txs.get(&user_id) {
return tx.subscribe();
}
let (tx, rx) = broadcast::channel(SHUTDOWN_CHANNEL_CAPACITY);
txs.insert(user_id, tx);
rx
}
pub async fn shutdown_user(&self, user_id: Uuid) {
{
let txs = self.user_shutdown_txs.read().await;
if let Some(tx) = txs.get(&user_id) {
let _ = tx.send(());
}
}
{
let mut map = self.user_inner.write().await;
map.remove(&user_id);
}
{
let mut txs = self.user_shutdown_txs.write().await;
txs.remove(&user_id);
}
}
pub async fn subscribe_user_notification(
&self,
user_id: Uuid,
) -> broadcast::Receiver<Arc<NotificationEvent>> {
let mut map = self.user_notification_inner.write().await;
if let Some(sender) = map.get(&user_id) {
return sender.subscribe();
}
let (tx, rx) = broadcast::channel(BROADCAST_CAPACITY);
map.insert(user_id, tx);
rx
}
pub async fn unsubscribe_user_notification(&self, user_id: Uuid) {
let mut map = self.user_notification_inner.write().await;
map.remove(&user_id);
}
pub async fn push_user_notification(&self, user_id: Uuid, event: Arc<NotificationEvent>) {
let map = self.user_notification_inner.read().await;
if let Some(sender) = map.get(&user_id) {
let _ = sender.send(event);
}
}
pub async fn subscribe(
&self,
room_id: Uuid,
_user_id: Uuid,
) -> Result<broadcast::Receiver<Arc<RoomMessageEvent>>, RoomError> {
let mut map = self.room_inner.write().await;
if let Some(_sender) = map.get(&room_id) {
drop(map);
let mut counts = self.room_subscriber_count.write().await;
*counts.entry(room_id).or_insert(0) += 1;
let map = self.room_inner.read().await;
if let Some(sender) = map.get(&room_id) {
return Ok(sender.subscribe());
}
return Err(RoomError::Internal(
"room disappeared during subscribe".into(),
));
}
if map.len() >= MAX_CONNECTIONS_PER_ROOM {
return Err(RoomError::RateLimited(format!(
"Room connection limit reached ({})",
MAX_CONNECTIONS_PER_ROOM
)));
}
let (tx, rx) = broadcast::channel(BROADCAST_CAPACITY);
map.insert(room_id, tx);
drop(map);
let mut counts = self.room_subscriber_count.write().await;
counts.insert(room_id, 1);
self.metrics.users_online.increment(1.0);
Ok(rx)
}
pub async fn unsubscribe(&self, room_id: Uuid, _user_id: Uuid) {
let mut counts = self.room_subscriber_count.write().await;
let count = counts.entry(room_id).or_insert(0);
if *count > 0 {
*count -= 1;
self.metrics.users_online.decrement(1.0);
}
if *count == 0 {
counts.remove(&room_id);
drop(counts);
let mut map = self.room_inner.write().await;
map.remove(&room_id);
}
}
pub async fn broadcast(&self, room_id: Uuid, event: RoomMessageEvent) {
{
let mut activity = self.room_last_activity.write().await;
activity.insert(room_id, Instant::now());
}
let map = self.room_inner.read().await;
if let Some(sender) = map.get(&room_id) {
let event = Arc::new(event);
if sender.send(event).is_err() {
self.metrics.broadcasts_dropped.increment(1);
}
}
}
pub async fn subscribe_project(
&self,
project_id: Uuid,
_user_id: Uuid,
) -> Result<broadcast::Receiver<Arc<ProjectRoomEvent>>, RoomError> {
let mut map = self.project_inner.write().await;
if map.get(&project_id).is_some() {
drop(map);
let mut counts = self.project_subscriber_count.write().await;
*counts.entry(project_id).or_insert(0) += 1;
let map = self.project_inner.read().await;
if let Some(sender) = map.get(&project_id) {
return Ok(sender.subscribe());
}
return Err(RoomError::Internal("project channel disappeared".into()));
}
if map.len() >= MAX_CONNECTIONS_PER_PROJECT {
return Err(RoomError::RateLimited(format!(
"Project connection limit reached ({})",
MAX_CONNECTIONS_PER_PROJECT
)));
}
let (tx, rx) = broadcast::channel(BROADCAST_CAPACITY);
map.insert(project_id, tx);
drop(map);
let mut counts = self.project_subscriber_count.write().await;
counts.insert(project_id, 1);
Ok(rx)
}
pub async fn unsubscribe_project(&self, project_id: Uuid, _user_id: Uuid) {
let mut counts = self.project_subscriber_count.write().await;
let count = counts.entry(project_id).or_insert(0);
if *count > 0 {
*count -= 1;
}
if *count == 0 {
counts.remove(&project_id);
drop(counts);
let mut map = self.project_inner.write().await;
map.remove(&project_id);
}
}
pub async fn broadcast_project(&self, project_id: Uuid, event: ProjectRoomEvent) {
let map = self.project_inner.read().await;
if let Some(sender) = map.get(&project_id) {
let event = Arc::new(event);
if sender.send(event).is_err() {
self.metrics.broadcasts_dropped.increment(1);
}
}
}
/// Broadcast an agent task event to all WS clients subscribed to this project.
pub async fn broadcast_agent_task(&self, project_id: Uuid, event: AgentTaskEvent) {
let map = self.task_inner.read().await;
if let Some(sender) = map.get(&project_id) {
let event = Arc::new(event);
if sender.send(event).is_err() {
self.metrics.broadcasts_dropped.increment(1);
}
}
}
/// Subscribe to agent task events for a project.
/// Returns a broadcast receiver that yields task events as they occur.
pub async fn subscribe_task_events(
&self,
project_id: Uuid,
) -> Result<broadcast::Receiver<Arc<AgentTaskEvent>>, RoomError> {
let mut map = self.task_inner.write().await;
if let Some(sender) = map.get(&project_id).cloned() {
drop(map);
return Ok(sender.subscribe());
}
let (tx, rx) = broadcast::channel(BROADCAST_CAPACITY);
map.insert(project_id, tx);
Ok(rx)
}
pub async fn subscribe_user(
&self,
user_id: Uuid,
) -> Result<broadcast::Receiver<Arc<ProjectRoomEvent>>, RoomError> {
let mut map = self.user_inner.write().await;
if let Some(_sender) = map.get(&user_id) {
drop(map);
let mut counts = self.user_subscriber_count.write().await;
*counts.entry(user_id).or_insert(0) += 1;
let map = self.user_inner.read().await;
if let Some(sender) = map.get(&user_id) {
return Ok(sender.subscribe());
}
return Err(RoomError::Internal("user channel disappeared".into()));
}
if map.len() >= MAX_CONNECTIONS_PER_USER {
return Err(RoomError::RateLimited(format!(
"User connection limit reached ({})",
MAX_CONNECTIONS_PER_USER
)));
}
let (tx, rx) = broadcast::channel(BROADCAST_CAPACITY);
map.insert(user_id, tx);
drop(map);
let mut counts = self.user_subscriber_count.write().await;
counts.insert(user_id, 1);
self.metrics.users_online.increment(1.0);
Ok(rx)
}
pub async fn unsubscribe_user(&self, user_id: Uuid) {
let mut counts = self.user_subscriber_count.write().await;
let count = counts.entry(user_id).or_insert(0);
if *count > 0 {
*count -= 1;
self.metrics.users_online.decrement(1.0);
}
if *count == 0 {
counts.remove(&user_id);
drop(counts);
let mut map = self.user_inner.write().await;
map.remove(&user_id);
}
}
pub async fn broadcast_to_user(&self, user_id: Uuid, event: ProjectRoomEvent) {
let map = self.user_inner.read().await;
if let Some(sender) = map.get(&user_id) {
let event = Arc::new(event);
if sender.send(event).is_err() {
self.metrics.broadcasts_dropped.increment(1);
}
}
}
pub async fn register_stream_channel(
&self,
message_id: Uuid,
) -> broadcast::Receiver<Arc<RoomMessageStreamChunkEvent>> {
let mut map = self.stream_inner.write().await;
if let Some(tx) = map.get(&message_id) {
return tx.subscribe();
}
let (tx, rx) = broadcast::channel(BROADCAST_CAPACITY);
map.insert(message_id, tx);
rx
}
pub async fn subscribe_stream(
&self,
message_id: Uuid,
) -> Option<broadcast::Receiver<Arc<RoomMessageStreamChunkEvent>>> {
let map = self.stream_inner.read().await;
map.get(&message_id).map(|tx| tx.subscribe())
}
pub async fn subscribe_room_stream(
&self,
room_id: Uuid,
) -> broadcast::Receiver<Arc<RoomMessageStreamChunkEvent>> {
let mut map = self.room_stream_inner.write().await;
if let Some(tx) = map.get(&room_id) {
return tx.subscribe();
}
let (tx, rx) = broadcast::channel(BROADCAST_CAPACITY);
map.insert(room_id, tx);
rx
}
pub async fn broadcast_stream_chunk(&self, event: RoomMessageStreamChunkEvent) {
let event = Arc::new(event);
let is_final_chunk = event.done;
let map = self.stream_inner.read().await;
if let Some(tx) = map.get(&event.message_id) {
let _ = tx.send(Arc::clone(&event));
}
drop(map);
let map = self.room_stream_inner.read().await;
if let Some(tx) = map.get(&event.room_id) {
let _ = tx.send(Arc::clone(&event));
}
if is_final_chunk {
drop(map);
let mut map = self.room_stream_inner.write().await;
map.remove(&event.room_id);
}
}
pub async fn close_stream_channel(&self, message_id: Uuid) {
let mut map = self.stream_inner.write().await;
map.remove(&message_id);
}
}
fn parse_sender_type(s: &str) -> MessageSenderType {
match s {
"member" => MessageSenderType::Member,
"admin" => MessageSenderType::Admin,
"owner" => MessageSenderType::Owner,
"ai" => MessageSenderType::Ai,
"system" => MessageSenderType::System,
"tool" => MessageSenderType::Tool,
"guest" => MessageSenderType::Guest,
_ => MessageSenderType::Member,
}
}
fn parse_content_type(s: &str) -> MessageContentType {
match s {
"text" => MessageContentType::Text,
"image" => MessageContentType::Image,
"audio" => MessageContentType::Audio,
"video" => MessageContentType::Video,
"file" => MessageContentType::File,
_ => MessageContentType::Text,
}
}
pub type PersistFn = Arc<
dyn Fn(Vec<RoomMessageEnvelope>) -> Pin<Box<dyn Future<Output = anyhow::Result<()>> + Send>>
+ Send
+ Sync,
>;
use dashmap::DashMap;
pub type DedupCache = Arc<DashMap<uuid::Uuid, Instant>>;
const DEDUP_CACHE_TTL: Duration = Duration::from_secs(300);
pub fn cleanup_dedup_cache(cache: &DedupCache) {
let cutoff = Instant::now() - DEDUP_CACHE_TTL;
cache.retain(|_, inserted_at| *inserted_at > cutoff);
}
pub fn make_persist_fn(
db: AppDatabase,
metrics: Arc<RoomMetrics>,
dedup_cache: DedupCache,
) -> PersistFn {
Arc::new(move |envelopes: Vec<RoomMessageEnvelope>| {
let db = db.clone();
let metrics = metrics.clone();
let cache = dedup_cache.clone();
Box::pin(async move {
for chunk in envelopes.chunks(BATCH_SIZE) {
let mut models_to_insert = Vec::new();
let mut ids_to_dedup: Vec<uuid::Uuid> = Vec::new();
for env in chunk {
if cache.contains_key(&env.id) {
metrics.incr_duplicates_skipped();
continue;
}
ids_to_dedup.push(env.id);
}
let existing_ids: std::collections::HashSet<uuid::Uuid> =
if !ids_to_dedup.is_empty() {
room_message::Entity::find()
.filter(room_message::Column::Id.is_in(ids_to_dedup))
.into_model::<room_message::Model>()
.all(&db)
.await?
.into_iter()
.map(|m| m.id)
.collect()
} else {
std::collections::HashSet::new()
};
for env in chunk {
if cache.contains_key(&env.id) {
continue;
}
cache.insert(env.id, Instant::now());
if existing_ids.contains(&env.id) {
metrics.incr_duplicates_skipped();
continue;
}
let sender_type = parse_sender_type(&env.sender_type);
let content_type = parse_content_type(&env.content_type);
models_to_insert.push(room_message::ActiveModel {
id: Set(env.id),
seq: Set(env.seq),
room: Set(env.room_id),
sender_type: Set(sender_type),
sender_id: Set(env.sender_id),
thread: Set(env.thread_id),
content: Set(env.content.clone()),
content_type: Set(content_type),
edited_at: Set(None),
send_at: Set(env.send_at.clone()),
revoked: Set(None),
revoked_by: Set(None),
in_reply_to: Set(env.in_reply_to),
});
}
if !models_to_insert.is_empty() {
let count = models_to_insert.len() as u64;
room_message::Entity::insert_many(models_to_insert)
.exec(&db)
.await?;
metrics.messages_persisted.increment(count);
}
}
Ok(())
})
})
}
pub type RedisFuture =
Pin<Box<dyn Future<Output = anyhow::Result<deadpool_redis::cluster::Connection>> + Send>>;
pub fn extract_get_redis(
producer: queue::MessageProducer,
) -> Arc<dyn Fn() -> RedisFuture + Send + Sync> {
Arc::new(move || {
let get_redis_fn = producer.get_redis.clone();
Box::pin(async move {
let handle = get_redis_fn();
match handle.await {
Ok(conn) => conn,
Err(_) => anyhow::bail!("redis pool task panicked"),
}
}) as RedisFuture
})
}
fn start_pubsub_thread<F, Fut>(
redis_url: String,
channel: String,
relay_tx: tokio::sync::mpsc::Sender<Vec<u8>>,
mut shutdown_rx: broadcast::Receiver<()>,
log: slog::Logger,
_on_msg: F,
) where
F: Fn(Vec<u8>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send,
{
thread::Builder::new()
.name(format!("redis-pubsub-{}", &channel[..channel.len().min(32)]))
.spawn(move || {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("pubsub thread runtime");
rt.block_on(async {
let redis_url = redis_url.clone();
loop {
if shutdown_rx.try_recv().is_ok() {
slog::info!(log, "pubsub thread shutting down before connect"; "channel" => %channel);
break;
}
let client = match redis::Client::open(redis_url.as_str()) {
Ok(c) => c,
Err(e) => {
slog::error!(log, "pubsub redis client open failed"; "channel" => %channel, "error" => %e);
thread::sleep(Duration::from_secs(1));
continue;
}
};
let mut pubsub = match client.get_async_pubsub().await {
Ok(p) => p,
Err(e) => {
slog::error!(log, "pubsub connection failed"; "channel" => %channel, "error" => %e);
thread::sleep(Duration::from_secs(1));
continue;
}
};
match pubsub.subscribe(&channel).await {
Ok(_) => slog::info!(log, "pubsub subscribed"; "channel" => %channel),
Err(e) => {
slog::error!(log, "pubsub subscribe failed"; "channel" => %channel, "error" => %e);
thread::sleep(Duration::from_secs(1));
continue;
}
}
let mut stream = pubsub.on_message();
loop {
if shutdown_rx.try_recv().is_ok() {
slog::info!(log, "pubsub thread shutting down"; "channel" => %channel);
return;
}
let msg = tokio::time::timeout(
Duration::from_millis(500),
futures::StreamExt::next(&mut stream),
)
.await;
match msg {
Ok(Some(msg)) => {
let payload = msg.get_payload_bytes();
slog::debug!(log, "pubsub received"; "channel" => %channel, "len" => payload.len());
if relay_tx.send(payload.to_vec()).await.is_err() {
slog::warn!(log, "pubsub relay channel closed"; "channel" => %channel);
return;
}
}
Ok(None) => {
slog::warn!(log, "pubsub stream ended, will reconnect"; "channel" => %channel);
break;
}
Err(_) => {}
}
}
slog::warn!(log, "pubsub connection lost, reconnecting"; "channel" => %channel);
}
});
})
.expect("pubsub thread spawn");
}
pub async fn subscribe_room_events(
redis_url: String,
manager: Arc<RoomConnectionManager>,
room_id: Uuid,
log: slog::Logger,
mut shutdown_rx: broadcast::Receiver<()>,
) {
let channel = format!("room:pub:{}", room_id);
let (tx, mut rx) = tokio::sync::mpsc::channel::<Vec<u8>>(1024);
slog::info!(log, "starting room pubsub subscriber"; "room_id" => %room_id, "channel" => %channel);
let thread_log = log.clone();
let thread_channel = channel.clone();
let thread_shutdown = shutdown_rx.resubscribe();
start_pubsub_thread(
redis_url,
thread_channel,
tx,
thread_shutdown,
thread_log,
|_| async {},
);
loop {
tokio::select! {
_ = shutdown_rx.recv() => {
slog::info!(log, "room subscriber shutting down"; "room_id" => %room_id);
break;
}
payload = rx.recv() => {
match payload {
Some(data) => {
match serde_json::from_slice::<RoomMessageEvent>(&data) {
Ok(event) => {
manager.broadcast(room_id, event).await;
}
Err(e) => {
slog::warn!(log, "malformed RoomMessageEvent"; "error" => %e);
}
}
}
None => {
slog::warn!(log, "pubsub relay channel closed"; "room_id" => %room_id);
break;
}
}
}
}
}
slog::info!(log, "room subscriber stopped"; "room_id" => %room_id);
}
pub async fn subscribe_project_room_events(
redis_url: String,
manager: Arc<RoomConnectionManager>,
project_id: Uuid,
log: slog::Logger,
mut shutdown_rx: broadcast::Receiver<()>,
) {
let channel = format!("project:pub:{}", project_id);
let (tx, mut rx) = tokio::sync::mpsc::channel::<Vec<u8>>(1024);
slog::info!(log, "starting project pubsub subscriber"; "project_id" => %project_id, "channel" => %channel);
let thread_log = log.clone();
let thread_channel = channel.clone();
let thread_shutdown = shutdown_rx.resubscribe();
start_pubsub_thread(
redis_url,
thread_channel,
tx,
thread_shutdown,
thread_log,
|_| async {},
);
loop {
tokio::select! {
_ = shutdown_rx.recv() => {
slog::info!(log, "project subscriber shutting down"; "project_id" => %project_id);
break;
}
payload = rx.recv() => {
match payload {
Some(data) => {
match serde_json::from_slice::<ProjectRoomEvent>(&data) {
Ok(event) => {
manager.broadcast_project(project_id, event).await;
}
Err(e) => {
slog::warn!(log, "malformed ProjectRoomEvent"; "error" => %e);
}
}
}
None => {
slog::warn!(log, "project pubsub relay channel closed"; "project_id" => %project_id);
break;
}
}
}
}
}
slog::info!(log, "project subscriber stopped"; "project_id" => %project_id);
}
/// Subscribe to Redis Pub/Sub `task:pub:{project_id}` and relay events to
/// `RoomConnectionManager::broadcast_agent_task()` so all WS clients get notified.
pub async fn subscribe_task_events_fn(
redis_url: String,
manager: Arc<RoomConnectionManager>,
project_id: Uuid,
log: slog::Logger,
mut shutdown_rx: broadcast::Receiver<()>,
) {
let channel = format!("task:pub:{}", project_id);
let (tx, mut rx) = tokio::sync::mpsc::channel::<Vec<u8>>(1024);
slog::info!(log, "starting task pubsub subscriber"; "project_id" => %project_id, "channel" => %channel);
let thread_log = log.clone();
let thread_channel = channel.clone();
let thread_shutdown = shutdown_rx.resubscribe();
start_pubsub_thread(
redis_url,
thread_channel,
tx,
thread_shutdown,
thread_log,
|_| async {},
);
loop {
tokio::select! {
_ = shutdown_rx.recv() => {
slog::info!(log, "task subscriber shutting down"; "project_id" => %project_id);
break;
}
payload = rx.recv() => {
match payload {
Some(data) => {
match serde_json::from_slice::<AgentTaskEvent>(&data) {
Ok(event) => {
manager.broadcast_agent_task(project_id, event).await;
}
Err(e) => {
slog::warn!(log, "malformed AgentTaskEvent"; "error" => %e);
}
}
}
None => {
slog::warn!(log, "task pubsub relay channel closed"; "project_id" => %project_id);
break;
}
}
}
}
}
slog::info!(log, "task subscriber stopped"; "project_id" => %project_id);
}