gitdataai/libs/agent/chat/nonstreaming_execution.rs

291 lines
10 KiB
Rust

use models::projects::project_skill;
use models::rooms::room_ai;
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter};
use uuid::Uuid;
use super::AiRequest;
use super::message_builder::MessageBuilder;
use super::service::ProcessResult;
use super::session_recording::record_ai_session;
use crate::client::AiClientConfig;
use crate::client::types::{ChatRequestMessage, ToolCall};
use crate::error::Result;
use crate::perception::{SkillEntry, ToolCallEvent};
use crate::tool::{ToolCall as AgentToolCall, ToolContext, ToolExecutor};
pub async fn execute_process(
request: AiRequest,
message_builder: &MessageBuilder,
tool_registry: &Option<crate::tool::registry::ToolRegistry>,
ai_base_url: Option<String>,
ai_api_key: Option<String>,
) -> Result<ProcessResult> {
let tools: Vec<serde_json::Value> = request.tools.clone().unwrap_or_default();
let tools_enabled = !tools.is_empty();
let max_tool_depth = request.max_tool_depth;
let mut messages = message_builder.build_messages(&request).await?;
let room_ai_config = room_ai::Entity::find()
.filter(room_ai::Column::Room.eq(request.room.id))
.filter(room_ai::Column::Model.eq(request.model.id))
.one(&request.db)
.await?;
let model_name = request.model.name.clone();
let temperature = room_ai_config
.as_ref()
.and_then(|r| r.temperature.map(|v| v as f32))
.unwrap_or(request.temperature as f32);
let max_tokens = room_ai_config
.as_ref()
.and_then(|r| r.max_tokens.map(|v| v as u32))
.unwrap_or(request.max_tokens as u32);
let mut tool_depth = 0;
let mut input_tokens = 0i64;
let mut output_tokens = 0i64;
let session_id = Uuid::now_v7();
let session_start = std::time::Instant::now();
let version_id = room_ai_config.as_ref().and_then(|r| r.version);
let config = AiClientConfig::new(ai_api_key.unwrap_or_default())
.with_base_url(ai_base_url.unwrap_or_else(|| "https://api.openai.com".into()));
loop {
let response = crate::client::call_with_params(
&messages,
&model_name,
&config,
temperature,
max_tokens,
None,
if tools_enabled { Some(&tools) } else { None },
if tools_enabled { None } else { Some("none") },
)
.await?;
let text = response.content.clone();
input_tokens += response.input_tokens;
output_tokens += response.output_tokens;
if tools_enabled && !response.tool_calls_finished.is_empty() {
let tool_call_messages: Vec<_> = response
.tool_calls_finished
.iter()
.map(|name| ToolCall {
id: Uuid::new_v4().to_string(),
type_: "function".into(),
function: crate::client::types::ToolCallFunction {
name: name.clone(),
arguments: "{}".into(),
},
})
.collect();
messages.push(ChatRequestMessage::assistant(
Some(text.clone()),
Some(tool_call_messages.clone()),
));
let calls: Vec<AgentToolCall> = tool_call_messages
.into_iter()
.map(|tc| AgentToolCall {
id: tc.id.clone(),
name: tc.function.name.clone(),
arguments: tc.function.arguments.clone(),
})
.collect();
let tool_messages = execute_tools(
&request,
&calls,
session_id,
&response.tool_calls_finished,
tool_registry,
message_builder,
)
.await;
messages.extend(tool_messages);
inject_passive_skills(
&request,
message_builder,
&response.tool_calls_finished,
&mut messages,
)
.await;
tool_depth += 1;
if tool_depth >= max_tool_depth {
let content = if text.is_empty() {
format!(
"[AI reached maximum tool depth ({}) — no final answer produced]",
max_tool_depth
)
} else {
text
};
record_ai_session(
&request.cache,
&request.db,
request.project.id,
request.sender.uid,
session_id,
request.room.id,
request.model.id,
version_id.unwrap_or_default(),
input_tokens,
output_tokens,
session_start.elapsed().as_millis() as i64,
)
.await;
return Ok(ProcessResult {
content,
input_tokens,
output_tokens,
});
}
continue;
}
record_ai_session(
&request.cache,
&request.db,
request.project.id,
request.sender.uid,
session_id,
request.room.id,
request.model.id,
version_id.unwrap_or_default(),
input_tokens,
output_tokens,
session_start.elapsed().as_millis() as i64,
)
.await;
return Ok(ProcessResult {
content: text,
input_tokens,
output_tokens,
});
}
}
async fn execute_tools(
request: &AiRequest,
calls: &[AgentToolCall],
session_id: Uuid,
tool_names: &[String],
tool_registry: &Option<crate::tool::registry::ToolRegistry>,
message_builder: &MessageBuilder,
) -> Vec<ChatRequestMessage> {
let mut ctx = ToolContext::new(
request.db.clone(),
request.cache.clone(),
request.config.clone(),
request.room.id,
Some(request.sender.uid),
)
.with_project(request.project.id);
if let Some(ref es) = message_builder.embed_service {
ctx = ctx.with_embed_service(es.clone());
}
if let Some(registry) = tool_registry {
ctx.registry_mut().merge(registry.clone());
}
let recorder =
crate::tool::recorder::ToolCallRecorder::with_session(request.db.clone(), session_id);
let start = std::time::Instant::now();
let executor = ToolExecutor::new();
match executor.execute_batch(calls.to_vec(), &mut ctx).await {
Ok(results) => {
for (call, result) in tool_names.iter().zip(results.iter()) {
let elapsed = start.elapsed().as_millis() as i64;
let is_error = matches!(result.result, crate::tool::ToolResult::Error(_));
let error_msg = match &result.result {
crate::tool::ToolResult::Error(msg) => Some(msg.clone()),
_ => None,
};
recorder.record(crate::tool::recorder::ToolCallRecord {
tool_call_id: Uuid::new_v4().to_string(),
session_id: recorder.session_id(),
tool_name: call.clone(),
caller: request.sender.uid,
arguments: serde_json::Value::Null,
status: if is_error {
models::ai::ToolCallStatus::Failed
} else {
models::ai::ToolCallStatus::Success
},
execution_time_ms: Some(elapsed),
error_message: error_msg,
error_stack: None,
retry_count: 0,
});
}
crate::tool::ToolExecutor::to_tool_messages(&results)
}
Err(e) => {
let elapsed = start.elapsed().as_millis() as i64;
for call_name in tool_names {
recorder.record(crate::tool::recorder::ToolCallRecord {
tool_call_id: Uuid::new_v4().to_string(),
session_id: recorder.session_id(),
tool_name: call_name.clone(),
caller: request.sender.uid,
arguments: serde_json::Value::Null,
status: models::ai::ToolCallStatus::Failed,
execution_time_ms: Some(elapsed),
error_message: Some(e.to_string()),
error_stack: None,
retry_count: 0,
});
}
let err_msg = format!("[Tool call failed: {}]", e);
tool_names
.iter()
.map(|_| ChatRequestMessage::tool(Uuid::new_v4().to_string(), &err_msg))
.collect()
}
}
}
async fn inject_passive_skills(
request: &AiRequest,
message_builder: &MessageBuilder,
tool_names: &[String],
messages: &mut Vec<ChatRequestMessage>,
) {
if let Ok(skills) = project_skill::Entity::find()
.filter(project_skill::Column::ProjectUuid.eq(request.project.id))
.filter(project_skill::Column::Enabled.eq(true))
.all(&request.db)
.await
{
let skill_entries: Vec<SkillEntry> = skills
.into_iter()
.map(|s| SkillEntry {
slug: s.slug,
name: s.name,
description: s.description,
content: s.content,
})
.collect();
let tool_events: Vec<ToolCallEvent> = tool_names
.iter()
.map(|name| ToolCallEvent {
tool_name: name.clone(),
arguments: String::new(),
})
.collect();
for event in &tool_events {
if let Some(ctx) = message_builder
.perception_service
.passive
.detect(event, &skill_entries)
{
messages.push(ctx.to_system_message());
}
}
}
}