use std::time::Duration; use uuid::Uuid; use crate::{ChannelError, ChannelResult}; const RATE_LIMIT_SCRIPT: &str = r#" local key = KEYS[1] local max = tonumber(ARGV[1]) local window = tonumber(ARGV[2]) local current = tonumber(redis.call('INCR', key)) if current == 1 then redis.call('EXPIRE', key, window) end if current > max then return 0 end return 1 "#; #[derive(Clone)] pub struct RateLimiter { cache: cache::AppCache, max_requests: u32, window: Duration, } impl RateLimiter { pub fn new(cache: cache::AppCache) -> Self { Self { cache, max_requests: 100, window: Duration::from_secs(60), } } pub fn with_config( cache: cache::AppCache, max_requests: u32, window: Duration, ) -> Self { Self { cache, max_requests, window, } } pub async fn check_rate_limit( &self, user_id: Uuid, action: &str, ) -> ChannelResult { let cluster = require_cluster(&self.cache)?; let key = format!("ratelimit:{}:{}", user_id, action); let mut conn = cluster.conn(); let allowed: i64 = redis::Cmd::new() .arg("EVAL") .arg(RATE_LIMIT_SCRIPT) .arg(1) .arg(&key) .arg(self.max_requests) .arg(self.window.as_secs()) .query_async(&mut conn) .await .map_err(|e| ChannelError::Cache(cache::CacheError::Redis(e)))?; Ok(allowed == 1) } pub async fn get_remaining( &self, user_id: Uuid, action: &str, ) -> ChannelResult { let key = format!("ratelimit:{}:{}", user_id, action); let count: Option = self.cache.get(&key).await?; let current = count.unwrap_or(0); Ok(self.max_requests.saturating_sub(current)) } } const CSRF_TTL_SECS: u64 = 3600; #[derive(Clone)] pub struct CsrfProtection { cache: cache::AppCache, } impl CsrfProtection { pub fn new(cache: cache::AppCache) -> Self { Self { cache } } pub async fn generate_token(&self, user_id: Uuid) -> ChannelResult { let token = Uuid::new_v4().to_string(); let key = format!("csrf:{}:{}", user_id, token); let cluster = require_cluster(&self.cache)?; let mut conn = cluster.conn(); let _: () = redis::Cmd::new() .arg("SET") .arg(&key) .arg("1") .arg("EX") .arg(CSRF_TTL_SECS) .query_async(&mut conn) .await .map_err(|e| ChannelError::Cache(cache::CacheError::Redis(e)))?; Ok(token) } pub async fn validate_token( &self, user_id: Uuid, token: &str, ) -> ChannelResult { let key = format!("csrf:{}:{}", user_id, token); let cluster = require_cluster(&self.cache)?; let mut conn = cluster.conn(); const VALIDATE_SCRIPT: &str = r#" local key = KEYS[1] local exists = redis.call('EXISTS', key) if exists == 1 then redis.call('DEL', key) return 1 end return 0 "#; let valid: i64 = redis::Cmd::new() .arg("EVAL") .arg(VALIDATE_SCRIPT) .arg(1) .arg(&key) .query_async(&mut conn) .await .map_err(|e| ChannelError::Cache(cache::CacheError::Redis(e)))?; Ok(valid == 1) } } pub fn require_cluster( cache: &cache::AppCache, ) -> ChannelResult<&cache::ClusterCache> { cache .cluster .as_ref() .ok_or(ChannelError::Internal("no cluster cache".to_string())) }