gitdataai/lib/service/auth/totp.rs

431 lines
13 KiB
Rust

use argon2::{Argon2, PasswordHash, password_hash::PasswordVerifier};
use db::sqlx;
use hmac::{Hmac, KeyInit, Mac};
use model::users::{User2FaModel, user_pass::UserPasswordModel};
use rand::RngExt;
use serde::{Deserialize, Serialize};
use session::Session;
use sha1::Sha1;
use sha2::{Digest, Sha256};
use uuid::Uuid;
use crate::{AppService, constant_time_eq, error::AppError};
#[derive(Deserialize, Serialize, Clone, Debug, utoipa::ToSchema)]
pub struct Enable2FAResponse {
pub secret: String,
pub qr_code: String,
pub backup_codes: Vec<String>,
}
#[derive(Deserialize, Serialize, Clone, Debug, utoipa::ToSchema)]
pub struct Verify2FAParams {
pub code: String,
}
#[derive(Deserialize, Serialize, Clone, Debug, utoipa::ToSchema)]
pub struct Disable2FAParams {
pub code: String,
pub password: String,
}
#[derive(Deserialize, Serialize, Clone, Debug, utoipa::ToSchema)]
pub struct Get2FAStatusResponse {
pub is_enabled: bool,
pub method: Option<String>,
pub has_backup_codes: bool,
}
impl AppService {
pub async fn auth_2fa_enable(
&self,
context: &Session,
) -> Result<Enable2FAResponse, AppError> {
let user_uid = context.user().ok_or(AppError::Unauthorized)?;
let user = self.auth_find_user_by_uid(user_uid).await?;
let existing = self.find_2fa(user_uid).await?;
if existing.as_ref().is_some_and(|two_fa| two_fa.enabled) {
return Err(AppError::TwoFactorAlreadyEnabled);
}
let secret = self.generate_totp_secret();
let backup_codes = self.generate_backup_codes(10);
let qr_code = format!(
"otpauth://totp/GitDataAI:{}?secret={}&issuer=GitDataAI",
user.username, secret
);
let now = chrono::Utc::now();
let hashed_backup_codes =
Self::hash_backup_codes(&backup_codes).join(".");
if existing.is_some() {
sqlx::query(
"UPDATE user_2fa SET secret = $1, backup_codes = $2, enabled = false, updated_at = $3 \
WHERE \"user\" = $4",
)
.bind(&secret)
.bind(&hashed_backup_codes)
.bind(now)
.bind(user_uid)
.execute(self.db.writer())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
} else {
sqlx::query(
"INSERT INTO user_2fa (\"user\", secret, backup_codes, enabled, created_at, updated_at) \
VALUES ($1, $2, $3, false, $4, $4)",
)
.bind(user_uid)
.bind(&secret)
.bind(&hashed_backup_codes)
.bind(now)
.execute(self.db.writer())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
}
Ok(Enable2FAResponse {
secret,
qr_code,
backup_codes,
})
}
pub async fn auth_2fa_verify_and_enable(
&self,
context: &Session,
params: Verify2FAParams,
) -> Result<(), AppError> {
let user_uid = context.user().ok_or(AppError::Unauthorized)?;
let two_fa = self
.find_2fa(user_uid)
.await?
.ok_or(AppError::TwoFactorNotSetup)?;
if two_fa.enabled {
return Err(AppError::TwoFactorAlreadyEnabled);
}
let secret =
two_fa.secret.as_ref().ok_or(AppError::TwoFactorNotSetup)?;
if !self.verify_totp_code(secret, &params.code)? {
return Err(AppError::InvalidTwoFactorCode);
}
sqlx::query("UPDATE user_2fa SET enabled = true, updated_at = $1 WHERE \"user\" = $2")
.bind(chrono::Utc::now())
.bind(user_uid)
.execute(self.db.writer())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
Ok(())
}
pub async fn auth_2fa_disable(
&self,
context: &Session,
params: Disable2FAParams,
) -> Result<(), AppError> {
let user_uid = context.user().ok_or(AppError::Unauthorized)?;
let password = self.auth_rsa_decode(context, params.password).await?;
self.verify_user_password(user_uid, &password).await?;
let two_fa = self
.find_2fa(user_uid)
.await?
.ok_or(AppError::TwoFactorNotSetup)?;
if !two_fa.enabled {
return Err(AppError::TwoFactorNotEnabled);
}
if !self
.verify_2fa_or_backup_code(&two_fa, &params.code)
.await?
{
return Err(AppError::InvalidTwoFactorCode);
}
sqlx::query("DELETE FROM user_2fa WHERE \"user\" = $1")
.bind(user_uid)
.execute(self.db.writer())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
Ok(())
}
pub async fn auth_2fa_verify(
&self,
user_uid: Uuid,
code: &str,
) -> Result<bool, AppError> {
let Some(two_fa) = self.find_2fa(user_uid).await? else {
return Ok(true);
};
if !two_fa.enabled {
return Ok(true);
}
self.verify_2fa_or_backup_code(&two_fa, code).await
}
pub async fn auth_2fa_status_by_uid(
&self,
user_uid: Uuid,
) -> Result<Get2FAStatusResponse, AppError> {
let Some(two_fa) = self.find_2fa(user_uid).await? else {
return Ok(Get2FAStatusResponse {
is_enabled: false,
method: None,
has_backup_codes: false,
});
};
Ok(Get2FAStatusResponse {
is_enabled: two_fa.enabled,
method: Some("totp".to_string()),
has_backup_codes: !two_fa.backup_codes.is_empty(),
})
}
pub async fn auth_2fa_status(
&self,
context: &Session,
) -> Result<Get2FAStatusResponse, AppError> {
let user_uid = context.user().ok_or(AppError::Unauthorized)?;
self.auth_2fa_status_by_uid(user_uid).await
}
pub async fn auth_2fa_verify_login(
&self,
context: &Session,
code: &str,
) -> Result<bool, AppError> {
let Some(totp_key) =
context.get::<String>(Self::TOTP_KEY).ok().flatten()
else {
return Ok(false);
};
let Some(user_uid) = self
.cache
.get::<Uuid>(&totp_key)
.await
.map_err(|e| AppError::InternalServerError(e.to_string()))?
else {
return Ok(false);
};
let verified = self.auth_2fa_verify(user_uid, code).await?;
if verified {
context.remove(Self::TOTP_KEY);
let _ = self.cache.remove(&totp_key).await;
context.set_user(user_uid);
}
Ok(verified)
}
pub async fn auth_2fa_regenerate_backup_codes(
&self,
context: &Session,
password: String,
) -> Result<Vec<String>, AppError> {
let user_uid = context.user().ok_or(AppError::Unauthorized)?;
let password = self.auth_rsa_decode(context, password).await?;
self.verify_user_password(user_uid, &password).await?;
let two_fa = self
.find_2fa(user_uid)
.await?
.ok_or(AppError::TwoFactorNotSetup)?;
if !two_fa.enabled {
return Err(AppError::TwoFactorNotEnabled);
}
let backup_codes = self.generate_backup_codes(10);
sqlx::query("UPDATE user_2fa SET backup_codes = $1, updated_at = $2 WHERE \"user\" = $3")
.bind(Self::hash_backup_codes(&backup_codes).join("."))
.bind(chrono::Utc::now())
.bind(user_uid)
.execute(self.db.writer())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
Ok(backup_codes)
}
fn generate_totp_secret(&self) -> String {
const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567";
#[allow(deprecated)]
let mut rng = rand::rng();
(0..32)
.map(|_| {
#[allow(deprecated)]
let idx = rng.random_range(0..CHARSET.len());
CHARSET[idx] as char
})
.collect()
}
fn generate_backup_codes(&self, count: usize) -> Vec<String> {
#[allow(deprecated)]
let mut rng = rand::rng();
(0..count)
.map(|_| {
format!(
"{:04}-{:04}-{:04}",
rng.random_range(0..10000),
rng.random_range(0..10000),
rng.random_range(0..10000)
)
})
.collect()
}
fn hash_backup_code(code: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(code.as_bytes());
hasher
.finalize()
.iter()
.map(|b| format!("{:02x}", b))
.collect::<String>()
}
fn hash_backup_codes(codes: &[String]) -> Vec<String> {
codes.iter().map(|c| Self::hash_backup_code(c)).collect()
}
fn verify_totp_code(
&self,
secret: &str,
code: &str,
) -> Result<bool, AppError> {
let now = chrono::Utc::now().timestamp() as u64;
let time_step = 30;
let counter = now / time_step;
for offset in [-1i64, 0, 1] {
let test_counter = (counter as i64 + offset) as u64;
let expected_code =
self.generate_totp_code(secret, test_counter)?;
if constant_time_eq(&expected_code, code) {
return Ok(true);
}
}
Ok(false)
}
fn generate_totp_code(
&self,
secret: &str,
counter: u64,
) -> Result<String, AppError> {
let secret_bytes = self.decode_base32(secret)?;
let counter_bytes = counter.to_be_bytes();
let mut mac = Hmac::<Sha1>::new_from_slice(&secret_bytes)
.map_err(|_| AppError::InvalidTwoFactorCode)?;
mac.update(&counter_bytes);
let result = mac.finalize().into_bytes();
let offset = (result[19] & 0x0f) as usize;
let code = u32::from_be_bytes([
result[offset] & 0x7f,
result[offset + 1],
result[offset + 2],
result[offset + 3],
]);
Ok(format!("{:06}", code % 1_000_000))
}
fn decode_base32(&self, input: &str) -> Result<Vec<u8>, AppError> {
const CHARSET: &str = "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567";
let input = input.to_uppercase().replace("=", "");
let mut bits = 0u64;
let mut bit_count = 0;
let mut output = Vec::new();
for c in input.chars() {
let val =
CHARSET.find(c).ok_or(AppError::InvalidTwoFactorCode)? as u64;
bits = (bits << 5) | val;
bit_count += 5;
if bit_count >= 8 {
bit_count -= 8;
output.push((bits >> bit_count) as u8);
bits &= (1 << bit_count) - 1;
}
}
Ok(output)
}
async fn verify_user_password(
&self,
user_uid: Uuid,
password: &str,
) -> Result<(), AppError> {
let user_password = sqlx::query_as::<_, UserPasswordModel>(
"SELECT \"user\", hash, salt, is_active, reason, created_at, updated_at \
FROM user_password WHERE \"user\" = $1 AND is_active = true",
)
.bind(user_uid)
.fetch_optional(self.db.reader())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?
.ok_or(AppError::UserNotFound)?;
let password_hash = PasswordHash::new(&user_password.hash)
.map_err(|_| AppError::InvalidPassword)?;
Argon2::default()
.verify_password(password.as_bytes(), &password_hash)
.map_err(|_| AppError::InvalidPassword)?;
Ok(())
}
async fn find_2fa(
&self,
user_uid: Uuid,
) -> Result<Option<User2FaModel>, AppError> {
sqlx::query_as::<_, User2FaModel>(
"SELECT \"user\", secret, backup_codes, enabled, created_at, updated_at \
FROM user_2fa WHERE \"user\" = $1",
)
.bind(user_uid)
.fetch_optional(self.db.reader())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))
}
async fn verify_2fa_or_backup_code(
&self,
two_fa: &User2FaModel,
code: &str,
) -> Result<bool, AppError> {
let secret =
two_fa.secret.as_ref().ok_or(AppError::TwoFactorNotSetup)?;
if self.verify_totp_code(secret, code)? {
return Ok(true);
}
let hashed_code = Self::hash_backup_code(code);
let mut backup_codes: Vec<String> = two_fa
.backup_codes
.split('.')
.filter(|code| !code.is_empty())
.map(ToOwned::to_owned)
.collect();
if backup_codes.contains(&hashed_code) {
backup_codes.retain(|stored| stored != &hashed_code);
sqlx::query(
"UPDATE user_2fa SET backup_codes = $1, updated_at = $2 WHERE \"user\" = $3",
)
.bind(backup_codes.join("."))
.bind(chrono::Utc::now())
.bind(two_fa.user)
.execute(self.db.writer())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
return Ok(true);
}
Ok(false)
}
}