use std::collections::HashMap; use std::sync::Arc; use std::sync::atomic::{AtomicI64, Ordering}; use dashmap::DashMap; use tokio::sync::Mutex; use uuid::Uuid; use crate::{ChannelError, ChannelResult, security::require_cluster}; const SEQ_KEY_PREFIX: &str = "room:seq:"; const DEFAULT_SEGMENT_SIZE: u64 = 1024; const MAX_REFRESH_RETRIES: u32 = 3; const BOOTSTRAP_SCRIPT: &str = r#" local key = KEYS[1] local db_max = tonumber(ARGV[1]) local current = tonumber(redis.call('GET', key) or '0') if current < db_max then redis.call('SET', key, db_max) end return tonumber(redis.call('GET', key)) "#; struct SegmentState { end: i64, next: AtomicI64, } pub struct SeqAllocator(Arc); struct SeqAllocatorInner { cache: cache::AppCache, db: db::AppDatabase, segments: DashMap>, refresh_locks: DashMap>>, segment_size: u64, } impl Clone for SeqAllocator { fn clone(&self) -> Self { Self(Arc::clone(&self.0)) } } impl SeqAllocator { pub fn new(cache: cache::AppCache, db: db::AppDatabase) -> Self { Self::with_segment_size(cache, db, DEFAULT_SEGMENT_SIZE) } pub fn with_segment_size( cache: cache::AppCache, db: db::AppDatabase, size: u64, ) -> Self { Self(Arc::new(SeqAllocatorInner { cache, db, segments: DashMap::new(), refresh_locks: DashMap::new(), segment_size: if size > 0 { size } else { DEFAULT_SEGMENT_SIZE }, })) } pub async fn seq(&self, room: Uuid) -> ChannelResult { for _ in 0..MAX_REFRESH_RETRIES { if let Some(seq) = self.try_allocate(&room) { return Ok(seq); } let lock = self.get_refresh_lock(room); let _guard = lock.lock().await; if let Some(seq) = self.try_allocate(&room) { return Ok(seq); } self.refresh_segment(room).await?; self.0.refresh_locks.remove(&room); } Err(ChannelError::Internal( "seq allocation exhausted".to_string(), )) } pub async fn bootstrap(&self, room: Uuid) -> ChannelResult { let db_max = self.db_max_seq(room).await?; if db_max == 0 { return Ok(0); } let key = format!("{}{}", SEQ_KEY_PREFIX, room); let cluster = require_cluster(&self.0.cache)?; let mut conn = cluster.conn(); let current: i64 = redis::Cmd::new() .arg("EVAL") .arg(BOOTSTRAP_SCRIPT) .arg(1) .arg(&key) .arg(db_max) .query_async(&mut conn) .await .map_err(|e| ChannelError::Cache(cache::CacheError::Redis(e)))?; self.0.segments.remove(&room); Ok(current) } pub async fn bootstrap_all( &self, rooms: Vec, ) -> ChannelResult> { let mut results = HashMap::with_capacity(rooms.len()); for room in rooms { results.insert(room, self.bootstrap(room).await?); } Ok(results) } fn try_allocate(&self, room: &Uuid) -> Option { let state = self.0.segments.get(room)?; loop { let current = state.next.load(Ordering::Acquire); if current >= state.end { return None; } if state .next .compare_exchange_weak(current, current + 1, Ordering::AcqRel, Ordering::Acquire) .is_ok() { return Some(current); } } } fn get_refresh_lock(&self, room: Uuid) -> Arc> { Arc::clone( self.0 .refresh_locks .entry(room) .or_insert_with(|| Arc::new(Mutex::new(()))) .value(), ) } async fn refresh_segment(&self, room: Uuid) -> ChannelResult<()> { let key = format!("{}{}", SEQ_KEY_PREFIX, room); let cluster = require_cluster(&self.0.cache)?; let mut conn = cluster.conn(); let counter: i64 = redis::Cmd::new() .arg("INCRBY") .arg(&key) .arg(self.0.segment_size as i64) .query_async(&mut conn) .await .map_err(|e| ChannelError::Cache(cache::CacheError::Redis(e)))?; let start = counter - self.0.segment_size as i64 + 1; let end = counter + 1; self.0.segments.insert( room, Arc::new(SegmentState { end, next: AtomicI64::new(start), }), ); Ok(()) } async fn db_max_seq(&self, room: Uuid) -> ChannelResult { let row: (i64,) = db::sqlx::query_as( "SELECT COALESCE(MAX(seq), 0) FROM room_message WHERE room = $1 AND deleted_at IS NULL", ) .bind(room) .fetch_one(self.0.db.reader()) .await?; Ok(row.0) } }