328 lines
10 KiB
Rust
328 lines
10 KiB
Rust
use chrono::Utc;
|
|
use db::sqlx::{self, types::Decimal};
|
|
use rust_decimal::Decimal as RustDecimal;
|
|
use uuid::Uuid;
|
|
|
|
use super::types::{BillingRecord, BillingTarget, SessionContext};
|
|
use crate::AppService;
|
|
use crate::error::AppError;
|
|
|
|
impl AppService {
|
|
pub async fn agent_calculate_cost(
|
|
&self,
|
|
model_version_id: Uuid,
|
|
input_tokens: i64,
|
|
output_tokens: i64,
|
|
) -> Result<Option<(RustDecimal, String)>, AppError> {
|
|
let (input_price_per_million, output_price_per_million) =
|
|
self.agent_resolve_pricing(model_version_id).await?;
|
|
|
|
let input_price = match input_price_per_million {
|
|
Some(p) => p,
|
|
None => return Ok(None),
|
|
};
|
|
|
|
let output_price = match output_price_per_million {
|
|
Some(p) => p,
|
|
None => return Ok(None),
|
|
};
|
|
|
|
let million = RustDecimal::from(1_000_000u64);
|
|
let input_decimal = RustDecimal::from(input_tokens);
|
|
let output_decimal = RustDecimal::from(output_tokens);
|
|
|
|
let cost = (input_decimal * input_price / million)
|
|
+ (output_decimal * output_price / million);
|
|
Ok(Some((cost, "USD".to_string())))
|
|
}
|
|
pub async fn agent_deduct_billing(
|
|
&self,
|
|
ctx: &SessionContext,
|
|
cost: RustDecimal,
|
|
) -> Result<(), AppError> {
|
|
match ctx.billing_target {
|
|
BillingTarget::User => {
|
|
let user_id = ctx.user_id.ok_or_else(|| {
|
|
AppError::BadRequest(
|
|
"user billing target requires user_id".to_string(),
|
|
)
|
|
})?;
|
|
self.deduct_user_balance(user_id, cost).await
|
|
}
|
|
BillingTarget::Workspace => {
|
|
let wk_id = ctx.workspace_id.ok_or_else(|| {
|
|
AppError::BadRequest(
|
|
"workspace billing target requires workspace_id"
|
|
.to_string(),
|
|
)
|
|
})?;
|
|
self.deduct_workspace_balance(wk_id, cost).await
|
|
}
|
|
}
|
|
}
|
|
pub async fn agent_record_usage(
|
|
&self,
|
|
record: &BillingRecord,
|
|
) -> Result<(), AppError> {
|
|
let cost_decimal: Option<Decimal> = record.cost.map(|c| c.into());
|
|
|
|
sqlx::query(
|
|
"INSERT INTO agent_token_usage \
|
|
(invocation, session, model_version, \
|
|
input_tokens, output_tokens, cached_input_tokens, \
|
|
cache_read_tokens, cache_write_tokens, reasoning_tokens, \
|
|
total_tokens, cost, currency, created_at) \
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)",
|
|
)
|
|
.bind(record.invocation_id)
|
|
.bind(record.session_id)
|
|
.bind(record.model_version_id)
|
|
.bind(record.input_tokens)
|
|
.bind(record.output_tokens)
|
|
.bind(record.cached_input_tokens)
|
|
.bind(record.cache_read_tokens)
|
|
.bind(record.cache_write_tokens)
|
|
.bind(record.reasoning_tokens)
|
|
.bind(record.total_tokens)
|
|
.bind(&cost_decimal)
|
|
.bind(&record.currency)
|
|
.bind(record.created_at)
|
|
.execute(self.db.writer())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
|
|
|
|
Ok(())
|
|
}
|
|
pub async fn agent_record_invocation(
|
|
&self,
|
|
invocation_id: Uuid,
|
|
session_id: Uuid,
|
|
conversation_id: Option<Uuid>,
|
|
message_id: Option<Uuid>,
|
|
model_version_id: Uuid,
|
|
status: &str,
|
|
error: Option<&str>,
|
|
) -> Result<(), AppError> {
|
|
let now = Utc::now();
|
|
sqlx::query(
|
|
"INSERT INTO agent_model_invocation \
|
|
(id, session, conversation, message, model_version, status, error, \
|
|
started_at, finished_at) \
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)",
|
|
)
|
|
.bind(invocation_id)
|
|
.bind(session_id)
|
|
.bind(conversation_id)
|
|
.bind(message_id)
|
|
.bind(model_version_id)
|
|
.bind(status)
|
|
.bind(error)
|
|
.bind(now)
|
|
.bind(now)
|
|
.execute(self.db.writer())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
|
|
|
|
Ok(())
|
|
}
|
|
pub async fn agent_record_tool_call(
|
|
&self,
|
|
invocation_id: Uuid,
|
|
session_id: Uuid,
|
|
conversation_id: Option<Uuid>,
|
|
message_id: Option<Uuid>,
|
|
tool_call_id: &str,
|
|
tool_name: &str,
|
|
arguments: Option<&str>,
|
|
result: Option<&str>,
|
|
error: Option<&str>,
|
|
status: &str,
|
|
latency_ms: Option<i64>,
|
|
) -> Result<(), AppError> {
|
|
let now = Utc::now();
|
|
sqlx::query(
|
|
"INSERT INTO agent_tool_call_log \
|
|
(invocation, session, conversation, message, tool_call_id, \
|
|
tool_name, arguments, result, error, status, \
|
|
started_at, finished_at, latency_ms) \
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)",
|
|
)
|
|
.bind(invocation_id)
|
|
.bind(session_id)
|
|
.bind(conversation_id)
|
|
.bind(message_id)
|
|
.bind(tool_call_id)
|
|
.bind(tool_name)
|
|
.bind(arguments)
|
|
.bind(result)
|
|
.bind(error)
|
|
.bind(status)
|
|
.bind(now)
|
|
.bind(now)
|
|
.bind(latency_ms)
|
|
.execute(self.db.writer())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
|
|
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
impl AppService {
|
|
async fn deduct_user_balance(
|
|
&self,
|
|
user_id: Uuid,
|
|
cost: RustDecimal,
|
|
) -> Result<(), AppError> {
|
|
const MAX_RETRIES: u32 = 3;
|
|
for attempt in 0..MAX_RETRIES {
|
|
match self.try_deduct_user_balance(user_id, cost).await {
|
|
Ok(()) => return Ok(()),
|
|
Err(AppError::TxnError) if attempt < MAX_RETRIES - 1 => {
|
|
let backoff_ms = 10 * (1 << attempt);
|
|
tokio::time::sleep(tokio::time::Duration::from_millis(
|
|
backoff_ms,
|
|
))
|
|
.await;
|
|
continue;
|
|
}
|
|
Err(e) => return Err(e),
|
|
}
|
|
}
|
|
Err(AppError::TxnError)
|
|
}
|
|
|
|
async fn try_deduct_user_balance(
|
|
&self,
|
|
user_id: Uuid,
|
|
cost: RustDecimal,
|
|
) -> Result<(), AppError> {
|
|
let mut txn = self.db.begin().await.map_err(|_| AppError::TxnError)?;
|
|
|
|
let current: Option<RustDecimal> = sqlx::query_scalar(
|
|
"SELECT balance FROM user_billing WHERE \"user\" = $1 FOR UPDATE",
|
|
)
|
|
.bind(user_id)
|
|
.fetch_optional(&mut **txn.inner_mut())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
|
|
let current = match current {
|
|
Some(balance) => balance,
|
|
None => {
|
|
let default_balance = RustDecimal::from(20);
|
|
let now = Utc::now();
|
|
sqlx::query(
|
|
"INSERT INTO user_billing (\"user\", balance, created_at, updated_at) \
|
|
VALUES ($1, $2, $3, $3)",
|
|
)
|
|
.bind(user_id)
|
|
.bind(&default_balance)
|
|
.bind(now)
|
|
.execute(&mut **txn.inner_mut())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
|
|
default_balance
|
|
}
|
|
};
|
|
|
|
if current < cost {
|
|
return Err(AppError::BadRequest(
|
|
"insufficient balance".to_string(),
|
|
));
|
|
}
|
|
|
|
let new_balance = current - cost;
|
|
|
|
sqlx::query(
|
|
"UPDATE user_billing SET balance = $1, updated_at = $2 WHERE \"user\" = $3",
|
|
)
|
|
.bind(&new_balance)
|
|
.bind(Utc::now())
|
|
.bind(user_id)
|
|
.execute(&mut **txn.inner_mut())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
|
|
|
|
txn.commit().await.map_err(|_| AppError::TxnError)?;
|
|
Ok(())
|
|
}
|
|
|
|
async fn deduct_workspace_balance(
|
|
&self,
|
|
wk_id: Uuid,
|
|
cost: RustDecimal,
|
|
) -> Result<(), AppError> {
|
|
const MAX_RETRIES: u32 = 3;
|
|
for attempt in 0..MAX_RETRIES {
|
|
match self.try_deduct_workspace_balance(wk_id, cost).await {
|
|
Ok(()) => return Ok(()),
|
|
Err(AppError::TxnError) if attempt < MAX_RETRIES - 1 => {
|
|
let backoff_ms = 10 * (1 << attempt);
|
|
tokio::time::sleep(tokio::time::Duration::from_millis(
|
|
backoff_ms,
|
|
))
|
|
.await;
|
|
continue;
|
|
}
|
|
Err(e) => return Err(e),
|
|
}
|
|
}
|
|
Err(AppError::TxnError)
|
|
}
|
|
|
|
async fn try_deduct_workspace_balance(
|
|
&self,
|
|
wk_id: Uuid,
|
|
cost: RustDecimal,
|
|
) -> Result<(), AppError> {
|
|
let mut txn = self.db.begin().await.map_err(|_| AppError::TxnError)?;
|
|
|
|
let current: Option<RustDecimal> = sqlx::query_scalar(
|
|
"SELECT balance FROM wk_billing WHERE wk = $1 FOR UPDATE",
|
|
)
|
|
.bind(wk_id)
|
|
.fetch_optional(&mut **txn.inner_mut())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
|
|
let current = match current {
|
|
Some(balance) => balance,
|
|
None => {
|
|
let default_balance = RustDecimal::from(20);
|
|
let now = Utc::now();
|
|
sqlx::query(
|
|
"INSERT INTO wk_billing (wk, balance, updated_at) \
|
|
VALUES ($1, $2, $3)",
|
|
)
|
|
.bind(wk_id)
|
|
.bind(&default_balance)
|
|
.bind(now)
|
|
.execute(&mut **txn.inner_mut())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
|
|
default_balance
|
|
}
|
|
};
|
|
|
|
if current < cost {
|
|
return Err(AppError::BadRequest(
|
|
"insufficient workspace balance".to_string(),
|
|
));
|
|
}
|
|
|
|
let new_balance = current - cost;
|
|
|
|
sqlx::query(
|
|
"UPDATE wk_billing SET balance = $1, updated_at = $2 WHERE wk = $3",
|
|
)
|
|
.bind(&new_balance)
|
|
.bind(Utc::now())
|
|
.bind(wk_id)
|
|
.execute(&mut **txn.inner_mut())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
|
|
|
|
txn.commit().await.map_err(|_| AppError::TxnError)?;
|
|
Ok(())
|
|
}
|
|
}
|