191 lines
5.1 KiB
Rust
191 lines
5.1 KiB
Rust
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
use std::sync::atomic::{AtomicI64, Ordering};
|
|
|
|
use dashmap::DashMap;
|
|
use tokio::sync::Mutex;
|
|
use uuid::Uuid;
|
|
|
|
use crate::{ChannelError, ChannelResult, security::require_cluster};
|
|
|
|
const SEQ_KEY_PREFIX: &str = "room:seq:";
|
|
const DEFAULT_SEGMENT_SIZE: u64 = 1024;
|
|
const MAX_REFRESH_RETRIES: u32 = 3;
|
|
|
|
const BOOTSTRAP_SCRIPT: &str = r#"
|
|
local key = KEYS[1]
|
|
local db_max = tonumber(ARGV[1])
|
|
local current = tonumber(redis.call('GET', key) or '0')
|
|
if current < db_max then
|
|
redis.call('SET', key, db_max)
|
|
end
|
|
return tonumber(redis.call('GET', key))
|
|
"#;
|
|
|
|
struct SegmentState {
|
|
end: i64,
|
|
next: AtomicI64,
|
|
}
|
|
|
|
pub struct SeqAllocator(Arc<SeqAllocatorInner>);
|
|
|
|
struct SeqAllocatorInner {
|
|
cache: cache::AppCache,
|
|
db: db::AppDatabase,
|
|
segments: DashMap<Uuid, Arc<SegmentState>>,
|
|
refresh_locks: DashMap<Uuid, Arc<Mutex<()>>>,
|
|
segment_size: u64,
|
|
}
|
|
|
|
impl Clone for SeqAllocator {
|
|
fn clone(&self) -> Self {
|
|
Self(Arc::clone(&self.0))
|
|
}
|
|
}
|
|
|
|
impl SeqAllocator {
|
|
pub fn new(cache: cache::AppCache, db: db::AppDatabase) -> Self {
|
|
Self::with_segment_size(cache, db, DEFAULT_SEGMENT_SIZE)
|
|
}
|
|
|
|
pub fn with_segment_size(
|
|
cache: cache::AppCache,
|
|
db: db::AppDatabase,
|
|
size: u64,
|
|
) -> Self {
|
|
Self(Arc::new(SeqAllocatorInner {
|
|
cache,
|
|
db,
|
|
segments: DashMap::new(),
|
|
refresh_locks: DashMap::new(),
|
|
segment_size: if size > 0 { size } else { DEFAULT_SEGMENT_SIZE },
|
|
}))
|
|
}
|
|
|
|
pub async fn seq(&self, room: Uuid) -> ChannelResult<i64> {
|
|
for _ in 0..MAX_REFRESH_RETRIES {
|
|
if let Some(seq) = self.try_allocate(&room) {
|
|
return Ok(seq);
|
|
}
|
|
|
|
let lock = self.get_refresh_lock(room);
|
|
let _guard = lock.lock().await;
|
|
|
|
if let Some(seq) = self.try_allocate(&room) {
|
|
return Ok(seq);
|
|
}
|
|
|
|
self.refresh_segment(room).await?;
|
|
self.0.refresh_locks.remove(&room);
|
|
}
|
|
|
|
Err(ChannelError::Internal(
|
|
"seq allocation exhausted".to_string(),
|
|
))
|
|
}
|
|
|
|
pub async fn bootstrap(&self, room: Uuid) -> ChannelResult<i64> {
|
|
let db_max = self.db_max_seq(room).await?;
|
|
if db_max == 0 {
|
|
return Ok(0);
|
|
}
|
|
|
|
let key = format!("{}{}", SEQ_KEY_PREFIX, room);
|
|
let cluster = require_cluster(&self.0.cache)?;
|
|
let mut conn = cluster.conn();
|
|
|
|
let current: i64 = redis::Cmd::new()
|
|
.arg("EVAL")
|
|
.arg(BOOTSTRAP_SCRIPT)
|
|
.arg(1)
|
|
.arg(&key)
|
|
.arg(db_max)
|
|
.query_async(&mut conn)
|
|
.await
|
|
.map_err(|e| ChannelError::Cache(cache::CacheError::Redis(e)))?;
|
|
|
|
self.0.segments.remove(&room);
|
|
Ok(current)
|
|
}
|
|
|
|
pub async fn bootstrap_all(
|
|
&self,
|
|
rooms: Vec<Uuid>,
|
|
) -> ChannelResult<HashMap<Uuid, i64>> {
|
|
let mut results = HashMap::with_capacity(rooms.len());
|
|
for room in rooms {
|
|
results.insert(room, self.bootstrap(room).await?);
|
|
}
|
|
Ok(results)
|
|
}
|
|
|
|
fn try_allocate(&self, room: &Uuid) -> Option<i64> {
|
|
let state = self.0.segments.get(room)?;
|
|
loop {
|
|
let current = state.next.load(Ordering::Acquire);
|
|
if current >= state.end {
|
|
return None;
|
|
}
|
|
if state
|
|
.next
|
|
.compare_exchange_weak(
|
|
current,
|
|
current + 1,
|
|
Ordering::AcqRel,
|
|
Ordering::Acquire,
|
|
)
|
|
.is_ok()
|
|
{
|
|
return Some(current);
|
|
}
|
|
}
|
|
}
|
|
|
|
fn get_refresh_lock(&self, room: Uuid) -> Arc<Mutex<()>> {
|
|
Arc::clone(
|
|
self.0
|
|
.refresh_locks
|
|
.entry(room)
|
|
.or_insert_with(|| Arc::new(Mutex::new(())))
|
|
.value(),
|
|
)
|
|
}
|
|
|
|
async fn refresh_segment(&self, room: Uuid) -> ChannelResult<()> {
|
|
let key = format!("{}{}", SEQ_KEY_PREFIX, room);
|
|
let cluster = require_cluster(&self.0.cache)?;
|
|
let mut conn = cluster.conn();
|
|
|
|
let counter: i64 = redis::Cmd::new()
|
|
.arg("INCRBY")
|
|
.arg(&key)
|
|
.arg(self.0.segment_size as i64)
|
|
.query_async(&mut conn)
|
|
.await
|
|
.map_err(|e| ChannelError::Cache(cache::CacheError::Redis(e)))?;
|
|
|
|
let start = counter - self.0.segment_size as i64 + 1;
|
|
let end = counter + 1;
|
|
|
|
self.0.segments.insert(
|
|
room,
|
|
Arc::new(SegmentState {
|
|
end,
|
|
next: AtomicI64::new(start),
|
|
}),
|
|
);
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn db_max_seq(&self, room: Uuid) -> ChannelResult<i64> {
|
|
let row: (i64,) = db::sqlx::query_as(
|
|
"SELECT COALESCE(MAX(seq), 0) FROM room_message WHERE room = $1 AND deleted_at IS NULL",
|
|
)
|
|
.bind(room)
|
|
.fetch_one(self.0.db.reader())
|
|
.await?;
|
|
Ok(row.0)
|
|
}
|
|
}
|