gitdataai/lib/channel/pagination.rs

221 lines
6.8 KiB
Rust

use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::rooms::RM_COLUMNS;
use crate::{ChannelError, ChannelResult};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MessagePage {
pub messages: Vec<MessageItem>,
pub has_more: bool,
pub next_cursor: Option<String>,
pub prev_cursor: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MessageItem {
pub id: Uuid,
pub room_id: Uuid,
pub seq: i64,
pub thread: Option<Uuid>,
pub parent: Option<Uuid>,
pub content: String,
pub content_type: String,
pub pinned: bool,
pub system_type: Option<String>,
pub metadata: serde_json::Value,
pub sender_id: Uuid,
pub send_at: chrono::DateTime<chrono::Utc>,
pub edited_at: Option<chrono::DateTime<chrono::Utc>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PaginationParams {
pub room_id: Uuid,
pub limit: u64,
pub cursor: Option<String>,
pub direction: PaginationDirection,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum PaginationDirection {
Before,
After,
}
pub struct MessagePagination {
db: db::AppDatabase,
}
impl MessagePagination {
pub fn new(db: db::AppDatabase) -> Self {
Self { db }
}
pub async fn get_messages(
&self,
params: PaginationParams,
) -> ChannelResult<MessagePage> {
let limit = std::cmp::Ord::min(params.limit, 100) as i64;
let cursor_seq = params.cursor.and_then(|c| c.parse::<i64>().ok());
let messages = match (params.direction, cursor_seq) {
(PaginationDirection::Before, Some(seq)) => {
db::sqlx::query_as::<_, model::room::RoomMessageModel>(
db::sqlx::AssertSqlSafe(format!(
"SELECT {RM_COLUMNS} FROM room_message \
WHERE room = $1 AND seq < $2 AND deleted_at IS NULL AND thread IS NULL \
ORDER BY seq DESC LIMIT $3"
)),
)
.bind(params.room_id)
.bind(seq)
.bind(limit + 1)
.fetch_all(self.db.reader())
.await?
}
(PaginationDirection::After, Some(seq)) => {
db::sqlx::query_as::<_, model::room::RoomMessageModel>(
db::sqlx::AssertSqlSafe(format!(
"SELECT {RM_COLUMNS} FROM room_message \
WHERE room = $1 AND seq > $2 AND deleted_at IS NULL AND thread IS NULL \
ORDER BY seq ASC LIMIT $3"
)),
)
.bind(params.room_id)
.bind(seq)
.bind(limit + 1)
.fetch_all(self.db.reader())
.await?
}
_ => {
db::sqlx::query_as::<_, model::room::RoomMessageModel>(
db::sqlx::AssertSqlSafe(format!(
"SELECT {RM_COLUMNS} FROM room_message \
WHERE room = $1 AND deleted_at IS NULL AND thread IS NULL \
ORDER BY seq DESC LIMIT $2"
)),
)
.bind(params.room_id)
.bind(limit + 1)
.fetch_all(self.db.reader())
.await?
}
};
let has_more = messages.len() > limit as usize;
let messages: Vec<_> =
messages.into_iter().take(limit as usize).collect();
let next_cursor = if has_more {
messages.last().map(|m| m.seq.to_string())
} else {
None
};
let prev_cursor = messages.first().map(|m| m.seq.to_string());
let items: Vec<MessageItem> = messages
.into_iter()
.map(|m| MessageItem {
id: m.id,
room_id: m.room,
seq: m.seq,
thread: m.thread,
parent: m.parent,
content: m.content,
content_type: m.content_type,
pinned: m.pinned,
system_type: m.system_type,
metadata: m.metadata,
sender_id: m.author,
send_at: m.created_at,
edited_at: m.edited_at,
})
.collect();
Ok(MessagePage {
messages: items,
has_more,
next_cursor,
prev_cursor,
})
}
pub async fn get_messages_around(
&self,
room_id: Uuid,
message_id: Uuid,
context_size: i64,
) -> ChannelResult<MessagePage> {
let target = db::sqlx::query_as::<_, model::room::RoomMessageModel>(
db::sqlx::AssertSqlSafe(format!(
"SELECT {RM_COLUMNS} FROM room_message \
WHERE id = $1 AND room = $2 AND deleted_at IS NULL"
)),
)
.bind(message_id)
.bind(room_id)
.fetch_optional(self.db.reader())
.await?
.ok_or(ChannelError::Internal("message not found".to_string()))?;
let before = db::sqlx::query_as::<_, model::room::RoomMessageModel>(
db::sqlx::AssertSqlSafe(format!(
"SELECT {RM_COLUMNS} FROM room_message \
WHERE room = $1 AND seq < $2 AND deleted_at IS NULL \
ORDER BY seq DESC LIMIT $3"
)),
)
.bind(room_id)
.bind(target.seq)
.bind(context_size)
.fetch_all(self.db.reader())
.await?;
let after = db::sqlx::query_as::<_, model::room::RoomMessageModel>(
db::sqlx::AssertSqlSafe(format!(
"SELECT {RM_COLUMNS} FROM room_message \
WHERE room = $1 AND seq > $2 AND deleted_at IS NULL \
ORDER BY seq ASC LIMIT $3"
)),
)
.bind(room_id)
.bind(target.seq)
.bind(context_size)
.fetch_all(self.db.reader())
.await?;
let mut all_messages = before;
all_messages.reverse();
all_messages.push(target);
all_messages.extend(after);
let items: Vec<MessageItem> = all_messages
.into_iter()
.map(|m| MessageItem {
id: m.id,
room_id: m.room,
seq: m.seq,
thread: m.thread,
parent: m.parent,
content: m.content,
content_type: m.content_type,
pinned: m.pinned,
system_type: m.system_type,
metadata: m.metadata,
sender_id: m.author,
send_at: m.created_at,
edited_at: m.edited_at,
})
.collect();
Ok(MessagePage {
messages: items,
has_more: false,
next_cursor: None,
prev_cursor: None,
})
}
}