gitdataai/lib/channel/http/handler/message.rs
2026-05-30 01:38:40 +08:00

591 lines
23 KiB
Rust

use chrono::Utc;
use uuid::Uuid;
use crate::event::{RoomInfo, UserInfo, message, thread};
use crate::{
ChannelBus, ChannelError, ChannelResult,
pagination::{MessagePagination, PaginationDirection, PaginationParams},
};
use super::{MAX_MESSAGES_PER_REQUEST, MAX_TEXT_LEN};
use super::WsOutEvent;
use super::WsHandler;
impl WsHandler {
/// Count non-deleted sibling replies to the same parent message.
async fn count_sibling_replies(
bus: &ChannelBus,
parent_id: Uuid,
) -> ChannelResult<i64> {
let (count,): (i64,) = db::sqlx::query_as(
"SELECT COUNT(*) FROM room_message WHERE parent = $1 AND deleted_at IS NULL",
)
.bind(parent_id)
.fetch_one(bus.inner.db.reader())
.await?;
Ok(count)
}
/// Walk the reply parent chain and return (root_message_id, root_message_seq, chain_depth).
async fn reply_chain_info(
bus: &ChannelBus,
parent_id: Uuid,
) -> ChannelResult<(Uuid, i64, i32)> {
let rows: Vec<(Uuid, i64, i32)> = db::sqlx::query_as(
r#"WITH RECURSIVE chain AS (
SELECT id, parent, seq, 1 AS depth
FROM room_message
WHERE id = $1 AND deleted_at IS NULL
UNION ALL
SELECT m.id, m.parent, m.seq, c.depth + 1
FROM room_message m
JOIN chain c ON m.id = c.parent
WHERE m.deleted_at IS NULL
)
SELECT id, seq, depth FROM chain ORDER BY depth DESC"#,
)
.bind(parent_id)
.fetch_all(bus.inner.db.reader())
.await?;
let root_id = rows.first().map(|r| r.0).unwrap_or(parent_id);
let root_seq = rows.first().map(|r| r.1).unwrap_or(0);
let max_depth = rows.first().map(|r| r.2).unwrap_or(1);
Ok((root_id, root_seq, max_depth))
}
/// Check if any message in the reply parent chain already belongs to a thread.
async fn find_thread_in_chain(
bus: &ChannelBus,
parent_id: Uuid,
) -> ChannelResult<Option<Uuid>> {
let row: Option<(Uuid,)> = db::sqlx::query_as(
r#"WITH RECURSIVE chain AS (
SELECT id, parent, thread
FROM room_message
WHERE id = $1 AND deleted_at IS NULL
UNION ALL
SELECT m.id, m.parent, m.thread
FROM room_message m
JOIN chain c ON m.id = c.parent
WHERE m.deleted_at IS NULL
)
SELECT thread FROM chain WHERE thread IS NOT NULL LIMIT 1"#,
)
.bind(parent_id)
.fetch_optional(bus.inner.db.reader())
.await?;
Ok(row.map(|r| r.0))
}
/// Update all messages in the reply parent chain to point to the given thread.
async fn attach_chain_to_thread(
bus: &ChannelBus,
parent_id: Uuid,
thread_id: Uuid,
) -> ChannelResult<()> {
db::sqlx::query(
r#"WITH RECURSIVE chain AS (
SELECT id FROM room_message
WHERE id = $1 AND deleted_at IS NULL
UNION ALL
SELECT m.id FROM room_message m
JOIN chain c ON m.id = c.parent
WHERE m.deleted_at IS NULL
)
UPDATE room_message SET thread = $2, updated_at = now()
WHERE id IN (SELECT id FROM chain) AND thread IS NULL"#,
)
.bind(parent_id)
.bind(thread_id)
.execute(bus.inner.db.writer())
.await?;
Ok(())
}
pub(super) async fn message_create(
bus: &ChannelBus,
user_id: Uuid,
room: Uuid,
content: String,
content_type: Option<String>,
thread: Option<Uuid>,
in_reply_to: Option<Uuid>,
) -> ChannelResult<Option<WsOutEvent>> {
Self::ensure_room_access(bus, user_id, room).await?;
if content.len() > MAX_TEXT_LEN {
return Err(ChannelError::Validation(
"message exceeds maximum length".to_string(),
));
}
if let Some(parent_message) = in_reply_to {
Self::ensure_message_in_room(bus, room, parent_message).await?;
}
// ── Auto-thread logic ──────────────────────────────────────────
let mut events: Vec<WsOutEvent> = Vec::new();
let effective_thread: Option<Uuid> = if let Some(ref parent_id) = in_reply_to {
if thread.is_some() {
thread
} else {
let existing = Self::find_thread_in_chain(bus, *parent_id).await?;
if let Some(tid) = existing {
Some(tid)
} else {
let sibling_count = Self::count_sibling_replies(bus, *parent_id).await?;
let (root_id, root_seq, chain_depth) = Self::reply_chain_info(bus, *parent_id).await?;
let should_create = sibling_count >= 3 || chain_depth >= 5;
if should_create {
let seq = bus.inner.seq.seq(room).await?;
let thread_row = db::sqlx::query_as::<_, model::room::RoomThreadModel>(
"INSERT INTO room_thread (room, seq, starter_message, title, created_by, created_at, updated_at) \
VALUES ($1, $2, $3, '', $4, now(), now()) \
RETURNING id, room, seq, starter_message, title, created_by, archived, locked, \
last_message_at, created_at, updated_at, archived_at",
)
.bind(room)
.bind(seq)
.bind(root_id) // UUID of the root message
.bind(user_id)
.fetch_one(bus.inner.db.writer())
.await?;
let new_thread_id = thread_row.id;
Self::attach_chain_to_thread(bus, *parent_id, new_thread_id).await?;
let tc_room = bus.lookup_room(room).await
.unwrap_or_else(|_| RoomInfo::unknown(room));
let created_by = bus.lookup_user(user_id).await
.unwrap_or_else(|_| UserInfo::unknown(user_id));
let data = thread::ThreadCreatedService {
id: new_thread_id,
room: tc_room,
parent: root_seq,
created_by,
participants: serde_json::Value::Null,
created_at: thread_row.created_at,
};
bus.publish_room_event(room, "thread.created", &data).await?;
events.push(WsOutEvent::ThreadCreated {
room: data.room.clone(),
data,
});
Some(new_thread_id)
} else {
None
}
}
}
} else {
thread
};
// ── End auto-thread logic ──────────────────────────────────────
if let Some(thread_id) = effective_thread {
let exists: Option<(Uuid,)> = db::sqlx::query_as(
"SELECT id FROM room_thread WHERE id = $1 AND room = $2",
)
.bind(thread_id)
.bind(room)
.fetch_optional(bus.inner.db.reader())
.await?;
if exists.is_none() {
return Err(ChannelError::RoomNotFound);
}
}
let seq = bus.inner.seq.seq(room).await?;
let sender = bus.lookup_user(user_id).await?;
let sender_for_response = sender.clone();
let row = db::sqlx::query_as::<_, model::room::RoomMessageModel>(
"INSERT INTO room_message (room, seq, thread, parent, author, content, content_type) \
VALUES ($1, $2, $3, $4, $5, $6, $7) \
RETURNING id, room, seq, thread, parent, author, content, content_type, pinned, \
system_type, metadata, edited_at, created_at, updated_at, deleted_at",
)
.bind(room)
.bind(seq)
.bind(effective_thread)
.bind(in_reply_to)
.bind(user_id)
.bind(content)
.bind(content_type.unwrap_or_else(|| "text".to_string()))
.fetch_one(bus.inner.db.writer())
.await?;
bus.publish_room_message(
row.clone(),
Some(sender),
).await?;
let msg_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room));
events.push(WsOutEvent::MessageNew {
room: msg_room.clone(),
data: message::MessageNewService {
id: row.id,
seq: row.seq,
room: msg_room,
sender_type: "user".to_string(),
sender: sender_for_response,
thread: row.thread,
in_reply_to: row.parent,
content: row.content,
content_type: row.content_type,
pinned: row.pinned,
system_type: row.system_type,
metadata: row.metadata,
thinking_content: None,
thinking_is_chunked: None,
send_at: row.created_at,
reactions: vec![],
},
});
Ok(events.into_iter().find(|e| matches!(e, WsOutEvent::MessageNew { .. })))
}
pub(super) async fn message_update(
bus: &ChannelBus,
user_id: Uuid,
message_id: Uuid,
content: String,
) -> ChannelResult<Option<WsOutEvent>> {
if content.len() > MAX_TEXT_LEN {
return Err(ChannelError::Validation(
"message exceeds maximum length".to_string(),
));
}
let room_id: (Uuid,) =
db::sqlx::query_as("SELECT room FROM room_message WHERE id = $1 AND deleted_at IS NULL")
.bind(message_id)
.fetch_optional(bus.inner.db.reader())
.await?
.ok_or(ChannelError::RoomNotFound)?;
Self::ensure_room_access(bus, user_id, room_id.0).await?;
let old = Self::load_message(bus, message_id).await?;
if old.author != user_id {
return Err(ChannelError::Unauthorized);
}
let row = db::sqlx::query_as::<_, model::room::RoomMessageModel>(
"UPDATE room_message SET content = $2, edited_at = now(), updated_at = now() \
WHERE id = $1 AND deleted_at IS NULL \
RETURNING id, room, seq, thread, parent, author, content, content_type, pinned, \
system_type, metadata, edited_at, created_at, updated_at, deleted_at",
)
.bind(message_id)
.bind(content)
.fetch_one(bus.inner.db.writer())
.await?;
db::sqlx::query(
"INSERT INTO room_message_edit_history (message, seq, editor, old_content, new_content) \
VALUES ($1, $2, $3, $4, $5)",
)
.bind(message_id)
.bind(row.seq)
.bind(user_id)
.bind(old.content)
.bind(row.content.clone())
.execute(bus.inner.db.writer())
.await?;
let sender = bus.lookup_user(row.author).await.unwrap_or_else(|_| UserInfo::unknown(row.author));
let room = bus.lookup_room(row.room).await.unwrap_or_else(|_| RoomInfo::unknown(row.room));
let data = message::MessageEditedService {
id: row.id,
seq: row.seq,
room,
sender,
content: row.content,
edited_at: row.edited_at.unwrap_or_else(Utc::now),
};
bus.publish_room_event(row.room, "message.edited", &data)
.await?;
Ok(Some(WsOutEvent::MessageEdited {
room: data.room.clone(),
data,
}))
}
pub(super) async fn message_revoke(
bus: &ChannelBus,
user_id: Uuid,
message_id: Uuid,
) -> ChannelResult<Option<WsOutEvent>> {
let room_id: (Uuid,) =
db::sqlx::query_as("SELECT room FROM room_message WHERE id = $1 AND deleted_at IS NULL")
.bind(message_id)
.fetch_optional(bus.inner.db.reader())
.await?
.ok_or(ChannelError::RoomNotFound)?;
Self::ensure_room_access(bus, user_id, room_id.0).await?;
let old = Self::load_message(bus, message_id).await?;
if old.author != user_id {
return Err(ChannelError::Unauthorized);
}
if let Some(window_secs) = bus.inner.config.revoke_window_secs {
let elapsed = Utc::now().signed_duration_since(old.created_at);
if elapsed.num_seconds() > window_secs as i64 {
return Err(ChannelError::Validation(
"message revoke window expired".to_string(),
));
}
}
let row = db::sqlx::query_as::<_, model::room::RoomMessageModel>(
"UPDATE room_message SET deleted_at = now(), updated_at = now() \
WHERE id = $1 AND deleted_at IS NULL \
RETURNING id, room, seq, thread, parent, author, content, content_type, pinned, \
system_type, metadata, edited_at, created_at, updated_at, deleted_at",
)
.bind(message_id)
.fetch_one(bus.inner.db.writer())
.await?;
let revoked_by = bus.lookup_user(user_id).await.unwrap_or_else(|_| UserInfo::unknown(user_id));
let room = bus.lookup_room(row.room).await.unwrap_or_else(|_| RoomInfo::unknown(row.room));
let data = message::MessageRevokedService {
id: row.id,
seq: row.seq,
room,
revoked_by,
revoked_at: row.deleted_at.unwrap_or_else(Utc::now),
};
bus.publish_room_event(row.room, "message.revoked", &data)
.await?;
Ok(Some(WsOutEvent::MessageRevoked {
room: data.room.clone(),
data,
}))
}
pub(super) async fn load_message(
bus: &ChannelBus,
message_id: Uuid,
) -> ChannelResult<model::room::RoomMessageModel> {
db::sqlx::query_as::<_, model::room::RoomMessageModel>(
"SELECT id, room, seq, thread, parent, author, content, content_type, pinned, \
system_type, metadata, edited_at, created_at, updated_at, deleted_at \
FROM room_message WHERE id = $1 AND deleted_at IS NULL",
)
.bind(message_id)
.fetch_one(bus.inner.db.reader())
.await
.map_err(|e| match e {
db::sqlx::Error::RowNotFound => ChannelError::RoomNotFound,
other => ChannelError::Database(other),
})
}
pub(super) async fn message_list(
bus: &ChannelBus,
user_id: Uuid,
room: Uuid,
before_seq: Option<i64>,
after_seq: Option<i64>,
limit: Option<u64>,
) -> ChannelResult<Option<WsOutEvent>> {
Self::ensure_room_access(bus, user_id, room).await?;
let limit = limit.unwrap_or(MAX_MESSAGES_PER_REQUEST);
let direction = if before_seq.is_some() {
PaginationDirection::Before
} else {
PaginationDirection::After
};
let cursor = before_seq
.map(|s| s.to_string())
.or(after_seq.map(|s| s.to_string()));
let pagination = MessagePagination::new(bus.inner.db.clone());
let page = pagination
.get_messages(PaginationParams {
room_id: room,
limit,
cursor,
direction,
})
.await?;
let mut page_messages = page.messages;
if before_seq.is_some() || (before_seq.is_none() && after_seq.is_none()) {
page_messages.reverse();
}
let message_ids: Vec<Uuid> = page_messages.iter().map(|m| m.id).collect();
let reactions = Self::reaction_groups_for_messages(bus, user_id, &message_ids)
.await
.unwrap_or_default();
let list_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room));
let mut messages: Vec<message::MessageNewService> =
Vec::with_capacity(page_messages.len());
for m in page_messages {
let sender = bus.lookup_user(m.sender_id).await
.unwrap_or_else(|_| UserInfo::unknown(m.sender_id));
messages.push(message::MessageNewService {
id: m.id,
seq: m.seq,
room: list_room.clone(),
sender_type: "user".to_string(),
sender,
thread: m.thread,
in_reply_to: m.parent,
content: m.content,
content_type: m.content_type,
pinned: m.pinned,
system_type: m.system_type,
metadata: m.metadata,
thinking_content: None,
thinking_is_chunked: None,
send_at: m.send_at,
reactions: reactions.get(&m.id).cloned().unwrap_or_default(),
});
}
let total = messages.len() as i64;
Ok(Some(WsOutEvent::MessageList {
room: list_room.clone(),
data: message::MessageListService {
room: list_room,
messages,
total,
},
}))
}
pub(super) async fn message_around(
bus: &ChannelBus,
user_id: Uuid,
room: Uuid,
seq: i64,
limit: Option<u64>,
) -> ChannelResult<Option<WsOutEvent>> {
Self::ensure_room_access(bus, user_id, room).await?;
let size = limit
.unwrap_or(MAX_MESSAGES_PER_REQUEST)
.min(MAX_MESSAGES_PER_REQUEST) as i64;
let rows = db::sqlx::query_as::<_, model::room::RoomMessageModel>(
"(SELECT id, room, seq, thread, parent, author, content, content_type, pinned, \
system_type, metadata, edited_at, created_at, updated_at, deleted_at \
FROM room_message \
WHERE room = $1 AND seq < $2 AND deleted_at IS NULL \
ORDER BY seq DESC LIMIT $3) \
UNION ALL \
(SELECT id, room, seq, thread, parent, author, content, content_type, pinned, \
system_type, metadata, edited_at, created_at, updated_at, deleted_at \
FROM room_message \
WHERE room = $1 AND seq >= $2 AND deleted_at IS NULL \
ORDER BY seq ASC LIMIT $3) \
ORDER BY seq ASC",
)
.bind(room)
.bind(seq)
.bind(size)
.fetch_all(bus.inner.db.reader())
.await?;
let author_ids: Vec<Uuid> = rows.iter().map(|r| r.author).collect();
let user_map = bus.lookup_users(&author_ids).await.unwrap_or_default();
let message_ids: Vec<Uuid> = rows.iter().map(|r| r.id).collect();
let reactions = Self::reaction_groups_for_messages(bus, user_id, &message_ids)
.await
.unwrap_or_default();
let around_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room));
let messages = rows
.into_iter()
.map(|r| {
let sender = user_map
.get(&r.author)
.cloned()
.unwrap_or_else(|| UserInfo::unknown(r.author));
message::MessageNewService {
id: r.id,
seq: r.seq,
room: around_room.clone(),
sender_type: "user".to_string(),
sender,
thread: r.thread,
in_reply_to: r.parent,
content: r.content,
content_type: r.content_type,
pinned: r.pinned,
system_type: r.system_type,
metadata: r.metadata,
thinking_content: None,
thinking_is_chunked: None,
send_at: r.created_at,
reactions: reactions.get(&r.id).cloned().unwrap_or_default(),
}
})
.collect::<Vec<_>>();
let total = messages.len() as i64;
Ok(Some(WsOutEvent::MessageList {
room: around_room.clone(),
data: message::MessageListService {
room: around_room,
messages,
total,
},
}))
}
pub(super) async fn missed_messages(
bus: &ChannelBus,
user_id: Uuid,
room: Uuid,
after_seq: i64,
limit: Option<i64>,
) -> ChannelResult<Option<WsOutEvent>> {
Self::ensure_room_access(bus, user_id, room).await?;
let limit =
limit.unwrap_or(MAX_MESSAGES_PER_REQUEST as i64).max(0) as usize;
let messages = bus
.inner
.reconnect
.get_missed_messages(room, after_seq)
.await?;
let author_ids: Vec<Uuid> = messages.iter().map(|m| m.sender_id).collect();
let message_ids: Vec<Uuid> = messages.iter().map(|m| m.message_id).collect();
let user_map = bus.lookup_users(&author_ids).await.unwrap_or_default();
let reactions = Self::reaction_groups_for_messages(bus, user_id, &message_ids)
.await
.unwrap_or_default();
let missed_room = bus.lookup_room(room).await.unwrap_or_else(|_| RoomInfo::unknown(room));
let messages = messages
.into_iter()
.take(limit)
.map(|m| {
let sender = user_map
.get(&m.sender_id)
.cloned()
.unwrap_or_else(|| UserInfo::unknown(m.sender_id));
message::MessageNewService {
id: m.message_id,
seq: m.seq,
room: missed_room.clone(),
sender_type: "user".to_string(),
sender,
thread: None,
in_reply_to: None,
content: m.content,
content_type: "text".to_string(),
pinned: false,
system_type: None,
metadata: serde_json::Value::Null,
thinking_content: None,
thinking_is_chunked: None,
send_at: m.send_at,
reactions: reactions.get(&m.message_id).cloned().unwrap_or_default(),
}
})
.collect::<Vec<_>>();
let data = message::MessageListService {
room: missed_room.clone(),
total: messages.len() as i64,
messages,
};
Ok(Some(WsOutEvent::MessageList {
room: missed_room,
data,
}))
}
}