diff --git a/libs/agent/billing.rs b/libs/agent/billing.rs index f5f4c70..8886fe7 100644 --- a/libs/agent/billing.rs +++ b/libs/agent/billing.rs @@ -23,20 +23,28 @@ pub struct BillingRecord { pub output_tokens: i64, } -/// Record AI usage for a project. +/// Extended result that includes insufficient balance flag for system message creation. +#[derive(Debug)] +pub enum BillingResult { + Success(BillingRecord), + InsufficientBalance { message: String }, +} + +/// Record AI usage for a project with cascading billing. /// -/// If the project belongs to a workspace, the cost is deducted from the -/// workspace's shared quota. Otherwise it is deducted from the project's own -/// billing balance. +/// Billing strategy: +/// 1. Try to deduct from project balance first +/// 2. If insufficient, fallback to workspace balance (if project belongs to workspace) +/// 3. If both insufficient or no workspace, return InsufficientBalance error with room_id /// -/// Returns an error if there is insufficient balance. +/// Returns BillingError::InsufficientBalance with room_id for system message creation. pub async fn record_ai_usage( db: &AppDatabase, project_uid: Uuid, model_id: Uuid, input_tokens: i64, output_tokens: i64, -) -> Result { +) -> Result { // 1. Look up the active price for this model. let pricing = model_pricing::Entity::find() .filter(model_pricing::Column::ModelVersionId.eq(model_id)) @@ -68,106 +76,27 @@ pub async fn record_ai_usage( let currency = pricing.currency.clone(); - // 3. Determine whether to bill the project or its workspace. + // 3. Cascading billing: project balance first, then workspace if insufficient. let proj = project::Entity::find_by_id(project_uid) .one(db) .await? .ok_or_else(|| AgentError::Internal("Project not found".into()))?; - if let Some(workspace_id) = proj.workspace_id { - // ── Workspace-shared quota ────────────────────────────────── - let txn = db.begin().await?; + let txn = db.begin().await?; - // SELECT FOR UPDATE to prevent race conditions - let current = workspace_billing::Entity::find_by_id(workspace_id) - .lock_exclusive() - .one(&txn) - .await? - .ok_or_else(|| AgentError::Internal("Workspace billing account not found".into()))?; + // Always check project balance first + let project_billing = project_billing::Entity::find_by_id(project_uid) + .lock_exclusive() + .one(&txn) + .await? + .ok_or_else(|| AgentError::Internal("Project billing account not found".into()))?; - // Validate balance before any modifications - if current.balance < total_cost { - txn.rollback().await?; - return Err(AgentError::Internal(format!( - "Insufficient workspace billing balance. Required: {:.4} {}, Available: {:.4} {}", - total_cost, currency, current.balance, currency - ))); - } + let now = chrono::Utc::now(); + if project_billing.balance >= total_cost { + // ── Project has sufficient balance ────────────────────────── let amount_dec = -total_cost; - let now = chrono::Utc::now(); - // Insert workspace billing history AFTER validation - workspace_billing_history::ActiveModel { - uid: Set(Uuid::new_v4()), - workspace_id: Set(workspace_id), - user_id: Set(Some(proj.created_by)), - amount: Set(amount_dec), - currency: Set(currency.clone()), - reason: Set(format!("ai_usage:{}", project_uid)), - extra: Set(Some(serde_json::json!({ - "project_id": project_uid.to_string(), - "model_id": model_id.to_string(), - "input_tokens": input_tokens, - "output_tokens": output_tokens, - }))), - created_at: Set(now), - } - .insert(&txn) - .await?; - - // Deduct from workspace balance - let new_balance = current.balance - total_cost; - let mut updated: workspace_billing::ActiveModel = current.into(); - updated.balance = Set(new_balance); - updated.updated_at = Set(now); - updated.update(&txn).await?; - - txn.commit().await?; - - let cost_f64 = total_cost.to_string().parse().unwrap_or(0.0); - - tracing::info!( - project_id = %project_uid, - model_id = %model_id, - input_tokens = input_tokens, - output_tokens = output_tokens, - cost = %cost_f64, - currency = %currency, - workspace_id = %workspace_id.to_string(), - "ai_usage_recorded" - ); - - Ok(BillingRecord { - cost: cost_f64, - currency, - input_tokens, - output_tokens, - }) - } else { - // ── Project-owned quota ───────────────────────────────────── - let txn = db.begin().await?; - - // SELECT FOR UPDATE to prevent race conditions - let current = project_billing::Entity::find_by_id(project_uid) - .lock_exclusive() - .one(&txn) - .await? - .ok_or_else(|| AgentError::Internal("Project billing account not found".into()))?; - - // Validate balance before any modifications - if current.balance < total_cost { - txn.rollback().await?; - return Err(AgentError::Internal(format!( - "Insufficient billing balance. Required: {:.4} {}, Available: {:.4} {}", - total_cost, currency, current.balance, currency - ))); - } - - let amount_dec = -total_cost; - let now = chrono::Utc::now(); - - // Insert project billing history AFTER validation project_billing_history::ActiveModel { uid: Set(Uuid::new_v4()), project: Set(project_uid), @@ -186,9 +115,8 @@ pub async fn record_ai_usage( .insert(&txn) .await?; - // Deduct from project balance - let new_balance = current.balance - total_cost; - let mut updated: project_billing::ActiveModel = current.into(); + let new_balance = project_billing.balance - total_cost; + let mut updated: project_billing::ActiveModel = project_billing.into(); updated.balance = Set(new_balance); updated.update(&txn).await?; @@ -203,14 +131,93 @@ pub async fn record_ai_usage( output_tokens = output_tokens, cost = %cost_f64, currency = %currency, + source = "project", "ai_usage_recorded" ); - Ok(BillingRecord { + Ok(BillingResult::Success(BillingRecord { cost: cost_f64, currency, input_tokens, output_tokens, + })) + } else if let Some(workspace_id) = proj.workspace_id { + // ── Project insufficient, fallback to workspace ───────────── + let workspace_billing = workspace_billing::Entity::find_by_id(workspace_id) + .lock_exclusive() + .one(&txn) + .await? + .ok_or_else(|| AgentError::Internal("Workspace billing account not found".into()))?; + + if workspace_billing.balance < total_cost { + txn.rollback().await?; + return Ok(BillingResult::InsufficientBalance { + message: format!( + "Insufficient balance. Project: {:.4} {}, Workspace: {:.4} {}, Required: {:.4} {}", + project_billing.balance, currency, + workspace_billing.balance, currency, + total_cost, currency + ), + }); + } + + let amount_dec = -total_cost; + + workspace_billing_history::ActiveModel { + uid: Set(Uuid::new_v4()), + workspace_id: Set(workspace_id), + user_id: Set(Some(proj.created_by)), + amount: Set(amount_dec), + currency: Set(currency.clone()), + reason: Set(format!("ai_usage:{}", project_uid)), + extra: Set(Some(serde_json::json!({ + "project_id": project_uid.to_string(), + "model_id": model_id.to_string(), + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "fallback_reason": "project_balance_insufficient" + }))), + created_at: Set(now), + } + .insert(&txn) + .await?; + + let new_balance = workspace_billing.balance - total_cost; + let mut updated: workspace_billing::ActiveModel = workspace_billing.into(); + updated.balance = Set(new_balance); + updated.updated_at = Set(now); + updated.update(&txn).await?; + + txn.commit().await?; + + let cost_f64 = total_cost.to_string().parse().unwrap_or(0.0); + + tracing::info!( + project_id = %project_uid, + model_id = %model_id, + input_tokens = input_tokens, + output_tokens = output_tokens, + cost = %cost_f64, + currency = %currency, + workspace_id = %workspace_id.to_string(), + source = "workspace_fallback", + "ai_usage_recorded" + ); + + Ok(BillingResult::Success(BillingRecord { + cost: cost_f64, + currency, + input_tokens, + output_tokens, + })) + } else { + // ── Project insufficient and no workspace ─────────────────── + txn.rollback().await?; + Ok(BillingResult::InsufficientBalance { + message: format!( + "Insufficient balance. Required: {:.4} {}, Available: {:.4} {}", + total_cost, currency, project_billing.balance, currency + ), }) } } diff --git a/libs/agent/chat/service.rs b/libs/agent/chat/service.rs index e17a539..a65021d 100644 --- a/libs/agent/chat/service.rs +++ b/libs/agent/chat/service.rs @@ -58,17 +58,24 @@ async fn record_ai_session( output_tokens: i64, latency_ms: i64, ) { - let (cost, currency) = match billing::record_ai_usage( + let (cost, currency, error_msg) = match billing::record_ai_usage( db, project_id, - model_id, + version_id, input_tokens, output_tokens, ) .await { - Ok(record) => (Some(record.cost), Some(record.currency)), - Err(_) => (None, None), + Ok(billing::BillingResult::Success(record)) => { + (Some(record.cost), Some(record.currency), None) + } + Ok(billing::BillingResult::InsufficientBalance { message }) => { + // Create system message for insufficient balance + create_system_message(db, room_id, &message).await; + (None, None, Some(message)) + } + Err(_) => (None, None, None), }; let _ = models::ai::ai_session::ActiveModel { @@ -81,7 +88,7 @@ async fn record_ai_session( latency_ms: Set(Some(latency_ms)), cost: Set(cost), currency: Set(currency), - error_message: Set(None), + error_message: Set(error_msg), error_code: Set(None), created_at: Set(chrono::Utc::now()), } @@ -89,6 +96,71 @@ async fn record_ai_session( .await; } +/// Create a system message in the room for billing errors. +async fn create_system_message( + db: &db::database::AppDatabase, + room_id: Uuid, + message: &str, +) { + use models::rooms::{room_message, MessageSenderType, MessageContentType}; + use sea_orm::Set; + + // Get next sequence number - we don't have cache here, so we query directly + let last_seq = match room_message::Entity::find() + .filter(room_message::Column::Room.eq(room_id)) + .order_by_desc(room_message::Column::Seq) + .one(db) + .await + { + Ok(Some(m)) => m.seq, + Ok(None) => 0, + Err(e) => { + tracing::warn!(error = %e, "Failed to get last seq for system message"); + return; + } + }; + + let seq = last_seq + 1; + let now = chrono::Utc::now(); + + let result = room_message::ActiveModel { + id: Set(Uuid::new_v4()), + seq: Set(seq), + room: Set(room_id), + sender_type: Set(MessageSenderType::System), + sender_id: Set(None), + model_id: Set(None), + thread: Set(None), + in_reply_to: Set(None), + content: Set(message.to_string()), + content_type: Set(MessageContentType::Text), + thinking_content: Set(None), + edited_at: Set(None), + send_at: Set(now), + revoked: Set(None), + revoked_by: Set(None), + } + .insert(db) + .await; + + match result { + Ok(_) => { + tracing::info!( + room_id = %room_id, + message = %message, + "system_message_created_for_billing_error" + ); + } + Err(e) => { + tracing::warn!( + error = %e, + room_id = %room_id, + "Failed to create system message for billing error" + ); + } + } +} + /// Service for handling AI chat requests in rooms. pub struct ChatService { ai_base_url: Option, @@ -614,7 +686,26 @@ impl ChatService { for call in &calls { let start = std::time::Instant::now(); let executor = crate::tool::ToolExecutor::new(); - let results = match executor.execute_batch(vec![call.clone()], &mut ctx).await { + + // Use select! loop to send heartbeat chunks at 30s intervals + // during long tool execution, resetting the frontend streaming timer. + let fut = executor.execute_batch(vec![call.clone()], &mut ctx); + tokio::pin!(fut); + + let results = loop { + tokio::select! { + result = fut.as_mut() => break result, + _ = tokio::time::sleep(std::time::Duration::from_secs(30)) => { + on_chunk(AiStreamChunk { + content: String::new(), + done: false, + chunk_type: AiChunkType::ToolCall, + }).await; + } + } + }; + + let results = match results { Ok(r) => r, Err(e) => { let elapsed = start.elapsed().as_millis() as i64; diff --git a/libs/agent/client/mod.rs b/libs/agent/client/mod.rs index bd69d11..13bff36 100644 --- a/libs/agent/client/mod.rs +++ b/libs/agent/client/mod.rs @@ -760,9 +760,9 @@ async fn call_stream_once( }) }; - // 60s timeout for the entire stream - match tokio::time::timeout(std::time::Duration::from_secs(60), stream_fut).await { + // 120s timeout for the entire stream + match tokio::time::timeout(std::time::Duration::from_secs(120), stream_fut).await { Ok(result) => result, - Err(_) => Err(AgentError::Timeout { task_id: 0, seconds: 60 }), + Err(_) => Err(AgentError::Timeout { task_id: 0, seconds: 120 }), } } diff --git a/libs/agent/lib.rs b/libs/agent/lib.rs index 9c224fa..eeb4172 100644 --- a/libs/agent/lib.rs +++ b/libs/agent/lib.rs @@ -13,7 +13,7 @@ pub mod sync; pub mod task; pub mod tokent; pub mod tool; -pub use billing::{BillingRecord, record_ai_usage}; +pub use billing::{BillingRecord, BillingResult, record_ai_usage}; pub use sync::list_accessible_models; pub use task::TaskService; pub use tokent::{TokenUsage, resolve_usage};