609 lines
19 KiB
Rust
609 lines
19 KiB
Rust
use std::{
|
|
collections::{HashMap, HashSet},
|
|
sync::Arc,
|
|
};
|
|
|
|
use cache::AppCache;
|
|
use dashmap::DashMap;
|
|
use db::AppDatabase;
|
|
use model::room::RoomMessageModel;
|
|
use serde::Deserialize;
|
|
use serde::Serialize;
|
|
use socketio::{Socket, SocketIo};
|
|
use tokio::sync::{Mutex, RwLock};
|
|
use tracing::warn;
|
|
use uuid::Uuid;
|
|
|
|
use crate::{
|
|
ChannelBusConfig, ChannelError, ChannelResult,
|
|
circuit_breaker::CircuitBreaker,
|
|
dedup::DeduplicationManager,
|
|
event::ChannelEvent,
|
|
metrics::ChannelMetrics,
|
|
reconnect::ReconnectManager,
|
|
rooms::{
|
|
active_workspace_users, catchup_messages, refresh_user_rooms_cache,
|
|
room_socket_name, room_workspace, user_rooms,
|
|
},
|
|
security::{CsrfProtection, RateLimiter},
|
|
seq::SeqAllocator,
|
|
};
|
|
|
|
const ROOM_MESSAGE_EVENT: &str = "room.message";
|
|
|
|
#[derive(Clone)]
|
|
pub struct ChannelBus {
|
|
pub(crate) inner: Arc<Inner>,
|
|
}
|
|
|
|
pub(crate) struct Inner {
|
|
pub(crate) db: AppDatabase,
|
|
pub(crate) cache: AppCache,
|
|
pub(crate) io: SocketIo,
|
|
pub(crate) config: ChannelBusConfig,
|
|
pub(crate) online: RwLock<HashMap<Uuid, HashMap<String, Socket>>>,
|
|
pub(crate) user_sync_locks: DashMap<Uuid, Arc<Mutex<()>>>,
|
|
pub(crate) typing_states: DashMap<(Uuid, Uuid), (crate::event::UserInfo, crate::event::RoomInfo, tokio_util::sync::CancellationToken)>,
|
|
pub(crate) seq: SeqAllocator,
|
|
pub(crate) dedup: DeduplicationManager,
|
|
pub(crate) metrics: ChannelMetrics,
|
|
pub(crate) reconnect: ReconnectManager,
|
|
pub(crate) rate_limiter: RateLimiter,
|
|
pub(crate) csrf: CsrfProtection,
|
|
pub(crate) circuit_breaker: CircuitBreaker,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct ConnectAuth {
|
|
#[serde(default)]
|
|
last_seq: HashMap<Uuid, i64>,
|
|
}
|
|
|
|
impl ChannelBus {
|
|
pub fn io(&self) -> &SocketIo {
|
|
&self.inner.io
|
|
}
|
|
pub async fn first_workspace_id(
|
|
&self,
|
|
user: Uuid,
|
|
) -> ChannelResult<Option<Uuid>> {
|
|
let row = db::sqlx::query_as::<_, (Uuid,)>(
|
|
"SELECT wk FROM wk_member WHERE \"user\" = $1 AND leave_at IS NULL LIMIT 1",
|
|
)
|
|
.bind(user)
|
|
.fetch_optional(self.inner.db.reader())
|
|
.await?;
|
|
Ok(row.map(|r| r.0))
|
|
}
|
|
pub async fn lookup_room(
|
|
&self,
|
|
room: Uuid,
|
|
) -> ChannelResult<crate::event::RoomInfo> {
|
|
let row = db::sqlx::query_as::<_, (String,)>(
|
|
"SELECT name FROM room WHERE id = $1",
|
|
)
|
|
.bind(room)
|
|
.fetch_optional(self.inner.db.reader())
|
|
.await?
|
|
.map(|(name,)| name)
|
|
.unwrap_or_default();
|
|
Ok(crate::event::RoomInfo {
|
|
id: room,
|
|
name: row,
|
|
})
|
|
}
|
|
pub async fn list_workspace_members(
|
|
&self,
|
|
workspace: Uuid,
|
|
) -> ChannelResult<Vec<(Uuid, String, String, String)>> {
|
|
let rows = db::sqlx::query_as::<_, (Uuid, String, String, String)>(
|
|
r#"SELECT u.id, u.username, u.display_name, u.avatar_url
|
|
FROM wk_member wm
|
|
JOIN "user" u ON u.id = wm."user"
|
|
WHERE wm.wk = $1 AND wm.leave_at IS NULL
|
|
ORDER BY u.username"#,
|
|
)
|
|
.bind(workspace)
|
|
.fetch_all(self.inner.db.reader())
|
|
.await?;
|
|
Ok(rows)
|
|
}
|
|
pub async fn lookup_workspace(
|
|
&self,
|
|
wk: Uuid,
|
|
) -> ChannelResult<crate::event::WorkspaceInfo> {
|
|
use db::sqlx::Row;
|
|
let row = db::sqlx::query(
|
|
"SELECT name, avatar_url FROM workspace WHERE id = $1",
|
|
)
|
|
.bind(wk)
|
|
.fetch_optional(self.inner.db.reader())
|
|
.await?;
|
|
let (name, avatar_url) = match row {
|
|
Some(r) => (r.get::<String, _>(0), r.get::<String, _>(1)),
|
|
None => (String::new(), String::new()),
|
|
};
|
|
Ok(crate::event::WorkspaceInfo {
|
|
id: wk,
|
|
name,
|
|
avatar_url,
|
|
})
|
|
}
|
|
pub async fn lookup_users(
|
|
&self,
|
|
users: &[Uuid],
|
|
) -> ChannelResult<std::collections::HashMap<Uuid, crate::event::UserInfo>> {
|
|
if users.is_empty() {
|
|
return Ok(std::collections::HashMap::new());
|
|
}
|
|
let rows = db::sqlx::query_as::<_, model::users::UserModel>(
|
|
"SELECT id, username, display_name, avatar_url, website_url, \
|
|
allow_use, can_search, last_sign_in_at, created_at, updated_at \
|
|
FROM \"user\" WHERE id = ANY($1)",
|
|
)
|
|
.bind(users)
|
|
.fetch_all(self.inner.db.reader())
|
|
.await?;
|
|
Ok(rows
|
|
.into_iter()
|
|
.map(|m| (m.id, crate::event::UserInfo::from_model(&m)))
|
|
.collect())
|
|
}
|
|
pub async fn lookup_user(
|
|
&self,
|
|
user: Uuid,
|
|
) -> ChannelResult<crate::event::UserInfo> {
|
|
let row = db::sqlx::query_as::<_, model::users::UserModel>(
|
|
"SELECT id, username, display_name, avatar_url, website_url, \
|
|
allow_use, can_search, last_sign_in_at, created_at, updated_at \
|
|
FROM \"user\" WHERE id = $1",
|
|
)
|
|
.bind(user)
|
|
.fetch_optional(self.inner.db.reader())
|
|
.await?
|
|
.map(|m| crate::event::UserInfo::from_model(&m))
|
|
.unwrap_or_else(|| crate::event::UserInfo::unknown(user));
|
|
Ok(row)
|
|
}
|
|
pub async fn list_user_rooms(
|
|
&self,
|
|
user: Uuid,
|
|
) -> ChannelResult<Vec<crate::rooms::RoomListItem>> {
|
|
crate::rooms::user_rooms_for_api(
|
|
&self.inner.db,
|
|
&self.inner.cache,
|
|
&self.inner.config,
|
|
user,
|
|
)
|
|
.await
|
|
}
|
|
pub async fn list_user_categories(
|
|
&self,
|
|
user: Uuid,
|
|
) -> ChannelResult<Vec<crate::rooms::CategoryListItem>> {
|
|
crate::rooms::user_categories_for_api(
|
|
&self.inner.db,
|
|
&self.inner.cache,
|
|
&self.inner.config,
|
|
user,
|
|
)
|
|
.await
|
|
}
|
|
|
|
pub fn new(
|
|
db: AppDatabase,
|
|
cache: AppCache,
|
|
io: SocketIo,
|
|
config: ChannelBusConfig,
|
|
) -> Self {
|
|
let seq = match config.seq_segment_size {
|
|
Some(size) => {
|
|
SeqAllocator::with_segment_size(cache.clone(), db.clone(), size)
|
|
}
|
|
None => SeqAllocator::new(cache.clone(), db.clone()),
|
|
};
|
|
let dedup = DeduplicationManager::with_config(
|
|
cache.clone(),
|
|
std::time::Duration::from_secs(
|
|
config.dedup_window_secs.unwrap_or(300),
|
|
),
|
|
);
|
|
let reconnect = ReconnectManager::new(cache.clone(), db.clone());
|
|
let rate_limiter = match (
|
|
config.rate_limit_max_requests,
|
|
config.rate_limit_window_secs,
|
|
) {
|
|
(Some(max), Some(secs)) => RateLimiter::with_config(
|
|
cache.clone(),
|
|
max,
|
|
std::time::Duration::from_secs(secs),
|
|
),
|
|
_ => RateLimiter::new(cache.clone()),
|
|
};
|
|
let csrf = CsrfProtection::new(cache.clone());
|
|
let circuit_breaker = match (
|
|
config.circuit_breaker_failure_threshold,
|
|
config.circuit_breaker_success_threshold,
|
|
config.circuit_breaker_timeout_secs,
|
|
config.circuit_breaker_half_open_max_calls,
|
|
) {
|
|
(Some(failure), Some(success), Some(secs), Some(half_open)) => {
|
|
CircuitBreaker::with_config(
|
|
failure,
|
|
success,
|
|
std::time::Duration::from_secs(secs),
|
|
half_open,
|
|
)
|
|
}
|
|
_ => CircuitBreaker::new(),
|
|
};
|
|
Self {
|
|
inner: Arc::new(Inner {
|
|
db,
|
|
cache,
|
|
io,
|
|
config,
|
|
online: RwLock::new(HashMap::new()),
|
|
user_sync_locks: DashMap::new(),
|
|
typing_states: DashMap::new(),
|
|
seq,
|
|
dedup,
|
|
metrics: ChannelMetrics::new(),
|
|
reconnect,
|
|
rate_limiter,
|
|
csrf,
|
|
circuit_breaker,
|
|
}),
|
|
}
|
|
}
|
|
|
|
pub async fn attach(&self) -> ChannelResult<()> {
|
|
let namespace =
|
|
self.inner.io.namespace(&self.inner.config.namespace).await;
|
|
|
|
let auth_bus = self.clone();
|
|
namespace
|
|
.use_middleware(move |socket, auth| {
|
|
let bus = auth_bus.clone();
|
|
async move {
|
|
if socket.session_user().is_some() {
|
|
return Ok(());
|
|
}
|
|
let token = auth
|
|
.as_ref()
|
|
.and_then(|v| v.get("access_token"))
|
|
.and_then(|v| v.as_str());
|
|
if let Some(token) = token {
|
|
let ctx = bus
|
|
.check_access_token(token.to_owned())
|
|
.await
|
|
.map_err(|_| {
|
|
socketio::SocketIoError::Adapter(
|
|
"token invalid or expired".to_owned(),
|
|
)
|
|
})?;
|
|
socket.set_user(ctx.user_id);
|
|
return Ok(());
|
|
}
|
|
Err(socketio::SocketIoError::Adapter(
|
|
"unauthorized".to_owned(),
|
|
))
|
|
}
|
|
})
|
|
.await;
|
|
|
|
let on_connect_bus = self.clone();
|
|
namespace
|
|
.on_connect(move |socket| {
|
|
let bus = on_connect_bus.clone();
|
|
async move {
|
|
bus.inner.metrics.increment_connections();
|
|
if let Err(error) = bus.handle_connect(socket.clone()).await {
|
|
warn!(%error, "channel socket connect failed, disconnecting");
|
|
let _ = socket.disconnect().await;
|
|
}
|
|
}
|
|
})
|
|
.await;
|
|
|
|
let on_disconnect_bus = self.clone();
|
|
namespace
|
|
.on_disconnect(move |socket, _reason| {
|
|
let bus = on_disconnect_bus.clone();
|
|
async move {
|
|
bus.inner.metrics.decrement_connections();
|
|
bus.handle_disconnect(&socket).await;
|
|
}
|
|
})
|
|
.await;
|
|
crate::http::ws::register_message_handler(self).await?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn publish_room_message(
|
|
&self,
|
|
message: RoomMessageModel,
|
|
sender: Option<crate::event::UserInfo>,
|
|
) -> ChannelResult<()> {
|
|
let is_new = self
|
|
.inner
|
|
.dedup
|
|
.check_and_mark(message.id, message.room)
|
|
.await?;
|
|
if !is_new {
|
|
return Ok(());
|
|
}
|
|
let event = match sender {
|
|
Some(s) => ChannelEvent::message_created_with_sender(message, s),
|
|
None => ChannelEvent::message_created(message),
|
|
};
|
|
self.publish_event(event).await
|
|
}
|
|
|
|
pub async fn publish_room_event<T>(
|
|
&self,
|
|
room: Uuid,
|
|
event_type: impl Into<String>,
|
|
payload: T,
|
|
) -> ChannelResult<()>
|
|
where
|
|
T: Serialize,
|
|
{
|
|
let payload = serde_json::to_value(payload)?;
|
|
self.publish_event(ChannelEvent::custom(room, event_type, payload))
|
|
.await
|
|
}
|
|
pub async fn emit_to_user<T: Serialize>(
|
|
&self,
|
|
user: Uuid,
|
|
event: &str,
|
|
data: &T,
|
|
) -> ChannelResult<()> {
|
|
let sockets = self
|
|
.inner
|
|
.online
|
|
.read()
|
|
.await
|
|
.get(&user)
|
|
.map(|sockets| sockets.values().cloned().collect::<Vec<_>>())
|
|
.unwrap_or_default();
|
|
|
|
for socket in sockets {
|
|
socket.emit(event, data).await?;
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn refresh_user(&self, user: Uuid) -> ChannelResult<()> {
|
|
let rooms = refresh_user_rooms_cache(
|
|
&self.inner.db,
|
|
&self.inner.cache,
|
|
&self.inner.config,
|
|
user,
|
|
)
|
|
.await?;
|
|
self.sync_online_user_rooms(user, &rooms).await
|
|
}
|
|
|
|
pub async fn workspace_changed(&self, wk: Uuid) -> ChannelResult<()> {
|
|
let users = active_workspace_users(&self.inner.db, wk).await?;
|
|
let bus = self.clone();
|
|
let results =
|
|
futures::future::join_all(users.into_iter().map(|user| {
|
|
let bus = bus.clone();
|
|
async move { bus.refresh_user(user).await }
|
|
}))
|
|
.await;
|
|
let mut first_error = None;
|
|
for result in results {
|
|
if let Err(e) = result {
|
|
tracing::warn!(error = %e, "workspace refresh failed for user");
|
|
if first_error.is_none() {
|
|
first_error = Some(e);
|
|
}
|
|
}
|
|
}
|
|
if let Some(e) = first_error {
|
|
Err(e)
|
|
} else {
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
pub async fn room_changed(&self, room: Uuid) -> ChannelResult<()> {
|
|
if let Some(wk) = room_workspace(&self.inner.db, room).await? {
|
|
self.workspace_changed(wk).await?;
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
async fn publish_event(&self, event: ChannelEvent) -> ChannelResult<()> {
|
|
self.inner.metrics.increment_sent();
|
|
|
|
let result = self
|
|
.inner
|
|
.circuit_breaker
|
|
.call(async {
|
|
self.inner
|
|
.io
|
|
.namespace(&self.inner.config.namespace)
|
|
.await
|
|
.to(room_socket_name(event.room))
|
|
.emit(ROOM_MESSAGE_EVENT, event)
|
|
.await
|
|
.map_err(ChannelError::SocketIo)
|
|
})
|
|
.await;
|
|
|
|
match result {
|
|
Ok(()) => {
|
|
self.inner.metrics.increment_received();
|
|
Ok(())
|
|
}
|
|
Err(e) => {
|
|
self.inner.metrics.increment_failed();
|
|
match e {
|
|
crate::circuit_breaker::CircuitBreakerError::Open => {
|
|
Err(ChannelError::Internal(
|
|
"circuit breaker open".to_string(),
|
|
))
|
|
}
|
|
crate::circuit_breaker::CircuitBreakerError::Inner(e) => {
|
|
Err(e)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn handle_connect(&self, socket: Socket) -> ChannelResult<()> {
|
|
let user = socket.session_user().ok_or(ChannelError::Unauthorized)?;
|
|
|
|
if !self
|
|
.inner
|
|
.rate_limiter
|
|
.check_rate_limit(user, "connect")
|
|
.await?
|
|
{
|
|
return Err(ChannelError::RateLimitExceeded);
|
|
}
|
|
|
|
let auth = socket
|
|
.auth()
|
|
.await
|
|
.and_then(|value| serde_json::from_value::<ConnectAuth>(value).ok())
|
|
.unwrap_or_else(|| ConnectAuth {
|
|
last_seq: HashMap::new(),
|
|
});
|
|
let rooms = user_rooms(
|
|
&self.inner.db,
|
|
&self.inner.cache,
|
|
&self.inner.config,
|
|
user,
|
|
)
|
|
.await?;
|
|
|
|
for room in &rooms {
|
|
socket.join(room_socket_name(*room)).await?;
|
|
}
|
|
self.register_socket(user, socket.clone()).await;
|
|
self.catchup(&socket, &rooms, &auth.last_seq).await?;
|
|
Ok(())
|
|
}
|
|
|
|
async fn handle_disconnect(&self, socket: &Socket) {
|
|
let Some(user) = socket.session_user() else {
|
|
return;
|
|
};
|
|
let rooms = socket.rooms().await;
|
|
for room_name in &rooms {
|
|
if let Some(room_str) = room_name.strip_prefix("room:") {
|
|
if let Ok(room_id) = Uuid::parse_str(room_str) {
|
|
let _ = self.publish_room_event(
|
|
room_id,
|
|
"voice.channel_left",
|
|
serde_json::json!({"user_id": user, "disconnected": true}),
|
|
)
|
|
.await;
|
|
}
|
|
}
|
|
}
|
|
let mut online = self.inner.online.write().await;
|
|
if let Some(sockets) = online.get_mut(&user) {
|
|
sockets.remove(socket.id());
|
|
if sockets.is_empty() {
|
|
online.remove(&user);
|
|
self.inner.user_sync_locks.remove(&user);
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn register_socket(&self, user: Uuid, socket: Socket) {
|
|
self.inner
|
|
.online
|
|
.write()
|
|
.await
|
|
.entry(user)
|
|
.or_default()
|
|
.insert(socket.id().to_owned(), socket);
|
|
}
|
|
|
|
async fn sync_online_user_rooms(
|
|
&self,
|
|
user: Uuid,
|
|
desired_rooms: &[Uuid],
|
|
) -> ChannelResult<()> {
|
|
let lock = self
|
|
.inner
|
|
.user_sync_locks
|
|
.entry(user)
|
|
.or_insert_with(|| Arc::new(Mutex::new(())))
|
|
.clone();
|
|
let _guard = lock.lock().await;
|
|
|
|
let sockets = self
|
|
.inner
|
|
.online
|
|
.read()
|
|
.await
|
|
.get(&user)
|
|
.map(|sockets| sockets.values().cloned().collect::<Vec<_>>())
|
|
.unwrap_or_default();
|
|
|
|
let desired = desired_rooms
|
|
.iter()
|
|
.map(|room| room_socket_name(*room))
|
|
.collect::<HashSet<_>>();
|
|
|
|
for socket in sockets {
|
|
let current = socket
|
|
.rooms()
|
|
.await
|
|
.into_iter()
|
|
.filter(|room| room.starts_with("room:"))
|
|
.collect::<HashSet<_>>();
|
|
|
|
for room in desired.difference(¤t) {
|
|
socket.join(room.clone()).await?;
|
|
}
|
|
for room in current.difference(&desired) {
|
|
socket.leave(room).await?;
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
async fn catchup(
|
|
&self,
|
|
socket: &Socket,
|
|
rooms: &[Uuid],
|
|
last_seq: &HashMap<Uuid, i64>,
|
|
) -> ChannelResult<()> {
|
|
for room in rooms {
|
|
let Some(seq) = last_seq.get(room) else {
|
|
continue;
|
|
};
|
|
let messages = catchup_messages(
|
|
&self.inner.db,
|
|
&self.inner.config,
|
|
*room,
|
|
*seq,
|
|
)
|
|
.await?;
|
|
for message in messages {
|
|
let sender = match self.lookup_user(message.author).await {
|
|
Ok(s) => Some(s),
|
|
Err(_) => None,
|
|
};
|
|
let event = match sender {
|
|
Some(s) => ChannelEvent::message_created_with_sender(message, s),
|
|
None => ChannelEvent::message_created(message),
|
|
};
|
|
socket.emit(ROOM_MESSAGE_EVENT, event).await?;
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
}
|