From 079ea3a5cf446c3184b88d87726b4c2e12204ecb Mon Sep 17 00:00:00 2001 From: zhenyi <434836402@qq.com> Date: Mon, 1 Jun 2026 22:04:31 +0800 Subject: [PATCH] refactor: update channel and model layers --- lib/channel/bus.rs | 14 +- lib/channel/circuit_breaker.rs | 61 ++++++- lib/channel/event/article.rs | 2 +- lib/channel/http/handler/article.rs | 43 ++--- lib/channel/http/handler/message.rs | 116 ++++++++++++- lib/channel/http/handler/mod.rs | 41 ++++- lib/channel/http/handler/thread.rs | 1 + lib/channel/http/out_event.rs | 6 +- lib/channel/http/types.rs | 1 + lib/channel/http/ws.rs | 2 +- lib/channel/lib.rs | 2 + lib/channel/metrics.rs | 51 +++++- lib/channel/richtext.rs | 230 +++++++++++++++----------- lib/channel/security.rs | 34 +++- lib/model/channel/channel_article.rs | 4 - lib/model/channel/mod.rs | 10 +- lib/model/channel/room_attachments.rs | 2 +- lib/model/channel/room_mention.rs | 2 +- lib/model/lib.rs | 2 +- 19 files changed, 473 insertions(+), 151 deletions(-) diff --git a/lib/channel/bus.rs b/lib/channel/bus.rs index 080a573..a3e21ce 100644 --- a/lib/channel/bus.rs +++ b/lib/channel/bus.rs @@ -41,6 +41,7 @@ pub struct Inner { pub cache: AppCache, pub io: SocketIo, pub config: ChannelBusConfig, + pub cdn: crate::CdnManager, pub online: RwLock>>, pub user_sync_locks: DashMap>>, pub typing_states: DashMap< @@ -203,6 +204,8 @@ impl ChannelBus { cache: AppCache, io: SocketIo, config: ChannelBusConfig, + cdn: crate::CdnManager, + metrics_registry: Option, ) -> Self { let seq = match config.seq_segment_size { Some(size) => { @@ -217,7 +220,7 @@ impl ChannelBus { ), ); let reconnect = ReconnectManager::new(cache.clone(), db.clone()); - let rate_limiter = match ( + let mut rate_limiter = match ( config.rate_limit_max_requests, config.rate_limit_window_secs, ) { @@ -229,7 +232,7 @@ impl ChannelBus { _ => RateLimiter::new(cache.clone()), }; let csrf = CsrfProtection::new(cache.clone()); - let circuit_breaker = match ( + let mut circuit_breaker = match ( config.circuit_breaker_failure_threshold, config.circuit_breaker_success_threshold, config.circuit_breaker_timeout_secs, @@ -245,18 +248,23 @@ impl ChannelBus { } _ => CircuitBreaker::new(), }; + if let Some(ref reg) = metrics_registry { + rate_limiter.set_metrics(reg); + circuit_breaker.set_metrics(reg); + } Self { inner: Arc::new(Inner { db, cache, io, config, + cdn, online: RwLock::new(HashMap::new()), user_sync_locks: DashMap::new(), typing_states: DashMap::new(), seq, dedup, - metrics: ChannelMetrics::new(), + metrics: ChannelMetrics::new(metrics_registry), reconnect, rate_limiter, csrf, diff --git a/lib/channel/circuit_breaker.rs b/lib/channel/circuit_breaker.rs index 9fab2dd..f2c4214 100644 --- a/lib/channel/circuit_breaker.rs +++ b/lib/channel/circuit_breaker.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use std::time::{Duration, Instant}; use tokio::sync::Mutex; +use track::CounterVec; use crate::ChannelError; @@ -11,6 +12,34 @@ const STATUS_HALF_OPEN: u8 = 2; #[derive(Clone)] pub struct CircuitBreaker { inner: Arc, + metrics: Option, +} + +#[derive(Clone)] +struct CircuitBreakerMetrics { + transitions: CounterVec, + calls: CounterVec, +} + +impl CircuitBreakerMetrics { + fn new(registry: &track::MetricsRegistry) -> Self { + Self { + transitions: registry + .register_counter_vec( + "circuit_breaker_transitions_total", + "Circuit breaker state transitions", + &["transition"], + ) + .expect("failed to register circuit_breaker_transitions_total"), + calls: registry + .register_counter_vec( + "circuit_breaker_calls_total", + "Circuit breaker call outcomes", + &["outcome"], + ) + .expect("failed to register circuit_breaker_calls_total"), + } + } } struct CircuitState { @@ -61,9 +90,14 @@ impl CircuitBreaker { half_open_max_calls, }, }), + metrics: None, } } + pub fn set_metrics(&mut self, registry: &track::MetricsRegistry) { + self.metrics = Some(CircuitBreakerMetrics::new(registry)); + } + pub async fn call(&self, f: F) -> Result where F: std::future::Future>, @@ -90,20 +124,29 @@ impl CircuitBreaker { false } } - _ => true, // Closed → allow + _ => true, } - }; // Lock released before executing the call. + }; if !slot_reserved { + if let Some(m) = &self.metrics { + m.calls.with_label_values(&["rejected"]).inc(); + } return Err(CircuitBreakerError::Open); } match f.await { Ok(result) => { + if let Some(m) = &self.metrics { + m.calls.with_label_values(&["success"]).inc(); + } self.on_success().await; Ok(result) } Err(e) => { + if let Some(m) = &self.metrics { + m.calls.with_label_values(&["failure"]).inc(); + } self.on_failure().await; Err(CircuitBreakerError::Inner(e)) } @@ -120,9 +163,15 @@ impl CircuitBreaker { state.status = STATUS_CLOSED; state.success_count = 0; state.half_open_calls = 0; + if let Some(m) = &self.metrics { + m.transitions + .with_label_values(&["half_open_to_closed"]) + .inc(); + } } } } + async fn on_failure(&self) { let mut state = self.inner.state.lock().await; state.failure_count += 1; @@ -132,10 +181,18 @@ impl CircuitBreaker { state.status = STATUS_OPEN; state.success_count = 0; state.half_open_calls = 0; + if let Some(m) = &self.metrics { + m.transitions + .with_label_values(&["half_open_to_open"]) + .inc(); + } } else if state.status == STATUS_CLOSED && state.failure_count >= self.inner.config.failure_threshold { state.status = STATUS_OPEN; + if let Some(m) = &self.metrics { + m.transitions.with_label_values(&["closed_to_open"]).inc(); + } } } diff --git a/lib/channel/event/article.rs b/lib/channel/event/article.rs index f2d2487..8ac4938 100644 --- a/lib/channel/event/article.rs +++ b/lib/channel/event/article.rs @@ -1,7 +1,7 @@ use serde::{Deserialize, Serialize}; use uuid::Uuid; -use super::common::{UserInfo, RoomInfo}; +use super::common::{RoomInfo, UserInfo}; /// Created when a user publishes an article in an article channel. #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/lib/channel/http/handler/article.rs b/lib/channel/http/handler/article.rs index 9a94c4d..15362b2 100644 --- a/lib/channel/http/handler/article.rs +++ b/lib/channel/http/handler/article.rs @@ -60,12 +60,10 @@ impl WsHandler { author: author.clone(), }; - bus.publish_room_event(channel, "article.created", &data).await?; + bus.publish_room_event(channel, "article.created", &data) + .await?; - Ok(Some(WsOutEvent::ArticleCreated { - room, - data, - })) + Ok(Some(WsOutEvent::ArticleCreated { room, data })) } pub(super) async fn article_update( @@ -142,12 +140,10 @@ impl WsHandler { channel: room.clone(), }; - bus.publish_room_event(row.channel, "article.updated", &data).await?; + bus.publish_room_event(row.channel, "article.updated", &data) + .await?; - Ok(Some(WsOutEvent::ArticleUpdated { - room, - data, - })) + Ok(Some(WsOutEvent::ArticleUpdated { room, data })) } pub(super) async fn article_delete( @@ -193,12 +189,10 @@ impl WsHandler { deleted_by: deleted_by.clone(), }; - bus.publish_room_event(old.channel, "article.deleted", &data).await?; + bus.publish_room_event(old.channel, "article.deleted", &data) + .await?; - Ok(Some(WsOutEvent::ArticleDeleted { - room, - data, - })) + Ok(Some(WsOutEvent::ArticleDeleted { room, data })) } pub(super) async fn article_list( @@ -383,7 +377,8 @@ impl WsHandler { user, like_count: new_count, }; - bus.publish_room_event(art.channel, "article.liked", &data).await?; + bus.publish_room_event(art.channel, "article.liked", &data) + .await?; Ok(Some(WsOutEvent::ArticleLiked { room, data })) } else { let data = article::ArticleUnlikedService { @@ -392,7 +387,8 @@ impl WsHandler { user, like_count: new_count, }; - bus.publish_room_event(art.channel, "article.unliked", &data).await?; + bus.publish_room_event(art.channel, "article.unliked", &data) + .await?; Ok(Some(WsOutEvent::ArticleUnliked { room, data })) } } @@ -462,7 +458,8 @@ impl WsHandler { comment_count: count_row.0, }; - bus.publish_room_event(art.channel, "article.comment.created", &data).await?; + bus.publish_room_event(art.channel, "article.comment.created", &data) + .await?; Ok(Some(WsOutEvent::ArticleCommentCreated { room, data })) } @@ -601,7 +598,8 @@ impl WsHandler { comment_count: count_row.0, }; - bus.publish_room_event(art.channel, "article.comment.deleted", &data).await?; + bus.publish_room_event(art.channel, "article.comment.deleted", &data) + .await?; Ok(Some(WsOutEvent::ArticleCommentDeleted { room, data })) } @@ -657,7 +655,12 @@ impl WsHandler { let users_map = bus.lookup_users(&user_ids).await?; let users: Vec = user_ids .iter() - .map(|id| users_map.get(id).cloned().unwrap_or_else(|| UserInfo::unknown(*id))) + .map(|id| { + users_map + .get(id) + .cloned() + .unwrap_or_else(|| UserInfo::unknown(*id)) + }) .collect(); Ok(Some(WsOutEvent::ArticleLikedUsers { diff --git a/lib/channel/http/handler/message.rs b/lib/channel/http/handler/message.rs index f702e4c..344d28a 100644 --- a/lib/channel/http/handler/message.rs +++ b/lib/channel/http/handler/message.rs @@ -103,6 +103,89 @@ impl WsHandler { Ok(()) } + /// Store parsed mentions in the room_mention table. + /// For `@all` mentions, also insert notifications for all workspace members. + async fn persist_mentions( + bus: &ChannelBus, + message_id: Uuid, + seq: i64, + mentions: &[crate::richtext::Mention], + sender_id: Uuid, + room: Uuid, + workspace: Uuid, + ) -> ChannelResult<()> { + let mut has_all = false; + + for mention in mentions { + db::sqlx::query( + "INSERT INTO room_mention (message, seq, mention_type, target_id) \ + VALUES ($1, $2, $3, $4)", + ) + .bind(message_id) + .bind(seq) + .bind(&mention.mention_type) + .bind(&mention.target_id) + .execute(bus.inner.db.writer()) + .await?; + + if mention.mention_type == "all" { + has_all = true; + } + } + + // When @all is used, insert notification records for every workspace member + if has_all && !workspace.is_nil() { + let sender = bus + .lookup_user(sender_id) + .await + .unwrap_or_else(|_| crate::event::UserInfo::unknown(sender_id)); + let room_info = bus + .lookup_room(room) + .await + .unwrap_or_else(|_| crate::event::RoomInfo::unknown(room)); + + let sender_name: &str = if sender.display_name.is_empty() { + &sender.username + } else { + &sender.display_name + }; + let title = format!( + "{} mentioned @everyone in #{}", + sender_name, room_info.name + ); + + let members = bus + .list_workspace_members(workspace) + .await + .unwrap_or_default(); + + for (member_id, _username, _display_name, _avatar_url) in &members { + if *member_id == sender_id { + continue; // Don't notify the sender + } + + let notify_id = Uuid::now_v7(); + db::sqlx::query( + "INSERT INTO user_app_notify \ + (id, \"user\", title, body, notify_type, target_type, target_id, \ + created_at, updated_at) \ + VALUES ($1, $2, $3, $4, $5, $6, $7, now(), now())", + ) + .bind(notify_id) + .bind(member_id) + .bind(&title) + .bind(&format!("Message from {}", sender_name)) + .bind("mention_all") + .bind("room") + .bind(room) + .execute(bus.inner.db.writer()) + .await?; + } + } + + Ok(()) + } + pub(super) async fn message_create( bus: &ChannelBus, user_id: Uuid, @@ -111,6 +194,7 @@ impl WsHandler { content_type: Option, thread: Option, in_reply_to: Option, + attachment_ids: Option>, ) -> ChannelResult> { Self::ensure_room_access(bus, user_id, room).await?; if content.len() > MAX_TEXT_LEN { @@ -226,6 +310,10 @@ impl WsHandler { let seq = bus.inner.seq.seq(room).await?; let sender = bus.lookup_user(user_id).await?; let sender_for_response = sender.clone(); + + // Parse mentions from content before inserting + let mentions = crate::richtext::parse_mentions(&content); + let row = db::sqlx::query_as::<_, model::channel::RoomMessageModel>( "INSERT INTO room_message (room, seq, thread, parent, author, content, content_type) \ VALUES ($1, $2, $3, $4, $5, $6, $7) \ @@ -237,11 +325,35 @@ impl WsHandler { .bind(effective_thread) .bind(in_reply_to) .bind(user_id) - .bind(content) - .bind(content_type.unwrap_or_else(|| "text".to_string())) + .bind(&content) + .bind(content_type.clone().unwrap_or_else(|| "text".to_string())) .fetch_one(bus.inner.db.writer()) .await?; + // Store mentions in the room_mention table + let workspace = crate::rooms::room_workspace(&bus.inner.db, room) + .await? + .unwrap_or(Uuid::nil()); + Self::persist_mentions( + bus, row.id, row.seq, &mentions, user_id, room, workspace, + ) + .await?; + + // Link attachments to the created message + if let Some(ref att_ids) = attachment_ids { + if !att_ids.is_empty() { + db::sqlx::query( + "UPDATE room_attachment SET message = $1, seq = $2 WHERE id = ANY($3) AND uploaded_by = $4", + ) + .bind(row.id) + .bind(row.seq) + .bind(att_ids) + .bind(user_id) + .execute(bus.inner.db.writer()) + .await?; + } + } + bus.publish_room_message(row.clone(), Some(sender)).await?; let msg_room = bus .lookup_room(room) diff --git a/lib/channel/http/handler/mod.rs b/lib/channel/http/handler/mod.rs index 3214a37..8301b00 100644 --- a/lib/channel/http/handler/mod.rs +++ b/lib/channel/http/handler/mod.rs @@ -87,6 +87,7 @@ impl WsHandler { content_type, thread, in_reply_to, + attachment_ids, } => { Self::message_create( bus, @@ -96,6 +97,7 @@ impl WsHandler { content_type, thread, in_reply_to, + attachment_ids, ) .await } @@ -117,8 +119,14 @@ impl WsHandler { channel_type, } => { Self::room_create( - bus, user_id, workspace, room_name, public, category, - ai_enabled, channel_type, + bus, + user_id, + workspace, + room_name, + public, + category, + ai_enabled, + channel_type, ) .await } @@ -357,8 +365,16 @@ impl WsHandler { status, } => { Self::article_create( - bus, user_id, channel, title, cover_url, content, - content_type, summary, tags, status, + bus, + user_id, + channel, + title, + cover_url, + content, + content_type, + summary, + tags, + status, ) .await } @@ -374,8 +390,17 @@ impl WsHandler { status, } => { Self::article_update( - bus, user_id, article_id, title, cover_url, content, - content_type, summary, tags, is_pinned, status, + bus, + user_id, + article_id, + title, + cover_url, + content, + content_type, + summary, + tags, + is_pinned, + status, ) .await } @@ -386,9 +411,7 @@ impl WsHandler { channel, before, limit, - } => { - Self::article_list(bus, user_id, channel, before, limit).await - } + } => Self::article_list(bus, user_id, channel, before, limit).await, WsInMessage::ArticleGet { article_id } => { Self::article_get(bus, user_id, article_id).await } diff --git a/lib/channel/http/handler/thread.rs b/lib/channel/http/handler/thread.rs index b02b05b..ea9a3d0 100644 --- a/lib/channel/http/handler/thread.rs +++ b/lib/channel/http/handler/thread.rs @@ -9,6 +9,7 @@ use super::WsOutEvent; /// Helper struct for thread_list JOIN query result #[derive(db::sqlx::FromRow)] +#[allow(dead_code)] struct ThreadListRow { id: Uuid, room: Uuid, diff --git a/lib/channel/http/out_event.rs b/lib/channel/http/out_event.rs index ff5db8a..3a9a649 100644 --- a/lib/channel/http/out_event.rs +++ b/lib/channel/http/out_event.rs @@ -2,9 +2,9 @@ use serde::Serialize; use uuid::Uuid; use crate::event::{ - RoomInfo, WorkspaceInfo, article, attachment, ban, category, conversation, draft, - forward, invite, member, message, message_read, notify, pin, presence, - reaction, rooms, search, star, thread, voice, workspace, + RoomInfo, WorkspaceInfo, article, attachment, ban, category, conversation, + draft, forward, invite, member, message, message_read, notify, pin, + presence, reaction, rooms, search, star, thread, voice, workspace, }; #[derive(Debug, Clone, Serialize)] diff --git a/lib/channel/http/types.rs b/lib/channel/http/types.rs index b80c187..015df5b 100644 --- a/lib/channel/http/types.rs +++ b/lib/channel/http/types.rs @@ -44,6 +44,7 @@ pub enum WsInMessage { content_type: Option, thread: Option, in_reply_to: Option, + attachment_ids: Option>, }, MessageUpdate { message: Uuid, diff --git a/lib/channel/http/ws.rs b/lib/channel/http/ws.rs index 2f34e4e..dd63816 100644 --- a/lib/channel/http/ws.rs +++ b/lib/channel/http/ws.rs @@ -1,7 +1,7 @@ use socketio::{EventPayload, Socket}; use uuid::Uuid; -use crate::{ChannelBus, ChannelError, ChannelResult}; +use crate::{ChannelBus, ChannelError}; use super::handler::WsHandler; use super::out_event::{WsError, WsOutEvent}; diff --git a/lib/channel/lib.rs b/lib/channel/lib.rs index 21baf07..ed8e71d 100644 --- a/lib/channel/lib.rs +++ b/lib/channel/lib.rs @@ -11,6 +11,7 @@ pub mod http; mod metrics; mod pagination; mod reconnect; +pub mod richtext; pub mod rooms; mod search; mod security; @@ -32,6 +33,7 @@ pub use pagination::{ PaginationParams, }; pub use reconnect::{ClientState, MissedMessage, ReconnectManager}; +pub use richtext::{Mention, parse_mentions}; pub use search::{SearchEngine, SearchHit, SearchQuery, SearchResult}; pub use security::{CsrfProtection, RateLimiter}; pub use seq::SeqAllocator; diff --git a/lib/channel/metrics.rs b/lib/channel/metrics.rs index 7f32113..f4d63de 100644 --- a/lib/channel/metrics.rs +++ b/lib/channel/metrics.rs @@ -1,45 +1,94 @@ use std::sync::Arc; +use track::{CounterVec, Gauge}; + #[derive(Clone)] pub struct ChannelMetrics { pub messages_sent: Arc, pub messages_received: Arc, pub messages_failed: Arc, pub active_connections: Arc, + events_total: Option, + active_connections_gauge: Option, } impl ChannelMetrics { - pub fn new() -> Self { + pub fn new(registry: Option) -> Self { + let events_total = registry.as_ref().and_then(|registry| { + registry + .register_counter_vec( + "channel_events_total", + "Total channel socket and message events", + &["event"], + ) + .map_err(|error| { + tracing::warn!(%error, "failed to register channel_events_total"); + error + }) + .ok() + }); + let active_connections_gauge = registry.as_ref().and_then(|registry| { + registry + .register_gauge( + "channel_active_connections", + "Current active channel socket connections", + ) + .map_err(|error| { + tracing::warn!(%error, "failed to register channel_active_connections"); + error + }) + .ok() + }); + Self { messages_sent: Arc::new(std::sync::atomic::AtomicU64::new(0)), messages_received: Arc::new(std::sync::atomic::AtomicU64::new(0)), messages_failed: Arc::new(std::sync::atomic::AtomicU64::new(0)), active_connections: Arc::new(std::sync::atomic::AtomicI64::new(0)), + events_total, + active_connections_gauge, } } pub fn increment_sent(&self) { self.messages_sent .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + self.record_event("sent"); } pub fn increment_received(&self) { self.messages_received .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + self.record_event("received"); } pub fn increment_failed(&self) { self.messages_failed .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + self.record_event("failed"); } pub fn increment_connections(&self) { self.active_connections .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + if let Some(gauge) = &self.active_connections_gauge { + gauge.inc(); + } + self.record_event("connected"); } pub fn decrement_connections(&self) { self.active_connections .fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + if let Some(gauge) = &self.active_connections_gauge { + gauge.dec(); + } + self.record_event("disconnected"); + } + + fn record_event(&self, event: &str) { + if let Some(counter) = &self.events_total { + counter.with_label_values(&[event]).inc(); + } } } diff --git a/lib/channel/richtext.rs b/lib/channel/richtext.rs index 3dc4743..a42ea35 100644 --- a/lib/channel/richtext.rs +++ b/lib/channel/richtext.rs @@ -1,112 +1,150 @@ use serde::{Deserialize, Serialize}; -use uuid::Uuid; -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct RichTextBlock { - pub block_type: BlockType, - pub content: String, - pub attributes: Option, +/// Parsed mention from `@[type:id:label]` IR format. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct Mention { + pub mention_type: String, + pub target_id: String, + pub label: String, } -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "lowercase")] -pub enum BlockType { - Text, - Code, - Quote, - Link, - Mention, - Emoji, - Image, -} +/// Parse all `@[type:id:label]` mentions from content. +/// Returns deduplicated mentions in order of first appearance. +pub fn parse_mentions(content: &str) -> Vec { + let mut mentions = Vec::new(); + let mut seen = std::collections::HashSet::new(); -pub struct RichTextRenderer; + // Simple manual parser for @[type:id:label] + let bytes = content.as_bytes(); + let len = bytes.len(); + let mut i = 0; -impl RichTextRenderer { - pub fn new() -> Self { - Self {} - } + while i < len { + // Look for "@[" + if i + 2 < len && bytes[i] == b'@' && bytes[i + 1] == b'[' { + let start = i + 2; // after "@[" - pub fn parse_markdown(&self, content: &str) -> Vec { - vec![RichTextBlock { - block_type: BlockType::Text, - content: content.to_string(), - attributes: None, - }] - } + // Find first ':'' after start + if let Some(type_end) = content[start..].find(':') { + let mention_type = &content[start..start + type_end]; + let after_type = start + type_end + 1; // after first ':' - pub fn parse_mentions(&self, content: &str) -> Vec { - content - .split_whitespace() - .filter(|w| w.starts_with('@')) - .filter_map(|w| Uuid::parse_str(&w[1..]).ok()) - .collect() - } + // Find second ':' (between id and label) + if let Some(id_end) = content[after_type..].find(':') { + let target_id = &content[after_type..after_type + id_end]; + let after_id = after_type + id_end + 1; // after second ':' - pub fn highlight_code(&self, code: &str, language: &str) -> String { - format!("```{}\n{}\n```", language, code) - } + // Find closing ']' + if let Some(close) = content[after_id..].find(']') { + let label = &content[after_id..after_id + close]; - pub fn render_to_html(&self, blocks: &[RichTextBlock]) -> String { - blocks - .iter() - .map(|block| match block.block_type { - BlockType::Text => { - format!("

{}

", html_escape(&block.content)) - } - BlockType::Code => format!( - "
{}
", - html_escape(&block.content) - ), - BlockType::Quote => format!( - "
{}
", - html_escape(&block.content) - ), - BlockType::Link => { - let safe_href = sanitize_uri(&block.content); - format!( - "{}", - html_escape(&safe_href), - html_escape(&block.content) - ) - } - BlockType::Mention => format!( - "@{}", - html_escape(&block.content) - ), - BlockType::Emoji => format!( - "{}", - html_escape(&block.content) - ), - BlockType::Image => { - let safe_src = sanitize_uri(&block.content); - if safe_src.is_empty() { - String::new() - } else { - format!("", html_escape(&safe_src)) + if !mention_type.is_empty() && !target_id.is_empty() { + let key = format!("{}:{}", mention_type, target_id); + if seen.insert(key) { + mentions.push(Mention { + mention_type: mention_type.to_string(), + target_id: target_id.to_string(), + label: label.to_string(), + }); + } + } + + i = after_id + close + 1; // skip past ']' + continue; } } - }) - .collect::>() - .join("\n") - } -} -fn sanitize_uri(uri: &str) -> String { - let lower = uri.to_lowercase(); - if lower.starts_with("http://") - || lower.starts_with("https://") - || lower.starts_with("mailto:") - { - uri.to_string() - } else { - String::new() + } + } + i += 1; } + + mentions } -fn html_escape(s: &str) -> String { - s.replace('&', "&") - .replace('<', "<") - .replace('>', ">") - .replace('"', """) - .replace('\'', "'") +/// Extract unique target IDs of a specific mention type. +pub fn extract_mention_ids( + mentions: &[Mention], + mention_type: &str, +) -> Vec { + mentions + .iter() + .filter(|m| m.mention_type == mention_type) + .map(|m| m.target_id.clone()) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_single_room_mention() { + let input = "hey check out @[room:abc123:general]"; + let mentions = parse_mentions(input); + assert_eq!(mentions.len(), 1); + assert_eq!(mentions[0].mention_type, "room"); + assert_eq!(mentions[0].target_id, "abc123"); + assert_eq!(mentions[0].label, "general"); + } + + #[test] + fn test_parse_single_repo_mention() { + let input = "look at @[repo:my-repo:my-repo]"; + let mentions = parse_mentions(input); + assert_eq!(mentions.len(), 1); + assert_eq!(mentions[0].mention_type, "repo"); + assert_eq!(mentions[0].target_id, "my-repo"); + } + + #[test] + fn test_parse_multiple_mentions() { + let input = + "compare @[repo:backend:backend] with @[repo:frontend:frontend]"; + let mentions = parse_mentions(input); + assert_eq!(mentions.len(), 2); + assert_eq!(mentions[0].target_id, "backend"); + assert_eq!(mentions[1].target_id, "frontend"); + } + + #[test] + fn test_deduplicate() { + let input = "look at @[repo:a:a] and also @[repo:a:a] please"; + let mentions = parse_mentions(input); + assert_eq!(mentions.len(), 1); + } + + #[test] + fn test_no_mentions() { + let input = "hello world no mentions here"; + let mentions = parse_mentions(input); + assert_eq!(mentions.len(), 0); + } + + #[test] + fn test_mixed_mentions() { + let input = "@[user:abc:John] and @[room:xyz:general] and @[repo:r:r]"; + let mentions = parse_mentions(input); + assert_eq!(mentions.len(), 3); + } + + #[test] + fn test_incomplete_mention_ignored() { + let input = "this @[incomplete is just text"; + let mentions = parse_mentions(input); + assert_eq!(mentions.len(), 0); + } + + #[test] + fn test_empty_input() { + let mentions = parse_mentions(""); + assert_eq!(mentions.len(), 0); + } + + #[test] + fn test_extract_mention_ids() { + let input = "@[repo:a:a] and @[room:b:general] and @[repo:c:c]"; + let mentions = parse_mentions(input); + let repo_ids = extract_mention_ids(&mentions, "repo"); + assert_eq!(repo_ids, vec!["a", "c"]); + } } diff --git a/lib/channel/security.rs b/lib/channel/security.rs index 97a59eb..fb37f1e 100644 --- a/lib/channel/security.rs +++ b/lib/channel/security.rs @@ -1,4 +1,5 @@ use std::time::Duration; +use track::CounterVec; use uuid::Uuid; use crate::{ChannelError, ChannelResult}; @@ -22,6 +23,26 @@ pub struct RateLimiter { cache: cache::AppCache, max_requests: u32, window: Duration, + metrics: Option, +} + +#[derive(Clone)] +struct RateLimiterMetrics { + outcomes: CounterVec, +} + +impl RateLimiterMetrics { + fn new(registry: &track::MetricsRegistry) -> Self { + Self { + outcomes: registry + .register_counter_vec( + "rate_limiter_decisions_total", + "Rate limiter decisions", + &["action", "outcome"], + ) + .expect("failed to register rate_limiter_decisions_total"), + } + } } impl RateLimiter { @@ -30,6 +51,7 @@ impl RateLimiter { cache, max_requests: 100, window: Duration::from_secs(60), + metrics: None, } } @@ -42,9 +64,14 @@ impl RateLimiter { cache, max_requests, window, + metrics: None, } } + pub fn set_metrics(&mut self, registry: &track::MetricsRegistry) { + self.metrics = Some(RateLimiterMetrics::new(registry)); + } + pub async fn check_rate_limit( &self, user_id: Uuid, @@ -65,7 +92,12 @@ impl RateLimiter { .await .map_err(|e| ChannelError::Cache(cache::CacheError::Redis(e)))?; - Ok(allowed == 1) + let is_allowed = allowed == 1; + if let Some(m) = &self.metrics { + let outcome = if is_allowed { "allowed" } else { "blocked" }; + m.outcomes.with_label_values(&[action, outcome]).inc(); + } + Ok(is_allowed) } pub async fn get_remaining( diff --git a/lib/model/channel/channel_article.rs b/lib/model/channel/channel_article.rs index e7a699f..131f0d3 100644 --- a/lib/model/channel/channel_article.rs +++ b/lib/model/channel/channel_article.rs @@ -3,7 +3,6 @@ use serde::{Deserialize, Serialize}; use sqlx::FromRow; use uuid::Uuid; - #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] pub struct ChannelArticleModel { pub id: Uuid, @@ -26,7 +25,6 @@ pub struct ChannelArticleModel { pub deleted_at: Option>, } - #[derive(Debug, Clone, Serialize, Deserialize)] pub struct CreateArticlePayload { pub channel: Uuid, @@ -39,7 +37,6 @@ pub struct CreateArticlePayload { pub status: Option, } - #[derive(Debug, Clone, Serialize, Deserialize)] pub struct UpdateArticlePayload { pub title: Option, @@ -52,7 +49,6 @@ pub struct UpdateArticlePayload { pub status: Option, } - #[derive(Debug, Clone, Serialize, Deserialize, FromRow)] pub struct ChannelArticleCard { pub id: Uuid, diff --git a/lib/model/channel/mod.rs b/lib/model/channel/mod.rs index 47fc29e..999e61d 100644 --- a/lib/model/channel/mod.rs +++ b/lib/model/channel/mod.rs @@ -15,19 +15,19 @@ pub mod room_server_label; pub mod room_threads; pub mod user_room_state; -pub use message_read::MessageReadModel; -pub use message_star::MessageStarModel; pub use channel::ChannelModel; pub use channel::ChannelType; -pub use channel_article::ChannelArticleModel; pub use channel_article::ChannelArticleCard; +pub use channel_article::ChannelArticleModel; pub use channel_article::CreateArticlePayload; pub use channel_article::UpdateArticlePayload; -pub use channel_article_interact::ArticleLikeModel; -pub use channel_article_interact::ArticleCommentModel; pub use channel_article_interact::ArticleCommentItem; pub use channel_article_interact::ArticleCommentList; +pub use channel_article_interact::ArticleCommentModel; +pub use channel_article_interact::ArticleLikeModel; pub use channel_article_interact::CreateCommentPayload; +pub use message_read::MessageReadModel; +pub use message_star::MessageStarModel; pub use room_attachments::RoomAttachmentModel; pub use room_categories::RoomCategoryModel; pub use room_mention::RoomMentionModel; diff --git a/lib/model/channel/room_attachments.rs b/lib/model/channel/room_attachments.rs index 58cb77e..d396eaf 100644 --- a/lib/model/channel/room_attachments.rs +++ b/lib/model/channel/room_attachments.rs @@ -6,7 +6,7 @@ use uuid::Uuid; #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] pub struct RoomAttachmentModel { pub id: Uuid, - pub message: Uuid, + pub message: Option, pub seq: i64, pub file_name: String, pub content_type: Option, diff --git a/lib/model/channel/room_mention.rs b/lib/model/channel/room_mention.rs index 2e73bf0..2e650ac 100644 --- a/lib/model/channel/room_mention.rs +++ b/lib/model/channel/room_mention.rs @@ -9,6 +9,6 @@ pub struct RoomMentionModel { pub message: Uuid, pub seq: i64, pub mention_type: String, - pub target_id: Uuid, + pub target_id: String, pub created_at: DateTime, } diff --git a/lib/model/lib.rs b/lib/model/lib.rs index 2aecdbe..b49cddf 100644 --- a/lib/model/lib.rs +++ b/lib/model/lib.rs @@ -2,12 +2,12 @@ use db::AppDatabase; pub mod agent; pub mod ai; +pub mod channel; pub mod issues; pub mod logs; pub mod notify; pub mod pull_request; pub mod repos; -pub mod channel; pub mod system; pub mod users; pub mod workspace;