gitdataai/lib/channel/security.rs

153 lines
3.6 KiB
Rust

use std::time::Duration;
use uuid::Uuid;
use crate::{ChannelError, ChannelResult};
const RATE_LIMIT_SCRIPT: &str = r#"
local key = KEYS[1]
local max = tonumber(ARGV[1])
local window = tonumber(ARGV[2])
local current = tonumber(redis.call('INCR', key))
if current == 1 then
redis.call('EXPIRE', key, window)
end
if current > max then
return 0
end
return 1
"#;
#[derive(Clone)]
pub struct RateLimiter {
cache: cache::AppCache,
max_requests: u32,
window: Duration,
}
impl RateLimiter {
pub fn new(cache: cache::AppCache) -> Self {
Self {
cache,
max_requests: 100,
window: Duration::from_secs(60),
}
}
pub fn with_config(
cache: cache::AppCache,
max_requests: u32,
window: Duration,
) -> Self {
Self {
cache,
max_requests,
window,
}
}
pub async fn check_rate_limit(
&self,
user_id: Uuid,
action: &str,
) -> ChannelResult<bool> {
let cluster = require_cluster(&self.cache)?;
let key = format!("ratelimit:{}:{}", user_id, action);
let mut conn = cluster.conn();
let allowed: i64 = redis::Cmd::new()
.arg("EVAL")
.arg(RATE_LIMIT_SCRIPT)
.arg(1)
.arg(&key)
.arg(self.max_requests)
.arg(self.window.as_secs())
.query_async(&mut conn)
.await
.map_err(|e| ChannelError::Cache(cache::CacheError::Redis(e)))?;
Ok(allowed == 1)
}
pub async fn get_remaining(
&self,
user_id: Uuid,
action: &str,
) -> ChannelResult<u32> {
let key = format!("ratelimit:{}:{}", user_id, action);
let count: Option<u32> = self.cache.get(&key).await?;
let current = count.unwrap_or(0);
Ok(self.max_requests.saturating_sub(current))
}
}
const CSRF_TTL_SECS: u64 = 3600;
#[derive(Clone)]
pub struct CsrfProtection {
cache: cache::AppCache,
}
impl CsrfProtection {
pub fn new(cache: cache::AppCache) -> Self {
Self { cache }
}
pub async fn generate_token(&self, user_id: Uuid) -> ChannelResult<String> {
let token = Uuid::new_v4().to_string();
let key = format!("csrf:{}:{}", user_id, token);
let cluster = require_cluster(&self.cache)?;
let mut conn = cluster.conn();
let _: () = redis::Cmd::new()
.arg("SET")
.arg(&key)
.arg("1")
.arg("EX")
.arg(CSRF_TTL_SECS)
.query_async(&mut conn)
.await
.map_err(|e| ChannelError::Cache(cache::CacheError::Redis(e)))?;
Ok(token)
}
pub async fn validate_token(
&self,
user_id: Uuid,
token: &str,
) -> ChannelResult<bool> {
let key = format!("csrf:{}:{}", user_id, token);
let cluster = require_cluster(&self.cache)?;
let mut conn = cluster.conn();
const VALIDATE_SCRIPT: &str = r#"
local key = KEYS[1]
local exists = redis.call('EXISTS', key)
if exists == 1 then
redis.call('DEL', key)
return 1
end
return 0
"#;
let valid: i64 = redis::Cmd::new()
.arg("EVAL")
.arg(VALIDATE_SCRIPT)
.arg(1)
.arg(&key)
.query_async(&mut conn)
.await
.map_err(|e| ChannelError::Cache(cache::CacheError::Redis(e)))?;
Ok(valid == 1)
}
}
pub fn require_cluster(
cache: &cache::AppCache,
) -> ChannelResult<&cache::ClusterCache> {
cache
.cluster
.as_ref()
.ok_or(ChannelError::Internal("no cluster cache".to_string()))
}