832 lines
28 KiB
Rust
832 lines
28 KiB
Rust
//! Unified AI client with built-in retry, token tracking, and session recording.
|
|
//!
|
|
//! Uses rig-core as the underlying AI provider library.
|
|
|
|
pub mod types;
|
|
pub use types::{ChatRequestMessage, ToolCall as ClientToolCall};
|
|
|
|
use std::pin::Pin;
|
|
use std::sync::Arc;
|
|
use std::time::Instant;
|
|
use uuid::Uuid;
|
|
|
|
use crate::error::{AgentError, Result};
|
|
|
|
use futures::StreamExt;
|
|
use rig::completion::message::{AssistantContent, Message as RigMessage};
|
|
use rig::completion::{CompletionModel, GetTokenUsage, ToolDefinition};
|
|
use rig::one_or_many::OneOrMany;
|
|
use rig::prelude::CompletionClient;
|
|
use rig::providers::openai;
|
|
|
|
/// AI call metrics — increments metrics crate counters for all AI calls.
|
|
#[derive(Debug, Clone, Default)]
|
|
pub struct AiMetrics;
|
|
|
|
impl AiMetrics {
|
|
pub fn new() -> Self {
|
|
Self
|
|
}
|
|
|
|
pub fn record_success(&self, input_tokens: i64, output_tokens: i64, has_function_call: bool) {
|
|
metrics::counter!("ai_calls_total").increment(1);
|
|
metrics::counter!("ai_calls_success").increment(1);
|
|
if input_tokens > 0 {
|
|
metrics::counter!("ai_input_tokens_total").increment(input_tokens as u64);
|
|
}
|
|
if output_tokens > 0 {
|
|
metrics::counter!("ai_output_tokens_total").increment(output_tokens as u64);
|
|
}
|
|
if has_function_call {
|
|
metrics::counter!("ai_function_calls_total").increment(1);
|
|
}
|
|
}
|
|
|
|
pub fn record_failure(&self) {
|
|
metrics::counter!("ai_calls_total").increment(1);
|
|
metrics::counter!("ai_calls_failure").increment(1);
|
|
}
|
|
}
|
|
|
|
/// Configuration for the AI client.
|
|
#[derive(Clone)]
|
|
pub struct AiClientConfig {
|
|
pub api_key: String,
|
|
pub base_url: Option<String>,
|
|
}
|
|
|
|
impl AiClientConfig {
|
|
pub fn new(api_key: String) -> Self {
|
|
Self {
|
|
api_key,
|
|
base_url: None,
|
|
}
|
|
}
|
|
|
|
pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
|
|
self.base_url = Some(base_url.into());
|
|
self
|
|
}
|
|
|
|
/// Build a rig OpenAI client from this config.
|
|
pub fn build_rig_client(&self) -> openai::Client {
|
|
let base = self
|
|
.base_url
|
|
.clone()
|
|
.unwrap_or_else(|| "https://api.openai.com".to_string());
|
|
openai::Client::builder()
|
|
.api_key(&self.api_key)
|
|
.base_url(&base)
|
|
.build()
|
|
.expect("Failed to build rig OpenAI client")
|
|
}
|
|
}
|
|
|
|
/// Response from an AI call, including usage statistics.
|
|
#[derive(Debug, Clone)]
|
|
pub struct AiCallResponse {
|
|
pub content: String,
|
|
pub input_tokens: i64,
|
|
pub output_tokens: i64,
|
|
pub latency_ms: i64,
|
|
pub tool_calls: Vec<ClientToolCall>,
|
|
pub tool_calls_finished: Vec<String>,
|
|
}
|
|
|
|
impl AiCallResponse {
|
|
pub fn total_tokens(&self) -> i64 {
|
|
self.input_tokens + self.output_tokens
|
|
}
|
|
}
|
|
|
|
/// Internal state for retry tracking.
|
|
#[derive(Debug)]
|
|
struct RetryState {
|
|
attempt: u32,
|
|
max_retries: u32,
|
|
max_backoff_ms: u64,
|
|
}
|
|
|
|
impl RetryState {
|
|
fn new(max_retries: u32) -> Self {
|
|
Self {
|
|
attempt: 0,
|
|
max_retries,
|
|
max_backoff_ms: 5000,
|
|
}
|
|
}
|
|
fn should_retry(&self) -> bool {
|
|
self.attempt < self.max_retries
|
|
}
|
|
fn backoff_duration(&self) -> std::time::Duration {
|
|
let exp = self.attempt.min(5);
|
|
let base_ms = 500u64
|
|
.saturating_mul(2u64.pow(exp))
|
|
.min(self.max_backoff_ms);
|
|
let max_jitter = (base_ms / 2).max(base_ms);
|
|
let offset = fastrand_u64(max_jitter + 1).saturating_sub(base_ms / 2);
|
|
let total = base_ms.saturating_add(offset).min(self.max_backoff_ms);
|
|
std::time::Duration::from_millis(total)
|
|
}
|
|
fn next(&mut self) {
|
|
self.attempt += 1;
|
|
}
|
|
}
|
|
|
|
fn fastrand_u64(n: u64) -> u64 {
|
|
use std::sync::atomic::{AtomicU64, Ordering};
|
|
static STATE: AtomicU64 = AtomicU64::new(0x193_667_6a_5e_7c_57);
|
|
if n <= 1 {
|
|
return 0;
|
|
}
|
|
let mut current = STATE.load(Ordering::Relaxed);
|
|
loop {
|
|
let new_val = current.wrapping_mul(6364136223846793005).wrapping_add(1);
|
|
match STATE.compare_exchange_weak(current, new_val, Ordering::Relaxed, Ordering::Relaxed) {
|
|
Ok(_) => return new_val % n,
|
|
Err(actual) => current = actual,
|
|
}
|
|
}
|
|
}
|
|
|
|
fn is_retryable_error(err: &AgentError) -> bool {
|
|
let msg = err.to_string();
|
|
msg.contains("connection refused")
|
|
|| msg.contains("connection timed out")
|
|
|| msg.contains("network error")
|
|
|| msg.contains("dns error")
|
|
|| msg.contains("error sending request")
|
|
|| msg.contains("Http client error")
|
|
|| msg.contains("rate_limit")
|
|
|| msg.contains("rate limit")
|
|
|| msg.contains("429")
|
|
|| msg.contains("500")
|
|
|| msg.contains("502")
|
|
|| msg.contains("503")
|
|
|| msg.contains("504")
|
|
|| msg.contains("internal_server_error")
|
|
|| msg.contains("service_unavailable")
|
|
|| msg.contains("gateway_timeout")
|
|
|| msg.contains("bad_gateway")
|
|
}
|
|
|
|
static AI_METRICS: std::sync::OnceLock<AiMetrics> = std::sync::OnceLock::new();
|
|
|
|
fn ai_metrics() -> &'static AiMetrics {
|
|
AI_METRICS.get_or_init(AiMetrics::new)
|
|
}
|
|
|
|
// ── Type conversions ─────────────────────────────────────────────────────────
|
|
|
|
pub(crate) fn to_rig_message(msg: &ChatRequestMessage) -> RigMessage {
|
|
match msg.role.as_str() {
|
|
"system" => {
|
|
// System messages are handled via preamble(), not passed as messages.
|
|
// We still need to return a valid RigMessage variant.
|
|
RigMessage::user(msg.content.as_deref().unwrap_or(""))
|
|
}
|
|
"user" => RigMessage::user(msg.content.as_deref().unwrap_or("")),
|
|
"assistant" => {
|
|
let mut parts: Vec<AssistantContent> = Vec::new();
|
|
if let Some(ref content) = msg.content {
|
|
if !content.is_empty() {
|
|
parts.push(AssistantContent::text(content));
|
|
}
|
|
}
|
|
if let Some(ref tool_calls) = msg.tool_calls {
|
|
for tc in tool_calls {
|
|
// GLM may return empty tool call IDs — fall back to a generated UUID.
|
|
let id = if tc.id.is_empty() {
|
|
Uuid::new_v4().to_string()
|
|
} else {
|
|
tc.id.clone()
|
|
};
|
|
parts.push(AssistantContent::tool_call_with_call_id(
|
|
&id,
|
|
id.clone(),
|
|
&tc.function.name,
|
|
serde_json::from_str(&tc.function.arguments)
|
|
.unwrap_or(serde_json::Value::Null),
|
|
));
|
|
}
|
|
}
|
|
if parts.is_empty() {
|
|
RigMessage::assistant("")
|
|
} else if parts.len() == 1 {
|
|
// Single part — use simpler constructors
|
|
match parts.pop().unwrap() {
|
|
AssistantContent::Text(t) => RigMessage::assistant(t.text),
|
|
ac => RigMessage::Assistant {
|
|
id: None,
|
|
content: OneOrMany::one(ac),
|
|
},
|
|
}
|
|
} else {
|
|
let content = OneOrMany::many(parts).expect("non-empty parts");
|
|
RigMessage::Assistant { id: None, content }
|
|
}
|
|
}
|
|
"tool" | "function" => {
|
|
let id = msg.tool_call_id.as_deref().unwrap_or("unknown").to_string();
|
|
let call_id = msg.tool_call_id.clone().or_else(|| Some(id.clone()));
|
|
let content = msg.content.as_deref().unwrap_or("");
|
|
RigMessage::tool_result_with_call_id(id, call_id, content)
|
|
}
|
|
"developer" => {
|
|
// Developer role maps to user/system in rig
|
|
RigMessage::user(msg.content.as_deref().unwrap_or(""))
|
|
}
|
|
_ => RigMessage::user(msg.content.as_deref().unwrap_or("")),
|
|
}
|
|
}
|
|
|
|
fn to_rig_tool_def(tool_json: &serde_json::Value) -> Option<ToolDefinition> {
|
|
let name = tool_json
|
|
.get("function")
|
|
.and_then(|f| f.get("name"))
|
|
.and_then(|n| n.as_str())?
|
|
.to_string();
|
|
|
|
let description = tool_json
|
|
.get("function")
|
|
.and_then(|f| f.get("description"))
|
|
.and_then(|d| d.as_str())
|
|
.map(|s| s.to_string())
|
|
.unwrap_or_default();
|
|
|
|
let parameters = tool_json
|
|
.get("function")
|
|
.and_then(|f| f.get("parameters"))
|
|
.cloned()
|
|
.unwrap_or(serde_json::json!({}));
|
|
|
|
Some(ToolDefinition {
|
|
name,
|
|
description,
|
|
parameters,
|
|
})
|
|
}
|
|
|
|
// ── Call helpers ─────────────────────────────────────────────────────────────
|
|
|
|
async fn do_completion<M>(
|
|
model: &M,
|
|
messages: &[ChatRequestMessage],
|
|
temperature: Option<f64>,
|
|
max_tokens: Option<u32>,
|
|
tools: Option<&[serde_json::Value]>,
|
|
tool_choice: Option<&str>,
|
|
) -> Result<(String, u64, u64, Vec<ClientToolCall>, Vec<String>)>
|
|
where
|
|
M: CompletionModel<Client = openai::Client>,
|
|
{
|
|
let preamble = messages
|
|
.iter()
|
|
.find(|m| m.role == "system")
|
|
.and_then(|m| m.content.as_deref())
|
|
.unwrap_or("")
|
|
.to_string();
|
|
|
|
let non_system: Vec<RigMessage> = messages
|
|
.iter()
|
|
.filter(|m| m.role != "system")
|
|
.map(to_rig_message)
|
|
.collect();
|
|
|
|
let tool_defs: Vec<ToolDefinition> = tools
|
|
.map(|ts| ts.iter().filter_map(to_rig_tool_def).collect())
|
|
.unwrap_or_default();
|
|
|
|
let mut builder = model.completion_request("");
|
|
|
|
if !preamble.is_empty() {
|
|
builder = builder.preamble(preamble);
|
|
}
|
|
|
|
if !non_system.is_empty() {
|
|
builder = builder.messages(non_system);
|
|
}
|
|
|
|
if let Some(t) = temperature {
|
|
builder = builder.temperature(t);
|
|
}
|
|
|
|
if let Some(mt) = max_tokens {
|
|
builder = builder.max_tokens(mt as u64);
|
|
}
|
|
|
|
if !tool_defs.is_empty() {
|
|
builder = builder.tools(tool_defs);
|
|
}
|
|
|
|
// Only set tool_choice when explicitly provided (mirrors call_stream_once logic)
|
|
if let Some(tc) = tool_choice {
|
|
match tc {
|
|
"none" => {
|
|
builder = builder.tool_choice(rig::completion::message::ToolChoice::None);
|
|
}
|
|
"auto" => {
|
|
builder = builder.tool_choice(rig::completion::message::ToolChoice::Auto);
|
|
}
|
|
s => {
|
|
builder = builder.tool_choice(rig::completion::message::ToolChoice::Specific {
|
|
function_names: vec![s.to_string()],
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
let response = builder
|
|
.send()
|
|
.await
|
|
.map_err(|e| AgentError::OpenAi(e.to_string()))?;
|
|
|
|
let mut content = String::new();
|
|
let mut tool_names: Vec<String> = Vec::new();
|
|
let mut tool_calls: Vec<ClientToolCall> = Vec::new();
|
|
for item in response.choice {
|
|
match item {
|
|
AssistantContent::Text(t) => {
|
|
content.push_str(&t.text);
|
|
}
|
|
AssistantContent::ToolCall(tc) => {
|
|
tool_names.push(tc.function.name.clone());
|
|
tool_calls.push(ClientToolCall {
|
|
id: tc.id,
|
|
type_: "function".into(),
|
|
function: types::ToolCallFunction {
|
|
name: tc.function.name,
|
|
arguments: serde_json::to_string(&tc.function.arguments)
|
|
.unwrap_or_else(|_| "{}".to_string()),
|
|
},
|
|
});
|
|
}
|
|
AssistantContent::Reasoning(_) => {}
|
|
AssistantContent::Image(_) => {}
|
|
}
|
|
}
|
|
|
|
let input_tokens = response.usage.input_tokens;
|
|
let output_tokens = response.usage.output_tokens;
|
|
|
|
Ok((content, input_tokens, output_tokens, tool_calls, tool_names))
|
|
}
|
|
|
|
// ── Public API ───────────────────────────────────────────────────────────────
|
|
|
|
/// Call the AI model with automatic retry (no custom params).
|
|
pub async fn call_with_retry(
|
|
messages: &[ChatRequestMessage],
|
|
model_name: &str,
|
|
config: &AiClientConfig,
|
|
max_retries: Option<u32>,
|
|
) -> Result<AiCallResponse> {
|
|
let client = config.build_rig_client();
|
|
let model = client.completion_model(model_name);
|
|
let mut state = RetryState::new(max_retries.unwrap_or(3));
|
|
|
|
loop {
|
|
let start = Instant::now();
|
|
|
|
let result = do_completion(&model, messages, None, None, None, None).await;
|
|
|
|
match result {
|
|
Ok((content, input_tokens, output_tokens, tool_calls, tool_names)) => {
|
|
let latency_ms = start.elapsed().as_millis() as i64;
|
|
let has_function_call = !tool_names.is_empty();
|
|
ai_metrics().record_success(
|
|
input_tokens as i64,
|
|
output_tokens as i64,
|
|
has_function_call,
|
|
);
|
|
return Ok(AiCallResponse {
|
|
content,
|
|
input_tokens: input_tokens as i64,
|
|
output_tokens: output_tokens as i64,
|
|
latency_ms,
|
|
tool_calls,
|
|
tool_calls_finished: tool_names,
|
|
});
|
|
}
|
|
Err(ref err) if state.should_retry() && is_retryable_error(err) => {
|
|
let duration = state.backoff_duration();
|
|
tracing::warn!(
|
|
attempt = state.attempt + 1,
|
|
max_retries = state.max_retries,
|
|
backoff_ms = duration.as_millis() as u64,
|
|
model = %model_name,
|
|
error = %err,
|
|
"ai_call_retry"
|
|
);
|
|
tokio::time::sleep(duration).await;
|
|
state.next();
|
|
}
|
|
Err(err) => {
|
|
ai_metrics().record_failure();
|
|
return Err(err);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Call with custom parameters (temperature, max_tokens, optional tools, optional tool_choice).
|
|
pub async fn call_with_params(
|
|
messages: &[ChatRequestMessage],
|
|
model_name: &str,
|
|
config: &AiClientConfig,
|
|
temperature: f32,
|
|
max_tokens: u32,
|
|
max_retries: Option<u32>,
|
|
tools: Option<&[serde_json::Value]>,
|
|
tool_choice: Option<&str>,
|
|
) -> Result<AiCallResponse> {
|
|
let client = config.build_rig_client();
|
|
let model = client.completion_model(model_name);
|
|
let mut state = RetryState::new(max_retries.unwrap_or(3));
|
|
|
|
loop {
|
|
let start = Instant::now();
|
|
|
|
let result = do_completion(
|
|
&model,
|
|
messages,
|
|
Some(temperature as f64),
|
|
Some(max_tokens),
|
|
tools,
|
|
tool_choice,
|
|
)
|
|
.await;
|
|
|
|
match result {
|
|
Ok((content, input_tokens, output_tokens, tool_calls, tool_names)) => {
|
|
let latency_ms = start.elapsed().as_millis() as i64;
|
|
let has_function_call = !tool_names.is_empty();
|
|
ai_metrics().record_success(
|
|
input_tokens as i64,
|
|
output_tokens as i64,
|
|
has_function_call,
|
|
);
|
|
return Ok(AiCallResponse {
|
|
content,
|
|
input_tokens: input_tokens as i64,
|
|
output_tokens: output_tokens as i64,
|
|
latency_ms,
|
|
tool_calls,
|
|
tool_calls_finished: tool_names,
|
|
});
|
|
}
|
|
Err(ref err) if state.should_retry() && is_retryable_error(err) => {
|
|
let duration = state.backoff_duration();
|
|
tracing::warn!(
|
|
attempt = state.attempt + 1,
|
|
max_retries = state.max_retries,
|
|
backoff_ms = duration.as_millis() as u64,
|
|
model = %model_name,
|
|
error = %err,
|
|
"ai_call_retry"
|
|
);
|
|
tokio::time::sleep(duration).await;
|
|
state.next();
|
|
}
|
|
Err(err) => {
|
|
ai_metrics().record_failure();
|
|
return Err(err);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// A tool call extracted from streaming response with accumulated arguments.
|
|
#[derive(Debug, Clone)]
|
|
pub struct StreamedToolCall {
|
|
/// Tool call ID
|
|
pub id: String,
|
|
/// Tool function name
|
|
pub name: String,
|
|
/// Accumulated JSON arguments string
|
|
pub arguments: String,
|
|
}
|
|
|
|
/// Type of chunk in the streaming response, preserving arrival order.
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub enum StreamChunkType {
|
|
Thinking,
|
|
Answer,
|
|
ToolCall,
|
|
ToolResult,
|
|
}
|
|
|
|
/// A single chunk from the streaming response in arrival order.
|
|
#[derive(Debug, Clone)]
|
|
pub struct StreamChunk {
|
|
pub chunk_type: StreamChunkType,
|
|
pub content: String,
|
|
}
|
|
|
|
/// Streaming result from rig.
|
|
#[derive(Debug)]
|
|
pub struct StreamResponse {
|
|
pub content: String,
|
|
pub input_tokens: i64,
|
|
pub output_tokens: i64,
|
|
/// Accumulated reasoning/thinking text from the model.
|
|
pub reasoning_content: String,
|
|
/// Full tool calls with accumulated arguments (not just names)
|
|
pub tool_calls: Vec<StreamedToolCall>,
|
|
/// All chunks in arrival order — preserves think/answer/tool interleaving.
|
|
pub chunks: Vec<StreamChunk>,
|
|
}
|
|
|
|
/// Async callback: takes a string delta and broadcasts it to the WebSocket.
|
|
/// The returned Future must be awaited by the caller.
|
|
pub type StreamTextCb =
|
|
Arc<dyn Fn(&str) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>> + Send + Sync>;
|
|
pub type StreamReasoningCb =
|
|
Arc<dyn Fn(&str) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>> + Send + Sync>;
|
|
pub type StreamToolCallCb = Arc<
|
|
dyn Fn(&StreamedToolCall) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>>
|
|
+ Send
|
|
+ Sync,
|
|
>;
|
|
|
|
/// Run a streaming chat completion with 60s timeout and 5 retries.
|
|
pub async fn call_stream(
|
|
messages: &[ChatRequestMessage],
|
|
model_name: &str,
|
|
config: &AiClientConfig,
|
|
temperature: f32,
|
|
max_tokens: u32,
|
|
tools: Option<&[serde_json::Value]>,
|
|
tool_choice: Option<&str>,
|
|
on_text_delta: StreamTextCb,
|
|
on_reasoning_delta: StreamReasoningCb,
|
|
on_tool_call: StreamToolCallCb,
|
|
) -> Result<StreamResponse> {
|
|
let mut state = RetryState::new(5);
|
|
|
|
loop {
|
|
let result = call_stream_once(
|
|
messages,
|
|
model_name,
|
|
config,
|
|
temperature,
|
|
max_tokens,
|
|
tools,
|
|
tool_choice,
|
|
on_text_delta.clone(),
|
|
on_reasoning_delta.clone(),
|
|
on_tool_call.clone(),
|
|
)
|
|
.await;
|
|
|
|
match result {
|
|
Ok(response) => return Ok(response),
|
|
Err(ref err) if state.should_retry() && is_retryable_error(err) => {
|
|
let duration = state.backoff_duration();
|
|
tracing::warn!(
|
|
attempt = state.attempt + 1,
|
|
max_retries = 5,
|
|
backoff_ms = duration.as_millis() as u64,
|
|
model = %model_name,
|
|
error = %err,
|
|
"ai_stream_retry"
|
|
);
|
|
tokio::time::sleep(duration).await;
|
|
state.next();
|
|
}
|
|
Err(err) => {
|
|
ai_metrics().record_failure();
|
|
return Err(err);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Single attempt of streaming completion with 60s timeout.
|
|
async fn call_stream_once(
|
|
messages: &[ChatRequestMessage],
|
|
model_name: &str,
|
|
config: &AiClientConfig,
|
|
temperature: f32,
|
|
max_tokens: u32,
|
|
tools: Option<&[serde_json::Value]>,
|
|
tool_choice: Option<&str>,
|
|
on_text_delta: StreamTextCb,
|
|
on_reasoning_delta: StreamReasoningCb,
|
|
on_tool_call: StreamToolCallCb,
|
|
) -> Result<StreamResponse> {
|
|
let client = config.build_rig_client();
|
|
let model = client.completion_model(model_name);
|
|
|
|
let preamble = messages
|
|
.iter()
|
|
.find(|m| m.role == "system")
|
|
.and_then(|m| m.content.as_deref())
|
|
.unwrap_or("")
|
|
.to_string();
|
|
|
|
let non_system: Vec<RigMessage> = messages
|
|
.iter()
|
|
.filter(|m| m.role != "system")
|
|
.map(to_rig_message)
|
|
.collect();
|
|
|
|
let tool_defs: Vec<ToolDefinition> = tools
|
|
.map(|ts| ts.iter().filter_map(to_rig_tool_def).collect())
|
|
.unwrap_or_default();
|
|
|
|
let mut builder = model
|
|
.completion_request("")
|
|
.temperature(temperature as f64)
|
|
.max_tokens(max_tokens as u64);
|
|
|
|
if !preamble.is_empty() {
|
|
builder = builder.preamble(preamble);
|
|
}
|
|
if !non_system.is_empty() {
|
|
builder = builder.messages(non_system);
|
|
}
|
|
if !tool_defs.is_empty() {
|
|
builder = builder.tools(tool_defs);
|
|
}
|
|
|
|
if let Some(tc) = tool_choice {
|
|
match tc {
|
|
"none" => {
|
|
builder = builder.tool_choice(rig::completion::message::ToolChoice::None);
|
|
}
|
|
"auto" => {
|
|
builder = builder.tool_choice(rig::completion::message::ToolChoice::Auto);
|
|
}
|
|
s => {
|
|
builder = builder.tool_choice(rig::completion::message::ToolChoice::Specific {
|
|
function_names: vec![s.to_string()],
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
let stream_fut = async {
|
|
let mut stream = builder
|
|
.stream()
|
|
.await
|
|
.map_err(|e| AgentError::OpenAi(e.to_string()))?;
|
|
|
|
let mut content = String::new();
|
|
let mut reasoning_content = String::new();
|
|
let mut tool_calls: Vec<StreamedToolCall> = Vec::new();
|
|
let mut chunks: Vec<StreamChunk> = Vec::new();
|
|
|
|
// Some models (e.g. GLM) ignore tool_choice="none" and still emit tool_calls.
|
|
// Filter them out so they don't cause spurious tool execution attempts.
|
|
let skip_tool_calls = tool_choice == Some("none");
|
|
|
|
use std::collections::HashMap;
|
|
let mut partial_tool_calls: HashMap<String, StreamedToolCall> = HashMap::new();
|
|
let mut stream_finished = false;
|
|
|
|
use rig::streaming::StreamedAssistantContent;
|
|
|
|
while let Some(item) = stream.next().await {
|
|
match item {
|
|
Ok(StreamedAssistantContent::Text(text)) => {
|
|
content.push_str(&text.text);
|
|
on_text_delta(&text.text).await;
|
|
chunks.push(StreamChunk {
|
|
chunk_type: StreamChunkType::Answer,
|
|
content: text.text,
|
|
});
|
|
}
|
|
Ok(StreamedAssistantContent::ToolCall {
|
|
tool_call,
|
|
internal_call_id,
|
|
}) => {
|
|
if skip_tool_calls {
|
|
partial_tool_calls.remove(&internal_call_id);
|
|
continue;
|
|
}
|
|
let arguments = match &tool_call.function.arguments {
|
|
serde_json::Value::String(s) => s.clone(),
|
|
other => serde_json::to_string(other).unwrap_or_else(|_| "{}".to_string()),
|
|
};
|
|
let tc = StreamedToolCall {
|
|
id: tool_call.id.clone(),
|
|
name: tool_call.function.name.clone(),
|
|
arguments,
|
|
};
|
|
on_tool_call(&tc).await;
|
|
chunks.push(StreamChunk {
|
|
chunk_type: StreamChunkType::ToolCall,
|
|
content: serde_json::json!({
|
|
"id": tc.id,
|
|
"name": tc.name,
|
|
"arguments": tc.arguments,
|
|
})
|
|
.to_string(),
|
|
});
|
|
tool_calls.push(tc);
|
|
partial_tool_calls.remove(&internal_call_id);
|
|
}
|
|
Ok(StreamedAssistantContent::ToolCallDelta {
|
|
id,
|
|
internal_call_id,
|
|
content: delta_content,
|
|
}) => {
|
|
if skip_tool_calls {
|
|
continue;
|
|
}
|
|
use rig::streaming::ToolCallDeltaContent;
|
|
match delta_content {
|
|
ToolCallDeltaContent::Name(name) => {
|
|
partial_tool_calls.insert(
|
|
internal_call_id.clone(),
|
|
StreamedToolCall {
|
|
id: id.clone(),
|
|
name,
|
|
arguments: String::new(),
|
|
},
|
|
);
|
|
}
|
|
ToolCallDeltaContent::Delta(delta) => {
|
|
if let Some(tc) = partial_tool_calls.get_mut(&internal_call_id) {
|
|
tc.arguments.push_str(&delta);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
Ok(StreamedAssistantContent::Reasoning(reasoning)) => {
|
|
for part in &reasoning.content {
|
|
if let rig::completion::message::ReasoningContent::Text { text, .. } = part
|
|
{
|
|
reasoning_content.push_str(text);
|
|
on_reasoning_delta(text).await;
|
|
chunks.push(StreamChunk {
|
|
chunk_type: StreamChunkType::Thinking,
|
|
content: text.clone(),
|
|
});
|
|
}
|
|
}
|
|
}
|
|
Ok(StreamedAssistantContent::ReasoningDelta { reasoning, .. }) => {
|
|
reasoning_content.push_str(&reasoning);
|
|
on_reasoning_delta(&reasoning).await;
|
|
chunks.push(StreamChunk {
|
|
chunk_type: StreamChunkType::Thinking,
|
|
content: reasoning.clone(),
|
|
});
|
|
}
|
|
Ok(StreamedAssistantContent::Final(response)) => {
|
|
stream_finished = true;
|
|
if !skip_tool_calls {
|
|
for (_, tc) in partial_tool_calls.drain() {
|
|
tool_calls.push(tc);
|
|
}
|
|
} else {
|
|
partial_tool_calls.drain();
|
|
}
|
|
if let Some(usage) = response.token_usage() {
|
|
let in_toks = usage.input_tokens as i64;
|
|
let out_toks = usage.output_tokens as i64;
|
|
ai_metrics().record_success(in_toks, out_toks, !tool_calls.is_empty());
|
|
return Ok(StreamResponse {
|
|
content,
|
|
reasoning_content,
|
|
input_tokens: in_toks,
|
|
output_tokens: out_toks,
|
|
tool_calls,
|
|
chunks,
|
|
});
|
|
}
|
|
// Usage not available from Final — fall through to flush
|
|
}
|
|
Err(e) => return Err(AgentError::OpenAi(e.to_string())),
|
|
}
|
|
}
|
|
|
|
// Flush any remaining partial tool calls (if stream ended without Final or Final had no usage)
|
|
if !stream_finished && !skip_tool_calls {
|
|
for (_, tc) in partial_tool_calls.drain() {
|
|
tool_calls.push(tc);
|
|
}
|
|
}
|
|
ai_metrics().record_success(0, 0, !tool_calls.is_empty());
|
|
Ok(StreamResponse {
|
|
content,
|
|
reasoning_content,
|
|
input_tokens: 0,
|
|
output_tokens: 0,
|
|
tool_calls,
|
|
chunks,
|
|
})
|
|
};
|
|
|
|
// 120s timeout for the entire stream
|
|
match tokio::time::timeout(std::time::Duration::from_secs(120), stream_fut).await {
|
|
Ok(result) => result,
|
|
Err(_) => Err(AgentError::Timeout {
|
|
task_id: 0,
|
|
seconds: 120,
|
|
}),
|
|
}
|
|
}
|