use std::collections::HashMap; use std::sync::Arc; use std::time::{Duration, Instant}; use tokio::sync::RwLock; use tokio::time::interval; #[derive(Debug, Clone)] pub struct RateLimitConfig { pub requests_per_window: u32, pub window_duration: Duration, } impl Default for RateLimitConfig { fn default() -> Self { Self { requests_per_window: 100, window_duration: Duration::from_secs(60), } } } #[derive(Debug)] struct RateLimitState { count: u32, reset_time: Instant, } pub struct RateLimiter { limits: Arc>>, config: RateLimitConfig, } impl RateLimiter { pub fn new(config: RateLimitConfig) -> Self { Self { limits: Arc::new(RwLock::new(HashMap::new())), config, } } pub async fn is_allowed(&self, key: &str) -> bool { let now = Instant::now(); let mut limits = self.limits.write().await; let state = limits .entry(key.to_string()) .or_insert_with(|| RateLimitState { count: 0, reset_time: now + self.config.window_duration, }); if now >= state.reset_time { state.count = 0; state.reset_time = now + self.config.window_duration; } if state.count >= self.config.requests_per_window { return false; } state.count += 1; true } pub async fn remaining_requests(&self, key: &str) -> u32 { let now = Instant::now(); let limits = self.limits.read().await; if let Some(state) = limits.get(key) { if now >= state.reset_time { self.config.requests_per_window } else { self.config.requests_per_window.saturating_sub(state.count) } } else { self.config.requests_per_window } } pub async fn reset_time(&self, key: &str) -> Duration { let now = Instant::now(); let limits = self.limits.read().await; if let Some(state) = limits.get(key) { if now >= state.reset_time { Duration::from_secs(0) } else { state.reset_time.duration_since(now) } } else { Duration::from_secs(0) } } /// Start a background cleanup task that removes expired entries every 5 minutes. /// This prevents unbounded HashMap growth. pub fn start_cleanup(self: Arc) -> tokio::task::JoinHandle<()> { tokio::spawn(async move { let mut ticker = interval(Duration::from_secs(300)); // every 5 minutes loop { ticker.tick().await; let now = Instant::now(); let mut limits = self.limits.write().await; limits.retain(|_, state| now < state.reset_time); } }) } } pub struct SshRateLimiter { limiter: RateLimiter, } impl SshRateLimiter { pub fn new() -> Self { Self { limiter: RateLimiter::new(RateLimitConfig::default()), } } pub async fn is_user_allowed(&self, user_id: &str) -> bool { self.limiter.is_allowed(&format!("user:{}", user_id)).await } pub async fn is_ip_allowed(&self, ip_address: &str) -> bool { self.limiter.is_allowed(&format!("ip:{}", ip_address)).await } pub async fn is_repo_access_allowed(&self, user_id: &str, repo_path: &str) -> bool { self.limiter .is_allowed(&format!("repo_access:{}:{}", user_id, repo_path)) .await } }