- Use AgentBuilder for native tool-calling with stream_prompt() - Add RecordingTool wrapper preserving retry + DB recording - Fix tool_choice bug in do_completion (same as call_stream_once) - Add seq field to RoomMessageStreamChunkEvent for strict ordering - Map streaming events: Text→Answer, Reasoning→Thought, ToolCall→Action - Only final event has done=true, removed premature stream ending - Store __chunks__ JSON in thinking_content for ordered replay
132 lines
4.2 KiB
Rust
132 lines
4.2 KiB
Rust
//! Batch tool call recorder — persists tool call records to `ai_tool_call` table.
|
|
//!
|
|
//! Uses an mpsc channel + background flush loop to batch-insert records,
|
|
//! reducing DB pressure from individual inserts.
|
|
//!
|
|
//! Flush triggers:
|
|
//! - Buffer reaches `BATCH_SIZE` (default 50)
|
|
//! - `FLUSH_INTERVAL` (default 5s) elapses with non-empty buffer
|
|
//! - Sender is dropped (remaining records flushed on channel close)
|
|
|
|
use std::time::Duration;
|
|
|
|
use db::database::AppDatabase;
|
|
use models::ai::ai_tool_call;
|
|
use models::ai::ToolCallStatus;
|
|
use sea_orm::*;
|
|
use tokio::sync::mpsc;
|
|
use uuid::Uuid;
|
|
|
|
const FLUSH_INTERVAL: Duration = Duration::from_secs(5);
|
|
const BATCH_SIZE: usize = 50;
|
|
|
|
/// A single tool call record to be persisted.
|
|
#[derive(Debug, Clone)]
|
|
pub struct ToolCallRecord {
|
|
pub tool_call_id: String,
|
|
pub session_id: Uuid,
|
|
pub tool_name: String,
|
|
pub caller: Uuid,
|
|
pub arguments: serde_json::Value,
|
|
pub status: ToolCallStatus,
|
|
pub execution_time_ms: Option<i64>,
|
|
pub error_message: Option<String>,
|
|
pub error_stack: Option<String>,
|
|
pub retry_count: i32,
|
|
}
|
|
|
|
/// Channel-based batched recorder. Cheap to clone — all clones share the same sender.
|
|
#[derive(Clone)]
|
|
pub struct ToolCallRecorder {
|
|
tx: mpsc::UnboundedSender<ToolCallRecord>,
|
|
session_id: Uuid,
|
|
}
|
|
|
|
impl ToolCallRecorder {
|
|
/// Create a new recorder with an auto-generated session ID
|
|
/// and spawn a background flush loop.
|
|
pub fn new(db: AppDatabase) -> Self {
|
|
Self::with_session(db, Uuid::new_v4())
|
|
}
|
|
|
|
/// Create a new recorder with a specific session ID
|
|
/// (so tool call records can be linked to an `AiSession`).
|
|
pub fn with_session(db: AppDatabase, session_id: Uuid) -> Self {
|
|
let (tx, rx) = mpsc::unbounded_channel();
|
|
tokio::spawn(flush_loop(db, rx));
|
|
Self { tx, session_id }
|
|
}
|
|
|
|
/// The session ID shared by all tool calls recorded through this instance.
|
|
pub fn session_id(&self) -> Uuid {
|
|
self.session_id
|
|
}
|
|
|
|
/// Enqueue a tool call record for batch persistence.
|
|
pub fn record(&self, record: ToolCallRecord) {
|
|
let _ = self.tx.send(record);
|
|
}
|
|
}
|
|
|
|
async fn flush_loop(db: AppDatabase, mut rx: mpsc::UnboundedReceiver<ToolCallRecord>) {
|
|
let mut buffer = Vec::with_capacity(BATCH_SIZE);
|
|
let mut ticker = tokio::time::interval(FLUSH_INTERVAL);
|
|
ticker.tick().await; // skip first immediate tick
|
|
|
|
loop {
|
|
tokio::select! {
|
|
Some(record) = rx.recv() => {
|
|
buffer.push(record);
|
|
if buffer.len() >= BATCH_SIZE {
|
|
flush(&db, &mut buffer).await;
|
|
}
|
|
}
|
|
_ = ticker.tick() => {
|
|
if !buffer.is_empty() {
|
|
flush(&db, &mut buffer).await;
|
|
}
|
|
}
|
|
else => {
|
|
// Channel closed — flush remaining and exit
|
|
if !buffer.is_empty() {
|
|
flush(&db, &mut buffer).await;
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn flush(db: &AppDatabase, buffer: &mut Vec<ToolCallRecord>) {
|
|
let now = chrono::Utc::now();
|
|
let models: Vec<ai_tool_call::ActiveModel> = buffer
|
|
.iter()
|
|
.map(|r| {
|
|
let status = r.status.to_string();
|
|
ai_tool_call::ActiveModel {
|
|
tool_call_id: Set(r.tool_call_id.clone()),
|
|
session: Set(r.session_id),
|
|
tool_name: Set(r.tool_name.clone()),
|
|
caller: Set(r.caller),
|
|
arguments: Set(r.arguments.clone()),
|
|
result: Set(serde_json::Value::Null),
|
|
status: Set(status),
|
|
execution_time_ms: Set(r.execution_time_ms),
|
|
error_message: Set(r.error_message.clone()),
|
|
error_stack: Set(r.error_stack.clone()),
|
|
retry_count: Set(r.retry_count),
|
|
created_at: Set(now),
|
|
completed_at: Set(Some(now)),
|
|
updated_at: Set(now),
|
|
}
|
|
})
|
|
.collect();
|
|
|
|
let count = models.len();
|
|
if let Err(e) = ai_tool_call::Entity::insert_many(models).exec(db).await {
|
|
tracing::warn!(error = %e, count, "failed_to_flush_tool_call_records");
|
|
}
|
|
|
|
buffer.clear();
|
|
}
|