use std::sync::Arc; use chrono::{Duration, Utc}; use deadpool_redis::cluster::Connection; use serde::{Deserialize, Serialize}; use uuid::Uuid; use crate::error::AppError; /// Token payload stored in Redis #[derive(Debug, Clone, Serialize, Deserialize)] pub struct WsTokenData { pub user_id: Uuid, pub expires_at: chrono::DateTime, pub created_at: chrono::DateTime, } const WS_TOKEN_PREFIX: &str = "ws_token:"; pub const WS_TOKEN_TTL_SECONDS: i64 = 300; // Token valid for 5 minutes /// Service for managing WebSocket connection tokens pub struct WsTokenService { get_redis: Arc tokio::task::JoinHandle> + Send + Sync>, } impl WsTokenService { pub fn new( get_redis: Arc< dyn Fn() -> tokio::task::JoinHandle> + Send + Sync, >, ) -> Self { Self { get_redis } } /// Generate a new WebSocket token for the given user pub async fn generate_token(&self, user_id: Uuid) -> Result { let token = Self::random_token(); let now = Utc::now(); let token_data = WsTokenData { user_id, expires_at: now + Duration::seconds(WS_TOKEN_TTL_SECONDS), created_at: now, }; let json = serde_json::to_string(&token_data).map_err(|e| { AppError::InternalServerError(format!("Failed to serialize ws token: {}", e)) })?; let key = format!("{}{}", WS_TOKEN_PREFIX, token); let mut conn = self.get_connection().await?; // Set token in Redis with TTL redis::cmd("SETEX") .arg(&key) .arg(WS_TOKEN_TTL_SECONDS) .arg(&json) .query_async::<()>(&mut conn) .await .map_err(|e| { AppError::InternalServerError(format!("Failed to store ws token: {}", e)) })?; Ok(token) } pub async fn validate_token(&self, token: &str) -> Result { let key = format!("{}{}", WS_TOKEN_PREFIX, token); let mut conn = self.get_connection().await?; // Get and delete token atomically (one-time use) let json: Option = redis::cmd("GETDEL") .arg(&key) .query_async(&mut conn) .await .map_err(|e| { AppError::InternalServerError(format!("Failed to validate ws token: {}", e)) })?; let token_data = json.ok_or_else(|| AppError::Unauthorized)?; let ws_token_data: WsTokenData = serde_json::from_str(&token_data).map_err(|e| { AppError::InternalServerError(format!("Failed to deserialize ws token: {}", e)) })?; if Utc::now() > ws_token_data.expires_at { return Err(AppError::Unauthorized); } Ok(ws_token_data.user_id) } fn random_token() -> String { let bytes: [u8; 32] = rand::random(); hex::encode(bytes) } async fn get_connection(&self) -> Result { (self.get_redis)() .await .map_err(|e| AppError::InternalServerError(format!("Redis join error: {}", e)))? .map_err(|e| AppError::InternalServerError(format!("Redis connection error: {}", e))) } }