gitdataai/libs/agent/chat/service.rs
ZhenYi 08045eef63 refactor(agent): enhance chat service with state management and billing
Add persistent chat session state (ChatState, sequence tracking, tool
calls). Introduce basic billing record in agent crate. Refine chat
service to route messages through state machine with tool support.
2026-04-30 19:16:44 +08:00

2081 lines
84 KiB
Rust

use futures::StreamExt;
use models::projects::project_skill;
use models::rooms::room_ai;
use rig::agent::{AgentBuilder, MultiTurnStreamItem};
use rig::client::CompletionClient;
use rig::completion::{CompletionModel, GetTokenUsage, Prompt};
use rig::streaming::{StreamedAssistantContent, StreamingPrompt};
use sea_orm::*;
use std::pin::Pin;
use std::sync::Arc;
use uuid::Uuid;
use super::context::RoomMessageContext;
use super::{AiChunkType, AiRequest, AiStreamChunk, Mention, StreamCallback};
use crate::billing;
use crate::client::AiClientConfig;
use crate::client::types::{ChatRequestMessage, ToolCall};
use crate::client::{
StreamChunk, StreamChunkType, StreamedToolCall, call_stream, call_with_params,
};
use crate::compact::{CompactConfig, CompactService};
use crate::embed::EmbedService;
use crate::error::{AgentError, Result};
use crate::perception::{PerceptionService, SkillEntry, ToolCallEvent};
use crate::react::{DEFAULT_SYSTEM_PROMPT, ReactStep};
use crate::react::types::Action as ReactAction;
use crate::tool::{
RecordingTool, ToolCall as AgentToolCall, ToolContext, ToolExecutor,
registry::ToolRegistry,
};
/// Result from streaming AI response.
pub struct StreamResult {
pub content: String,
pub reasoning_content: String,
pub input_tokens: i64,
pub output_tokens: i64,
/// All chunks in arrival order — preserves ReAct multi-cycle ordering.
pub chunks: Vec<StreamChunk>,
}
/// Result from non-streaming AI response.
pub struct ProcessResult {
pub content: String,
pub input_tokens: i64,
pub output_tokens: i64,
}
/// Record an AI session with cost calculation.
async fn record_ai_session(
cache: &db::cache::AppCache,
db: &db::database::AppDatabase,
project_id: Uuid,
session_id: Uuid,
room_id: Uuid,
model_id: Uuid,
version_id: Uuid,
input_tokens: i64,
output_tokens: i64,
latency_ms: i64,
) {
metrics::histogram!("ai_call_latency_ms", "model" => model_id.to_string()).record(latency_ms as f64);
let session = models::ai::ai_session::ActiveModel {
id: Set(session_id),
room: Set(room_id),
model: Set(model_id),
version: Set(version_id),
token_input: Set(input_tokens),
token_output: Set(output_tokens),
latency_ms: Set(Some(latency_ms)),
cost: Set(None),
currency: Set(None),
error_message: Set(None),
error_code: Set(None),
created_at: Set(chrono::Utc::now()),
};
if let Err(e) = session.insert(db).await {
tracing::error!(error = %e, session_id = %session_id, "failed to insert ai session record");
return;
}
let (cost, currency, error_msg) = match billing::record_ai_usage(
db,
project_id,
version_id,
input_tokens,
output_tokens,
)
.await
{
Ok(billing::BillingResult::Success(record)) => {
(Some(record.cost), Some(record.currency), None)
}
Ok(billing::BillingResult::InsufficientBalance { message }) => {
create_system_message(cache, db, room_id, &message).await;
(None, None, Some(message))
}
Err(e) => (None, None, Some(e.to_string())),
};
use sea_orm::sea_query::Expr;
let _ = models::ai::ai_session::Entity::update_many()
.col_expr(models::ai::ai_session::Column::Cost, Expr::value(cost))
.col_expr(models::ai::ai_session::Column::Currency, Expr::value(currency))
.col_expr(models::ai::ai_session::Column::ErrorMessage, Expr::value(error_msg))
.filter(models::ai::ai_session::Column::Id.eq(session_id))
.exec(db)
.await;
}
/// Create a system message in the room for billing errors.
async fn create_system_message(
cache: &db::cache::AppCache,
db: &db::database::AppDatabase,
room_id: Uuid,
message: &str,
) {
use models::rooms::{room_message, MessageSenderType, MessageContentType};
use sea_orm::Set;
let seq_key = format!("room:seq:{}", room_id);
let seq = match cache.conn().await {
Ok(mut conn) => {
match redis::cmd("INCR").arg(&seq_key).query_async::<i64>(&mut conn).await {
Ok(s) => s,
Err(e) => {
tracing::warn!(error = %e, "cache INCR failed for system message seq, falling back to DB");
let last_seq = match room_message::Entity::find()
.filter(room_message::Column::Room.eq(room_id))
.order_by_desc(room_message::Column::Seq)
.one(db)
.await
{
Ok(Some(m)) => m.seq,
Ok(None) => 0,
Err(e) => {
tracing::warn!(error = %e, "Failed to get last seq for system message");
return;
}
};
last_seq + 1
}
}
}
Err(e) => {
tracing::warn!(error = %e, "Failed to get Redis connection for system message seq");
return;
}
};
let now = chrono::Utc::now();
let result = room_message::ActiveModel {
id: Set(Uuid::now_v7()),
seq: Set(seq),
room: Set(room_id),
sender_type: Set(MessageSenderType::System),
sender_id: Set(None),
model_id: Set(None),
thread: Set(None),
in_reply_to: Set(None),
content: Set(message.to_string()),
content_type: Set(MessageContentType::Text),
thinking_content: Set(None),
edited_at: Set(None),
send_at: Set(now),
revoked: Set(None),
revoked_by: Set(None),
}
.insert(db)
.await;
match result {
Ok(_) => {
tracing::info!(
room_id = %room_id,
message = %message,
"system_message_created_for_billing_error"
);
}
Err(e) => {
tracing::warn!(
error = %e,
room_id = %room_id,
"Failed to create system message for billing error"
);
}
}
}
/// Service for handling AI chat requests in rooms.
pub struct ChatService {
ai_base_url: Option<String>,
ai_api_key: Option<String>,
compact_service: Option<CompactService>,
embed_service: Option<EmbedService>,
perception_service: PerceptionService,
tool_registry: Option<ToolRegistry>,
}
impl ChatService {
pub fn new() -> Self {
Self {
ai_base_url: None,
ai_api_key: None,
compact_service: None,
embed_service: None,
perception_service: PerceptionService::default(),
tool_registry: None,
}
}
pub fn with_ai_client_config(mut self, config: AiClientConfig) -> Self {
self.ai_base_url = config.base_url.clone();
self.ai_api_key = Some(config.api_key.clone());
self
}
pub fn with_compact_service(mut self, compact_service: CompactService) -> Self {
self.compact_service = Some(compact_service);
self
}
pub fn with_embed_service(mut self, embed_service: EmbedService) -> Self {
self.embed_service = Some(embed_service);
self
}
pub fn with_perception_service(mut self, perception_service: PerceptionService) -> Self {
self.perception_service = perception_service;
self
}
pub fn with_tool_registry(mut self, registry: ToolRegistry) -> Self {
self.tool_registry = Some(registry);
self
}
/// Returns all registered tools as JSON tool definitions.
pub fn tools(&self) -> Vec<serde_json::Value> {
self.tool_registry
.as_ref()
.map(|r| r.to_openai_tools())
.unwrap_or_default()
}
/// Build a RigToolSet from the registered tool registry.
///
/// This enables using the same tools with `RigAgentService` via rig's native Agent.
/// The context (db, cache, config, room_id, sender_id) is passed through to each
/// tool handler at creation time.
#[cfg(feature = "rig")]
pub fn rig_toolset(
&self,
db: db::database::AppDatabase,
cache: db::cache::AppCache,
config: config::AppConfig,
room_id: uuid::Uuid,
sender_id: Option<uuid::Uuid>,
project_id: uuid::Uuid,
) -> Option<crate::RigToolSet> {
self.tool_registry.as_ref().map(|registry| {
crate::RigToolSet::from_registry(
registry, db, cache, config, room_id, sender_id, project_id,
)
})
}
/// Get a reference to the underlying ToolRegistry.
pub fn tool_registry(&self) -> Option<&ToolRegistry> {
self.tool_registry.as_ref()
}
pub async fn process(&self, request: AiRequest) -> Result<ProcessResult> {
let tools: Vec<serde_json::Value> = request.tools.clone().unwrap_or_default();
let tools_enabled = !tools.is_empty();
let max_tool_depth = request.max_tool_depth;
let mut messages = self.build_messages(&request).await?;
let room_ai = room_ai::Entity::find()
.filter(room_ai::Column::Room.eq(request.room.id))
.filter(room_ai::Column::Model.eq(request.model.id))
.one(&request.db)
.await?;
let model_name = request.model.name.clone();
let temperature = room_ai
.as_ref()
.and_then(|r| r.temperature.map(|v| v as f32))
.unwrap_or(request.temperature as f32);
let max_tokens = room_ai
.as_ref()
.and_then(|r| r.max_tokens.map(|v| v as u32))
.unwrap_or(request.max_tokens as u32);
let mut tool_depth = 0;
let mut input_tokens = 0i64;
let mut output_tokens = 0i64;
let session_id = Uuid::now_v7();
let session_start = std::time::Instant::now();
let version_id = room_ai.as_ref().and_then(|r| r.version);
let config = AiClientConfig::new(self.ai_api_key.clone().unwrap_or_default())
.with_base_url(
self.ai_base_url
.clone()
.unwrap_or_else(|| "https://api.openai.com".into()),
);
loop {
let response = call_with_params(
&messages,
&model_name,
&config,
temperature,
max_tokens,
None,
if tools_enabled { Some(&tools) } else { None },
if tools_enabled { None } else { Some("none") },
)
.await?;
let text = response.content.clone();
input_tokens += response.input_tokens;
output_tokens += response.output_tokens;
if tools_enabled && !response.tool_calls_finished.is_empty() {
// Build assistant message with tool_calls
let tool_call_messages: Vec<_> = response
.tool_calls_finished
.iter()
.map(|name| {
// We need ID and arguments — for non-streaming we reconstruct from content
// The model returns tool_calls in its content; for now we create a placeholder
// that will be replaced by actual tool results
ToolCall {
id: Uuid::new_v4().to_string(),
type_: "function".into(),
function: crate::client::types::ToolCallFunction {
name: name.clone(),
arguments: "{}".into(),
},
}
})
.collect();
messages.push(ChatRequestMessage::assistant(
Some(text.clone()),
Some(tool_call_messages.clone()),
));
// Create ToolCall list for executor (we need real IDs and args)
// Since we can't get args from streaming, use name matching from the text
let calls: Vec<AgentToolCall> = tool_call_messages
.into_iter()
.map(|tc| AgentToolCall {
id: tc.id.clone(),
name: tc.function.name.clone(),
arguments: tc.function.arguments.clone(),
})
.collect();
let tool_messages = {
let mut ctx = ToolContext::new(
request.db.clone(),
request.cache.clone(),
request.config.clone(),
request.room.id,
Some(request.sender.uid),
)
.with_project(request.project.id);
if let Some(ref es) = self.embed_service {
ctx = ctx.with_embed_service(es.clone());
}
if let Some(ref registry) = self.tool_registry {
ctx.registry_mut().merge(registry.clone());
}
let recorder = crate::tool::recorder::ToolCallRecorder::with_session(
request.db.clone(),
session_id,
);
let start = std::time::Instant::now();
let executor = ToolExecutor::new();
match executor.execute_batch(calls, &mut ctx).await {
Ok(results) => {
for (call, result) in
response.tool_calls_finished.iter().zip(results.iter())
{
let elapsed = start.elapsed().as_millis() as i64;
let is_error =
matches!(result.result, crate::tool::ToolResult::Error(_));
let error_msg = match &result.result {
crate::tool::ToolResult::Error(msg) => Some(msg.clone()),
_ => None,
};
recorder.record(crate::tool::recorder::ToolCallRecord {
tool_call_id: call.clone(),
session_id: recorder.session_id(),
tool_name: call.clone(),
caller: request.sender.uid,
arguments: serde_json::Value::Null,
status: if is_error {
models::ai::ToolCallStatus::Failed
} else {
models::ai::ToolCallStatus::Success
},
execution_time_ms: Some(elapsed),
error_message: error_msg,
error_stack: None,
retry_count: 0,
});
}
crate::tool::ToolExecutor::to_tool_messages(&results)
}
Err(e) => {
let elapsed = start.elapsed().as_millis() as i64;
for call_name in &response.tool_calls_finished {
recorder.record(crate::tool::recorder::ToolCallRecord {
tool_call_id: Uuid::new_v4().to_string(),
session_id: recorder.session_id(),
tool_name: call_name.clone(),
caller: request.sender.uid,
arguments: serde_json::Value::Null,
status: models::ai::ToolCallStatus::Failed,
execution_time_ms: Some(elapsed),
error_message: Some(e.to_string()),
error_stack: None,
retry_count: 0,
});
}
let err_msg = format!("[Tool call failed: {}]", e);
response
.tool_calls_finished
.iter()
.map(|_| {
ChatRequestMessage::tool(Uuid::new_v4().to_string(), &err_msg)
})
.collect()
}
}
};
messages.extend(tool_messages);
// Inject passive-detected skills based on tool calls
if let Ok(skills) = project_skill::Entity::find()
.filter(project_skill::Column::ProjectUuid.eq(request.project.id))
.filter(project_skill::Column::Enabled.eq(true))
.all(&request.db)
.await
{
let skill_entries: Vec<SkillEntry> = skills
.into_iter()
.map(|s| SkillEntry {
slug: s.slug,
name: s.name,
description: s.description,
content: s.content,
})
.collect();
let tool_events: Vec<ToolCallEvent> = response
.tool_calls_finished
.iter()
.map(|name| ToolCallEvent {
tool_name: name.clone(),
arguments: String::new(),
})
.collect();
for event in &tool_events {
if let Some(ctx) = self
.perception_service
.passive
.detect(event, &skill_entries)
{
messages.push(ctx.to_system_message());
}
}
}
tool_depth += 1;
if tool_depth >= max_tool_depth {
let content = if text.is_empty() {
format!(
"[AI reached maximum tool depth ({}) — no final answer produced]",
max_tool_depth
)
} else {
text
};
// Record session
record_ai_session(
&request.cache,
&request.db,
request.project.id,
session_id,
request.room.id,
request.model.id,
version_id.unwrap_or_default(),
input_tokens,
output_tokens,
session_start.elapsed().as_millis() as i64,
)
.await;
return Ok(ProcessResult {
content,
input_tokens,
output_tokens,
});
}
continue;
}
// Record session
record_ai_session(
&request.cache,
&request.db,
request.project.id,
session_id,
request.room.id,
request.model.id,
version_id.unwrap_or_default(),
input_tokens,
output_tokens,
session_start.elapsed().as_millis() as i64,
)
.await;
return Ok(ProcessResult {
content: text,
input_tokens,
output_tokens,
});
}
}
pub async fn process_stream(
&self,
request: AiRequest,
on_chunk: StreamCallback,
) -> Result<StreamResult> {
// Wrap on_chunk in Arc so it can be shared across loop iterations
let on_chunk = Arc::new(on_chunk);
let tools: Vec<serde_json::Value> = request.tools.clone().unwrap_or_default();
let tools_enabled = !tools.is_empty();
let max_tool_depth = request.max_tool_depth;
let mut messages = self.build_messages(&request).await?;
let room_ai = room_ai::Entity::find()
.filter(room_ai::Column::Room.eq(request.room.id))
.filter(room_ai::Column::Model.eq(request.model.id))
.one(&request.db)
.await?;
let model_name = request.model.name.clone();
let temperature = room_ai
.as_ref()
.and_then(|r| r.temperature.map(|v| v as f32))
.unwrap_or(request.temperature as f32);
let max_tokens = room_ai
.as_ref()
.and_then(|r| r.max_tokens.map(|v| v as u32))
.unwrap_or(request.max_tokens as u32);
let mut tool_depth = 0;
let mut total_input_tokens = 0i64;
let mut total_output_tokens = 0i64;
let session_id = Uuid::now_v7();
let session_start = std::time::Instant::now();
let version_id = room_ai.as_ref().and_then(|r| r.version);
let config = AiClientConfig::new(self.ai_api_key.clone().unwrap_or_default())
.with_base_url(
self.ai_base_url
.clone()
.unwrap_or_else(|| "https://api.openai.com".into()),
);
let mut full_content = String::new();
let mut all_chunks: Vec<StreamChunk> = Vec::new();
// Collect tool calls during streaming, push them incrementally after.
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<StreamedToolCall>();
loop {
let on_chunk_cb = on_chunk.clone();
let on_chunk_cb2 = on_chunk_cb.clone();
let tx_arc = Arc::new(tx.clone());
let tx_arc2 = tx_arc.clone();
let response = call_stream(
&messages,
&model_name,
&config,
temperature,
max_tokens,
if tools_enabled { Some(&tools) } else { None },
None, // tool_choice — auto (let model decide)
Arc::new(move |delta| {
let fut = on_chunk_cb(AiStreamChunk {
content: delta.to_string(),
done: false,
chunk_type: AiChunkType::Answer,
});
fut
}),
Arc::new(move |delta| {
let fut = on_chunk_cb2(AiStreamChunk {
content: delta.to_string(),
done: false,
chunk_type: AiChunkType::Thinking,
});
fut
}),
Arc::new(move |tc: &StreamedToolCall| {
let tx = tx_arc2.clone();
let tc_owned = tc.clone();
Box::pin(async move {
let _ = tx.send(tc_owned);
}) as Pin<Box<dyn std::future::Future<Output = ()> + Send>>
}),
)
.await?;
total_input_tokens += response.input_tokens;
total_output_tokens += response.output_tokens;
// Collect chunks from this streaming iteration in order.
all_chunks.extend(response.chunks);
let has_tool_calls = tools_enabled && !response.tool_calls.is_empty();
if has_tool_calls {
// Accumulate the assistant's text before tool calls
full_content.push_str(&response.content);
full_content.push('\n');
// Build assistant message with tool_calls from streaming response
let tool_calls: Vec<ToolCall> = response
.tool_calls
.iter()
.map(|tc| ToolCall {
id: tc.id.clone(),
type_: "function".into(),
function: crate::client::types::ToolCallFunction {
name: tc.name.clone(),
arguments: tc.arguments.clone(),
},
})
.collect();
messages.push(ChatRequestMessage::assistant(
Some(response.content.clone()),
Some(tool_calls.clone()),
));
// Push each tool call incrementally to frontend.
// Use try_recv() — tx is never dropped so recv() would deadlock.
loop {
match rx.try_recv() {
Ok(tc) => {
let args_display = if tc.arguments.len() > 100 {
format!("{}...", &tc.arguments[..100])
} else {
tc.arguments.clone()
};
let tool_display = format!("🔧 {}({})", tc.name, args_display);
on_chunk(AiStreamChunk {
content: tool_display.clone(),
done: false,
chunk_type: AiChunkType::ToolCall,
})
.await;
all_chunks.push(StreamChunk {
chunk_type: StreamChunkType::ToolCall,
content: tool_display,
});
}
Err(tokio::sync::mpsc::error::TryRecvError::Empty) => break,
Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => break,
}
}
// Execute tools one at a time, push each result incrementally
let calls: Vec<AgentToolCall> = response
.tool_calls
.iter()
.map(|tc| AgentToolCall {
id: tc.id.clone(),
name: tc.name.clone(),
arguments: tc.arguments.clone(),
})
.collect();
let mut tool_messages = Vec::new();
let mut ctx = crate::tool::ToolContext::new(
request.db.clone(),
request.cache.clone(),
request.config.clone(),
request.room.id,
Some(request.sender.uid),
)
.with_project(request.project.id);
if let Some(ref es) = self.embed_service {
ctx = ctx.with_embed_service(es.clone());
}
if let Some(ref registry) = self.tool_registry {
ctx.registry_mut().merge(registry.clone());
}
let recorder = crate::tool::recorder::ToolCallRecorder::with_session(
request.db.clone(),
session_id,
);
for call in &calls {
let start = std::time::Instant::now();
// Spawn tool execution in a separate task to avoid blocking the
// tokio worker thread (git2 operations are synchronous).
// This allows the heartbeat timer to fire independently.
let call_clone = call.clone();
let mut ctx_clone = ctx.clone();
let (result_tx, mut result_rx) = tokio::sync::oneshot::channel();
tokio::spawn(async move {
let executor = crate::tool::ToolExecutor::new();
let res = executor.execute_batch(vec![call_clone], &mut ctx_clone).await;
let _ = result_tx.send(res);
});
// Send heartbeats every 10s until tool execution completes
let heartbeat_dur = std::time::Duration::from_secs(10);
let results = loop {
tokio::select! {
res = &mut result_rx => {
match res {
Ok(inner) => break inner,
Err(_) => break Err(crate::tool::ToolError::ExecutionError("tool task cancelled".into())),
}
},
_ = tokio::time::sleep(heartbeat_dur) => {
on_chunk(AiStreamChunk {
content: String::new(),
done: false,
chunk_type: AiChunkType::ToolCall,
}).await;
}
}
};
let results = match results {
Ok(r) => r,
Err(e) => {
let elapsed = start.elapsed().as_millis() as i64;
recorder.record(crate::tool::recorder::ToolCallRecord {
tool_call_id: call.id.clone(),
session_id: recorder.session_id(),
tool_name: call.name.clone(),
caller: request.sender.uid,
arguments: call.arguments_json().unwrap_or_default(),
status: models::ai::ToolCallStatus::Failed,
execution_time_ms: Some(elapsed),
error_message: Some(e.to_string()),
error_stack: None,
retry_count: 0,
});
let err_text = format!("[Tool call failed: {}]", e);
tracing::warn!(tool = %call.name, args = %call.arguments, error = %e, "tool_call_failed");
let err_display = format!("{} (failed)", call.name);
on_chunk(AiStreamChunk {
content: err_display.clone(),
done: false,
chunk_type: AiChunkType::ToolResult,
})
.await;
all_chunks.push(StreamChunk {
chunk_type: StreamChunkType::ToolCall,
content: err_display,
});
tool_messages.push(ChatRequestMessage::tool(&call.id, &err_text));
continue;
}
};
for result in &results {
let text = match &result.result {
crate::tool::ToolResult::Ok(v) => v.to_string(),
crate::tool::ToolResult::Error(msg) => msg.clone(),
};
let preview = if text.len() > 300 {
format!("{}...", &text[..300])
} else {
text.clone()
};
tracing::debug!("tool_result: {} — {}", call.name, preview);
let elapsed = start.elapsed().as_millis() as i64;
let is_error = matches!(result.result, crate::tool::ToolResult::Error(_));
let error_msg = match &result.result {
crate::tool::ToolResult::Error(msg) => Some(msg.clone()),
_ => None,
};
recorder.record(crate::tool::recorder::ToolCallRecord {
tool_call_id: call.id.clone(),
session_id: recorder.session_id(),
tool_name: call.name.clone(),
caller: request.sender.uid,
arguments: call.arguments_json().unwrap_or_default(),
status: if is_error {
models::ai::ToolCallStatus::Failed
} else {
models::ai::ToolCallStatus::Success
},
execution_time_ms: Some(elapsed),
error_message: error_msg,
error_stack: None,
retry_count: 0,
});
// Do NOT emit tool_result chunks to frontend — raw output may contain sensitive data.
// Log server-side only; frontend sees tool_call status via on_chunk below.
}
let success_display = format!("{}", call.name);
on_chunk(AiStreamChunk {
content: success_display.clone(),
done: false,
chunk_type: AiChunkType::ToolResult,
})
.await;
all_chunks.push(StreamChunk {
chunk_type: StreamChunkType::ToolCall,
content: success_display,
});
let msgs = crate::tool::ToolExecutor::to_tool_messages(&results);
tool_messages.extend(msgs);
}
messages.extend(tool_messages);
// Inject passive-detected skills based on tool calls
if let Ok(skills) = project_skill::Entity::find()
.filter(project_skill::Column::ProjectUuid.eq(request.project.id))
.filter(project_skill::Column::Enabled.eq(true))
.all(&request.db)
.await
{
let skill_entries: Vec<SkillEntry> = skills
.into_iter()
.map(|s| SkillEntry {
slug: s.slug,
name: s.name,
description: s.description,
content: s.content,
})
.collect();
let tool_events: Vec<ToolCallEvent> = response
.tool_calls
.iter()
.map(|tc| ToolCallEvent {
tool_name: tc.name.clone(),
arguments: tc.arguments.clone(),
})
.collect();
for event in &tool_events {
if let Some(ctx) = self
.perception_service
.passive
.detect(event, &skill_entries)
{
messages.push(ctx.to_system_message());
}
}
}
tool_depth += 1;
if tool_depth >= max_tool_depth {
let max_depth_text = format!(
"[AI reached maximum tool depth ({}) — no final answer produced]",
max_tool_depth
);
on_chunk(AiStreamChunk {
content: max_depth_text.clone(),
done: true,
chunk_type: AiChunkType::Answer,
})
.await;
all_chunks.push(StreamChunk {
chunk_type: StreamChunkType::Answer,
content: max_depth_text,
});
// Record session
record_ai_session(
&request.cache,
&request.db,
request.project.id,
session_id,
request.room.id,
request.model.id,
version_id.unwrap_or_default(),
total_input_tokens,
total_output_tokens,
session_start.elapsed().as_millis() as i64,
)
.await;
return Ok(StreamResult {
content: full_content,
reasoning_content: String::new(),
input_tokens: 0,
output_tokens: 0,
chunks: all_chunks,
});
}
continue;
}
// Final answer — accumulate and return
full_content.push_str(&response.content);
on_chunk(AiStreamChunk {
content: response.content.clone(),
done: true,
chunk_type: AiChunkType::Answer,
})
.await;
all_chunks.push(StreamChunk {
chunk_type: StreamChunkType::Answer,
content: response.content.clone(),
});
// Record session
record_ai_session(
&request.cache,
&request.db,
request.project.id,
session_id,
request.room.id,
request.model.id,
version_id.unwrap_or_default(),
total_input_tokens,
total_output_tokens,
session_start.elapsed().as_millis() as i64,
)
.await;
return Ok(StreamResult {
content: full_content,
reasoning_content: response.reasoning_content,
input_tokens: response.input_tokens,
output_tokens: response.output_tokens,
chunks: all_chunks,
});
}
}
async fn build_messages(&self, request: &AiRequest) -> Result<Vec<ChatRequestMessage>> {
let mut messages = Vec::new();
// Core reasoning instruction — prioritize analysis before answering.
messages.push(ChatRequestMessage::system(
"When receiving a question or problem, follow this reasoning process:\n\
1. ANALYZE: Break down the question. Identify what is being asked, what context is available, and what information is missing.\n\
2. GATHER: Use available tools (repository search, file reading, etc.) to collect relevant information before answering.\n\
3. REASON: Synthesize the gathered information. Consider edge cases and potential issues.\n\
4. ANSWER: Provide a clear, actionable answer based on your analysis.\n\
\n\
Do NOT guess or assume when tools can provide concrete answers. Always verify claims against actual code or documentation.".to_string()
));
let mut processed_history = Vec::new();
if let Some(compact_service) = &self.compact_service {
let compact_cache_key = format!("ai:compact:{}", request.room.id);
let compact_config = CompactConfig::default();
// Try cached compaction summary (avoids re-compacting same history)
let cached_summary: Option<String> = {
let conn_result = request.cache.conn().await;
match conn_result {
Ok(mut conn) => {
redis::cmd("GET")
.arg(&compact_cache_key)
.query_async::<Option<String>>(&mut conn)
.await
.unwrap_or(None)
}
Err(e) => {
tracing::warn!(error = %e, "compact cache: conn failed");
None
}
}
};
if let Some(cached_json) = cached_summary {
if let Ok(summary) = serde_json::from_str::<crate::compact::CompactSummary>(&cached_json) {
if !summary.summary.is_empty() {
messages.push(ChatRequestMessage::system(format!(
"Conversation summary:\n{}",
summary.summary
)));
}
processed_history = summary.retained;
}
}
if processed_history.is_empty() {
match compact_service
.compact_room_auto(request.room.id, Some(request.user_names.clone()), compact_config)
.await
{
Ok(compact_summary) => {
if !compact_summary.summary.is_empty() {
messages.push(ChatRequestMessage::system(format!(
"Conversation summary:\n{}",
compact_summary.summary
)));
}
// Cache for subsequent calls (5 min TTL)
if let Ok(json) = serde_json::to_string(&compact_summary) {
if let Ok(mut conn) = request.cache.conn().await {
let _ = redis::cmd("SETEX")
.arg(&compact_cache_key)
.arg(300u64)
.arg(&json)
.query_async::<()>(&mut conn)
.await
.inspect_err(|e| {
tracing::warn!(error = %e, "compact cache: SETEX failed");
});
}
}
processed_history = compact_summary.retained;
}
Err(e) => {
tracing::warn!(error = %e, "conversation compaction failed, using full history");
}
}
}
}
if !processed_history.is_empty() {
for msg_summary in processed_history {
let ctx = RoomMessageContext::from(msg_summary);
messages.push(ctx.to_message());
}
} else {
for msg in &request.history {
let ctx = RoomMessageContext::from_model_with_names(msg, &request.user_names);
messages.push(ctx.to_message());
}
}
for mention in &request.mention {
match mention {
Mention::Repo(repo) => {
// Inject repo details into system prompt so AI knows the repo context
let mut parts = vec![
format!("Name: {}", repo.repo_name),
format!("ID: {}", repo.id),
];
if let Some(ref desc) = repo.description {
parts.push(format!("Description: {}", desc));
}
parts.push(format!("Default branch: {}", repo.default_branch));
parts.push(format!(
"Private: {}",
if repo.is_private { "yes" } else { "no" }
));
parts.push(format!("Created: {}", repo.created_at.format("%Y-%m-%d")));
messages.push(ChatRequestMessage::system(format!(
"Mentioned repository:\n{}",
parts.join("\n")
)));
// Vector search for related issues and repos (enhancement, optional)
if let Some(embed_service) = &self.embed_service {
let query = format!(
"{} {}",
repo.repo_name,
repo.description.as_deref().unwrap_or_default()
);
if let Ok(issues) = embed_service.search_issues(&query, 5).await {
if !issues.is_empty() {
let context = format!(
"Related issues for repo {}:\n{}",
repo.repo_name,
issues
.iter()
.map(|i| format!("- {}", i.payload.text))
.collect::<Vec<_>>()
.join("\n")
);
messages.push(ChatRequestMessage::system(context));
}
}
if let Ok(repos) = embed_service.search_repos(&query, 3).await {
if !repos.is_empty() {
let context = format!(
"Similar repositories:\n{}",
repos
.iter()
.map(|r| format!("- {}", r.payload.text))
.collect::<Vec<_>>()
.join("\n")
);
messages.push(ChatRequestMessage::system(context));
}
}
}
}
Mention::User(user) => {
let mut profile_parts = vec![format!("Username: {}", user.username)];
if let Some(ref display_name) = user.display_name {
profile_parts.push(format!("Display name: {}", display_name));
}
if let Some(ref org) = user.organization {
profile_parts.push(format!("Organization: {}", org));
}
if let Some(ref website) = user.website_url {
profile_parts.push(format!("Website: {}", website));
}
messages.push(ChatRequestMessage::system(format!(
"Mentioned user profile:\n{}",
profile_parts.join("\n")
)));
}
}
}
let skill_contexts = self.build_skill_context(request).await;
for ctx in skill_contexts {
messages.push(ctx.to_system_message());
}
let memories = self.build_memory_context(request).await;
for mem in memories {
messages.push(mem.to_system_message());
}
messages.push(ChatRequestMessage::system(format!(
"Current Project:\n{}\nDescription: {}\nPublic: {}",
request.project.display_name,
request.project.description.as_deref().unwrap_or("(none)"),
if request.project.is_public {
"yes"
} else {
"no"
}
)));
let mut sender_parts = vec![format!("**Sender:** {}", request.sender.username)];
if let Some(ref display_name) = request.sender.display_name {
sender_parts.push(display_name.clone());
}
if let Some(ref org) = request.sender.organization {
sender_parts.push(format!("({})", org));
}
let sender_display = sender_parts.join(" ");
messages.push(ChatRequestMessage::system(format!(
"The person sending the next message:\n{}",
sender_display
)));
messages.push(ChatRequestMessage::user(&request.input));
Ok(messages)
}
async fn build_skill_context(
&self,
request: &AiRequest,
) -> Vec<crate::perception::SkillContext> {
// Load database skills for this project
let db_skills: Vec<SkillEntry> = match project_skill::Entity::find()
.filter(project_skill::Column::ProjectUuid.eq(request.project.id))
.filter(project_skill::Column::Enabled.eq(true))
.all(&request.db)
.await
{
Ok(models) => models
.into_iter()
.map(|s| SkillEntry {
slug: s.slug,
name: s.name,
description: s.description,
content: s.content,
})
.collect(),
Err(_) => Vec::new(),
};
// Load built-in skills and merge with database skills
// Built-in skills override database skills with the same slug
let mut all_skills: Vec<SkillEntry> = db_skills;
for built_in in crate::skills::all_skills() {
// Skip if a database skill with the same slug already exists
if !all_skills.iter().any(|s| s.slug == built_in.slug) {
all_skills.push(SkillEntry {
slug: built_in.slug.to_string(),
name: built_in.name.to_string(),
description: Some(built_in.description.to_string()),
content: built_in.content.clone(),
});
}
}
if all_skills.is_empty() {
return Vec::new();
}
let history_texts: Vec<String> = request
.history
.iter()
.rev()
.take(10)
.map(|msg| msg.content.clone())
.collect();
let tool_events: Vec<ToolCallEvent> = Vec::new();
let keyword_skills = self
.perception_service
.inject_skills(&request.input, &history_texts, &tool_events, &all_skills)
.await;
let mut vector_skills = Vec::new();
if let Some(embed_service) = &self.embed_service {
let awareness = crate::perception::VectorActiveAwareness::default();
vector_skills = awareness
.detect(
embed_service,
&request.input,
&request.project.id.to_string(),
)
.await;
}
let mut seen = std::collections::HashSet::new();
let mut result = Vec::new();
for ctx in vector_skills {
if seen.insert(ctx.label.clone()) {
result.push(ctx);
}
}
for ctx in keyword_skills {
if seen.insert(ctx.label.clone()) {
result.push(ctx);
}
}
result
}
async fn build_memory_context(
&self,
request: &AiRequest,
) -> Vec<crate::perception::vector::MemoryContext> {
let embed_service = match &self.embed_service {
Some(s) => s,
None => return Vec::new(),
};
let awareness = crate::perception::VectorPassiveAwareness::default();
awareness
.detect(
embed_service,
&request.input,
&request.project.display_name,
&request.room.id.to_string(),
)
.await
}
pub async fn process_react<C, Fut>(&self, request: &AiRequest, mut on_chunk: C) -> Result<(String, i64, i64)>
where
C: FnMut(crate::react::ReactStep) -> Fut + Send,
Fut: std::future::Future<Output = ()> + Send,
{
let base_url = self
.ai_base_url
.clone()
.unwrap_or_else(|| "https://api.openai.com".into());
let api_key = self.ai_api_key.clone().unwrap_or_default();
let client_config = AiClientConfig::new(api_key).with_base_url(base_url);
let Some(registry) = &self.tool_registry else {
return Err(AgentError::Internal("no tool registry registered".into()));
};
let db = request.db.clone();
let cache = request.cache.clone();
let cfg = request.config.clone();
let room_id = request.room.id;
let sender_uid = request.sender.uid;
let project_id = request.project.id;
let session_id = Uuid::now_v7();
let session_start = std::time::Instant::now();
let version_id = room_ai::Entity::find()
.filter(room_ai::Column::Room.eq(request.room.id))
.filter(room_ai::Column::Model.eq(request.model.id))
.one(&request.db)
.await
.ok()
.flatten()
.and_then(|r| r.version);
// Build rig tools with recording wrapper directly from registry
let mut tools: Vec<Box<dyn rig::tool::ToolDyn + 'static>> = Vec::new();
for def in registry.definitions() {
let name = def.name.clone();
if let Some(handler) = registry.get(&name) {
let adapter = crate::tool::RigToolAdapter::new(
handler.clone(),
def.clone(),
db.clone(),
cache.clone(),
cfg.clone(),
room_id,
Some(sender_uid),
project_id,
);
tools.push(Box::new(RecordingTool::new(
Box::new(adapter),
db.clone(),
session_id,
sender_uid,
)));
}
}
// Build rig agent (handles multi-turn tool calls natively)
let rig_client = client_config.build_rig_client();
let model = rig_client.completion_model(&request.model.name);
let agent = AgentBuilder::new(model)
.preamble(DEFAULT_SYSTEM_PROMPT)
.tools(tools)
.default_max_turns(request.max_tool_depth)
.build();
let stream = agent
.stream_prompt(&request.input)
.with_history(Vec::new())
.multi_turn(request.max_tool_depth)
.await;
tokio::pin!(stream);
let mut step_count = 0usize;
let mut final_content = String::new();
let mut total_input_tokens: i64 = 0;
let mut total_output_tokens: i64 = 0;
while let Some(item) = stream.next().await {
match item {
Ok(MultiTurnStreamItem::StreamAssistantItem(
StreamedAssistantContent::Text(text),
)) => {
step_count += 1;
let t = text.text;
on_chunk(ReactStep::Answer {
step: step_count,
answer: t.clone(),
})
.await;
final_content.push_str(&t);
}
Ok(MultiTurnStreamItem::StreamAssistantItem(
StreamedAssistantContent::Reasoning(reasoning),
)) => {
let reasoning_text = reasoning.reasoning.join("");
if !reasoning_text.is_empty() {
step_count += 1;
on_chunk(ReactStep::Thought {
step: step_count,
thought: reasoning_text,
})
.await;
}
}
Ok(MultiTurnStreamItem::StreamAssistantItem(
StreamedAssistantContent::ReasoningDelta { reasoning, .. },
)) => {
if !reasoning.is_empty() {
step_count += 1;
on_chunk(ReactStep::Thought {
step: step_count,
thought: reasoning,
})
.await;
}
}
Ok(MultiTurnStreamItem::StreamAssistantItem(
StreamedAssistantContent::ToolCall { tool_call, .. },
)) => {
step_count += 1;
let args: serde_json::Value = match &tool_call.function.arguments {
serde_json::Value::String(s) => {
serde_json::from_str(s).unwrap_or(serde_json::Value::Null)
}
v => v.clone(),
};
on_chunk(ReactStep::Action {
step: step_count,
action: ReactAction::new(&tool_call.function.name, args),
})
.await;
}
Ok(MultiTurnStreamItem::StreamUserItem(
rig::streaming::StreamedUserContent::ToolResult { tool_result, .. },
)) => {
step_count += 1;
let obs = tool_result_content_to_string(&tool_result.content);
on_chunk(ReactStep::Observation {
step: step_count,
observation: obs,
})
.await;
}
Ok(MultiTurnStreamItem::FinalResponse(resp)) => {
let usage = resp.usage();
total_input_tokens = usage.input_tokens as i64;
total_output_tokens = usage.output_tokens as i64;
// Text was already streamed incrementally via Answer events.
}
Err(e) => {
let err_msg = format!("rig agent stream error: {}", e);
return Err(AgentError::OpenAi(err_msg));
}
_ => {}
}
}
let elapsed_ms = session_start.elapsed().as_millis() as i64;
record_ai_session(
&request.cache,
&request.db,
request.project.id,
session_id,
request.room.id,
request.model.id,
version_id.unwrap_or_default(),
total_input_tokens,
total_output_tokens,
elapsed_ms,
)
.await;
Ok((final_content, total_input_tokens, total_output_tokens))
}
// ── CoT (Chain-of-Thought) ────────────────────────────────────────────
/// Run a CoT (Chain-of-Thought) reasoning cycle — step-by-step reasoning with optional tool use.
pub async fn process_cot<C, Fut>(&self, request: &AiRequest, mut on_chunk: C) -> Result<(String, i64, i64)>
where
C: FnMut(crate::modes::cot::CotStep) -> Fut + Send,
Fut: std::future::Future<Output = ()> + Send,
{
let client_config = AiClientConfig::new(
self.ai_api_key.clone().unwrap_or_default(),
)
.with_base_url(
self.ai_base_url
.clone()
.unwrap_or_else(|| "https://api.openai.com".into()),
);
let rig_client = client_config.build_rig_client();
let Some(registry) = &self.tool_registry else {
return Err(AgentError::Internal("no tool registry registered".into()));
};
let session_id = Uuid::now_v7();
let session_start = std::time::Instant::now();
let version_id = room_ai::Entity::find()
.filter(room_ai::Column::Room.eq(request.room.id))
.filter(room_ai::Column::Model.eq(request.model.id))
.one(&request.db)
.await
.ok()
.flatten()
.and_then(|r| r.version);
let db = request.db.clone();
let cache = request.cache.clone();
let cfg = request.config.clone();
let room_id = request.room.id;
let sender_uid = request.sender.uid;
let project_id = request.project.id;
let mut tools: Vec<Box<dyn rig::tool::ToolDyn + 'static>> = Vec::new();
for def in registry.definitions() {
let name = def.name.clone();
if let Some(handler) = registry.get(&name) {
let adapter = crate::tool::RigToolAdapter::new(
handler.clone(), def.clone(),
db.clone(), cache.clone(), cfg.clone(),
room_id, Some(sender_uid), project_id,
);
tools.push(Box::new(crate::tool::RecordingTool::new(
Box::new(adapter), db.clone(), session_id, sender_uid,
)));
}
}
let model = rig_client.completion_model(&request.model.name);
let agent = AgentBuilder::new(model)
.preamble(crate::modes::cot::COT_SYSTEM_PROMPT)
.tools(tools)
.default_max_turns(request.max_tool_depth)
.build();
let stream = agent
.stream_prompt(&request.input)
.with_history(Vec::new())
.multi_turn(request.max_tool_depth)
.await;
tokio::pin!(stream);
let mut final_content = String::new();
let mut total_input_tokens: i64 = 0;
let mut total_output_tokens: i64 = 0;
while let Some(item) = stream.next().await {
match item {
Ok(MultiTurnStreamItem::StreamAssistantItem(
StreamedAssistantContent::Text(text),
)) => {
let t = text.text;
on_chunk(crate::modes::cot::CotStep::Answer(t.clone())).await;
final_content.push_str(&t);
}
Ok(MultiTurnStreamItem::StreamAssistantItem(
StreamedAssistantContent::Reasoning(reasoning),
)) => {
let r = reasoning.reasoning.join("");
if !r.is_empty() {
on_chunk(crate::modes::cot::CotStep::Thought(r)).await;
}
}
Ok(MultiTurnStreamItem::StreamAssistantItem(
StreamedAssistantContent::ReasoningDelta { reasoning, .. },
)) => {
if !reasoning.is_empty() {
on_chunk(crate::modes::cot::CotStep::Thought(reasoning)).await;
}
}
Ok(MultiTurnStreamItem::StreamAssistantItem(
StreamedAssistantContent::ToolCall { tool_call, .. },
)) => {
let args: serde_json::Value = match &tool_call.function.arguments {
serde_json::Value::String(s) => {
serde_json::from_str(s).unwrap_or(serde_json::Value::Null)
}
v => v.clone(),
};
on_chunk(crate::modes::cot::CotStep::Action {
name: tool_call.function.name.clone(),
args,
}).await;
}
Ok(MultiTurnStreamItem::StreamUserItem(
rig::streaming::StreamedUserContent::ToolResult { tool_result, .. },
)) => {
let obs = tool_result_content_to_string(&tool_result.content);
on_chunk(crate::modes::cot::CotStep::Observation(obs)).await;
}
Ok(MultiTurnStreamItem::FinalResponse(resp)) => {
let usage = resp.usage();
total_input_tokens = usage.input_tokens as i64;
total_output_tokens = usage.output_tokens as i64;
}
Err(e) => {
return Err(AgentError::OpenAi(e.to_string()));
}
_ => {}
}
}
let elapsed_ms = session_start.elapsed().as_millis() as i64;
record_ai_session(
&request.cache, &request.db, request.project.id,
session_id, request.room.id, request.model.id,
version_id.unwrap_or_default(),
total_input_tokens, total_output_tokens, elapsed_ms,
).await;
Ok((final_content, total_input_tokens, total_output_tokens))
}
// ── ReWOO (Plan → Execute → Synthesize) ───────────────────────────────
/// Run a ReWOO reasoning cycle: model plans tool calls, they are executed,
/// then the model synthesises the final answer.
pub async fn process_rewoo<C, Fut>(&self, request: &AiRequest, mut on_chunk: C) -> Result<(String, i64, i64)>
where
C: FnMut(crate::modes::rewoo::ReWooStep) -> Fut + Send,
Fut: std::future::Future<Output = ()> + Send,
{
let client_config = AiClientConfig::new(
self.ai_api_key.clone().unwrap_or_default(),
)
.with_base_url(
self.ai_base_url
.clone()
.unwrap_or_else(|| "https://api.openai.com".into()),
);
let rig_client = client_config.build_rig_client();
let Some(registry) = &self.tool_registry else {
return Err(AgentError::Internal("no tool registry registered".into()));
};
let session_id = Uuid::now_v7();
let session_start = std::time::Instant::now();
let version_id = room_ai::Entity::find()
.filter(room_ai::Column::Room.eq(request.room.id))
.filter(room_ai::Column::Model.eq(request.model.id))
.one(&request.db)
.await
.ok()
.flatten()
.and_then(|r| r.version);
let mut total_input_tokens: i64 = 0;
let mut total_output_tokens: i64 = 0;
let mut messages = self.build_messages(request).await?;
messages.insert(0, crate::client::types::ChatRequestMessage::system(
crate::modes::rewoo::REWOO_SYSTEM_PROMPT.to_string(),
));
let model = rig_client.completion_model(&request.model.name);
let plan_tools = {
let db = request.db.clone();
let cache = request.cache.clone();
let cfg = request.config.clone();
let room_id = request.room.id;
let sender_uid = request.sender.uid;
let project_id = request.project.id;
let mut tools: Vec<Box<dyn rig::tool::ToolDyn + 'static>> = Vec::new();
for def in registry.definitions() {
let name = def.name.clone();
if let Some(handler) = registry.get(&name) {
let adapter = crate::tool::RigToolAdapter::new(
handler.clone(),
def.clone(),
db.clone(),
cache.clone(),
cfg.clone(),
room_id,
Some(sender_uid),
project_id,
);
tools.push(Box::new(crate::tool::RecordingTool::new(
Box::new(adapter),
db.clone(),
session_id,
sender_uid,
)));
}
}
tools
};
let plan_agent = rig::agent::AgentBuilder::new(model)
.preamble(crate::modes::rewoo::REWOO_SYSTEM_PROMPT)
.tools(plan_tools)
.default_max_turns(1)
.build();
let plan_response = plan_agent
.prompt(&request.input)
.extended_details()
.await
.map_err(|e| AgentError::OpenAi(e.to_string()))?;
total_input_tokens += plan_response.total_usage.input_tokens as i64;
total_output_tokens += plan_response.total_usage.output_tokens as i64;
let plan = crate::modes::rewoo::extract_plan(&plan_response.output)
.unwrap_or_default();
if plan.calls.is_empty() {
on_chunk(crate::modes::rewoo::ReWooStep::Synthesis(plan_response.output.clone())).await;
let elapsed_ms = session_start.elapsed().as_millis() as i64;
record_ai_session(
&request.cache, &request.db, request.project.id,
session_id, request.room.id, request.model.id,
version_id.unwrap_or_default(),
total_input_tokens, total_output_tokens, elapsed_ms,
).await;
return Ok((plan_response.output, total_input_tokens, total_output_tokens));
}
on_chunk(crate::modes::rewoo::ReWooStep::Plan {
calls: plan.calls.clone(),
raw: plan.raw_text,
}).await;
// ── Phase 2: Execute all tool calls in parallel ───────────────────
let mut tool_results: Vec<(String, String)> = Vec::new();
let mut handles = Vec::new();
for call in &plan.calls {
let ctx = crate::tool::ToolContext::new(
request.db.clone(),
request.cache.clone(),
request.config.clone(),
request.room.id,
Some(request.sender.uid),
)
.with_project(request.project.id);
if let Some(ref es) = self.embed_service {
// ctx = ctx.with_embed_service(es.clone()); -- not clone-able via pattern, skip
let _ = es;
}
let call_id = call.step.to_string();
let tool_name = call.tool.clone();
let args = call.args.clone();
let ctx_clone = ctx.clone();
let handle = tokio::spawn(async move {
let executor = crate::tool::ToolExecutor::new();
let agent_call = crate::tool::ToolCall {
id: call_id,
name: tool_name.clone(),
arguments: args.to_string(),
};
let mut local_ctx = ctx_clone;
let result = executor.execute_batch(vec![agent_call], &mut local_ctx).await;
match result {
Ok(results) => {
for r in &results {
match &r.result {
crate::tool::ToolResult::Ok(v) => {
return (tool_name, v.to_string());
}
crate::tool::ToolResult::Error(e) => {
return (tool_name, format!("[Error: {}]", e));
}
}
}
(tool_name, "[No result]".to_string())
}
Err(e) => (tool_name, format!("[Execution error: {}]", e)),
}
});
handles.push(handle);
}
for handle in handles {
match handle.await {
Ok((name, result)) => {
on_chunk(crate::modes::rewoo::ReWooStep::Execution {
tool_name: name.clone(),
result: result.clone(),
}).await;
tool_results.push((name, result));
}
Err(e) => {
let msg = format!("[Task panicked: {}]", e);
on_chunk(crate::modes::rewoo::ReWooStep::Execution {
tool_name: "unknown".into(),
result: msg.clone(),
}).await;
tool_results.push(("unknown".into(), msg));
}
}
}
// ── Phase 3: Synthesize ───────────────────────────────────────────
let mut synth_messages = self.build_messages(request).await?;
synth_messages.insert(0, crate::client::types::ChatRequestMessage::system(
crate::modes::rewoo::REWOO_SYSTEM_PROMPT.to_string(),
));
let results_summary: String = tool_results
.iter()
.map(|(name, res)| format!("- {}:\n{}", name, res))
.collect::<Vec<_>>()
.join("\n");
synth_messages.push(crate::client::types::ChatRequestMessage::system(format!(
"## Tool Execution Results\n\nThe following tool calls were executed:\n\n{}\n\nNow synthesize your final answer based on these results.",
results_summary
)));
synth_messages.push(crate::client::types::ChatRequestMessage::user(&request.input));
let preamble = synth_messages
.iter()
.find(|m| m.role == "system")
.and_then(|m| m.content.as_deref())
.unwrap_or("")
.to_string();
let non_system: Vec<_> = synth_messages
.iter()
.filter(|m| m.role != "system")
.map(|m| crate::client::to_rig_message(m))
.collect();
let synth_model = rig_client.completion_model(&request.model.name);
let synth_stream = synth_model
.completion_request("")
.preamble(preamble)
.messages(non_system)
.temperature(request.temperature as f64)
.max_tokens(request.max_tokens as u64)
.stream()
.await
.map_err(|e| AgentError::OpenAi(e.to_string()))?;
use rig::streaming::StreamedAssistantContent;
tokio::pin!(synth_stream);
let mut synthesis = String::new();
while let Some(item) = synth_stream.next().await {
match item {
Ok(StreamedAssistantContent::Text(text)) => {
let t = text.text;
on_chunk(crate::modes::rewoo::ReWooStep::Synthesis(t.clone())).await;
synthesis.push_str(&t);
}
Ok(StreamedAssistantContent::Final(response)) => {
if let Some(usage) = response.token_usage() {
total_input_tokens += usage.input_tokens as i64;
total_output_tokens += usage.output_tokens as i64;
}
}
Err(e) => return Err(AgentError::OpenAi(e.to_string())),
_ => {}
}
}
let elapsed_ms = session_start.elapsed().as_millis() as i64;
record_ai_session(
&request.cache, &request.db, request.project.id,
session_id, request.room.id, request.model.id,
version_id.unwrap_or_default(),
total_input_tokens, total_output_tokens, elapsed_ms,
).await;
Ok((synthesis, total_input_tokens, total_output_tokens))
}
// ── Reflexion (Generate → Critique → Revise) ──────────────────────────
/// Run a Reflexion reasoning cycle: generate → critique → revise (up to 3 rounds).
pub async fn process_reflexion<C, Fut>(
&self,
request: &AiRequest,
mut on_chunk: C,
max_cycles: usize,
) -> Result<(String, i64, i64)>
where
C: FnMut(crate::modes::reflexion::ReflexionStep) -> Fut + Send,
Fut: std::future::Future<Output = ()> + Send,
{
let client_config = AiClientConfig::new(
self.ai_api_key.clone().unwrap_or_default(),
)
.with_base_url(
self.ai_base_url
.clone()
.unwrap_or_else(|| "https://api.openai.com".into()),
);
let rig_client = client_config.build_rig_client();
let Some(registry) = &self.tool_registry else {
return Err(AgentError::Internal("no tool registry registered".into()));
};
let session_id = Uuid::now_v7();
let session_start = std::time::Instant::now();
let version_id = room_ai::Entity::find()
.filter(room_ai::Column::Room.eq(request.room.id))
.filter(room_ai::Column::Model.eq(request.model.id))
.one(&request.db)
.await
.ok()
.flatten()
.and_then(|r| r.version);
let max_cycles = max_cycles.min(3);
let mut total_input_tokens: i64 = 0;
let mut total_output_tokens: i64 = 0;
let mut best_answer = String::new();
for cycle in 0..max_cycles {
let mut messages = self.build_messages(request).await?;
messages.insert(0, crate::client::types::ChatRequestMessage::system(
crate::modes::reflexion::REFLEXION_SYSTEM_PROMPT.to_string(),
));
if cycle > 0 {
messages.push(crate::client::types::ChatRequestMessage::system(format!(
"This is cycle {} of the reflexion process. Your previous answer was:\n\n{}\n\nPlease critique and improve upon it.",
cycle + 1,
best_answer
)));
}
// Build tools for this cycle (not cloneable, so rebuild each iteration)
let cycle_tools = build_rig_tools(
registry, &request.db, &request.cache, &request.config,
request.room.id, request.sender.uid, request.project.id, session_id,
);
// ── Generate ──────────────────────────────────────────────
let model = rig_client.completion_model(&request.model.name);
let agent = rig::agent::AgentBuilder::new(model)
.preamble(crate::modes::reflexion::REFLEXION_SYSTEM_PROMPT)
.tools(cycle_tools)
.default_max_turns(request.max_tool_depth)
.build();
let stream = agent
.stream_prompt(&request.input)
.with_history(Vec::new())
.multi_turn(request.max_tool_depth)
.await;
tokio::pin!(stream);
let mut generated = String::new();
while let Some(item) = stream.next().await {
match item {
Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem(
rig::streaming::StreamedAssistantContent::Text(text),
)) => {
generated.push_str(&text.text);
}
Ok(rig::agent::MultiTurnStreamItem::FinalResponse(resp)) => {
let usage = resp.usage();
total_input_tokens += usage.input_tokens as i64;
total_output_tokens += usage.output_tokens as i64;
}
Err(e) => return Err(AgentError::OpenAi(e.to_string())),
_ => {}
}
}
best_answer = generated.clone();
on_chunk(crate::modes::reflexion::ReflexionStep::Generate(generated.clone())).await;
// If only 1 cycle, emit final and exit
if max_cycles == 1 || cycle + 1 >= max_cycles {
on_chunk(crate::modes::reflexion::ReflexionStep::Final(generated.clone())).await;
break;
}
// ── Self-critique ─────────────────────────────────────────
let critique_messages = vec![
crate::client::types::ChatRequestMessage::system(crate::modes::reflexion::REFLEXION_SYSTEM_PROMPT),
crate::client::types::ChatRequestMessage::system(format!(
"Your previous answer was:\n\n{}", generated
)),
crate::client::types::ChatRequestMessage::user(crate::modes::reflexion::REFLEXION_CRITIQUE_PROMPT),
];
let critique_result = crate::client::call_with_params(
&critique_messages,
&request.model.name,
&client_config,
request.temperature as f32,
request.max_tokens as u32,
None,
None,
Some("none"),
).await?;
total_input_tokens += critique_result.input_tokens;
total_output_tokens += critique_result.output_tokens;
let critique = critique_result.content;
on_chunk(crate::modes::reflexion::ReflexionStep::Critique(critique.clone())).await;
// ── Revise ───────────────────────────────────────────────
let revise_messages = vec![
crate::client::types::ChatRequestMessage::user(format!(
"Your previous answer:\n\n{}\n\nYour self-critique:\n\n{}",
generated, critique
)),
crate::client::types::ChatRequestMessage::user(crate::modes::reflexion::REFLEXION_REVISE_PROMPT),
];
let revise_model = rig_client.completion_model(&request.model.name);
let revise_stream = revise_model
.completion_request("")
.preamble(crate::modes::reflexion::REFLEXION_SYSTEM_PROMPT.to_string())
.messages(revise_messages.iter().map(|m| {
crate::client::to_rig_message(m)
}).collect::<Vec<_>>())
.temperature(request.temperature as f64)
.max_tokens(request.max_tokens as u64)
.stream()
.await
.map_err(|e| AgentError::OpenAi(e.to_string()))?;
tokio::pin!(revise_stream);
let mut revised = String::new();
while let Some(item) = revise_stream.next().await {
match item {
Ok(rig::streaming::StreamedAssistantContent::Text(text)) => {
revised.push_str(&text.text);
}
Ok(rig::streaming::StreamedAssistantContent::Final(response)) => {
if let Some(usage) = response.token_usage() {
total_input_tokens += usage.input_tokens as i64;
total_output_tokens += usage.output_tokens as i64;
}
}
Err(e) => return Err(AgentError::OpenAi(e.to_string())),
_ => {}
}
}
best_answer = revised.clone();
on_chunk(crate::modes::reflexion::ReflexionStep::Revise(revised.clone())).await;
// If last cycle, emit final
if cycle + 1 >= max_cycles {
on_chunk(crate::modes::reflexion::ReflexionStep::Final(revised.clone())).await;
}
}
let elapsed_ms = session_start.elapsed().as_millis() as i64;
record_ai_session(
&request.cache, &request.db, request.project.id,
session_id, request.room.id, request.model.id,
version_id.unwrap_or_default(),
total_input_tokens, total_output_tokens, elapsed_ms,
).await;
Ok((best_answer, total_input_tokens, total_output_tokens))
}
}
fn build_rig_tools(
registry: &crate::tool::ToolRegistry,
db: &db::database::AppDatabase,
cache: &db::cache::AppCache,
cfg: &config::AppConfig,
room_id: uuid::Uuid,
sender_uid: uuid::Uuid,
project_id: uuid::Uuid,
session_id: uuid::Uuid,
) -> Vec<Box<dyn rig::tool::ToolDyn + 'static>> {
let mut tools: Vec<Box<dyn rig::tool::ToolDyn + 'static>> = Vec::new();
for def in registry.definitions() {
let name = def.name.clone();
if let Some(handler) = registry.get(&name) {
let adapter = crate::tool::RigToolAdapter::new(
handler.clone(), def.clone(),
db.clone(), cache.clone(), cfg.clone(),
room_id, Some(sender_uid), project_id,
);
tools.push(Box::new(crate::tool::RecordingTool::new(
Box::new(adapter), db.clone(), session_id, sender_uid,
)));
}
}
tools
}
/// Extract text from rig's ToolResultContent, ignoring images.
fn tool_result_content_to_string(content: &rig::one_or_many::OneOrMany<rig::completion::message::ToolResultContent>) -> String {
use rig::completion::message::ToolResultContent;
content
.iter()
.filter_map(|item| {
if let ToolResultContent::Text(t) = item {
Some(t.text.clone())
} else {
None
}
})
.collect::<Vec<_>>()
.join("\n")
}