206 lines
5.9 KiB
Rust
206 lines
5.9 KiB
Rust
use chrono::Utc;
|
|
use db::sqlx;
|
|
use uuid::Uuid;
|
|
|
|
use super::types::{
|
|
AgentCostInfo, AgentStepInfo, AgentToolCallInfo, BillingRecord,
|
|
SessionContext,
|
|
};
|
|
use crate::AppService;
|
|
use crate::error::AppError;
|
|
|
|
impl AppService {
|
|
pub(super) async fn persist_user_message(
|
|
&self,
|
|
conversation_id: Uuid,
|
|
user_id: Uuid,
|
|
content: &str,
|
|
) -> Result<Uuid, AppError> {
|
|
let message_id = Uuid::now_v7();
|
|
let now = Utc::now();
|
|
sqlx::query(
|
|
"INSERT INTO agent_message \
|
|
(id, conversation, role, author, content, content_type, status, created_at, updated_at) \
|
|
VALUES ($1, $2, 'user', $3, $4, 'text', 'completed', $5, $5)",
|
|
)
|
|
.bind(message_id)
|
|
.bind(conversation_id)
|
|
.bind(user_id)
|
|
.bind(content)
|
|
.bind(now)
|
|
.execute(self.db.writer())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
|
|
Ok(message_id)
|
|
}
|
|
|
|
pub(super) async fn persist_assistant_message(
|
|
&self,
|
|
conversation_id: Uuid,
|
|
_session_id: Uuid,
|
|
content: &str,
|
|
reasoning_content: Option<&str>,
|
|
invocation_id: Uuid,
|
|
) -> Result<Uuid, AppError> {
|
|
let message_id = Uuid::now_v7();
|
|
let now = Utc::now();
|
|
sqlx::query(
|
|
"INSERT INTO agent_message \
|
|
(id, conversation, role, content, content_type, status, \
|
|
model_invocation, reasoning_content, created_at, updated_at) \
|
|
VALUES ($1, $2, 'assistant', $3, 'text', 'completed', $4, $5, $6, $6)",
|
|
)
|
|
.bind(message_id)
|
|
.bind(conversation_id)
|
|
.bind(content)
|
|
.bind(invocation_id)
|
|
.bind(reasoning_content)
|
|
.bind(now)
|
|
.execute(self.db.writer())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
|
|
Ok(message_id)
|
|
}
|
|
|
|
pub(super) async fn persist_billing_and_deduct(
|
|
&self,
|
|
ctx: &SessionContext,
|
|
invocation_id: Uuid,
|
|
input_tokens: i64,
|
|
output_tokens: i64,
|
|
) -> Result<Option<AgentCostInfo>, AppError> {
|
|
let cost_result = self
|
|
.agent_calculate_cost(
|
|
ctx.model_version_id,
|
|
input_tokens,
|
|
output_tokens,
|
|
)
|
|
.await?;
|
|
|
|
let (cost, currency) = match cost_result {
|
|
Some((c, cur)) => (c, cur),
|
|
None => {
|
|
let record = BillingRecord {
|
|
invocation_id,
|
|
session_id: ctx.session_id,
|
|
model_version_id: ctx.model_version_id,
|
|
input_tokens,
|
|
output_tokens,
|
|
cached_input_tokens: 0,
|
|
cache_read_tokens: 0,
|
|
cache_write_tokens: 0,
|
|
reasoning_tokens: 0,
|
|
total_tokens: input_tokens.saturating_add(output_tokens),
|
|
cost: None,
|
|
currency: None,
|
|
created_at: Utc::now(),
|
|
};
|
|
self.agent_record_usage(&record).await?;
|
|
return Ok(None);
|
|
}
|
|
};
|
|
|
|
let record = BillingRecord {
|
|
invocation_id,
|
|
session_id: ctx.session_id,
|
|
model_version_id: ctx.model_version_id,
|
|
input_tokens,
|
|
output_tokens,
|
|
cached_input_tokens: 0,
|
|
cache_read_tokens: 0,
|
|
cache_write_tokens: 0,
|
|
reasoning_tokens: 0,
|
|
total_tokens: input_tokens.saturating_add(output_tokens),
|
|
cost: Some(cost),
|
|
currency: Some(currency.clone()),
|
|
created_at: Utc::now(),
|
|
};
|
|
self.agent_record_usage(&record).await?;
|
|
|
|
if let Err(e) = self.agent_deduct_billing(ctx, cost).await {
|
|
tracing::warn!(
|
|
invocation_id = %invocation_id,
|
|
error = %e,
|
|
"agent billing deduction failed"
|
|
);
|
|
}
|
|
|
|
Ok(Some(AgentCostInfo {
|
|
amount: cost.to_string(),
|
|
currency,
|
|
}))
|
|
}
|
|
|
|
pub(super) async fn update_conversation_timestamp(
|
|
&self,
|
|
conversation_id: Uuid,
|
|
) -> Result<(), AppError> {
|
|
let now = Utc::now();
|
|
sqlx::query(
|
|
"UPDATE agent_conversation SET last_message_at = $1, updated_at = $1 WHERE id = $2",
|
|
)
|
|
.bind(now)
|
|
.bind(conversation_id)
|
|
.execute(self.db.writer())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
|
|
Ok(())
|
|
}
|
|
|
|
pub(super) async fn update_conversation_title(
|
|
&self,
|
|
conversation_id: Uuid,
|
|
title: &str,
|
|
) -> Result<(), AppError> {
|
|
let now = Utc::now();
|
|
sqlx::query(
|
|
"UPDATE agent_conversation SET title = $1, updated_at = $2 WHERE id = $3",
|
|
)
|
|
.bind(title)
|
|
.bind(now)
|
|
.bind(conversation_id)
|
|
.execute(self.db.writer())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
#[allow(dead_code)]
|
|
pub(super) fn step_info_from_agent(
|
|
step: ai::agent::AgentStep,
|
|
) -> AgentStepInfo {
|
|
AgentStepInfo {
|
|
index: step.index,
|
|
assistant: step.assistant,
|
|
tool_calls: step
|
|
.tool_calls
|
|
.into_iter()
|
|
.map(tool_call_info_from_record)
|
|
.collect(),
|
|
reflection: step.reflection,
|
|
}
|
|
}
|
|
|
|
#[allow(dead_code)]
|
|
pub(super) fn tool_call_info_from_record(
|
|
record: ai::agent::ToolCallRecord,
|
|
) -> AgentToolCallInfo {
|
|
AgentToolCallInfo {
|
|
id: record.id,
|
|
name: record.name,
|
|
arguments: record.arguments,
|
|
output: record.output,
|
|
error: record.error,
|
|
elapsed_ms: record.elapsed_ms,
|
|
}
|
|
}
|
|
|
|
pub(super) fn stream_error(error: &str) -> String {
|
|
let payload = serde_json::json!({
|
|
"type": "error",
|
|
"error": error,
|
|
});
|
|
format!("data: {}\n\n", payload)
|
|
}
|