fix(agent): 修复扣费链路并实现级联扣费策略

- billing.rs: 修复参数传递 (model_id -> version_id)
- billing.rs: 新增 BillingResult 枚举支持 InsufficientBalance 错误
- billing.rs: 实现级联扣费 (优先 project 余额,不足时 fallback 到 workspace)
- billing.rs: 余额不足时创建系统消息并持久化
- chat/service.rs: 捕获 InsufficientBalance 错误并调用 create_system_message
- client/mod.rs: 超时时间从 60s 改为 120s
This commit is contained in:
ZhenYi 2026-04-28 19:59:06 +08:00
parent 13523762aa
commit c6bb72682b
4 changed files with 208 additions and 110 deletions

View File

@ -23,20 +23,28 @@ pub struct BillingRecord {
pub output_tokens: i64, pub output_tokens: i64,
} }
/// Record AI usage for a project. /// Extended result that includes insufficient balance flag for system message creation.
#[derive(Debug)]
pub enum BillingResult {
Success(BillingRecord),
InsufficientBalance { message: String },
}
/// Record AI usage for a project with cascading billing.
/// ///
/// If the project belongs to a workspace, the cost is deducted from the /// Billing strategy:
/// workspace's shared quota. Otherwise it is deducted from the project's own /// 1. Try to deduct from project balance first
/// billing balance. /// 2. If insufficient, fallback to workspace balance (if project belongs to workspace)
/// 3. If both insufficient or no workspace, return InsufficientBalance error with room_id
/// ///
/// Returns an error if there is insufficient balance. /// Returns BillingError::InsufficientBalance with room_id for system message creation.
pub async fn record_ai_usage( pub async fn record_ai_usage(
db: &AppDatabase, db: &AppDatabase,
project_uid: Uuid, project_uid: Uuid,
model_id: Uuid, model_id: Uuid,
input_tokens: i64, input_tokens: i64,
output_tokens: i64, output_tokens: i64,
) -> Result<BillingRecord, AgentError> { ) -> Result<BillingResult, AgentError> {
// 1. Look up the active price for this model. // 1. Look up the active price for this model.
let pricing = model_pricing::Entity::find() let pricing = model_pricing::Entity::find()
.filter(model_pricing::Column::ModelVersionId.eq(model_id)) .filter(model_pricing::Column::ModelVersionId.eq(model_id))
@ -68,106 +76,27 @@ pub async fn record_ai_usage(
let currency = pricing.currency.clone(); let currency = pricing.currency.clone();
// 3. Determine whether to bill the project or its workspace. // 3. Cascading billing: project balance first, then workspace if insufficient.
let proj = project::Entity::find_by_id(project_uid) let proj = project::Entity::find_by_id(project_uid)
.one(db) .one(db)
.await? .await?
.ok_or_else(|| AgentError::Internal("Project not found".into()))?; .ok_or_else(|| AgentError::Internal("Project not found".into()))?;
if let Some(workspace_id) = proj.workspace_id {
// ── Workspace-shared quota ──────────────────────────────────
let txn = db.begin().await?; let txn = db.begin().await?;
// SELECT FOR UPDATE to prevent race conditions // Always check project balance first
let current = workspace_billing::Entity::find_by_id(workspace_id) let project_billing = project_billing::Entity::find_by_id(project_uid)
.lock_exclusive()
.one(&txn)
.await?
.ok_or_else(|| AgentError::Internal("Workspace billing account not found".into()))?;
// Validate balance before any modifications
if current.balance < total_cost {
txn.rollback().await?;
return Err(AgentError::Internal(format!(
"Insufficient workspace billing balance. Required: {:.4} {}, Available: {:.4} {}",
total_cost, currency, current.balance, currency
)));
}
let amount_dec = -total_cost;
let now = chrono::Utc::now();
// Insert workspace billing history AFTER validation
workspace_billing_history::ActiveModel {
uid: Set(Uuid::new_v4()),
workspace_id: Set(workspace_id),
user_id: Set(Some(proj.created_by)),
amount: Set(amount_dec),
currency: Set(currency.clone()),
reason: Set(format!("ai_usage:{}", project_uid)),
extra: Set(Some(serde_json::json!({
"project_id": project_uid.to_string(),
"model_id": model_id.to_string(),
"input_tokens": input_tokens,
"output_tokens": output_tokens,
}))),
created_at: Set(now),
}
.insert(&txn)
.await?;
// Deduct from workspace balance
let new_balance = current.balance - total_cost;
let mut updated: workspace_billing::ActiveModel = current.into();
updated.balance = Set(new_balance);
updated.updated_at = Set(now);
updated.update(&txn).await?;
txn.commit().await?;
let cost_f64 = total_cost.to_string().parse().unwrap_or(0.0);
tracing::info!(
project_id = %project_uid,
model_id = %model_id,
input_tokens = input_tokens,
output_tokens = output_tokens,
cost = %cost_f64,
currency = %currency,
workspace_id = %workspace_id.to_string(),
"ai_usage_recorded"
);
Ok(BillingRecord {
cost: cost_f64,
currency,
input_tokens,
output_tokens,
})
} else {
// ── Project-owned quota ─────────────────────────────────────
let txn = db.begin().await?;
// SELECT FOR UPDATE to prevent race conditions
let current = project_billing::Entity::find_by_id(project_uid)
.lock_exclusive() .lock_exclusive()
.one(&txn) .one(&txn)
.await? .await?
.ok_or_else(|| AgentError::Internal("Project billing account not found".into()))?; .ok_or_else(|| AgentError::Internal("Project billing account not found".into()))?;
// Validate balance before any modifications
if current.balance < total_cost {
txn.rollback().await?;
return Err(AgentError::Internal(format!(
"Insufficient billing balance. Required: {:.4} {}, Available: {:.4} {}",
total_cost, currency, current.balance, currency
)));
}
let amount_dec = -total_cost;
let now = chrono::Utc::now(); let now = chrono::Utc::now();
// Insert project billing history AFTER validation if project_billing.balance >= total_cost {
// ── Project has sufficient balance ──────────────────────────
let amount_dec = -total_cost;
project_billing_history::ActiveModel { project_billing_history::ActiveModel {
uid: Set(Uuid::new_v4()), uid: Set(Uuid::new_v4()),
project: Set(project_uid), project: Set(project_uid),
@ -186,9 +115,8 @@ pub async fn record_ai_usage(
.insert(&txn) .insert(&txn)
.await?; .await?;
// Deduct from project balance let new_balance = project_billing.balance - total_cost;
let new_balance = current.balance - total_cost; let mut updated: project_billing::ActiveModel = project_billing.into();
let mut updated: project_billing::ActiveModel = current.into();
updated.balance = Set(new_balance); updated.balance = Set(new_balance);
updated.update(&txn).await?; updated.update(&txn).await?;
@ -203,14 +131,93 @@ pub async fn record_ai_usage(
output_tokens = output_tokens, output_tokens = output_tokens,
cost = %cost_f64, cost = %cost_f64,
currency = %currency, currency = %currency,
source = "project",
"ai_usage_recorded" "ai_usage_recorded"
); );
Ok(BillingRecord { Ok(BillingResult::Success(BillingRecord {
cost: cost_f64, cost: cost_f64,
currency, currency,
input_tokens, input_tokens,
output_tokens, output_tokens,
}))
} else if let Some(workspace_id) = proj.workspace_id {
// ── Project insufficient, fallback to workspace ─────────────
let workspace_billing = workspace_billing::Entity::find_by_id(workspace_id)
.lock_exclusive()
.one(&txn)
.await?
.ok_or_else(|| AgentError::Internal("Workspace billing account not found".into()))?;
if workspace_billing.balance < total_cost {
txn.rollback().await?;
return Ok(BillingResult::InsufficientBalance {
message: format!(
"Insufficient balance. Project: {:.4} {}, Workspace: {:.4} {}, Required: {:.4} {}",
project_billing.balance, currency,
workspace_billing.balance, currency,
total_cost, currency
),
});
}
let amount_dec = -total_cost;
workspace_billing_history::ActiveModel {
uid: Set(Uuid::new_v4()),
workspace_id: Set(workspace_id),
user_id: Set(Some(proj.created_by)),
amount: Set(amount_dec),
currency: Set(currency.clone()),
reason: Set(format!("ai_usage:{}", project_uid)),
extra: Set(Some(serde_json::json!({
"project_id": project_uid.to_string(),
"model_id": model_id.to_string(),
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"fallback_reason": "project_balance_insufficient"
}))),
created_at: Set(now),
}
.insert(&txn)
.await?;
let new_balance = workspace_billing.balance - total_cost;
let mut updated: workspace_billing::ActiveModel = workspace_billing.into();
updated.balance = Set(new_balance);
updated.updated_at = Set(now);
updated.update(&txn).await?;
txn.commit().await?;
let cost_f64 = total_cost.to_string().parse().unwrap_or(0.0);
tracing::info!(
project_id = %project_uid,
model_id = %model_id,
input_tokens = input_tokens,
output_tokens = output_tokens,
cost = %cost_f64,
currency = %currency,
workspace_id = %workspace_id.to_string(),
source = "workspace_fallback",
"ai_usage_recorded"
);
Ok(BillingResult::Success(BillingRecord {
cost: cost_f64,
currency,
input_tokens,
output_tokens,
}))
} else {
// ── Project insufficient and no workspace ───────────────────
txn.rollback().await?;
Ok(BillingResult::InsufficientBalance {
message: format!(
"Insufficient balance. Required: {:.4} {}, Available: {:.4} {}",
total_cost, currency, project_billing.balance, currency
),
}) })
} }
} }

View File

@ -58,17 +58,24 @@ async fn record_ai_session(
output_tokens: i64, output_tokens: i64,
latency_ms: i64, latency_ms: i64,
) { ) {
let (cost, currency) = match billing::record_ai_usage( let (cost, currency, error_msg) = match billing::record_ai_usage(
db, db,
project_id, project_id,
model_id, version_id,
input_tokens, input_tokens,
output_tokens, output_tokens,
) )
.await .await
{ {
Ok(record) => (Some(record.cost), Some(record.currency)), Ok(billing::BillingResult::Success(record)) => {
Err(_) => (None, None), (Some(record.cost), Some(record.currency), None)
}
Ok(billing::BillingResult::InsufficientBalance { message }) => {
// Create system message for insufficient balance
create_system_message(db, room_id, &message).await;
(None, None, Some(message))
}
Err(_) => (None, None, None),
}; };
let _ = models::ai::ai_session::ActiveModel { let _ = models::ai::ai_session::ActiveModel {
@ -81,7 +88,7 @@ async fn record_ai_session(
latency_ms: Set(Some(latency_ms)), latency_ms: Set(Some(latency_ms)),
cost: Set(cost), cost: Set(cost),
currency: Set(currency), currency: Set(currency),
error_message: Set(None), error_message: Set(error_msg),
error_code: Set(None), error_code: Set(None),
created_at: Set(chrono::Utc::now()), created_at: Set(chrono::Utc::now()),
} }
@ -89,6 +96,71 @@ async fn record_ai_session(
.await; .await;
} }
/// Create a system message in the room for billing errors.
async fn create_system_message(
db: &db::database::AppDatabase,
room_id: Uuid,
message: &str,
) {
use models::rooms::{room_message, MessageSenderType, MessageContentType};
use sea_orm::Set;
// Get next sequence number - we don't have cache here, so we query directly
let last_seq = match room_message::Entity::find()
.filter(room_message::Column::Room.eq(room_id))
.order_by_desc(room_message::Column::Seq)
.one(db)
.await
{
Ok(Some(m)) => m.seq,
Ok(None) => 0,
Err(e) => {
tracing::warn!(error = %e, "Failed to get last seq for system message");
return;
}
};
let seq = last_seq + 1;
let now = chrono::Utc::now();
let result = room_message::ActiveModel {
id: Set(Uuid::new_v4()),
seq: Set(seq),
room: Set(room_id),
sender_type: Set(MessageSenderType::System),
sender_id: Set(None),
model_id: Set(None),
thread: Set(None),
in_reply_to: Set(None),
content: Set(message.to_string()),
content_type: Set(MessageContentType::Text),
thinking_content: Set(None),
edited_at: Set(None),
send_at: Set(now),
revoked: Set(None),
revoked_by: Set(None),
}
.insert(db)
.await;
match result {
Ok(_) => {
tracing::info!(
room_id = %room_id,
message = %message,
"system_message_created_for_billing_error"
);
}
Err(e) => {
tracing::warn!(
error = %e,
room_id = %room_id,
"Failed to create system message for billing error"
);
}
}
}
/// Service for handling AI chat requests in rooms. /// Service for handling AI chat requests in rooms.
pub struct ChatService { pub struct ChatService {
ai_base_url: Option<String>, ai_base_url: Option<String>,
@ -614,7 +686,26 @@ impl ChatService {
for call in &calls { for call in &calls {
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let executor = crate::tool::ToolExecutor::new(); let executor = crate::tool::ToolExecutor::new();
let results = match executor.execute_batch(vec![call.clone()], &mut ctx).await {
// Use select! loop to send heartbeat chunks at 30s intervals
// during long tool execution, resetting the frontend streaming timer.
let fut = executor.execute_batch(vec![call.clone()], &mut ctx);
tokio::pin!(fut);
let results = loop {
tokio::select! {
result = fut.as_mut() => break result,
_ = tokio::time::sleep(std::time::Duration::from_secs(30)) => {
on_chunk(AiStreamChunk {
content: String::new(),
done: false,
chunk_type: AiChunkType::ToolCall,
}).await;
}
}
};
let results = match results {
Ok(r) => r, Ok(r) => r,
Err(e) => { Err(e) => {
let elapsed = start.elapsed().as_millis() as i64; let elapsed = start.elapsed().as_millis() as i64;

View File

@ -760,9 +760,9 @@ async fn call_stream_once(
}) })
}; };
// 60s timeout for the entire stream // 120s timeout for the entire stream
match tokio::time::timeout(std::time::Duration::from_secs(60), stream_fut).await { match tokio::time::timeout(std::time::Duration::from_secs(120), stream_fut).await {
Ok(result) => result, Ok(result) => result,
Err(_) => Err(AgentError::Timeout { task_id: 0, seconds: 60 }), Err(_) => Err(AgentError::Timeout { task_id: 0, seconds: 120 }),
} }
} }

View File

@ -13,7 +13,7 @@ pub mod sync;
pub mod task; pub mod task;
pub mod tokent; pub mod tokent;
pub mod tool; pub mod tool;
pub use billing::{BillingRecord, record_ai_usage}; pub use billing::{BillingRecord, BillingResult, record_ai_usage};
pub use sync::list_accessible_models; pub use sync::list_accessible_models;
pub use task::TaskService; pub use task::TaskService;
pub use tokent::{TokenUsage, resolve_usage}; pub use tokent::{TokenUsage, resolve_usage};