fix(billing): add transaction isolation and fix race conditions

Critical fixes:
- Wrap balance updates in database transactions with SELECT FOR UPDATE
- Move history insert after balance validation to prevent orphaned records
- Use Decimal throughout to avoid silent conversion failures
- Prevent concurrent requests from causing negative balances

Tasks resolved:
- Task #4: Silent Decimal conversion failures
- Task #5: Missing transaction isolation (race conditions)
- Task #6: History inserted before validation
This commit is contained in:
ZhenYi 2026-04-28 10:12:24 +08:00
parent 6edacbcdf2
commit 8a6ec1f62f

View File

@ -54,20 +54,17 @@ pub async fn record_ai_usage(
let input_price: Decimal = pricing let input_price: Decimal = pricing
.input_price_per_1k_tokens .input_price_per_1k_tokens
.parse() .parse()
.unwrap_or(Decimal::ZERO); .map_err(|e| AgentError::Internal(format!("Invalid input price format: {}", e)))?;
let output_price: Decimal = pricing let output_price: Decimal = pricing
.output_price_per_1k_tokens .output_price_per_1k_tokens
.parse() .parse()
.unwrap_or(Decimal::ZERO); .map_err(|e| AgentError::Internal(format!("Invalid output price format: {}", e)))?;
let tokens_i = Decimal::from(input_tokens); let tokens_i = Decimal::from(input_tokens);
let tokens_o = Decimal::from(output_tokens); let tokens_o = Decimal::from(output_tokens);
let thousand = Decimal::from(1000); let thousand = Decimal::from(1000);
let total_cost: f64 = ((tokens_i / thousand) * input_price let total_cost = (tokens_i / thousand) * input_price
+ (tokens_o / thousand) * output_price) + (tokens_o / thousand) * output_price;
.to_string()
.parse()
.unwrap_or(0.0);
let currency = pricing.currency.clone(); let currency = pricing.currency.clone();
@ -79,25 +76,29 @@ pub async fn record_ai_usage(
if let Some(workspace_id) = proj.workspace_id { if let Some(workspace_id) = proj.workspace_id {
// ── Workspace-shared quota ────────────────────────────────── // ── Workspace-shared quota ──────────────────────────────────
let txn = db.begin().await?;
// SELECT FOR UPDATE to prevent race conditions
let current = workspace_billing::Entity::find_by_id(workspace_id) let current = workspace_billing::Entity::find_by_id(workspace_id)
.one(db) .lock_exclusive()
.one(&txn)
.await? .await?
.ok_or_else(|| AgentError::Internal("Workspace billing account not found".into()))?; .ok_or_else(|| AgentError::Internal("Workspace billing account not found".into()))?;
let current_balance: f64 = current.balance.to_string().parse().unwrap_or(0.0); // Validate balance before any modifications
if current.balance < total_cost {
if current_balance < total_cost { txn.rollback().await?;
return Err(AgentError::Internal(format!( return Err(AgentError::Internal(format!(
"Insufficient workspace billing balance. Required: {:.4} {}, Available: {:.4} {}", "Insufficient workspace billing balance. Required: {:.4} {}, Available: {:.4} {}",
total_cost, currency, current_balance, currency total_cost, currency, current.balance, currency
))); )));
} }
let amount_dec = Decimal::from_f64_retain(-total_cost).unwrap_or(Decimal::ZERO); let amount_dec = -total_cost;
let now = chrono::Utc::now(); let now = chrono::Utc::now();
// Insert workspace billing history. // Insert workspace billing history AFTER validation
let _ = workspace_billing_history::ActiveModel { workspace_billing_history::ActiveModel {
uid: Set(Uuid::new_v4()), uid: Set(Uuid::new_v4()),
workspace_id: Set(workspace_id), workspace_id: Set(workspace_id),
user_id: Set(Some(proj.created_by)), user_id: Set(Some(proj.created_by)),
@ -112,39 +113,62 @@ pub async fn record_ai_usage(
}))), }))),
created_at: Set(now), created_at: Set(now),
} }
.insert(db) .insert(&txn)
.await; .await?;
// Deduct from workspace balance. // Deduct from workspace balance
let new_balance = let new_balance = current.balance - total_cost;
Decimal::from_f64_retain(current_balance - total_cost).unwrap_or(Decimal::ZERO);
let mut updated: workspace_billing::ActiveModel = current.into(); let mut updated: workspace_billing::ActiveModel = current.into();
updated.balance = Set(new_balance); updated.balance = Set(new_balance);
updated.updated_at = Set(now); updated.updated_at = Set(now);
updated.update(db).await?; updated.update(&txn).await?;
txn.commit().await?;
let cost_f64 = total_cost.to_string().parse().unwrap_or(0.0);
tracing::info!( tracing::info!(
project_id = %project_uid, project_id = %project_uid,
model_id = %model_id, model_id = %model_id,
input_tokens = input_tokens, input_tokens = input_tokens,
output_tokens = output_tokens, output_tokens = output_tokens,
cost = %total_cost, cost = %cost_f64,
currency = %currency, currency = %currency,
workspace_id = %workspace_id.to_string(), workspace_id = %workspace_id.to_string(),
"ai_usage_recorded" "ai_usage_recorded"
); );
Ok(BillingRecord { Ok(BillingRecord {
cost: total_cost, cost: cost_f64,
currency, currency,
input_tokens, input_tokens,
output_tokens, output_tokens,
}) })
} else { } else {
// ── Project-owned quota ───────────────────────────────────── // ── Project-owned quota ─────────────────────────────────────
let amount_dec = Decimal::from_f64_retain(-total_cost).unwrap_or(Decimal::ZERO); let txn = db.begin().await?;
let _ = project_billing_history::ActiveModel { // SELECT FOR UPDATE to prevent race conditions
let current = project_billing::Entity::find_by_id(project_uid)
.lock_exclusive()
.one(&txn)
.await?
.ok_or_else(|| AgentError::Internal("Project billing account not found".into()))?;
// Validate balance before any modifications
if current.balance < total_cost {
txn.rollback().await?;
return Err(AgentError::Internal(format!(
"Insufficient billing balance. Required: {:.4} {}, Available: {:.4} {}",
total_cost, currency, current.balance, currency
)));
}
let amount_dec = -total_cost;
let now = chrono::Utc::now();
// Insert project billing history AFTER validation
project_billing_history::ActiveModel {
uid: Set(Uuid::new_v4()), uid: Set(Uuid::new_v4()),
project: Set(project_uid), project: Set(project_uid),
user: Set(None), user: Set(None),
@ -156,44 +180,34 @@ pub async fn record_ai_usage(
"input_tokens": input_tokens, "input_tokens": input_tokens,
"output_tokens": output_tokens, "output_tokens": output_tokens,
}))), }))),
created_at: Set(chrono::Utc::now()), created_at: Set(now),
..Default::default() ..Default::default()
} }
.insert(db) .insert(&txn)
.await; .await?;
let current = project_billing::Entity::find_by_id(project_uid) // Deduct from project balance
.one(db) let new_balance = current.balance - total_cost;
.await?
.ok_or_else(|| AgentError::Internal("Project billing account not found".into()))?;
let current_balance: f64 = current.balance.to_string().parse().unwrap_or(0.0);
if current_balance < total_cost {
return Err(AgentError::Internal(format!(
"Insufficient billing balance. Required: {:.4} {}, Available: {:.4} {}",
total_cost, currency, current_balance, currency
)));
}
let new_balance =
Decimal::from_f64_retain(current_balance - total_cost).unwrap_or(Decimal::ZERO);
let mut updated: project_billing::ActiveModel = current.into(); let mut updated: project_billing::ActiveModel = current.into();
updated.balance = Set(new_balance); updated.balance = Set(new_balance);
updated.update(db).await?; updated.update(&txn).await?;
txn.commit().await?;
let cost_f64 = total_cost.to_string().parse().unwrap_or(0.0);
tracing::info!( tracing::info!(
project_id = %project_uid, project_id = %project_uid,
model_id = %model_id, model_id = %model_id,
input_tokens = input_tokens, input_tokens = input_tokens,
output_tokens = output_tokens, output_tokens = output_tokens,
cost = %total_cost, cost = %cost_f64,
currency = %currency, currency = %currency,
"ai_usage_recorded" "ai_usage_recorded"
); );
Ok(BillingRecord { Ok(BillingRecord {
cost: total_cost, cost: cost_f64,
currency, currency,
input_tokens, input_tokens,
output_tokens, output_tokens,