feat(agent): add ordered stream chunk collection + retry for HTTP errors
- StreamChunk/StreamChunkType types for preserving arrival order - Chunk collection in call_stream_once and process_stream - Add "error sending request" and "Http client error" to retryable errors - StreamResult includes chunks vector for ordered replay
This commit is contained in:
parent
0b5dc98ce5
commit
b4b5538447
@ -1,4 +1,5 @@
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use models::projects::project_skill;
|
||||
use models::rooms::room_ai;
|
||||
@ -9,7 +10,7 @@ use super::context::RoomMessageContext;
|
||||
use super::{AiChunkType, AiRequest, AiStreamChunk, Mention, StreamCallback};
|
||||
use crate::client::types::{ChatRequestMessage, ToolCall};
|
||||
use crate::client::AiClientConfig;
|
||||
use crate::client::{call_stream, call_with_params};
|
||||
use crate::client::{call_stream, call_with_params, StreamChunk, StreamChunkType, StreamedToolCall};
|
||||
use crate::compact::{CompactConfig, CompactService};
|
||||
use crate::embed::EmbedService;
|
||||
use crate::error::{AgentError, Result};
|
||||
@ -17,6 +18,23 @@ use crate::perception::{PerceptionService, SkillEntry, ToolCallEvent};
|
||||
use crate::react::{ReactAgent, ReactConfig, DEFAULT_SYSTEM_PROMPT};
|
||||
use crate::tool::{ToolCall as AgentToolCall, ToolContext, ToolExecutor, ToolResult, registry::ToolRegistry};
|
||||
|
||||
/// Result from streaming AI response.
|
||||
pub struct StreamResult {
|
||||
pub content: String,
|
||||
pub reasoning_content: String,
|
||||
pub input_tokens: i64,
|
||||
pub output_tokens: i64,
|
||||
/// All chunks in arrival order — preserves ReAct multi-cycle ordering.
|
||||
pub chunks: Vec<StreamChunk>,
|
||||
}
|
||||
|
||||
/// Result from non-streaming AI response.
|
||||
pub struct ProcessResult {
|
||||
pub content: String,
|
||||
pub input_tokens: i64,
|
||||
pub output_tokens: i64,
|
||||
}
|
||||
|
||||
/// Service for handling AI chat requests in rooms.
|
||||
pub struct ChatService {
|
||||
ai_base_url: Option<String>,
|
||||
@ -97,7 +115,7 @@ impl ChatService {
|
||||
self.tool_registry.as_ref()
|
||||
}
|
||||
|
||||
pub async fn process(&self, request: AiRequest) -> Result<String> {
|
||||
pub async fn process(&self, request: AiRequest) -> Result<ProcessResult> {
|
||||
let tools: Vec<serde_json::Value> = request.tools.clone().unwrap_or_default();
|
||||
let tools_enabled = !tools.is_empty();
|
||||
let max_tool_depth = request.max_tool_depth;
|
||||
@ -120,6 +138,8 @@ impl ChatService {
|
||||
.and_then(|r| r.max_tokens.map(|v| v as u32))
|
||||
.unwrap_or(request.max_tokens as u32);
|
||||
let mut tool_depth = 0;
|
||||
let mut input_tokens = 0i64;
|
||||
let mut output_tokens = 0i64;
|
||||
|
||||
let config = AiClientConfig::new(
|
||||
self.ai_api_key.clone().unwrap_or_default(),
|
||||
@ -140,6 +160,8 @@ impl ChatService {
|
||||
.await?;
|
||||
|
||||
let text = response.content.clone();
|
||||
input_tokens += response.input_tokens;
|
||||
output_tokens += response.output_tokens;
|
||||
|
||||
if tools_enabled && !response.tool_calls_finished.is_empty() {
|
||||
// Build assistant message with tool_calls
|
||||
@ -176,17 +198,30 @@ impl ChatService {
|
||||
})
|
||||
.collect();
|
||||
|
||||
let tool_messages = match self.execute_tool_calls(calls, &request).await {
|
||||
Ok(msgs) => msgs,
|
||||
let tool_messages = {
|
||||
let mut ctx = ToolContext::new(
|
||||
request.db.clone(),
|
||||
request.cache.clone(),
|
||||
request.config.clone(),
|
||||
request.room.id,
|
||||
Some(request.sender.uid),
|
||||
)
|
||||
.with_project(request.project.id);
|
||||
if let Some(ref registry) = self.tool_registry {
|
||||
ctx.registry_mut().merge(registry.clone());
|
||||
}
|
||||
let executor = ToolExecutor::new();
|
||||
match executor.execute_batch(calls, &mut ctx).await {
|
||||
Ok(results) => ToolExecutor::to_tool_messages(&results),
|
||||
Err(e) => {
|
||||
let err_msg = format!("[Tool call failed: {}]", e);
|
||||
// Return error as a single tool result per call
|
||||
response
|
||||
.tool_calls_finished
|
||||
.iter()
|
||||
.map(|_| ChatRequestMessage::tool(Uuid::new_v4().to_string(), &err_msg))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
};
|
||||
messages.extend(tool_messages);
|
||||
|
||||
@ -225,22 +260,26 @@ impl ChatService {
|
||||
|
||||
tool_depth += 1;
|
||||
if tool_depth >= max_tool_depth {
|
||||
if text.is_empty() {
|
||||
return Ok(format!(
|
||||
let content = if text.is_empty() {
|
||||
format!(
|
||||
"[AI reached maximum tool depth ({}) — no final answer produced]",
|
||||
max_tool_depth
|
||||
));
|
||||
}
|
||||
return Ok(text);
|
||||
)
|
||||
} else {
|
||||
text
|
||||
};
|
||||
return Ok(ProcessResult { content, input_tokens, output_tokens });
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
return Ok(text);
|
||||
return Ok(ProcessResult { content: text, input_tokens, output_tokens });
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn process_stream(&self, request: AiRequest, on_chunk: StreamCallback) -> Result<String> {
|
||||
pub async fn process_stream(&self, request: AiRequest, on_chunk: StreamCallback) -> Result<StreamResult> {
|
||||
// Wrap on_chunk in Arc so it can be shared across loop iterations
|
||||
let on_chunk = Arc::new(on_chunk);
|
||||
let tools: Vec<serde_json::Value> = request.tools.clone().unwrap_or_default();
|
||||
let tools_enabled = !tools.is_empty();
|
||||
let max_tool_depth = request.max_tool_depth;
|
||||
@ -270,13 +309,15 @@ impl ChatService {
|
||||
.with_base_url(self.ai_base_url.clone().unwrap_or_else(|| "https://api.openai.com".into()));
|
||||
|
||||
let mut full_content = String::new();
|
||||
let mut has_called_tools = false;
|
||||
let mut all_chunks: Vec<StreamChunk> = Vec::new();
|
||||
// Collect tool calls during streaming, push them incrementally after.
|
||||
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<StreamedToolCall>();
|
||||
|
||||
loop {
|
||||
let chunk_type = if has_called_tools {
|
||||
AiChunkType::Answer
|
||||
} else {
|
||||
AiChunkType::Thinking
|
||||
};
|
||||
let on_chunk_cb = on_chunk.clone();
|
||||
let on_chunk_cb2 = on_chunk_cb.clone();
|
||||
let tx_arc = Arc::new(tx.clone());
|
||||
let tx_arc2 = tx_arc.clone();
|
||||
let response = call_stream(
|
||||
&messages,
|
||||
&model_name,
|
||||
@ -284,18 +325,36 @@ impl ChatService {
|
||||
temperature,
|
||||
max_tokens,
|
||||
if tools_enabled { Some(&tools) } else { None },
|
||||
|delta| {
|
||||
let _ = on_chunk(AiStreamChunk {
|
||||
Arc::new(move |delta| {
|
||||
let fut = on_chunk_cb(AiStreamChunk {
|
||||
content: delta.to_string(),
|
||||
done: false,
|
||||
chunk_type: chunk_type.clone(),
|
||||
chunk_type: AiChunkType::Answer,
|
||||
});
|
||||
},
|
||||
fut
|
||||
}),
|
||||
Arc::new(move |delta| {
|
||||
let fut = on_chunk_cb2(AiStreamChunk {
|
||||
content: delta.to_string(),
|
||||
done: false,
|
||||
chunk_type: AiChunkType::Thinking,
|
||||
});
|
||||
fut
|
||||
}),
|
||||
Arc::new(move |tc: &StreamedToolCall| {
|
||||
let tx = tx_arc2.clone();
|
||||
let tc_owned = tc.clone();
|
||||
Box::pin(async move {
|
||||
let _ = tx.send(tc_owned);
|
||||
}) as Pin<Box<dyn std::future::Future<Output = ()> + Send>>
|
||||
}),
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Collect chunks from this streaming iteration in order.
|
||||
all_chunks.extend(response.chunks);
|
||||
|
||||
let has_tool_calls = tools_enabled && !response.tool_calls.is_empty();
|
||||
has_called_tools = true;
|
||||
|
||||
if has_tool_calls {
|
||||
// Accumulate the assistant's text before tool calls
|
||||
@ -321,28 +380,34 @@ impl ChatService {
|
||||
Some(tool_calls.clone()),
|
||||
));
|
||||
|
||||
// Stream tool call summary to frontend
|
||||
let call_summary: Vec<String> = response
|
||||
.tool_calls
|
||||
.iter()
|
||||
.map(|tc| {
|
||||
// Truncate long arguments for display
|
||||
// Push each tool call incrementally to frontend.
|
||||
// Use try_recv() — tx is never dropped so recv() would deadlock.
|
||||
loop {
|
||||
match rx.try_recv() {
|
||||
Ok(tc) => {
|
||||
let args_display = if tc.arguments.len() > 100 {
|
||||
format!("{}...", &tc.arguments[..100])
|
||||
} else {
|
||||
tc.arguments.clone()
|
||||
};
|
||||
format!("{}({})", tc.name, args_display)
|
||||
})
|
||||
.collect();
|
||||
let tool_display = format!("🔧 {}({})", tc.name, args_display);
|
||||
on_chunk(AiStreamChunk {
|
||||
content: format!("[Calling tools: {}]", call_summary.join(", ")),
|
||||
content: tool_display.clone(),
|
||||
done: false,
|
||||
chunk_type: AiChunkType::ToolCall,
|
||||
})
|
||||
.await;
|
||||
all_chunks.push(StreamChunk {
|
||||
chunk_type: StreamChunkType::ToolCall,
|
||||
content: tool_display,
|
||||
});
|
||||
}
|
||||
Err(tokio::sync::mpsc::error::TryRecvError::Empty) => break,
|
||||
Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => break,
|
||||
}
|
||||
}
|
||||
|
||||
// Execute tools with full arguments from streaming
|
||||
// Execute tools one at a time, push each result incrementally
|
||||
let calls: Vec<AgentToolCall> = response
|
||||
.tool_calls
|
||||
.iter()
|
||||
@ -353,43 +418,71 @@ impl ChatService {
|
||||
})
|
||||
.collect();
|
||||
|
||||
let tool_messages = match self.execute_tool_calls(calls, &request).await {
|
||||
Ok(msgs) => {
|
||||
let result_summary: Vec<String> = msgs
|
||||
.iter()
|
||||
.map(|m| {
|
||||
let text = m.content.as_deref().unwrap_or("[no content]");
|
||||
if text.len() > 300 {
|
||||
format!("{}...", &text[..300])
|
||||
} else {
|
||||
text.to_string()
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
on_chunk(AiStreamChunk {
|
||||
content: format!("[Tool results: {}]", result_summary.join("; ")),
|
||||
done: false,
|
||||
chunk_type: AiChunkType::ToolResult,
|
||||
})
|
||||
.await;
|
||||
msgs
|
||||
let mut tool_messages = Vec::new();
|
||||
for call in &calls {
|
||||
let ctx = &mut crate::tool::ToolContext::new(
|
||||
request.db.clone(),
|
||||
request.cache.clone(),
|
||||
request.config.clone(),
|
||||
request.room.id,
|
||||
Some(request.sender.uid),
|
||||
);
|
||||
if let Some(ref registry) = self.tool_registry {
|
||||
ctx.registry_mut().merge(registry.clone());
|
||||
}
|
||||
|
||||
let executor = crate::tool::ToolExecutor::new();
|
||||
let results = match executor.execute_batch(vec![call.clone()], ctx).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
let err_text = format!("[Tool call failed: {}]", e);
|
||||
tracing::warn!(tool = %call.name, error = %e, "tool_call_failed");
|
||||
// Do NOT emit tool_result chunks to frontend — show error via tool_call instead
|
||||
let err_display = format!("❌ {} (failed)", call.name);
|
||||
on_chunk(AiStreamChunk {
|
||||
content: err_text.clone(),
|
||||
content: err_display.clone(),
|
||||
done: false,
|
||||
chunk_type: AiChunkType::ToolResult,
|
||||
chunk_type: AiChunkType::ToolCall,
|
||||
})
|
||||
.await;
|
||||
// Return error tool messages
|
||||
response
|
||||
.tool_calls
|
||||
.iter()
|
||||
.map(|tc| ChatRequestMessage::tool(&tc.id, &err_text))
|
||||
.collect()
|
||||
all_chunks.push(StreamChunk {
|
||||
chunk_type: StreamChunkType::ToolCall,
|
||||
content: err_display,
|
||||
});
|
||||
tool_messages.push(ChatRequestMessage::tool(&call.id, &err_text));
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
for result in &results {
|
||||
let text = match &result.result {
|
||||
crate::tool::ToolResult::Ok(v) => v.to_string(),
|
||||
crate::tool::ToolResult::Error(msg) => msg.clone(),
|
||||
};
|
||||
let preview = if text.len() > 300 {
|
||||
format!("{}...", &text[..300])
|
||||
} else {
|
||||
text.clone()
|
||||
};
|
||||
tracing::debug!("tool_result: {} — {}", call.name, preview);
|
||||
// Do NOT emit tool_result chunks to frontend — raw output may contain sensitive data.
|
||||
// Log server-side only; frontend sees tool_call status via on_chunk below.
|
||||
}
|
||||
let success_display = format!("✅ {}", call.name);
|
||||
on_chunk(AiStreamChunk {
|
||||
content: success_display.clone(),
|
||||
done: false,
|
||||
chunk_type: AiChunkType::ToolCall,
|
||||
})
|
||||
.await;
|
||||
all_chunks.push(StreamChunk {
|
||||
chunk_type: StreamChunkType::ToolCall,
|
||||
content: success_display,
|
||||
});
|
||||
|
||||
let msgs = crate::tool::ToolExecutor::to_tool_messages(&results);
|
||||
tool_messages.extend(msgs);
|
||||
}
|
||||
messages.extend(tool_messages);
|
||||
|
||||
// Inject passive-detected skills based on tool calls
|
||||
@ -427,60 +520,54 @@ impl ChatService {
|
||||
|
||||
tool_depth += 1;
|
||||
if tool_depth >= max_tool_depth {
|
||||
on_chunk(AiStreamChunk {
|
||||
content: format!(
|
||||
let max_depth_text = format!(
|
||||
"[AI reached maximum tool depth ({}) — no final answer produced]",
|
||||
max_tool_depth
|
||||
),
|
||||
);
|
||||
on_chunk(AiStreamChunk {
|
||||
content: max_depth_text.clone(),
|
||||
done: true,
|
||||
chunk_type: AiChunkType::Answer,
|
||||
})
|
||||
.await;
|
||||
return Ok(full_content);
|
||||
all_chunks.push(StreamChunk {
|
||||
chunk_type: StreamChunkType::Answer,
|
||||
content: max_depth_text,
|
||||
});
|
||||
return Ok(StreamResult {
|
||||
content: full_content,
|
||||
reasoning_content: String::new(),
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
chunks: all_chunks,
|
||||
});
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Final answer — accumulate and return
|
||||
full_content.push_str(&response.content);
|
||||
|
||||
on_chunk(AiStreamChunk {
|
||||
content: response.content,
|
||||
content: response.content.clone(),
|
||||
done: true,
|
||||
chunk_type: AiChunkType::Answer,
|
||||
})
|
||||
.await;
|
||||
return Ok(full_content);
|
||||
all_chunks.push(StreamChunk {
|
||||
chunk_type: StreamChunkType::Answer,
|
||||
content: response.content.clone(),
|
||||
});
|
||||
return Ok(StreamResult {
|
||||
content: full_content,
|
||||
reasoning_content: response.reasoning_content,
|
||||
input_tokens: response.input_tokens,
|
||||
output_tokens: response.output_tokens,
|
||||
chunks: all_chunks,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Executes a batch of tool calls and returns the tool result messages.
|
||||
async fn execute_tool_calls(
|
||||
&self,
|
||||
calls: Vec<AgentToolCall>,
|
||||
request: &AiRequest,
|
||||
) -> Result<Vec<ChatRequestMessage>> {
|
||||
let mut ctx = ToolContext::new(
|
||||
request.db.clone(),
|
||||
request.cache.clone(),
|
||||
request.config.clone(),
|
||||
request.room.id,
|
||||
Some(request.sender.uid),
|
||||
)
|
||||
.with_project(request.project.id);
|
||||
|
||||
if let Some(ref registry) = self.tool_registry {
|
||||
ctx.registry_mut().merge(registry.clone());
|
||||
}
|
||||
|
||||
let executor = ToolExecutor::new();
|
||||
let results = executor
|
||||
.execute_batch(calls, &mut ctx)
|
||||
.await
|
||||
.map_err(|e| AgentError::Internal(e.to_string()))?;
|
||||
|
||||
Ok(ToolExecutor::to_tool_messages(&results))
|
||||
}
|
||||
|
||||
async fn build_messages(&self, request: &AiRequest) -> Result<Vec<ChatRequestMessage>> {
|
||||
let mut messages = Vec::new();
|
||||
|
||||
|
||||
@ -5,6 +5,8 @@
|
||||
pub mod types;
|
||||
pub use types::{ChatRequestMessage, ToolCall as ClientToolCall};
|
||||
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
use uuid::Uuid;
|
||||
|
||||
@ -130,6 +132,8 @@ fn is_retryable_error(err: &AgentError) -> bool {
|
||||
|| msg.contains("connection timed out")
|
||||
|| msg.contains("network error")
|
||||
|| msg.contains("dns error")
|
||||
|| msg.contains("error sending request")
|
||||
|| msg.contains("Http client error")
|
||||
|| msg.contains("rate_limit")
|
||||
|| msg.contains("rate limit")
|
||||
|| msg.contains("429")
|
||||
@ -451,17 +455,42 @@ pub struct StreamedToolCall {
|
||||
pub arguments: String,
|
||||
}
|
||||
|
||||
/// Type of chunk in the streaming response, preserving arrival order.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum StreamChunkType {
|
||||
Thinking,
|
||||
Answer,
|
||||
ToolCall,
|
||||
}
|
||||
|
||||
/// A single chunk from the streaming response in arrival order.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StreamChunk {
|
||||
pub chunk_type: StreamChunkType,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
/// Streaming result from rig.
|
||||
#[derive(Debug)]
|
||||
pub struct StreamResponse {
|
||||
pub content: String,
|
||||
pub input_tokens: i64,
|
||||
pub output_tokens: i64,
|
||||
/// Accumulated reasoning/thinking text from the model.
|
||||
pub reasoning_content: String,
|
||||
/// Full tool calls with accumulated arguments (not just names)
|
||||
pub tool_calls: Vec<StreamedToolCall>,
|
||||
/// All chunks in arrival order — preserves think/answer/tool interleaving.
|
||||
pub chunks: Vec<StreamChunk>,
|
||||
}
|
||||
|
||||
/// Run a streaming chat completion.
|
||||
/// Async callback: takes a string delta and broadcasts it to the WebSocket.
|
||||
/// The returned Future must be awaited by the caller.
|
||||
pub type StreamTextCb = Arc<dyn Fn(&str) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>> + Send + Sync>;
|
||||
pub type StreamReasoningCb = Arc<dyn Fn(&str) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>> + Send + Sync>;
|
||||
pub type StreamToolCallCb = Arc<dyn Fn(&StreamedToolCall) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>> + Send + Sync>;
|
||||
|
||||
/// Run a streaming chat completion with 60s timeout and 5 retries.
|
||||
pub async fn call_stream(
|
||||
messages: &[ChatRequestMessage],
|
||||
model_name: &str,
|
||||
@ -469,7 +498,53 @@ pub async fn call_stream(
|
||||
temperature: f32,
|
||||
max_tokens: u32,
|
||||
tools: Option<&[serde_json::Value]>,
|
||||
mut on_text_delta: impl FnMut(&str),
|
||||
on_text_delta: StreamTextCb,
|
||||
on_reasoning_delta: StreamReasoningCb,
|
||||
on_tool_call: StreamToolCallCb,
|
||||
) -> Result<StreamResponse> {
|
||||
let mut state = RetryState::new(5);
|
||||
|
||||
loop {
|
||||
let result = call_stream_once(
|
||||
messages, model_name, config, temperature, max_tokens, tools,
|
||||
on_text_delta.clone(), on_reasoning_delta.clone(), on_tool_call.clone(),
|
||||
)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(response) => return Ok(response),
|
||||
Err(ref err) if state.should_retry() && is_retryable_error(err) => {
|
||||
let duration = state.backoff_duration();
|
||||
tracing::warn!(
|
||||
attempt = state.attempt + 1,
|
||||
max_retries = 5,
|
||||
backoff_ms = duration.as_millis() as u64,
|
||||
model = %model_name,
|
||||
error = %err,
|
||||
"ai_stream_retry"
|
||||
);
|
||||
tokio::time::sleep(duration).await;
|
||||
state.next();
|
||||
}
|
||||
Err(err) => {
|
||||
ai_metrics().record_failure();
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Single attempt of streaming completion with 60s timeout.
|
||||
async fn call_stream_once(
|
||||
messages: &[ChatRequestMessage],
|
||||
model_name: &str,
|
||||
config: &AiClientConfig,
|
||||
temperature: f32,
|
||||
max_tokens: u32,
|
||||
tools: Option<&[serde_json::Value]>,
|
||||
on_text_delta: StreamTextCb,
|
||||
on_reasoning_delta: StreamReasoningCb,
|
||||
on_tool_call: StreamToolCallCb,
|
||||
) -> Result<StreamResponse> {
|
||||
let client = config.build_rig_client();
|
||||
let model = client.completion_model(model_name);
|
||||
@ -506,15 +581,17 @@ pub async fn call_stream(
|
||||
builder = builder.tools(tool_defs);
|
||||
}
|
||||
|
||||
let stream_fut = async {
|
||||
let mut stream = builder
|
||||
.stream()
|
||||
.await
|
||||
.map_err(|e| AgentError::OpenAi(e.to_string()))?;
|
||||
|
||||
let mut content = String::new();
|
||||
let mut reasoning_content = String::new();
|
||||
let mut tool_calls: Vec<StreamedToolCall> = Vec::new();
|
||||
let mut chunks: Vec<StreamChunk> = Vec::new();
|
||||
|
||||
// Track partial tool calls by internal_call_id for argument accumulation
|
||||
use std::collections::HashMap;
|
||||
let mut partial_tool_calls: HashMap<String, StreamedToolCall> = HashMap::new();
|
||||
let mut stream_finished = false;
|
||||
@ -525,34 +602,45 @@ pub async fn call_stream(
|
||||
match item {
|
||||
Ok(StreamedAssistantContent::Text(text)) => {
|
||||
content.push_str(&text.text);
|
||||
on_text_delta(&text.text);
|
||||
on_text_delta(&text.text).await;
|
||||
chunks.push(StreamChunk {
|
||||
chunk_type: StreamChunkType::Answer,
|
||||
content: text.text,
|
||||
});
|
||||
}
|
||||
Ok(StreamedAssistantContent::ToolCall {
|
||||
tool_call,
|
||||
internal_call_id,
|
||||
}) => {
|
||||
// Complete tool call - extract arguments from the JSON Value
|
||||
let arguments = match &tool_call.function.arguments {
|
||||
serde_json::Value::String(s) => s.clone(),
|
||||
other => serde_json::to_string(other).unwrap_or_else(|_| "{}".to_string()),
|
||||
};
|
||||
tool_calls.push(StreamedToolCall {
|
||||
let tc = StreamedToolCall {
|
||||
id: tool_call.id.clone(),
|
||||
name: tool_call.function.name.clone(),
|
||||
arguments,
|
||||
};
|
||||
on_tool_call(&tc).await;
|
||||
chunks.push(StreamChunk {
|
||||
chunk_type: StreamChunkType::ToolCall,
|
||||
content: serde_json::json!({
|
||||
"id": tc.id,
|
||||
"name": tc.name,
|
||||
"arguments": tc.arguments,
|
||||
}).to_string(),
|
||||
});
|
||||
// Remove from partial if it was being accumulated
|
||||
tool_calls.push(tc);
|
||||
partial_tool_calls.remove(&internal_call_id);
|
||||
}
|
||||
Ok(StreamedAssistantContent::ToolCallDelta {
|
||||
id,
|
||||
internal_call_id,
|
||||
content,
|
||||
content: delta_content,
|
||||
}) => {
|
||||
use rig::streaming::ToolCallDeltaContent;
|
||||
match content {
|
||||
match delta_content {
|
||||
ToolCallDeltaContent::Name(name) => {
|
||||
// Start accumulating a new tool call
|
||||
partial_tool_calls.insert(
|
||||
internal_call_id.clone(),
|
||||
StreamedToolCall {
|
||||
@ -563,40 +651,55 @@ pub async fn call_stream(
|
||||
);
|
||||
}
|
||||
ToolCallDeltaContent::Delta(delta) => {
|
||||
// Append to existing partial tool call
|
||||
if let Some(tc) = partial_tool_calls.get_mut(&internal_call_id) {
|
||||
tc.arguments.push_str(&delta);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(StreamedAssistantContent::Reasoning(_)) => {}
|
||||
Ok(StreamedAssistantContent::ReasoningDelta { .. }) => {}
|
||||
Ok(StreamedAssistantContent::Reasoning(reasoning)) => {
|
||||
for part in &reasoning.reasoning {
|
||||
reasoning_content.push_str(part);
|
||||
on_reasoning_delta(part).await;
|
||||
chunks.push(StreamChunk {
|
||||
chunk_type: StreamChunkType::Thinking,
|
||||
content: part.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
Ok(StreamedAssistantContent::ReasoningDelta { reasoning, .. }) => {
|
||||
reasoning_content.push_str(&reasoning);
|
||||
on_reasoning_delta(&reasoning).await;
|
||||
chunks.push(StreamChunk {
|
||||
chunk_type: StreamChunkType::Thinking,
|
||||
content: reasoning.clone(),
|
||||
});
|
||||
}
|
||||
Ok(StreamedAssistantContent::Final(response)) => {
|
||||
stream_finished = true;
|
||||
// Flush any remaining partial tool calls
|
||||
for (_, tc) in partial_tool_calls.drain() {
|
||||
tool_calls.push(tc);
|
||||
}
|
||||
if let Some(usage) = response.token_usage() {
|
||||
ai_metrics().record_success(
|
||||
usage.input_tokens as i64,
|
||||
usage.output_tokens as i64,
|
||||
!tool_calls.is_empty(),
|
||||
);
|
||||
let in_toks = usage.input_tokens as i64;
|
||||
let out_toks = usage.output_tokens as i64;
|
||||
ai_metrics().record_success(in_toks, out_toks, !tool_calls.is_empty());
|
||||
return Ok(StreamResponse {
|
||||
content,
|
||||
input_tokens: usage.input_tokens as i64,
|
||||
output_tokens: usage.output_tokens as i64,
|
||||
reasoning_content,
|
||||
input_tokens: in_toks,
|
||||
output_tokens: out_toks,
|
||||
tool_calls,
|
||||
chunks,
|
||||
});
|
||||
}
|
||||
// Usage not available from Final — fall through to flush
|
||||
}
|
||||
Err(e) => return Err(AgentError::OpenAi(e.to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
// Flush any remaining partial tool calls (if stream ended without Final)
|
||||
// Flush any remaining partial tool calls (if stream ended without Final or Final had no usage)
|
||||
if !stream_finished {
|
||||
for (_, tc) in partial_tool_calls.drain() {
|
||||
tool_calls.push(tc);
|
||||
@ -605,8 +708,17 @@ pub async fn call_stream(
|
||||
ai_metrics().record_success(0, 0, !tool_calls.is_empty());
|
||||
Ok(StreamResponse {
|
||||
content,
|
||||
reasoning_content,
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
tool_calls,
|
||||
chunks,
|
||||
})
|
||||
};
|
||||
|
||||
// 60s timeout for the entire stream
|
||||
match tokio::time::timeout(std::time::Duration::from_secs(60), stream_fut).await {
|
||||
Ok(result) => result,
|
||||
Err(_) => Err(AgentError::Timeout { task_id: 0, seconds: 60 }),
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user