gitdataai/libs/agent/chat/service.rs

1254 lines
49 KiB
Rust

use futures::StreamExt;
use models::projects::project_skill;
use models::rooms::room_ai;
use rig::agent::{AgentBuilder, MultiTurnStreamItem};
use rig::client::CompletionClient;
use rig::streaming::{StreamedAssistantContent, StreamingPrompt};
use sea_orm::*;
use std::pin::Pin;
use std::sync::Arc;
use uuid::Uuid;
use super::context::RoomMessageContext;
use super::{AiChunkType, AiRequest, AiStreamChunk, Mention, StreamCallback};
use crate::billing;
use crate::client::AiClientConfig;
use crate::client::types::{ChatRequestMessage, ToolCall};
use crate::client::{
StreamChunk, StreamChunkType, StreamedToolCall, call_stream, call_with_params,
};
use crate::compact::{CompactConfig, CompactService};
use crate::embed::EmbedService;
use crate::error::{AgentError, Result};
use crate::perception::{PerceptionService, SkillEntry, ToolCallEvent};
use crate::react::{DEFAULT_SYSTEM_PROMPT, ReactStep};
use crate::react::types::Action as ReactAction;
use crate::tool::{
RecordingTool, ToolCall as AgentToolCall, ToolContext, ToolExecutor,
registry::ToolRegistry,
};
/// Result from streaming AI response.
pub struct StreamResult {
pub content: String,
pub reasoning_content: String,
pub input_tokens: i64,
pub output_tokens: i64,
/// All chunks in arrival order — preserves ReAct multi-cycle ordering.
pub chunks: Vec<StreamChunk>,
}
/// Result from non-streaming AI response.
pub struct ProcessResult {
pub content: String,
pub input_tokens: i64,
pub output_tokens: i64,
}
/// Record an AI session with cost calculation.
async fn record_ai_session(
db: &db::database::AppDatabase,
project_id: Uuid,
session_id: Uuid,
room_id: Uuid,
model_id: Uuid,
version_id: Uuid,
input_tokens: i64,
output_tokens: i64,
latency_ms: i64,
) {
let (cost, currency) = match billing::record_ai_usage(
db,
project_id,
model_id,
input_tokens,
output_tokens,
)
.await
{
Ok(record) => (Some(record.cost), Some(record.currency)),
Err(_) => (None, None),
};
let _ = models::ai::ai_session::ActiveModel {
id: Set(session_id),
room: Set(room_id),
model: Set(model_id),
version: Set(version_id),
token_input: Set(input_tokens),
token_output: Set(output_tokens),
latency_ms: Set(Some(latency_ms)),
cost: Set(cost),
currency: Set(currency),
error_message: Set(None),
error_code: Set(None),
created_at: Set(chrono::Utc::now()),
}
.insert(db)
.await;
}
/// Service for handling AI chat requests in rooms.
pub struct ChatService {
ai_base_url: Option<String>,
ai_api_key: Option<String>,
compact_service: Option<CompactService>,
embed_service: Option<EmbedService>,
perception_service: PerceptionService,
tool_registry: Option<ToolRegistry>,
}
impl ChatService {
pub fn new() -> Self {
Self {
ai_base_url: None,
ai_api_key: None,
compact_service: None,
embed_service: None,
perception_service: PerceptionService::default(),
tool_registry: None,
}
}
pub fn with_ai_client_config(mut self, config: AiClientConfig) -> Self {
self.ai_base_url = config.base_url.clone();
self.ai_api_key = Some(config.api_key.clone());
self
}
pub fn with_compact_service(mut self, compact_service: CompactService) -> Self {
self.compact_service = Some(compact_service);
self
}
pub fn with_embed_service(mut self, embed_service: EmbedService) -> Self {
self.embed_service = Some(embed_service);
self
}
pub fn with_perception_service(mut self, perception_service: PerceptionService) -> Self {
self.perception_service = perception_service;
self
}
pub fn with_tool_registry(mut self, registry: ToolRegistry) -> Self {
self.tool_registry = Some(registry);
self
}
/// Returns all registered tools as JSON tool definitions.
pub fn tools(&self) -> Vec<serde_json::Value> {
self.tool_registry
.as_ref()
.map(|r| r.to_openai_tools())
.unwrap_or_default()
}
/// Build a RigToolSet from the registered tool registry.
///
/// This enables using the same tools with `RigAgentService` via rig's native Agent.
/// The context (db, cache, config, room_id, sender_id) is passed through to each
/// tool handler at creation time.
#[cfg(feature = "rig")]
pub fn rig_toolset(
&self,
db: db::database::AppDatabase,
cache: db::cache::AppCache,
config: config::AppConfig,
room_id: uuid::Uuid,
sender_id: Option<uuid::Uuid>,
project_id: uuid::Uuid,
) -> Option<crate::RigToolSet> {
self.tool_registry.as_ref().map(|registry| {
crate::RigToolSet::from_registry(
registry, db, cache, config, room_id, sender_id, project_id,
)
})
}
/// Get a reference to the underlying ToolRegistry.
pub fn tool_registry(&self) -> Option<&ToolRegistry> {
self.tool_registry.as_ref()
}
pub async fn process(&self, request: AiRequest) -> Result<ProcessResult> {
let tools: Vec<serde_json::Value> = request.tools.clone().unwrap_or_default();
let tools_enabled = !tools.is_empty();
let max_tool_depth = request.max_tool_depth;
let mut messages = self.build_messages(&request).await?;
let room_ai = room_ai::Entity::find()
.filter(room_ai::Column::Room.eq(request.room.id))
.filter(room_ai::Column::Model.eq(request.model.id))
.one(&request.db)
.await?;
let model_name = request.model.name.clone();
let temperature = room_ai
.as_ref()
.and_then(|r| r.temperature.map(|v| v as f32))
.unwrap_or(request.temperature as f32);
let max_tokens = room_ai
.as_ref()
.and_then(|r| r.max_tokens.map(|v| v as u32))
.unwrap_or(request.max_tokens as u32);
let mut tool_depth = 0;
let mut input_tokens = 0i64;
let mut output_tokens = 0i64;
let session_id = Uuid::new_v4();
let session_start = std::time::Instant::now();
let version_id = room_ai.as_ref().and_then(|r| r.version);
let config = AiClientConfig::new(self.ai_api_key.clone().unwrap_or_default())
.with_base_url(
self.ai_base_url
.clone()
.unwrap_or_else(|| "https://api.openai.com".into()),
);
loop {
let response = call_with_params(
&messages,
&model_name,
&config,
temperature,
max_tokens,
None,
if tools_enabled { Some(&tools) } else { None },
if tools_enabled { None } else { Some("none") },
)
.await?;
let text = response.content.clone();
input_tokens += response.input_tokens;
output_tokens += response.output_tokens;
if tools_enabled && !response.tool_calls_finished.is_empty() {
// Build assistant message with tool_calls
let tool_call_messages: Vec<_> = response
.tool_calls_finished
.iter()
.map(|name| {
// We need ID and arguments — for non-streaming we reconstruct from content
// The model returns tool_calls in its content; for now we create a placeholder
// that will be replaced by actual tool results
ToolCall {
id: Uuid::new_v4().to_string(),
type_: "function".into(),
function: crate::client::types::ToolCallFunction {
name: name.clone(),
arguments: "{}".into(),
},
}
})
.collect();
messages.push(ChatRequestMessage::assistant(
Some(text.clone()),
Some(tool_call_messages.clone()),
));
// Create ToolCall list for executor (we need real IDs and args)
// Since we can't get args from streaming, use name matching from the text
let calls: Vec<AgentToolCall> = tool_call_messages
.into_iter()
.map(|tc| AgentToolCall {
id: tc.id.clone(),
name: tc.function.name.clone(),
arguments: tc.function.arguments.clone(),
})
.collect();
let tool_messages = {
let mut ctx = ToolContext::new(
request.db.clone(),
request.cache.clone(),
request.config.clone(),
request.room.id,
Some(request.sender.uid),
)
.with_project(request.project.id);
if let Some(ref registry) = self.tool_registry {
ctx.registry_mut().merge(registry.clone());
}
let recorder = crate::tool::recorder::ToolCallRecorder::with_session(
request.db.clone(),
session_id,
);
let start = std::time::Instant::now();
let executor = ToolExecutor::new();
match executor.execute_batch(calls, &mut ctx).await {
Ok(results) => {
for (call, result) in
response.tool_calls_finished.iter().zip(results.iter())
{
let elapsed = start.elapsed().as_millis() as i64;
let is_error =
matches!(result.result, crate::tool::ToolResult::Error(_));
let error_msg = match &result.result {
crate::tool::ToolResult::Error(msg) => Some(msg.clone()),
_ => None,
};
recorder.record(crate::tool::recorder::ToolCallRecord {
tool_call_id: call.clone(),
session_id: recorder.session_id(),
tool_name: call.clone(),
caller: request.sender.uid,
arguments: serde_json::Value::Null,
status: if is_error {
models::ai::ToolCallStatus::Failed
} else {
models::ai::ToolCallStatus::Success
},
execution_time_ms: Some(elapsed),
error_message: error_msg,
error_stack: None,
retry_count: 0,
});
}
crate::tool::ToolExecutor::to_tool_messages(&results)
}
Err(e) => {
let elapsed = start.elapsed().as_millis() as i64;
for call_name in &response.tool_calls_finished {
recorder.record(crate::tool::recorder::ToolCallRecord {
tool_call_id: Uuid::new_v4().to_string(),
session_id: recorder.session_id(),
tool_name: call_name.clone(),
caller: request.sender.uid,
arguments: serde_json::Value::Null,
status: models::ai::ToolCallStatus::Failed,
execution_time_ms: Some(elapsed),
error_message: Some(e.to_string()),
error_stack: None,
retry_count: 0,
});
}
let err_msg = format!("[Tool call failed: {}]", e);
response
.tool_calls_finished
.iter()
.map(|_| {
ChatRequestMessage::tool(Uuid::new_v4().to_string(), &err_msg)
})
.collect()
}
}
};
messages.extend(tool_messages);
// Inject passive-detected skills based on tool calls
if let Ok(skills) = project_skill::Entity::find()
.filter(project_skill::Column::ProjectUuid.eq(request.project.id))
.filter(project_skill::Column::Enabled.eq(true))
.all(&request.db)
.await
{
let skill_entries: Vec<SkillEntry> = skills
.into_iter()
.map(|s| SkillEntry {
slug: s.slug,
name: s.name,
description: s.description,
content: s.content,
})
.collect();
let tool_events: Vec<ToolCallEvent> = response
.tool_calls_finished
.iter()
.map(|name| ToolCallEvent {
tool_name: name.clone(),
arguments: String::new(),
})
.collect();
for event in &tool_events {
if let Some(ctx) = self
.perception_service
.passive
.detect(event, &skill_entries)
{
messages.push(ctx.to_system_message());
}
}
}
tool_depth += 1;
if tool_depth >= max_tool_depth {
let content = if text.is_empty() {
format!(
"[AI reached maximum tool depth ({}) — no final answer produced]",
max_tool_depth
)
} else {
text
};
// Record session
record_ai_session(
&request.db,
request.project.id,
session_id,
request.room.id,
request.model.id,
version_id.unwrap_or_default(),
input_tokens,
output_tokens,
session_start.elapsed().as_millis() as i64,
)
.await;
return Ok(ProcessResult {
content,
input_tokens,
output_tokens,
});
}
continue;
}
// Record session
record_ai_session(
&request.db,
request.project.id,
session_id,
request.room.id,
request.model.id,
version_id.unwrap_or_default(),
input_tokens,
output_tokens,
session_start.elapsed().as_millis() as i64,
)
.await;
return Ok(ProcessResult {
content: text,
input_tokens,
output_tokens,
});
}
}
pub async fn process_stream(
&self,
request: AiRequest,
on_chunk: StreamCallback,
) -> Result<StreamResult> {
// Wrap on_chunk in Arc so it can be shared across loop iterations
let on_chunk = Arc::new(on_chunk);
let tools: Vec<serde_json::Value> = request.tools.clone().unwrap_or_default();
let tools_enabled = !tools.is_empty();
let max_tool_depth = request.max_tool_depth;
let mut messages = self.build_messages(&request).await?;
let room_ai = room_ai::Entity::find()
.filter(room_ai::Column::Room.eq(request.room.id))
.filter(room_ai::Column::Model.eq(request.model.id))
.one(&request.db)
.await?;
let model_name = request.model.name.clone();
let temperature = room_ai
.as_ref()
.and_then(|r| r.temperature.map(|v| v as f32))
.unwrap_or(request.temperature as f32);
let max_tokens = room_ai
.as_ref()
.and_then(|r| r.max_tokens.map(|v| v as u32))
.unwrap_or(request.max_tokens as u32);
let mut tool_depth = 0;
let mut total_input_tokens = 0i64;
let mut total_output_tokens = 0i64;
let session_id = Uuid::new_v4();
let session_start = std::time::Instant::now();
let version_id = room_ai.as_ref().and_then(|r| r.version);
let config = AiClientConfig::new(self.ai_api_key.clone().unwrap_or_default())
.with_base_url(
self.ai_base_url
.clone()
.unwrap_or_else(|| "https://api.openai.com".into()),
);
let mut full_content = String::new();
let mut all_chunks: Vec<StreamChunk> = Vec::new();
// Collect tool calls during streaming, push them incrementally after.
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<StreamedToolCall>();
loop {
let on_chunk_cb = on_chunk.clone();
let on_chunk_cb2 = on_chunk_cb.clone();
let tx_arc = Arc::new(tx.clone());
let tx_arc2 = tx_arc.clone();
let response = call_stream(
&messages,
&model_name,
&config,
temperature,
max_tokens,
if tools_enabled { Some(&tools) } else { None },
None, // tool_choice — auto (let model decide)
Arc::new(move |delta| {
let fut = on_chunk_cb(AiStreamChunk {
content: delta.to_string(),
done: false,
chunk_type: AiChunkType::Answer,
});
fut
}),
Arc::new(move |delta| {
let fut = on_chunk_cb2(AiStreamChunk {
content: delta.to_string(),
done: false,
chunk_type: AiChunkType::Thinking,
});
fut
}),
Arc::new(move |tc: &StreamedToolCall| {
let tx = tx_arc2.clone();
let tc_owned = tc.clone();
Box::pin(async move {
let _ = tx.send(tc_owned);
}) as Pin<Box<dyn std::future::Future<Output = ()> + Send>>
}),
)
.await?;
total_input_tokens += response.input_tokens;
total_output_tokens += response.output_tokens;
// Collect chunks from this streaming iteration in order.
all_chunks.extend(response.chunks);
let has_tool_calls = tools_enabled && !response.tool_calls.is_empty();
if has_tool_calls {
// Accumulate the assistant's text before tool calls
full_content.push_str(&response.content);
full_content.push('\n');
// Build assistant message with tool_calls from streaming response
let tool_calls: Vec<ToolCall> = response
.tool_calls
.iter()
.map(|tc| ToolCall {
id: tc.id.clone(),
type_: "function".into(),
function: crate::client::types::ToolCallFunction {
name: tc.name.clone(),
arguments: tc.arguments.clone(),
},
})
.collect();
messages.push(ChatRequestMessage::assistant(
Some(response.content.clone()),
Some(tool_calls.clone()),
));
// Push each tool call incrementally to frontend.
// Use try_recv() — tx is never dropped so recv() would deadlock.
loop {
match rx.try_recv() {
Ok(tc) => {
let args_display = if tc.arguments.len() > 100 {
format!("{}...", &tc.arguments[..100])
} else {
tc.arguments.clone()
};
let tool_display = format!("🔧 {}({})", tc.name, args_display);
on_chunk(AiStreamChunk {
content: tool_display.clone(),
done: false,
chunk_type: AiChunkType::ToolCall,
})
.await;
all_chunks.push(StreamChunk {
chunk_type: StreamChunkType::ToolCall,
content: tool_display,
});
}
Err(tokio::sync::mpsc::error::TryRecvError::Empty) => break,
Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => break,
}
}
// Execute tools one at a time, push each result incrementally
let calls: Vec<AgentToolCall> = response
.tool_calls
.iter()
.map(|tc| AgentToolCall {
id: tc.id.clone(),
name: tc.name.clone(),
arguments: tc.arguments.clone(),
})
.collect();
let mut tool_messages = Vec::new();
let mut ctx = crate::tool::ToolContext::new(
request.db.clone(),
request.cache.clone(),
request.config.clone(),
request.room.id,
Some(request.sender.uid),
)
.with_project(request.project.id);
if let Some(ref registry) = self.tool_registry {
ctx.registry_mut().merge(registry.clone());
}
let recorder = crate::tool::recorder::ToolCallRecorder::with_session(
request.db.clone(),
session_id,
);
for call in &calls {
let start = std::time::Instant::now();
let executor = crate::tool::ToolExecutor::new();
let results = match executor.execute_batch(vec![call.clone()], &mut ctx).await {
Ok(r) => r,
Err(e) => {
let elapsed = start.elapsed().as_millis() as i64;
recorder.record(crate::tool::recorder::ToolCallRecord {
tool_call_id: call.id.clone(),
session_id: recorder.session_id(),
tool_name: call.name.clone(),
caller: request.sender.uid,
arguments: call.arguments_json().unwrap_or_default(),
status: models::ai::ToolCallStatus::Failed,
execution_time_ms: Some(elapsed),
error_message: Some(e.to_string()),
error_stack: None,
retry_count: 0,
});
let err_text = format!("[Tool call failed: {}]", e);
tracing::warn!(tool = %call.name, args = %call.arguments, error = %e, "tool_call_failed");
let err_display = format!("{} (failed)", call.name);
on_chunk(AiStreamChunk {
content: err_display.clone(),
done: false,
chunk_type: AiChunkType::ToolResult,
})
.await;
all_chunks.push(StreamChunk {
chunk_type: StreamChunkType::ToolCall,
content: err_display,
});
tool_messages.push(ChatRequestMessage::tool(&call.id, &err_text));
continue;
}
};
for result in &results {
let text = match &result.result {
crate::tool::ToolResult::Ok(v) => v.to_string(),
crate::tool::ToolResult::Error(msg) => msg.clone(),
};
let preview = if text.len() > 300 {
format!("{}...", &text[..300])
} else {
text.clone()
};
tracing::debug!("tool_result: {} — {}", call.name, preview);
let elapsed = start.elapsed().as_millis() as i64;
let is_error = matches!(result.result, crate::tool::ToolResult::Error(_));
let error_msg = match &result.result {
crate::tool::ToolResult::Error(msg) => Some(msg.clone()),
_ => None,
};
recorder.record(crate::tool::recorder::ToolCallRecord {
tool_call_id: call.id.clone(),
session_id: recorder.session_id(),
tool_name: call.name.clone(),
caller: request.sender.uid,
arguments: call.arguments_json().unwrap_or_default(),
status: if is_error {
models::ai::ToolCallStatus::Failed
} else {
models::ai::ToolCallStatus::Success
},
execution_time_ms: Some(elapsed),
error_message: error_msg,
error_stack: None,
retry_count: 0,
});
// Do NOT emit tool_result chunks to frontend — raw output may contain sensitive data.
// Log server-side only; frontend sees tool_call status via on_chunk below.
}
let success_display = format!("{}", call.name);
on_chunk(AiStreamChunk {
content: success_display.clone(),
done: false,
chunk_type: AiChunkType::ToolResult,
})
.await;
all_chunks.push(StreamChunk {
chunk_type: StreamChunkType::ToolCall,
content: success_display,
});
let msgs = crate::tool::ToolExecutor::to_tool_messages(&results);
tool_messages.extend(msgs);
}
messages.extend(tool_messages);
// Inject passive-detected skills based on tool calls
if let Ok(skills) = project_skill::Entity::find()
.filter(project_skill::Column::ProjectUuid.eq(request.project.id))
.filter(project_skill::Column::Enabled.eq(true))
.all(&request.db)
.await
{
let skill_entries: Vec<SkillEntry> = skills
.into_iter()
.map(|s| SkillEntry {
slug: s.slug,
name: s.name,
description: s.description,
content: s.content,
})
.collect();
let tool_events: Vec<ToolCallEvent> = response
.tool_calls
.iter()
.map(|tc| ToolCallEvent {
tool_name: tc.name.clone(),
arguments: tc.arguments.clone(),
})
.collect();
for event in &tool_events {
if let Some(ctx) = self
.perception_service
.passive
.detect(event, &skill_entries)
{
messages.push(ctx.to_system_message());
}
}
}
tool_depth += 1;
if tool_depth >= max_tool_depth {
let max_depth_text = format!(
"[AI reached maximum tool depth ({}) — no final answer produced]",
max_tool_depth
);
on_chunk(AiStreamChunk {
content: max_depth_text.clone(),
done: true,
chunk_type: AiChunkType::Answer,
})
.await;
all_chunks.push(StreamChunk {
chunk_type: StreamChunkType::Answer,
content: max_depth_text,
});
// Record session
record_ai_session(
&request.db,
request.project.id,
session_id,
request.room.id,
request.model.id,
version_id.unwrap_or_default(),
total_input_tokens,
total_output_tokens,
session_start.elapsed().as_millis() as i64,
)
.await;
return Ok(StreamResult {
content: full_content,
reasoning_content: String::new(),
input_tokens: 0,
output_tokens: 0,
chunks: all_chunks,
});
}
continue;
}
// Final answer — accumulate and return
full_content.push_str(&response.content);
on_chunk(AiStreamChunk {
content: response.content.clone(),
done: true,
chunk_type: AiChunkType::Answer,
})
.await;
all_chunks.push(StreamChunk {
chunk_type: StreamChunkType::Answer,
content: response.content.clone(),
});
// Record session
record_ai_session(
&request.db,
request.project.id,
session_id,
request.room.id,
request.model.id,
version_id.unwrap_or_default(),
total_input_tokens,
total_output_tokens,
session_start.elapsed().as_millis() as i64,
)
.await;
return Ok(StreamResult {
content: full_content,
reasoning_content: response.reasoning_content,
input_tokens: response.input_tokens,
output_tokens: response.output_tokens,
chunks: all_chunks,
});
}
}
async fn build_messages(&self, request: &AiRequest) -> Result<Vec<ChatRequestMessage>> {
let mut messages = Vec::new();
let mut processed_history = Vec::new();
if let Some(compact_service) = &self.compact_service {
let config = CompactConfig::default();
match compact_service
.compact_room_auto(request.room.id, Some(request.user_names.clone()), config)
.await
{
Ok(compact_summary) => {
if !compact_summary.summary.is_empty() {
messages.push(ChatRequestMessage::system(format!(
"Conversation summary:\n{}",
compact_summary.summary
)));
}
processed_history = compact_summary.retained;
}
Err(e) => {
tracing::warn!(error = %e, "conversation compaction failed, using full history");
}
}
}
if !processed_history.is_empty() {
for msg_summary in processed_history {
let ctx = RoomMessageContext::from(msg_summary);
messages.push(ctx.to_message());
}
} else {
for msg in &request.history {
let ctx = RoomMessageContext::from_model_with_names(msg, &request.user_names);
messages.push(ctx.to_message());
}
}
for mention in &request.mention {
match mention {
Mention::Repo(repo) => {
// Inject repo details into system prompt so AI knows the repo context
let mut parts = vec![
format!("Name: {}", repo.repo_name),
format!("ID: {}", repo.id),
];
if let Some(ref desc) = repo.description {
parts.push(format!("Description: {}", desc));
}
parts.push(format!("Default branch: {}", repo.default_branch));
parts.push(format!(
"Private: {}",
if repo.is_private { "yes" } else { "no" }
));
parts.push(format!("Created: {}", repo.created_at.format("%Y-%m-%d")));
messages.push(ChatRequestMessage::system(format!(
"Mentioned repository:\n{}",
parts.join("\n")
)));
// Vector search for related issues and repos (enhancement, optional)
if let Some(embed_service) = &self.embed_service {
let query = format!(
"{} {}",
repo.repo_name,
repo.description.as_deref().unwrap_or_default()
);
if let Ok(issues) = embed_service.search_issues(&query, 5).await {
if !issues.is_empty() {
let context = format!(
"Related issues for repo {}:\n{}",
repo.repo_name,
issues
.iter()
.map(|i| format!("- {}", i.payload.text))
.collect::<Vec<_>>()
.join("\n")
);
messages.push(ChatRequestMessage::system(context));
}
}
if let Ok(repos) = embed_service.search_repos(&query, 3).await {
if !repos.is_empty() {
let context = format!(
"Similar repositories:\n{}",
repos
.iter()
.map(|r| format!("- {}", r.payload.text))
.collect::<Vec<_>>()
.join("\n")
);
messages.push(ChatRequestMessage::system(context));
}
}
}
}
Mention::User(user) => {
let mut profile_parts = vec![format!("Username: {}", user.username)];
if let Some(ref display_name) = user.display_name {
profile_parts.push(format!("Display name: {}", display_name));
}
if let Some(ref org) = user.organization {
profile_parts.push(format!("Organization: {}", org));
}
if let Some(ref website) = user.website_url {
profile_parts.push(format!("Website: {}", website));
}
messages.push(ChatRequestMessage::system(format!(
"Mentioned user profile:\n{}",
profile_parts.join("\n")
)));
}
}
}
let skill_contexts = self.build_skill_context(request).await;
for ctx in skill_contexts {
messages.push(ctx.to_system_message());
}
let memories = self.build_memory_context(request).await;
for mem in memories {
messages.push(mem.to_system_message());
}
messages.push(ChatRequestMessage::system(format!(
"Current Project:\n{}\nDescription: {}\nPublic: {}",
request.project.display_name,
request.project.description.as_deref().unwrap_or("(none)"),
if request.project.is_public {
"yes"
} else {
"no"
}
)));
let mut sender_parts = vec![format!("**Sender:** {}", request.sender.username)];
if let Some(ref display_name) = request.sender.display_name {
sender_parts.push(display_name.clone());
}
if let Some(ref org) = request.sender.organization {
sender_parts.push(format!("({})", org));
}
let sender_display = sender_parts.join(" ");
messages.push(ChatRequestMessage::system(format!(
"The person sending the next message:\n{}",
sender_display
)));
messages.push(ChatRequestMessage::user(&request.input));
Ok(messages)
}
async fn build_skill_context(
&self,
request: &AiRequest,
) -> Vec<crate::perception::SkillContext> {
// Load database skills for this project
let db_skills: Vec<SkillEntry> = match project_skill::Entity::find()
.filter(project_skill::Column::ProjectUuid.eq(request.project.id))
.filter(project_skill::Column::Enabled.eq(true))
.all(&request.db)
.await
{
Ok(models) => models
.into_iter()
.map(|s| SkillEntry {
slug: s.slug,
name: s.name,
description: s.description,
content: s.content,
})
.collect(),
Err(_) => Vec::new(),
};
// Load built-in skills and merge with database skills
// Built-in skills override database skills with the same slug
let mut all_skills: Vec<SkillEntry> = db_skills;
for built_in in crate::skills::all_skills() {
// Skip if a database skill with the same slug already exists
if !all_skills.iter().any(|s| s.slug == built_in.slug) {
all_skills.push(SkillEntry {
slug: built_in.slug.to_string(),
name: built_in.name.to_string(),
description: Some(built_in.description.to_string()),
content: built_in.content.clone(),
});
}
}
if all_skills.is_empty() {
return Vec::new();
}
let history_texts: Vec<String> = request
.history
.iter()
.rev()
.take(10)
.map(|msg| msg.content.clone())
.collect();
let tool_events: Vec<ToolCallEvent> = Vec::new();
let keyword_skills = self
.perception_service
.inject_skills(&request.input, &history_texts, &tool_events, &all_skills)
.await;
let mut vector_skills = Vec::new();
if let Some(embed_service) = &self.embed_service {
let awareness = crate::perception::VectorActiveAwareness::default();
vector_skills = awareness
.detect(
embed_service,
&request.input,
&request.project.id.to_string(),
)
.await;
}
let mut seen = std::collections::HashSet::new();
let mut result = Vec::new();
for ctx in vector_skills {
if seen.insert(ctx.label.clone()) {
result.push(ctx);
}
}
for ctx in keyword_skills {
if seen.insert(ctx.label.clone()) {
result.push(ctx);
}
}
result
}
async fn build_memory_context(
&self,
request: &AiRequest,
) -> Vec<crate::perception::vector::MemoryContext> {
let embed_service = match &self.embed_service {
Some(s) => s,
None => return Vec::new(),
};
let awareness = crate::perception::VectorPassiveAwareness::default();
awareness
.detect(
embed_service,
&request.input,
&request.project.display_name,
&request.room.id.to_string(),
)
.await
}
pub async fn process_react<C>(&self, request: &AiRequest, mut on_chunk: C) -> Result<String>
where
C: FnMut(crate::react::ReactStep) + Send,
{
let base_url = self
.ai_base_url
.clone()
.unwrap_or_else(|| "https://api.openai.com".into());
let api_key = self.ai_api_key.clone().unwrap_or_default();
let client_config = AiClientConfig::new(api_key).with_base_url(base_url);
let Some(registry) = &self.tool_registry else {
return Err(AgentError::Internal("no tool registry registered".into()));
};
let db = request.db.clone();
let cache = request.cache.clone();
let cfg = request.config.clone();
let room_id = request.room.id;
let sender_uid = request.sender.uid;
let project_id = request.project.id;
let session_id = Uuid::new_v4();
let session_start = std::time::Instant::now();
let version_id = room_ai::Entity::find()
.filter(room_ai::Column::Room.eq(request.room.id))
.filter(room_ai::Column::Model.eq(request.model.id))
.one(&request.db)
.await
.ok()
.flatten()
.and_then(|r| r.version);
// Build rig tools with recording wrapper directly from registry
let mut tools: Vec<Box<dyn rig::tool::ToolDyn + 'static>> = Vec::new();
for def in registry.definitions() {
let name = def.name.clone();
if let Some(handler) = registry.get(&name) {
let adapter = crate::tool::RigToolAdapter::new(
handler.clone(),
def.clone(),
db.clone(),
cache.clone(),
cfg.clone(),
room_id,
Some(sender_uid),
project_id,
);
tools.push(Box::new(RecordingTool::new(
Box::new(adapter),
db.clone(),
session_id,
sender_uid,
)));
}
}
// Build rig agent (handles multi-turn tool calls natively)
let rig_client = client_config.build_rig_client();
let model = rig_client.completion_model(&request.model.name);
let agent = AgentBuilder::new(model)
.preamble(DEFAULT_SYSTEM_PROMPT)
.tools(tools)
.default_max_turns(request.max_tool_depth)
.build();
let stream = agent
.stream_prompt(&request.input)
.with_history(Vec::new())
.multi_turn(request.max_tool_depth)
.await;
tokio::pin!(stream);
let mut step_count = 0usize;
let mut final_content = String::new();
let mut total_input_tokens: i64 = 0;
let mut total_output_tokens: i64 = 0;
while let Some(item) = stream.next().await {
match item {
Ok(MultiTurnStreamItem::StreamAssistantItem(
StreamedAssistantContent::Text(text),
)) => {
step_count += 1;
let t = text.text;
on_chunk(ReactStep::Answer {
step: step_count,
answer: t.clone(),
});
final_content.push_str(&t);
}
Ok(MultiTurnStreamItem::StreamAssistantItem(
StreamedAssistantContent::Reasoning(reasoning),
)) => {
let reasoning_text = reasoning.reasoning.join("");
if !reasoning_text.is_empty() {
step_count += 1;
on_chunk(ReactStep::Thought {
step: step_count,
thought: reasoning_text,
});
}
}
Ok(MultiTurnStreamItem::StreamAssistantItem(
StreamedAssistantContent::ReasoningDelta { reasoning, .. },
)) => {
if !reasoning.is_empty() {
step_count += 1;
on_chunk(ReactStep::Thought {
step: step_count,
thought: reasoning,
});
}
}
Ok(MultiTurnStreamItem::StreamAssistantItem(
StreamedAssistantContent::ToolCall { tool_call, .. },
)) => {
step_count += 1;
let args: serde_json::Value = match &tool_call.function.arguments {
serde_json::Value::String(s) => {
serde_json::from_str(s).unwrap_or(serde_json::Value::Null)
}
v => v.clone(),
};
on_chunk(ReactStep::Action {
step: step_count,
action: ReactAction::new(&tool_call.function.name, args),
});
}
Ok(MultiTurnStreamItem::StreamUserItem(
rig::streaming::StreamedUserContent::ToolResult { tool_result, .. },
)) => {
step_count += 1;
let obs = tool_result_content_to_string(&tool_result.content);
on_chunk(ReactStep::Observation {
step: step_count,
observation: obs,
});
}
Ok(MultiTurnStreamItem::FinalResponse(resp)) => {
let usage = resp.usage();
total_input_tokens = usage.input_tokens as i64;
total_output_tokens = usage.output_tokens as i64;
// Text was already streamed incrementally via Answer events.
}
Err(e) => {
let err_msg = format!("rig agent stream error: {}", e);
return Err(AgentError::OpenAi(err_msg));
}
_ => {}
}
}
let elapsed_ms = session_start.elapsed().as_millis() as i64;
record_ai_session(
&request.db,
request.project.id,
session_id,
request.room.id,
request.model.id,
version_id.unwrap_or_default(),
total_input_tokens,
total_output_tokens,
elapsed_ms,
)
.await;
Ok(final_content)
}
}
/// Extract text from rig's ToolResultContent, ignoring images.
fn tool_result_content_to_string(content: &rig::one_or_many::OneOrMany<rig::completion::message::ToolResultContent>) -> String {
use rig::completion::message::ToolResultContent;
content
.iter()
.filter_map(|item| {
if let ToolResultContent::Text(t) = item {
Some(t.text.clone())
} else {
None
}
})
.collect::<Vec<_>>()
.join("\n")
}