use redis::AsyncCommands; use std::time::Duration; use uuid::Uuid; #[derive(Clone)] pub struct RateLimiter { cache: db::cache::AppCache, max_requests: u32, window: Duration, } impl RateLimiter { pub fn new(cache: db::cache::AppCache) -> Self { Self { cache, max_requests: 100, window: Duration::from_secs(60), } } pub async fn check_rate_limit( &self, user_id: Uuid, action: &str, ) -> Result { let key = format!("ratelimit:{}:{}", user_id, action); let mut conn = self .cache .conn() .await .map_err(|_| crate::error::AppTransportError::Internal)?; // Atomic INCR with EX NX — sets TTL only on first creation let count: u32 = redis::Cmd::new() .arg("INCR") .arg(&key) .query_async(&mut conn) .await .map_err(|_| crate::error::AppTransportError::Internal)?; // Set expiry only when the key is newly created (count == 1) if count == 1 { let _: () = redis::Cmd::new() .arg("EXPIRE") .arg(&key) .arg(self.window.as_secs()) .query_async(&mut conn) .await .map_err(|_| crate::error::AppTransportError::Internal)?; } Ok(count <= self.max_requests) } pub async fn get_remaining( &self, user_id: Uuid, action: &str, ) -> Result { let key = format!("ratelimit:{}:{}", user_id, action); let mut conn = self .cache .conn() .await .map_err(|_| crate::error::AppTransportError::Internal)?; let count: Option = conn .get(&key) .await .map_err(|_| crate::error::AppTransportError::Internal)?; let current = count.unwrap_or(0); Ok(self.max_requests.saturating_sub(current)) } } #[derive(Clone)] pub struct CsrfProtection { cache: db::cache::AppCache, } impl CsrfProtection { pub fn new(cache: db::cache::AppCache) -> Self { Self { cache } } pub async fn generate_token( &self, user_id: Uuid, ) -> Result { let token = Uuid::new_v4().to_string(); let key = format!("csrf:{}:{}", user_id, token); let mut conn = self .cache .conn() .await .map_err(|_| crate::error::AppTransportError::Internal)?; let _: () = conn .set_ex(&key, "1", 3600) .await .map_err(|_| crate::error::AppTransportError::Internal)?; Ok(token) } pub async fn validate_token( &self, user_id: Uuid, token: &str, ) -> Result { let key = format!("csrf:{}:{}", user_id, token); let mut conn = self .cache .conn() .await .map_err(|_| crate::error::AppTransportError::Internal)?; let exists: bool = conn .exists(&key) .await .map_err(|_| crate::error::AppTransportError::Internal)?; if exists { let _: () = conn .del(&key) .await .map_err(|_| crate::error::AppTransportError::Internal)?; } Ok(exists) } }