gitdataai/lib/service/agent/billing.rs

328 lines
10 KiB
Rust

use chrono::Utc;
use db::sqlx::{self, types::Decimal};
use rust_decimal::Decimal as RustDecimal;
use uuid::Uuid;
use super::types::{BillingRecord, BillingTarget, SessionContext};
use crate::AppService;
use crate::error::AppError;
impl AppService {
pub async fn agent_calculate_cost(
&self,
model_version_id: Uuid,
input_tokens: i64,
output_tokens: i64,
) -> Result<Option<(RustDecimal, String)>, AppError> {
let (input_price_per_million, output_price_per_million) =
self.agent_resolve_pricing(model_version_id).await?;
let input_price = match input_price_per_million {
Some(p) => p,
None => return Ok(None),
};
let output_price = match output_price_per_million {
Some(p) => p,
None => return Ok(None),
};
let million = RustDecimal::from(1_000_000u64);
let input_decimal = RustDecimal::from(input_tokens);
let output_decimal = RustDecimal::from(output_tokens);
let cost = (input_decimal * input_price / million)
+ (output_decimal * output_price / million);
Ok(Some((cost, "USD".to_string())))
}
pub async fn agent_deduct_billing(
&self,
ctx: &SessionContext,
cost: RustDecimal,
) -> Result<(), AppError> {
match ctx.billing_target {
BillingTarget::User => {
let user_id = ctx.user_id.ok_or_else(|| {
AppError::BadRequest(
"user billing target requires user_id".to_string(),
)
})?;
self.deduct_user_balance(user_id, cost).await
}
BillingTarget::Workspace => {
let wk_id = ctx.workspace_id.ok_or_else(|| {
AppError::BadRequest(
"workspace billing target requires workspace_id"
.to_string(),
)
})?;
self.deduct_workspace_balance(wk_id, cost).await
}
}
}
pub async fn agent_record_usage(
&self,
record: &BillingRecord,
) -> Result<(), AppError> {
let cost_decimal: Option<Decimal> = record.cost.map(|c| c.into());
sqlx::query(
"INSERT INTO agent_token_usage \
(invocation, session, model_version, \
input_tokens, output_tokens, cached_input_tokens, \
cache_read_tokens, cache_write_tokens, reasoning_tokens, \
total_tokens, cost, currency, created_at) \
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)",
)
.bind(record.invocation_id)
.bind(record.session_id)
.bind(record.model_version_id)
.bind(record.input_tokens)
.bind(record.output_tokens)
.bind(record.cached_input_tokens)
.bind(record.cache_read_tokens)
.bind(record.cache_write_tokens)
.bind(record.reasoning_tokens)
.bind(record.total_tokens)
.bind(&cost_decimal)
.bind(&record.currency)
.bind(record.created_at)
.execute(self.db.writer())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
Ok(())
}
pub async fn agent_record_invocation(
&self,
invocation_id: Uuid,
session_id: Uuid,
conversation_id: Option<Uuid>,
message_id: Option<Uuid>,
model_version_id: Uuid,
status: &str,
error: Option<&str>,
) -> Result<(), AppError> {
let now = Utc::now();
sqlx::query(
"INSERT INTO agent_model_invocation \
(id, session, conversation, message, model_version, status, error, \
started_at, finished_at) \
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)",
)
.bind(invocation_id)
.bind(session_id)
.bind(conversation_id)
.bind(message_id)
.bind(model_version_id)
.bind(status)
.bind(error)
.bind(now)
.bind(now)
.execute(self.db.writer())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
Ok(())
}
pub async fn agent_record_tool_call(
&self,
invocation_id: Uuid,
session_id: Uuid,
conversation_id: Option<Uuid>,
message_id: Option<Uuid>,
tool_call_id: &str,
tool_name: &str,
arguments: Option<&str>,
result: Option<&str>,
error: Option<&str>,
status: &str,
latency_ms: Option<i64>,
) -> Result<(), AppError> {
let now = Utc::now();
sqlx::query(
"INSERT INTO agent_tool_call_log \
(invocation, session, conversation, message, tool_call_id, \
tool_name, arguments, result, error, status, \
started_at, finished_at, latency_ms) \
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)",
)
.bind(invocation_id)
.bind(session_id)
.bind(conversation_id)
.bind(message_id)
.bind(tool_call_id)
.bind(tool_name)
.bind(arguments)
.bind(result)
.bind(error)
.bind(status)
.bind(now)
.bind(now)
.bind(latency_ms)
.execute(self.db.writer())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
Ok(())
}
}
impl AppService {
async fn deduct_user_balance(
&self,
user_id: Uuid,
cost: RustDecimal,
) -> Result<(), AppError> {
const MAX_RETRIES: u32 = 3;
for attempt in 0..MAX_RETRIES {
match self.try_deduct_user_balance(user_id, cost).await {
Ok(()) => return Ok(()),
Err(AppError::TxnError) if attempt < MAX_RETRIES - 1 => {
let backoff_ms = 10 * (1 << attempt);
tokio::time::sleep(tokio::time::Duration::from_millis(
backoff_ms,
))
.await;
continue;
}
Err(e) => return Err(e),
}
}
Err(AppError::TxnError)
}
async fn try_deduct_user_balance(
&self,
user_id: Uuid,
cost: RustDecimal,
) -> Result<(), AppError> {
let mut txn = self.db.begin().await.map_err(|_| AppError::TxnError)?;
let current: Option<RustDecimal> = sqlx::query_scalar(
"SELECT balance FROM user_billing WHERE \"user\" = $1 FOR UPDATE",
)
.bind(user_id)
.fetch_optional(&mut **txn.inner_mut())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
let current = match current {
Some(balance) => balance,
None => {
let default_balance = RustDecimal::from(20);
let now = Utc::now();
sqlx::query(
"INSERT INTO user_billing (\"user\", balance, created_at, updated_at) \
VALUES ($1, $2, $3, $3)",
)
.bind(user_id)
.bind(&default_balance)
.bind(now)
.execute(&mut **txn.inner_mut())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
default_balance
}
};
if current < cost {
return Err(AppError::BadRequest(
"insufficient balance".to_string(),
));
}
let new_balance = current - cost;
sqlx::query(
"UPDATE user_billing SET balance = $1, updated_at = $2 WHERE \"user\" = $3",
)
.bind(&new_balance)
.bind(Utc::now())
.bind(user_id)
.execute(&mut **txn.inner_mut())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
txn.commit().await.map_err(|_| AppError::TxnError)?;
Ok(())
}
async fn deduct_workspace_balance(
&self,
wk_id: Uuid,
cost: RustDecimal,
) -> Result<(), AppError> {
const MAX_RETRIES: u32 = 3;
for attempt in 0..MAX_RETRIES {
match self.try_deduct_workspace_balance(wk_id, cost).await {
Ok(()) => return Ok(()),
Err(AppError::TxnError) if attempt < MAX_RETRIES - 1 => {
let backoff_ms = 10 * (1 << attempt);
tokio::time::sleep(tokio::time::Duration::from_millis(
backoff_ms,
))
.await;
continue;
}
Err(e) => return Err(e),
}
}
Err(AppError::TxnError)
}
async fn try_deduct_workspace_balance(
&self,
wk_id: Uuid,
cost: RustDecimal,
) -> Result<(), AppError> {
let mut txn = self.db.begin().await.map_err(|_| AppError::TxnError)?;
let current: Option<RustDecimal> = sqlx::query_scalar(
"SELECT balance FROM wk_billing WHERE wk = $1 FOR UPDATE",
)
.bind(wk_id)
.fetch_optional(&mut **txn.inner_mut())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
let current = match current {
Some(balance) => balance,
None => {
let default_balance = RustDecimal::from(20);
let now = Utc::now();
sqlx::query(
"INSERT INTO wk_billing (wk, balance, updated_at) \
VALUES ($1, $2, $3)",
)
.bind(wk_id)
.bind(&default_balance)
.bind(now)
.execute(&mut **txn.inner_mut())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
default_balance
}
};
if current < cost {
return Err(AppError::BadRequest(
"insufficient workspace balance".to_string(),
));
}
let new_balance = current - cost;
sqlx::query(
"UPDATE wk_billing SET balance = $1, updated_at = $2 WHERE wk = $3",
)
.bind(&new_balance)
.bind(Utc::now())
.bind(wk_id)
.execute(&mut **txn.inner_mut())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
txn.commit().await.map_err(|_| AppError::TxnError)?;
Ok(())
}
}