gitdataai/libs/agent/chat/streaming_execution.rs
ZhenYi 14f6e1e500 feat(core): initialize project with access control and AI integration
- Add gitignore and prettier configuration files for project scaffolding
- Implement room access control service with project member verification
- Create user access key management with CRUD operations and activity logging
- Add accordion UI component for frontend expandable sections
- Implement room AI configuration with list, upsert, and delete operations
- Add AI event types for agent join/leave/status change tracking
- Create streaming AI processing services for mode and react patterns
- Build room AI service with model detection and idempotency handling
- Integrate chat service orchestration for AI message processing
- Add typing indicators and stream cancellation for AI interactions
- Implement mention parsing and context extraction for AI agents
2026-05-03 06:04:31 +08:00

249 lines
14 KiB
Rust

use models::projects::project_skill;
use models::rooms::room_ai;
use sea_orm::{EntityTrait, ColumnTrait, QueryFilter};
use std::pin::Pin;
use std::sync::Arc;
use uuid::Uuid;
use super::service::StreamResult;
use super::{AiChunkType, AiRequest, AiStreamChunk, StreamCallback};
use crate::client::AiClientConfig;
use crate::client::types::{ChatRequestMessage, ToolCall};
use crate::client::{StreamChunk, StreamChunkType, StreamedToolCall, call_stream};
use crate::error::Result;
use crate::perception::{SkillEntry, ToolCallEvent};
use crate::tool::{ToolCall as AgentToolCall, ToolContext, ToolExecutor};
use super::message_builder::MessageBuilder;
use super::session_recording::record_ai_session;
type SharedCallback = Arc<dyn Fn(AiStreamChunk) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>> + Send + Sync>;
pub async fn execute_process_stream(
request: AiRequest, on_chunk: StreamCallback,
message_builder: &MessageBuilder,
tool_registry: &Option<crate::tool::registry::ToolRegistry>,
ai_base_url: Option<String>, ai_api_key: Option<String>,
) -> Result<StreamResult> {
let on_chunk: SharedCallback = Arc::from(on_chunk);
let tools: Vec<serde_json::Value> = request.tools.clone().unwrap_or_default();
let tools_enabled = !tools.is_empty();
let max_tool_depth = request.max_tool_depth;
let mut messages = message_builder.build_messages(&request).await?;
let room_ai_config = room_ai::Entity::find()
.filter(room_ai::Column::Room.eq(request.room.id))
.filter(room_ai::Column::Model.eq(request.model.id))
.one(&request.db).await?;
let model_name = request.model.name.clone();
let temperature = room_ai_config.as_ref().and_then(|r| r.temperature.map(|v| v as f32)).unwrap_or(request.temperature as f32);
let max_tokens = room_ai_config.as_ref().and_then(|r| r.max_tokens.map(|v| v as u32)).unwrap_or(request.max_tokens as u32);
let mut tool_depth = 0;
let mut total_input_tokens = 0i64;
let mut total_output_tokens = 0i64;
let session_id = Uuid::now_v7();
let session_start = std::time::Instant::now();
let version_id = room_ai_config.as_ref().and_then(|r| r.version);
let config = AiClientConfig::new(ai_api_key.unwrap_or_default())
.with_base_url(ai_base_url.unwrap_or_else(|| "https://api.openai.com".into()));
let mut full_content = String::new();
let mut all_chunks: Vec<StreamChunk> = Vec::new();
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<StreamedToolCall>();
loop {
let on_chunk_cb = on_chunk.clone();
let on_chunk_cb2 = on_chunk.clone();
let tx_arc = Arc::new(tx.clone());
let tx_arc2 = tx_arc.clone();
let response = call_stream(
&messages, &model_name, &config, temperature, max_tokens,
if tools_enabled { Some(&tools) } else { None }, None,
Arc::new(move |delta| {
let fut = on_chunk_cb(AiStreamChunk { content: delta.to_string(), done: false, chunk_type: AiChunkType::Answer });
fut
}),
Arc::new(move |delta| {
let fut = on_chunk_cb2(AiStreamChunk { content: delta.to_string(), done: false, chunk_type: AiChunkType::Thinking });
fut
}),
Arc::new(move |tc: &StreamedToolCall| {
let tx = tx_arc2.clone();
let tc_owned = tc.clone();
Box::pin(async move { let _ = tx.send(tc_owned); }) as Pin<Box<dyn std::future::Future<Output = ()> + Send>>
}),
).await?;
total_input_tokens += response.input_tokens;
total_output_tokens += response.output_tokens;
all_chunks.extend(response.chunks.clone());
let has_tool_calls = tools_enabled && !response.tool_calls.is_empty();
if !has_tool_calls {
return handle_final_answer(response, full_content, on_chunk, all_chunks, &request, session_id, version_id, total_input_tokens, total_output_tokens, session_start).await;
}
full_content.push_str(&response.content);
full_content.push('\n');
let tool_calls: Vec<ToolCall> = response.tool_calls.iter().map(|tc| ToolCall {
id: tc.id.clone(), type_: "function".into(),
function: crate::client::types::ToolCallFunction { name: tc.name.clone(), arguments: tc.arguments.clone() },
}).collect();
messages.push(ChatRequestMessage::assistant(Some(response.content.clone()), Some(tool_calls.clone())));
drain_tool_call_notifications(&mut rx, &on_chunk, &mut all_chunks).await;
let calls: Vec<AgentToolCall> = response.tool_calls.iter().map(|tc| AgentToolCall {
id: tc.id.clone(), name: tc.name.clone(), arguments: tc.arguments.clone(),
}).collect();
let tool_messages = execute_streaming_tools(
&request, &calls, session_id, &on_chunk, &mut all_chunks,
tool_registry, message_builder,
).await;
messages.extend(tool_messages);
inject_passive_skills_stream(&request, message_builder, &response.tool_calls, &mut messages).await;
tool_depth += 1;
if tool_depth >= max_tool_depth {
let max_depth_text = format!("[AI reached maximum tool depth ({}) — no final answer produced]", max_tool_depth);
on_chunk(AiStreamChunk { content: max_depth_text.clone(), done: true, chunk_type: AiChunkType::Answer }).await;
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::Answer, content: max_depth_text });
record_ai_session(&request.cache, &request.db, request.project.id, session_id, request.room.id, request.model.id, version_id.unwrap_or_default(), total_input_tokens, total_output_tokens, session_start.elapsed().as_millis() as i64).await;
return Ok(StreamResult { content: full_content, reasoning_content: String::new(), input_tokens: 0, output_tokens: 0, chunks: all_chunks });
}
}
}
async fn drain_tool_call_notifications(
rx: &mut tokio::sync::mpsc::UnboundedReceiver<StreamedToolCall>,
on_chunk: &SharedCallback,
all_chunks: &mut Vec<StreamChunk>,
) {
loop {
match rx.try_recv() {
Ok(tc) => {
let args_display = if tc.arguments.len() > 100 {
let end = tc.arguments.char_indices().map(|(i, _)| i).take_while(|&i| i <= 100).last().unwrap_or(100);
format!("{}...", &tc.arguments[..end])
} else { tc.arguments.clone() };
let tool_display = format!("🔧 {}({})", tc.name, args_display);
on_chunk(AiStreamChunk { content: tool_display.clone(), done: false, chunk_type: AiChunkType::ToolCall }).await;
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::ToolCall, content: tool_display });
}
Err(tokio::sync::mpsc::error::TryRecvError::Empty) => break,
Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => break,
}
}
}
async fn execute_streaming_tools(
request: &AiRequest, calls: &[AgentToolCall], session_id: Uuid,
on_chunk: &SharedCallback,
all_chunks: &mut Vec<StreamChunk>,
tool_registry: &Option<crate::tool::registry::ToolRegistry>,
message_builder: &MessageBuilder,
) -> Vec<ChatRequestMessage> {
let mut tool_messages = Vec::new();
let mut ctx = ToolContext::new(request.db.clone(), request.cache.clone(), request.config.clone(), request.room.id, Some(request.sender.uid)).with_project(request.project.id);
if let Some(es) = &message_builder.embed_service { ctx = ctx.with_embed_service(es.clone()); }
if let Some(registry) = tool_registry { ctx.registry_mut().merge(registry.clone()); }
let recorder = crate::tool::recorder::ToolCallRecorder::with_session(request.db.clone(), session_id);
for call in calls {
let start = std::time::Instant::now();
let call_clone = call.clone();
let mut ctx_clone = ctx.clone();
let (result_tx, mut result_rx) = tokio::sync::oneshot::channel();
tokio::spawn(async move {
let executor = ToolExecutor::new();
let res = executor.execute_batch(vec![call_clone], &mut ctx_clone).await;
let _ = result_tx.send(res);
});
let heartbeat_dur = std::time::Duration::from_secs(10);
let results = loop {
tokio::select! {
res = &mut result_rx => {
match res { Ok(inner) => break inner, Err(_) => break Err(crate::tool::ToolError::ExecutionError("tool task cancelled".into())), }
},
_ = tokio::time::sleep(heartbeat_dur) => {
on_chunk(AiStreamChunk { content: String::new(), done: false, chunk_type: AiChunkType::ToolCall }).await;
}
}
};
match results {
Ok(results) => {
for result in &results {
let text = match &result.result { crate::tool::ToolResult::Ok(v) => v.to_string(), crate::tool::ToolResult::Error(msg) => msg.clone() };
let preview = if text.len() > 300 {
let end = text.char_indices().map(|(i, _)| i).take_while(|&i| i <= 300).last().unwrap_or(300);
format!("{}...", &text[..end])
} else { text.clone() };
tracing::debug!("tool_result: {} — {}", call.name, preview);
let elapsed = start.elapsed().as_millis() as i64;
let is_error = matches!(result.result, crate::tool::ToolResult::Error(_));
let error_msg = match &result.result { crate::tool::ToolResult::Error(msg) => Some(msg.clone()), _ => None };
recorder.record(crate::tool::recorder::ToolCallRecord { tool_call_id: call.id.clone(), session_id: recorder.session_id(), tool_name: call.name.clone(), caller: request.sender.uid, arguments: call.arguments_json().unwrap_or_default(), status: if is_error { models::ai::ToolCallStatus::Failed } else { models::ai::ToolCallStatus::Success }, execution_time_ms: Some(elapsed), error_message: error_msg, error_stack: None, retry_count: 0 });
}
let success_display = format!("{}", call.name);
on_chunk(AiStreamChunk { content: success_display.clone(), done: false, chunk_type: AiChunkType::ToolResult }).await;
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::ToolCall, content: success_display });
let msgs = ToolExecutor::to_tool_messages(&results);
tool_messages.extend(msgs);
}
Err(e) => {
let elapsed = start.elapsed().as_millis() as i64;
recorder.record(crate::tool::recorder::ToolCallRecord { tool_call_id: call.id.clone(), session_id: recorder.session_id(), tool_name: call.name.clone(), caller: request.sender.uid, arguments: call.arguments_json().unwrap_or_default(), status: models::ai::ToolCallStatus::Failed, execution_time_ms: Some(elapsed), error_message: Some(e.to_string()), error_stack: None, retry_count: 0 });
let err_text = format!("[Tool call failed: {}]", e);
tracing::warn!(tool = %call.name, args = %call.arguments, error = %e, "tool_call_failed");
let err_display = format!("{} (failed)", call.name);
on_chunk(AiStreamChunk { content: err_display.clone(), done: false, chunk_type: AiChunkType::ToolResult }).await;
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::ToolCall, content: err_display });
tool_messages.push(ChatRequestMessage::tool(&call.id, &err_text));
}
}
}
tool_messages
}
async fn handle_final_answer(
response: crate::client::StreamResponse, full_content: String,
on_chunk: SharedCallback,
mut all_chunks: Vec<StreamChunk>, request: &AiRequest,
session_id: Uuid, version_id: Option<Uuid>,
total_input_tokens: i64, total_output_tokens: i64,
session_start: std::time::Instant,
) -> Result<StreamResult> {
let full_content = full_content + &response.content;
on_chunk(AiStreamChunk { content: response.content.clone(), done: true, chunk_type: AiChunkType::Answer }).await;
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::Answer, content: response.content.clone() });
record_ai_session(&request.cache, &request.db, request.project.id, session_id, request.room.id, request.model.id, version_id.unwrap_or_default(), total_input_tokens, total_output_tokens, session_start.elapsed().as_millis() as i64).await;
Ok(StreamResult { content: full_content, reasoning_content: response.reasoning_content, input_tokens: response.input_tokens, output_tokens: response.output_tokens, chunks: all_chunks })
}
async fn inject_passive_skills_stream(
request: &AiRequest, message_builder: &MessageBuilder,
tool_calls: &[StreamedToolCall], messages: &mut Vec<ChatRequestMessage>,
) {
if let Ok(skills) = project_skill::Entity::find()
.filter(project_skill::Column::ProjectUuid.eq(request.project.id))
.filter(project_skill::Column::Enabled.eq(true)).all(&request.db).await {
let skill_entries: Vec<SkillEntry> = skills.into_iter().map(|s| SkillEntry { slug: s.slug, name: s.name, description: s.description, content: s.content }).collect();
let tool_events: Vec<ToolCallEvent> = tool_calls.iter().map(|tc| ToolCallEvent { tool_name: tc.name.clone(), arguments: tc.arguments.clone() }).collect();
for event in &tool_events {
if let Some(ctx) = message_builder.perception_service.passive.detect(event, &skill_entries) {
messages.push(ctx.to_system_message());
}
}
}
}