2088 lines
84 KiB
Rust
2088 lines
84 KiB
Rust
use futures::StreamExt;
|
|
use models::projects::project_skill;
|
|
use models::rooms::room_ai;
|
|
use rig::agent::{AgentBuilder, MultiTurnStreamItem};
|
|
use rig::client::CompletionClient;
|
|
use rig::completion::{CompletionModel, GetTokenUsage, Prompt};
|
|
use rig::streaming::{StreamedAssistantContent, StreamingPrompt};
|
|
use sea_orm::*;
|
|
use std::pin::Pin;
|
|
use std::sync::Arc;
|
|
use uuid::Uuid;
|
|
|
|
use super::context::RoomMessageContext;
|
|
use super::{AiChunkType, AiRequest, AiStreamChunk, Mention, StreamCallback};
|
|
use crate::billing;
|
|
use crate::client::AiClientConfig;
|
|
|
|
use crate::client::types::{ChatRequestMessage, ToolCall};
|
|
use crate::client::{
|
|
StreamChunk, StreamChunkType, StreamedToolCall, call_stream, call_with_params,
|
|
};
|
|
use crate::compact::{CompactConfig, CompactService};
|
|
use crate::embed::EmbedService;
|
|
use crate::error::{AgentError, Result};
|
|
use crate::perception::{PerceptionService, SkillEntry, ToolCallEvent};
|
|
use crate::react::{DEFAULT_SYSTEM_PROMPT, ReactStep};
|
|
use crate::react::types::Action as ReactAction;
|
|
use crate::tool::{
|
|
RecordingTool, ToolCall as AgentToolCall, ToolContext, ToolExecutor,
|
|
registry::ToolRegistry,
|
|
};
|
|
|
|
/// Result from streaming AI response.
|
|
pub struct StreamResult {
|
|
pub content: String,
|
|
pub reasoning_content: String,
|
|
pub input_tokens: i64,
|
|
pub output_tokens: i64,
|
|
/// All chunks in arrival order — preserves ReAct multi-cycle ordering.
|
|
pub chunks: Vec<StreamChunk>,
|
|
}
|
|
|
|
/// Result from non-streaming AI response.
|
|
pub struct ProcessResult {
|
|
pub content: String,
|
|
pub input_tokens: i64,
|
|
pub output_tokens: i64,
|
|
}
|
|
|
|
/// Record an AI session with cost calculation.
|
|
async fn record_ai_session(
|
|
cache: &db::cache::AppCache,
|
|
db: &db::database::AppDatabase,
|
|
project_id: Uuid,
|
|
session_id: Uuid,
|
|
room_id: Uuid,
|
|
model_id: Uuid,
|
|
version_id: Uuid,
|
|
input_tokens: i64,
|
|
output_tokens: i64,
|
|
latency_ms: i64,
|
|
) {
|
|
metrics::histogram!("ai_call_latency_ms", "model" => model_id.to_string()).record(latency_ms as f64);
|
|
|
|
let session = models::ai::ai_session::ActiveModel {
|
|
id: Set(session_id),
|
|
room: Set(room_id),
|
|
model: Set(model_id),
|
|
version: Set(version_id),
|
|
token_input: Set(input_tokens),
|
|
token_output: Set(output_tokens),
|
|
latency_ms: Set(Some(latency_ms)),
|
|
cost: Set(None),
|
|
currency: Set(None),
|
|
error_message: Set(None),
|
|
error_code: Set(None),
|
|
created_at: Set(chrono::Utc::now()),
|
|
};
|
|
|
|
if let Err(e) = session.insert(db).await {
|
|
tracing::error!(error = %e, session_id = %session_id, "failed to insert ai session record");
|
|
return;
|
|
}
|
|
|
|
let (cost, currency, error_msg) = match billing::record_ai_usage(
|
|
db,
|
|
project_id,
|
|
version_id,
|
|
input_tokens,
|
|
output_tokens,
|
|
)
|
|
.await
|
|
{
|
|
Ok(billing::BillingResult::Success(record)) => {
|
|
(Some(record.cost), Some(record.currency), None)
|
|
}
|
|
Ok(billing::BillingResult::InsufficientBalance { message }) => {
|
|
create_system_message(cache, db, room_id, &message).await;
|
|
(None, None, Some(message))
|
|
}
|
|
Err(e) => (None, None, Some(e.to_string())),
|
|
};
|
|
|
|
use sea_orm::sea_query::Expr;
|
|
let _ = models::ai::ai_session::Entity::update_many()
|
|
.col_expr(models::ai::ai_session::Column::Cost, Expr::value(cost))
|
|
.col_expr(models::ai::ai_session::Column::Currency, Expr::value(currency))
|
|
.col_expr(models::ai::ai_session::Column::ErrorMessage, Expr::value(error_msg))
|
|
.filter(models::ai::ai_session::Column::Id.eq(session_id))
|
|
.exec(db)
|
|
.await;
|
|
}
|
|
|
|
/// Create a system message in the room for billing errors.
|
|
async fn create_system_message(
|
|
cache: &db::cache::AppCache,
|
|
db: &db::database::AppDatabase,
|
|
room_id: Uuid,
|
|
message: &str,
|
|
) {
|
|
use models::rooms::{room_message, MessageSenderType, MessageContentType};
|
|
use sea_orm::Set;
|
|
|
|
let seq_key = format!("room:seq:{}", room_id);
|
|
let seq = match cache.conn().await {
|
|
Ok(mut conn) => {
|
|
match redis::cmd("INCR").arg(&seq_key).query_async::<i64>(&mut conn).await {
|
|
Ok(s) => s,
|
|
Err(e) => {
|
|
tracing::warn!(error = %e, "cache INCR failed for system message seq, falling back to DB");
|
|
let last_seq = match room_message::Entity::find()
|
|
.filter(room_message::Column::Room.eq(room_id))
|
|
.order_by_desc(room_message::Column::Seq)
|
|
.one(db)
|
|
.await
|
|
{
|
|
Ok(Some(m)) => m.seq,
|
|
Ok(None) => 0,
|
|
Err(e) => {
|
|
tracing::warn!(error = %e, "Failed to get last seq for system message");
|
|
return;
|
|
}
|
|
};
|
|
last_seq + 1
|
|
}
|
|
}
|
|
}
|
|
Err(e) => {
|
|
tracing::warn!(error = %e, "Failed to get Redis connection for system message seq");
|
|
return;
|
|
}
|
|
};
|
|
|
|
let now = chrono::Utc::now();
|
|
|
|
let result = room_message::ActiveModel {
|
|
id: Set(Uuid::now_v7()),
|
|
seq: Set(seq),
|
|
room: Set(room_id),
|
|
sender_type: Set(MessageSenderType::System),
|
|
sender_id: Set(None),
|
|
model_id: Set(None),
|
|
thread: Set(None),
|
|
in_reply_to: Set(None),
|
|
content: Set(message.to_string()),
|
|
content_type: Set(MessageContentType::Text),
|
|
thinking_content: Set(None),
|
|
edited_at: Set(None),
|
|
send_at: Set(now),
|
|
revoked: Set(None),
|
|
revoked_by: Set(None),
|
|
}
|
|
.insert(db)
|
|
.await;
|
|
|
|
match result {
|
|
Ok(_) => {
|
|
tracing::info!(
|
|
room_id = %room_id,
|
|
message = %message,
|
|
"system_message_created_for_billing_error"
|
|
);
|
|
}
|
|
Err(e) => {
|
|
tracing::warn!(
|
|
error = %e,
|
|
room_id = %room_id,
|
|
"Failed to create system message for billing error"
|
|
);
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Service for handling AI chat requests in rooms.
|
|
pub struct ChatService {
|
|
ai_base_url: Option<String>,
|
|
ai_api_key: Option<String>,
|
|
compact_service: Option<CompactService>,
|
|
embed_service: Option<EmbedService>,
|
|
perception_service: PerceptionService,
|
|
tool_registry: Option<ToolRegistry>,
|
|
}
|
|
|
|
impl ChatService {
|
|
pub fn new() -> Self {
|
|
Self {
|
|
ai_base_url: None,
|
|
ai_api_key: None,
|
|
compact_service: None,
|
|
embed_service: None,
|
|
perception_service: PerceptionService::default(),
|
|
tool_registry: None,
|
|
}
|
|
}
|
|
|
|
pub fn with_ai_client_config(mut self, config: AiClientConfig) -> Self {
|
|
self.ai_base_url = config.base_url.clone();
|
|
self.ai_api_key = Some(config.api_key.clone());
|
|
self
|
|
}
|
|
|
|
pub fn with_compact_service(mut self, compact_service: CompactService) -> Self {
|
|
self.compact_service = Some(compact_service);
|
|
self
|
|
}
|
|
|
|
pub fn with_embed_service(mut self, embed_service: EmbedService) -> Self {
|
|
self.embed_service = Some(embed_service);
|
|
self
|
|
}
|
|
|
|
pub fn with_perception_service(mut self, perception_service: PerceptionService) -> Self {
|
|
self.perception_service = perception_service;
|
|
self
|
|
}
|
|
|
|
pub fn with_tool_registry(mut self, registry: ToolRegistry) -> Self {
|
|
self.tool_registry = Some(registry);
|
|
self
|
|
}
|
|
|
|
/// Returns all registered tools as JSON tool definitions.
|
|
pub fn tools(&self) -> Vec<serde_json::Value> {
|
|
self.tool_registry
|
|
.as_ref()
|
|
.map(|r| r.to_openai_tools())
|
|
.unwrap_or_default()
|
|
}
|
|
|
|
/// Build a RigToolSet from the registered tool registry.
|
|
///
|
|
/// This enables using the same tools with `RigAgentService` via rig's native Agent.
|
|
/// The context (db, cache, config, room_id, sender_id) is passed through to each
|
|
/// tool handler at creation time.
|
|
#[cfg(feature = "rig")]
|
|
pub fn rig_toolset(
|
|
&self,
|
|
db: db::database::AppDatabase,
|
|
cache: db::cache::AppCache,
|
|
config: config::AppConfig,
|
|
room_id: uuid::Uuid,
|
|
sender_id: Option<uuid::Uuid>,
|
|
project_id: uuid::Uuid,
|
|
) -> Option<crate::RigToolSet> {
|
|
self.tool_registry.as_ref().map(|registry| {
|
|
crate::RigToolSet::from_registry(
|
|
registry, db, cache, config, room_id, sender_id, project_id,
|
|
)
|
|
})
|
|
}
|
|
|
|
/// Get a reference to the underlying ToolRegistry.
|
|
pub fn tool_registry(&self) -> Option<&ToolRegistry> {
|
|
self.tool_registry.as_ref()
|
|
}
|
|
|
|
pub async fn process(&self, request: AiRequest) -> Result<ProcessResult> {
|
|
let tools: Vec<serde_json::Value> = request.tools.clone().unwrap_or_default();
|
|
let tools_enabled = !tools.is_empty();
|
|
let max_tool_depth = request.max_tool_depth;
|
|
|
|
let mut messages = self.build_messages(&request).await?;
|
|
|
|
let room_ai = room_ai::Entity::find()
|
|
.filter(room_ai::Column::Room.eq(request.room.id))
|
|
.filter(room_ai::Column::Model.eq(request.model.id))
|
|
.one(&request.db)
|
|
.await?;
|
|
|
|
let model_name = request.model.name.clone();
|
|
let temperature = room_ai
|
|
.as_ref()
|
|
.and_then(|r| r.temperature.map(|v| v as f32))
|
|
.unwrap_or(request.temperature as f32);
|
|
let max_tokens = room_ai
|
|
.as_ref()
|
|
.and_then(|r| r.max_tokens.map(|v| v as u32))
|
|
.unwrap_or(request.max_tokens as u32);
|
|
let mut tool_depth = 0;
|
|
let mut input_tokens = 0i64;
|
|
let mut output_tokens = 0i64;
|
|
let session_id = Uuid::now_v7();
|
|
let session_start = std::time::Instant::now();
|
|
let version_id = room_ai.as_ref().and_then(|r| r.version);
|
|
|
|
let config = AiClientConfig::new(self.ai_api_key.clone().unwrap_or_default())
|
|
.with_base_url(
|
|
self.ai_base_url
|
|
.clone()
|
|
.unwrap_or_else(|| "https://api.openai.com".into()),
|
|
);
|
|
|
|
loop {
|
|
let response = call_with_params(
|
|
&messages,
|
|
&model_name,
|
|
&config,
|
|
temperature,
|
|
max_tokens,
|
|
None,
|
|
if tools_enabled { Some(&tools) } else { None },
|
|
if tools_enabled { None } else { Some("none") },
|
|
)
|
|
.await?;
|
|
|
|
let text = response.content.clone();
|
|
input_tokens += response.input_tokens;
|
|
output_tokens += response.output_tokens;
|
|
|
|
if tools_enabled && !response.tool_calls_finished.is_empty() {
|
|
// Build assistant message with tool_calls
|
|
let tool_call_messages: Vec<_> = response
|
|
.tool_calls_finished
|
|
.iter()
|
|
.map(|name| {
|
|
// We need ID and arguments — for non-streaming we reconstruct from content
|
|
// The model returns tool_calls in its content; for now we create a placeholder
|
|
// that will be replaced by actual tool results
|
|
ToolCall {
|
|
id: Uuid::new_v4().to_string(),
|
|
type_: "function".into(),
|
|
function: crate::client::types::ToolCallFunction {
|
|
name: name.clone(),
|
|
arguments: "{}".into(),
|
|
},
|
|
}
|
|
})
|
|
.collect();
|
|
|
|
messages.push(ChatRequestMessage::assistant(
|
|
Some(text.clone()),
|
|
Some(tool_call_messages.clone()),
|
|
));
|
|
|
|
// Create ToolCall list for executor (we need real IDs and args)
|
|
// Since we can't get args from streaming, use name matching from the text
|
|
let calls: Vec<AgentToolCall> = tool_call_messages
|
|
.into_iter()
|
|
.map(|tc| AgentToolCall {
|
|
id: tc.id.clone(),
|
|
name: tc.function.name.clone(),
|
|
arguments: tc.function.arguments.clone(),
|
|
})
|
|
.collect();
|
|
|
|
let tool_messages = {
|
|
let mut ctx = ToolContext::new(
|
|
request.db.clone(),
|
|
request.cache.clone(),
|
|
request.config.clone(),
|
|
request.room.id,
|
|
Some(request.sender.uid),
|
|
)
|
|
.with_project(request.project.id);
|
|
if let Some(ref es) = self.embed_service {
|
|
ctx = ctx.with_embed_service(es.clone());
|
|
}
|
|
if let Some(ref registry) = self.tool_registry {
|
|
ctx.registry_mut().merge(registry.clone());
|
|
}
|
|
|
|
let recorder = crate::tool::recorder::ToolCallRecorder::with_session(
|
|
request.db.clone(),
|
|
session_id,
|
|
);
|
|
let start = std::time::Instant::now();
|
|
|
|
let executor = ToolExecutor::new();
|
|
match executor.execute_batch(calls, &mut ctx).await {
|
|
Ok(results) => {
|
|
for (call, result) in
|
|
response.tool_calls_finished.iter().zip(results.iter())
|
|
{
|
|
let elapsed = start.elapsed().as_millis() as i64;
|
|
let is_error =
|
|
matches!(result.result, crate::tool::ToolResult::Error(_));
|
|
let error_msg = match &result.result {
|
|
crate::tool::ToolResult::Error(msg) => Some(msg.clone()),
|
|
_ => None,
|
|
};
|
|
recorder.record(crate::tool::recorder::ToolCallRecord {
|
|
tool_call_id: call.clone(),
|
|
session_id: recorder.session_id(),
|
|
tool_name: call.clone(),
|
|
caller: request.sender.uid,
|
|
arguments: serde_json::Value::Null,
|
|
status: if is_error {
|
|
models::ai::ToolCallStatus::Failed
|
|
} else {
|
|
models::ai::ToolCallStatus::Success
|
|
},
|
|
execution_time_ms: Some(elapsed),
|
|
error_message: error_msg,
|
|
error_stack: None,
|
|
retry_count: 0,
|
|
});
|
|
}
|
|
crate::tool::ToolExecutor::to_tool_messages(&results)
|
|
}
|
|
Err(e) => {
|
|
let elapsed = start.elapsed().as_millis() as i64;
|
|
for call_name in &response.tool_calls_finished {
|
|
recorder.record(crate::tool::recorder::ToolCallRecord {
|
|
tool_call_id: Uuid::new_v4().to_string(),
|
|
session_id: recorder.session_id(),
|
|
tool_name: call_name.clone(),
|
|
caller: request.sender.uid,
|
|
arguments: serde_json::Value::Null,
|
|
status: models::ai::ToolCallStatus::Failed,
|
|
execution_time_ms: Some(elapsed),
|
|
error_message: Some(e.to_string()),
|
|
error_stack: None,
|
|
retry_count: 0,
|
|
});
|
|
}
|
|
|
|
let err_msg = format!("[Tool call failed: {}]", e);
|
|
response
|
|
.tool_calls_finished
|
|
.iter()
|
|
.map(|_| {
|
|
ChatRequestMessage::tool(Uuid::new_v4().to_string(), &err_msg)
|
|
})
|
|
.collect()
|
|
}
|
|
}
|
|
};
|
|
messages.extend(tool_messages);
|
|
|
|
// Inject passive-detected skills based on tool calls
|
|
if let Ok(skills) = project_skill::Entity::find()
|
|
.filter(project_skill::Column::ProjectUuid.eq(request.project.id))
|
|
.filter(project_skill::Column::Enabled.eq(true))
|
|
.all(&request.db)
|
|
.await
|
|
{
|
|
let skill_entries: Vec<SkillEntry> = skills
|
|
.into_iter()
|
|
.map(|s| SkillEntry {
|
|
slug: s.slug,
|
|
name: s.name,
|
|
description: s.description,
|
|
content: s.content,
|
|
})
|
|
.collect();
|
|
let tool_events: Vec<ToolCallEvent> = response
|
|
.tool_calls_finished
|
|
.iter()
|
|
.map(|name| ToolCallEvent {
|
|
tool_name: name.clone(),
|
|
arguments: String::new(),
|
|
})
|
|
.collect();
|
|
for event in &tool_events {
|
|
if let Some(ctx) = self
|
|
.perception_service
|
|
.passive
|
|
.detect(event, &skill_entries)
|
|
{
|
|
messages.push(ctx.to_system_message());
|
|
}
|
|
}
|
|
}
|
|
|
|
tool_depth += 1;
|
|
if tool_depth >= max_tool_depth {
|
|
let content = if text.is_empty() {
|
|
format!(
|
|
"[AI reached maximum tool depth ({}) — no final answer produced]",
|
|
max_tool_depth
|
|
)
|
|
} else {
|
|
text
|
|
};
|
|
// Record session
|
|
record_ai_session(
|
|
&request.cache,
|
|
&request.db,
|
|
request.project.id,
|
|
session_id,
|
|
request.room.id,
|
|
request.model.id,
|
|
version_id.unwrap_or_default(),
|
|
input_tokens,
|
|
output_tokens,
|
|
session_start.elapsed().as_millis() as i64,
|
|
)
|
|
.await;
|
|
return Ok(ProcessResult {
|
|
content,
|
|
input_tokens,
|
|
output_tokens,
|
|
});
|
|
}
|
|
continue;
|
|
}
|
|
|
|
// Record session
|
|
record_ai_session(
|
|
&request.cache,
|
|
&request.db,
|
|
request.project.id,
|
|
session_id,
|
|
request.room.id,
|
|
request.model.id,
|
|
version_id.unwrap_or_default(),
|
|
input_tokens,
|
|
output_tokens,
|
|
session_start.elapsed().as_millis() as i64,
|
|
)
|
|
.await;
|
|
return Ok(ProcessResult {
|
|
content: text,
|
|
input_tokens,
|
|
output_tokens,
|
|
});
|
|
}
|
|
}
|
|
|
|
pub async fn process_stream(
|
|
&self,
|
|
request: AiRequest,
|
|
on_chunk: StreamCallback,
|
|
) -> Result<StreamResult> {
|
|
// Wrap on_chunk in Arc so it can be shared across loop iterations
|
|
let on_chunk = Arc::new(on_chunk);
|
|
let tools: Vec<serde_json::Value> = request.tools.clone().unwrap_or_default();
|
|
let tools_enabled = !tools.is_empty();
|
|
let max_tool_depth = request.max_tool_depth;
|
|
|
|
let mut messages = self.build_messages(&request).await?;
|
|
|
|
let room_ai = room_ai::Entity::find()
|
|
.filter(room_ai::Column::Room.eq(request.room.id))
|
|
.filter(room_ai::Column::Model.eq(request.model.id))
|
|
.one(&request.db)
|
|
.await?;
|
|
|
|
let model_name = request.model.name.clone();
|
|
let temperature = room_ai
|
|
.as_ref()
|
|
.and_then(|r| r.temperature.map(|v| v as f32))
|
|
.unwrap_or(request.temperature as f32);
|
|
let max_tokens = room_ai
|
|
.as_ref()
|
|
.and_then(|r| r.max_tokens.map(|v| v as u32))
|
|
.unwrap_or(request.max_tokens as u32);
|
|
let mut tool_depth = 0;
|
|
let mut total_input_tokens = 0i64;
|
|
let mut total_output_tokens = 0i64;
|
|
let session_id = Uuid::now_v7();
|
|
let session_start = std::time::Instant::now();
|
|
|
|
let version_id = room_ai.as_ref().and_then(|r| r.version);
|
|
|
|
let config = AiClientConfig::new(self.ai_api_key.clone().unwrap_or_default())
|
|
.with_base_url(
|
|
self.ai_base_url
|
|
.clone()
|
|
.unwrap_or_else(|| "https://api.openai.com".into()),
|
|
);
|
|
|
|
let mut full_content = String::new();
|
|
let mut all_chunks: Vec<StreamChunk> = Vec::new();
|
|
// Collect tool calls during streaming, push them incrementally after.
|
|
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<StreamedToolCall>();
|
|
|
|
loop {
|
|
let on_chunk_cb = on_chunk.clone();
|
|
let on_chunk_cb2 = on_chunk_cb.clone();
|
|
let tx_arc = Arc::new(tx.clone());
|
|
let tx_arc2 = tx_arc.clone();
|
|
let response = call_stream(
|
|
&messages,
|
|
&model_name,
|
|
&config,
|
|
temperature,
|
|
max_tokens,
|
|
if tools_enabled { Some(&tools) } else { None },
|
|
None, // tool_choice — auto (let model decide)
|
|
Arc::new(move |delta| {
|
|
let fut = on_chunk_cb(AiStreamChunk {
|
|
content: delta.to_string(),
|
|
done: false,
|
|
chunk_type: AiChunkType::Answer,
|
|
});
|
|
fut
|
|
}),
|
|
Arc::new(move |delta| {
|
|
let fut = on_chunk_cb2(AiStreamChunk {
|
|
content: delta.to_string(),
|
|
done: false,
|
|
chunk_type: AiChunkType::Thinking,
|
|
});
|
|
fut
|
|
}),
|
|
Arc::new(move |tc: &StreamedToolCall| {
|
|
let tx = tx_arc2.clone();
|
|
let tc_owned = tc.clone();
|
|
Box::pin(async move {
|
|
let _ = tx.send(tc_owned);
|
|
}) as Pin<Box<dyn std::future::Future<Output = ()> + Send>>
|
|
}),
|
|
)
|
|
.await?;
|
|
|
|
total_input_tokens += response.input_tokens;
|
|
total_output_tokens += response.output_tokens;
|
|
|
|
// Collect chunks from this streaming iteration in order.
|
|
all_chunks.extend(response.chunks);
|
|
|
|
let has_tool_calls = tools_enabled && !response.tool_calls.is_empty();
|
|
|
|
if has_tool_calls {
|
|
// Accumulate the assistant's text before tool calls
|
|
full_content.push_str(&response.content);
|
|
full_content.push('\n');
|
|
|
|
// Build assistant message with tool_calls from streaming response
|
|
let tool_calls: Vec<ToolCall> = response
|
|
.tool_calls
|
|
.iter()
|
|
.map(|tc| ToolCall {
|
|
id: tc.id.clone(),
|
|
type_: "function".into(),
|
|
function: crate::client::types::ToolCallFunction {
|
|
name: tc.name.clone(),
|
|
arguments: tc.arguments.clone(),
|
|
},
|
|
})
|
|
.collect();
|
|
|
|
messages.push(ChatRequestMessage::assistant(
|
|
Some(response.content.clone()),
|
|
Some(tool_calls.clone()),
|
|
));
|
|
|
|
// Push each tool call incrementally to frontend.
|
|
// Use try_recv() — tx is never dropped so recv() would deadlock.
|
|
loop {
|
|
match rx.try_recv() {
|
|
Ok(tc) => {
|
|
let args_display = if tc.arguments.len() > 100 {
|
|
let end = tc.arguments.char_indices().map(|(i, _)| i).take_while(|&i| i <= 100).last().unwrap_or(100);
|
|
format!("{}...", &tc.arguments[..end])
|
|
} else {
|
|
tc.arguments.clone()
|
|
};
|
|
let tool_display = format!("🔧 {}({})", tc.name, args_display);
|
|
on_chunk(AiStreamChunk {
|
|
content: tool_display.clone(),
|
|
done: false,
|
|
chunk_type: AiChunkType::ToolCall,
|
|
})
|
|
.await;
|
|
all_chunks.push(StreamChunk {
|
|
chunk_type: StreamChunkType::ToolCall,
|
|
content: tool_display,
|
|
});
|
|
}
|
|
Err(tokio::sync::mpsc::error::TryRecvError::Empty) => break,
|
|
Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => break,
|
|
}
|
|
}
|
|
|
|
// Execute tools one at a time, push each result incrementally
|
|
let calls: Vec<AgentToolCall> = response
|
|
.tool_calls
|
|
.iter()
|
|
.map(|tc| AgentToolCall {
|
|
id: tc.id.clone(),
|
|
name: tc.name.clone(),
|
|
arguments: tc.arguments.clone(),
|
|
})
|
|
.collect();
|
|
|
|
let mut tool_messages = Vec::new();
|
|
let mut ctx = crate::tool::ToolContext::new(
|
|
request.db.clone(),
|
|
request.cache.clone(),
|
|
request.config.clone(),
|
|
request.room.id,
|
|
Some(request.sender.uid),
|
|
)
|
|
.with_project(request.project.id);
|
|
if let Some(ref es) = self.embed_service {
|
|
ctx = ctx.with_embed_service(es.clone());
|
|
}
|
|
if let Some(ref registry) = self.tool_registry {
|
|
ctx.registry_mut().merge(registry.clone());
|
|
}
|
|
|
|
let recorder = crate::tool::recorder::ToolCallRecorder::with_session(
|
|
request.db.clone(),
|
|
session_id,
|
|
);
|
|
|
|
for call in &calls {
|
|
let start = std::time::Instant::now();
|
|
|
|
// Spawn tool execution in a separate task to avoid blocking the
|
|
// tokio worker thread (git2 operations are synchronous).
|
|
// This allows the heartbeat timer to fire independently.
|
|
let call_clone = call.clone();
|
|
let mut ctx_clone = ctx.clone();
|
|
let (result_tx, mut result_rx) = tokio::sync::oneshot::channel();
|
|
tokio::spawn(async move {
|
|
let executor = crate::tool::ToolExecutor::new();
|
|
let res = executor.execute_batch(vec![call_clone], &mut ctx_clone).await;
|
|
let _ = result_tx.send(res);
|
|
});
|
|
|
|
// Send heartbeats every 10s until tool execution completes
|
|
let heartbeat_dur = std::time::Duration::from_secs(10);
|
|
let results = loop {
|
|
tokio::select! {
|
|
res = &mut result_rx => {
|
|
match res {
|
|
Ok(inner) => break inner,
|
|
Err(_) => break Err(crate::tool::ToolError::ExecutionError("tool task cancelled".into())),
|
|
}
|
|
},
|
|
_ = tokio::time::sleep(heartbeat_dur) => {
|
|
on_chunk(AiStreamChunk {
|
|
content: String::new(),
|
|
done: false,
|
|
chunk_type: AiChunkType::ToolCall,
|
|
}).await;
|
|
}
|
|
}
|
|
};
|
|
|
|
let results = match results {
|
|
Ok(r) => r,
|
|
Err(e) => {
|
|
let elapsed = start.elapsed().as_millis() as i64;
|
|
recorder.record(crate::tool::recorder::ToolCallRecord {
|
|
tool_call_id: call.id.clone(),
|
|
session_id: recorder.session_id(),
|
|
tool_name: call.name.clone(),
|
|
caller: request.sender.uid,
|
|
arguments: call.arguments_json().unwrap_or_default(),
|
|
status: models::ai::ToolCallStatus::Failed,
|
|
execution_time_ms: Some(elapsed),
|
|
error_message: Some(e.to_string()),
|
|
error_stack: None,
|
|
retry_count: 0,
|
|
});
|
|
|
|
let err_text = format!("[Tool call failed: {}]", e);
|
|
tracing::warn!(tool = %call.name, args = %call.arguments, error = %e, "tool_call_failed");
|
|
let err_display = format!("❌ {} (failed)", call.name);
|
|
on_chunk(AiStreamChunk {
|
|
content: err_display.clone(),
|
|
done: false,
|
|
chunk_type: AiChunkType::ToolResult,
|
|
})
|
|
.await;
|
|
all_chunks.push(StreamChunk {
|
|
chunk_type: StreamChunkType::ToolCall,
|
|
content: err_display,
|
|
});
|
|
tool_messages.push(ChatRequestMessage::tool(&call.id, &err_text));
|
|
continue;
|
|
}
|
|
};
|
|
|
|
for result in &results {
|
|
let text = match &result.result {
|
|
crate::tool::ToolResult::Ok(v) => v.to_string(),
|
|
crate::tool::ToolResult::Error(msg) => msg.clone(),
|
|
};
|
|
let preview = if text.len() > 300 {
|
|
let end = text
|
|
.char_indices()
|
|
.map(|(i, _)| i)
|
|
.take_while(|&i| i <= 300)
|
|
.last()
|
|
.unwrap_or(300);
|
|
format!("{}...", &text[..end])
|
|
} else {
|
|
text.clone()
|
|
};
|
|
tracing::debug!("tool_result: {} — {}", call.name, preview);
|
|
|
|
let elapsed = start.elapsed().as_millis() as i64;
|
|
let is_error = matches!(result.result, crate::tool::ToolResult::Error(_));
|
|
let error_msg = match &result.result {
|
|
crate::tool::ToolResult::Error(msg) => Some(msg.clone()),
|
|
_ => None,
|
|
};
|
|
recorder.record(crate::tool::recorder::ToolCallRecord {
|
|
tool_call_id: call.id.clone(),
|
|
session_id: recorder.session_id(),
|
|
tool_name: call.name.clone(),
|
|
caller: request.sender.uid,
|
|
arguments: call.arguments_json().unwrap_or_default(),
|
|
status: if is_error {
|
|
models::ai::ToolCallStatus::Failed
|
|
} else {
|
|
models::ai::ToolCallStatus::Success
|
|
},
|
|
execution_time_ms: Some(elapsed),
|
|
error_message: error_msg,
|
|
error_stack: None,
|
|
retry_count: 0,
|
|
});
|
|
// Do NOT emit tool_result chunks to frontend — raw output may contain sensitive data.
|
|
// Log server-side only; frontend sees tool_call status via on_chunk below.
|
|
}
|
|
let success_display = format!("✅ {}", call.name);
|
|
on_chunk(AiStreamChunk {
|
|
content: success_display.clone(),
|
|
done: false,
|
|
chunk_type: AiChunkType::ToolResult,
|
|
})
|
|
.await;
|
|
all_chunks.push(StreamChunk {
|
|
chunk_type: StreamChunkType::ToolCall,
|
|
content: success_display,
|
|
});
|
|
|
|
let msgs = crate::tool::ToolExecutor::to_tool_messages(&results);
|
|
tool_messages.extend(msgs);
|
|
}
|
|
messages.extend(tool_messages);
|
|
|
|
// Inject passive-detected skills based on tool calls
|
|
if let Ok(skills) = project_skill::Entity::find()
|
|
.filter(project_skill::Column::ProjectUuid.eq(request.project.id))
|
|
.filter(project_skill::Column::Enabled.eq(true))
|
|
.all(&request.db)
|
|
.await
|
|
{
|
|
let skill_entries: Vec<SkillEntry> = skills
|
|
.into_iter()
|
|
.map(|s| SkillEntry {
|
|
slug: s.slug,
|
|
name: s.name,
|
|
description: s.description,
|
|
content: s.content,
|
|
})
|
|
.collect();
|
|
let tool_events: Vec<ToolCallEvent> = response
|
|
.tool_calls
|
|
.iter()
|
|
.map(|tc| ToolCallEvent {
|
|
tool_name: tc.name.clone(),
|
|
arguments: tc.arguments.clone(),
|
|
})
|
|
.collect();
|
|
for event in &tool_events {
|
|
if let Some(ctx) = self
|
|
.perception_service
|
|
.passive
|
|
.detect(event, &skill_entries)
|
|
{
|
|
messages.push(ctx.to_system_message());
|
|
}
|
|
}
|
|
}
|
|
|
|
tool_depth += 1;
|
|
if tool_depth >= max_tool_depth {
|
|
let max_depth_text = format!(
|
|
"[AI reached maximum tool depth ({}) — no final answer produced]",
|
|
max_tool_depth
|
|
);
|
|
on_chunk(AiStreamChunk {
|
|
content: max_depth_text.clone(),
|
|
done: true,
|
|
chunk_type: AiChunkType::Answer,
|
|
})
|
|
.await;
|
|
all_chunks.push(StreamChunk {
|
|
chunk_type: StreamChunkType::Answer,
|
|
content: max_depth_text,
|
|
});
|
|
// Record session
|
|
record_ai_session(
|
|
&request.cache,
|
|
&request.db,
|
|
request.project.id,
|
|
session_id,
|
|
request.room.id,
|
|
request.model.id,
|
|
version_id.unwrap_or_default(),
|
|
total_input_tokens,
|
|
total_output_tokens,
|
|
session_start.elapsed().as_millis() as i64,
|
|
)
|
|
.await;
|
|
return Ok(StreamResult {
|
|
content: full_content,
|
|
reasoning_content: String::new(),
|
|
input_tokens: 0,
|
|
output_tokens: 0,
|
|
chunks: all_chunks,
|
|
});
|
|
}
|
|
continue;
|
|
}
|
|
|
|
// Final answer — accumulate and return
|
|
full_content.push_str(&response.content);
|
|
|
|
on_chunk(AiStreamChunk {
|
|
content: response.content.clone(),
|
|
done: true,
|
|
chunk_type: AiChunkType::Answer,
|
|
})
|
|
.await;
|
|
all_chunks.push(StreamChunk {
|
|
chunk_type: StreamChunkType::Answer,
|
|
content: response.content.clone(),
|
|
});
|
|
// Record session
|
|
record_ai_session(
|
|
&request.cache,
|
|
&request.db,
|
|
request.project.id,
|
|
session_id,
|
|
request.room.id,
|
|
request.model.id,
|
|
version_id.unwrap_or_default(),
|
|
total_input_tokens,
|
|
total_output_tokens,
|
|
session_start.elapsed().as_millis() as i64,
|
|
)
|
|
.await;
|
|
return Ok(StreamResult {
|
|
content: full_content,
|
|
reasoning_content: response.reasoning_content,
|
|
input_tokens: response.input_tokens,
|
|
output_tokens: response.output_tokens,
|
|
chunks: all_chunks,
|
|
});
|
|
}
|
|
}
|
|
|
|
async fn build_messages(&self, request: &AiRequest) -> Result<Vec<ChatRequestMessage>> {
|
|
let mut messages = Vec::new();
|
|
|
|
// Core reasoning instruction — prioritize analysis before answering.
|
|
messages.push(ChatRequestMessage::system(
|
|
"When receiving a question or problem, follow this reasoning process:\n\
|
|
1. ANALYZE: Break down the question. Identify what is being asked, what context is available, and what information is missing.\n\
|
|
2. GATHER: Use available tools (repository search, file reading, etc.) to collect relevant information before answering.\n\
|
|
3. REASON: Synthesize the gathered information. Consider edge cases and potential issues.\n\
|
|
4. ANSWER: Provide a clear, actionable answer based on your analysis.\n\
|
|
\n\
|
|
Do NOT guess or assume when tools can provide concrete answers. Always verify claims against actual code or documentation.".to_string()
|
|
));
|
|
|
|
let mut processed_history = Vec::new();
|
|
if let Some(compact_service) = &self.compact_service {
|
|
let compact_cache_key = format!("ai:compact:{}", request.room.id);
|
|
let compact_config = CompactConfig::default();
|
|
|
|
// Try cached compaction summary (avoids re-compacting same history)
|
|
let cached_summary: Option<String> = {
|
|
let conn_result = request.cache.conn().await;
|
|
match conn_result {
|
|
Ok(mut conn) => {
|
|
redis::cmd("GET")
|
|
.arg(&compact_cache_key)
|
|
.query_async::<Option<String>>(&mut conn)
|
|
.await
|
|
.unwrap_or(None)
|
|
}
|
|
Err(e) => {
|
|
tracing::warn!(error = %e, "compact cache: conn failed");
|
|
None
|
|
}
|
|
}
|
|
};
|
|
|
|
if let Some(cached_json) = cached_summary {
|
|
if let Ok(summary) = serde_json::from_str::<crate::compact::CompactSummary>(&cached_json) {
|
|
if !summary.summary.is_empty() {
|
|
messages.push(ChatRequestMessage::system(format!(
|
|
"Conversation summary:\n{}",
|
|
summary.summary
|
|
)));
|
|
}
|
|
processed_history = summary.retained;
|
|
}
|
|
}
|
|
|
|
if processed_history.is_empty() {
|
|
match compact_service
|
|
.compact_room_auto(request.room.id, Some(request.user_names.clone()), compact_config)
|
|
.await
|
|
{
|
|
Ok(compact_summary) => {
|
|
if !compact_summary.summary.is_empty() {
|
|
messages.push(ChatRequestMessage::system(format!(
|
|
"Conversation summary:\n{}",
|
|
compact_summary.summary
|
|
)));
|
|
}
|
|
// Cache for subsequent calls (5 min TTL)
|
|
if let Ok(json) = serde_json::to_string(&compact_summary) {
|
|
if let Ok(mut conn) = request.cache.conn().await {
|
|
let _ = redis::cmd("SETEX")
|
|
.arg(&compact_cache_key)
|
|
.arg(300u64)
|
|
.arg(&json)
|
|
.query_async::<()>(&mut conn)
|
|
.await
|
|
.inspect_err(|e| {
|
|
tracing::warn!(error = %e, "compact cache: SETEX failed");
|
|
});
|
|
}
|
|
}
|
|
processed_history = compact_summary.retained;
|
|
}
|
|
Err(e) => {
|
|
tracing::warn!(error = %e, "conversation compaction failed, using full history");
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if !processed_history.is_empty() {
|
|
for msg_summary in processed_history {
|
|
let ctx = RoomMessageContext::from(msg_summary);
|
|
messages.push(ctx.to_message());
|
|
}
|
|
} else {
|
|
for msg in &request.history {
|
|
let ctx = RoomMessageContext::from_model_with_names(msg, &request.user_names);
|
|
messages.push(ctx.to_message());
|
|
}
|
|
}
|
|
|
|
for mention in &request.mention {
|
|
match mention {
|
|
Mention::Repo(repo) => {
|
|
// Inject repo details into system prompt so AI knows the repo context
|
|
let mut parts = vec![
|
|
format!("Name: {}", repo.repo_name),
|
|
format!("ID: {}", repo.id),
|
|
];
|
|
if let Some(ref desc) = repo.description {
|
|
parts.push(format!("Description: {}", desc));
|
|
}
|
|
parts.push(format!("Default branch: {}", repo.default_branch));
|
|
parts.push(format!(
|
|
"Private: {}",
|
|
if repo.is_private { "yes" } else { "no" }
|
|
));
|
|
parts.push(format!("Created: {}", repo.created_at.format("%Y-%m-%d")));
|
|
messages.push(ChatRequestMessage::system(format!(
|
|
"Mentioned repository:\n{}",
|
|
parts.join("\n")
|
|
)));
|
|
|
|
// Vector search for related issues and repos (enhancement, optional)
|
|
if let Some(embed_service) = &self.embed_service {
|
|
let query = format!(
|
|
"{} {}",
|
|
repo.repo_name,
|
|
repo.description.as_deref().unwrap_or_default()
|
|
);
|
|
if let Ok(issues) = embed_service.search_issues(&query, 5).await {
|
|
if !issues.is_empty() {
|
|
let context = format!(
|
|
"Related issues for repo {}:\n{}",
|
|
repo.repo_name,
|
|
issues
|
|
.iter()
|
|
.map(|i| format!("- {}", i.payload.text))
|
|
.collect::<Vec<_>>()
|
|
.join("\n")
|
|
);
|
|
messages.push(ChatRequestMessage::system(context));
|
|
}
|
|
}
|
|
if let Ok(repos) = embed_service.search_repos(&query, 3).await {
|
|
if !repos.is_empty() {
|
|
let context = format!(
|
|
"Similar repositories:\n{}",
|
|
repos
|
|
.iter()
|
|
.map(|r| format!("- {}", r.payload.text))
|
|
.collect::<Vec<_>>()
|
|
.join("\n")
|
|
);
|
|
messages.push(ChatRequestMessage::system(context));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
Mention::User(user) => {
|
|
let mut profile_parts = vec![format!("Username: {}", user.username)];
|
|
if let Some(ref display_name) = user.display_name {
|
|
profile_parts.push(format!("Display name: {}", display_name));
|
|
}
|
|
if let Some(ref org) = user.organization {
|
|
profile_parts.push(format!("Organization: {}", org));
|
|
}
|
|
if let Some(ref website) = user.website_url {
|
|
profile_parts.push(format!("Website: {}", website));
|
|
}
|
|
messages.push(ChatRequestMessage::system(format!(
|
|
"Mentioned user profile:\n{}",
|
|
profile_parts.join("\n")
|
|
)));
|
|
}
|
|
}
|
|
}
|
|
|
|
let skill_contexts = self.build_skill_context(request).await;
|
|
for ctx in skill_contexts {
|
|
messages.push(ctx.to_system_message());
|
|
}
|
|
|
|
let memories = self.build_memory_context(request).await;
|
|
for mem in memories {
|
|
messages.push(mem.to_system_message());
|
|
}
|
|
|
|
messages.push(ChatRequestMessage::system(format!(
|
|
"Current Project:\n{}\nDescription: {}\nPublic: {}",
|
|
request.project.display_name,
|
|
request.project.description.as_deref().unwrap_or("(none)"),
|
|
if request.project.is_public {
|
|
"yes"
|
|
} else {
|
|
"no"
|
|
}
|
|
)));
|
|
|
|
let mut sender_parts = vec![format!("**Sender:** {}", request.sender.username)];
|
|
if let Some(ref display_name) = request.sender.display_name {
|
|
sender_parts.push(display_name.clone());
|
|
}
|
|
if let Some(ref org) = request.sender.organization {
|
|
sender_parts.push(format!("({})", org));
|
|
}
|
|
let sender_display = sender_parts.join(" ");
|
|
messages.push(ChatRequestMessage::system(format!(
|
|
"The person sending the next message:\n{}",
|
|
sender_display
|
|
)));
|
|
|
|
messages.push(ChatRequestMessage::user(&request.input));
|
|
|
|
Ok(messages)
|
|
}
|
|
|
|
async fn build_skill_context(
|
|
&self,
|
|
request: &AiRequest,
|
|
) -> Vec<crate::perception::SkillContext> {
|
|
// Load database skills for this project
|
|
let db_skills: Vec<SkillEntry> = match project_skill::Entity::find()
|
|
.filter(project_skill::Column::ProjectUuid.eq(request.project.id))
|
|
.filter(project_skill::Column::Enabled.eq(true))
|
|
.all(&request.db)
|
|
.await
|
|
{
|
|
Ok(models) => models
|
|
.into_iter()
|
|
.map(|s| SkillEntry {
|
|
slug: s.slug,
|
|
name: s.name,
|
|
description: s.description,
|
|
content: s.content,
|
|
})
|
|
.collect(),
|
|
Err(_) => Vec::new(),
|
|
};
|
|
|
|
// Load built-in skills and merge with database skills
|
|
// Built-in skills override database skills with the same slug
|
|
let mut all_skills: Vec<SkillEntry> = db_skills;
|
|
for built_in in crate::skills::all_skills() {
|
|
// Skip if a database skill with the same slug already exists
|
|
if !all_skills.iter().any(|s| s.slug == built_in.slug) {
|
|
all_skills.push(SkillEntry {
|
|
slug: built_in.slug.to_string(),
|
|
name: built_in.name.to_string(),
|
|
description: Some(built_in.description.to_string()),
|
|
content: built_in.content.clone(),
|
|
});
|
|
}
|
|
}
|
|
|
|
if all_skills.is_empty() {
|
|
return Vec::new();
|
|
}
|
|
|
|
let history_texts: Vec<String> = request
|
|
.history
|
|
.iter()
|
|
.rev()
|
|
.take(10)
|
|
.map(|msg| msg.content.clone())
|
|
.collect();
|
|
|
|
let tool_events: Vec<ToolCallEvent> = Vec::new();
|
|
let keyword_skills = self
|
|
.perception_service
|
|
.inject_skills(&request.input, &history_texts, &tool_events, &all_skills)
|
|
.await;
|
|
|
|
let mut vector_skills = Vec::new();
|
|
if let Some(embed_service) = &self.embed_service {
|
|
let awareness = crate::perception::VectorActiveAwareness::default();
|
|
vector_skills = awareness
|
|
.detect(
|
|
embed_service,
|
|
&request.input,
|
|
&request.project.id.to_string(),
|
|
)
|
|
.await;
|
|
}
|
|
|
|
let mut seen = std::collections::HashSet::new();
|
|
let mut result = Vec::new();
|
|
for ctx in vector_skills {
|
|
if seen.insert(ctx.label.clone()) {
|
|
result.push(ctx);
|
|
}
|
|
}
|
|
for ctx in keyword_skills {
|
|
if seen.insert(ctx.label.clone()) {
|
|
result.push(ctx);
|
|
}
|
|
}
|
|
|
|
result
|
|
}
|
|
|
|
async fn build_memory_context(
|
|
&self,
|
|
request: &AiRequest,
|
|
) -> Vec<crate::perception::vector::MemoryContext> {
|
|
let embed_service = match &self.embed_service {
|
|
Some(s) => s,
|
|
None => return Vec::new(),
|
|
};
|
|
|
|
let awareness = crate::perception::VectorPassiveAwareness::default();
|
|
awareness
|
|
.detect(
|
|
embed_service,
|
|
&request.input,
|
|
&request.project.display_name,
|
|
&request.room.id.to_string(),
|
|
)
|
|
.await
|
|
}
|
|
|
|
pub async fn process_react<C, Fut>(&self, request: &AiRequest, mut on_chunk: C) -> Result<(String, i64, i64)>
|
|
where
|
|
C: FnMut(crate::react::ReactStep) -> Fut + Send,
|
|
Fut: std::future::Future<Output = ()> + Send,
|
|
{
|
|
let base_url = self
|
|
.ai_base_url
|
|
.clone()
|
|
.unwrap_or_else(|| "https://api.openai.com".into());
|
|
let api_key = self.ai_api_key.clone().unwrap_or_default();
|
|
let client_config = AiClientConfig::new(api_key).with_base_url(base_url);
|
|
|
|
let Some(registry) = &self.tool_registry else {
|
|
return Err(AgentError::Internal("no tool registry registered".into()));
|
|
};
|
|
|
|
let db = request.db.clone();
|
|
let cache = request.cache.clone();
|
|
let cfg = request.config.clone();
|
|
let room_id = request.room.id;
|
|
let sender_uid = request.sender.uid;
|
|
let project_id = request.project.id;
|
|
let session_id = Uuid::now_v7();
|
|
let session_start = std::time::Instant::now();
|
|
let version_id = room_ai::Entity::find()
|
|
.filter(room_ai::Column::Room.eq(request.room.id))
|
|
.filter(room_ai::Column::Model.eq(request.model.id))
|
|
.one(&request.db)
|
|
.await
|
|
.ok()
|
|
.flatten()
|
|
.and_then(|r| r.version);
|
|
|
|
// Build rig tools with recording wrapper directly from registry
|
|
let mut tools: Vec<Box<dyn rig::tool::ToolDyn + 'static>> = Vec::new();
|
|
for def in registry.definitions() {
|
|
let name = def.name.clone();
|
|
if let Some(handler) = registry.get(&name) {
|
|
let adapter = crate::tool::RigToolAdapter::new(
|
|
handler.clone(),
|
|
def.clone(),
|
|
db.clone(),
|
|
cache.clone(),
|
|
cfg.clone(),
|
|
room_id,
|
|
Some(sender_uid),
|
|
project_id,
|
|
);
|
|
tools.push(Box::new(RecordingTool::new(
|
|
Box::new(adapter),
|
|
db.clone(),
|
|
session_id,
|
|
sender_uid,
|
|
)));
|
|
}
|
|
}
|
|
|
|
// Build rig agent (handles multi-turn tool calls natively)
|
|
let rig_client = client_config.build_rig_client();
|
|
let model = rig_client.completion_model(&request.model.name);
|
|
let agent = AgentBuilder::new(model)
|
|
.preamble(DEFAULT_SYSTEM_PROMPT)
|
|
.tools(tools)
|
|
.default_max_turns(request.max_tool_depth)
|
|
.build();
|
|
|
|
let stream = agent
|
|
.stream_prompt(&request.input)
|
|
.with_history(Vec::new())
|
|
.multi_turn(request.max_tool_depth)
|
|
.await;
|
|
|
|
tokio::pin!(stream);
|
|
|
|
let mut step_count = 0usize;
|
|
let mut final_content = String::new();
|
|
let mut total_input_tokens: i64 = 0;
|
|
let mut total_output_tokens: i64 = 0;
|
|
|
|
while let Some(item) = stream.next().await {
|
|
match item {
|
|
Ok(MultiTurnStreamItem::StreamAssistantItem(
|
|
StreamedAssistantContent::Text(text),
|
|
)) => {
|
|
step_count += 1;
|
|
let t = text.text;
|
|
on_chunk(ReactStep::Answer {
|
|
step: step_count,
|
|
answer: t.clone(),
|
|
})
|
|
.await;
|
|
final_content.push_str(&t);
|
|
}
|
|
Ok(MultiTurnStreamItem::StreamAssistantItem(
|
|
StreamedAssistantContent::Reasoning(reasoning),
|
|
)) => {
|
|
let reasoning_text = reasoning.reasoning.join("");
|
|
if !reasoning_text.is_empty() {
|
|
step_count += 1;
|
|
on_chunk(ReactStep::Thought {
|
|
step: step_count,
|
|
thought: reasoning_text,
|
|
})
|
|
.await;
|
|
}
|
|
}
|
|
Ok(MultiTurnStreamItem::StreamAssistantItem(
|
|
StreamedAssistantContent::ReasoningDelta { reasoning, .. },
|
|
)) => {
|
|
if !reasoning.is_empty() {
|
|
step_count += 1;
|
|
on_chunk(ReactStep::Thought {
|
|
step: step_count,
|
|
thought: reasoning,
|
|
})
|
|
.await;
|
|
}
|
|
}
|
|
Ok(MultiTurnStreamItem::StreamAssistantItem(
|
|
StreamedAssistantContent::ToolCall { tool_call, .. },
|
|
)) => {
|
|
step_count += 1;
|
|
let args: serde_json::Value = match &tool_call.function.arguments {
|
|
serde_json::Value::String(s) => {
|
|
serde_json::from_str(s).unwrap_or(serde_json::Value::Null)
|
|
}
|
|
v => v.clone(),
|
|
};
|
|
on_chunk(ReactStep::Action {
|
|
step: step_count,
|
|
action: ReactAction::new(&tool_call.function.name, args),
|
|
})
|
|
.await;
|
|
}
|
|
Ok(MultiTurnStreamItem::StreamUserItem(
|
|
rig::streaming::StreamedUserContent::ToolResult { tool_result, .. },
|
|
)) => {
|
|
step_count += 1;
|
|
let obs = tool_result_content_to_string(&tool_result.content);
|
|
on_chunk(ReactStep::Observation {
|
|
step: step_count,
|
|
observation: obs,
|
|
})
|
|
.await;
|
|
}
|
|
Ok(MultiTurnStreamItem::FinalResponse(resp)) => {
|
|
let usage = resp.usage();
|
|
total_input_tokens = usage.input_tokens as i64;
|
|
total_output_tokens = usage.output_tokens as i64;
|
|
// Text was already streamed incrementally via Answer events.
|
|
}
|
|
Err(e) => {
|
|
let err_msg = format!("rig agent stream error: {}", e);
|
|
return Err(AgentError::OpenAi(err_msg));
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
|
|
let elapsed_ms = session_start.elapsed().as_millis() as i64;
|
|
record_ai_session(
|
|
&request.cache,
|
|
&request.db,
|
|
request.project.id,
|
|
session_id,
|
|
request.room.id,
|
|
request.model.id,
|
|
version_id.unwrap_or_default(),
|
|
total_input_tokens,
|
|
total_output_tokens,
|
|
elapsed_ms,
|
|
)
|
|
.await;
|
|
|
|
Ok((final_content, total_input_tokens, total_output_tokens))
|
|
}
|
|
|
|
// ── CoT (Chain-of-Thought) ────────────────────────────────────────────
|
|
|
|
/// Run a CoT (Chain-of-Thought) reasoning cycle — step-by-step reasoning with optional tool use.
|
|
pub async fn process_cot<C, Fut>(&self, request: &AiRequest, mut on_chunk: C) -> Result<(String, i64, i64)>
|
|
where
|
|
C: FnMut(crate::modes::cot::CotStep) -> Fut + Send,
|
|
Fut: std::future::Future<Output = ()> + Send,
|
|
{
|
|
let client_config = AiClientConfig::new(
|
|
self.ai_api_key.clone().unwrap_or_default(),
|
|
)
|
|
.with_base_url(
|
|
self.ai_base_url
|
|
.clone()
|
|
.unwrap_or_else(|| "https://api.openai.com".into()),
|
|
);
|
|
let rig_client = client_config.build_rig_client();
|
|
|
|
let Some(registry) = &self.tool_registry else {
|
|
return Err(AgentError::Internal("no tool registry registered".into()));
|
|
};
|
|
|
|
let session_id = Uuid::now_v7();
|
|
let session_start = std::time::Instant::now();
|
|
let version_id = room_ai::Entity::find()
|
|
.filter(room_ai::Column::Room.eq(request.room.id))
|
|
.filter(room_ai::Column::Model.eq(request.model.id))
|
|
.one(&request.db)
|
|
.await
|
|
.ok()
|
|
.flatten()
|
|
.and_then(|r| r.version);
|
|
|
|
let db = request.db.clone();
|
|
let cache = request.cache.clone();
|
|
let cfg = request.config.clone();
|
|
let room_id = request.room.id;
|
|
let sender_uid = request.sender.uid;
|
|
let project_id = request.project.id;
|
|
|
|
let mut tools: Vec<Box<dyn rig::tool::ToolDyn + 'static>> = Vec::new();
|
|
for def in registry.definitions() {
|
|
let name = def.name.clone();
|
|
if let Some(handler) = registry.get(&name) {
|
|
let adapter = crate::tool::RigToolAdapter::new(
|
|
handler.clone(), def.clone(),
|
|
db.clone(), cache.clone(), cfg.clone(),
|
|
room_id, Some(sender_uid), project_id,
|
|
);
|
|
tools.push(Box::new(crate::tool::RecordingTool::new(
|
|
Box::new(adapter), db.clone(), session_id, sender_uid,
|
|
)));
|
|
}
|
|
}
|
|
|
|
let model = rig_client.completion_model(&request.model.name);
|
|
let agent = AgentBuilder::new(model)
|
|
.preamble(crate::modes::cot::COT_SYSTEM_PROMPT)
|
|
.tools(tools)
|
|
.default_max_turns(request.max_tool_depth)
|
|
.build();
|
|
|
|
let stream = agent
|
|
.stream_prompt(&request.input)
|
|
.with_history(Vec::new())
|
|
.multi_turn(request.max_tool_depth)
|
|
.await;
|
|
|
|
tokio::pin!(stream);
|
|
|
|
let mut final_content = String::new();
|
|
let mut total_input_tokens: i64 = 0;
|
|
let mut total_output_tokens: i64 = 0;
|
|
|
|
while let Some(item) = stream.next().await {
|
|
match item {
|
|
Ok(MultiTurnStreamItem::StreamAssistantItem(
|
|
StreamedAssistantContent::Text(text),
|
|
)) => {
|
|
let t = text.text;
|
|
on_chunk(crate::modes::cot::CotStep::Answer(t.clone())).await;
|
|
final_content.push_str(&t);
|
|
}
|
|
Ok(MultiTurnStreamItem::StreamAssistantItem(
|
|
StreamedAssistantContent::Reasoning(reasoning),
|
|
)) => {
|
|
let r = reasoning.reasoning.join("");
|
|
if !r.is_empty() {
|
|
on_chunk(crate::modes::cot::CotStep::Thought(r)).await;
|
|
}
|
|
}
|
|
Ok(MultiTurnStreamItem::StreamAssistantItem(
|
|
StreamedAssistantContent::ReasoningDelta { reasoning, .. },
|
|
)) => {
|
|
if !reasoning.is_empty() {
|
|
on_chunk(crate::modes::cot::CotStep::Thought(reasoning)).await;
|
|
}
|
|
}
|
|
Ok(MultiTurnStreamItem::StreamAssistantItem(
|
|
StreamedAssistantContent::ToolCall { tool_call, .. },
|
|
)) => {
|
|
let args: serde_json::Value = match &tool_call.function.arguments {
|
|
serde_json::Value::String(s) => {
|
|
serde_json::from_str(s).unwrap_or(serde_json::Value::Null)
|
|
}
|
|
v => v.clone(),
|
|
};
|
|
on_chunk(crate::modes::cot::CotStep::Action {
|
|
name: tool_call.function.name.clone(),
|
|
args,
|
|
}).await;
|
|
}
|
|
Ok(MultiTurnStreamItem::StreamUserItem(
|
|
rig::streaming::StreamedUserContent::ToolResult { tool_result, .. },
|
|
)) => {
|
|
let obs = tool_result_content_to_string(&tool_result.content);
|
|
on_chunk(crate::modes::cot::CotStep::Observation(obs)).await;
|
|
}
|
|
Ok(MultiTurnStreamItem::FinalResponse(resp)) => {
|
|
let usage = resp.usage();
|
|
total_input_tokens = usage.input_tokens as i64;
|
|
total_output_tokens = usage.output_tokens as i64;
|
|
}
|
|
Err(e) => {
|
|
return Err(AgentError::OpenAi(e.to_string()));
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
|
|
let elapsed_ms = session_start.elapsed().as_millis() as i64;
|
|
record_ai_session(
|
|
&request.cache, &request.db, request.project.id,
|
|
session_id, request.room.id, request.model.id,
|
|
version_id.unwrap_or_default(),
|
|
total_input_tokens, total_output_tokens, elapsed_ms,
|
|
).await;
|
|
|
|
Ok((final_content, total_input_tokens, total_output_tokens))
|
|
}
|
|
|
|
// ── ReWOO (Plan → Execute → Synthesize) ───────────────────────────────
|
|
|
|
/// Run a ReWOO reasoning cycle: model plans tool calls, they are executed,
|
|
/// then the model synthesises the final answer.
|
|
pub async fn process_rewoo<C, Fut>(&self, request: &AiRequest, mut on_chunk: C) -> Result<(String, i64, i64)>
|
|
where
|
|
C: FnMut(crate::modes::rewoo::ReWooStep) -> Fut + Send,
|
|
Fut: std::future::Future<Output = ()> + Send,
|
|
{
|
|
let client_config = AiClientConfig::new(
|
|
self.ai_api_key.clone().unwrap_or_default(),
|
|
)
|
|
.with_base_url(
|
|
self.ai_base_url
|
|
.clone()
|
|
.unwrap_or_else(|| "https://api.openai.com".into()),
|
|
);
|
|
let rig_client = client_config.build_rig_client();
|
|
|
|
let Some(registry) = &self.tool_registry else {
|
|
return Err(AgentError::Internal("no tool registry registered".into()));
|
|
};
|
|
|
|
let session_id = Uuid::now_v7();
|
|
let session_start = std::time::Instant::now();
|
|
let version_id = room_ai::Entity::find()
|
|
.filter(room_ai::Column::Room.eq(request.room.id))
|
|
.filter(room_ai::Column::Model.eq(request.model.id))
|
|
.one(&request.db)
|
|
.await
|
|
.ok()
|
|
.flatten()
|
|
.and_then(|r| r.version);
|
|
|
|
let mut total_input_tokens: i64 = 0;
|
|
let mut total_output_tokens: i64 = 0;
|
|
|
|
let mut messages = self.build_messages(request).await?;
|
|
messages.insert(0, crate::client::types::ChatRequestMessage::system(
|
|
crate::modes::rewoo::REWOO_SYSTEM_PROMPT.to_string(),
|
|
));
|
|
let model = rig_client.completion_model(&request.model.name);
|
|
|
|
let plan_tools = {
|
|
let db = request.db.clone();
|
|
let cache = request.cache.clone();
|
|
let cfg = request.config.clone();
|
|
let room_id = request.room.id;
|
|
let sender_uid = request.sender.uid;
|
|
let project_id = request.project.id;
|
|
|
|
let mut tools: Vec<Box<dyn rig::tool::ToolDyn + 'static>> = Vec::new();
|
|
for def in registry.definitions() {
|
|
let name = def.name.clone();
|
|
if let Some(handler) = registry.get(&name) {
|
|
let adapter = crate::tool::RigToolAdapter::new(
|
|
handler.clone(),
|
|
def.clone(),
|
|
db.clone(),
|
|
cache.clone(),
|
|
cfg.clone(),
|
|
room_id,
|
|
Some(sender_uid),
|
|
project_id,
|
|
);
|
|
tools.push(Box::new(crate::tool::RecordingTool::new(
|
|
Box::new(adapter),
|
|
db.clone(),
|
|
session_id,
|
|
sender_uid,
|
|
)));
|
|
}
|
|
}
|
|
tools
|
|
};
|
|
|
|
let plan_agent = rig::agent::AgentBuilder::new(model)
|
|
.preamble(crate::modes::rewoo::REWOO_SYSTEM_PROMPT)
|
|
.tools(plan_tools)
|
|
.default_max_turns(1)
|
|
.build();
|
|
|
|
let plan_response = plan_agent
|
|
.prompt(&request.input)
|
|
.extended_details()
|
|
.await
|
|
.map_err(|e| AgentError::OpenAi(e.to_string()))?;
|
|
|
|
total_input_tokens += plan_response.total_usage.input_tokens as i64;
|
|
total_output_tokens += plan_response.total_usage.output_tokens as i64;
|
|
|
|
let plan = crate::modes::rewoo::extract_plan(&plan_response.output)
|
|
.unwrap_or_default();
|
|
|
|
if plan.calls.is_empty() {
|
|
on_chunk(crate::modes::rewoo::ReWooStep::Synthesis(plan_response.output.clone())).await;
|
|
let elapsed_ms = session_start.elapsed().as_millis() as i64;
|
|
record_ai_session(
|
|
&request.cache, &request.db, request.project.id,
|
|
session_id, request.room.id, request.model.id,
|
|
version_id.unwrap_or_default(),
|
|
total_input_tokens, total_output_tokens, elapsed_ms,
|
|
).await;
|
|
return Ok((plan_response.output, total_input_tokens, total_output_tokens));
|
|
}
|
|
|
|
on_chunk(crate::modes::rewoo::ReWooStep::Plan {
|
|
calls: plan.calls.clone(),
|
|
raw: plan.raw_text,
|
|
}).await;
|
|
|
|
// ── Phase 2: Execute all tool calls in parallel ───────────────────
|
|
let mut tool_results: Vec<(String, String)> = Vec::new();
|
|
let mut handles = Vec::new();
|
|
|
|
for call in &plan.calls {
|
|
let ctx = crate::tool::ToolContext::new(
|
|
request.db.clone(),
|
|
request.cache.clone(),
|
|
request.config.clone(),
|
|
request.room.id,
|
|
Some(request.sender.uid),
|
|
)
|
|
.with_project(request.project.id);
|
|
if let Some(ref es) = self.embed_service {
|
|
// ctx = ctx.with_embed_service(es.clone()); -- not clone-able via pattern, skip
|
|
let _ = es;
|
|
}
|
|
|
|
let call_id = call.step.to_string();
|
|
let tool_name = call.tool.clone();
|
|
let args = call.args.clone();
|
|
let ctx_clone = ctx.clone();
|
|
|
|
let handle = tokio::spawn(async move {
|
|
let executor = crate::tool::ToolExecutor::new();
|
|
let agent_call = crate::tool::ToolCall {
|
|
id: call_id,
|
|
name: tool_name.clone(),
|
|
arguments: args.to_string(),
|
|
};
|
|
let mut local_ctx = ctx_clone;
|
|
let result = executor.execute_batch(vec![agent_call], &mut local_ctx).await;
|
|
match result {
|
|
Ok(results) => {
|
|
for r in &results {
|
|
match &r.result {
|
|
crate::tool::ToolResult::Ok(v) => {
|
|
return (tool_name, v.to_string());
|
|
}
|
|
crate::tool::ToolResult::Error(e) => {
|
|
return (tool_name, format!("[Error: {}]", e));
|
|
}
|
|
}
|
|
}
|
|
(tool_name, "[No result]".to_string())
|
|
}
|
|
Err(e) => (tool_name, format!("[Execution error: {}]", e)),
|
|
}
|
|
});
|
|
handles.push(handle);
|
|
}
|
|
|
|
for handle in handles {
|
|
match handle.await {
|
|
Ok((name, result)) => {
|
|
on_chunk(crate::modes::rewoo::ReWooStep::Execution {
|
|
tool_name: name.clone(),
|
|
result: result.clone(),
|
|
}).await;
|
|
tool_results.push((name, result));
|
|
}
|
|
Err(e) => {
|
|
let msg = format!("[Task panicked: {}]", e);
|
|
on_chunk(crate::modes::rewoo::ReWooStep::Execution {
|
|
tool_name: "unknown".into(),
|
|
result: msg.clone(),
|
|
}).await;
|
|
tool_results.push(("unknown".into(), msg));
|
|
}
|
|
}
|
|
}
|
|
|
|
// ── Phase 3: Synthesize ───────────────────────────────────────────
|
|
let mut synth_messages = self.build_messages(request).await?;
|
|
synth_messages.insert(0, crate::client::types::ChatRequestMessage::system(
|
|
crate::modes::rewoo::REWOO_SYSTEM_PROMPT.to_string(),
|
|
));
|
|
|
|
let results_summary: String = tool_results
|
|
.iter()
|
|
.map(|(name, res)| format!("- {}:\n{}", name, res))
|
|
.collect::<Vec<_>>()
|
|
.join("\n");
|
|
|
|
synth_messages.push(crate::client::types::ChatRequestMessage::system(format!(
|
|
"## Tool Execution Results\n\nThe following tool calls were executed:\n\n{}\n\nNow synthesize your final answer based on these results.",
|
|
results_summary
|
|
)));
|
|
synth_messages.push(crate::client::types::ChatRequestMessage::user(&request.input));
|
|
|
|
let preamble = synth_messages
|
|
.iter()
|
|
.find(|m| m.role == "system")
|
|
.and_then(|m| m.content.as_deref())
|
|
.unwrap_or("")
|
|
.to_string();
|
|
let non_system: Vec<_> = synth_messages
|
|
.iter()
|
|
.filter(|m| m.role != "system")
|
|
.map(|m| crate::client::to_rig_message(m))
|
|
.collect();
|
|
|
|
let synth_model = rig_client.completion_model(&request.model.name);
|
|
let synth_stream = synth_model
|
|
.completion_request("")
|
|
.preamble(preamble)
|
|
.messages(non_system)
|
|
.temperature(request.temperature as f64)
|
|
.max_tokens(request.max_tokens as u64)
|
|
.stream()
|
|
.await
|
|
.map_err(|e| AgentError::OpenAi(e.to_string()))?;
|
|
|
|
use rig::streaming::StreamedAssistantContent;
|
|
tokio::pin!(synth_stream);
|
|
|
|
let mut synthesis = String::new();
|
|
while let Some(item) = synth_stream.next().await {
|
|
match item {
|
|
Ok(StreamedAssistantContent::Text(text)) => {
|
|
let t = text.text;
|
|
on_chunk(crate::modes::rewoo::ReWooStep::Synthesis(t.clone())).await;
|
|
synthesis.push_str(&t);
|
|
}
|
|
Ok(StreamedAssistantContent::Final(response)) => {
|
|
if let Some(usage) = response.token_usage() {
|
|
total_input_tokens += usage.input_tokens as i64;
|
|
total_output_tokens += usage.output_tokens as i64;
|
|
}
|
|
}
|
|
Err(e) => return Err(AgentError::OpenAi(e.to_string())),
|
|
_ => {}
|
|
}
|
|
}
|
|
|
|
let elapsed_ms = session_start.elapsed().as_millis() as i64;
|
|
record_ai_session(
|
|
&request.cache, &request.db, request.project.id,
|
|
session_id, request.room.id, request.model.id,
|
|
version_id.unwrap_or_default(),
|
|
total_input_tokens, total_output_tokens, elapsed_ms,
|
|
).await;
|
|
|
|
Ok((synthesis, total_input_tokens, total_output_tokens))
|
|
}
|
|
|
|
// ── Reflexion (Generate → Critique → Revise) ──────────────────────────
|
|
|
|
/// Run a Reflexion reasoning cycle: generate → critique → revise (up to 3 rounds).
|
|
pub async fn process_reflexion<C, Fut>(
|
|
&self,
|
|
request: &AiRequest,
|
|
mut on_chunk: C,
|
|
max_cycles: usize,
|
|
) -> Result<(String, i64, i64)>
|
|
where
|
|
C: FnMut(crate::modes::reflexion::ReflexionStep) -> Fut + Send,
|
|
Fut: std::future::Future<Output = ()> + Send,
|
|
{
|
|
let client_config = AiClientConfig::new(
|
|
self.ai_api_key.clone().unwrap_or_default(),
|
|
)
|
|
.with_base_url(
|
|
self.ai_base_url
|
|
.clone()
|
|
.unwrap_or_else(|| "https://api.openai.com".into()),
|
|
);
|
|
let rig_client = client_config.build_rig_client();
|
|
let Some(registry) = &self.tool_registry else {
|
|
return Err(AgentError::Internal("no tool registry registered".into()));
|
|
};
|
|
|
|
let session_id = Uuid::now_v7();
|
|
let session_start = std::time::Instant::now();
|
|
let version_id = room_ai::Entity::find()
|
|
.filter(room_ai::Column::Room.eq(request.room.id))
|
|
.filter(room_ai::Column::Model.eq(request.model.id))
|
|
.one(&request.db)
|
|
.await
|
|
.ok()
|
|
.flatten()
|
|
.and_then(|r| r.version);
|
|
|
|
let max_cycles = max_cycles.min(3);
|
|
|
|
let mut total_input_tokens: i64 = 0;
|
|
let mut total_output_tokens: i64 = 0;
|
|
let mut best_answer = String::new();
|
|
|
|
for cycle in 0..max_cycles {
|
|
let mut messages = self.build_messages(request).await?;
|
|
messages.insert(0, crate::client::types::ChatRequestMessage::system(
|
|
crate::modes::reflexion::REFLEXION_SYSTEM_PROMPT.to_string(),
|
|
));
|
|
|
|
if cycle > 0 {
|
|
messages.push(crate::client::types::ChatRequestMessage::system(format!(
|
|
"This is cycle {} of the reflexion process. Your previous answer was:\n\n{}\n\nPlease critique and improve upon it.",
|
|
cycle + 1,
|
|
best_answer
|
|
)));
|
|
}
|
|
|
|
// Build tools for this cycle (not cloneable, so rebuild each iteration)
|
|
let cycle_tools = build_rig_tools(
|
|
registry, &request.db, &request.cache, &request.config,
|
|
request.room.id, request.sender.uid, request.project.id, session_id,
|
|
);
|
|
|
|
// ── Generate ──────────────────────────────────────────────
|
|
let model = rig_client.completion_model(&request.model.name);
|
|
let agent = rig::agent::AgentBuilder::new(model)
|
|
.preamble(crate::modes::reflexion::REFLEXION_SYSTEM_PROMPT)
|
|
.tools(cycle_tools)
|
|
.default_max_turns(request.max_tool_depth)
|
|
.build();
|
|
|
|
let stream = agent
|
|
.stream_prompt(&request.input)
|
|
.with_history(Vec::new())
|
|
.multi_turn(request.max_tool_depth)
|
|
.await;
|
|
|
|
tokio::pin!(stream);
|
|
let mut generated = String::new();
|
|
|
|
while let Some(item) = stream.next().await {
|
|
match item {
|
|
Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem(
|
|
rig::streaming::StreamedAssistantContent::Text(text),
|
|
)) => {
|
|
generated.push_str(&text.text);
|
|
}
|
|
Ok(rig::agent::MultiTurnStreamItem::FinalResponse(resp)) => {
|
|
let usage = resp.usage();
|
|
total_input_tokens += usage.input_tokens as i64;
|
|
total_output_tokens += usage.output_tokens as i64;
|
|
}
|
|
Err(e) => return Err(AgentError::OpenAi(e.to_string())),
|
|
_ => {}
|
|
}
|
|
}
|
|
|
|
best_answer = generated.clone();
|
|
on_chunk(crate::modes::reflexion::ReflexionStep::Generate(generated.clone())).await;
|
|
|
|
// If only 1 cycle, emit final and exit
|
|
if max_cycles == 1 || cycle + 1 >= max_cycles {
|
|
on_chunk(crate::modes::reflexion::ReflexionStep::Final(generated.clone())).await;
|
|
break;
|
|
}
|
|
|
|
// ── Self-critique ─────────────────────────────────────────
|
|
let critique_messages = vec![
|
|
crate::client::types::ChatRequestMessage::system(crate::modes::reflexion::REFLEXION_SYSTEM_PROMPT),
|
|
crate::client::types::ChatRequestMessage::system(format!(
|
|
"Your previous answer was:\n\n{}", generated
|
|
)),
|
|
crate::client::types::ChatRequestMessage::user(crate::modes::reflexion::REFLEXION_CRITIQUE_PROMPT),
|
|
];
|
|
|
|
let critique_result = crate::client::call_with_params(
|
|
&critique_messages,
|
|
&request.model.name,
|
|
&client_config,
|
|
request.temperature as f32,
|
|
request.max_tokens as u32,
|
|
None,
|
|
None,
|
|
Some("none"),
|
|
).await?;
|
|
|
|
total_input_tokens += critique_result.input_tokens;
|
|
total_output_tokens += critique_result.output_tokens;
|
|
let critique = critique_result.content;
|
|
on_chunk(crate::modes::reflexion::ReflexionStep::Critique(critique.clone())).await;
|
|
|
|
// ── Revise ───────────────────────────────────────────────
|
|
let revise_messages = vec![
|
|
crate::client::types::ChatRequestMessage::user(format!(
|
|
"Your previous answer:\n\n{}\n\nYour self-critique:\n\n{}",
|
|
generated, critique
|
|
)),
|
|
crate::client::types::ChatRequestMessage::user(crate::modes::reflexion::REFLEXION_REVISE_PROMPT),
|
|
];
|
|
|
|
let revise_model = rig_client.completion_model(&request.model.name);
|
|
let revise_stream = revise_model
|
|
.completion_request("")
|
|
.preamble(crate::modes::reflexion::REFLEXION_SYSTEM_PROMPT.to_string())
|
|
.messages(revise_messages.iter().map(|m| {
|
|
crate::client::to_rig_message(m)
|
|
}).collect::<Vec<_>>())
|
|
.temperature(request.temperature as f64)
|
|
.max_tokens(request.max_tokens as u64)
|
|
.stream()
|
|
.await
|
|
.map_err(|e| AgentError::OpenAi(e.to_string()))?;
|
|
|
|
tokio::pin!(revise_stream);
|
|
let mut revised = String::new();
|
|
|
|
while let Some(item) = revise_stream.next().await {
|
|
match item {
|
|
Ok(rig::streaming::StreamedAssistantContent::Text(text)) => {
|
|
revised.push_str(&text.text);
|
|
}
|
|
Ok(rig::streaming::StreamedAssistantContent::Final(response)) => {
|
|
if let Some(usage) = response.token_usage() {
|
|
total_input_tokens += usage.input_tokens as i64;
|
|
total_output_tokens += usage.output_tokens as i64;
|
|
}
|
|
}
|
|
Err(e) => return Err(AgentError::OpenAi(e.to_string())),
|
|
_ => {}
|
|
}
|
|
}
|
|
|
|
best_answer = revised.clone();
|
|
on_chunk(crate::modes::reflexion::ReflexionStep::Revise(revised.clone())).await;
|
|
|
|
// If last cycle, emit final
|
|
if cycle + 1 >= max_cycles {
|
|
on_chunk(crate::modes::reflexion::ReflexionStep::Final(revised.clone())).await;
|
|
}
|
|
}
|
|
|
|
let elapsed_ms = session_start.elapsed().as_millis() as i64;
|
|
record_ai_session(
|
|
&request.cache, &request.db, request.project.id,
|
|
session_id, request.room.id, request.model.id,
|
|
version_id.unwrap_or_default(),
|
|
total_input_tokens, total_output_tokens, elapsed_ms,
|
|
).await;
|
|
|
|
Ok((best_answer, total_input_tokens, total_output_tokens))
|
|
}
|
|
}
|
|
|
|
fn build_rig_tools(
|
|
registry: &crate::tool::ToolRegistry,
|
|
db: &db::database::AppDatabase,
|
|
cache: &db::cache::AppCache,
|
|
cfg: &config::AppConfig,
|
|
room_id: uuid::Uuid,
|
|
sender_uid: uuid::Uuid,
|
|
project_id: uuid::Uuid,
|
|
session_id: uuid::Uuid,
|
|
) -> Vec<Box<dyn rig::tool::ToolDyn + 'static>> {
|
|
let mut tools: Vec<Box<dyn rig::tool::ToolDyn + 'static>> = Vec::new();
|
|
for def in registry.definitions() {
|
|
let name = def.name.clone();
|
|
if let Some(handler) = registry.get(&name) {
|
|
let adapter = crate::tool::RigToolAdapter::new(
|
|
handler.clone(), def.clone(),
|
|
db.clone(), cache.clone(), cfg.clone(),
|
|
room_id, Some(sender_uid), project_id,
|
|
);
|
|
tools.push(Box::new(crate::tool::RecordingTool::new(
|
|
Box::new(adapter), db.clone(), session_id, sender_uid,
|
|
)));
|
|
}
|
|
}
|
|
tools
|
|
}
|
|
|
|
/// Extract text from rig's ToolResultContent, ignoring images.
|
|
fn tool_result_content_to_string(content: &rig::one_or_many::OneOrMany<rig::completion::message::ToolResultContent>) -> String {
|
|
use rig::completion::message::ToolResultContent;
|
|
content
|
|
.iter()
|
|
.filter_map(|item| {
|
|
if let ToolResultContent::Text(t) = item {
|
|
Some(t.text.clone())
|
|
} else {
|
|
None
|
|
}
|
|
})
|
|
.collect::<Vec<_>>()
|
|
.join("\n")
|
|
}
|