307 lines
9.9 KiB
Rust
307 lines
9.9 KiB
Rust
use base64::Engine;
|
|
use hmac::{KeyInit, Mac};
|
|
use sha2::Sha256;
|
|
use uuid::Uuid;
|
|
|
|
use crate::{
|
|
ChannelBus, ChannelError, ChannelResult, security::require_cluster,
|
|
};
|
|
|
|
type HmacSha256 = hmac::Hmac<Sha256>;
|
|
|
|
const VERSION: u8 = 0;
|
|
pub const TOKEN_TTL_SECS: u64 = 600;
|
|
const SESSION_TTL_SECS: u64 = 1800;
|
|
const MAX_LIFETIME_SECS: i64 = 3000;
|
|
|
|
const TOKEN_PREFIX: &str = "token:access:";
|
|
const SESSION_PREFIX: &str = "channel:session:";
|
|
|
|
const MAX_TOKEN_BASE64_LEN: usize = 256;
|
|
|
|
#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)]
|
|
pub struct ChannelAccessToken {
|
|
pub access_token: String,
|
|
}
|
|
|
|
#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)]
|
|
pub struct ChannelTokenApply {
|
|
pub client_id: String,
|
|
pub device_id: String,
|
|
}
|
|
|
|
#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)]
|
|
pub struct ChannelTokenContext {
|
|
pub user_id: Uuid,
|
|
pub device_id: String,
|
|
pub client_id: String,
|
|
}
|
|
|
|
struct TokenPayload {
|
|
user_id: Uuid,
|
|
created_at: i64,
|
|
}
|
|
|
|
impl TokenPayload {
|
|
const LEN: usize = 57;
|
|
|
|
fn encode(&self, signing_key: &[u8]) -> ChannelResult<Vec<u8>> {
|
|
let mut buf = Vec::with_capacity(Self::LEN);
|
|
buf.push(VERSION);
|
|
buf.extend_from_slice(self.user_id.as_bytes());
|
|
buf.extend_from_slice(&self.created_at.to_be_bytes());
|
|
|
|
let tag = hmac_sign(signing_key, &buf)?;
|
|
buf.extend_from_slice(&tag);
|
|
|
|
Ok(buf)
|
|
}
|
|
|
|
fn decode(bytes: &[u8], signing_key: &[u8]) -> ChannelResult<Self> {
|
|
if bytes.len() != Self::LEN || bytes[0] != VERSION {
|
|
return Err(ChannelError::TokenInvalidOrExpired);
|
|
}
|
|
|
|
let expected_tag = hmac_sign(signing_key, &bytes[..25])?;
|
|
if !constant_time_eq(&expected_tag, &bytes[25..]) {
|
|
return Err(ChannelError::TokenInvalidOrExpired);
|
|
}
|
|
|
|
let user_id_bytes: [u8; 16] = bytes[1..17].try_into().map_err(
|
|
|_: std::array::TryFromSliceError| {
|
|
ChannelError::TokenInvalidOrExpired
|
|
},
|
|
)?;
|
|
let user_id = Uuid::from_bytes(user_id_bytes);
|
|
let created_at_bytes: [u8; 8] = bytes[17..25].try_into().map_err(
|
|
|_: std::array::TryFromSliceError| {
|
|
ChannelError::TokenInvalidOrExpired
|
|
},
|
|
)?;
|
|
let created_at = i64::from_be_bytes(created_at_bytes);
|
|
|
|
Ok(TokenPayload {
|
|
user_id,
|
|
created_at,
|
|
})
|
|
}
|
|
}
|
|
|
|
impl ChannelBus {
|
|
fn signing_key(&self) -> ChannelResult<[u8; 32]> {
|
|
let secret =
|
|
self.inner.config.signing_secret.as_deref().ok_or(
|
|
ChannelError::Internal("no signing secret".to_string()),
|
|
)?;
|
|
let mut mac =
|
|
HmacSha256::new_from_slice(secret.as_bytes()).map_err(|_| {
|
|
ChannelError::Internal("hmac init failed".to_string())
|
|
})?;
|
|
mac.update(b"channel-access-token-signing-key");
|
|
let result = mac.finalize().into_bytes();
|
|
Ok(result.into())
|
|
}
|
|
|
|
fn session_hash_key(&self, user_id: &Uuid, created_at: i64) -> String {
|
|
format!("{}{}:{}", SESSION_PREFIX, user_id, created_at)
|
|
}
|
|
|
|
fn token_redis_key(&self, token_str: &str) -> String {
|
|
format!("{}{}", TOKEN_PREFIX, token_str)
|
|
}
|
|
|
|
pub async fn apply_access_token(
|
|
&self,
|
|
user_id: Uuid,
|
|
apply: ChannelTokenApply,
|
|
) -> ChannelResult<ChannelAccessToken> {
|
|
let created_at = chrono::Utc::now().timestamp();
|
|
let signing_key = self.signing_key()?;
|
|
|
|
let payload = TokenPayload {
|
|
user_id,
|
|
created_at,
|
|
};
|
|
let token_bytes = payload.encode(&signing_key)?;
|
|
let access_token = base64::engine::general_purpose::URL_SAFE_NO_PAD
|
|
.encode(&token_bytes);
|
|
|
|
let session_key = self.session_hash_key(&user_id, created_at);
|
|
let token_key = self.token_redis_key(&access_token);
|
|
|
|
let cluster = require_cluster(&self.inner.cache)?;
|
|
let mut conn = cluster.conn();
|
|
let mut pipe = redis::Pipeline::new();
|
|
pipe.hset(&session_key, "device_id", &apply.device_id)
|
|
.hset(&session_key, "client_id", &apply.client_id)
|
|
.expire(&session_key, SESSION_TTL_SECS as i64);
|
|
pipe.query_async::<()>(&mut conn)
|
|
.await
|
|
.map_err(|e| ChannelError::Cache(cache::CacheError::Redis(e)))?;
|
|
redis::Cmd::set_ex(&token_key, &session_key, TOKEN_TTL_SECS)
|
|
.query_async::<()>(&mut conn)
|
|
.await
|
|
.map_err(|e| ChannelError::Cache(cache::CacheError::Redis(e)))?;
|
|
|
|
Ok(ChannelAccessToken { access_token })
|
|
}
|
|
|
|
pub async fn check_access_token(
|
|
&self,
|
|
access_token: String,
|
|
) -> ChannelResult<ChannelTokenContext> {
|
|
let token_bytes = decode_token_bytes(&access_token)?;
|
|
|
|
let signing_key = self.signing_key()?;
|
|
let payload = TokenPayload::decode(&token_bytes, &signing_key)?;
|
|
|
|
let elapsed = chrono::Utc::now().timestamp() - payload.created_at;
|
|
if elapsed > MAX_LIFETIME_SECS {
|
|
return Err(ChannelError::TokenInvalidOrExpired);
|
|
}
|
|
|
|
let session_key =
|
|
self.session_hash_key(&payload.user_id, payload.created_at);
|
|
|
|
let cluster = require_cluster(&self.inner.cache)?;
|
|
let mut conn = cluster.conn();
|
|
|
|
let token_key = self.token_redis_key(&access_token);
|
|
let token_exists: bool = redis::Cmd::new()
|
|
.arg("EXISTS")
|
|
.arg(&token_key)
|
|
.query_async(&mut conn)
|
|
.await
|
|
.map_err(|e| ChannelError::Cache(cache::CacheError::Redis(e)))?;
|
|
|
|
if !token_exists {
|
|
return Err(ChannelError::TokenInvalidOrExpired);
|
|
}
|
|
|
|
let hash_data: std::collections::HashMap<String, String> =
|
|
redis::Cmd::new()
|
|
.arg("HGETALL")
|
|
.arg(&session_key)
|
|
.query_async(&mut conn)
|
|
.await
|
|
.map_err(|e| ChannelError::Cache(cache::CacheError::Redis(e)))?;
|
|
|
|
let device_id = hash_data
|
|
.get("device_id")
|
|
.cloned()
|
|
.ok_or(ChannelError::TokenInvalidOrExpired)?;
|
|
let client_id = hash_data
|
|
.get("client_id")
|
|
.cloned()
|
|
.ok_or(ChannelError::TokenInvalidOrExpired)?;
|
|
|
|
Ok(ChannelTokenContext {
|
|
user_id: payload.user_id,
|
|
device_id,
|
|
client_id,
|
|
})
|
|
}
|
|
|
|
pub async fn renew_access_token(
|
|
&self,
|
|
access_token: String,
|
|
) -> ChannelResult<ChannelAccessToken> {
|
|
let token_bytes = decode_token_bytes(&access_token)?;
|
|
|
|
let signing_key = self.signing_key()?;
|
|
let payload = TokenPayload::decode(&token_bytes, &signing_key)?;
|
|
|
|
let elapsed = chrono::Utc::now().timestamp() - payload.created_at;
|
|
if elapsed > MAX_LIFETIME_SECS {
|
|
return Err(ChannelError::RenewalLimitExceeded);
|
|
}
|
|
|
|
let session_key =
|
|
self.session_hash_key(&payload.user_id, payload.created_at);
|
|
let token_key = self.token_redis_key(&access_token);
|
|
|
|
let cluster = require_cluster(&self.inner.cache)?;
|
|
let mut conn = cluster.conn();
|
|
let hash_data: std::collections::HashMap<String, String> =
|
|
redis::Cmd::new()
|
|
.arg("HGETALL")
|
|
.arg(&session_key)
|
|
.query_async(&mut conn)
|
|
.await
|
|
.map_err(|e| {
|
|
ChannelError::Cache(cache::CacheError::Redis(e))
|
|
})?;
|
|
|
|
let device_id = hash_data
|
|
.get("device_id")
|
|
.cloned()
|
|
.ok_or(ChannelError::TokenInvalidOrExpired)?;
|
|
let client_id = hash_data
|
|
.get("client_id")
|
|
.cloned()
|
|
.ok_or(ChannelError::TokenInvalidOrExpired)?;
|
|
let _: () = redis::Cmd::new()
|
|
.arg("DEL")
|
|
.arg(&token_key)
|
|
.query_async(&mut conn)
|
|
.await
|
|
.map_err(|e| ChannelError::Cache(cache::CacheError::Redis(e)))?;
|
|
let created_at = chrono::Utc::now().timestamp();
|
|
let new_payload = TokenPayload {
|
|
user_id: payload.user_id,
|
|
created_at,
|
|
};
|
|
let new_token_bytes = new_payload.encode(&signing_key)?;
|
|
let new_access_token =
|
|
base64::engine::general_purpose::URL_SAFE_NO_PAD
|
|
.encode(&new_token_bytes);
|
|
|
|
let new_session_key =
|
|
self.session_hash_key(&payload.user_id, created_at);
|
|
let new_token_key = self.token_redis_key(&new_access_token);
|
|
let mut pipe = redis::Pipeline::new();
|
|
pipe.hset(&new_session_key, "device_id", &device_id)
|
|
.hset(&new_session_key, "client_id", &client_id)
|
|
.expire(&new_session_key, SESSION_TTL_SECS as i64);
|
|
|
|
pipe.query_async::<()>(&mut conn)
|
|
.await
|
|
.map_err(|e| ChannelError::Cache(cache::CacheError::Redis(e)))?;
|
|
redis::Cmd::set_ex(&new_token_key, &new_session_key, TOKEN_TTL_SECS)
|
|
.query_async::<()>(&mut conn)
|
|
.await
|
|
.map_err(|e| ChannelError::Cache(cache::CacheError::Redis(e)))?;
|
|
|
|
Ok(ChannelAccessToken {
|
|
access_token: new_access_token,
|
|
})
|
|
}
|
|
}
|
|
|
|
fn decode_token_bytes(token: &str) -> ChannelResult<Vec<u8>> {
|
|
if token.len() > MAX_TOKEN_BASE64_LEN {
|
|
return Err(ChannelError::TokenInvalidOrExpired);
|
|
}
|
|
base64::engine::general_purpose::URL_SAFE_NO_PAD
|
|
.decode(token)
|
|
.map_err(|_| ChannelError::TokenInvalidOrExpired)
|
|
}
|
|
|
|
fn hmac_sign(key: &[u8], payload: &[u8]) -> ChannelResult<[u8; 32]> {
|
|
let mut mac = HmacSha256::new_from_slice(key)
|
|
.map_err(|_| ChannelError::Internal("hmac sign failed".to_string()))?;
|
|
mac.update(payload);
|
|
Ok(mac.finalize().into_bytes().into())
|
|
}
|
|
|
|
fn constant_time_eq(expected: &[u8; 32], actual: &[u8]) -> bool {
|
|
if actual.len() != 32 {
|
|
return false;
|
|
}
|
|
let mut diff = 0u8;
|
|
for i in 0..32 {
|
|
diff |= expected[i] ^ actual[i];
|
|
}
|
|
diff == 0
|
|
}
|