gitdataai/libs/agent/tool/recorder.rs
ZhenYi 5b3a6700be refactor(agent): replace custom ReAct loop with rig::agent::Agent
- Use AgentBuilder for native tool-calling with stream_prompt()
- Add RecordingTool wrapper preserving retry + DB recording
- Fix tool_choice bug in do_completion (same as call_stream_once)
- Add seq field to RoomMessageStreamChunkEvent for strict ordering
- Map streaming events: Text→Answer, Reasoning→Thought, ToolCall→Action
- Only final event has done=true, removed premature stream ending
- Store __chunks__ JSON in thinking_content for ordered replay
2026-04-28 09:42:36 +08:00

132 lines
4.2 KiB
Rust

//! Batch tool call recorder — persists tool call records to `ai_tool_call` table.
//!
//! Uses an mpsc channel + background flush loop to batch-insert records,
//! reducing DB pressure from individual inserts.
//!
//! Flush triggers:
//! - Buffer reaches `BATCH_SIZE` (default 50)
//! - `FLUSH_INTERVAL` (default 5s) elapses with non-empty buffer
//! - Sender is dropped (remaining records flushed on channel close)
use std::time::Duration;
use db::database::AppDatabase;
use models::ai::ai_tool_call;
use models::ai::ToolCallStatus;
use sea_orm::*;
use tokio::sync::mpsc;
use uuid::Uuid;
const FLUSH_INTERVAL: Duration = Duration::from_secs(5);
const BATCH_SIZE: usize = 50;
/// A single tool call record to be persisted.
#[derive(Debug, Clone)]
pub struct ToolCallRecord {
pub tool_call_id: String,
pub session_id: Uuid,
pub tool_name: String,
pub caller: Uuid,
pub arguments: serde_json::Value,
pub status: ToolCallStatus,
pub execution_time_ms: Option<i64>,
pub error_message: Option<String>,
pub error_stack: Option<String>,
pub retry_count: i32,
}
/// Channel-based batched recorder. Cheap to clone — all clones share the same sender.
#[derive(Clone)]
pub struct ToolCallRecorder {
tx: mpsc::UnboundedSender<ToolCallRecord>,
session_id: Uuid,
}
impl ToolCallRecorder {
/// Create a new recorder with an auto-generated session ID
/// and spawn a background flush loop.
pub fn new(db: AppDatabase) -> Self {
Self::with_session(db, Uuid::new_v4())
}
/// Create a new recorder with a specific session ID
/// (so tool call records can be linked to an `AiSession`).
pub fn with_session(db: AppDatabase, session_id: Uuid) -> Self {
let (tx, rx) = mpsc::unbounded_channel();
tokio::spawn(flush_loop(db, rx));
Self { tx, session_id }
}
/// The session ID shared by all tool calls recorded through this instance.
pub fn session_id(&self) -> Uuid {
self.session_id
}
/// Enqueue a tool call record for batch persistence.
pub fn record(&self, record: ToolCallRecord) {
let _ = self.tx.send(record);
}
}
async fn flush_loop(db: AppDatabase, mut rx: mpsc::UnboundedReceiver<ToolCallRecord>) {
let mut buffer = Vec::with_capacity(BATCH_SIZE);
let mut ticker = tokio::time::interval(FLUSH_INTERVAL);
ticker.tick().await; // skip first immediate tick
loop {
tokio::select! {
Some(record) = rx.recv() => {
buffer.push(record);
if buffer.len() >= BATCH_SIZE {
flush(&db, &mut buffer).await;
}
}
_ = ticker.tick() => {
if !buffer.is_empty() {
flush(&db, &mut buffer).await;
}
}
else => {
// Channel closed — flush remaining and exit
if !buffer.is_empty() {
flush(&db, &mut buffer).await;
}
break;
}
}
}
}
async fn flush(db: &AppDatabase, buffer: &mut Vec<ToolCallRecord>) {
let now = chrono::Utc::now();
let models: Vec<ai_tool_call::ActiveModel> = buffer
.iter()
.map(|r| {
let status = r.status.to_string();
ai_tool_call::ActiveModel {
tool_call_id: Set(r.tool_call_id.clone()),
session: Set(r.session_id),
tool_name: Set(r.tool_name.clone()),
caller: Set(r.caller),
arguments: Set(r.arguments.clone()),
result: Set(serde_json::Value::Null),
status: Set(status),
execution_time_ms: Set(r.execution_time_ms),
error_message: Set(r.error_message.clone()),
error_stack: Set(r.error_stack.clone()),
retry_count: Set(r.retry_count),
created_at: Set(now),
completed_at: Set(Some(now)),
updated_at: Set(now),
}
})
.collect();
let count = models.len();
if let Err(e) = ai_tool_call::Entity::insert_many(models).exec(db).await {
tracing::warn!(error = %e, count, "failed_to_flush_tool_call_records");
}
buffer.clear();
}