gitdataai/lib/service/agent/persistence.rs

206 lines
5.9 KiB
Rust

use chrono::Utc;
use db::sqlx;
use uuid::Uuid;
use super::types::{
AgentCostInfo, AgentStepInfo, AgentToolCallInfo, BillingRecord,
SessionContext,
};
use crate::AppService;
use crate::error::AppError;
impl AppService {
pub(super) async fn persist_user_message(
&self,
conversation_id: Uuid,
user_id: Uuid,
content: &str,
) -> Result<Uuid, AppError> {
let message_id = Uuid::now_v7();
let now = Utc::now();
sqlx::query(
"INSERT INTO agent_message \
(id, conversation, role, author, content, content_type, status, created_at, updated_at) \
VALUES ($1, $2, 'user', $3, $4, 'text', 'completed', $5, $5)",
)
.bind(message_id)
.bind(conversation_id)
.bind(user_id)
.bind(content)
.bind(now)
.execute(self.db.writer())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
Ok(message_id)
}
pub(super) async fn persist_assistant_message(
&self,
conversation_id: Uuid,
_session_id: Uuid,
content: &str,
reasoning_content: Option<&str>,
invocation_id: Uuid,
) -> Result<Uuid, AppError> {
let message_id = Uuid::now_v7();
let now = Utc::now();
sqlx::query(
"INSERT INTO agent_message \
(id, conversation, role, content, content_type, status, \
model_invocation, reasoning_content, created_at, updated_at) \
VALUES ($1, $2, 'assistant', $3, 'text', 'completed', $4, $5, $6, $6)",
)
.bind(message_id)
.bind(conversation_id)
.bind(content)
.bind(invocation_id)
.bind(reasoning_content)
.bind(now)
.execute(self.db.writer())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
Ok(message_id)
}
pub(super) async fn persist_billing_and_deduct(
&self,
ctx: &SessionContext,
invocation_id: Uuid,
input_tokens: i64,
output_tokens: i64,
) -> Result<Option<AgentCostInfo>, AppError> {
let cost_result = self
.agent_calculate_cost(
ctx.model_version_id,
input_tokens,
output_tokens,
)
.await?;
let (cost, currency) = match cost_result {
Some((c, cur)) => (c, cur),
None => {
let record = BillingRecord {
invocation_id,
session_id: ctx.session_id,
model_version_id: ctx.model_version_id,
input_tokens,
output_tokens,
cached_input_tokens: 0,
cache_read_tokens: 0,
cache_write_tokens: 0,
reasoning_tokens: 0,
total_tokens: input_tokens.saturating_add(output_tokens),
cost: None,
currency: None,
created_at: Utc::now(),
};
self.agent_record_usage(&record).await?;
return Ok(None);
}
};
let record = BillingRecord {
invocation_id,
session_id: ctx.session_id,
model_version_id: ctx.model_version_id,
input_tokens,
output_tokens,
cached_input_tokens: 0,
cache_read_tokens: 0,
cache_write_tokens: 0,
reasoning_tokens: 0,
total_tokens: input_tokens.saturating_add(output_tokens),
cost: Some(cost),
currency: Some(currency.clone()),
created_at: Utc::now(),
};
self.agent_record_usage(&record).await?;
if let Err(e) = self.agent_deduct_billing(ctx, cost).await {
tracing::warn!(
invocation_id = %invocation_id,
error = %e,
"agent billing deduction failed"
);
}
Ok(Some(AgentCostInfo {
amount: cost.to_string(),
currency,
}))
}
pub(super) async fn update_conversation_timestamp(
&self,
conversation_id: Uuid,
) -> Result<(), AppError> {
let now = Utc::now();
sqlx::query(
"UPDATE agent_conversation SET last_message_at = $1, updated_at = $1 WHERE id = $2",
)
.bind(now)
.bind(conversation_id)
.execute(self.db.writer())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
Ok(())
}
pub(super) async fn update_conversation_title(
&self,
conversation_id: Uuid,
title: &str,
) -> Result<(), AppError> {
let now = Utc::now();
sqlx::query(
"UPDATE agent_conversation SET title = $1, updated_at = $2 WHERE id = $3",
)
.bind(title)
.bind(now)
.bind(conversation_id)
.execute(self.db.writer())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
Ok(())
}
}
#[allow(dead_code)]
pub(super) fn step_info_from_agent(
step: ai::agent::AgentStep,
) -> AgentStepInfo {
AgentStepInfo {
index: step.index,
assistant: step.assistant,
tool_calls: step
.tool_calls
.into_iter()
.map(tool_call_info_from_record)
.collect(),
reflection: step.reflection,
}
}
#[allow(dead_code)]
pub(super) fn tool_call_info_from_record(
record: ai::agent::ToolCallRecord,
) -> AgentToolCallInfo {
AgentToolCallInfo {
id: record.id,
name: record.name,
arguments: record.arguments,
output: record.output,
error: record.error,
elapsed_ms: record.elapsed_ms,
}
}
pub(super) fn stream_error(error: &str) -> String {
let payload = serde_json::json!({
"type": "error",
"error": error,
});
format!("data: {}\n\n", payload)
}