use serde::{Deserialize, Serialize}; use uuid::Uuid; use crate::rooms::RM_COLUMNS; use crate::{ChannelError, ChannelResult}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct MessagePage { pub messages: Vec, pub has_more: bool, pub next_cursor: Option, pub prev_cursor: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct MessageItem { pub id: Uuid, pub room_id: Uuid, pub seq: i64, pub thread: Option, pub parent: Option, pub content: String, pub content_type: String, pub pinned: bool, pub system_type: Option, pub metadata: serde_json::Value, pub sender_id: Uuid, pub send_at: chrono::DateTime, pub edited_at: Option>, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct PaginationParams { pub room_id: Uuid, pub limit: u64, pub cursor: Option, pub direction: PaginationDirection, } #[derive(Debug, Clone, Copy, Serialize, Deserialize)] #[serde(rename_all = "lowercase")] pub enum PaginationDirection { Before, After, } pub struct MessagePagination { db: db::AppDatabase, } impl MessagePagination { pub fn new(db: db::AppDatabase) -> Self { Self { db } } pub async fn get_messages( &self, params: PaginationParams, ) -> ChannelResult { let limit = std::cmp::Ord::min(params.limit, 100) as i64; let cursor_seq = params.cursor.and_then(|c| c.parse::().ok()); let messages = match (params.direction, cursor_seq) { (PaginationDirection::Before, Some(seq)) => { db::sqlx::query_as::<_, model::room::RoomMessageModel>( db::sqlx::AssertSqlSafe(format!( "SELECT {RM_COLUMNS} FROM room_message \ WHERE room = $1 AND seq < $2 AND deleted_at IS NULL AND thread IS NULL \ ORDER BY seq DESC LIMIT $3" )), ) .bind(params.room_id) .bind(seq) .bind(limit + 1) .fetch_all(self.db.reader()) .await? } (PaginationDirection::After, Some(seq)) => { db::sqlx::query_as::<_, model::room::RoomMessageModel>( db::sqlx::AssertSqlSafe(format!( "SELECT {RM_COLUMNS} FROM room_message \ WHERE room = $1 AND seq > $2 AND deleted_at IS NULL AND thread IS NULL \ ORDER BY seq ASC LIMIT $3" )), ) .bind(params.room_id) .bind(seq) .bind(limit + 1) .fetch_all(self.db.reader()) .await? } _ => { db::sqlx::query_as::<_, model::room::RoomMessageModel>( db::sqlx::AssertSqlSafe(format!( "SELECT {RM_COLUMNS} FROM room_message \ WHERE room = $1 AND deleted_at IS NULL AND thread IS NULL \ ORDER BY seq DESC LIMIT $2" )), ) .bind(params.room_id) .bind(limit + 1) .fetch_all(self.db.reader()) .await? } }; let has_more = messages.len() > limit as usize; let messages: Vec<_> = messages.into_iter().take(limit as usize).collect(); let next_cursor = if has_more { messages.last().map(|m| m.seq.to_string()) } else { None }; let prev_cursor = messages.first().map(|m| m.seq.to_string()); let items: Vec = messages .into_iter() .map(|m| MessageItem { id: m.id, room_id: m.room, seq: m.seq, thread: m.thread, parent: m.parent, content: m.content, content_type: m.content_type, pinned: m.pinned, system_type: m.system_type, metadata: m.metadata, sender_id: m.author, send_at: m.created_at, edited_at: m.edited_at, }) .collect(); Ok(MessagePage { messages: items, has_more, next_cursor, prev_cursor, }) } pub async fn get_messages_around( &self, room_id: Uuid, message_id: Uuid, context_size: i64, ) -> ChannelResult { let target = db::sqlx::query_as::<_, model::room::RoomMessageModel>( db::sqlx::AssertSqlSafe(format!( "SELECT {RM_COLUMNS} FROM room_message \ WHERE id = $1 AND room = $2 AND deleted_at IS NULL" )), ) .bind(message_id) .bind(room_id) .fetch_optional(self.db.reader()) .await? .ok_or(ChannelError::Internal("message not found".to_string()))?; let before = db::sqlx::query_as::<_, model::room::RoomMessageModel>( db::sqlx::AssertSqlSafe(format!( "SELECT {RM_COLUMNS} FROM room_message \ WHERE room = $1 AND seq < $2 AND deleted_at IS NULL \ ORDER BY seq DESC LIMIT $3" )), ) .bind(room_id) .bind(target.seq) .bind(context_size) .fetch_all(self.db.reader()) .await?; let after = db::sqlx::query_as::<_, model::room::RoomMessageModel>( db::sqlx::AssertSqlSafe(format!( "SELECT {RM_COLUMNS} FROM room_message \ WHERE room = $1 AND seq > $2 AND deleted_at IS NULL \ ORDER BY seq ASC LIMIT $3" )), ) .bind(room_id) .bind(target.seq) .bind(context_size) .fetch_all(self.db.reader()) .await?; let mut all_messages = before; all_messages.reverse(); all_messages.push(target); all_messages.extend(after); let items: Vec = all_messages .into_iter() .map(|m| MessageItem { id: m.id, room_id: m.room, seq: m.seq, thread: m.thread, parent: m.parent, content: m.content, content_type: m.content_type, pinned: m.pinned, system_type: m.system_type, metadata: m.metadata, sender_id: m.author, send_at: m.created_at, edited_at: m.edited_at, }) .collect(); Ok(MessagePage { messages: items, has_more: false, next_cursor: None, prev_cursor: None, }) } }