use chrono::Utc; use uuid::Uuid; use crate::event::{RoomInfo, UserInfo, message, thread}; use crate::{ ChannelBus, ChannelError, ChannelResult, pagination::{MessagePagination, PaginationDirection, PaginationParams}, }; use super::{MAX_MESSAGES_PER_REQUEST, MAX_TEXT_LEN}; use super::WsOutEvent; use super::WsHandler; impl WsHandler { /// Count non-deleted sibling replies to the same parent message. async fn count_sibling_replies( bus: &ChannelBus, parent_id: Uuid, ) -> ChannelResult { let (count,): (i64,) = db::sqlx::query_as( "SELECT COUNT(*) FROM room_message WHERE parent = $1 AND deleted_at IS NULL", ) .bind(parent_id) .fetch_one(bus.inner.db.reader()) .await?; Ok(count) } /// Walk the reply parent chain and return (root_message_id, root_message_seq, chain_depth). async fn reply_chain_info( bus: &ChannelBus, parent_id: Uuid, ) -> ChannelResult<(Uuid, i64, i32)> { let rows: Vec<(Uuid, i64, i32)> = db::sqlx::query_as( r#"WITH RECURSIVE chain AS ( SELECT id, parent, seq, 1 AS depth FROM room_message WHERE id = $1 AND deleted_at IS NULL UNION ALL SELECT m.id, m.parent, m.seq, c.depth + 1 FROM room_message m JOIN chain c ON m.id = c.parent WHERE m.deleted_at IS NULL ) SELECT id, seq, depth FROM chain ORDER BY depth DESC"#, ) .bind(parent_id) .fetch_all(bus.inner.db.reader()) .await?; let root_id = rows.first().map(|r| r.0).unwrap_or(parent_id); let root_seq = rows.first().map(|r| r.1).unwrap_or(0); let max_depth = rows.first().map(|r| r.2).unwrap_or(1); Ok((root_id, root_seq, max_depth)) } /// Check if any message in the reply parent chain already belongs to a thread. async fn find_thread_in_chain( bus: &ChannelBus, parent_id: Uuid, ) -> ChannelResult> { let row: Option<(Uuid,)> = db::sqlx::query_as( r#"WITH RECURSIVE chain AS ( SELECT id, parent, thread FROM room_message WHERE id = $1 AND deleted_at IS NULL UNION ALL SELECT m.id, m.parent, m.thread FROM room_message m JOIN chain c ON m.id = c.parent WHERE m.deleted_at IS NULL ) SELECT thread FROM chain WHERE thread IS NOT NULL LIMIT 1"#, ) .bind(parent_id) .fetch_optional(bus.inner.db.reader()) .await?; Ok(row.map(|r| r.0)) } /// Update all messages in the reply parent chain to point to the given thread. async fn attach_chain_to_thread( bus: &ChannelBus, parent_id: Uuid, thread_id: Uuid, ) -> ChannelResult<()> { db::sqlx::query( r#"WITH RECURSIVE chain AS ( SELECT id FROM room_message WHERE id = $1 AND deleted_at IS NULL UNION ALL SELECT m.id FROM room_message m JOIN chain c ON m.id = c.parent WHERE m.deleted_at IS NULL ) UPDATE room_message SET thread = $2, updated_at = now() WHERE id IN (SELECT id FROM chain) AND thread IS NULL"#, ) .bind(parent_id) .bind(thread_id) .execute(bus.inner.db.writer()) .await?; Ok(()) } pub(super) async fn message_create( bus: &ChannelBus, user_id: Uuid, room: Uuid, content: String, content_type: Option, thread: Option, in_reply_to: Option, ) -> ChannelResult> { Self::ensure_room_access(bus, user_id, room).await?; if content.len() > MAX_TEXT_LEN { return Err(ChannelError::Validation( "message exceeds maximum length".to_string(), )); } if let Some(parent_message) = in_reply_to { Self::ensure_message_in_room(bus, room, parent_message).await?; } // ── Auto-thread logic ────────────────────────────────────────── let mut events: Vec = Vec::new(); let effective_thread: Option = if let Some(ref parent_id) = in_reply_to { if thread.is_some() { thread } else { let existing = Self::find_thread_in_chain(bus, *parent_id).await?; if let Some(tid) = existing { Some(tid) } else { let sibling_count = Self::count_sibling_replies(bus, *parent_id).await?; let (root_id, root_seq, chain_depth) = Self::reply_chain_info(bus, *parent_id).await?; let should_create = sibling_count >= 3 || chain_depth >= 5; if should_create { let seq = bus.inner.seq.seq(room).await?; let thread_row = db::sqlx::query_as::<_, model::room::RoomThreadModel>( "INSERT INTO room_thread (room, seq, starter_message, title, created_by, created_at, updated_at) \ VALUES ($1, $2, $3, '', $4, now(), now()) \ RETURNING id, room, seq, starter_message, title, created_by, archived, locked, \ last_message_at, created_at, updated_at, archived_at", ) .bind(room) .bind(seq) .bind(root_id) // UUID of the root message .bind(user_id) .fetch_one(bus.inner.db.writer()) .await?; let new_thread_id = thread_row.id; Self::attach_chain_to_thread(bus, *parent_id, new_thread_id).await?; let tc_room = bus.lookup_room(room).await .unwrap_or_else(|_| RoomInfo::unknown(room)); let created_by = bus.lookup_user(user_id).await .unwrap_or_else(|_| UserInfo::unknown(user_id)); let data = thread::ThreadCreatedService { id: new_thread_id, room: tc_room, parent: root_seq, created_by, participants: serde_json::Value::Null, created_at: thread_row.created_at, }; bus.publish_room_event(room, "thread.created", &data).await?; events.push(WsOutEvent::ThreadCreated { room: data.room.clone(), data, }); Some(new_thread_id) } else { None } } } } else { thread }; // ── End auto-thread logic ────────────────────────────────────── if let Some(thread_id) = effective_thread { let thread_row: Option<(bool, bool)> = db::sqlx::query_as( "SELECT locked, archived FROM room_thread WHERE id = $1 AND room = $2", ) .bind(thread_id) .bind(room) .fetch_optional(bus.inner.db.reader()) .await?; match thread_row { None => return Err(ChannelError::RoomNotFound), Some((locked, archived)) => { if locked { return Err(ChannelError::Validation( "thread is resolved".to_string(), )); } if archived { return Err(ChannelError::Validation( "thread is archived".to_string(), )); } } } } let seq = bus.inner.seq.seq(room).await?; let sender = bus.lookup_user(user_id).await?; let sender_for_response = sender.clone(); let row = db::sqlx::query_as::<_, model::room::RoomMessageModel>( "INSERT INTO room_message (room, seq, thread, parent, author, content, content_type) \ VALUES ($1, $2, $3, $4, $5, $6, $7) \ RETURNING id, room, seq, thread, parent, author, content, content_type, pinned, \ system_type, metadata, edited_at, created_at, updated_at, deleted_at", ) .bind(room) .bind(seq) .bind(effective_thread) .bind(in_reply_to) .bind(user_id) .bind(content) .bind(content_type.unwrap_or_else(|| "text".to_string())) .fetch_one(bus.inner.db.writer()) .await?; bus.publish_room_message( row.clone(), Some(sender), ).await?; let msg_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); events.push(WsOutEvent::MessageNew { room: msg_room.clone(), data: message::MessageNewService { id: row.id, seq: row.seq, room: msg_room, sender_type: "user".to_string(), sender: sender_for_response, thread: row.thread, in_reply_to: row.parent, content: row.content, content_type: row.content_type, pinned: row.pinned, system_type: row.system_type, metadata: row.metadata, thinking_content: None, thinking_is_chunked: None, send_at: row.created_at, reactions: vec![], }, }); Ok(events.into_iter().find(|e| matches!(e, WsOutEvent::MessageNew { .. }))) } pub(super) async fn message_update( bus: &ChannelBus, user_id: Uuid, message_id: Uuid, content: String, ) -> ChannelResult> { if content.len() > MAX_TEXT_LEN { return Err(ChannelError::Validation( "message exceeds maximum length".to_string(), )); } let room_id: (Uuid,) = db::sqlx::query_as("SELECT room FROM room_message WHERE id = $1 AND deleted_at IS NULL") .bind(message_id) .fetch_optional(bus.inner.db.reader()) .await? .ok_or(ChannelError::RoomNotFound)?; Self::ensure_room_access(bus, user_id, room_id.0).await?; let old = Self::load_message(bus, message_id).await?; if old.author != user_id { return Err(ChannelError::Unauthorized); } let row = db::sqlx::query_as::<_, model::room::RoomMessageModel>( "UPDATE room_message SET content = $2, edited_at = now(), updated_at = now() \ WHERE id = $1 AND deleted_at IS NULL \ RETURNING id, room, seq, thread, parent, author, content, content_type, pinned, \ system_type, metadata, edited_at, created_at, updated_at, deleted_at", ) .bind(message_id) .bind(content) .fetch_one(bus.inner.db.writer()) .await?; db::sqlx::query( "INSERT INTO room_message_edit_history (message, seq, editor, old_content, new_content) \ VALUES ($1, $2, $3, $4, $5)", ) .bind(message_id) .bind(row.seq) .bind(user_id) .bind(old.content) .bind(row.content.clone()) .execute(bus.inner.db.writer()) .await?; let sender = bus.lookup_user(row.author).await.unwrap_or_else(|_| UserInfo::unknown(row.author)); let room = bus.lookup_room(row.room).await.unwrap_or_else(|_| RoomInfo::unknown(row.room)); let data = message::MessageEditedService { id: row.id, seq: row.seq, room, sender, content: row.content, edited_at: row.edited_at.unwrap_or_else(Utc::now), }; bus.publish_room_event(row.room, "message.edited", &data) .await?; Ok(Some(WsOutEvent::MessageEdited { room: data.room.clone(), data, })) } pub(super) async fn message_revoke( bus: &ChannelBus, user_id: Uuid, message_id: Uuid, ) -> ChannelResult> { let room_id: (Uuid,) = db::sqlx::query_as("SELECT room FROM room_message WHERE id = $1 AND deleted_at IS NULL") .bind(message_id) .fetch_optional(bus.inner.db.reader()) .await? .ok_or(ChannelError::RoomNotFound)?; Self::ensure_room_access(bus, user_id, room_id.0).await?; let old = Self::load_message(bus, message_id).await?; if old.author != user_id { return Err(ChannelError::Unauthorized); } if let Some(window_secs) = bus.inner.config.revoke_window_secs { let elapsed = Utc::now().signed_duration_since(old.created_at); if elapsed.num_seconds() > window_secs as i64 { return Err(ChannelError::Validation( "message revoke window expired".to_string(), )); } } let row = db::sqlx::query_as::<_, model::room::RoomMessageModel>( "UPDATE room_message SET deleted_at = now(), updated_at = now() \ WHERE id = $1 AND deleted_at IS NULL \ RETURNING id, room, seq, thread, parent, author, content, content_type, pinned, \ system_type, metadata, edited_at, created_at, updated_at, deleted_at", ) .bind(message_id) .fetch_one(bus.inner.db.writer()) .await?; let revoked_by = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id)); let room = bus.lookup_room(row.room).await.unwrap_or_else(|_| RoomInfo::unknown(row.room)); let data = message::MessageRevokedService { id: row.id, seq: row.seq, room, revoked_by, revoked_at: row.deleted_at.unwrap_or_else(Utc::now), }; bus.publish_room_event(row.room, "message.revoked", &data) .await?; Ok(Some(WsOutEvent::MessageRevoked { room: data.room.clone(), data, })) } pub(super) async fn load_message( bus: &ChannelBus, message_id: Uuid, ) -> ChannelResult { db::sqlx::query_as::<_, model::room::RoomMessageModel>( "SELECT id, room, seq, thread, parent, author, content, content_type, pinned, \ system_type, metadata, edited_at, created_at, updated_at, deleted_at \ FROM room_message WHERE id = $1 AND deleted_at IS NULL", ) .bind(message_id) .fetch_one(bus.inner.db.reader()) .await .map_err(|e| match e { db::sqlx::Error::RowNotFound => ChannelError::RoomNotFound, other => ChannelError::Database(other), }) } pub(super) async fn message_list( bus: &ChannelBus, user_id: Uuid, room: Uuid, before_seq: Option, after_seq: Option, limit: Option, ) -> ChannelResult> { Self::ensure_room_access(bus, user_id, room).await?; let limit = limit.unwrap_or(MAX_MESSAGES_PER_REQUEST); let direction = if before_seq.is_some() { PaginationDirection::Before } else { PaginationDirection::After }; let cursor = before_seq .map(|s| s.to_string()) .or(after_seq.map(|s| s.to_string())); let pagination = MessagePagination::new(bus.inner.db.clone()); let page = pagination .get_messages(PaginationParams { room_id: room, limit, cursor, direction, }) .await?; let mut page_messages = page.messages; if before_seq.is_some() || (before_seq.is_none() && after_seq.is_none()) { page_messages.reverse(); } let message_ids: Vec = page_messages.iter().map(|m| m.id).collect(); let reactions = Self::reaction_groups_for_messages(bus, user_id, &message_ids) .await .unwrap_or_default(); let list_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); let mut messages: Vec = Vec::with_capacity(page_messages.len()); for m in page_messages { let sender = bus.lookup_user(m.sender_id).await .unwrap_or_else(|_| UserInfo::unknown(m.sender_id)); messages.push(message::MessageNewService { id: m.id, seq: m.seq, room: list_room.clone(), sender_type: "user".to_string(), sender, thread: m.thread, in_reply_to: m.parent, content: m.content, content_type: m.content_type, pinned: m.pinned, system_type: m.system_type, metadata: m.metadata, thinking_content: None, thinking_is_chunked: None, send_at: m.send_at, reactions: reactions.get(&m.id).cloned().unwrap_or_default(), }); } let total = messages.len() as i64; Ok(Some(WsOutEvent::MessageList { room: list_room.clone(), data: message::MessageListService { room: list_room, messages, total, }, })) } pub(super) async fn message_around( bus: &ChannelBus, user_id: Uuid, room: Uuid, seq: i64, limit: Option, thread: Option, ) -> ChannelResult> { Self::ensure_room_access(bus, user_id, room).await?; let size = limit .unwrap_or(MAX_MESSAGES_PER_REQUEST) .min(MAX_MESSAGES_PER_REQUEST) as i64; let rows = if let Some(tid) = thread { // Pre-fetch starter_message to avoid per-row subquery let starter: Option<(Uuid,)> = db::sqlx::query_as( "SELECT starter_message FROM room_thread WHERE id = $1", ) .bind(tid) .fetch_optional(bus.inner.db.reader()) .await?; let starter_id = starter.map(|r| r.0); db::sqlx::query_as::<_, model::room::RoomMessageModel>( "(SELECT id, room, seq, thread, parent, author, content, content_type, pinned, \ system_type, metadata, edited_at, created_at, updated_at, deleted_at \ FROM room_message \ WHERE room = $1 AND seq < $2 AND deleted_at IS NULL \ AND (thread = $4 OR ($5 IS NOT NULL AND id = $5)) \ ORDER BY seq DESC LIMIT $3) \ UNION ALL \ (SELECT id, room, seq, thread, parent, author, content, content_type, pinned, \ system_type, metadata, edited_at, created_at, updated_at, deleted_at \ FROM room_message \ WHERE room = $1 AND seq >= $2 AND deleted_at IS NULL \ AND (thread = $4 OR ($5 IS NOT NULL AND id = $5)) \ ORDER BY seq ASC LIMIT $3) \ ORDER BY seq ASC", ) .bind(room) .bind(seq) .bind(size) .bind(tid) .bind(starter_id) .fetch_all(bus.inner.db.reader()) .await? } else { db::sqlx::query_as::<_, model::room::RoomMessageModel>( "(SELECT id, room, seq, thread, parent, author, content, content_type, pinned, \ system_type, metadata, edited_at, created_at, updated_at, deleted_at \ FROM room_message \ WHERE room = $1 AND seq < $2 AND deleted_at IS NULL AND thread IS NULL \ ORDER BY seq DESC LIMIT $3) \ UNION ALL \ (SELECT id, room, seq, thread, parent, author, content, content_type, pinned, \ system_type, metadata, edited_at, created_at, updated_at, deleted_at \ FROM room_message \ WHERE room = $1 AND seq >= $2 AND deleted_at IS NULL AND thread IS NULL \ ORDER BY seq ASC LIMIT $3) \ ORDER BY seq ASC", ) .bind(room) .bind(seq) .bind(size) .fetch_all(bus.inner.db.reader()) .await? }; let author_ids: Vec = rows.iter().map(|r| r.author).collect(); let user_map = bus.lookup_users(&author_ids).await.unwrap_or_default(); let message_ids: Vec = rows.iter().map(|r| r.id).collect(); let reactions = Self::reaction_groups_for_messages(bus, user_id, &message_ids) .await .unwrap_or_default(); let around_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); let messages = rows .into_iter() .map(|r| { let sender = user_map .get(&r.author) .cloned() .unwrap_or_else(|| UserInfo::unknown(r.author)); message::MessageNewService { id: r.id, seq: r.seq, room: around_room.clone(), sender_type: "user".to_string(), sender, thread: r.thread, in_reply_to: r.parent, content: r.content, content_type: r.content_type, pinned: r.pinned, system_type: r.system_type, metadata: r.metadata, thinking_content: None, thinking_is_chunked: None, send_at: r.created_at, reactions: reactions.get(&r.id).cloned().unwrap_or_default(), } }) .collect::>(); let total = messages.len() as i64; Ok(Some(WsOutEvent::MessageList { room: around_room.clone(), data: message::MessageListService { room: around_room, messages, total, }, })) } pub(super) async fn missed_messages( bus: &ChannelBus, user_id: Uuid, room: Uuid, after_seq: i64, limit: Option, ) -> ChannelResult> { Self::ensure_room_access(bus, user_id, room).await?; let limit = limit.unwrap_or(MAX_MESSAGES_PER_REQUEST as i64).max(0) as usize; let messages = bus .inner .reconnect .get_missed_messages(room, after_seq) .await?; let author_ids: Vec = messages.iter().map(|m| m.sender_id).collect(); let message_ids: Vec = messages.iter().map(|m| m.message_id).collect(); let user_map = bus.lookup_users(&author_ids).await.unwrap_or_default(); let reactions = Self::reaction_groups_for_messages(bus, user_id, &message_ids) .await .unwrap_or_default(); let missed_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room)); let messages = messages .into_iter() .take(limit) .map(|m| { let sender = user_map .get(&m.sender_id) .cloned() .unwrap_or_else(|| UserInfo::unknown(m.sender_id)); message::MessageNewService { id: m.message_id, seq: m.seq, room: missed_room.clone(), sender_type: "user".to_string(), sender, thread: m.thread, in_reply_to: m.parent, content: m.content, content_type: m.content_type, pinned: m.pinned, system_type: m.system_type, metadata: m.metadata, thinking_content: None, thinking_is_chunked: None, send_at: m.send_at, reactions: reactions.get(&m.message_id).cloned().unwrap_or_default(), } }) .collect::>(); let data = message::MessageListService { room: missed_room.clone(), total: messages.len() as i64, messages, }; Ok(Some(WsOutEvent::MessageList { room: missed_room, data, })) } }