use std::{sync::Arc, time::Duration as StdDuration}; use actix_web::cookie::time::Duration; use redis::{ cluster::{ClusterClient, ClusterClientBuilder}, cluster_async::ClusterConnection, }; use tokio::sync::Mutex; use super::SessionKey; use crate::storage::{ SessionStore, format::{deserialize_session_state, serialize_session_state}, interface::{LoadError, SaveError, SessionState, UpdateError}, utils::generate_session_key, }; #[derive(Clone)] pub struct RedisClusterSessionStore { client: ClusterClient, connection: Arc>, configuration: CacheConfiguration, } #[derive(Clone)] struct CacheConfiguration { cache_keygen: Arc String + Send + Sync>, } impl Default for CacheConfiguration { fn default() -> Self { Self { cache_keygen: Arc::new(str::to_owned), } } } impl RedisClusterSessionStore { const DEFAULT_CONNECTION_TIMEOUT: StdDuration = StdDuration::from_secs(2); const DEFAULT_RESPONSE_TIMEOUT: StdDuration = StdDuration::from_secs(2); const DEFAULT_COMMAND_TIMEOUT: StdDuration = StdDuration::from_secs(3); const DEFAULT_RETRIES: u32 = 1; const DEFAULT_RETRY_MIN_WAIT_MS: u64 = 25; const DEFAULT_RETRY_MAX_WAIT_MS: u64 = 100; const DEFAULT_RETRY_FACTOR: u64 = 10; const DEFAULT_RETRY_EXPONENT_BASE: u64 = 2; pub fn builder( connection_strings: Vec, ) -> RedisClusterSessionStoreBuilder { RedisClusterSessionStoreBuilder { configuration: CacheConfiguration::default(), connection_strings, } } pub async fn new( connection_strings: Vec, ) -> anyhow::Result { Self::builder(connection_strings).build().await } fn client_builder(connection_strings: Vec) -> ClusterClientBuilder { ClusterClient::builder(connection_strings) .connection_timeout(Self::DEFAULT_CONNECTION_TIMEOUT) .response_timeout(Self::DEFAULT_RESPONSE_TIMEOUT) .retries(Self::DEFAULT_RETRIES) .min_retry_wait(Self::DEFAULT_RETRY_MIN_WAIT_MS) .max_retry_wait(Self::DEFAULT_RETRY_MAX_WAIT_MS) .retry_wait_formula( Self::DEFAULT_RETRY_FACTOR, Self::DEFAULT_RETRY_EXPONENT_BASE, ) } async fn connect( client: &ClusterClient, ) -> anyhow::Result { let started = std::time::Instant::now(); let connection = tokio::time::timeout( Self::DEFAULT_COMMAND_TIMEOUT, client.get_async_connection(), ) .await .map_err(|_| anyhow::anyhow!("session redis async connect timed out"))? .map_err(|e| anyhow::anyhow!(e))?; tracing::debug!( elapsed_ms = started.elapsed().as_millis() as u64, "session redis async connect finished" ); Ok(connection) } async fn execute_cmd( &self, op_name: &'static str, make_cmd: F, ) -> anyhow::Result where T: redis::FromRedisValue, F: Fn() -> redis::Cmd, { let first_try: anyhow::Result = { let mut connection = self.connection.lock().await; let started = std::time::Instant::now(); tracing::debug!(op = op_name, "session redis command start"); match tokio::time::timeout( Self::DEFAULT_COMMAND_TIMEOUT, make_cmd().query_async(&mut *connection), ) .await .map_err(|_| anyhow::anyhow!("session redis command timed out"))? { Ok(value) => { tracing::debug!( op = op_name, elapsed_ms = started.elapsed().as_millis() as u64, "session redis command finished" ); return Ok(value); } Err(err) => Err(anyhow::anyhow!(err)), } }; if let Err(error) = &first_try { tracing::warn!(op = op_name, error = %error, "session redis command failed, reconnecting"); } let new_connection = Self::connect(&self.client).await?; { let mut connection = self.connection.lock().await; *connection = new_connection; } let mut connection = self.connection.lock().await; let started = std::time::Instant::now(); tracing::debug!(op = op_name, "session redis command retry start"); let result = tokio::time::timeout( Self::DEFAULT_COMMAND_TIMEOUT, make_cmd().query_async(&mut *connection), ) .await .map_err(|_| anyhow::anyhow!("session redis command retry timed out"))? .map_err(|e| anyhow::anyhow!(e))?; tracing::debug!( op = op_name, elapsed_ms = started.elapsed().as_millis() as u64, "session redis command retry finished" ); Ok(result) } fn ttl_seconds(ttl: &Duration) -> anyhow::Result { let ttl_secs = ttl.whole_seconds(); if ttl_secs <= 0 { anyhow::bail!("session TTL must be positive"); } u64::try_from(ttl_secs).map_err(anyhow::Error::new) } } #[must_use] pub struct RedisClusterSessionStoreBuilder { configuration: CacheConfiguration, connection_strings: Vec, } impl RedisClusterSessionStoreBuilder { pub fn cache_keygen(mut self, keygen: F) -> Self where F: Fn(&str) -> String + 'static + Send + Sync, { self.configuration.cache_keygen = Arc::new(keygen); self } pub async fn build(self) -> anyhow::Result { let client = RedisClusterSessionStore::client_builder(self.connection_strings) .build()?; let connection = RedisClusterSessionStore::connect(&client).await?; Ok(RedisClusterSessionStore { client, connection: Arc::new(Mutex::new(connection)), configuration: self.configuration, }) } } impl SessionStore for RedisClusterSessionStore { async fn load( &self, session_key: &SessionKey, ) -> Result, LoadError> { let cache_key = self.configuration.cache_keygen.as_ref()(session_key.as_ref()); let value: Option = self .execute_cmd("get", move || { let mut cmd = redis::cmd("GET"); cmd.arg(&cache_key); cmd }) .await .map_err(LoadError::Other)?; match value { None => Ok(None), Some(value) => Ok(Some( deserialize_session_state(&value) .map_err(LoadError::Deserialization)?, )), } } async fn save( &self, session_state: SessionState, ttl: &Duration, ) -> Result { let body = serialize_session_state(&session_state) .map_err(SaveError::Serialization)?; let session_key = generate_session_key(); let cache_key = self.configuration.cache_keygen.as_ref()(session_key.as_ref()); let ttl_secs = Self::ttl_seconds(ttl).map_err(SaveError::Other)?; self.execute_cmd::<(), _>("set_ex", move || { let mut cmd = redis::cmd("SETEX"); cmd.arg(&cache_key).arg(ttl_secs).arg(&body); cmd }) .await .map_err(SaveError::Other)?; Ok(session_key) } async fn update( &self, session_key: SessionKey, session_state: SessionState, ttl: &Duration, ) -> Result { let body = serialize_session_state(&session_state) .map_err(UpdateError::Serialization)?; let cache_key = self.configuration.cache_keygen.as_ref()(session_key.as_ref()); let ttl_secs = Self::ttl_seconds(ttl).map_err(UpdateError::Other)?; self.execute_cmd::<(), _>("set_ex", move || { let mut cmd = redis::cmd("SETEX"); cmd.arg(&cache_key).arg(ttl_secs).arg(&body); cmd }) .await .map_err(UpdateError::Other)?; Ok(session_key) } async fn update_ttl( &self, session_key: &SessionKey, ttl: &Duration, ) -> anyhow::Result<()> { let cache_key = self.configuration.cache_keygen.as_ref()(session_key.as_ref()); let ttl_secs = Self::ttl_seconds(ttl)?; self.execute_cmd("expire", move || { let mut cmd = redis::cmd("EXPIRE"); cmd.arg(&cache_key).arg(ttl_secs); cmd }) .await .map(|_: bool| ()) } async fn delete( &self, session_key: &SessionKey, ) -> Result<(), anyhow::Error> { let cache_key = self.configuration.cache_keygen.as_ref()(session_key.as_ref()); self.execute_cmd("del", move || { let mut cmd = redis::cmd("DEL"); cmd.arg(&cache_key); cmd }) .await .map(|_: i64| ()) } }