126 lines
3.4 KiB
Rust
126 lines
3.4 KiB
Rust
use std::collections::HashMap;
|
|
use uuid::Uuid;
|
|
|
|
use model::room::RoomMessageModel;
|
|
use serde::{Deserialize, Serialize};
|
|
|
|
use crate::rooms::RM_COLUMNS;
|
|
use crate::ChannelResult;
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct ClientState {
|
|
pub user_id: Uuid,
|
|
pub last_seq: HashMap<Uuid, i64>,
|
|
pub last_seen: chrono::DateTime<chrono::Utc>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct MissedMessage {
|
|
pub room_id: Uuid,
|
|
pub message_id: Uuid,
|
|
pub seq: i64,
|
|
pub content: String,
|
|
pub sender_id: Uuid,
|
|
pub send_at: chrono::DateTime<chrono::Utc>,
|
|
pub thread: Option<Uuid>,
|
|
pub parent: Option<Uuid>,
|
|
pub content_type: String,
|
|
pub pinned: bool,
|
|
pub system_type: Option<String>,
|
|
pub metadata: serde_json::Value,
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct ReconnectManager {
|
|
cache: cache::AppCache,
|
|
db: db::AppDatabase,
|
|
}
|
|
|
|
impl ReconnectManager {
|
|
pub fn new(cache: cache::AppCache, db: db::AppDatabase) -> Self {
|
|
Self { cache, db }
|
|
}
|
|
|
|
pub async fn save_client_state(
|
|
&self,
|
|
user_id: Uuid,
|
|
room_id: Uuid,
|
|
last_seq: i64,
|
|
) -> ChannelResult<()> {
|
|
let key = format!("client:state:{}:{}", user_id, room_id);
|
|
self.cache.set(&key, &last_seq.to_string()).await?;
|
|
if let Some(cluster) = &self.cache.cluster {
|
|
cluster
|
|
.expire(&key, std::time::Duration::from_secs(86400))
|
|
.await?;
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn get_last_seq(
|
|
&self,
|
|
user_id: Uuid,
|
|
room_id: Uuid,
|
|
) -> ChannelResult<Option<i64>> {
|
|
let key = format!("client:state:{}:{}", user_id, room_id);
|
|
let value: Option<String> = self.cache.get(&key).await?;
|
|
Ok(value.and_then(|v| v.parse::<i64>().ok()))
|
|
}
|
|
|
|
pub async fn get_missed_messages(
|
|
&self,
|
|
room_id: Uuid,
|
|
since_seq: i64,
|
|
) -> ChannelResult<Vec<MissedMessage>> {
|
|
let messages = db::sqlx::query_as::<_, RoomMessageModel>(
|
|
db::sqlx::AssertSqlSafe(format!(
|
|
"SELECT {RM_COLUMNS} FROM room_message \
|
|
WHERE room = $1 AND seq > $2 AND deleted_at IS NULL AND thread IS NULL \
|
|
ORDER BY seq ASC \
|
|
LIMIT 100"
|
|
)),
|
|
)
|
|
.bind(room_id)
|
|
.bind(since_seq)
|
|
.fetch_all(self.db.reader())
|
|
.await?;
|
|
|
|
let missed: Vec<MissedMessage> = messages
|
|
.into_iter()
|
|
.map(|m| MissedMessage {
|
|
room_id: m.room,
|
|
message_id: m.id,
|
|
seq: m.seq,
|
|
content: m.content,
|
|
sender_id: m.author,
|
|
send_at: m.created_at,
|
|
thread: m.thread,
|
|
parent: m.parent,
|
|
content_type: m.content_type,
|
|
pinned: m.pinned,
|
|
system_type: m.system_type,
|
|
metadata: m.metadata,
|
|
})
|
|
.collect();
|
|
|
|
Ok(missed)
|
|
}
|
|
|
|
pub async fn handle_reconnect(
|
|
&self,
|
|
_user_id: Uuid,
|
|
room_states: HashMap<Uuid, i64>,
|
|
) -> ChannelResult<HashMap<Uuid, Vec<MissedMessage>>> {
|
|
let mut result = HashMap::new();
|
|
|
|
for (room_id, client_seq) in room_states {
|
|
let missed = self.get_missed_messages(room_id, client_seq).await?;
|
|
if !missed.is_empty() {
|
|
result.insert(room_id, missed);
|
|
}
|
|
}
|
|
|
|
Ok(result)
|
|
}
|
|
}
|