gitdataai/lib/service/agent/trace.rs

328 lines
9.2 KiB
Rust

use chrono::Utc;
use db::sqlx;
use model::agent::AgentTraceModel;
use serde_json::{Value, json};
use uuid::Uuid;
use crate::AppService;
use crate::error::AppError;
pub struct TraceContext {
pub invocation_id: Uuid,
pub conversation_id: Uuid,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct TraceReplay {
pub invocation_id: Uuid,
pub conversation_id: Uuid,
pub phases: Vec<TracePhaseRow>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct TracePhaseRow {
pub id: Uuid,
pub sequence: i32,
pub phase: String,
pub label: String,
pub content: Option<String>,
pub tool_calls: Option<Value>,
pub tool_results: Option<Value>,
pub input_tokens: Option<i64>,
pub output_tokens: Option<i64>,
pub metadata: Option<Value>,
pub created_at: chrono::DateTime<Utc>,
}
impl AppService {
pub async fn trace_record(
&self,
ctx: &TraceContext,
sequence: i32,
phase: &str,
content: Option<&str>,
tool_calls: Option<&Value>,
tool_results: Option<&Value>,
input_tokens: Option<i64>,
output_tokens: Option<i64>,
metadata: Option<&Value>,
) -> Result<Uuid, AppError> {
let id = Uuid::now_v7();
let now = Utc::now();
sqlx::query(
"INSERT INTO agent_trace \
(id, invocation, conversation, sequence, phase, content, \
tool_calls, tool_results, input_tokens, output_tokens, metadata, created_at) \
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)",
)
.bind(id)
.bind(ctx.invocation_id)
.bind(ctx.conversation_id)
.bind(sequence)
.bind(phase)
.bind(content)
.bind(tool_calls)
.bind(tool_results)
.bind(input_tokens)
.bind(output_tokens)
.bind(metadata)
.bind(now)
.execute(self.db.writer())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
Ok(id)
}
pub async fn trace_replay_by_invocation(
&self,
invocation_id: Uuid,
) -> Result<TraceReplay, AppError> {
let rows = sqlx::query_as::<_, AgentTraceModel>(
"SELECT id, invocation, conversation, sequence, phase, content, \
tool_calls, tool_results, input_tokens, output_tokens, metadata, created_at \
FROM agent_trace WHERE invocation = $1 ORDER BY sequence ASC",
)
.bind(invocation_id)
.fetch_all(self.db.reader())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
let conversation_id =
rows.first().map(|r| r.conversation).unwrap_or(Uuid::nil());
Ok(TraceReplay {
invocation_id,
conversation_id,
phases: rows
.into_iter()
.map(|r| TracePhaseRow {
id: r.id,
sequence: r.sequence,
phase: r.phase.clone(),
label: r.phase_label().to_string(),
content: r.content,
tool_calls: r.tool_calls,
tool_results: r.tool_results,
input_tokens: r.input_tokens,
output_tokens: r.output_tokens,
metadata: r.metadata,
created_at: r.created_at,
})
.collect(),
})
}
pub async fn trace_replay_by_conversation(
&self,
conversation_id: Uuid,
) -> Result<Vec<TraceReplay>, AppError> {
let rows = sqlx::query_as::<_, AgentTraceModel>(
"SELECT id, invocation, conversation, sequence, phase, content, \
tool_calls, tool_results, input_tokens, output_tokens, metadata, created_at \
FROM agent_trace WHERE conversation = $1 ORDER BY invocation, sequence ASC",
)
.bind(conversation_id)
.fetch_all(self.db.reader())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
let mut grouped: std::collections::BTreeMap<
Uuid,
Vec<AgentTraceModel>,
> = std::collections::BTreeMap::new();
for row in rows {
grouped.entry(row.invocation).or_default().push(row);
}
Ok(grouped
.into_iter()
.map(|(invocation_id, rows)| TraceReplay {
invocation_id,
conversation_id,
phases: rows
.into_iter()
.map(|r| TracePhaseRow {
id: r.id,
sequence: r.sequence,
phase: r.phase.clone(),
label: r.phase_label().to_string(),
content: r.content,
tool_calls: r.tool_calls,
tool_results: r.tool_results,
input_tokens: r.input_tokens,
output_tokens: r.output_tokens,
metadata: r.metadata,
created_at: r.created_at,
})
.collect(),
})
.collect())
}
}
pub struct TraceAccumulator {
ctx: TraceContext,
seq: i32,
think_buf: String,
answer_buf: String,
think_tokens: i64,
answer_tokens: i64,
svc: AppService,
}
impl TraceAccumulator {
pub fn new(
svc: AppService,
invocation_id: Uuid,
conversation_id: Uuid,
) -> Self {
Self {
ctx: TraceContext {
invocation_id,
conversation_id,
},
seq: 0,
think_buf: String::new(),
answer_buf: String::new(),
think_tokens: 0,
answer_tokens: 0,
svc,
}
}
pub async fn feed_thinking(&mut self, chunk: &str) {
self.think_buf.push_str(chunk);
self.think_tokens += (chunk.chars().count() as f64 / 2.5).ceil() as i64;
}
pub async fn feed_text(&mut self, chunk: &str) {
if !self.think_buf.is_empty() {
self.flush_think().await;
}
self.answer_buf.push_str(chunk);
self.answer_tokens +=
(chunk.chars().count() as f64 / 2.5).ceil() as i64;
}
pub async fn feed_tool_call(
&mut self,
tool_call_id: &str,
tool_name: &str,
args: &Value,
) {
if !self.answer_buf.is_empty() {
self.flush_answer().await;
}
let _ = self.svc.trace_record(
&self.ctx, self.seq, "act",
None,
Some(&json!({ "tool_call_id": tool_call_id, "name": tool_name, "arguments": args })),
None,
None, None, None,
).await;
self.seq += 1;
}
pub async fn feed_tool_result(
&mut self,
tool_call_id: &str,
tool_name: &str,
output: Option<&Value>,
error: Option<&str>,
elapsed_ms: i64,
) {
let _ = self
.svc
.trace_record(
&self.ctx,
self.seq,
"act",
None,
None,
Some(&json!({
"tool_call_id": tool_call_id,
"name": tool_name,
"output": output,
"error": error,
"elapsed_ms": elapsed_ms,
})),
None,
None,
None,
)
.await;
self.seq += 1;
}
pub async fn finish(
&mut self,
output: &str,
input_tokens: i64,
output_tokens: i64,
) {
if !self.think_buf.is_empty() {
self.flush_think().await;
}
if !self.answer_buf.is_empty() {
self.flush_answer().await;
}
let _ = self
.svc
.trace_record(
&self.ctx,
self.seq,
"summarize",
Some(output),
None,
None,
Some(input_tokens),
Some(output_tokens),
None,
)
.await;
}
async fn flush_think(&mut self) {
let content = std::mem::take(&mut self.think_buf);
let tokens = self.think_tokens;
self.think_tokens = 0;
let _ = self
.svc
.trace_record(
&self.ctx,
self.seq,
"think",
Some(&content),
None,
None,
Some(tokens),
None,
None,
)
.await;
self.seq += 1;
}
async fn flush_answer(&mut self) {
let content = std::mem::take(&mut self.answer_buf);
let tokens = self.answer_tokens;
self.answer_tokens = 0;
let _ = self
.svc
.trace_record(
&self.ctx,
self.seq,
"answer",
Some(&content),
None,
None,
None,
Some(tokens),
None,
)
.await;
self.seq += 1;
}
}