use std::{ collections::HashMap, sync::Arc, time::{Duration, Instant}, }; use tokio::{sync::RwLock, time::interval}; #[derive(Debug, Clone)] pub struct RateLimitConfig { pub requests_per_window: u32, pub window_duration: Duration, } impl Default for RateLimitConfig { fn default() -> Self { Self { requests_per_window: 100, window_duration: Duration::from_secs(60), } } } #[derive(Debug)] struct RateLimitState { count: u32, reset_time: Instant, } #[derive(Debug, Clone)] pub struct RateLimiter { limits: Arc>>, config: RateLimitConfig, } impl RateLimiter { pub fn new(config: RateLimitConfig) -> Self { Self { limits: Arc::new(RwLock::new(HashMap::new())), config, } } pub async fn is_allowed(&self, key: &str) -> bool { let now = Instant::now(); let mut limits = self.limits.write().await; let state = limits .entry(key.to_string()) .or_insert_with(|| RateLimitState { count: 0, reset_time: now + self.config.window_duration, }); if now >= state.reset_time { state.count = 0; state.reset_time = now + self.config.window_duration; } if state.count >= self.config.requests_per_window { return false; } state.count += 1; true } pub fn start_cleanup(self: Arc) -> tokio::task::JoinHandle<()> { tokio::spawn(async move { let mut ticker = interval(Duration::from_secs(300)); loop { ticker.tick().await; let now = Instant::now(); let mut limits = self.limits.write().await; limits.retain(|_, state| now < state.reset_time); } }) } } pub struct SshRateLimiter { limiter: RateLimiter, } impl SshRateLimiter { pub fn new() -> Self { Self { limiter: RateLimiter::new(RateLimitConfig::default()), } } pub async fn is_user_allowed(&self, user_id: &str) -> bool { self.limiter.is_allowed(&format!("user:{}", user_id)).await } pub async fn is_ip_allowed(&self, ip_address: &str) -> bool { self.limiter.is_allowed(&format!("ip:{}", ip_address)).await } pub async fn is_repo_access_allowed( &self, user_id: &str, repo_path: &str, ) -> bool { self.limiter .is_allowed(&format!("repo_access:{}:{}", user_id, repo_path)) .await } pub fn start_cleanup(self: Arc) -> tokio::task::JoinHandle<()> { RateLimiter::start_cleanup(Arc::new(self.limiter.clone())) } }