use std::time::Duration; use track::CounterVec; use uuid::Uuid; use crate::{ChannelError, ChannelResult}; const RATE_LIMIT_SCRIPT: &str = r#" local key = KEYS[1] local max = tonumber(ARGV[1]) local window = tonumber(ARGV[2]) local current = tonumber(redis.call('INCR', key)) if current == 1 then redis.call('EXPIRE', key, window) end if current > max then return 0 end return 1 "#; #[derive(Clone)] pub struct RateLimiter { cache: cache::AppCache, max_requests: u32, window: Duration, metrics: Option, } #[derive(Clone)] struct RateLimiterMetrics { outcomes: CounterVec, } impl RateLimiterMetrics { fn new(registry: &track::MetricsRegistry) -> Self { Self { outcomes: registry .register_counter_vec( "rate_limiter_decisions_total", "Rate limiter decisions", &["action", "outcome"], ) .expect("failed to register rate_limiter_decisions_total"), } } } impl RateLimiter { pub fn new(cache: cache::AppCache) -> Self { Self { cache, max_requests: 100, window: Duration::from_secs(60), metrics: None, } } pub fn with_config( cache: cache::AppCache, max_requests: u32, window: Duration, ) -> Self { Self { cache, max_requests, window, metrics: None, } } pub fn set_metrics(&mut self, registry: &track::MetricsRegistry) { self.metrics = Some(RateLimiterMetrics::new(registry)); } pub async fn check_rate_limit( &self, user_id: Uuid, action: &str, ) -> ChannelResult { let cluster = require_cluster(&self.cache)?; let key = format!("ratelimit:{}:{}", user_id, action); let mut conn = cluster.conn(); let allowed: i64 = redis::Cmd::new() .arg("EVAL") .arg(RATE_LIMIT_SCRIPT) .arg(1) .arg(&key) .arg(self.max_requests) .arg(self.window.as_secs()) .query_async(&mut conn) .await .map_err(|e| ChannelError::Cache(cache::CacheError::Redis(e)))?; let is_allowed = allowed == 1; if let Some(m) = &self.metrics { let outcome = if is_allowed { "allowed" } else { "blocked" }; m.outcomes.with_label_values(&[action, outcome]).inc(); } Ok(is_allowed) } pub async fn get_remaining( &self, user_id: Uuid, action: &str, ) -> ChannelResult { let key = format!("ratelimit:{}:{}", user_id, action); let count: Option = self.cache.get(&key).await?; let current = count.unwrap_or(0); Ok(self.max_requests.saturating_sub(current)) } } const CSRF_TTL_SECS: u64 = 3600; #[derive(Clone)] pub struct CsrfProtection { cache: cache::AppCache, } impl CsrfProtection { pub fn new(cache: cache::AppCache) -> Self { Self { cache } } pub async fn generate_token(&self, user_id: Uuid) -> ChannelResult { let token = Uuid::new_v4().to_string(); let key = format!("csrf:{}:{}", user_id, token); let cluster = require_cluster(&self.cache)?; let mut conn = cluster.conn(); let _: () = redis::Cmd::new() .arg("SET") .arg(&key) .arg("1") .arg("EX") .arg(CSRF_TTL_SECS) .query_async(&mut conn) .await .map_err(|e| ChannelError::Cache(cache::CacheError::Redis(e)))?; Ok(token) } pub async fn validate_token( &self, user_id: Uuid, token: &str, ) -> ChannelResult { let key = format!("csrf:{}:{}", user_id, token); let cluster = require_cluster(&self.cache)?; let mut conn = cluster.conn(); const VALIDATE_SCRIPT: &str = r#" local key = KEYS[1] local exists = redis.call('EXISTS', key) if exists == 1 then redis.call('DEL', key) return 1 end return 0 "#; let valid: i64 = redis::Cmd::new() .arg("EVAL") .arg(VALIDATE_SCRIPT) .arg(1) .arg(&key) .query_async(&mut conn) .await .map_err(|e| ChannelError::Cache(cache::CacheError::Redis(e)))?; Ok(valid == 1) } } pub fn require_cluster( cache: &cache::AppCache, ) -> ChannelResult<&cache::ClusterCache> { cache .cluster .as_ref() .ok_or(ChannelError::Internal("no cluster cache".to_string())) }