gitdataai/libs/service/ws_token.rs
2026-04-14 19:02:01 +08:00

104 lines
3.2 KiB
Rust

use std::sync::Arc;
use chrono::{Duration, Utc};
use deadpool_redis::cluster::Connection;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::error::AppError;
/// Token payload stored in Redis
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WsTokenData {
pub user_id: Uuid,
pub expires_at: chrono::DateTime<Utc>,
pub created_at: chrono::DateTime<Utc>,
}
const WS_TOKEN_PREFIX: &str = "ws_token:";
pub const WS_TOKEN_TTL_SECONDS: i64 = 300; // Token valid for 5 minutes
/// Service for managing WebSocket connection tokens
pub struct WsTokenService {
get_redis: Arc<dyn Fn() -> tokio::task::JoinHandle<anyhow::Result<Connection>> + Send + Sync>,
}
impl WsTokenService {
pub fn new(
get_redis: Arc<
dyn Fn() -> tokio::task::JoinHandle<anyhow::Result<Connection>> + Send + Sync,
>,
) -> Self {
Self { get_redis }
}
/// Generate a new WebSocket token for the given user
pub async fn generate_token(&self, user_id: Uuid) -> Result<String, AppError> {
let token = Self::random_token();
let now = Utc::now();
let token_data = WsTokenData {
user_id,
expires_at: now + Duration::seconds(WS_TOKEN_TTL_SECONDS),
created_at: now,
};
let json = serde_json::to_string(&token_data).map_err(|e| {
AppError::InternalServerError(format!("Failed to serialize ws token: {}", e))
})?;
let key = format!("{}{}", WS_TOKEN_PREFIX, token);
let mut conn = self.get_connection().await?;
// Set token in Redis with TTL
redis::cmd("SETEX")
.arg(&key)
.arg(WS_TOKEN_TTL_SECONDS)
.arg(&json)
.query_async::<()>(&mut conn)
.await
.map_err(|e| {
AppError::InternalServerError(format!("Failed to store ws token: {}", e))
})?;
Ok(token)
}
pub async fn validate_token(&self, token: &str) -> Result<Uuid, AppError> {
let key = format!("{}{}", WS_TOKEN_PREFIX, token);
let mut conn = self.get_connection().await?;
// Get and delete token atomically (one-time use)
let json: Option<String> = redis::cmd("GETDEL")
.arg(&key)
.query_async(&mut conn)
.await
.map_err(|e| {
AppError::InternalServerError(format!("Failed to validate ws token: {}", e))
})?;
let token_data = json.ok_or_else(|| AppError::Unauthorized)?;
let ws_token_data: WsTokenData = serde_json::from_str(&token_data).map_err(|e| {
AppError::InternalServerError(format!("Failed to deserialize ws token: {}", e))
})?;
if Utc::now() > ws_token_data.expires_at {
return Err(AppError::Unauthorized);
}
Ok(ws_token_data.user_id)
}
fn random_token() -> String {
let bytes: [u8; 32] = rand::random();
hex::encode(bytes)
}
async fn get_connection(&self) -> Result<Connection, AppError> {
(self.get_redis)()
.await
.map_err(|e| AppError::InternalServerError(format!("Redis join error: {}", e)))?
.map_err(|e| AppError::InternalServerError(format!("Redis connection error: {}", e)))
}
}