feat(agent): add ordered stream chunk collection + retry for HTTP errors

- StreamChunk/StreamChunkType types for preserving arrival order
- Chunk collection in call_stream_once and process_stream
- Add "error sending request" and "Http client error" to retryable errors
- StreamResult includes chunks vector for ordered replay
This commit is contained in:
ZhenYi 2026-04-26 13:10:26 +08:00
parent 0b5dc98ce5
commit b4b5538447
2 changed files with 410 additions and 211 deletions

View File

@ -1,4 +1,5 @@
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use models::projects::project_skill;
use models::rooms::room_ai;
@ -9,7 +10,7 @@ use super::context::RoomMessageContext;
use super::{AiChunkType, AiRequest, AiStreamChunk, Mention, StreamCallback};
use crate::client::types::{ChatRequestMessage, ToolCall};
use crate::client::AiClientConfig;
use crate::client::{call_stream, call_with_params};
use crate::client::{call_stream, call_with_params, StreamChunk, StreamChunkType, StreamedToolCall};
use crate::compact::{CompactConfig, CompactService};
use crate::embed::EmbedService;
use crate::error::{AgentError, Result};
@ -17,6 +18,23 @@ use crate::perception::{PerceptionService, SkillEntry, ToolCallEvent};
use crate::react::{ReactAgent, ReactConfig, DEFAULT_SYSTEM_PROMPT};
use crate::tool::{ToolCall as AgentToolCall, ToolContext, ToolExecutor, ToolResult, registry::ToolRegistry};
/// Result from streaming AI response.
pub struct StreamResult {
pub content: String,
pub reasoning_content: String,
pub input_tokens: i64,
pub output_tokens: i64,
/// All chunks in arrival order — preserves ReAct multi-cycle ordering.
pub chunks: Vec<StreamChunk>,
}
/// Result from non-streaming AI response.
pub struct ProcessResult {
pub content: String,
pub input_tokens: i64,
pub output_tokens: i64,
}
/// Service for handling AI chat requests in rooms.
pub struct ChatService {
ai_base_url: Option<String>,
@ -97,7 +115,7 @@ impl ChatService {
self.tool_registry.as_ref()
}
pub async fn process(&self, request: AiRequest) -> Result<String> {
pub async fn process(&self, request: AiRequest) -> Result<ProcessResult> {
let tools: Vec<serde_json::Value> = request.tools.clone().unwrap_or_default();
let tools_enabled = !tools.is_empty();
let max_tool_depth = request.max_tool_depth;
@ -120,6 +138,8 @@ impl ChatService {
.and_then(|r| r.max_tokens.map(|v| v as u32))
.unwrap_or(request.max_tokens as u32);
let mut tool_depth = 0;
let mut input_tokens = 0i64;
let mut output_tokens = 0i64;
let config = AiClientConfig::new(
self.ai_api_key.clone().unwrap_or_default(),
@ -140,6 +160,8 @@ impl ChatService {
.await?;
let text = response.content.clone();
input_tokens += response.input_tokens;
output_tokens += response.output_tokens;
if tools_enabled && !response.tool_calls_finished.is_empty() {
// Build assistant message with tool_calls
@ -176,17 +198,30 @@ impl ChatService {
})
.collect();
let tool_messages = match self.execute_tool_calls(calls, &request).await {
Ok(msgs) => msgs,
let tool_messages = {
let mut ctx = ToolContext::new(
request.db.clone(),
request.cache.clone(),
request.config.clone(),
request.room.id,
Some(request.sender.uid),
)
.with_project(request.project.id);
if let Some(ref registry) = self.tool_registry {
ctx.registry_mut().merge(registry.clone());
}
let executor = ToolExecutor::new();
match executor.execute_batch(calls, &mut ctx).await {
Ok(results) => ToolExecutor::to_tool_messages(&results),
Err(e) => {
let err_msg = format!("[Tool call failed: {}]", e);
// Return error as a single tool result per call
response
.tool_calls_finished
.iter()
.map(|_| ChatRequestMessage::tool(Uuid::new_v4().to_string(), &err_msg))
.collect()
}
}
};
messages.extend(tool_messages);
@ -225,22 +260,26 @@ impl ChatService {
tool_depth += 1;
if tool_depth >= max_tool_depth {
if text.is_empty() {
return Ok(format!(
let content = if text.is_empty() {
format!(
"[AI reached maximum tool depth ({}) — no final answer produced]",
max_tool_depth
));
}
return Ok(text);
)
} else {
text
};
return Ok(ProcessResult { content, input_tokens, output_tokens });
}
continue;
}
return Ok(text);
return Ok(ProcessResult { content: text, input_tokens, output_tokens });
}
}
pub async fn process_stream(&self, request: AiRequest, on_chunk: StreamCallback) -> Result<String> {
pub async fn process_stream(&self, request: AiRequest, on_chunk: StreamCallback) -> Result<StreamResult> {
// Wrap on_chunk in Arc so it can be shared across loop iterations
let on_chunk = Arc::new(on_chunk);
let tools: Vec<serde_json::Value> = request.tools.clone().unwrap_or_default();
let tools_enabled = !tools.is_empty();
let max_tool_depth = request.max_tool_depth;
@ -270,13 +309,15 @@ impl ChatService {
.with_base_url(self.ai_base_url.clone().unwrap_or_else(|| "https://api.openai.com".into()));
let mut full_content = String::new();
let mut has_called_tools = false;
let mut all_chunks: Vec<StreamChunk> = Vec::new();
// Collect tool calls during streaming, push them incrementally after.
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<StreamedToolCall>();
loop {
let chunk_type = if has_called_tools {
AiChunkType::Answer
} else {
AiChunkType::Thinking
};
let on_chunk_cb = on_chunk.clone();
let on_chunk_cb2 = on_chunk_cb.clone();
let tx_arc = Arc::new(tx.clone());
let tx_arc2 = tx_arc.clone();
let response = call_stream(
&messages,
&model_name,
@ -284,18 +325,36 @@ impl ChatService {
temperature,
max_tokens,
if tools_enabled { Some(&tools) } else { None },
|delta| {
let _ = on_chunk(AiStreamChunk {
Arc::new(move |delta| {
let fut = on_chunk_cb(AiStreamChunk {
content: delta.to_string(),
done: false,
chunk_type: chunk_type.clone(),
chunk_type: AiChunkType::Answer,
});
},
fut
}),
Arc::new(move |delta| {
let fut = on_chunk_cb2(AiStreamChunk {
content: delta.to_string(),
done: false,
chunk_type: AiChunkType::Thinking,
});
fut
}),
Arc::new(move |tc: &StreamedToolCall| {
let tx = tx_arc2.clone();
let tc_owned = tc.clone();
Box::pin(async move {
let _ = tx.send(tc_owned);
}) as Pin<Box<dyn std::future::Future<Output = ()> + Send>>
}),
)
.await?;
// Collect chunks from this streaming iteration in order.
all_chunks.extend(response.chunks);
let has_tool_calls = tools_enabled && !response.tool_calls.is_empty();
has_called_tools = true;
if has_tool_calls {
// Accumulate the assistant's text before tool calls
@ -321,28 +380,34 @@ impl ChatService {
Some(tool_calls.clone()),
));
// Stream tool call summary to frontend
let call_summary: Vec<String> = response
.tool_calls
.iter()
.map(|tc| {
// Truncate long arguments for display
// Push each tool call incrementally to frontend.
// Use try_recv() — tx is never dropped so recv() would deadlock.
loop {
match rx.try_recv() {
Ok(tc) => {
let args_display = if tc.arguments.len() > 100 {
format!("{}...", &tc.arguments[..100])
} else {
tc.arguments.clone()
};
format!("{}({})", tc.name, args_display)
})
.collect();
let tool_display = format!("🔧 {}({})", tc.name, args_display);
on_chunk(AiStreamChunk {
content: format!("[Calling tools: {}]", call_summary.join(", ")),
content: tool_display.clone(),
done: false,
chunk_type: AiChunkType::ToolCall,
})
.await;
all_chunks.push(StreamChunk {
chunk_type: StreamChunkType::ToolCall,
content: tool_display,
});
}
Err(tokio::sync::mpsc::error::TryRecvError::Empty) => break,
Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => break,
}
}
// Execute tools with full arguments from streaming
// Execute tools one at a time, push each result incrementally
let calls: Vec<AgentToolCall> = response
.tool_calls
.iter()
@ -353,43 +418,71 @@ impl ChatService {
})
.collect();
let tool_messages = match self.execute_tool_calls(calls, &request).await {
Ok(msgs) => {
let result_summary: Vec<String> = msgs
.iter()
.map(|m| {
let text = m.content.as_deref().unwrap_or("[no content]");
if text.len() > 300 {
format!("{}...", &text[..300])
} else {
text.to_string()
}
})
.collect();
on_chunk(AiStreamChunk {
content: format!("[Tool results: {}]", result_summary.join("; ")),
done: false,
chunk_type: AiChunkType::ToolResult,
})
.await;
msgs
let mut tool_messages = Vec::new();
for call in &calls {
let ctx = &mut crate::tool::ToolContext::new(
request.db.clone(),
request.cache.clone(),
request.config.clone(),
request.room.id,
Some(request.sender.uid),
);
if let Some(ref registry) = self.tool_registry {
ctx.registry_mut().merge(registry.clone());
}
let executor = crate::tool::ToolExecutor::new();
let results = match executor.execute_batch(vec![call.clone()], ctx).await {
Ok(r) => r,
Err(e) => {
let err_text = format!("[Tool call failed: {}]", e);
tracing::warn!(tool = %call.name, error = %e, "tool_call_failed");
// Do NOT emit tool_result chunks to frontend — show error via tool_call instead
let err_display = format!("{} (failed)", call.name);
on_chunk(AiStreamChunk {
content: err_text.clone(),
content: err_display.clone(),
done: false,
chunk_type: AiChunkType::ToolResult,
chunk_type: AiChunkType::ToolCall,
})
.await;
// Return error tool messages
response
.tool_calls
.iter()
.map(|tc| ChatRequestMessage::tool(&tc.id, &err_text))
.collect()
all_chunks.push(StreamChunk {
chunk_type: StreamChunkType::ToolCall,
content: err_display,
});
tool_messages.push(ChatRequestMessage::tool(&call.id, &err_text));
continue;
}
};
for result in &results {
let text = match &result.result {
crate::tool::ToolResult::Ok(v) => v.to_string(),
crate::tool::ToolResult::Error(msg) => msg.clone(),
};
let preview = if text.len() > 300 {
format!("{}...", &text[..300])
} else {
text.clone()
};
tracing::debug!("tool_result: {} — {}", call.name, preview);
// Do NOT emit tool_result chunks to frontend — raw output may contain sensitive data.
// Log server-side only; frontend sees tool_call status via on_chunk below.
}
let success_display = format!("{}", call.name);
on_chunk(AiStreamChunk {
content: success_display.clone(),
done: false,
chunk_type: AiChunkType::ToolCall,
})
.await;
all_chunks.push(StreamChunk {
chunk_type: StreamChunkType::ToolCall,
content: success_display,
});
let msgs = crate::tool::ToolExecutor::to_tool_messages(&results);
tool_messages.extend(msgs);
}
messages.extend(tool_messages);
// Inject passive-detected skills based on tool calls
@ -427,60 +520,54 @@ impl ChatService {
tool_depth += 1;
if tool_depth >= max_tool_depth {
on_chunk(AiStreamChunk {
content: format!(
let max_depth_text = format!(
"[AI reached maximum tool depth ({}) — no final answer produced]",
max_tool_depth
),
);
on_chunk(AiStreamChunk {
content: max_depth_text.clone(),
done: true,
chunk_type: AiChunkType::Answer,
})
.await;
return Ok(full_content);
all_chunks.push(StreamChunk {
chunk_type: StreamChunkType::Answer,
content: max_depth_text,
});
return Ok(StreamResult {
content: full_content,
reasoning_content: String::new(),
input_tokens: 0,
output_tokens: 0,
chunks: all_chunks,
});
}
continue;
}
// Final answer — accumulate and return
full_content.push_str(&response.content);
on_chunk(AiStreamChunk {
content: response.content,
content: response.content.clone(),
done: true,
chunk_type: AiChunkType::Answer,
})
.await;
return Ok(full_content);
all_chunks.push(StreamChunk {
chunk_type: StreamChunkType::Answer,
content: response.content.clone(),
});
return Ok(StreamResult {
content: full_content,
reasoning_content: response.reasoning_content,
input_tokens: response.input_tokens,
output_tokens: response.output_tokens,
chunks: all_chunks,
});
}
}
/// Executes a batch of tool calls and returns the tool result messages.
async fn execute_tool_calls(
&self,
calls: Vec<AgentToolCall>,
request: &AiRequest,
) -> Result<Vec<ChatRequestMessage>> {
let mut ctx = ToolContext::new(
request.db.clone(),
request.cache.clone(),
request.config.clone(),
request.room.id,
Some(request.sender.uid),
)
.with_project(request.project.id);
if let Some(ref registry) = self.tool_registry {
ctx.registry_mut().merge(registry.clone());
}
let executor = ToolExecutor::new();
let results = executor
.execute_batch(calls, &mut ctx)
.await
.map_err(|e| AgentError::Internal(e.to_string()))?;
Ok(ToolExecutor::to_tool_messages(&results))
}
async fn build_messages(&self, request: &AiRequest) -> Result<Vec<ChatRequestMessage>> {
let mut messages = Vec::new();

View File

@ -5,6 +5,8 @@
pub mod types;
pub use types::{ChatRequestMessage, ToolCall as ClientToolCall};
use std::pin::Pin;
use std::sync::Arc;
use std::time::Instant;
use uuid::Uuid;
@ -130,6 +132,8 @@ fn is_retryable_error(err: &AgentError) -> bool {
|| msg.contains("connection timed out")
|| msg.contains("network error")
|| msg.contains("dns error")
|| msg.contains("error sending request")
|| msg.contains("Http client error")
|| msg.contains("rate_limit")
|| msg.contains("rate limit")
|| msg.contains("429")
@ -451,17 +455,42 @@ pub struct StreamedToolCall {
pub arguments: String,
}
/// Type of chunk in the streaming response, preserving arrival order.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum StreamChunkType {
Thinking,
Answer,
ToolCall,
}
/// A single chunk from the streaming response in arrival order.
#[derive(Debug, Clone)]
pub struct StreamChunk {
pub chunk_type: StreamChunkType,
pub content: String,
}
/// Streaming result from rig.
#[derive(Debug)]
pub struct StreamResponse {
pub content: String,
pub input_tokens: i64,
pub output_tokens: i64,
/// Accumulated reasoning/thinking text from the model.
pub reasoning_content: String,
/// Full tool calls with accumulated arguments (not just names)
pub tool_calls: Vec<StreamedToolCall>,
/// All chunks in arrival order — preserves think/answer/tool interleaving.
pub chunks: Vec<StreamChunk>,
}
/// Run a streaming chat completion.
/// Async callback: takes a string delta and broadcasts it to the WebSocket.
/// The returned Future must be awaited by the caller.
pub type StreamTextCb = Arc<dyn Fn(&str) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>> + Send + Sync>;
pub type StreamReasoningCb = Arc<dyn Fn(&str) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>> + Send + Sync>;
pub type StreamToolCallCb = Arc<dyn Fn(&StreamedToolCall) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>> + Send + Sync>;
/// Run a streaming chat completion with 60s timeout and 5 retries.
pub async fn call_stream(
messages: &[ChatRequestMessage],
model_name: &str,
@ -469,7 +498,53 @@ pub async fn call_stream(
temperature: f32,
max_tokens: u32,
tools: Option<&[serde_json::Value]>,
mut on_text_delta: impl FnMut(&str),
on_text_delta: StreamTextCb,
on_reasoning_delta: StreamReasoningCb,
on_tool_call: StreamToolCallCb,
) -> Result<StreamResponse> {
let mut state = RetryState::new(5);
loop {
let result = call_stream_once(
messages, model_name, config, temperature, max_tokens, tools,
on_text_delta.clone(), on_reasoning_delta.clone(), on_tool_call.clone(),
)
.await;
match result {
Ok(response) => return Ok(response),
Err(ref err) if state.should_retry() && is_retryable_error(err) => {
let duration = state.backoff_duration();
tracing::warn!(
attempt = state.attempt + 1,
max_retries = 5,
backoff_ms = duration.as_millis() as u64,
model = %model_name,
error = %err,
"ai_stream_retry"
);
tokio::time::sleep(duration).await;
state.next();
}
Err(err) => {
ai_metrics().record_failure();
return Err(err);
}
}
}
}
/// Single attempt of streaming completion with 60s timeout.
async fn call_stream_once(
messages: &[ChatRequestMessage],
model_name: &str,
config: &AiClientConfig,
temperature: f32,
max_tokens: u32,
tools: Option<&[serde_json::Value]>,
on_text_delta: StreamTextCb,
on_reasoning_delta: StreamReasoningCb,
on_tool_call: StreamToolCallCb,
) -> Result<StreamResponse> {
let client = config.build_rig_client();
let model = client.completion_model(model_name);
@ -506,15 +581,17 @@ pub async fn call_stream(
builder = builder.tools(tool_defs);
}
let stream_fut = async {
let mut stream = builder
.stream()
.await
.map_err(|e| AgentError::OpenAi(e.to_string()))?;
let mut content = String::new();
let mut reasoning_content = String::new();
let mut tool_calls: Vec<StreamedToolCall> = Vec::new();
let mut chunks: Vec<StreamChunk> = Vec::new();
// Track partial tool calls by internal_call_id for argument accumulation
use std::collections::HashMap;
let mut partial_tool_calls: HashMap<String, StreamedToolCall> = HashMap::new();
let mut stream_finished = false;
@ -525,34 +602,45 @@ pub async fn call_stream(
match item {
Ok(StreamedAssistantContent::Text(text)) => {
content.push_str(&text.text);
on_text_delta(&text.text);
on_text_delta(&text.text).await;
chunks.push(StreamChunk {
chunk_type: StreamChunkType::Answer,
content: text.text,
});
}
Ok(StreamedAssistantContent::ToolCall {
tool_call,
internal_call_id,
}) => {
// Complete tool call - extract arguments from the JSON Value
let arguments = match &tool_call.function.arguments {
serde_json::Value::String(s) => s.clone(),
other => serde_json::to_string(other).unwrap_or_else(|_| "{}".to_string()),
};
tool_calls.push(StreamedToolCall {
let tc = StreamedToolCall {
id: tool_call.id.clone(),
name: tool_call.function.name.clone(),
arguments,
};
on_tool_call(&tc).await;
chunks.push(StreamChunk {
chunk_type: StreamChunkType::ToolCall,
content: serde_json::json!({
"id": tc.id,
"name": tc.name,
"arguments": tc.arguments,
}).to_string(),
});
// Remove from partial if it was being accumulated
tool_calls.push(tc);
partial_tool_calls.remove(&internal_call_id);
}
Ok(StreamedAssistantContent::ToolCallDelta {
id,
internal_call_id,
content,
content: delta_content,
}) => {
use rig::streaming::ToolCallDeltaContent;
match content {
match delta_content {
ToolCallDeltaContent::Name(name) => {
// Start accumulating a new tool call
partial_tool_calls.insert(
internal_call_id.clone(),
StreamedToolCall {
@ -563,40 +651,55 @@ pub async fn call_stream(
);
}
ToolCallDeltaContent::Delta(delta) => {
// Append to existing partial tool call
if let Some(tc) = partial_tool_calls.get_mut(&internal_call_id) {
tc.arguments.push_str(&delta);
}
}
}
}
Ok(StreamedAssistantContent::Reasoning(_)) => {}
Ok(StreamedAssistantContent::ReasoningDelta { .. }) => {}
Ok(StreamedAssistantContent::Reasoning(reasoning)) => {
for part in &reasoning.reasoning {
reasoning_content.push_str(part);
on_reasoning_delta(part).await;
chunks.push(StreamChunk {
chunk_type: StreamChunkType::Thinking,
content: part.clone(),
});
}
}
Ok(StreamedAssistantContent::ReasoningDelta { reasoning, .. }) => {
reasoning_content.push_str(&reasoning);
on_reasoning_delta(&reasoning).await;
chunks.push(StreamChunk {
chunk_type: StreamChunkType::Thinking,
content: reasoning.clone(),
});
}
Ok(StreamedAssistantContent::Final(response)) => {
stream_finished = true;
// Flush any remaining partial tool calls
for (_, tc) in partial_tool_calls.drain() {
tool_calls.push(tc);
}
if let Some(usage) = response.token_usage() {
ai_metrics().record_success(
usage.input_tokens as i64,
usage.output_tokens as i64,
!tool_calls.is_empty(),
);
let in_toks = usage.input_tokens as i64;
let out_toks = usage.output_tokens as i64;
ai_metrics().record_success(in_toks, out_toks, !tool_calls.is_empty());
return Ok(StreamResponse {
content,
input_tokens: usage.input_tokens as i64,
output_tokens: usage.output_tokens as i64,
reasoning_content,
input_tokens: in_toks,
output_tokens: out_toks,
tool_calls,
chunks,
});
}
// Usage not available from Final — fall through to flush
}
Err(e) => return Err(AgentError::OpenAi(e.to_string())),
}
}
// Flush any remaining partial tool calls (if stream ended without Final)
// Flush any remaining partial tool calls (if stream ended without Final or Final had no usage)
if !stream_finished {
for (_, tc) in partial_tool_calls.drain() {
tool_calls.push(tc);
@ -605,8 +708,17 @@ pub async fn call_stream(
ai_metrics().record_success(0, 0, !tool_calls.is_empty());
Ok(StreamResponse {
content,
reasoning_content,
input_tokens: 0,
output_tokens: 0,
tool_calls,
chunks,
})
};
// 60s timeout for the entire stream
match tokio::time::timeout(std::time::Duration::from_secs(60), stream_fut).await {
Ok(result) => result,
Err(_) => Err(AgentError::Timeout { task_id: 0, seconds: 60 }),
}
}