- Extract push_unique_skill_context method to MessageBuilder - Merge built-in skills with DB skills in passive injection - Simplify code structure for both streaming/nonstreaming execution
293 lines
10 KiB
Rust
293 lines
10 KiB
Rust
use models::projects::project_skill;
|
|
use models::rooms::room_ai;
|
|
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter};
|
|
use uuid::Uuid;
|
|
|
|
use super::AiRequest;
|
|
use super::message_builder::MessageBuilder;
|
|
use super::service::ProcessResult;
|
|
use super::session_recording::record_ai_session;
|
|
use crate::client::AiClientConfig;
|
|
use crate::client::types::ChatRequestMessage;
|
|
use crate::error::Result;
|
|
use crate::perception::{SkillEntry, ToolCallEvent};
|
|
use crate::tool::{ToolCall as AgentToolCall, ToolContext, ToolExecutor};
|
|
|
|
pub async fn execute_process(
|
|
request: AiRequest,
|
|
message_builder: &MessageBuilder,
|
|
tool_registry: &Option<crate::tool::registry::ToolRegistry>,
|
|
ai_base_url: Option<String>,
|
|
ai_api_key: Option<String>,
|
|
) -> Result<ProcessResult> {
|
|
let tools: Vec<serde_json::Value> = request.tools.clone().unwrap_or_default();
|
|
let tools_enabled = !tools.is_empty();
|
|
let max_tool_depth = request.max_tool_depth;
|
|
|
|
let mut messages = message_builder.build_messages(&request).await?;
|
|
|
|
let room_ai_config = room_ai::Entity::find()
|
|
.filter(room_ai::Column::Room.eq(request.room.id))
|
|
.filter(room_ai::Column::Model.eq(request.model.id))
|
|
.one(&request.db)
|
|
.await?;
|
|
|
|
let model_name = request.model.name.clone();
|
|
let profile = request.execution_profile.as_ref();
|
|
let temperature = profile
|
|
.and_then(|p| p.temperature.map(|v| v as f32))
|
|
.or_else(|| {
|
|
room_ai_config
|
|
.as_ref()
|
|
.and_then(|r| r.temperature.map(|v| v as f32))
|
|
})
|
|
.unwrap_or(request.temperature as f32);
|
|
let max_tokens = profile
|
|
.and_then(|p| p.max_tokens.map(|v| v as u32))
|
|
.or_else(|| {
|
|
room_ai_config
|
|
.as_ref()
|
|
.and_then(|r| r.max_tokens.map(|v| v as u32))
|
|
})
|
|
.unwrap_or(request.max_tokens as u32);
|
|
let mut tool_depth = 0;
|
|
let mut input_tokens = 0i64;
|
|
let mut output_tokens = 0i64;
|
|
let session_id = Uuid::now_v7();
|
|
let session_start = std::time::Instant::now();
|
|
let version_id = room_ai_config.as_ref().and_then(|r| r.version);
|
|
|
|
let config = AiClientConfig::new(ai_api_key.unwrap_or_default())
|
|
.with_base_url(ai_base_url.unwrap_or_else(|| "https://api.openai.com".into()));
|
|
|
|
loop {
|
|
let response = crate::client::call_with_params(
|
|
&messages,
|
|
&model_name,
|
|
&config,
|
|
temperature,
|
|
max_tokens,
|
|
None,
|
|
if tools_enabled { Some(&tools) } else { None },
|
|
if tools_enabled { None } else { Some("none") },
|
|
)
|
|
.await?;
|
|
|
|
let text = response.content.clone();
|
|
input_tokens += response.input_tokens;
|
|
output_tokens += response.output_tokens;
|
|
|
|
if tools_enabled && !response.tool_calls.is_empty() {
|
|
messages.push(ChatRequestMessage::assistant(
|
|
Some(text.clone()),
|
|
Some(response.tool_calls.clone()),
|
|
));
|
|
|
|
let calls: Vec<AgentToolCall> = response
|
|
.tool_calls
|
|
.iter()
|
|
.map(|tc| AgentToolCall {
|
|
id: tc.id.clone(),
|
|
name: tc.function.name.clone(),
|
|
arguments: tc.function.arguments.clone(),
|
|
})
|
|
.collect();
|
|
let tool_names: Vec<String> = calls.iter().map(|call| call.name.clone()).collect();
|
|
|
|
let tool_messages =
|
|
execute_tools(&request, &calls, session_id, tool_registry, message_builder).await;
|
|
messages.extend(tool_messages);
|
|
inject_passive_skills(&request, message_builder, &tool_names, &mut messages).await;
|
|
|
|
tool_depth += 1;
|
|
if tool_depth >= max_tool_depth {
|
|
let content = if text.is_empty() {
|
|
format!(
|
|
"[AI reached maximum tool depth ({}) — no final answer produced]",
|
|
max_tool_depth
|
|
)
|
|
} else {
|
|
text
|
|
};
|
|
record_ai_session(
|
|
&request.cache,
|
|
&request.db,
|
|
request.project.id,
|
|
request.sender.uid,
|
|
session_id,
|
|
request.room.id,
|
|
request.model.id,
|
|
version_id.unwrap_or_default(),
|
|
input_tokens,
|
|
output_tokens,
|
|
session_start.elapsed().as_millis() as i64,
|
|
)
|
|
.await;
|
|
return Ok(ProcessResult {
|
|
content,
|
|
input_tokens,
|
|
output_tokens,
|
|
});
|
|
}
|
|
continue;
|
|
}
|
|
|
|
record_ai_session(
|
|
&request.cache,
|
|
&request.db,
|
|
request.project.id,
|
|
request.sender.uid,
|
|
session_id,
|
|
request.room.id,
|
|
request.model.id,
|
|
version_id.unwrap_or_default(),
|
|
input_tokens,
|
|
output_tokens,
|
|
session_start.elapsed().as_millis() as i64,
|
|
)
|
|
.await;
|
|
return Ok(ProcessResult {
|
|
content: text,
|
|
input_tokens,
|
|
output_tokens,
|
|
});
|
|
}
|
|
}
|
|
|
|
async fn execute_tools(
|
|
request: &AiRequest,
|
|
calls: &[AgentToolCall],
|
|
session_id: Uuid,
|
|
tool_registry: &Option<crate::tool::registry::ToolRegistry>,
|
|
message_builder: &MessageBuilder,
|
|
) -> Vec<ChatRequestMessage> {
|
|
let mut ctx = ToolContext::new(
|
|
request.db.clone(),
|
|
request.cache.clone(),
|
|
request.config.clone(),
|
|
request.room.id,
|
|
Some(request.sender.uid),
|
|
)
|
|
.with_project(request.project.id);
|
|
if let Some(ref es) = message_builder.embed_service {
|
|
ctx = ctx.with_embed_service(es.clone());
|
|
}
|
|
if let Some(registry) = tool_registry {
|
|
ctx.registry_mut().merge(registry.clone());
|
|
}
|
|
|
|
let recorder =
|
|
crate::tool::recorder::ToolCallRecorder::with_session(request.db.clone(), session_id);
|
|
let start = std::time::Instant::now();
|
|
let executor = ToolExecutor::new();
|
|
match executor.execute_batch(calls.to_vec(), &mut ctx).await {
|
|
Ok(results) => {
|
|
for (call, result) in calls.iter().zip(results.iter()) {
|
|
let elapsed = start.elapsed().as_millis() as i64;
|
|
let is_error = matches!(result.result, crate::tool::ToolResult::Error(_));
|
|
let error_msg = match &result.result {
|
|
crate::tool::ToolResult::Error(msg) => Some(msg.clone()),
|
|
_ => None,
|
|
};
|
|
recorder.record(crate::tool::recorder::ToolCallRecord {
|
|
tool_call_id: Uuid::new_v4().to_string(),
|
|
session_id: recorder.session_id(),
|
|
tool_name: call.name.clone(),
|
|
caller: request.sender.uid,
|
|
arguments: call
|
|
.arguments_json()
|
|
.unwrap_or_else(|_| serde_json::Value::Null),
|
|
status: if is_error {
|
|
models::ai::ToolCallStatus::Failed
|
|
} else {
|
|
models::ai::ToolCallStatus::Success
|
|
},
|
|
execution_time_ms: Some(elapsed),
|
|
error_message: error_msg,
|
|
error_stack: None,
|
|
retry_count: 0,
|
|
});
|
|
}
|
|
crate::tool::ToolExecutor::to_tool_messages(&results)
|
|
}
|
|
Err(e) => {
|
|
let elapsed = start.elapsed().as_millis() as i64;
|
|
for call in calls {
|
|
recorder.record(crate::tool::recorder::ToolCallRecord {
|
|
tool_call_id: Uuid::new_v4().to_string(),
|
|
session_id: recorder.session_id(),
|
|
tool_name: call.name.clone(),
|
|
caller: request.sender.uid,
|
|
arguments: call
|
|
.arguments_json()
|
|
.unwrap_or_else(|_| serde_json::Value::Null),
|
|
status: models::ai::ToolCallStatus::Failed,
|
|
execution_time_ms: Some(elapsed),
|
|
error_message: Some(e.to_string()),
|
|
error_stack: None,
|
|
retry_count: 0,
|
|
});
|
|
}
|
|
let err_msg = format!("[Tool call failed: {}]", e);
|
|
calls
|
|
.iter()
|
|
.map(|_| ChatRequestMessage::tool(Uuid::new_v4().to_string(), &err_msg))
|
|
.collect()
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn inject_passive_skills(
|
|
request: &AiRequest,
|
|
message_builder: &MessageBuilder,
|
|
tool_names: &[String],
|
|
messages: &mut Vec<ChatRequestMessage>,
|
|
) {
|
|
if let Ok(skills) = project_skill::Entity::find()
|
|
.filter(project_skill::Column::ProjectUuid.eq(request.project.id))
|
|
.filter(project_skill::Column::Enabled.eq(true))
|
|
.all(&request.db)
|
|
.await
|
|
{
|
|
let mut skill_entries: Vec<SkillEntry> = skills
|
|
.into_iter()
|
|
.map(|s| SkillEntry {
|
|
slug: s.slug,
|
|
name: s.name,
|
|
description: s.description,
|
|
content: s.content,
|
|
})
|
|
.collect();
|
|
for built_in in crate::skills::all_skills() {
|
|
if !skill_entries.iter().any(|s| s.slug == built_in.slug) {
|
|
skill_entries.push(SkillEntry {
|
|
slug: built_in.slug.to_string(),
|
|
name: built_in.name.to_string(),
|
|
description: Some(built_in.description.to_string()),
|
|
content: built_in.content.clone(),
|
|
});
|
|
}
|
|
}
|
|
let tool_events: Vec<ToolCallEvent> = tool_names
|
|
.iter()
|
|
.map(|name| ToolCallEvent {
|
|
tool_name: name.clone(),
|
|
arguments: String::new(),
|
|
})
|
|
.collect();
|
|
let mut contexts = Vec::new();
|
|
for event in &tool_events {
|
|
if let Some(ctx) = message_builder
|
|
.perception_service
|
|
.passive
|
|
.detect(event, &skill_entries)
|
|
{
|
|
MessageBuilder::push_unique_skill_context(&mut contexts, ctx);
|
|
}
|
|
}
|
|
for ctx in contexts {
|
|
messages.push(ctx.to_system_message());
|
|
}
|
|
}
|
|
}
|