328 lines
9.2 KiB
Rust
328 lines
9.2 KiB
Rust
use chrono::Utc;
|
|
use db::sqlx;
|
|
use model::agent::AgentTraceModel;
|
|
use serde_json::{Value, json};
|
|
use uuid::Uuid;
|
|
|
|
use crate::AppService;
|
|
use crate::error::AppError;
|
|
|
|
pub struct TraceContext {
|
|
pub invocation_id: Uuid,
|
|
pub conversation_id: Uuid,
|
|
}
|
|
|
|
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
|
pub struct TraceReplay {
|
|
pub invocation_id: Uuid,
|
|
pub conversation_id: Uuid,
|
|
pub phases: Vec<TracePhaseRow>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
|
pub struct TracePhaseRow {
|
|
pub id: Uuid,
|
|
pub sequence: i32,
|
|
pub phase: String,
|
|
pub label: String,
|
|
pub content: Option<String>,
|
|
pub tool_calls: Option<Value>,
|
|
pub tool_results: Option<Value>,
|
|
pub input_tokens: Option<i64>,
|
|
pub output_tokens: Option<i64>,
|
|
pub metadata: Option<Value>,
|
|
pub created_at: chrono::DateTime<Utc>,
|
|
}
|
|
|
|
impl AppService {
|
|
pub async fn trace_record(
|
|
&self,
|
|
ctx: &TraceContext,
|
|
sequence: i32,
|
|
phase: &str,
|
|
content: Option<&str>,
|
|
tool_calls: Option<&Value>,
|
|
tool_results: Option<&Value>,
|
|
input_tokens: Option<i64>,
|
|
output_tokens: Option<i64>,
|
|
metadata: Option<&Value>,
|
|
) -> Result<Uuid, AppError> {
|
|
let id = Uuid::now_v7();
|
|
let now = Utc::now();
|
|
sqlx::query(
|
|
"INSERT INTO agent_trace \
|
|
(id, invocation, conversation, sequence, phase, content, \
|
|
tool_calls, tool_results, input_tokens, output_tokens, metadata, created_at) \
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)",
|
|
)
|
|
.bind(id)
|
|
.bind(ctx.invocation_id)
|
|
.bind(ctx.conversation_id)
|
|
.bind(sequence)
|
|
.bind(phase)
|
|
.bind(content)
|
|
.bind(tool_calls)
|
|
.bind(tool_results)
|
|
.bind(input_tokens)
|
|
.bind(output_tokens)
|
|
.bind(metadata)
|
|
.bind(now)
|
|
.execute(self.db.writer())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
|
|
|
|
Ok(id)
|
|
}
|
|
|
|
pub async fn trace_replay_by_invocation(
|
|
&self,
|
|
invocation_id: Uuid,
|
|
) -> Result<TraceReplay, AppError> {
|
|
let rows = sqlx::query_as::<_, AgentTraceModel>(
|
|
"SELECT id, invocation, conversation, sequence, phase, content, \
|
|
tool_calls, tool_results, input_tokens, output_tokens, metadata, created_at \
|
|
FROM agent_trace WHERE invocation = $1 ORDER BY sequence ASC",
|
|
)
|
|
.bind(invocation_id)
|
|
.fetch_all(self.db.reader())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
|
|
|
|
let conversation_id =
|
|
rows.first().map(|r| r.conversation).unwrap_or(Uuid::nil());
|
|
|
|
Ok(TraceReplay {
|
|
invocation_id,
|
|
conversation_id,
|
|
phases: rows
|
|
.into_iter()
|
|
.map(|r| TracePhaseRow {
|
|
id: r.id,
|
|
sequence: r.sequence,
|
|
phase: r.phase.clone(),
|
|
label: r.phase_label().to_string(),
|
|
content: r.content,
|
|
tool_calls: r.tool_calls,
|
|
tool_results: r.tool_results,
|
|
input_tokens: r.input_tokens,
|
|
output_tokens: r.output_tokens,
|
|
metadata: r.metadata,
|
|
created_at: r.created_at,
|
|
})
|
|
.collect(),
|
|
})
|
|
}
|
|
|
|
pub async fn trace_replay_by_conversation(
|
|
&self,
|
|
conversation_id: Uuid,
|
|
) -> Result<Vec<TraceReplay>, AppError> {
|
|
let rows = sqlx::query_as::<_, AgentTraceModel>(
|
|
"SELECT id, invocation, conversation, sequence, phase, content, \
|
|
tool_calls, tool_results, input_tokens, output_tokens, metadata, created_at \
|
|
FROM agent_trace WHERE conversation = $1 ORDER BY invocation, sequence ASC",
|
|
)
|
|
.bind(conversation_id)
|
|
.fetch_all(self.db.reader())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
|
|
|
|
let mut grouped: std::collections::BTreeMap<
|
|
Uuid,
|
|
Vec<AgentTraceModel>,
|
|
> = std::collections::BTreeMap::new();
|
|
for row in rows {
|
|
grouped.entry(row.invocation).or_default().push(row);
|
|
}
|
|
|
|
Ok(grouped
|
|
.into_iter()
|
|
.map(|(invocation_id, rows)| TraceReplay {
|
|
invocation_id,
|
|
conversation_id,
|
|
phases: rows
|
|
.into_iter()
|
|
.map(|r| TracePhaseRow {
|
|
id: r.id,
|
|
sequence: r.sequence,
|
|
phase: r.phase.clone(),
|
|
label: r.phase_label().to_string(),
|
|
content: r.content,
|
|
tool_calls: r.tool_calls,
|
|
tool_results: r.tool_results,
|
|
input_tokens: r.input_tokens,
|
|
output_tokens: r.output_tokens,
|
|
metadata: r.metadata,
|
|
created_at: r.created_at,
|
|
})
|
|
.collect(),
|
|
})
|
|
.collect())
|
|
}
|
|
}
|
|
|
|
pub struct TraceAccumulator {
|
|
ctx: TraceContext,
|
|
seq: i32,
|
|
think_buf: String,
|
|
answer_buf: String,
|
|
think_tokens: i64,
|
|
answer_tokens: i64,
|
|
svc: AppService,
|
|
}
|
|
|
|
impl TraceAccumulator {
|
|
pub fn new(
|
|
svc: AppService,
|
|
invocation_id: Uuid,
|
|
conversation_id: Uuid,
|
|
) -> Self {
|
|
Self {
|
|
ctx: TraceContext {
|
|
invocation_id,
|
|
conversation_id,
|
|
},
|
|
seq: 0,
|
|
think_buf: String::new(),
|
|
answer_buf: String::new(),
|
|
think_tokens: 0,
|
|
answer_tokens: 0,
|
|
svc,
|
|
}
|
|
}
|
|
|
|
pub async fn feed_thinking(&mut self, chunk: &str) {
|
|
self.think_buf.push_str(chunk);
|
|
self.think_tokens += (chunk.chars().count() as f64 / 2.5).ceil() as i64;
|
|
}
|
|
|
|
pub async fn feed_text(&mut self, chunk: &str) {
|
|
if !self.think_buf.is_empty() {
|
|
self.flush_think().await;
|
|
}
|
|
self.answer_buf.push_str(chunk);
|
|
self.answer_tokens +=
|
|
(chunk.chars().count() as f64 / 2.5).ceil() as i64;
|
|
}
|
|
|
|
pub async fn feed_tool_call(
|
|
&mut self,
|
|
tool_call_id: &str,
|
|
tool_name: &str,
|
|
args: &Value,
|
|
) {
|
|
if !self.answer_buf.is_empty() {
|
|
self.flush_answer().await;
|
|
}
|
|
let _ = self.svc.trace_record(
|
|
&self.ctx, self.seq, "act",
|
|
None,
|
|
Some(&json!({ "tool_call_id": tool_call_id, "name": tool_name, "arguments": args })),
|
|
None,
|
|
None, None, None,
|
|
).await;
|
|
self.seq += 1;
|
|
}
|
|
|
|
pub async fn feed_tool_result(
|
|
&mut self,
|
|
tool_call_id: &str,
|
|
tool_name: &str,
|
|
output: Option<&Value>,
|
|
error: Option<&str>,
|
|
elapsed_ms: i64,
|
|
) {
|
|
let _ = self
|
|
.svc
|
|
.trace_record(
|
|
&self.ctx,
|
|
self.seq,
|
|
"act",
|
|
None,
|
|
None,
|
|
Some(&json!({
|
|
"tool_call_id": tool_call_id,
|
|
"name": tool_name,
|
|
"output": output,
|
|
"error": error,
|
|
"elapsed_ms": elapsed_ms,
|
|
})),
|
|
None,
|
|
None,
|
|
None,
|
|
)
|
|
.await;
|
|
self.seq += 1;
|
|
}
|
|
|
|
pub async fn finish(
|
|
&mut self,
|
|
output: &str,
|
|
input_tokens: i64,
|
|
output_tokens: i64,
|
|
) {
|
|
if !self.think_buf.is_empty() {
|
|
self.flush_think().await;
|
|
}
|
|
if !self.answer_buf.is_empty() {
|
|
self.flush_answer().await;
|
|
}
|
|
let _ = self
|
|
.svc
|
|
.trace_record(
|
|
&self.ctx,
|
|
self.seq,
|
|
"summarize",
|
|
Some(output),
|
|
None,
|
|
None,
|
|
Some(input_tokens),
|
|
Some(output_tokens),
|
|
None,
|
|
)
|
|
.await;
|
|
}
|
|
|
|
async fn flush_think(&mut self) {
|
|
let content = std::mem::take(&mut self.think_buf);
|
|
let tokens = self.think_tokens;
|
|
self.think_tokens = 0;
|
|
let _ = self
|
|
.svc
|
|
.trace_record(
|
|
&self.ctx,
|
|
self.seq,
|
|
"think",
|
|
Some(&content),
|
|
None,
|
|
None,
|
|
Some(tokens),
|
|
None,
|
|
None,
|
|
)
|
|
.await;
|
|
self.seq += 1;
|
|
}
|
|
|
|
async fn flush_answer(&mut self) {
|
|
let content = std::mem::take(&mut self.answer_buf);
|
|
let tokens = self.answer_tokens;
|
|
self.answer_tokens = 0;
|
|
let _ = self
|
|
.svc
|
|
.trace_record(
|
|
&self.ctx,
|
|
self.seq,
|
|
"answer",
|
|
Some(&content),
|
|
None,
|
|
None,
|
|
None,
|
|
Some(tokens),
|
|
None,
|
|
)
|
|
.await;
|
|
self.seq += 1;
|
|
}
|
|
}
|