feat(agent): add ordered stream chunk collection + retry for HTTP errors

- StreamChunk/StreamChunkType types for preserving arrival order
- Chunk collection in call_stream_once and process_stream
- Add "error sending request" and "Http client error" to retryable errors
- StreamResult includes chunks vector for ordered replay
This commit is contained in:
ZhenYi 2026-04-26 13:10:26 +08:00
parent 0b5dc98ce5
commit b4b5538447
2 changed files with 410 additions and 211 deletions

View File

@ -1,4 +1,5 @@
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use models::projects::project_skill; use models::projects::project_skill;
use models::rooms::room_ai; use models::rooms::room_ai;
@ -9,7 +10,7 @@ use super::context::RoomMessageContext;
use super::{AiChunkType, AiRequest, AiStreamChunk, Mention, StreamCallback}; use super::{AiChunkType, AiRequest, AiStreamChunk, Mention, StreamCallback};
use crate::client::types::{ChatRequestMessage, ToolCall}; use crate::client::types::{ChatRequestMessage, ToolCall};
use crate::client::AiClientConfig; use crate::client::AiClientConfig;
use crate::client::{call_stream, call_with_params}; use crate::client::{call_stream, call_with_params, StreamChunk, StreamChunkType, StreamedToolCall};
use crate::compact::{CompactConfig, CompactService}; use crate::compact::{CompactConfig, CompactService};
use crate::embed::EmbedService; use crate::embed::EmbedService;
use crate::error::{AgentError, Result}; use crate::error::{AgentError, Result};
@ -17,6 +18,23 @@ use crate::perception::{PerceptionService, SkillEntry, ToolCallEvent};
use crate::react::{ReactAgent, ReactConfig, DEFAULT_SYSTEM_PROMPT}; use crate::react::{ReactAgent, ReactConfig, DEFAULT_SYSTEM_PROMPT};
use crate::tool::{ToolCall as AgentToolCall, ToolContext, ToolExecutor, ToolResult, registry::ToolRegistry}; use crate::tool::{ToolCall as AgentToolCall, ToolContext, ToolExecutor, ToolResult, registry::ToolRegistry};
/// Result from streaming AI response.
pub struct StreamResult {
pub content: String,
pub reasoning_content: String,
pub input_tokens: i64,
pub output_tokens: i64,
/// All chunks in arrival order — preserves ReAct multi-cycle ordering.
pub chunks: Vec<StreamChunk>,
}
/// Result from non-streaming AI response.
pub struct ProcessResult {
pub content: String,
pub input_tokens: i64,
pub output_tokens: i64,
}
/// Service for handling AI chat requests in rooms. /// Service for handling AI chat requests in rooms.
pub struct ChatService { pub struct ChatService {
ai_base_url: Option<String>, ai_base_url: Option<String>,
@ -97,7 +115,7 @@ impl ChatService {
self.tool_registry.as_ref() self.tool_registry.as_ref()
} }
pub async fn process(&self, request: AiRequest) -> Result<String> { pub async fn process(&self, request: AiRequest) -> Result<ProcessResult> {
let tools: Vec<serde_json::Value> = request.tools.clone().unwrap_or_default(); let tools: Vec<serde_json::Value> = request.tools.clone().unwrap_or_default();
let tools_enabled = !tools.is_empty(); let tools_enabled = !tools.is_empty();
let max_tool_depth = request.max_tool_depth; let max_tool_depth = request.max_tool_depth;
@ -120,6 +138,8 @@ impl ChatService {
.and_then(|r| r.max_tokens.map(|v| v as u32)) .and_then(|r| r.max_tokens.map(|v| v as u32))
.unwrap_or(request.max_tokens as u32); .unwrap_or(request.max_tokens as u32);
let mut tool_depth = 0; let mut tool_depth = 0;
let mut input_tokens = 0i64;
let mut output_tokens = 0i64;
let config = AiClientConfig::new( let config = AiClientConfig::new(
self.ai_api_key.clone().unwrap_or_default(), self.ai_api_key.clone().unwrap_or_default(),
@ -140,6 +160,8 @@ impl ChatService {
.await?; .await?;
let text = response.content.clone(); let text = response.content.clone();
input_tokens += response.input_tokens;
output_tokens += response.output_tokens;
if tools_enabled && !response.tool_calls_finished.is_empty() { if tools_enabled && !response.tool_calls_finished.is_empty() {
// Build assistant message with tool_calls // Build assistant message with tool_calls
@ -176,16 +198,29 @@ impl ChatService {
}) })
.collect(); .collect();
let tool_messages = match self.execute_tool_calls(calls, &request).await { let tool_messages = {
Ok(msgs) => msgs, let mut ctx = ToolContext::new(
Err(e) => { request.db.clone(),
let err_msg = format!("[Tool call failed: {}]", e); request.cache.clone(),
// Return error as a single tool result per call request.config.clone(),
response request.room.id,
.tool_calls_finished Some(request.sender.uid),
.iter() )
.map(|_| ChatRequestMessage::tool(Uuid::new_v4().to_string(), &err_msg)) .with_project(request.project.id);
.collect() if let Some(ref registry) = self.tool_registry {
ctx.registry_mut().merge(registry.clone());
}
let executor = ToolExecutor::new();
match executor.execute_batch(calls, &mut ctx).await {
Ok(results) => ToolExecutor::to_tool_messages(&results),
Err(e) => {
let err_msg = format!("[Tool call failed: {}]", e);
response
.tool_calls_finished
.iter()
.map(|_| ChatRequestMessage::tool(Uuid::new_v4().to_string(), &err_msg))
.collect()
}
} }
}; };
messages.extend(tool_messages); messages.extend(tool_messages);
@ -225,22 +260,26 @@ impl ChatService {
tool_depth += 1; tool_depth += 1;
if tool_depth >= max_tool_depth { if tool_depth >= max_tool_depth {
if text.is_empty() { let content = if text.is_empty() {
return Ok(format!( format!(
"[AI reached maximum tool depth ({}) — no final answer produced]", "[AI reached maximum tool depth ({}) — no final answer produced]",
max_tool_depth max_tool_depth
)); )
} } else {
return Ok(text); text
};
return Ok(ProcessResult { content, input_tokens, output_tokens });
} }
continue; continue;
} }
return Ok(text); return Ok(ProcessResult { content: text, input_tokens, output_tokens });
} }
} }
pub async fn process_stream(&self, request: AiRequest, on_chunk: StreamCallback) -> Result<String> { pub async fn process_stream(&self, request: AiRequest, on_chunk: StreamCallback) -> Result<StreamResult> {
// Wrap on_chunk in Arc so it can be shared across loop iterations
let on_chunk = Arc::new(on_chunk);
let tools: Vec<serde_json::Value> = request.tools.clone().unwrap_or_default(); let tools: Vec<serde_json::Value> = request.tools.clone().unwrap_or_default();
let tools_enabled = !tools.is_empty(); let tools_enabled = !tools.is_empty();
let max_tool_depth = request.max_tool_depth; let max_tool_depth = request.max_tool_depth;
@ -270,13 +309,15 @@ impl ChatService {
.with_base_url(self.ai_base_url.clone().unwrap_or_else(|| "https://api.openai.com".into())); .with_base_url(self.ai_base_url.clone().unwrap_or_else(|| "https://api.openai.com".into()));
let mut full_content = String::new(); let mut full_content = String::new();
let mut has_called_tools = false; let mut all_chunks: Vec<StreamChunk> = Vec::new();
// Collect tool calls during streaming, push them incrementally after.
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<StreamedToolCall>();
loop { loop {
let chunk_type = if has_called_tools { let on_chunk_cb = on_chunk.clone();
AiChunkType::Answer let on_chunk_cb2 = on_chunk_cb.clone();
} else { let tx_arc = Arc::new(tx.clone());
AiChunkType::Thinking let tx_arc2 = tx_arc.clone();
};
let response = call_stream( let response = call_stream(
&messages, &messages,
&model_name, &model_name,
@ -284,18 +325,36 @@ impl ChatService {
temperature, temperature,
max_tokens, max_tokens,
if tools_enabled { Some(&tools) } else { None }, if tools_enabled { Some(&tools) } else { None },
|delta| { Arc::new(move |delta| {
let _ = on_chunk(AiStreamChunk { let fut = on_chunk_cb(AiStreamChunk {
content: delta.to_string(), content: delta.to_string(),
done: false, done: false,
chunk_type: chunk_type.clone(), chunk_type: AiChunkType::Answer,
}); });
}, fut
}),
Arc::new(move |delta| {
let fut = on_chunk_cb2(AiStreamChunk {
content: delta.to_string(),
done: false,
chunk_type: AiChunkType::Thinking,
});
fut
}),
Arc::new(move |tc: &StreamedToolCall| {
let tx = tx_arc2.clone();
let tc_owned = tc.clone();
Box::pin(async move {
let _ = tx.send(tc_owned);
}) as Pin<Box<dyn std::future::Future<Output = ()> + Send>>
}),
) )
.await?; .await?;
// Collect chunks from this streaming iteration in order.
all_chunks.extend(response.chunks);
let has_tool_calls = tools_enabled && !response.tool_calls.is_empty(); let has_tool_calls = tools_enabled && !response.tool_calls.is_empty();
has_called_tools = true;
if has_tool_calls { if has_tool_calls {
// Accumulate the assistant's text before tool calls // Accumulate the assistant's text before tool calls
@ -321,28 +380,34 @@ impl ChatService {
Some(tool_calls.clone()), Some(tool_calls.clone()),
)); ));
// Stream tool call summary to frontend // Push each tool call incrementally to frontend.
let call_summary: Vec<String> = response // Use try_recv() — tx is never dropped so recv() would deadlock.
.tool_calls loop {
.iter() match rx.try_recv() {
.map(|tc| { Ok(tc) => {
// Truncate long arguments for display let args_display = if tc.arguments.len() > 100 {
let args_display = if tc.arguments.len() > 100 { format!("{}...", &tc.arguments[..100])
format!("{}...", &tc.arguments[..100]) } else {
} else { tc.arguments.clone()
tc.arguments.clone() };
}; let tool_display = format!("🔧 {}({})", tc.name, args_display);
format!("{}({})", tc.name, args_display) on_chunk(AiStreamChunk {
}) content: tool_display.clone(),
.collect(); done: false,
on_chunk(AiStreamChunk { chunk_type: AiChunkType::ToolCall,
content: format!("[Calling tools: {}]", call_summary.join(", ")), })
done: false, .await;
chunk_type: AiChunkType::ToolCall, all_chunks.push(StreamChunk {
}) chunk_type: StreamChunkType::ToolCall,
.await; content: tool_display,
});
}
Err(tokio::sync::mpsc::error::TryRecvError::Empty) => break,
Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => break,
}
}
// Execute tools with full arguments from streaming // Execute tools one at a time, push each result incrementally
let calls: Vec<AgentToolCall> = response let calls: Vec<AgentToolCall> = response
.tool_calls .tool_calls
.iter() .iter()
@ -353,43 +418,71 @@ impl ChatService {
}) })
.collect(); .collect();
let tool_messages = match self.execute_tool_calls(calls, &request).await { let mut tool_messages = Vec::new();
Ok(msgs) => { for call in &calls {
let result_summary: Vec<String> = msgs let ctx = &mut crate::tool::ToolContext::new(
.iter() request.db.clone(),
.map(|m| { request.cache.clone(),
let text = m.content.as_deref().unwrap_or("[no content]"); request.config.clone(),
if text.len() > 300 { request.room.id,
format!("{}...", &text[..300]) Some(request.sender.uid),
} else { );
text.to_string() if let Some(ref registry) = self.tool_registry {
} ctx.registry_mut().merge(registry.clone());
}
let executor = crate::tool::ToolExecutor::new();
let results = match executor.execute_batch(vec![call.clone()], ctx).await {
Ok(r) => r,
Err(e) => {
let err_text = format!("[Tool call failed: {}]", e);
tracing::warn!(tool = %call.name, error = %e, "tool_call_failed");
// Do NOT emit tool_result chunks to frontend — show error via tool_call instead
let err_display = format!("{} (failed)", call.name);
on_chunk(AiStreamChunk {
content: err_display.clone(),
done: false,
chunk_type: AiChunkType::ToolCall,
}) })
.collect(); .await;
on_chunk(AiStreamChunk { all_chunks.push(StreamChunk {
content: format!("[Tool results: {}]", result_summary.join("; ")), chunk_type: StreamChunkType::ToolCall,
done: false, content: err_display,
chunk_type: AiChunkType::ToolResult, });
}) tool_messages.push(ChatRequestMessage::tool(&call.id, &err_text));
.await; continue;
msgs }
};
for result in &results {
let text = match &result.result {
crate::tool::ToolResult::Ok(v) => v.to_string(),
crate::tool::ToolResult::Error(msg) => msg.clone(),
};
let preview = if text.len() > 300 {
format!("{}...", &text[..300])
} else {
text.clone()
};
tracing::debug!("tool_result: {} — {}", call.name, preview);
// Do NOT emit tool_result chunks to frontend — raw output may contain sensitive data.
// Log server-side only; frontend sees tool_call status via on_chunk below.
} }
Err(e) => { let success_display = format!("{}", call.name);
let err_text = format!("[Tool call failed: {}]", e); on_chunk(AiStreamChunk {
on_chunk(AiStreamChunk { content: success_display.clone(),
content: err_text.clone(), done: false,
done: false, chunk_type: AiChunkType::ToolCall,
chunk_type: AiChunkType::ToolResult, })
}) .await;
.await; all_chunks.push(StreamChunk {
// Return error tool messages chunk_type: StreamChunkType::ToolCall,
response content: success_display,
.tool_calls });
.iter()
.map(|tc| ChatRequestMessage::tool(&tc.id, &err_text)) let msgs = crate::tool::ToolExecutor::to_tool_messages(&results);
.collect() tool_messages.extend(msgs);
} }
};
messages.extend(tool_messages); messages.extend(tool_messages);
// Inject passive-detected skills based on tool calls // Inject passive-detected skills based on tool calls
@ -427,60 +520,54 @@ impl ChatService {
tool_depth += 1; tool_depth += 1;
if tool_depth >= max_tool_depth { if tool_depth >= max_tool_depth {
let max_depth_text = format!(
"[AI reached maximum tool depth ({}) — no final answer produced]",
max_tool_depth
);
on_chunk(AiStreamChunk { on_chunk(AiStreamChunk {
content: format!( content: max_depth_text.clone(),
"[AI reached maximum tool depth ({}) — no final answer produced]",
max_tool_depth
),
done: true, done: true,
chunk_type: AiChunkType::Answer, chunk_type: AiChunkType::Answer,
}) })
.await; .await;
return Ok(full_content); all_chunks.push(StreamChunk {
chunk_type: StreamChunkType::Answer,
content: max_depth_text,
});
return Ok(StreamResult {
content: full_content,
reasoning_content: String::new(),
input_tokens: 0,
output_tokens: 0,
chunks: all_chunks,
});
} }
continue; continue;
} }
// Final answer — accumulate and return // Final answer — accumulate and return
full_content.push_str(&response.content); full_content.push_str(&response.content);
on_chunk(AiStreamChunk { on_chunk(AiStreamChunk {
content: response.content, content: response.content.clone(),
done: true, done: true,
chunk_type: AiChunkType::Answer, chunk_type: AiChunkType::Answer,
}) })
.await; .await;
return Ok(full_content); all_chunks.push(StreamChunk {
chunk_type: StreamChunkType::Answer,
content: response.content.clone(),
});
return Ok(StreamResult {
content: full_content,
reasoning_content: response.reasoning_content,
input_tokens: response.input_tokens,
output_tokens: response.output_tokens,
chunks: all_chunks,
});
} }
} }
/// Executes a batch of tool calls and returns the tool result messages.
async fn execute_tool_calls(
&self,
calls: Vec<AgentToolCall>,
request: &AiRequest,
) -> Result<Vec<ChatRequestMessage>> {
let mut ctx = ToolContext::new(
request.db.clone(),
request.cache.clone(),
request.config.clone(),
request.room.id,
Some(request.sender.uid),
)
.with_project(request.project.id);
if let Some(ref registry) = self.tool_registry {
ctx.registry_mut().merge(registry.clone());
}
let executor = ToolExecutor::new();
let results = executor
.execute_batch(calls, &mut ctx)
.await
.map_err(|e| AgentError::Internal(e.to_string()))?;
Ok(ToolExecutor::to_tool_messages(&results))
}
async fn build_messages(&self, request: &AiRequest) -> Result<Vec<ChatRequestMessage>> { async fn build_messages(&self, request: &AiRequest) -> Result<Vec<ChatRequestMessage>> {
let mut messages = Vec::new(); let mut messages = Vec::new();

View File

@ -5,6 +5,8 @@
pub mod types; pub mod types;
pub use types::{ChatRequestMessage, ToolCall as ClientToolCall}; pub use types::{ChatRequestMessage, ToolCall as ClientToolCall};
use std::pin::Pin;
use std::sync::Arc;
use std::time::Instant; use std::time::Instant;
use uuid::Uuid; use uuid::Uuid;
@ -130,6 +132,8 @@ fn is_retryable_error(err: &AgentError) -> bool {
|| msg.contains("connection timed out") || msg.contains("connection timed out")
|| msg.contains("network error") || msg.contains("network error")
|| msg.contains("dns error") || msg.contains("dns error")
|| msg.contains("error sending request")
|| msg.contains("Http client error")
|| msg.contains("rate_limit") || msg.contains("rate_limit")
|| msg.contains("rate limit") || msg.contains("rate limit")
|| msg.contains("429") || msg.contains("429")
@ -451,17 +455,42 @@ pub struct StreamedToolCall {
pub arguments: String, pub arguments: String,
} }
/// Type of chunk in the streaming response, preserving arrival order.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum StreamChunkType {
Thinking,
Answer,
ToolCall,
}
/// A single chunk from the streaming response in arrival order.
#[derive(Debug, Clone)]
pub struct StreamChunk {
pub chunk_type: StreamChunkType,
pub content: String,
}
/// Streaming result from rig. /// Streaming result from rig.
#[derive(Debug)] #[derive(Debug)]
pub struct StreamResponse { pub struct StreamResponse {
pub content: String, pub content: String,
pub input_tokens: i64, pub input_tokens: i64,
pub output_tokens: i64, pub output_tokens: i64,
/// Accumulated reasoning/thinking text from the model.
pub reasoning_content: String,
/// Full tool calls with accumulated arguments (not just names) /// Full tool calls with accumulated arguments (not just names)
pub tool_calls: Vec<StreamedToolCall>, pub tool_calls: Vec<StreamedToolCall>,
/// All chunks in arrival order — preserves think/answer/tool interleaving.
pub chunks: Vec<StreamChunk>,
} }
/// Run a streaming chat completion. /// Async callback: takes a string delta and broadcasts it to the WebSocket.
/// The returned Future must be awaited by the caller.
pub type StreamTextCb = Arc<dyn Fn(&str) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>> + Send + Sync>;
pub type StreamReasoningCb = Arc<dyn Fn(&str) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>> + Send + Sync>;
pub type StreamToolCallCb = Arc<dyn Fn(&StreamedToolCall) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>> + Send + Sync>;
/// Run a streaming chat completion with 60s timeout and 5 retries.
pub async fn call_stream( pub async fn call_stream(
messages: &[ChatRequestMessage], messages: &[ChatRequestMessage],
model_name: &str, model_name: &str,
@ -469,7 +498,53 @@ pub async fn call_stream(
temperature: f32, temperature: f32,
max_tokens: u32, max_tokens: u32,
tools: Option<&[serde_json::Value]>, tools: Option<&[serde_json::Value]>,
mut on_text_delta: impl FnMut(&str), on_text_delta: StreamTextCb,
on_reasoning_delta: StreamReasoningCb,
on_tool_call: StreamToolCallCb,
) -> Result<StreamResponse> {
let mut state = RetryState::new(5);
loop {
let result = call_stream_once(
messages, model_name, config, temperature, max_tokens, tools,
on_text_delta.clone(), on_reasoning_delta.clone(), on_tool_call.clone(),
)
.await;
match result {
Ok(response) => return Ok(response),
Err(ref err) if state.should_retry() && is_retryable_error(err) => {
let duration = state.backoff_duration();
tracing::warn!(
attempt = state.attempt + 1,
max_retries = 5,
backoff_ms = duration.as_millis() as u64,
model = %model_name,
error = %err,
"ai_stream_retry"
);
tokio::time::sleep(duration).await;
state.next();
}
Err(err) => {
ai_metrics().record_failure();
return Err(err);
}
}
}
}
/// Single attempt of streaming completion with 60s timeout.
async fn call_stream_once(
messages: &[ChatRequestMessage],
model_name: &str,
config: &AiClientConfig,
temperature: f32,
max_tokens: u32,
tools: Option<&[serde_json::Value]>,
on_text_delta: StreamTextCb,
on_reasoning_delta: StreamReasoningCb,
on_tool_call: StreamToolCallCb,
) -> Result<StreamResponse> { ) -> Result<StreamResponse> {
let client = config.build_rig_client(); let client = config.build_rig_client();
let model = client.completion_model(model_name); let model = client.completion_model(model_name);
@ -506,107 +581,144 @@ pub async fn call_stream(
builder = builder.tools(tool_defs); builder = builder.tools(tool_defs);
} }
let mut stream = builder let stream_fut = async {
.stream() let mut stream = builder
.await .stream()
.map_err(|e| AgentError::OpenAi(e.to_string()))?; .await
.map_err(|e| AgentError::OpenAi(e.to_string()))?;
let mut content = String::new(); let mut content = String::new();
let mut tool_calls: Vec<StreamedToolCall> = Vec::new(); let mut reasoning_content = String::new();
let mut tool_calls: Vec<StreamedToolCall> = Vec::new();
let mut chunks: Vec<StreamChunk> = Vec::new();
// Track partial tool calls by internal_call_id for argument accumulation use std::collections::HashMap;
use std::collections::HashMap; let mut partial_tool_calls: HashMap<String, StreamedToolCall> = HashMap::new();
let mut partial_tool_calls: HashMap<String, StreamedToolCall> = HashMap::new(); let mut stream_finished = false;
let mut stream_finished = false;
use rig::streaming::StreamedAssistantContent; use rig::streaming::StreamedAssistantContent;
while let Some(item) = stream.next().await { while let Some(item) = stream.next().await {
match item { match item {
Ok(StreamedAssistantContent::Text(text)) => { Ok(StreamedAssistantContent::Text(text)) => {
content.push_str(&text.text); content.push_str(&text.text);
on_text_delta(&text.text); on_text_delta(&text.text).await;
} chunks.push(StreamChunk {
Ok(StreamedAssistantContent::ToolCall { chunk_type: StreamChunkType::Answer,
tool_call, content: text.text,
internal_call_id, });
}) => { }
// Complete tool call - extract arguments from the JSON Value Ok(StreamedAssistantContent::ToolCall {
let arguments = match &tool_call.function.arguments { tool_call,
serde_json::Value::String(s) => s.clone(), internal_call_id,
other => serde_json::to_string(other).unwrap_or_else(|_| "{}".to_string()), }) => {
}; let arguments = match &tool_call.function.arguments {
tool_calls.push(StreamedToolCall { serde_json::Value::String(s) => s.clone(),
id: tool_call.id.clone(), other => serde_json::to_string(other).unwrap_or_else(|_| "{}".to_string()),
name: tool_call.function.name.clone(), };
arguments, let tc = StreamedToolCall {
}); id: tool_call.id.clone(),
// Remove from partial if it was being accumulated name: tool_call.function.name.clone(),
partial_tool_calls.remove(&internal_call_id); arguments,
} };
Ok(StreamedAssistantContent::ToolCallDelta { on_tool_call(&tc).await;
id, chunks.push(StreamChunk {
internal_call_id, chunk_type: StreamChunkType::ToolCall,
content, content: serde_json::json!({
}) => { "id": tc.id,
use rig::streaming::ToolCallDeltaContent; "name": tc.name,
match content { "arguments": tc.arguments,
ToolCallDeltaContent::Name(name) => { }).to_string(),
// Start accumulating a new tool call });
partial_tool_calls.insert( tool_calls.push(tc);
internal_call_id.clone(), partial_tool_calls.remove(&internal_call_id);
StreamedToolCall { }
id: id.clone(), Ok(StreamedAssistantContent::ToolCallDelta {
name, id,
arguments: String::new(), internal_call_id,
}, content: delta_content,
); }) => {
} use rig::streaming::ToolCallDeltaContent;
ToolCallDeltaContent::Delta(delta) => { match delta_content {
// Append to existing partial tool call ToolCallDeltaContent::Name(name) => {
if let Some(tc) = partial_tool_calls.get_mut(&internal_call_id) { partial_tool_calls.insert(
tc.arguments.push_str(&delta); internal_call_id.clone(),
StreamedToolCall {
id: id.clone(),
name,
arguments: String::new(),
},
);
}
ToolCallDeltaContent::Delta(delta) => {
if let Some(tc) = partial_tool_calls.get_mut(&internal_call_id) {
tc.arguments.push_str(&delta);
}
} }
} }
} }
} Ok(StreamedAssistantContent::Reasoning(reasoning)) => {
Ok(StreamedAssistantContent::Reasoning(_)) => {} for part in &reasoning.reasoning {
Ok(StreamedAssistantContent::ReasoningDelta { .. }) => {} reasoning_content.push_str(part);
Ok(StreamedAssistantContent::Final(response)) => { on_reasoning_delta(part).await;
stream_finished = true; chunks.push(StreamChunk {
// Flush any remaining partial tool calls chunk_type: StreamChunkType::Thinking,
for (_, tc) in partial_tool_calls.drain() { content: part.clone(),
tool_calls.push(tc); });
}
} }
if let Some(usage) = response.token_usage() { Ok(StreamedAssistantContent::ReasoningDelta { reasoning, .. }) => {
ai_metrics().record_success( reasoning_content.push_str(&reasoning);
usage.input_tokens as i64, on_reasoning_delta(&reasoning).await;
usage.output_tokens as i64, chunks.push(StreamChunk {
!tool_calls.is_empty(), chunk_type: StreamChunkType::Thinking,
); content: reasoning.clone(),
return Ok(StreamResponse {
content,
input_tokens: usage.input_tokens as i64,
output_tokens: usage.output_tokens as i64,
tool_calls,
}); });
} }
Ok(StreamedAssistantContent::Final(response)) => {
stream_finished = true;
for (_, tc) in partial_tool_calls.drain() {
tool_calls.push(tc);
}
if let Some(usage) = response.token_usage() {
let in_toks = usage.input_tokens as i64;
let out_toks = usage.output_tokens as i64;
ai_metrics().record_success(in_toks, out_toks, !tool_calls.is_empty());
return Ok(StreamResponse {
content,
reasoning_content,
input_tokens: in_toks,
output_tokens: out_toks,
tool_calls,
chunks,
});
}
// Usage not available from Final — fall through to flush
}
Err(e) => return Err(AgentError::OpenAi(e.to_string())),
} }
Err(e) => return Err(AgentError::OpenAi(e.to_string())),
} }
}
// Flush any remaining partial tool calls (if stream ended without Final) // Flush any remaining partial tool calls (if stream ended without Final or Final had no usage)
if !stream_finished { if !stream_finished {
for (_, tc) in partial_tool_calls.drain() { for (_, tc) in partial_tool_calls.drain() {
tool_calls.push(tc); tool_calls.push(tc);
}
} }
ai_metrics().record_success(0, 0, !tool_calls.is_empty());
Ok(StreamResponse {
content,
reasoning_content,
input_tokens: 0,
output_tokens: 0,
tool_calls,
chunks,
})
};
// 60s timeout for the entire stream
match tokio::time::timeout(std::time::Duration::from_secs(60), stream_fut).await {
Ok(result) => result,
Err(_) => Err(AgentError::Timeout { task_id: 0, seconds: 60 }),
} }
ai_metrics().record_success(0, 0, !tool_calls.is_empty());
Ok(StreamResponse {
content,
input_tokens: 0,
output_tokens: 0,
tool_calls,
})
} }