221 lines
6.8 KiB
Rust
221 lines
6.8 KiB
Rust
use serde::{Deserialize, Serialize};
|
|
use uuid::Uuid;
|
|
|
|
use crate::rooms::RM_COLUMNS;
|
|
use crate::{ChannelError, ChannelResult};
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct MessagePage {
|
|
pub messages: Vec<MessageItem>,
|
|
pub has_more: bool,
|
|
pub next_cursor: Option<String>,
|
|
pub prev_cursor: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct MessageItem {
|
|
pub id: Uuid,
|
|
pub room_id: Uuid,
|
|
pub seq: i64,
|
|
pub thread: Option<Uuid>,
|
|
pub parent: Option<Uuid>,
|
|
pub content: String,
|
|
pub content_type: String,
|
|
pub pinned: bool,
|
|
pub system_type: Option<String>,
|
|
pub metadata: serde_json::Value,
|
|
pub sender_id: Uuid,
|
|
pub send_at: chrono::DateTime<chrono::Utc>,
|
|
pub edited_at: Option<chrono::DateTime<chrono::Utc>>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct PaginationParams {
|
|
pub room_id: Uuid,
|
|
pub limit: u64,
|
|
pub cursor: Option<String>,
|
|
pub direction: PaginationDirection,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
|
#[serde(rename_all = "lowercase")]
|
|
pub enum PaginationDirection {
|
|
Before,
|
|
After,
|
|
}
|
|
|
|
pub struct MessagePagination {
|
|
db: db::AppDatabase,
|
|
}
|
|
|
|
impl MessagePagination {
|
|
pub fn new(db: db::AppDatabase) -> Self {
|
|
Self { db }
|
|
}
|
|
|
|
pub async fn get_messages(
|
|
&self,
|
|
params: PaginationParams,
|
|
) -> ChannelResult<MessagePage> {
|
|
let limit = std::cmp::Ord::min(params.limit, 100) as i64;
|
|
let cursor_seq = params.cursor.and_then(|c| c.parse::<i64>().ok());
|
|
|
|
let messages = match (params.direction, cursor_seq) {
|
|
(PaginationDirection::Before, Some(seq)) => {
|
|
db::sqlx::query_as::<_, model::room::RoomMessageModel>(
|
|
db::sqlx::AssertSqlSafe(format!(
|
|
"SELECT {RM_COLUMNS} FROM room_message \
|
|
WHERE room = $1 AND seq < $2 AND deleted_at IS NULL AND thread IS NULL \
|
|
ORDER BY seq DESC LIMIT $3"
|
|
)),
|
|
)
|
|
.bind(params.room_id)
|
|
.bind(seq)
|
|
.bind(limit + 1)
|
|
.fetch_all(self.db.reader())
|
|
.await?
|
|
}
|
|
(PaginationDirection::After, Some(seq)) => {
|
|
db::sqlx::query_as::<_, model::room::RoomMessageModel>(
|
|
db::sqlx::AssertSqlSafe(format!(
|
|
"SELECT {RM_COLUMNS} FROM room_message \
|
|
WHERE room = $1 AND seq > $2 AND deleted_at IS NULL AND thread IS NULL \
|
|
ORDER BY seq ASC LIMIT $3"
|
|
)),
|
|
)
|
|
.bind(params.room_id)
|
|
.bind(seq)
|
|
.bind(limit + 1)
|
|
.fetch_all(self.db.reader())
|
|
.await?
|
|
}
|
|
_ => {
|
|
db::sqlx::query_as::<_, model::room::RoomMessageModel>(
|
|
db::sqlx::AssertSqlSafe(format!(
|
|
"SELECT {RM_COLUMNS} FROM room_message \
|
|
WHERE room = $1 AND deleted_at IS NULL AND thread IS NULL \
|
|
ORDER BY seq DESC LIMIT $2"
|
|
)),
|
|
)
|
|
.bind(params.room_id)
|
|
.bind(limit + 1)
|
|
.fetch_all(self.db.reader())
|
|
.await?
|
|
}
|
|
};
|
|
|
|
let has_more = messages.len() > limit as usize;
|
|
let messages: Vec<_> =
|
|
messages.into_iter().take(limit as usize).collect();
|
|
|
|
let next_cursor = if has_more {
|
|
messages.last().map(|m| m.seq.to_string())
|
|
} else {
|
|
None
|
|
};
|
|
let prev_cursor = messages.first().map(|m| m.seq.to_string());
|
|
|
|
let items: Vec<MessageItem> = messages
|
|
.into_iter()
|
|
.map(|m| MessageItem {
|
|
id: m.id,
|
|
room_id: m.room,
|
|
seq: m.seq,
|
|
thread: m.thread,
|
|
parent: m.parent,
|
|
content: m.content,
|
|
content_type: m.content_type,
|
|
pinned: m.pinned,
|
|
system_type: m.system_type,
|
|
metadata: m.metadata,
|
|
sender_id: m.author,
|
|
send_at: m.created_at,
|
|
edited_at: m.edited_at,
|
|
})
|
|
.collect();
|
|
|
|
Ok(MessagePage {
|
|
messages: items,
|
|
has_more,
|
|
next_cursor,
|
|
prev_cursor,
|
|
})
|
|
}
|
|
|
|
pub async fn get_messages_around(
|
|
&self,
|
|
room_id: Uuid,
|
|
message_id: Uuid,
|
|
context_size: i64,
|
|
) -> ChannelResult<MessagePage> {
|
|
let target = db::sqlx::query_as::<_, model::room::RoomMessageModel>(
|
|
db::sqlx::AssertSqlSafe(format!(
|
|
"SELECT {RM_COLUMNS} FROM room_message \
|
|
WHERE id = $1 AND room = $2 AND deleted_at IS NULL"
|
|
)),
|
|
)
|
|
.bind(message_id)
|
|
.bind(room_id)
|
|
.fetch_optional(self.db.reader())
|
|
.await?
|
|
.ok_or(ChannelError::Internal("message not found".to_string()))?;
|
|
|
|
let before = db::sqlx::query_as::<_, model::room::RoomMessageModel>(
|
|
db::sqlx::AssertSqlSafe(format!(
|
|
"SELECT {RM_COLUMNS} FROM room_message \
|
|
WHERE room = $1 AND seq < $2 AND deleted_at IS NULL \
|
|
ORDER BY seq DESC LIMIT $3"
|
|
)),
|
|
)
|
|
.bind(room_id)
|
|
.bind(target.seq)
|
|
.bind(context_size)
|
|
.fetch_all(self.db.reader())
|
|
.await?;
|
|
|
|
let after = db::sqlx::query_as::<_, model::room::RoomMessageModel>(
|
|
db::sqlx::AssertSqlSafe(format!(
|
|
"SELECT {RM_COLUMNS} FROM room_message \
|
|
WHERE room = $1 AND seq > $2 AND deleted_at IS NULL \
|
|
ORDER BY seq ASC LIMIT $3"
|
|
)),
|
|
)
|
|
.bind(room_id)
|
|
.bind(target.seq)
|
|
.bind(context_size)
|
|
.fetch_all(self.db.reader())
|
|
.await?;
|
|
|
|
let mut all_messages = before;
|
|
all_messages.reverse();
|
|
all_messages.push(target);
|
|
all_messages.extend(after);
|
|
|
|
let items: Vec<MessageItem> = all_messages
|
|
.into_iter()
|
|
.map(|m| MessageItem {
|
|
id: m.id,
|
|
room_id: m.room,
|
|
seq: m.seq,
|
|
thread: m.thread,
|
|
parent: m.parent,
|
|
content: m.content,
|
|
content_type: m.content_type,
|
|
pinned: m.pinned,
|
|
system_type: m.system_type,
|
|
metadata: m.metadata,
|
|
sender_id: m.author,
|
|
send_at: m.created_at,
|
|
edited_at: m.edited_at,
|
|
})
|
|
.collect();
|
|
|
|
Ok(MessagePage {
|
|
messages: items,
|
|
has_more: false,
|
|
next_cursor: None,
|
|
prev_cursor: None,
|
|
})
|
|
}
|
|
}
|