gitdataai/lib/channel/security.rs

185 lines
4.6 KiB
Rust

use std::time::Duration;
use track::CounterVec;
use uuid::Uuid;
use crate::{ChannelError, ChannelResult};
const RATE_LIMIT_SCRIPT: &str = r#"
local key = KEYS[1]
local max = tonumber(ARGV[1])
local window = tonumber(ARGV[2])
local current = tonumber(redis.call('INCR', key))
if current == 1 then
redis.call('EXPIRE', key, window)
end
if current > max then
return 0
end
return 1
"#;
#[derive(Clone)]
pub struct RateLimiter {
cache: cache::AppCache,
max_requests: u32,
window: Duration,
metrics: Option<RateLimiterMetrics>,
}
#[derive(Clone)]
struct RateLimiterMetrics {
outcomes: CounterVec,
}
impl RateLimiterMetrics {
fn new(registry: &track::MetricsRegistry) -> Self {
Self {
outcomes: registry
.register_counter_vec(
"rate_limiter_decisions_total",
"Rate limiter decisions",
&["action", "outcome"],
)
.expect("failed to register rate_limiter_decisions_total"),
}
}
}
impl RateLimiter {
pub fn new(cache: cache::AppCache) -> Self {
Self {
cache,
max_requests: 100,
window: Duration::from_secs(60),
metrics: None,
}
}
pub fn with_config(
cache: cache::AppCache,
max_requests: u32,
window: Duration,
) -> Self {
Self {
cache,
max_requests,
window,
metrics: None,
}
}
pub fn set_metrics(&mut self, registry: &track::MetricsRegistry) {
self.metrics = Some(RateLimiterMetrics::new(registry));
}
pub async fn check_rate_limit(
&self,
user_id: Uuid,
action: &str,
) -> ChannelResult<bool> {
let cluster = require_cluster(&self.cache)?;
let key = format!("ratelimit:{}:{}", user_id, action);
let mut conn = cluster.conn();
let allowed: i64 = redis::Cmd::new()
.arg("EVAL")
.arg(RATE_LIMIT_SCRIPT)
.arg(1)
.arg(&key)
.arg(self.max_requests)
.arg(self.window.as_secs())
.query_async(&mut conn)
.await
.map_err(|e| ChannelError::Cache(cache::CacheError::Redis(e)))?;
let is_allowed = allowed == 1;
if let Some(m) = &self.metrics {
let outcome = if is_allowed { "allowed" } else { "blocked" };
m.outcomes.with_label_values(&[action, outcome]).inc();
}
Ok(is_allowed)
}
pub async fn get_remaining(
&self,
user_id: Uuid,
action: &str,
) -> ChannelResult<u32> {
let key = format!("ratelimit:{}:{}", user_id, action);
let count: Option<u32> = self.cache.get(&key).await?;
let current = count.unwrap_or(0);
Ok(self.max_requests.saturating_sub(current))
}
}
const CSRF_TTL_SECS: u64 = 3600;
#[derive(Clone)]
pub struct CsrfProtection {
cache: cache::AppCache,
}
impl CsrfProtection {
pub fn new(cache: cache::AppCache) -> Self {
Self { cache }
}
pub async fn generate_token(&self, user_id: Uuid) -> ChannelResult<String> {
let token = Uuid::new_v4().to_string();
let key = format!("csrf:{}:{}", user_id, token);
let cluster = require_cluster(&self.cache)?;
let mut conn = cluster.conn();
let _: () = redis::Cmd::new()
.arg("SET")
.arg(&key)
.arg("1")
.arg("EX")
.arg(CSRF_TTL_SECS)
.query_async(&mut conn)
.await
.map_err(|e| ChannelError::Cache(cache::CacheError::Redis(e)))?;
Ok(token)
}
pub async fn validate_token(
&self,
user_id: Uuid,
token: &str,
) -> ChannelResult<bool> {
let key = format!("csrf:{}:{}", user_id, token);
let cluster = require_cluster(&self.cache)?;
let mut conn = cluster.conn();
const VALIDATE_SCRIPT: &str = r#"
local key = KEYS[1]
local exists = redis.call('EXISTS', key)
if exists == 1 then
redis.call('DEL', key)
return 1
end
return 0
"#;
let valid: i64 = redis::Cmd::new()
.arg("EVAL")
.arg(VALIDATE_SCRIPT)
.arg(1)
.arg(&key)
.query_async(&mut conn)
.await
.map_err(|e| ChannelError::Cache(cache::CacheError::Redis(e)))?;
Ok(valid == 1)
}
}
pub fn require_cluster(
cache: &cache::AppCache,
) -> ChannelResult<&cache::ClusterCache> {
cache
.cluster
.as_ref()
.ok_or(ChannelError::Internal("no cluster cache".to_string()))
}