gitdataai/lib/channel/bus.rs

607 lines
18 KiB
Rust

use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use cache::AppCache;
use dashmap::DashMap;
use db::AppDatabase;
use model::room::RoomMessageModel;
use serde::Deserialize;
use serde::Serialize;
use socketio::{Socket, SocketIo};
use tokio::sync::{Mutex, RwLock};
use tracing::warn;
use uuid::Uuid;
use crate::{
ChannelBusConfig, ChannelError, ChannelResult,
circuit_breaker::CircuitBreaker,
dedup::DeduplicationManager,
event::ChannelEvent,
metrics::ChannelMetrics,
reconnect::ReconnectManager,
rooms::{
active_workspace_users, catchup_messages, refresh_user_rooms_cache,
room_socket_name, room_workspace, user_rooms,
},
security::{CsrfProtection, RateLimiter},
seq::SeqAllocator,
};
const ROOM_MESSAGE_EVENT: &str = "room.message";
#[derive(Clone)]
pub struct ChannelBus {
pub inner: Arc<Inner>,
}
pub struct Inner {
pub db: AppDatabase,
pub cache: AppCache,
pub io: SocketIo,
pub config: ChannelBusConfig,
pub online: RwLock<HashMap<Uuid, HashMap<String, Socket>>>,
pub user_sync_locks: DashMap<Uuid, Arc<Mutex<()>>>,
pub typing_states: DashMap<
(Uuid, Uuid),
(
crate::event::UserInfo,
crate::event::RoomInfo,
tokio_util::sync::CancellationToken,
),
>,
pub seq: SeqAllocator,
pub dedup: DeduplicationManager,
pub metrics: ChannelMetrics,
pub reconnect: ReconnectManager,
pub rate_limiter: RateLimiter,
pub csrf: CsrfProtection,
pub circuit_breaker: CircuitBreaker,
}
#[derive(Debug, Deserialize)]
struct ConnectAuth {
#[serde(default)]
last_seq: HashMap<Uuid, i64>,
}
impl ChannelBus {
pub fn io(&self) -> &SocketIo {
&self.inner.io
}
pub async fn first_workspace_id(
&self,
user: Uuid,
) -> ChannelResult<Option<Uuid>> {
let row = db::sqlx::query_as::<_, (Uuid,)>(
"SELECT wk FROM wk_member WHERE \"user\" = $1 AND leave_at IS NULL LIMIT 1",
)
.bind(user)
.fetch_optional(self.inner.db.reader())
.await?;
Ok(row.map(|r| r.0))
}
pub async fn lookup_room(
&self,
room: Uuid,
) -> ChannelResult<crate::event::RoomInfo> {
let row = db::sqlx::query_as::<_, (String,)>(
"SELECT name FROM room WHERE id = $1",
)
.bind(room)
.fetch_optional(self.inner.db.reader())
.await?
.map(|(name,)| name)
.unwrap_or_default();
Ok(crate::event::RoomInfo {
id: room,
name: row,
})
}
pub async fn list_workspace_members(
&self,
workspace: Uuid,
) -> ChannelResult<Vec<(Uuid, String, String, String)>> {
let rows = db::sqlx::query_as::<_, (Uuid, String, String, String)>(
r#"SELECT u.id, u.username, u.display_name, u.avatar_url
FROM wk_member wm
JOIN "user" u ON u.id = wm."user"
WHERE wm.wk = $1 AND wm.leave_at IS NULL
ORDER BY u.username"#,
)
.bind(workspace)
.fetch_all(self.inner.db.reader())
.await?;
Ok(rows)
}
pub async fn lookup_workspace(
&self,
wk: Uuid,
) -> ChannelResult<crate::event::WorkspaceInfo> {
use db::sqlx::Row;
let row = db::sqlx::query(
"SELECT name, avatar_url FROM workspace WHERE id = $1",
)
.bind(wk)
.fetch_optional(self.inner.db.reader())
.await?;
let (name, avatar_url) = match row {
Some(r) => (r.get::<String, _>(0), r.get::<String, _>(1)),
None => (String::new(), String::new()),
};
Ok(crate::event::WorkspaceInfo {
id: wk,
name,
avatar_url,
})
}
pub async fn lookup_users(
&self,
users: &[Uuid],
) -> ChannelResult<std::collections::HashMap<Uuid, crate::event::UserInfo>>
{
if users.is_empty() {
return Ok(std::collections::HashMap::new());
}
let rows = db::sqlx::query_as::<_, model::users::UserModel>(
"SELECT id, username, display_name, avatar_url, website_url, \
allow_use, can_search, last_sign_in_at, created_at, updated_at \
FROM \"user\" WHERE id = ANY($1)",
)
.bind(users)
.fetch_all(self.inner.db.reader())
.await?;
Ok(rows
.into_iter()
.map(|m| (m.id, crate::event::UserInfo::from_model(&m)))
.collect())
}
pub async fn lookup_user(
&self,
user: Uuid,
) -> ChannelResult<crate::event::UserInfo> {
let row = db::sqlx::query_as::<_, model::users::UserModel>(
"SELECT id, username, display_name, avatar_url, website_url, \
allow_use, can_search, last_sign_in_at, created_at, updated_at \
FROM \"user\" WHERE id = $1",
)
.bind(user)
.fetch_optional(self.inner.db.reader())
.await?
.map(|m| crate::event::UserInfo::from_model(&m))
.unwrap_or_else(|| crate::event::UserInfo::unknown(user));
Ok(row)
}
pub async fn list_user_rooms(
&self,
user: Uuid,
) -> ChannelResult<Vec<crate::rooms::RoomListItem>> {
crate::rooms::user_rooms_for_api(
&self.inner.db,
&self.inner.cache,
&self.inner.config,
user,
)
.await
}
pub async fn list_user_categories(
&self,
user: Uuid,
) -> ChannelResult<Vec<crate::rooms::CategoryListItem>> {
crate::rooms::user_categories_for_api(
&self.inner.db,
&self.inner.cache,
&self.inner.config,
user,
)
.await
}
pub fn new(
db: AppDatabase,
cache: AppCache,
io: SocketIo,
config: ChannelBusConfig,
) -> Self {
let seq = match config.seq_segment_size {
Some(size) => {
SeqAllocator::with_segment_size(cache.clone(), db.clone(), size)
}
None => SeqAllocator::new(cache.clone(), db.clone()),
};
let dedup = DeduplicationManager::with_config(
cache.clone(),
std::time::Duration::from_secs(
config.dedup_window_secs.unwrap_or(300),
),
);
let reconnect = ReconnectManager::new(cache.clone(), db.clone());
let rate_limiter = match (
config.rate_limit_max_requests,
config.rate_limit_window_secs,
) {
(Some(max), Some(secs)) => RateLimiter::with_config(
cache.clone(),
max,
std::time::Duration::from_secs(secs),
),
_ => RateLimiter::new(cache.clone()),
};
let csrf = CsrfProtection::new(cache.clone());
let circuit_breaker = match (
config.circuit_breaker_failure_threshold,
config.circuit_breaker_success_threshold,
config.circuit_breaker_timeout_secs,
config.circuit_breaker_half_open_max_calls,
) {
(Some(failure), Some(success), Some(secs), Some(half_open)) => {
CircuitBreaker::with_config(
failure,
success,
std::time::Duration::from_secs(secs),
half_open,
)
}
_ => CircuitBreaker::new(),
};
Self {
inner: Arc::new(Inner {
db,
cache,
io,
config,
online: RwLock::new(HashMap::new()),
user_sync_locks: DashMap::new(),
typing_states: DashMap::new(),
seq,
dedup,
metrics: ChannelMetrics::new(),
reconnect,
rate_limiter,
csrf,
circuit_breaker,
}),
}
}
pub async fn attach(&self) -> ChannelResult<()> {
let namespace =
self.inner.io.namespace(&self.inner.config.namespace).await;
let auth_bus = self.clone();
namespace
.use_middleware(move |socket, auth| {
let bus = auth_bus.clone();
async move {
if socket.session_user().is_some() {
return Ok(());
}
let token = auth
.as_ref()
.and_then(|v| v.get("access_token"))
.and_then(|v| v.as_str());
if let Some(token) = token {
let ctx = bus
.check_access_token(token.to_owned())
.await
.map_err(|_| {
socketio::SocketIoError::Adapter(
"token invalid or expired".to_owned(),
)
})?;
socket.set_user(ctx.user_id);
return Ok(());
}
Err(socketio::SocketIoError::Adapter(
"unauthorized".to_owned(),
))
}
})
.await;
let on_connect_bus = self.clone();
namespace
.on_connect(move |socket| {
let bus = on_connect_bus.clone();
async move {
bus.inner.metrics.increment_connections();
if let Err(error) = bus.handle_connect(socket.clone()).await {
warn!(%error, "channel socket connect failed, disconnecting");
let _ = socket.disconnect().await;
}
}
})
.await;
let on_disconnect_bus = self.clone();
namespace
.on_disconnect(move |socket, _reason| {
let bus = on_disconnect_bus.clone();
async move {
bus.inner.metrics.decrement_connections();
bus.handle_disconnect(&socket).await;
}
})
.await;
crate::http::ws::register_message_handler(self).await?;
Ok(())
}
pub async fn publish_room_message(
&self,
message: RoomMessageModel,
sender: Option<crate::event::UserInfo>,
) -> ChannelResult<()> {
let is_new = self
.inner
.dedup
.check_and_mark(message.id, message.room)
.await?;
if !is_new {
return Ok(());
}
let event = match sender {
Some(s) => ChannelEvent::message_created_with_sender(message, s),
None => ChannelEvent::message_created(message),
};
self.publish_event(event).await
}
pub async fn publish_room_event<T>(
&self,
room: Uuid,
event_type: impl Into<String>,
payload: T,
) -> ChannelResult<()>
where
T: Serialize,
{
let payload = serde_json::to_value(payload)?;
self.publish_event(ChannelEvent::custom(room, event_type, payload))
.await
}
pub async fn emit_to_user<T: Serialize>(
&self,
user: Uuid,
event: &str,
data: &T,
) -> ChannelResult<()> {
let sockets = self
.inner
.online
.read()
.await
.get(&user)
.map(|sockets| sockets.values().cloned().collect::<Vec<_>>())
.unwrap_or_default();
for socket in sockets {
socket.emit(event, data).await?;
}
Ok(())
}
pub async fn refresh_user(&self, user: Uuid) -> ChannelResult<()> {
let rooms = refresh_user_rooms_cache(
&self.inner.db,
&self.inner.cache,
&self.inner.config,
user,
)
.await?;
self.sync_online_user_rooms(user, &rooms).await
}
pub async fn workspace_changed(&self, wk: Uuid) -> ChannelResult<()> {
let users = active_workspace_users(&self.inner.db, wk).await?;
let bus = self.clone();
let results =
futures::future::join_all(users.into_iter().map(|user| {
let bus = bus.clone();
async move { bus.refresh_user(user).await }
}))
.await;
let mut first_error = None;
for result in results {
if let Err(e) = result {
tracing::warn!(error = %e, "workspace refresh failed for user");
if first_error.is_none() {
first_error = Some(e);
}
}
}
if let Some(e) = first_error {
Err(e)
} else {
Ok(())
}
}
pub async fn room_changed(&self, room: Uuid) -> ChannelResult<()> {
if let Some(wk) = room_workspace(&self.inner.db, room).await? {
self.workspace_changed(wk).await?;
}
Ok(())
}
async fn publish_event(&self, event: ChannelEvent) -> ChannelResult<()> {
self.inner.metrics.increment_sent();
// Best-effort broadcast — individual socket failures are expected
// (sockets disconnect) and should not block all broadcasts.
let result = self
.inner
.io
.namespace(&self.inner.config.namespace)
.await
.to(room_socket_name(event.room))
.emit(ROOM_MESSAGE_EVENT, event)
.await;
match result {
Ok(()) => {
self.inner.metrics.increment_received();
Ok(())
}
Err(e) => {
tracing::warn!(error = %e, "WS broadcast failed");
self.inner.metrics.increment_failed();
Ok(()) // best-effort: don't propagate broadcast errors
}
}
}
async fn handle_connect(&self, socket: Socket) -> ChannelResult<()> {
let user = socket.session_user().ok_or(ChannelError::Unauthorized)?;
if !self
.inner
.rate_limiter
.check_rate_limit(user, "connect")
.await?
{
return Err(ChannelError::RateLimitExceeded);
}
let auth = socket
.auth()
.await
.and_then(|value| serde_json::from_value::<ConnectAuth>(value).ok())
.unwrap_or_else(|| ConnectAuth {
last_seq: HashMap::new(),
});
let rooms = user_rooms(
&self.inner.db,
&self.inner.cache,
&self.inner.config,
user,
)
.await?;
for room in &rooms {
socket.join(room_socket_name(*room)).await?;
}
self.register_socket(user, socket.clone()).await;
self.catchup(&socket, &rooms, &auth.last_seq).await?;
Ok(())
}
async fn handle_disconnect(&self, socket: &Socket) {
let Some(user) = socket.session_user() else {
return;
};
let rooms = socket.rooms().await;
for room_name in &rooms {
if let Some(room_str) = room_name.strip_prefix("room:") {
if let Ok(room_id) = Uuid::parse_str(room_str) {
let _ = self.publish_room_event(
room_id,
"voice.channel_left",
serde_json::json!({"user_id": user, "disconnected": true}),
)
.await;
}
}
}
let mut online = self.inner.online.write().await;
if let Some(sockets) = online.get_mut(&user) {
sockets.remove(socket.id());
if sockets.is_empty() {
online.remove(&user);
self.inner.user_sync_locks.remove(&user);
}
}
}
async fn register_socket(&self, user: Uuid, socket: Socket) {
self.inner
.online
.write()
.await
.entry(user)
.or_default()
.insert(socket.id().to_owned(), socket);
}
async fn sync_online_user_rooms(
&self,
user: Uuid,
desired_rooms: &[Uuid],
) -> ChannelResult<()> {
let lock = self
.inner
.user_sync_locks
.entry(user)
.or_insert_with(|| Arc::new(Mutex::new(())))
.clone();
let _guard = lock.lock().await;
let sockets = self
.inner
.online
.read()
.await
.get(&user)
.map(|sockets| sockets.values().cloned().collect::<Vec<_>>())
.unwrap_or_default();
let desired = desired_rooms
.iter()
.map(|room| room_socket_name(*room))
.collect::<HashSet<_>>();
for socket in sockets {
let current = socket
.rooms()
.await
.into_iter()
.filter(|room| room.starts_with("room:"))
.collect::<HashSet<_>>();
for room in desired.difference(&current) {
socket.join(room.clone()).await?;
}
for room in current.difference(&desired) {
socket.leave(room).await?;
}
}
Ok(())
}
async fn catchup(
&self,
socket: &Socket,
rooms: &[Uuid],
last_seq: &HashMap<Uuid, i64>,
) -> ChannelResult<()> {
for room in rooms {
let Some(seq) = last_seq.get(room) else {
continue;
};
let messages = catchup_messages(
&self.inner.db,
&self.inner.config,
*room,
*seq,
)
.await?;
for message in messages {
let sender = match self.lookup_user(message.author).await {
Ok(s) => Some(s),
Err(_) => None,
};
let event = match sender {
Some(s) => {
ChannelEvent::message_created_with_sender(message, s)
}
None => ChannelEvent::message_created(message),
};
socket.emit(ROOM_MESSAGE_EVENT, event).await?;
}
}
Ok(())
}
}