use anyhow::Context; use chrono::Utc; use deadpool_redis::cluster::Pool; use redis::AsyncCommands; use serde_json; use thiserror::Error; use uuid::Uuid; use crate::types::UserSession; #[derive(Error, Debug)] pub enum SessionStorageError { #[error("Redis error: {0}")] Redis(#[from] anyhow::Error), #[error("session not found: {0}")] NotFound(Uuid), } const KEY_CONN: &str = "user:conn:"; const KEY_USER_SESSIONS: &str = "user:user_sessions:"; const KEY_WORKSPACE_SESSIONS: &str = "user:workspace_sessions:"; #[derive(Clone)] pub struct SessionStorage { pool: Pool, heartbeat_ttl_secs: u64, } impl SessionStorage { pub fn new(pool: Pool) -> Self { Self { pool, heartbeat_ttl_secs: 120, } } pub fn with_heartbeat_ttl(mut self, ttl_secs: u64) -> Self { self.heartbeat_ttl_secs = ttl_secs; self } async fn get_conn(&self) -> Result { self.pool .get() .await .context("failed to get Redis connection from pool") .map_err(SessionStorageError::Redis) } fn conn_key(session_id: &Uuid) -> String { format!("{KEY_CONN}{session_id}") } fn user_sessions_key(user_id: &Uuid) -> String { format!("{KEY_USER_SESSIONS}{user_id}") } fn workspace_sessions_key(workspace_id: &Uuid) -> String { format!("{KEY_WORKSPACE_SESSIONS}{workspace_id}") } fn to_err(e: E) -> SessionStorageError { SessionStorageError::Redis(anyhow::anyhow!(e)) } /// Store a new user session and associate it with user + workspace indexes. pub async fn save_session(&self, session: &UserSession) -> Result<(), SessionStorageError> { let mut conn = self.get_conn().await?; let key = Self::conn_key(&session.session_id); let user_key = Self::user_sessions_key(&session.user_id); let ws_key = Self::workspace_sessions_key(&session.workspace_id); let value = serde_json::to_string(session) .context("serialize UserSession") .map_err(SessionStorageError::Redis)?; let ttl = self.heartbeat_ttl_secs; let _: () = redis::pipe() .set_ex(&key, &value, ttl) .sadd(&user_key, session_id_str(&session.session_id)) .expire(&user_key, 0) .sadd(&ws_key, session_id_str(&session.session_id)) .expire(&ws_key, 0) .query_async(&mut conn) .await .map_err(Self::to_err)?; Ok(()) } /// Get a session by its ID. pub async fn get_session(&self, session_id: &Uuid) -> Result, SessionStorageError> { let mut conn = self.get_conn().await?; let key = Self::conn_key(session_id); let value: Option = conn .get(&key) .await .map_err(Self::to_err)?; match value { Some(v) => { let session: UserSession = serde_json::from_str(&v) .context("deserialize UserSession") .map_err(SessionStorageError::Redis)?; Ok(Some(session)) } None => Ok(None), } } /// Update the heartbeat timestamp and refresh TTL. pub async fn heartbeat(&self, session_id: &Uuid) -> Result<(), SessionStorageError> { let mut conn = self.get_conn().await?; let key = Self::conn_key(session_id); let ttl = self.heartbeat_ttl_secs; let updated = Utc::now(); let script = redis::Script::new( r#" local v = redis.call('GET', KEYS[1]) if not v then return 0 end local session = cjson.decode(v) session.last_heartbeat = ARGV[1] redis.call('SETEX', KEYS[1], ARGV[2], cjson.encode(session)) return 1 "#, ); let result: i64 = script .key(&key) .arg(updated.to_rfc3339()) .arg(ttl) .invoke_async(&mut conn) .await .map_err(Self::to_err)?; if result == 0 { return Err(SessionStorageError::NotFound(*session_id)); } Ok(()) } /// Delete a session by ID and clean up indexes. pub async fn delete_session(&self, session_id: &Uuid) -> Result<(), SessionStorageError> { let session = self.get_session(session_id).await?; let key = Self::conn_key(session_id); let _: () = self .get_conn() .await? .del(&key) .await .map_err(Self::to_err)?; if let Some(ref s) = session { let mut conn = self.get_conn().await?; let user_key = Self::user_sessions_key(&s.user_id); let ws_key = Self::workspace_sessions_key(&s.workspace_id); let id_str = session_id_str(session_id); let _: () = conn.srem::<_, _, ()>(&user_key, &id_str).await.map_err(Self::to_err)?; let _: () = conn.srem::<_, _, ()>(&ws_key, &id_str).await.map_err(Self::to_err)?; } Ok(()) } /// Delete all sessions for a specific user. pub async fn delete_user_sessions(&self, user_id: &Uuid) -> Result, SessionStorageError> { let mut conn = self.get_conn().await?; let user_key = Self::user_sessions_key(user_id); let session_ids: Vec = conn .smembers(&user_key) .await .map_err(Self::to_err)?; let mut deleted = Vec::new(); for id_str in &session_ids { if let Ok(sid) = Uuid::parse_str(id_str) { let conn_key = Self::conn_key(&sid); let _: () = conn.del(&conn_key).await.map_err(Self::to_err)?; deleted.push(sid); } } let _: () = conn.del(&user_key).await.map_err(Self::to_err)?; Ok(deleted) } /// Delete all sessions for a user within a specific workspace. pub async fn delete_user_workspace_sessions( &self, user_id: &Uuid, workspace_id: &Uuid, ) -> Result, SessionStorageError> { let mut conn = self.get_conn().await?; let ws_key = Self::workspace_sessions_key(workspace_id); let user_key = Self::user_sessions_key(user_id); let ws_session_ids: Vec = conn .smembers(&ws_key) .await .map_err(Self::to_err)?; let mut deleted = Vec::new(); for id_str in &ws_session_ids { if let Ok(sid) = Uuid::parse_str(id_str) { let session = self.get_session(&sid).await?; if let Some(ref s) = session { if s.user_id == *user_id { let conn_key = Self::conn_key(&sid); let _: () = conn.del(&conn_key).await.map_err(Self::to_err)?; let _: () = conn.srem::<_, _, ()>(&user_key, id_str).await.map_err(Self::to_err)?; deleted.push(sid); } } } } Ok(deleted) } /// Get all active sessions for a user. pub async fn get_user_sessions(&self, user_id: &Uuid) -> Result, SessionStorageError> { let mut conn = self.get_conn().await?; let user_key = Self::user_sessions_key(user_id); let session_ids: Vec = conn .smembers(&user_key) .await .map_err(Self::to_err)?; if session_ids.is_empty() { return Ok(Vec::new()); } // Batch fetch all sessions in a single MGET instead of N individual GET calls. let keys: Vec = session_ids .iter() .map(|id| format!("{}{}", KEY_CONN, id)) .collect(); let values: Vec> = redis::cmd("MGET") .arg(&keys) .query_async(&mut conn) .await .map_err(Self::to_err)?; let sessions: Vec = values .into_iter() .flatten() .filter_map(|v| serde_json::from_str(&v).ok()) .collect(); Ok(sessions) } /// Get all active sessions in a workspace. pub async fn get_workspace_sessions( &self, workspace_id: &Uuid, ) -> Result, SessionStorageError> { let mut conn = self.get_conn().await?; let ws_key = Self::workspace_sessions_key(workspace_id); let session_ids: Vec = conn .smembers(&ws_key) .await .map_err(Self::to_err)?; if session_ids.is_empty() { return Ok(Vec::new()); } // Batch fetch all sessions in a single MGET instead of N individual GET calls. let keys: Vec = session_ids .iter() .map(|id| format!("{}{}", KEY_CONN, id)) .collect(); let values: Vec> = redis::cmd("MGET") .arg(&keys) .query_async(&mut conn) .await .map_err(Self::to_err)?; let sessions: Vec = values .into_iter() .flatten() .filter_map(|v| serde_json::from_str(&v).ok()) .collect(); Ok(sessions) } /// Get distinct user IDs active in a workspace. pub async fn get_workspace_online_users( &self, workspace_id: &Uuid, ) -> Result, SessionStorageError> { let sessions = self.get_workspace_sessions(workspace_id).await?; let mut seen = std::collections::HashSet::new(); let mut result = Vec::new(); for s in sessions { if seen.insert(s.user_id) { result.push(s.user_id); } } Ok(result) } /// Get the count of online sessions for a user. pub async fn get_user_session_count(&self, user_id: &Uuid) -> Result { let sessions = self.get_user_sessions(user_id).await?; Ok(sessions.len()) } /// Check if a user has any active sessions (online status). pub async fn is_user_online(&self, user_id: &Uuid) -> Result { let count = self.get_user_session_count(user_id).await?; Ok(count > 0) } /// Returns a reference to the underlying Redis pool. pub fn pool(&self) -> &deadpool_redis::cluster::Pool { &self.pool } /// Get online status for a user. pub async fn get_user_status(&self, user_id: &Uuid) -> Result { let sessions = self.get_user_sessions(user_id).await?; if sessions.is_empty() { return Ok(crate::types::OnlineStatus::Offline); } let now = Utc::now(); let idle_threshold = chrono::Duration::minutes(5); if sessions.iter().any(|s| now - s.last_heartbeat < idle_threshold) { Ok(crate::types::OnlineStatus::Online) } else { Ok(crate::types::OnlineStatus::Idle) } } } fn session_id_str(id: &Uuid) -> String { id.to_string() }