gitdataai/lib/service/auth/rsa.rs
2026-05-30 01:38:40 +08:00

151 lines
5.4 KiB
Rust

use base64::Engine;
use chacha20poly1305::{ChaCha20Poly1305, KeyInit, Nonce, aead::Aead};
use hkdf::Hkdf;
use rand_chacha::{ChaCha12Rng, rand_core::SeedableRng};
use rsa::{
Pkcs1v15Encrypt, RsaPrivateKey, RsaPublicKey,
pkcs1::{DecodeRsaPrivateKey, EncodeRsaPrivateKey, EncodeRsaPublicKey},
};
use serde::{Deserialize, Serialize};
use session::Session;
use sha2::Sha256;
use crate::{AppService, error::AppError};
#[derive(Deserialize, Serialize, Clone, Debug, utoipa::ToSchema)]
pub struct RsaResponse {
pub public_key: String,
}
impl AppService {
pub const RSA_PRIVATE_KEY: &'static str = "rsa:private";
pub const RSA_PUBLIC_KEY: &'static str = "rsa:public";
const RSA_BIT_SIZE: usize = 2048;
fn derive_rsa_encryption_key(&self) -> [u8; 32] {
let secret = self
.config
.env
.get("APP_SESSION_SECRET")
.map(|s| s.as_str())
.expect("APP_SESSION_SECRET must be set in production. Do not use fallback keys.");
let hk = Hkdf::<Sha256>::new(
Some(b"rsa-session-encryption"),
secret.as_bytes(),
);
let mut okm = [0u8; 32];
hk.expand(b"rsa-private-key-aead", &mut okm)
.expect("HKDF expand within hash length");
okm
}
fn encrypt_rsa_key(&self, plaintext: &str) -> Result<String, AppError> {
let key = self.derive_rsa_encryption_key();
let cipher = ChaCha20Poly1305::new_from_slice(&key)
.expect("32-byte key is valid for ChaCha20Poly1305");
let nonce_bytes: [u8; 12] = rand::random();
let nonce = Nonce::from(nonce_bytes);
let ciphertext = cipher
.encrypt(&nonce, plaintext.as_bytes())
.map_err(|_| AppError::RsaGenerationError)?;
let mut combined = nonce_bytes.to_vec();
combined.extend_from_slice(&ciphertext);
Ok(base64::engine::general_purpose::STANDARD.encode(&combined))
}
fn decrypt_rsa_key(&self, encrypted: &str) -> Result<String, AppError> {
let key = self.derive_rsa_encryption_key();
let cipher = ChaCha20Poly1305::new_from_slice(&key)
.expect("32-byte key is valid for ChaCha20Poly1305");
let combined = base64::engine::general_purpose::STANDARD
.decode(encrypted)
.map_err(|_| AppError::RsaDecodeError)?;
if combined.len() < 12 {
return Err(AppError::RsaDecodeError);
}
let mut nonce_bytes = [0u8; 12];
nonce_bytes.copy_from_slice(&combined[..12]);
let nonce = Nonce::from(nonce_bytes);
let plaintext = cipher
.decrypt(&nonce, &combined[12..])
.map_err(|_| AppError::RsaDecodeError)?;
Ok(String::from_utf8(plaintext)
.map_err(|_| AppError::RsaDecodeError)?)
}
pub async fn auth_rsa(
&self,
context: &Session,
) -> Result<RsaResponse, AppError> {
if context
.get::<String>(Self::RSA_PRIVATE_KEY)
.ok()
.flatten()
.is_some()
&& context
.get::<String>(Self::RSA_PUBLIC_KEY)
.ok()
.flatten()
.is_some()
{
let public_key = context
.get::<String>(Self::RSA_PUBLIC_KEY)
.ok()
.flatten()
.expect("checked above");
return Ok(RsaResponse { public_key });
}
let seed: [u8; 32] = rand::random();
let mut rng = ChaCha12Rng::from_seed(seed);
let priv_key = RsaPrivateKey::new(&mut rng, Self::RSA_BIT_SIZE)
.map_err(|_| {
tracing::error!("RSA key generation failed");
AppError::RsaGenerationError
})?;
let pub_key = RsaPublicKey::from(&priv_key);
let priv_pem = priv_key
.to_pkcs1_pem(Default::default())
.map_err(|_| AppError::RsaGenerationError)?
.to_string();
let public_key = pub_key
.to_pkcs1_pem(Default::default())
.map_err(|_| AppError::RsaGenerationError)?
.to_string();
context
.insert(Self::RSA_PRIVATE_KEY, self.encrypt_rsa_key(&priv_pem)?)
.map_err(|_| AppError::RsaGenerationError)?;
context
.insert(Self::RSA_PUBLIC_KEY, public_key.clone())
.map_err(|_| AppError::RsaGenerationError)?;
Ok(RsaResponse { public_key })
}
pub async fn auth_rsa_decode(
&self,
context: &Session,
data: String,
) -> Result<String, AppError> {
let encrypted_priv = context
.get::<String>(Self::RSA_PRIVATE_KEY)
.map_err(|_| AppError::RsaDecodeError)?
.ok_or(AppError::RsaDecodeError)?;
let priv_pem = self.decrypt_rsa_key(&encrypted_priv)?;
let priv_key = RsaPrivateKey::from_pkcs1_pem(&priv_pem).map_err(|_| {
tracing::warn!(ip = ?context.ip_address(), "RSA decode failed: invalid private key");
AppError::RsaDecodeError
})?;
let cipher = base64::engine::general_purpose::STANDARD
.decode(&data)
.map_err(|_| AppError::RsaDecodeError)?;
let decrypted = priv_key.decrypt(Pkcs1v15Encrypt, &cipher).map_err(|_| {
tracing::warn!(ip = ?context.ip_address(), "RSA decrypt failed");
AppError::RsaDecodeError
})?;
Ok(String::from_utf8_lossy(&decrypted).to_string())
}
}