fix(agent): 修复扣费链路并实现级联扣费策略

- billing.rs: 修复参数传递 (model_id -> version_id)
- billing.rs: 新增 BillingResult 枚举支持 InsufficientBalance 错误
- billing.rs: 实现级联扣费 (优先 project 余额,不足时 fallback 到 workspace)
- billing.rs: 余额不足时创建系统消息并持久化
- chat/service.rs: 捕获 InsufficientBalance 错误并调用 create_system_message
- client/mod.rs: 超时时间从 60s 改为 120s
This commit is contained in:
ZhenYi 2026-04-28 19:59:06 +08:00
parent 13523762aa
commit c6bb72682b
4 changed files with 208 additions and 110 deletions

View File

@ -23,20 +23,28 @@ pub struct BillingRecord {
pub output_tokens: i64,
}
/// Record AI usage for a project.
/// Extended result that includes insufficient balance flag for system message creation.
#[derive(Debug)]
pub enum BillingResult {
Success(BillingRecord),
InsufficientBalance { message: String },
}
/// Record AI usage for a project with cascading billing.
///
/// If the project belongs to a workspace, the cost is deducted from the
/// workspace's shared quota. Otherwise it is deducted from the project's own
/// billing balance.
/// Billing strategy:
/// 1. Try to deduct from project balance first
/// 2. If insufficient, fallback to workspace balance (if project belongs to workspace)
/// 3. If both insufficient or no workspace, return InsufficientBalance error with room_id
///
/// Returns an error if there is insufficient balance.
/// Returns BillingError::InsufficientBalance with room_id for system message creation.
pub async fn record_ai_usage(
db: &AppDatabase,
project_uid: Uuid,
model_id: Uuid,
input_tokens: i64,
output_tokens: i64,
) -> Result<BillingRecord, AgentError> {
) -> Result<BillingResult, AgentError> {
// 1. Look up the active price for this model.
let pricing = model_pricing::Entity::find()
.filter(model_pricing::Column::ModelVersionId.eq(model_id))
@ -68,106 +76,27 @@ pub async fn record_ai_usage(
let currency = pricing.currency.clone();
// 3. Determine whether to bill the project or its workspace.
// 3. Cascading billing: project balance first, then workspace if insufficient.
let proj = project::Entity::find_by_id(project_uid)
.one(db)
.await?
.ok_or_else(|| AgentError::Internal("Project not found".into()))?;
if let Some(workspace_id) = proj.workspace_id {
// ── Workspace-shared quota ──────────────────────────────────
let txn = db.begin().await?;
// SELECT FOR UPDATE to prevent race conditions
let current = workspace_billing::Entity::find_by_id(workspace_id)
.lock_exclusive()
.one(&txn)
.await?
.ok_or_else(|| AgentError::Internal("Workspace billing account not found".into()))?;
// Validate balance before any modifications
if current.balance < total_cost {
txn.rollback().await?;
return Err(AgentError::Internal(format!(
"Insufficient workspace billing balance. Required: {:.4} {}, Available: {:.4} {}",
total_cost, currency, current.balance, currency
)));
}
let amount_dec = -total_cost;
let now = chrono::Utc::now();
// Insert workspace billing history AFTER validation
workspace_billing_history::ActiveModel {
uid: Set(Uuid::new_v4()),
workspace_id: Set(workspace_id),
user_id: Set(Some(proj.created_by)),
amount: Set(amount_dec),
currency: Set(currency.clone()),
reason: Set(format!("ai_usage:{}", project_uid)),
extra: Set(Some(serde_json::json!({
"project_id": project_uid.to_string(),
"model_id": model_id.to_string(),
"input_tokens": input_tokens,
"output_tokens": output_tokens,
}))),
created_at: Set(now),
}
.insert(&txn)
.await?;
// Deduct from workspace balance
let new_balance = current.balance - total_cost;
let mut updated: workspace_billing::ActiveModel = current.into();
updated.balance = Set(new_balance);
updated.updated_at = Set(now);
updated.update(&txn).await?;
txn.commit().await?;
let cost_f64 = total_cost.to_string().parse().unwrap_or(0.0);
tracing::info!(
project_id = %project_uid,
model_id = %model_id,
input_tokens = input_tokens,
output_tokens = output_tokens,
cost = %cost_f64,
currency = %currency,
workspace_id = %workspace_id.to_string(),
"ai_usage_recorded"
);
Ok(BillingRecord {
cost: cost_f64,
currency,
input_tokens,
output_tokens,
})
} else {
// ── Project-owned quota ─────────────────────────────────────
let txn = db.begin().await?;
// SELECT FOR UPDATE to prevent race conditions
let current = project_billing::Entity::find_by_id(project_uid)
// Always check project balance first
let project_billing = project_billing::Entity::find_by_id(project_uid)
.lock_exclusive()
.one(&txn)
.await?
.ok_or_else(|| AgentError::Internal("Project billing account not found".into()))?;
// Validate balance before any modifications
if current.balance < total_cost {
txn.rollback().await?;
return Err(AgentError::Internal(format!(
"Insufficient billing balance. Required: {:.4} {}, Available: {:.4} {}",
total_cost, currency, current.balance, currency
)));
}
let amount_dec = -total_cost;
let now = chrono::Utc::now();
// Insert project billing history AFTER validation
if project_billing.balance >= total_cost {
// ── Project has sufficient balance ──────────────────────────
let amount_dec = -total_cost;
project_billing_history::ActiveModel {
uid: Set(Uuid::new_v4()),
project: Set(project_uid),
@ -186,9 +115,8 @@ pub async fn record_ai_usage(
.insert(&txn)
.await?;
// Deduct from project balance
let new_balance = current.balance - total_cost;
let mut updated: project_billing::ActiveModel = current.into();
let new_balance = project_billing.balance - total_cost;
let mut updated: project_billing::ActiveModel = project_billing.into();
updated.balance = Set(new_balance);
updated.update(&txn).await?;
@ -203,14 +131,93 @@ pub async fn record_ai_usage(
output_tokens = output_tokens,
cost = %cost_f64,
currency = %currency,
source = "project",
"ai_usage_recorded"
);
Ok(BillingRecord {
Ok(BillingResult::Success(BillingRecord {
cost: cost_f64,
currency,
input_tokens,
output_tokens,
}))
} else if let Some(workspace_id) = proj.workspace_id {
// ── Project insufficient, fallback to workspace ─────────────
let workspace_billing = workspace_billing::Entity::find_by_id(workspace_id)
.lock_exclusive()
.one(&txn)
.await?
.ok_or_else(|| AgentError::Internal("Workspace billing account not found".into()))?;
if workspace_billing.balance < total_cost {
txn.rollback().await?;
return Ok(BillingResult::InsufficientBalance {
message: format!(
"Insufficient balance. Project: {:.4} {}, Workspace: {:.4} {}, Required: {:.4} {}",
project_billing.balance, currency,
workspace_billing.balance, currency,
total_cost, currency
),
});
}
let amount_dec = -total_cost;
workspace_billing_history::ActiveModel {
uid: Set(Uuid::new_v4()),
workspace_id: Set(workspace_id),
user_id: Set(Some(proj.created_by)),
amount: Set(amount_dec),
currency: Set(currency.clone()),
reason: Set(format!("ai_usage:{}", project_uid)),
extra: Set(Some(serde_json::json!({
"project_id": project_uid.to_string(),
"model_id": model_id.to_string(),
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"fallback_reason": "project_balance_insufficient"
}))),
created_at: Set(now),
}
.insert(&txn)
.await?;
let new_balance = workspace_billing.balance - total_cost;
let mut updated: workspace_billing::ActiveModel = workspace_billing.into();
updated.balance = Set(new_balance);
updated.updated_at = Set(now);
updated.update(&txn).await?;
txn.commit().await?;
let cost_f64 = total_cost.to_string().parse().unwrap_or(0.0);
tracing::info!(
project_id = %project_uid,
model_id = %model_id,
input_tokens = input_tokens,
output_tokens = output_tokens,
cost = %cost_f64,
currency = %currency,
workspace_id = %workspace_id.to_string(),
source = "workspace_fallback",
"ai_usage_recorded"
);
Ok(BillingResult::Success(BillingRecord {
cost: cost_f64,
currency,
input_tokens,
output_tokens,
}))
} else {
// ── Project insufficient and no workspace ───────────────────
txn.rollback().await?;
Ok(BillingResult::InsufficientBalance {
message: format!(
"Insufficient balance. Required: {:.4} {}, Available: {:.4} {}",
total_cost, currency, project_billing.balance, currency
),
})
}
}

View File

@ -58,17 +58,24 @@ async fn record_ai_session(
output_tokens: i64,
latency_ms: i64,
) {
let (cost, currency) = match billing::record_ai_usage(
let (cost, currency, error_msg) = match billing::record_ai_usage(
db,
project_id,
model_id,
version_id,
input_tokens,
output_tokens,
)
.await
{
Ok(record) => (Some(record.cost), Some(record.currency)),
Err(_) => (None, None),
Ok(billing::BillingResult::Success(record)) => {
(Some(record.cost), Some(record.currency), None)
}
Ok(billing::BillingResult::InsufficientBalance { message }) => {
// Create system message for insufficient balance
create_system_message(db, room_id, &message).await;
(None, None, Some(message))
}
Err(_) => (None, None, None),
};
let _ = models::ai::ai_session::ActiveModel {
@ -81,7 +88,7 @@ async fn record_ai_session(
latency_ms: Set(Some(latency_ms)),
cost: Set(cost),
currency: Set(currency),
error_message: Set(None),
error_message: Set(error_msg),
error_code: Set(None),
created_at: Set(chrono::Utc::now()),
}
@ -89,6 +96,71 @@ async fn record_ai_session(
.await;
}
/// Create a system message in the room for billing errors.
async fn create_system_message(
db: &db::database::AppDatabase,
room_id: Uuid,
message: &str,
) {
use models::rooms::{room_message, MessageSenderType, MessageContentType};
use sea_orm::Set;
// Get next sequence number - we don't have cache here, so we query directly
let last_seq = match room_message::Entity::find()
.filter(room_message::Column::Room.eq(room_id))
.order_by_desc(room_message::Column::Seq)
.one(db)
.await
{
Ok(Some(m)) => m.seq,
Ok(None) => 0,
Err(e) => {
tracing::warn!(error = %e, "Failed to get last seq for system message");
return;
}
};
let seq = last_seq + 1;
let now = chrono::Utc::now();
let result = room_message::ActiveModel {
id: Set(Uuid::new_v4()),
seq: Set(seq),
room: Set(room_id),
sender_type: Set(MessageSenderType::System),
sender_id: Set(None),
model_id: Set(None),
thread: Set(None),
in_reply_to: Set(None),
content: Set(message.to_string()),
content_type: Set(MessageContentType::Text),
thinking_content: Set(None),
edited_at: Set(None),
send_at: Set(now),
revoked: Set(None),
revoked_by: Set(None),
}
.insert(db)
.await;
match result {
Ok(_) => {
tracing::info!(
room_id = %room_id,
message = %message,
"system_message_created_for_billing_error"
);
}
Err(e) => {
tracing::warn!(
error = %e,
room_id = %room_id,
"Failed to create system message for billing error"
);
}
}
}
/// Service for handling AI chat requests in rooms.
pub struct ChatService {
ai_base_url: Option<String>,
@ -614,7 +686,26 @@ impl ChatService {
for call in &calls {
let start = std::time::Instant::now();
let executor = crate::tool::ToolExecutor::new();
let results = match executor.execute_batch(vec![call.clone()], &mut ctx).await {
// Use select! loop to send heartbeat chunks at 30s intervals
// during long tool execution, resetting the frontend streaming timer.
let fut = executor.execute_batch(vec![call.clone()], &mut ctx);
tokio::pin!(fut);
let results = loop {
tokio::select! {
result = fut.as_mut() => break result,
_ = tokio::time::sleep(std::time::Duration::from_secs(30)) => {
on_chunk(AiStreamChunk {
content: String::new(),
done: false,
chunk_type: AiChunkType::ToolCall,
}).await;
}
}
};
let results = match results {
Ok(r) => r,
Err(e) => {
let elapsed = start.elapsed().as_millis() as i64;

View File

@ -760,9 +760,9 @@ async fn call_stream_once(
})
};
// 60s timeout for the entire stream
match tokio::time::timeout(std::time::Duration::from_secs(60), stream_fut).await {
// 120s timeout for the entire stream
match tokio::time::timeout(std::time::Duration::from_secs(120), stream_fut).await {
Ok(result) => result,
Err(_) => Err(AgentError::Timeout { task_id: 0, seconds: 60 }),
Err(_) => Err(AgentError::Timeout { task_id: 0, seconds: 120 }),
}
}

View File

@ -13,7 +13,7 @@ pub mod sync;
pub mod task;
pub mod tokent;
pub mod tool;
pub use billing::{BillingRecord, record_ai_usage};
pub use billing::{BillingRecord, BillingResult, record_ai_usage};
pub use sync::list_accessible_models;
pub use task::TaskService;
pub use tokent::{TokenUsage, resolve_usage};