153 lines
3.6 KiB
Rust
153 lines
3.6 KiB
Rust
use std::time::Duration;
|
|
use uuid::Uuid;
|
|
|
|
use crate::{ChannelError, ChannelResult};
|
|
|
|
const RATE_LIMIT_SCRIPT: &str = r#"
|
|
local key = KEYS[1]
|
|
local max = tonumber(ARGV[1])
|
|
local window = tonumber(ARGV[2])
|
|
local current = tonumber(redis.call('INCR', key))
|
|
if current == 1 then
|
|
redis.call('EXPIRE', key, window)
|
|
end
|
|
if current > max then
|
|
return 0
|
|
end
|
|
return 1
|
|
"#;
|
|
|
|
#[derive(Clone)]
|
|
pub struct RateLimiter {
|
|
cache: cache::AppCache,
|
|
max_requests: u32,
|
|
window: Duration,
|
|
}
|
|
|
|
impl RateLimiter {
|
|
pub fn new(cache: cache::AppCache) -> Self {
|
|
Self {
|
|
cache,
|
|
max_requests: 100,
|
|
window: Duration::from_secs(60),
|
|
}
|
|
}
|
|
|
|
pub fn with_config(
|
|
cache: cache::AppCache,
|
|
max_requests: u32,
|
|
window: Duration,
|
|
) -> Self {
|
|
Self {
|
|
cache,
|
|
max_requests,
|
|
window,
|
|
}
|
|
}
|
|
|
|
pub async fn check_rate_limit(
|
|
&self,
|
|
user_id: Uuid,
|
|
action: &str,
|
|
) -> ChannelResult<bool> {
|
|
let cluster = require_cluster(&self.cache)?;
|
|
let key = format!("ratelimit:{}:{}", user_id, action);
|
|
let mut conn = cluster.conn();
|
|
|
|
let allowed: i64 = redis::Cmd::new()
|
|
.arg("EVAL")
|
|
.arg(RATE_LIMIT_SCRIPT)
|
|
.arg(1)
|
|
.arg(&key)
|
|
.arg(self.max_requests)
|
|
.arg(self.window.as_secs())
|
|
.query_async(&mut conn)
|
|
.await
|
|
.map_err(|e| ChannelError::Cache(cache::CacheError::Redis(e)))?;
|
|
|
|
Ok(allowed == 1)
|
|
}
|
|
|
|
pub async fn get_remaining(
|
|
&self,
|
|
user_id: Uuid,
|
|
action: &str,
|
|
) -> ChannelResult<u32> {
|
|
let key = format!("ratelimit:{}:{}", user_id, action);
|
|
let count: Option<u32> = self.cache.get(&key).await?;
|
|
let current = count.unwrap_or(0);
|
|
Ok(self.max_requests.saturating_sub(current))
|
|
}
|
|
}
|
|
|
|
const CSRF_TTL_SECS: u64 = 3600;
|
|
|
|
#[derive(Clone)]
|
|
pub struct CsrfProtection {
|
|
cache: cache::AppCache,
|
|
}
|
|
|
|
impl CsrfProtection {
|
|
pub fn new(cache: cache::AppCache) -> Self {
|
|
Self { cache }
|
|
}
|
|
|
|
pub async fn generate_token(&self, user_id: Uuid) -> ChannelResult<String> {
|
|
let token = Uuid::new_v4().to_string();
|
|
let key = format!("csrf:{}:{}", user_id, token);
|
|
let cluster = require_cluster(&self.cache)?;
|
|
let mut conn = cluster.conn();
|
|
|
|
let _: () = redis::Cmd::new()
|
|
.arg("SET")
|
|
.arg(&key)
|
|
.arg("1")
|
|
.arg("EX")
|
|
.arg(CSRF_TTL_SECS)
|
|
.query_async(&mut conn)
|
|
.await
|
|
.map_err(|e| ChannelError::Cache(cache::CacheError::Redis(e)))?;
|
|
|
|
Ok(token)
|
|
}
|
|
|
|
pub async fn validate_token(
|
|
&self,
|
|
user_id: Uuid,
|
|
token: &str,
|
|
) -> ChannelResult<bool> {
|
|
let key = format!("csrf:{}:{}", user_id, token);
|
|
let cluster = require_cluster(&self.cache)?;
|
|
let mut conn = cluster.conn();
|
|
const VALIDATE_SCRIPT: &str = r#"
|
|
local key = KEYS[1]
|
|
local exists = redis.call('EXISTS', key)
|
|
if exists == 1 then
|
|
redis.call('DEL', key)
|
|
return 1
|
|
end
|
|
return 0
|
|
"#;
|
|
|
|
let valid: i64 = redis::Cmd::new()
|
|
.arg("EVAL")
|
|
.arg(VALIDATE_SCRIPT)
|
|
.arg(1)
|
|
.arg(&key)
|
|
.query_async(&mut conn)
|
|
.await
|
|
.map_err(|e| ChannelError::Cache(cache::CacheError::Redis(e)))?;
|
|
|
|
Ok(valid == 1)
|
|
}
|
|
}
|
|
|
|
pub(crate) fn require_cluster(
|
|
cache: &cache::AppCache,
|
|
) -> ChannelResult<&cache::ClusterCache> {
|
|
cache
|
|
.cluster
|
|
.as_ref()
|
|
.ok_or(ChannelError::Internal("no cluster cache".to_string()))
|
|
}
|