use std::{sync::Arc, time::Duration}; use redis::{ AsyncCommands, cluster::ClusterClient, cluster_async::ClusterConnection, }; use serde::{Serialize, de::DeserializeOwned}; use tokio::time::timeout; use crate::{CacheError, CacheResult}; const DEFAULT_COMMAND_TIMEOUT: Duration = Duration::from_secs(3); #[derive(Clone, Debug)] pub struct ClusterCacheConfig { pub urls: Vec, pub key_prefix: Option, pub command_timeout: Duration, } impl ClusterCacheConfig { pub fn new(urls: Vec) -> Self { Self { urls, key_prefix: None, command_timeout: DEFAULT_COMMAND_TIMEOUT, } } } #[derive(Clone)] pub struct ClusterCache { connection: ClusterConnection, key_prefix: Option>, command_timeout: Duration, } impl ClusterCache { pub async fn connect(config: ClusterCacheConfig) -> CacheResult { if config.urls.is_empty() { return Err(CacheError::Config( "redis cluster urls are empty".to_string(), )); } let client = ClusterClient::new(config.urls).map_err(CacheError::Redis)?; let connection = timeout(config.command_timeout, client.get_async_connection()) .await .map_err(|_| CacheError::Timeout("connect redis cluster"))? .map_err(CacheError::Redis)?; Ok(Self { connection, key_prefix: config.key_prefix.map(Arc::from), command_timeout: config.command_timeout, }) } pub async fn get(&self, key: &str) -> CacheResult> where T: DeserializeOwned, { let key = self.key(key); let mut connection = self.connection.clone(); let value: Option> = self .run(redis::cmd("GET").arg(&key).query_async(&mut connection)) .await?; match value { Some(value) => serde_json::from_slice(&value) .map(Some) .map_err(CacheError::Serialize), None => Ok(None), } } pub async fn get_json( &self, key: &str, ) -> CacheResult> { self.get(key).await } pub async fn set( &self, key: &str, value: &T, ttl: Option, ) -> CacheResult<()> where T: Serialize + ?Sized, { let key = self.key(key); let value = serde_json::to_vec(value).map_err(CacheError::Serialize)?; let mut connection = self.connection.clone(); if let Some(ttl) = ttl { let seconds = ttl.as_secs().max(1); self.run::<(), _>(connection.set_ex(key, value, seconds)) .await } else { self.run::<(), _>(connection.set(key, value)).await } } pub async fn remove(&self, key: &str) -> CacheResult { let key = self.key(key); let mut connection = self.connection.clone(); let removed: u64 = self.run(connection.del(key)).await?; Ok(removed > 0) } pub async fn exists(&self, key: &str) -> CacheResult { let key = self.key(key); let mut connection = self.connection.clone(); self.run(connection.exists(key)).await } pub async fn set_nx_with_ttl( &self, key: &str, value: &T, ttl: Duration, ) -> CacheResult where T: Serialize, { let key = self.key(key); let value = serde_json::to_vec(value).map_err(CacheError::Serialize)?; let mut connection = self.connection.clone(); let result: Option = self .run( redis::cmd("SET") .arg(&key) .arg(&value) .arg("NX") .arg("EX") .arg(ttl.as_secs().max(1)) .query_async(&mut connection), ) .await?; Ok(result.is_some()) } pub async fn expire(&self, key: &str, ttl: Duration) -> CacheResult { let key = self.key(key); let mut connection = self.connection.clone(); self.run(connection.expire(key, ttl.as_secs() as i64)).await } pub async fn delete_pattern(&self, pattern: &str) -> CacheResult { let pattern = self.key(pattern); let mut connection = self.connection.clone(); let keys: Vec = self .run( redis::cmd("KEYS") .arg(&pattern) .query_async(&mut connection), ) .await?; if keys.is_empty() { return Ok(0); } let mut connection = self.connection.clone(); self.run(connection.del(keys)).await } pub async fn ping(&self) -> CacheResult<()> { let mut connection = self.connection.clone(); let pong: String = self .run(redis::cmd("PING").query_async(&mut connection)) .await?; if pong == "PONG" { Ok(()) } else { Err(CacheError::Protocol(format!( "unexpected redis PING response: {pong}" ))) } } pub fn conn(&self) -> ClusterConnection { self.connection.clone() } fn key(&self, key: &str) -> String { match &self.key_prefix { Some(prefix) => format!("{prefix}:{key}"), None => key.to_string(), } } async fn run(&self, future: F) -> CacheResult where F: std::future::Future>, { timeout(self.command_timeout, future) .await .map_err(|_| CacheError::Timeout("redis command"))? .map_err(CacheError::Redis) } }