gitdataai/libs/agent/chat/streaming_execution.rs

273 lines
14 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use models::projects::project_skill;
use models::rooms::room_ai;
use sea_orm::{EntityTrait, ColumnTrait, QueryFilter};
use std::pin::Pin;
use std::sync::Arc;
use uuid::Uuid;
use super::service::StreamResult;
use super::{AiChunkType, AiRequest, AiStreamChunk, StreamCallback};
use crate::client::AiClientConfig;
use crate::client::types::{ChatRequestMessage, ToolCall};
use crate::client::{StreamChunk, StreamChunkType, StreamedToolCall, call_stream};
use crate::error::Result;
use crate::perception::{SkillEntry, ToolCallEvent};
use crate::tool::{ToolCall as AgentToolCall, ToolContext, ToolExecutor};
use super::message_builder::MessageBuilder;
use super::session_recording::record_ai_session;
type SharedCallback = Arc<dyn Fn(AiStreamChunk) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>> + Send + Sync>;
pub async fn execute_process_stream(
request: AiRequest, on_chunk: StreamCallback,
message_builder: &MessageBuilder,
tool_registry: &Option<crate::tool::registry::ToolRegistry>,
ai_base_url: Option<String>, ai_api_key: Option<String>,
) -> Result<StreamResult> {
let on_chunk: SharedCallback = Arc::from(on_chunk);
let tools: Vec<serde_json::Value> = request.tools.clone().unwrap_or_default();
let tools_enabled = !tools.is_empty();
let max_tool_depth = request.max_tool_depth;
let mut messages = message_builder.build_messages(&request).await?;
let room_ai_config = room_ai::Entity::find()
.filter(room_ai::Column::Room.eq(request.room.id))
.filter(room_ai::Column::Model.eq(request.model.id))
.one(&request.db).await?;
let model_name = request.model.name.clone();
let temperature = room_ai_config.as_ref().and_then(|r| r.temperature.map(|v| v as f32)).unwrap_or(request.temperature as f32);
let max_tokens = room_ai_config.as_ref().and_then(|r| r.max_tokens.map(|v| v as u32)).unwrap_or(request.max_tokens as u32);
let mut tool_depth = 0;
let mut total_input_tokens = 0i64;
let mut total_output_tokens = 0i64;
let session_id = Uuid::now_v7();
let session_start = std::time::Instant::now();
let version_id = room_ai_config.as_ref().and_then(|r| r.version);
let config = AiClientConfig::new(ai_api_key.unwrap_or_default())
.with_base_url(ai_base_url.unwrap_or_else(|| "https://api.openai.com".into()));
let mut full_content = String::new();
let mut all_chunks: Vec<StreamChunk> = Vec::new();
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<StreamedToolCall>();
loop {
let on_chunk_cb = on_chunk.clone();
let on_chunk_cb2 = on_chunk.clone();
let tx_arc = Arc::new(tx.clone());
let tx_arc2 = tx_arc.clone();
let response = call_stream(
&messages, &model_name, &config, temperature, max_tokens,
if tools_enabled { Some(&tools) } else { None }, None,
Arc::new(move |delta| {
let content = delta.to_string();
let fut = on_chunk_cb(AiStreamChunk { content, done: false, chunk_type: AiChunkType::Answer });
fut
}),
Arc::new(move |delta| {
let fut = on_chunk_cb2(AiStreamChunk { content: delta.to_string(), done: false, chunk_type: AiChunkType::Thinking });
fut
}),
Arc::new(move |tc: &StreamedToolCall| {
let tx = tx_arc2.clone();
let tc_owned = tc.clone();
Box::pin(async move { let _ = tx.send(tc_owned); }) as Pin<Box<dyn std::future::Future<Output = ()> + Send>>
}),
).await?;
total_input_tokens += response.input_tokens;
total_output_tokens += response.output_tokens;
all_chunks.extend(response.chunks.clone());
let has_tool_calls = tools_enabled && !response.tool_calls.is_empty();
if !has_tool_calls {
return handle_final_answer(response, all_chunks, &request, session_id, version_id, total_input_tokens, total_output_tokens, session_start).await;
}
full_content.push_str(&response.content);
let tool_calls: Vec<ToolCall> = response.tool_calls.iter().map(|tc| ToolCall {
id: tc.id.clone(), type_: "function".into(),
function: crate::client::types::ToolCallFunction { name: tc.name.clone(), arguments: tc.arguments.clone() },
}).collect();
messages.push(ChatRequestMessage::assistant(Some(response.content.clone()), Some(tool_calls.clone())));
drain_tool_call_notifications(&mut rx, &on_chunk, &mut all_chunks).await;
let calls: Vec<AgentToolCall> = response.tool_calls.iter().map(|tc| AgentToolCall {
id: tc.id.clone(), name: tc.name.clone(), arguments: tc.arguments.clone(),
}).collect();
let tool_messages = execute_streaming_tools(
&request, &calls, session_id, &on_chunk, &mut all_chunks,
tool_registry, message_builder,
).await;
messages.extend(tool_messages);
inject_passive_skills_stream(&request, message_builder, &response.tool_calls, &mut messages).await;
tool_depth += 1;
if tool_depth >= max_tool_depth {
let max_depth_text = format!("[AI reached maximum tool depth ({}) — no final answer produced]", max_tool_depth);
on_chunk(AiStreamChunk { content: max_depth_text.clone(), done: true, chunk_type: AiChunkType::Answer }).await;
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::Answer, content: max_depth_text });
record_ai_session(&request.cache, &request.db, request.project.id, request.sender.uid, session_id, request.room.id, request.model.id, version_id.unwrap_or_default(), total_input_tokens, total_output_tokens, session_start.elapsed().as_millis() as i64).await;
return Ok(StreamResult { content: full_content, reasoning_content: String::new(), input_tokens: 0, output_tokens: 0, chunks: all_chunks });
}
}
}
async fn drain_tool_call_notifications(
rx: &mut tokio::sync::mpsc::UnboundedReceiver<StreamedToolCall>,
on_chunk: &SharedCallback,
all_chunks: &mut Vec<StreamChunk>,
) {
loop {
match rx.try_recv() {
Ok(tc) => {
let args_display = if tc.arguments.len() > 100 {
let end = tc.arguments.char_indices().map(|(i, _)| i).take_while(|&i| i <= 100).last().unwrap_or(100);
format!("{}...", &tc.arguments[..end])
} else { tc.arguments.clone() };
let tool_display = format!("🔧 {}({})", tc.name, args_display);
on_chunk(AiStreamChunk { content: tool_display.clone(), done: false, chunk_type: AiChunkType::ToolCall }).await;
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::ToolCall, content: tool_display });
}
Err(tokio::sync::mpsc::error::TryRecvError::Empty) => break,
Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => break,
}
}
}
async fn execute_streaming_tools(
request: &AiRequest, calls: &[AgentToolCall], session_id: Uuid,
on_chunk: &SharedCallback,
all_chunks: &mut Vec<StreamChunk>,
tool_registry: &Option<crate::tool::registry::ToolRegistry>,
message_builder: &MessageBuilder,
) -> Vec<ChatRequestMessage> {
let mut tool_messages = Vec::new();
let mut ctx = ToolContext::new(request.db.clone(), request.cache.clone(), request.config.clone(), request.room.id, Some(request.sender.uid)).with_project(request.project.id);
if let Some(es) = &message_builder.embed_service { ctx = ctx.with_embed_service(es.clone()); }
if let Some(registry) = tool_registry { ctx.registry_mut().merge(registry.clone()); }
let recorder = crate::tool::recorder::ToolCallRecorder::with_session(request.db.clone(), session_id);
let mut join_set = tokio::task::JoinSet::new();
for call in calls {
let call_clone = call.clone();
let mut ctx_clone = ctx.clone();
let sender_uid = request.sender.uid;
let recorder_clone = recorder.clone();
join_set.spawn(async move {
let start = std::time::Instant::now();
let executor = ToolExecutor::new();
let res = executor.execute_batch(vec![call_clone.clone()], &mut ctx_clone).await;
(call_clone, res, start.elapsed(), sender_uid, recorder_clone)
});
}
let heartbeat_dur = std::time::Duration::from_secs(10);
while !join_set.is_empty() {
tokio::select! {
Some(res) = join_set.join_next() => {
if let Ok((call, results, elapsed, sender_uid, recorder)) = res {
match results {
Ok(results) => {
for result in &results {
let text = match &result.result { crate::tool::ToolResult::Ok(v) => v.to_string(), crate::tool::ToolResult::Error(msg) => msg.clone() };
let preview = if text.len() > 300 {
let end = text.char_indices().map(|(i, _)| i).take_while(|&i| i <= 300).last().unwrap_or(300);
format!("{}...", &text[..end])
} else { text.clone() };
tracing::debug!("tool_result: {} — {}", call.name, preview);
let is_error = matches!(result.result, crate::tool::ToolResult::Error(_));
let error_msg = match &result.result { crate::tool::ToolResult::Error(msg) => Some(msg.clone()), _ => None };
recorder.record(crate::tool::recorder::ToolCallRecord {
tool_call_id: call.id.clone(),
session_id: recorder.session_id(),
tool_name: call.name.clone(),
caller: sender_uid,
arguments: call.arguments_json().unwrap_or_default(),
status: if is_error { models::ai::ToolCallStatus::Failed } else { models::ai::ToolCallStatus::Success },
execution_time_ms: Some(elapsed.as_millis() as i64),
error_message: error_msg,
error_stack: None,
retry_count: 0
});
}
let success_display = format!("{}", call.name);
on_chunk(AiStreamChunk { content: success_display.clone(), done: false, chunk_type: AiChunkType::ToolResult }).await;
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::ToolCall, content: success_display });
let msgs = ToolExecutor::to_tool_messages(&results);
tool_messages.extend(msgs);
}
Err(e) => {
recorder.record(crate::tool::recorder::ToolCallRecord {
tool_call_id: call.id.clone(),
session_id: recorder.session_id(),
tool_name: call.name.clone(),
caller: sender_uid,
arguments: call.arguments_json().unwrap_or_default(),
status: models::ai::ToolCallStatus::Failed,
execution_time_ms: Some(elapsed.as_millis() as i64),
error_message: Some(e.to_string()),
error_stack: None,
retry_count: 0
});
let err_text = format!("[Tool call failed: {}]", e);
tracing::warn!(tool = %call.name, args = %call.arguments, error = %e, "tool_call_failed");
let err_display = format!("{} (failed)", call.name);
on_chunk(AiStreamChunk { content: err_display.clone(), done: false, chunk_type: AiChunkType::ToolResult }).await;
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::ToolCall, content: err_display });
tool_messages.push(ChatRequestMessage::tool(&call.id, &err_text));
}
}
}
},
_ = tokio::time::sleep(heartbeat_dur) => {
on_chunk(AiStreamChunk { content: String::new(), done: false, chunk_type: AiChunkType::ToolCall }).await;
}
}
}
tool_messages
}
async fn handle_final_answer(
response: crate::client::StreamResponse,
all_chunks: Vec<StreamChunk>, request: &AiRequest,
session_id: Uuid, version_id: Option<Uuid>,
total_input_tokens: i64, total_output_tokens: i64,
session_start: std::time::Instant,
) -> Result<StreamResult> {
let full_content = response.content.clone();
// Don't push full content as a chunk — incremental deltas in
// response.chunks (already accumulated above) sum to the same text.
// merge_consecutive_blocks would concatenate delta_sum + full =
// 2× full, causing duplicate content in DB persistence.
record_ai_session(&request.cache, &request.db, request.project.id, request.sender.uid, session_id, request.room.id, request.model.id, version_id.unwrap_or_default(), total_input_tokens, total_output_tokens, session_start.elapsed().as_millis() as i64).await;
Ok(StreamResult { content: full_content, reasoning_content: response.reasoning_content, input_tokens: total_input_tokens, output_tokens: total_output_tokens, chunks: all_chunks })
}
async fn inject_passive_skills_stream(
request: &AiRequest, message_builder: &MessageBuilder,
tool_calls: &[StreamedToolCall], messages: &mut Vec<ChatRequestMessage>,
) {
if let Ok(skills) = project_skill::Entity::find()
.filter(project_skill::Column::ProjectUuid.eq(request.project.id))
.filter(project_skill::Column::Enabled.eq(true)).all(&request.db).await {
let skill_entries: Vec<SkillEntry> = skills.into_iter().map(|s| SkillEntry { slug: s.slug, name: s.name, description: s.description, content: s.content }).collect();
let tool_events: Vec<ToolCallEvent> = tool_calls.iter().map(|tc| ToolCallEvent { tool_name: tc.name.clone(), arguments: tc.arguments.clone() }).collect();
for event in &tool_events {
if let Some(ctx) = message_builder.perception_service.passive.detect(event, &skill_entries) {
messages.push(ctx.to_system_message());
}
}
}
}