From 8a6ec1f62f248668065aeffc06847fb19935bdae Mon Sep 17 00:00:00 2001 From: ZhenYi <434836402@qq.com> Date: Tue, 28 Apr 2026 10:12:24 +0800 Subject: [PATCH] fix(billing): add transaction isolation and fix race conditions Critical fixes: - Wrap balance updates in database transactions with SELECT FOR UPDATE - Move history insert after balance validation to prevent orphaned records - Use Decimal throughout to avoid silent conversion failures - Prevent concurrent requests from causing negative balances Tasks resolved: - Task #4: Silent Decimal conversion failures - Task #5: Missing transaction isolation (race conditions) - Task #6: History inserted before validation --- libs/agent/billing.rs | 108 ++++++++++++++++++++++++------------------ 1 file changed, 61 insertions(+), 47 deletions(-) diff --git a/libs/agent/billing.rs b/libs/agent/billing.rs index 9bced81..f5f4c70 100644 --- a/libs/agent/billing.rs +++ b/libs/agent/billing.rs @@ -54,20 +54,17 @@ pub async fn record_ai_usage( let input_price: Decimal = pricing .input_price_per_1k_tokens .parse() - .unwrap_or(Decimal::ZERO); + .map_err(|e| AgentError::Internal(format!("Invalid input price format: {}", e)))?; let output_price: Decimal = pricing .output_price_per_1k_tokens .parse() - .unwrap_or(Decimal::ZERO); + .map_err(|e| AgentError::Internal(format!("Invalid output price format: {}", e)))?; let tokens_i = Decimal::from(input_tokens); let tokens_o = Decimal::from(output_tokens); let thousand = Decimal::from(1000); - let total_cost: f64 = ((tokens_i / thousand) * input_price - + (tokens_o / thousand) * output_price) - .to_string() - .parse() - .unwrap_or(0.0); + let total_cost = (tokens_i / thousand) * input_price + + (tokens_o / thousand) * output_price; let currency = pricing.currency.clone(); @@ -79,25 +76,29 @@ pub async fn record_ai_usage( if let Some(workspace_id) = proj.workspace_id { // ── Workspace-shared quota ────────────────────────────────── + let txn = db.begin().await?; + + // SELECT FOR UPDATE to prevent race conditions let current = workspace_billing::Entity::find_by_id(workspace_id) - .one(db) + .lock_exclusive() + .one(&txn) .await? .ok_or_else(|| AgentError::Internal("Workspace billing account not found".into()))?; - let current_balance: f64 = current.balance.to_string().parse().unwrap_or(0.0); - - if current_balance < total_cost { + // Validate balance before any modifications + if current.balance < total_cost { + txn.rollback().await?; return Err(AgentError::Internal(format!( "Insufficient workspace billing balance. Required: {:.4} {}, Available: {:.4} {}", - total_cost, currency, current_balance, currency + total_cost, currency, current.balance, currency ))); } - let amount_dec = Decimal::from_f64_retain(-total_cost).unwrap_or(Decimal::ZERO); + let amount_dec = -total_cost; let now = chrono::Utc::now(); - // Insert workspace billing history. - let _ = workspace_billing_history::ActiveModel { + // Insert workspace billing history AFTER validation + workspace_billing_history::ActiveModel { uid: Set(Uuid::new_v4()), workspace_id: Set(workspace_id), user_id: Set(Some(proj.created_by)), @@ -112,39 +113,62 @@ pub async fn record_ai_usage( }))), created_at: Set(now), } - .insert(db) - .await; + .insert(&txn) + .await?; - // Deduct from workspace balance. - let new_balance = - Decimal::from_f64_retain(current_balance - total_cost).unwrap_or(Decimal::ZERO); + // Deduct from workspace balance + let new_balance = current.balance - total_cost; let mut updated: workspace_billing::ActiveModel = current.into(); updated.balance = Set(new_balance); updated.updated_at = Set(now); - updated.update(db).await?; + updated.update(&txn).await?; + + txn.commit().await?; + + let cost_f64 = total_cost.to_string().parse().unwrap_or(0.0); tracing::info!( project_id = %project_uid, model_id = %model_id, input_tokens = input_tokens, output_tokens = output_tokens, - cost = %total_cost, + cost = %cost_f64, currency = %currency, workspace_id = %workspace_id.to_string(), "ai_usage_recorded" ); Ok(BillingRecord { - cost: total_cost, + cost: cost_f64, currency, input_tokens, output_tokens, }) } else { // ── Project-owned quota ───────────────────────────────────── - let amount_dec = Decimal::from_f64_retain(-total_cost).unwrap_or(Decimal::ZERO); + let txn = db.begin().await?; - let _ = project_billing_history::ActiveModel { + // SELECT FOR UPDATE to prevent race conditions + let current = project_billing::Entity::find_by_id(project_uid) + .lock_exclusive() + .one(&txn) + .await? + .ok_or_else(|| AgentError::Internal("Project billing account not found".into()))?; + + // Validate balance before any modifications + if current.balance < total_cost { + txn.rollback().await?; + return Err(AgentError::Internal(format!( + "Insufficient billing balance. Required: {:.4} {}, Available: {:.4} {}", + total_cost, currency, current.balance, currency + ))); + } + + let amount_dec = -total_cost; + let now = chrono::Utc::now(); + + // Insert project billing history AFTER validation + project_billing_history::ActiveModel { uid: Set(Uuid::new_v4()), project: Set(project_uid), user: Set(None), @@ -156,44 +180,34 @@ pub async fn record_ai_usage( "input_tokens": input_tokens, "output_tokens": output_tokens, }))), - created_at: Set(chrono::Utc::now()), + created_at: Set(now), ..Default::default() } - .insert(db) - .await; + .insert(&txn) + .await?; - let current = project_billing::Entity::find_by_id(project_uid) - .one(db) - .await? - .ok_or_else(|| AgentError::Internal("Project billing account not found".into()))?; - - let current_balance: f64 = current.balance.to_string().parse().unwrap_or(0.0); - - if current_balance < total_cost { - return Err(AgentError::Internal(format!( - "Insufficient billing balance. Required: {:.4} {}, Available: {:.4} {}", - total_cost, currency, current_balance, currency - ))); - } - - let new_balance = - Decimal::from_f64_retain(current_balance - total_cost).unwrap_or(Decimal::ZERO); + // Deduct from project balance + let new_balance = current.balance - total_cost; let mut updated: project_billing::ActiveModel = current.into(); updated.balance = Set(new_balance); - updated.update(db).await?; + updated.update(&txn).await?; + + txn.commit().await?; + + let cost_f64 = total_cost.to_string().parse().unwrap_or(0.0); tracing::info!( project_id = %project_uid, model_id = %model_id, input_tokens = input_tokens, output_tokens = output_tokens, - cost = %total_cost, + cost = %cost_f64, currency = %currency, "ai_usage_recorded" ); Ok(BillingRecord { - cost: total_cost, + cost: cost_f64, currency, input_tokens, output_tokens,