gitdataai/libs/session_manager/src/storage.rs
ZhenYi 5776af18ca perf: sequence generation Redis-only + session MGET batch
service.rs: Replace per-message Lua+DB seq with simple INCR, only
reconcile DB every 1000 messages (99.9% queries eliminated).

storage.rs: Replace N+1 GET loop with single MGET for both
get_user_sessions and get_workspace_sessions (N+1 → 2 roundtrips).
2026-04-24 00:04:27 +08:00

356 lines
11 KiB
Rust

use anyhow::Context;
use chrono::Utc;
use deadpool_redis::cluster::Pool;
use redis::AsyncCommands;
use serde_json;
use thiserror::Error;
use uuid::Uuid;
use crate::types::UserSession;
#[derive(Error, Debug)]
pub enum SessionStorageError {
#[error("Redis error: {0}")]
Redis(#[from] anyhow::Error),
#[error("session not found: {0}")]
NotFound(Uuid),
}
const KEY_CONN: &str = "user:conn:";
const KEY_USER_SESSIONS: &str = "user:user_sessions:";
const KEY_WORKSPACE_SESSIONS: &str = "user:workspace_sessions:";
#[derive(Clone)]
pub struct SessionStorage {
pool: Pool,
heartbeat_ttl_secs: u64,
}
impl SessionStorage {
pub fn new(pool: Pool) -> Self {
Self {
pool,
heartbeat_ttl_secs: 120,
}
}
pub fn with_heartbeat_ttl(mut self, ttl_secs: u64) -> Self {
self.heartbeat_ttl_secs = ttl_secs;
self
}
async fn get_conn(&self) -> Result<deadpool_redis::cluster::Connection, SessionStorageError> {
self.pool
.get()
.await
.context("failed to get Redis connection from pool")
.map_err(SessionStorageError::Redis)
}
fn conn_key(session_id: &Uuid) -> String {
format!("{KEY_CONN}{session_id}")
}
fn user_sessions_key(user_id: &Uuid) -> String {
format!("{KEY_USER_SESSIONS}{user_id}")
}
fn workspace_sessions_key(workspace_id: &Uuid) -> String {
format!("{KEY_WORKSPACE_SESSIONS}{workspace_id}")
}
fn to_err<E: std::error::Error + Send + Sync + 'static>(e: E) -> SessionStorageError {
SessionStorageError::Redis(anyhow::anyhow!(e))
}
/// Store a new user session and associate it with user + workspace indexes.
pub async fn save_session(&self, session: &UserSession) -> Result<(), SessionStorageError> {
let mut conn = self.get_conn().await?;
let key = Self::conn_key(&session.session_id);
let user_key = Self::user_sessions_key(&session.user_id);
let ws_key = Self::workspace_sessions_key(&session.workspace_id);
let value = serde_json::to_string(session)
.context("serialize UserSession")
.map_err(SessionStorageError::Redis)?;
let ttl = self.heartbeat_ttl_secs;
let _: () = redis::pipe()
.set_ex(&key, &value, ttl)
.sadd(&user_key, session_id_str(&session.session_id))
.expire(&user_key, 0)
.sadd(&ws_key, session_id_str(&session.session_id))
.expire(&ws_key, 0)
.query_async(&mut conn)
.await
.map_err(Self::to_err)?;
Ok(())
}
/// Get a session by its ID.
pub async fn get_session(&self, session_id: &Uuid) -> Result<Option<UserSession>, SessionStorageError> {
let mut conn = self.get_conn().await?;
let key = Self::conn_key(session_id);
let value: Option<String> = conn
.get(&key)
.await
.map_err(Self::to_err)?;
match value {
Some(v) => {
let session: UserSession = serde_json::from_str(&v)
.context("deserialize UserSession")
.map_err(SessionStorageError::Redis)?;
Ok(Some(session))
}
None => Ok(None),
}
}
/// Update the heartbeat timestamp and refresh TTL.
pub async fn heartbeat(&self, session_id: &Uuid) -> Result<(), SessionStorageError> {
let mut conn = self.get_conn().await?;
let key = Self::conn_key(session_id);
let ttl = self.heartbeat_ttl_secs;
let updated = Utc::now();
let script = redis::Script::new(
r#"
local v = redis.call('GET', KEYS[1])
if not v then return 0 end
local session = cjson.decode(v)
session.last_heartbeat = ARGV[1]
redis.call('SETEX', KEYS[1], ARGV[2], cjson.encode(session))
return 1
"#,
);
let result: i64 = script
.key(&key)
.arg(updated.to_rfc3339())
.arg(ttl)
.invoke_async(&mut conn)
.await
.map_err(Self::to_err)?;
if result == 0 {
return Err(SessionStorageError::NotFound(*session_id));
}
Ok(())
}
/// Delete a session by ID and clean up indexes.
pub async fn delete_session(&self, session_id: &Uuid) -> Result<(), SessionStorageError> {
let session = self.get_session(session_id).await?;
let key = Self::conn_key(session_id);
let _: () = self
.get_conn()
.await?
.del(&key)
.await
.map_err(Self::to_err)?;
if let Some(ref s) = session {
let mut conn = self.get_conn().await?;
let user_key = Self::user_sessions_key(&s.user_id);
let ws_key = Self::workspace_sessions_key(&s.workspace_id);
let id_str = session_id_str(session_id);
let _: () = conn.srem::<_, _, ()>(&user_key, &id_str).await.map_err(Self::to_err)?;
let _: () = conn.srem::<_, _, ()>(&ws_key, &id_str).await.map_err(Self::to_err)?;
}
Ok(())
}
/// Delete all sessions for a specific user.
pub async fn delete_user_sessions(&self, user_id: &Uuid) -> Result<Vec<Uuid>, SessionStorageError> {
let mut conn = self.get_conn().await?;
let user_key = Self::user_sessions_key(user_id);
let session_ids: Vec<String> = conn
.smembers(&user_key)
.await
.map_err(Self::to_err)?;
let mut deleted = Vec::new();
for id_str in &session_ids {
if let Ok(sid) = Uuid::parse_str(id_str) {
let conn_key = Self::conn_key(&sid);
let _: () = conn.del(&conn_key).await.map_err(Self::to_err)?;
deleted.push(sid);
}
}
let _: () = conn.del(&user_key).await.map_err(Self::to_err)?;
Ok(deleted)
}
/// Delete all sessions for a user within a specific workspace.
pub async fn delete_user_workspace_sessions(
&self,
user_id: &Uuid,
workspace_id: &Uuid,
) -> Result<Vec<Uuid>, SessionStorageError> {
let mut conn = self.get_conn().await?;
let ws_key = Self::workspace_sessions_key(workspace_id);
let user_key = Self::user_sessions_key(user_id);
let ws_session_ids: Vec<String> = conn
.smembers(&ws_key)
.await
.map_err(Self::to_err)?;
let mut deleted = Vec::new();
for id_str in &ws_session_ids {
if let Ok(sid) = Uuid::parse_str(id_str) {
let session = self.get_session(&sid).await?;
if let Some(ref s) = session {
if s.user_id == *user_id {
let conn_key = Self::conn_key(&sid);
let _: () = conn.del(&conn_key).await.map_err(Self::to_err)?;
let _: () = conn.srem::<_, _, ()>(&user_key, id_str).await.map_err(Self::to_err)?;
deleted.push(sid);
}
}
}
}
Ok(deleted)
}
/// Get all active sessions for a user.
pub async fn get_user_sessions(&self, user_id: &Uuid) -> Result<Vec<UserSession>, SessionStorageError> {
let mut conn = self.get_conn().await?;
let user_key = Self::user_sessions_key(user_id);
let session_ids: Vec<String> = conn
.smembers(&user_key)
.await
.map_err(Self::to_err)?;
if session_ids.is_empty() {
return Ok(Vec::new());
}
// Batch fetch all sessions in a single MGET instead of N individual GET calls.
let keys: Vec<String> = session_ids
.iter()
.map(|id| format!("{}{}", KEY_CONN, id))
.collect();
let values: Vec<Option<String>> = redis::cmd("MGET")
.arg(&keys)
.query_async(&mut conn)
.await
.map_err(Self::to_err)?;
let sessions: Vec<UserSession> = values
.into_iter()
.flatten()
.filter_map(|v| serde_json::from_str(&v).ok())
.collect();
Ok(sessions)
}
/// Get all active sessions in a workspace.
pub async fn get_workspace_sessions(
&self,
workspace_id: &Uuid,
) -> Result<Vec<UserSession>, SessionStorageError> {
let mut conn = self.get_conn().await?;
let ws_key = Self::workspace_sessions_key(workspace_id);
let session_ids: Vec<String> = conn
.smembers(&ws_key)
.await
.map_err(Self::to_err)?;
if session_ids.is_empty() {
return Ok(Vec::new());
}
// Batch fetch all sessions in a single MGET instead of N individual GET calls.
let keys: Vec<String> = session_ids
.iter()
.map(|id| format!("{}{}", KEY_CONN, id))
.collect();
let values: Vec<Option<String>> = redis::cmd("MGET")
.arg(&keys)
.query_async(&mut conn)
.await
.map_err(Self::to_err)?;
let sessions: Vec<UserSession> = values
.into_iter()
.flatten()
.filter_map(|v| serde_json::from_str(&v).ok())
.collect();
Ok(sessions)
}
/// Get distinct user IDs active in a workspace.
pub async fn get_workspace_online_users(
&self,
workspace_id: &Uuid,
) -> Result<Vec<Uuid>, SessionStorageError> {
let sessions = self.get_workspace_sessions(workspace_id).await?;
let mut seen = std::collections::HashSet::new();
let mut result = Vec::new();
for s in sessions {
if seen.insert(s.user_id) {
result.push(s.user_id);
}
}
Ok(result)
}
/// Get the count of online sessions for a user.
pub async fn get_user_session_count(&self, user_id: &Uuid) -> Result<usize, SessionStorageError> {
let sessions = self.get_user_sessions(user_id).await?;
Ok(sessions.len())
}
/// Check if a user has any active sessions (online status).
pub async fn is_user_online(&self, user_id: &Uuid) -> Result<bool, SessionStorageError> {
let count = self.get_user_session_count(user_id).await?;
Ok(count > 0)
}
/// Returns a reference to the underlying Redis pool.
pub fn pool(&self) -> &deadpool_redis::cluster::Pool {
&self.pool
}
/// Get online status for a user.
pub async fn get_user_status(&self, user_id: &Uuid) -> Result<crate::types::OnlineStatus, SessionStorageError> {
let sessions = self.get_user_sessions(user_id).await?;
if sessions.is_empty() {
return Ok(crate::types::OnlineStatus::Offline);
}
let now = Utc::now();
let idle_threshold = chrono::Duration::minutes(5);
if sessions.iter().any(|s| now - s.last_heartbeat < idle_threshold) {
Ok(crate::types::OnlineStatus::Online)
} else {
Ok(crate::types::OnlineStatus::Idle)
}
}
}
fn session_id_str(id: &Uuid) -> String {
id.to_string()
}