feat(agent): add ordered stream chunk collection + retry for HTTP errors

- StreamChunk/StreamChunkType types for preserving arrival order
- Chunk collection in call_stream_once and process_stream
- Add "error sending request" and "Http client error" to retryable errors
- StreamResult includes chunks vector for ordered replay
This commit is contained in:
ZhenYi 2026-04-26 13:10:26 +08:00
parent 0b5dc98ce5
commit b4b5538447
2 changed files with 410 additions and 211 deletions

View File

@ -1,4 +1,5 @@
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use models::projects::project_skill; use models::projects::project_skill;
use models::rooms::room_ai; use models::rooms::room_ai;
@ -9,7 +10,7 @@ use super::context::RoomMessageContext;
use super::{AiChunkType, AiRequest, AiStreamChunk, Mention, StreamCallback}; use super::{AiChunkType, AiRequest, AiStreamChunk, Mention, StreamCallback};
use crate::client::types::{ChatRequestMessage, ToolCall}; use crate::client::types::{ChatRequestMessage, ToolCall};
use crate::client::AiClientConfig; use crate::client::AiClientConfig;
use crate::client::{call_stream, call_with_params}; use crate::client::{call_stream, call_with_params, StreamChunk, StreamChunkType, StreamedToolCall};
use crate::compact::{CompactConfig, CompactService}; use crate::compact::{CompactConfig, CompactService};
use crate::embed::EmbedService; use crate::embed::EmbedService;
use crate::error::{AgentError, Result}; use crate::error::{AgentError, Result};
@ -17,6 +18,23 @@ use crate::perception::{PerceptionService, SkillEntry, ToolCallEvent};
use crate::react::{ReactAgent, ReactConfig, DEFAULT_SYSTEM_PROMPT}; use crate::react::{ReactAgent, ReactConfig, DEFAULT_SYSTEM_PROMPT};
use crate::tool::{ToolCall as AgentToolCall, ToolContext, ToolExecutor, ToolResult, registry::ToolRegistry}; use crate::tool::{ToolCall as AgentToolCall, ToolContext, ToolExecutor, ToolResult, registry::ToolRegistry};
/// Result from streaming AI response.
pub struct StreamResult {
pub content: String,
pub reasoning_content: String,
pub input_tokens: i64,
pub output_tokens: i64,
/// All chunks in arrival order — preserves ReAct multi-cycle ordering.
pub chunks: Vec<StreamChunk>,
}
/// Result from non-streaming AI response.
pub struct ProcessResult {
pub content: String,
pub input_tokens: i64,
pub output_tokens: i64,
}
/// Service for handling AI chat requests in rooms. /// Service for handling AI chat requests in rooms.
pub struct ChatService { pub struct ChatService {
ai_base_url: Option<String>, ai_base_url: Option<String>,
@ -97,7 +115,7 @@ impl ChatService {
self.tool_registry.as_ref() self.tool_registry.as_ref()
} }
pub async fn process(&self, request: AiRequest) -> Result<String> { pub async fn process(&self, request: AiRequest) -> Result<ProcessResult> {
let tools: Vec<serde_json::Value> = request.tools.clone().unwrap_or_default(); let tools: Vec<serde_json::Value> = request.tools.clone().unwrap_or_default();
let tools_enabled = !tools.is_empty(); let tools_enabled = !tools.is_empty();
let max_tool_depth = request.max_tool_depth; let max_tool_depth = request.max_tool_depth;
@ -120,6 +138,8 @@ impl ChatService {
.and_then(|r| r.max_tokens.map(|v| v as u32)) .and_then(|r| r.max_tokens.map(|v| v as u32))
.unwrap_or(request.max_tokens as u32); .unwrap_or(request.max_tokens as u32);
let mut tool_depth = 0; let mut tool_depth = 0;
let mut input_tokens = 0i64;
let mut output_tokens = 0i64;
let config = AiClientConfig::new( let config = AiClientConfig::new(
self.ai_api_key.clone().unwrap_or_default(), self.ai_api_key.clone().unwrap_or_default(),
@ -140,6 +160,8 @@ impl ChatService {
.await?; .await?;
let text = response.content.clone(); let text = response.content.clone();
input_tokens += response.input_tokens;
output_tokens += response.output_tokens;
if tools_enabled && !response.tool_calls_finished.is_empty() { if tools_enabled && !response.tool_calls_finished.is_empty() {
// Build assistant message with tool_calls // Build assistant message with tool_calls
@ -176,17 +198,30 @@ impl ChatService {
}) })
.collect(); .collect();
let tool_messages = match self.execute_tool_calls(calls, &request).await { let tool_messages = {
Ok(msgs) => msgs, let mut ctx = ToolContext::new(
request.db.clone(),
request.cache.clone(),
request.config.clone(),
request.room.id,
Some(request.sender.uid),
)
.with_project(request.project.id);
if let Some(ref registry) = self.tool_registry {
ctx.registry_mut().merge(registry.clone());
}
let executor = ToolExecutor::new();
match executor.execute_batch(calls, &mut ctx).await {
Ok(results) => ToolExecutor::to_tool_messages(&results),
Err(e) => { Err(e) => {
let err_msg = format!("[Tool call failed: {}]", e); let err_msg = format!("[Tool call failed: {}]", e);
// Return error as a single tool result per call
response response
.tool_calls_finished .tool_calls_finished
.iter() .iter()
.map(|_| ChatRequestMessage::tool(Uuid::new_v4().to_string(), &err_msg)) .map(|_| ChatRequestMessage::tool(Uuid::new_v4().to_string(), &err_msg))
.collect() .collect()
} }
}
}; };
messages.extend(tool_messages); messages.extend(tool_messages);
@ -225,22 +260,26 @@ impl ChatService {
tool_depth += 1; tool_depth += 1;
if tool_depth >= max_tool_depth { if tool_depth >= max_tool_depth {
if text.is_empty() { let content = if text.is_empty() {
return Ok(format!( format!(
"[AI reached maximum tool depth ({}) — no final answer produced]", "[AI reached maximum tool depth ({}) — no final answer produced]",
max_tool_depth max_tool_depth
)); )
} } else {
return Ok(text); text
};
return Ok(ProcessResult { content, input_tokens, output_tokens });
} }
continue; continue;
} }
return Ok(text); return Ok(ProcessResult { content: text, input_tokens, output_tokens });
} }
} }
pub async fn process_stream(&self, request: AiRequest, on_chunk: StreamCallback) -> Result<String> { pub async fn process_stream(&self, request: AiRequest, on_chunk: StreamCallback) -> Result<StreamResult> {
// Wrap on_chunk in Arc so it can be shared across loop iterations
let on_chunk = Arc::new(on_chunk);
let tools: Vec<serde_json::Value> = request.tools.clone().unwrap_or_default(); let tools: Vec<serde_json::Value> = request.tools.clone().unwrap_or_default();
let tools_enabled = !tools.is_empty(); let tools_enabled = !tools.is_empty();
let max_tool_depth = request.max_tool_depth; let max_tool_depth = request.max_tool_depth;
@ -270,13 +309,15 @@ impl ChatService {
.with_base_url(self.ai_base_url.clone().unwrap_or_else(|| "https://api.openai.com".into())); .with_base_url(self.ai_base_url.clone().unwrap_or_else(|| "https://api.openai.com".into()));
let mut full_content = String::new(); let mut full_content = String::new();
let mut has_called_tools = false; let mut all_chunks: Vec<StreamChunk> = Vec::new();
// Collect tool calls during streaming, push them incrementally after.
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<StreamedToolCall>();
loop { loop {
let chunk_type = if has_called_tools { let on_chunk_cb = on_chunk.clone();
AiChunkType::Answer let on_chunk_cb2 = on_chunk_cb.clone();
} else { let tx_arc = Arc::new(tx.clone());
AiChunkType::Thinking let tx_arc2 = tx_arc.clone();
};
let response = call_stream( let response = call_stream(
&messages, &messages,
&model_name, &model_name,
@ -284,18 +325,36 @@ impl ChatService {
temperature, temperature,
max_tokens, max_tokens,
if tools_enabled { Some(&tools) } else { None }, if tools_enabled { Some(&tools) } else { None },
|delta| { Arc::new(move |delta| {
let _ = on_chunk(AiStreamChunk { let fut = on_chunk_cb(AiStreamChunk {
content: delta.to_string(), content: delta.to_string(),
done: false, done: false,
chunk_type: chunk_type.clone(), chunk_type: AiChunkType::Answer,
}); });
}, fut
}),
Arc::new(move |delta| {
let fut = on_chunk_cb2(AiStreamChunk {
content: delta.to_string(),
done: false,
chunk_type: AiChunkType::Thinking,
});
fut
}),
Arc::new(move |tc: &StreamedToolCall| {
let tx = tx_arc2.clone();
let tc_owned = tc.clone();
Box::pin(async move {
let _ = tx.send(tc_owned);
}) as Pin<Box<dyn std::future::Future<Output = ()> + Send>>
}),
) )
.await?; .await?;
// Collect chunks from this streaming iteration in order.
all_chunks.extend(response.chunks);
let has_tool_calls = tools_enabled && !response.tool_calls.is_empty(); let has_tool_calls = tools_enabled && !response.tool_calls.is_empty();
has_called_tools = true;
if has_tool_calls { if has_tool_calls {
// Accumulate the assistant's text before tool calls // Accumulate the assistant's text before tool calls
@ -321,28 +380,34 @@ impl ChatService {
Some(tool_calls.clone()), Some(tool_calls.clone()),
)); ));
// Stream tool call summary to frontend // Push each tool call incrementally to frontend.
let call_summary: Vec<String> = response // Use try_recv() — tx is never dropped so recv() would deadlock.
.tool_calls loop {
.iter() match rx.try_recv() {
.map(|tc| { Ok(tc) => {
// Truncate long arguments for display
let args_display = if tc.arguments.len() > 100 { let args_display = if tc.arguments.len() > 100 {
format!("{}...", &tc.arguments[..100]) format!("{}...", &tc.arguments[..100])
} else { } else {
tc.arguments.clone() tc.arguments.clone()
}; };
format!("{}({})", tc.name, args_display) let tool_display = format!("🔧 {}({})", tc.name, args_display);
})
.collect();
on_chunk(AiStreamChunk { on_chunk(AiStreamChunk {
content: format!("[Calling tools: {}]", call_summary.join(", ")), content: tool_display.clone(),
done: false, done: false,
chunk_type: AiChunkType::ToolCall, chunk_type: AiChunkType::ToolCall,
}) })
.await; .await;
all_chunks.push(StreamChunk {
chunk_type: StreamChunkType::ToolCall,
content: tool_display,
});
}
Err(tokio::sync::mpsc::error::TryRecvError::Empty) => break,
Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => break,
}
}
// Execute tools with full arguments from streaming // Execute tools one at a time, push each result incrementally
let calls: Vec<AgentToolCall> = response let calls: Vec<AgentToolCall> = response
.tool_calls .tool_calls
.iter() .iter()
@ -353,43 +418,71 @@ impl ChatService {
}) })
.collect(); .collect();
let tool_messages = match self.execute_tool_calls(calls, &request).await { let mut tool_messages = Vec::new();
Ok(msgs) => { for call in &calls {
let result_summary: Vec<String> = msgs let ctx = &mut crate::tool::ToolContext::new(
.iter() request.db.clone(),
.map(|m| { request.cache.clone(),
let text = m.content.as_deref().unwrap_or("[no content]"); request.config.clone(),
if text.len() > 300 { request.room.id,
format!("{}...", &text[..300]) Some(request.sender.uid),
} else { );
text.to_string() if let Some(ref registry) = self.tool_registry {
} ctx.registry_mut().merge(registry.clone());
})
.collect();
on_chunk(AiStreamChunk {
content: format!("[Tool results: {}]", result_summary.join("; ")),
done: false,
chunk_type: AiChunkType::ToolResult,
})
.await;
msgs
} }
let executor = crate::tool::ToolExecutor::new();
let results = match executor.execute_batch(vec![call.clone()], ctx).await {
Ok(r) => r,
Err(e) => { Err(e) => {
let err_text = format!("[Tool call failed: {}]", e); let err_text = format!("[Tool call failed: {}]", e);
tracing::warn!(tool = %call.name, error = %e, "tool_call_failed");
// Do NOT emit tool_result chunks to frontend — show error via tool_call instead
let err_display = format!("{} (failed)", call.name);
on_chunk(AiStreamChunk { on_chunk(AiStreamChunk {
content: err_text.clone(), content: err_display.clone(),
done: false, done: false,
chunk_type: AiChunkType::ToolResult, chunk_type: AiChunkType::ToolCall,
}) })
.await; .await;
// Return error tool messages all_chunks.push(StreamChunk {
response chunk_type: StreamChunkType::ToolCall,
.tool_calls content: err_display,
.iter() });
.map(|tc| ChatRequestMessage::tool(&tc.id, &err_text)) tool_messages.push(ChatRequestMessage::tool(&call.id, &err_text));
.collect() continue;
} }
}; };
for result in &results {
let text = match &result.result {
crate::tool::ToolResult::Ok(v) => v.to_string(),
crate::tool::ToolResult::Error(msg) => msg.clone(),
};
let preview = if text.len() > 300 {
format!("{}...", &text[..300])
} else {
text.clone()
};
tracing::debug!("tool_result: {} — {}", call.name, preview);
// Do NOT emit tool_result chunks to frontend — raw output may contain sensitive data.
// Log server-side only; frontend sees tool_call status via on_chunk below.
}
let success_display = format!("{}", call.name);
on_chunk(AiStreamChunk {
content: success_display.clone(),
done: false,
chunk_type: AiChunkType::ToolCall,
})
.await;
all_chunks.push(StreamChunk {
chunk_type: StreamChunkType::ToolCall,
content: success_display,
});
let msgs = crate::tool::ToolExecutor::to_tool_messages(&results);
tool_messages.extend(msgs);
}
messages.extend(tool_messages); messages.extend(tool_messages);
// Inject passive-detected skills based on tool calls // Inject passive-detected skills based on tool calls
@ -427,60 +520,54 @@ impl ChatService {
tool_depth += 1; tool_depth += 1;
if tool_depth >= max_tool_depth { if tool_depth >= max_tool_depth {
on_chunk(AiStreamChunk { let max_depth_text = format!(
content: format!(
"[AI reached maximum tool depth ({}) — no final answer produced]", "[AI reached maximum tool depth ({}) — no final answer produced]",
max_tool_depth max_tool_depth
), );
on_chunk(AiStreamChunk {
content: max_depth_text.clone(),
done: true, done: true,
chunk_type: AiChunkType::Answer, chunk_type: AiChunkType::Answer,
}) })
.await; .await;
return Ok(full_content); all_chunks.push(StreamChunk {
chunk_type: StreamChunkType::Answer,
content: max_depth_text,
});
return Ok(StreamResult {
content: full_content,
reasoning_content: String::new(),
input_tokens: 0,
output_tokens: 0,
chunks: all_chunks,
});
} }
continue; continue;
} }
// Final answer — accumulate and return // Final answer — accumulate and return
full_content.push_str(&response.content); full_content.push_str(&response.content);
on_chunk(AiStreamChunk { on_chunk(AiStreamChunk {
content: response.content, content: response.content.clone(),
done: true, done: true,
chunk_type: AiChunkType::Answer, chunk_type: AiChunkType::Answer,
}) })
.await; .await;
return Ok(full_content); all_chunks.push(StreamChunk {
chunk_type: StreamChunkType::Answer,
content: response.content.clone(),
});
return Ok(StreamResult {
content: full_content,
reasoning_content: response.reasoning_content,
input_tokens: response.input_tokens,
output_tokens: response.output_tokens,
chunks: all_chunks,
});
} }
} }
/// Executes a batch of tool calls and returns the tool result messages.
async fn execute_tool_calls(
&self,
calls: Vec<AgentToolCall>,
request: &AiRequest,
) -> Result<Vec<ChatRequestMessage>> {
let mut ctx = ToolContext::new(
request.db.clone(),
request.cache.clone(),
request.config.clone(),
request.room.id,
Some(request.sender.uid),
)
.with_project(request.project.id);
if let Some(ref registry) = self.tool_registry {
ctx.registry_mut().merge(registry.clone());
}
let executor = ToolExecutor::new();
let results = executor
.execute_batch(calls, &mut ctx)
.await
.map_err(|e| AgentError::Internal(e.to_string()))?;
Ok(ToolExecutor::to_tool_messages(&results))
}
async fn build_messages(&self, request: &AiRequest) -> Result<Vec<ChatRequestMessage>> { async fn build_messages(&self, request: &AiRequest) -> Result<Vec<ChatRequestMessage>> {
let mut messages = Vec::new(); let mut messages = Vec::new();

View File

@ -5,6 +5,8 @@
pub mod types; pub mod types;
pub use types::{ChatRequestMessage, ToolCall as ClientToolCall}; pub use types::{ChatRequestMessage, ToolCall as ClientToolCall};
use std::pin::Pin;
use std::sync::Arc;
use std::time::Instant; use std::time::Instant;
use uuid::Uuid; use uuid::Uuid;
@ -130,6 +132,8 @@ fn is_retryable_error(err: &AgentError) -> bool {
|| msg.contains("connection timed out") || msg.contains("connection timed out")
|| msg.contains("network error") || msg.contains("network error")
|| msg.contains("dns error") || msg.contains("dns error")
|| msg.contains("error sending request")
|| msg.contains("Http client error")
|| msg.contains("rate_limit") || msg.contains("rate_limit")
|| msg.contains("rate limit") || msg.contains("rate limit")
|| msg.contains("429") || msg.contains("429")
@ -451,17 +455,42 @@ pub struct StreamedToolCall {
pub arguments: String, pub arguments: String,
} }
/// Type of chunk in the streaming response, preserving arrival order.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum StreamChunkType {
Thinking,
Answer,
ToolCall,
}
/// A single chunk from the streaming response in arrival order.
#[derive(Debug, Clone)]
pub struct StreamChunk {
pub chunk_type: StreamChunkType,
pub content: String,
}
/// Streaming result from rig. /// Streaming result from rig.
#[derive(Debug)] #[derive(Debug)]
pub struct StreamResponse { pub struct StreamResponse {
pub content: String, pub content: String,
pub input_tokens: i64, pub input_tokens: i64,
pub output_tokens: i64, pub output_tokens: i64,
/// Accumulated reasoning/thinking text from the model.
pub reasoning_content: String,
/// Full tool calls with accumulated arguments (not just names) /// Full tool calls with accumulated arguments (not just names)
pub tool_calls: Vec<StreamedToolCall>, pub tool_calls: Vec<StreamedToolCall>,
/// All chunks in arrival order — preserves think/answer/tool interleaving.
pub chunks: Vec<StreamChunk>,
} }
/// Run a streaming chat completion. /// Async callback: takes a string delta and broadcasts it to the WebSocket.
/// The returned Future must be awaited by the caller.
pub type StreamTextCb = Arc<dyn Fn(&str) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>> + Send + Sync>;
pub type StreamReasoningCb = Arc<dyn Fn(&str) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>> + Send + Sync>;
pub type StreamToolCallCb = Arc<dyn Fn(&StreamedToolCall) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>> + Send + Sync>;
/// Run a streaming chat completion with 60s timeout and 5 retries.
pub async fn call_stream( pub async fn call_stream(
messages: &[ChatRequestMessage], messages: &[ChatRequestMessage],
model_name: &str, model_name: &str,
@ -469,7 +498,53 @@ pub async fn call_stream(
temperature: f32, temperature: f32,
max_tokens: u32, max_tokens: u32,
tools: Option<&[serde_json::Value]>, tools: Option<&[serde_json::Value]>,
mut on_text_delta: impl FnMut(&str), on_text_delta: StreamTextCb,
on_reasoning_delta: StreamReasoningCb,
on_tool_call: StreamToolCallCb,
) -> Result<StreamResponse> {
let mut state = RetryState::new(5);
loop {
let result = call_stream_once(
messages, model_name, config, temperature, max_tokens, tools,
on_text_delta.clone(), on_reasoning_delta.clone(), on_tool_call.clone(),
)
.await;
match result {
Ok(response) => return Ok(response),
Err(ref err) if state.should_retry() && is_retryable_error(err) => {
let duration = state.backoff_duration();
tracing::warn!(
attempt = state.attempt + 1,
max_retries = 5,
backoff_ms = duration.as_millis() as u64,
model = %model_name,
error = %err,
"ai_stream_retry"
);
tokio::time::sleep(duration).await;
state.next();
}
Err(err) => {
ai_metrics().record_failure();
return Err(err);
}
}
}
}
/// Single attempt of streaming completion with 60s timeout.
async fn call_stream_once(
messages: &[ChatRequestMessage],
model_name: &str,
config: &AiClientConfig,
temperature: f32,
max_tokens: u32,
tools: Option<&[serde_json::Value]>,
on_text_delta: StreamTextCb,
on_reasoning_delta: StreamReasoningCb,
on_tool_call: StreamToolCallCb,
) -> Result<StreamResponse> { ) -> Result<StreamResponse> {
let client = config.build_rig_client(); let client = config.build_rig_client();
let model = client.completion_model(model_name); let model = client.completion_model(model_name);
@ -506,15 +581,17 @@ pub async fn call_stream(
builder = builder.tools(tool_defs); builder = builder.tools(tool_defs);
} }
let stream_fut = async {
let mut stream = builder let mut stream = builder
.stream() .stream()
.await .await
.map_err(|e| AgentError::OpenAi(e.to_string()))?; .map_err(|e| AgentError::OpenAi(e.to_string()))?;
let mut content = String::new(); let mut content = String::new();
let mut reasoning_content = String::new();
let mut tool_calls: Vec<StreamedToolCall> = Vec::new(); let mut tool_calls: Vec<StreamedToolCall> = Vec::new();
let mut chunks: Vec<StreamChunk> = Vec::new();
// Track partial tool calls by internal_call_id for argument accumulation
use std::collections::HashMap; use std::collections::HashMap;
let mut partial_tool_calls: HashMap<String, StreamedToolCall> = HashMap::new(); let mut partial_tool_calls: HashMap<String, StreamedToolCall> = HashMap::new();
let mut stream_finished = false; let mut stream_finished = false;
@ -525,34 +602,45 @@ pub async fn call_stream(
match item { match item {
Ok(StreamedAssistantContent::Text(text)) => { Ok(StreamedAssistantContent::Text(text)) => {
content.push_str(&text.text); content.push_str(&text.text);
on_text_delta(&text.text); on_text_delta(&text.text).await;
chunks.push(StreamChunk {
chunk_type: StreamChunkType::Answer,
content: text.text,
});
} }
Ok(StreamedAssistantContent::ToolCall { Ok(StreamedAssistantContent::ToolCall {
tool_call, tool_call,
internal_call_id, internal_call_id,
}) => { }) => {
// Complete tool call - extract arguments from the JSON Value
let arguments = match &tool_call.function.arguments { let arguments = match &tool_call.function.arguments {
serde_json::Value::String(s) => s.clone(), serde_json::Value::String(s) => s.clone(),
other => serde_json::to_string(other).unwrap_or_else(|_| "{}".to_string()), other => serde_json::to_string(other).unwrap_or_else(|_| "{}".to_string()),
}; };
tool_calls.push(StreamedToolCall { let tc = StreamedToolCall {
id: tool_call.id.clone(), id: tool_call.id.clone(),
name: tool_call.function.name.clone(), name: tool_call.function.name.clone(),
arguments, arguments,
};
on_tool_call(&tc).await;
chunks.push(StreamChunk {
chunk_type: StreamChunkType::ToolCall,
content: serde_json::json!({
"id": tc.id,
"name": tc.name,
"arguments": tc.arguments,
}).to_string(),
}); });
// Remove from partial if it was being accumulated tool_calls.push(tc);
partial_tool_calls.remove(&internal_call_id); partial_tool_calls.remove(&internal_call_id);
} }
Ok(StreamedAssistantContent::ToolCallDelta { Ok(StreamedAssistantContent::ToolCallDelta {
id, id,
internal_call_id, internal_call_id,
content, content: delta_content,
}) => { }) => {
use rig::streaming::ToolCallDeltaContent; use rig::streaming::ToolCallDeltaContent;
match content { match delta_content {
ToolCallDeltaContent::Name(name) => { ToolCallDeltaContent::Name(name) => {
// Start accumulating a new tool call
partial_tool_calls.insert( partial_tool_calls.insert(
internal_call_id.clone(), internal_call_id.clone(),
StreamedToolCall { StreamedToolCall {
@ -563,40 +651,55 @@ pub async fn call_stream(
); );
} }
ToolCallDeltaContent::Delta(delta) => { ToolCallDeltaContent::Delta(delta) => {
// Append to existing partial tool call
if let Some(tc) = partial_tool_calls.get_mut(&internal_call_id) { if let Some(tc) = partial_tool_calls.get_mut(&internal_call_id) {
tc.arguments.push_str(&delta); tc.arguments.push_str(&delta);
} }
} }
} }
} }
Ok(StreamedAssistantContent::Reasoning(_)) => {} Ok(StreamedAssistantContent::Reasoning(reasoning)) => {
Ok(StreamedAssistantContent::ReasoningDelta { .. }) => {} for part in &reasoning.reasoning {
reasoning_content.push_str(part);
on_reasoning_delta(part).await;
chunks.push(StreamChunk {
chunk_type: StreamChunkType::Thinking,
content: part.clone(),
});
}
}
Ok(StreamedAssistantContent::ReasoningDelta { reasoning, .. }) => {
reasoning_content.push_str(&reasoning);
on_reasoning_delta(&reasoning).await;
chunks.push(StreamChunk {
chunk_type: StreamChunkType::Thinking,
content: reasoning.clone(),
});
}
Ok(StreamedAssistantContent::Final(response)) => { Ok(StreamedAssistantContent::Final(response)) => {
stream_finished = true; stream_finished = true;
// Flush any remaining partial tool calls
for (_, tc) in partial_tool_calls.drain() { for (_, tc) in partial_tool_calls.drain() {
tool_calls.push(tc); tool_calls.push(tc);
} }
if let Some(usage) = response.token_usage() { if let Some(usage) = response.token_usage() {
ai_metrics().record_success( let in_toks = usage.input_tokens as i64;
usage.input_tokens as i64, let out_toks = usage.output_tokens as i64;
usage.output_tokens as i64, ai_metrics().record_success(in_toks, out_toks, !tool_calls.is_empty());
!tool_calls.is_empty(),
);
return Ok(StreamResponse { return Ok(StreamResponse {
content, content,
input_tokens: usage.input_tokens as i64, reasoning_content,
output_tokens: usage.output_tokens as i64, input_tokens: in_toks,
output_tokens: out_toks,
tool_calls, tool_calls,
chunks,
}); });
} }
// Usage not available from Final — fall through to flush
} }
Err(e) => return Err(AgentError::OpenAi(e.to_string())), Err(e) => return Err(AgentError::OpenAi(e.to_string())),
} }
} }
// Flush any remaining partial tool calls (if stream ended without Final) // Flush any remaining partial tool calls (if stream ended without Final or Final had no usage)
if !stream_finished { if !stream_finished {
for (_, tc) in partial_tool_calls.drain() { for (_, tc) in partial_tool_calls.drain() {
tool_calls.push(tc); tool_calls.push(tc);
@ -605,8 +708,17 @@ pub async fn call_stream(
ai_metrics().record_success(0, 0, !tool_calls.is_empty()); ai_metrics().record_success(0, 0, !tool_calls.is_empty());
Ok(StreamResponse { Ok(StreamResponse {
content, content,
reasoning_content,
input_tokens: 0, input_tokens: 0,
output_tokens: 0, output_tokens: 0,
tool_calls, tool_calls,
chunks,
}) })
};
// 60s timeout for the entire stream
match tokio::time::timeout(std::time::Duration::from_secs(60), stream_fut).await {
Ok(result) => result,
Err(_) => Err(AgentError::Timeout { task_id: 0, seconds: 60 }),
}
} }