104 lines
3.2 KiB
Rust
104 lines
3.2 KiB
Rust
use std::sync::Arc;
|
|
|
|
use chrono::{Duration, Utc};
|
|
use deadpool_redis::cluster::Connection;
|
|
use serde::{Deserialize, Serialize};
|
|
use uuid::Uuid;
|
|
|
|
use crate::error::AppError;
|
|
|
|
/// Token payload stored in Redis
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct WsTokenData {
|
|
pub user_id: Uuid,
|
|
pub expires_at: chrono::DateTime<Utc>,
|
|
pub created_at: chrono::DateTime<Utc>,
|
|
}
|
|
|
|
const WS_TOKEN_PREFIX: &str = "ws_token:";
|
|
pub const WS_TOKEN_TTL_SECONDS: i64 = 300; // Token valid for 5 minutes
|
|
|
|
/// Service for managing WebSocket connection tokens
|
|
pub struct WsTokenService {
|
|
get_redis: Arc<dyn Fn() -> tokio::task::JoinHandle<anyhow::Result<Connection>> + Send + Sync>,
|
|
}
|
|
|
|
impl WsTokenService {
|
|
pub fn new(
|
|
get_redis: Arc<
|
|
dyn Fn() -> tokio::task::JoinHandle<anyhow::Result<Connection>> + Send + Sync,
|
|
>,
|
|
) -> Self {
|
|
Self { get_redis }
|
|
}
|
|
|
|
/// Generate a new WebSocket token for the given user
|
|
pub async fn generate_token(&self, user_id: Uuid) -> Result<String, AppError> {
|
|
let token = Self::random_token();
|
|
let now = Utc::now();
|
|
let token_data = WsTokenData {
|
|
user_id,
|
|
expires_at: now + Duration::seconds(WS_TOKEN_TTL_SECONDS),
|
|
created_at: now,
|
|
};
|
|
|
|
let json = serde_json::to_string(&token_data).map_err(|e| {
|
|
AppError::InternalServerError(format!("Failed to serialize ws token: {}", e))
|
|
})?;
|
|
|
|
let key = format!("{}{}", WS_TOKEN_PREFIX, token);
|
|
let mut conn = self.get_connection().await?;
|
|
|
|
// Set token in Redis with TTL
|
|
redis::cmd("SETEX")
|
|
.arg(&key)
|
|
.arg(WS_TOKEN_TTL_SECONDS)
|
|
.arg(&json)
|
|
.query_async::<()>(&mut conn)
|
|
.await
|
|
.map_err(|e| {
|
|
AppError::InternalServerError(format!("Failed to store ws token: {}", e))
|
|
})?;
|
|
|
|
Ok(token)
|
|
}
|
|
|
|
pub async fn validate_token(&self, token: &str) -> Result<Uuid, AppError> {
|
|
let key = format!("{}{}", WS_TOKEN_PREFIX, token);
|
|
let mut conn = self.get_connection().await?;
|
|
|
|
// Get and delete token atomically (one-time use)
|
|
let json: Option<String> = redis::cmd("GETDEL")
|
|
.arg(&key)
|
|
.query_async(&mut conn)
|
|
.await
|
|
.map_err(|e| {
|
|
AppError::InternalServerError(format!("Failed to validate ws token: {}", e))
|
|
})?;
|
|
|
|
let token_data = json.ok_or_else(|| AppError::Unauthorized)?;
|
|
|
|
let ws_token_data: WsTokenData = serde_json::from_str(&token_data).map_err(|e| {
|
|
AppError::InternalServerError(format!("Failed to deserialize ws token: {}", e))
|
|
})?;
|
|
|
|
if Utc::now() > ws_token_data.expires_at {
|
|
return Err(AppError::Unauthorized);
|
|
}
|
|
|
|
Ok(ws_token_data.user_id)
|
|
}
|
|
|
|
fn random_token() -> String {
|
|
let bytes: [u8; 32] = rand::random();
|
|
hex::encode(bytes)
|
|
}
|
|
|
|
async fn get_connection(&self) -> Result<Connection, AppError> {
|
|
(self.get_redis)()
|
|
.await
|
|
.map_err(|e| AppError::InternalServerError(format!("Redis join error: {}", e)))?
|
|
.map_err(|e| AppError::InternalServerError(format!("Redis connection error: {}", e)))
|
|
}
|
|
}
|