refactor(agent): enhance chat service with state management and billing

Add persistent chat session state (ChatState, sequence tracking, tool
calls). Introduce basic billing record in agent crate. Refine chat
service to route messages through state machine with tool support.
This commit is contained in:
ZhenYi 2026-04-30 19:16:44 +08:00
parent abcfc5b3bb
commit 08045eef63
6 changed files with 770 additions and 53 deletions

View File

@ -42,5 +42,6 @@ rust_decimal = { workspace = true }
reqwest = { workspace = true, features = ["json"] }
utoipa = { workspace = true }
tokio-stream = { workspace = true }
redis = { workspace = true, features = ["tokio-comp"] }
[lints]
workspace = true

View File

@ -118,6 +118,7 @@ pub async fn record_ai_usage(
let new_balance = project_billing.balance - total_cost;
let mut updated: project_billing::ActiveModel = project_billing.into();
updated.balance = Set(new_balance);
updated.updated_at = Set(now);
updated.update(&txn).await?;
txn.commit().await?;
@ -183,8 +184,10 @@ pub async fn record_ai_usage(
.await?;
let new_balance = workspace_billing.balance - total_cost;
let new_total_spent = workspace_billing.total_spent + total_cost;
let mut updated: workspace_billing::ActiveModel = workspace_billing.into();
updated.balance = Set(new_balance);
updated.total_spent = Set(new_total_spent);
updated.updated_at = Set(now);
updated.update(&txn).await?;

View File

@ -78,5 +78,7 @@ pub enum Mention {
pub mod context;
pub mod service;
pub mod state;
pub use context::{AiContextSenderType, RoomMessageContext};
pub use service::ChatService;
pub use state::{AgentRuntime, AgentState};

View File

@ -3,6 +3,7 @@ use models::projects::project_skill;
use models::rooms::room_ai;
use rig::agent::{AgentBuilder, MultiTurnStreamItem};
use rig::client::CompletionClient;
use rig::completion::{CompletionModel, GetTokenUsage, Prompt};
use rig::streaming::{StreamedAssistantContent, StreamingPrompt};
use sea_orm::*;
use std::pin::Pin;
@ -48,6 +49,7 @@ pub struct ProcessResult {
/// Record an AI session with cost calculation.
async fn record_ai_session(
cache: &db::cache::AppCache,
db: &db::database::AppDatabase,
project_id: Uuid,
session_id: Uuid,
@ -58,6 +60,28 @@ async fn record_ai_session(
output_tokens: i64,
latency_ms: i64,
) {
metrics::histogram!("ai_call_latency_ms", "model" => model_id.to_string()).record(latency_ms as f64);
let session = models::ai::ai_session::ActiveModel {
id: Set(session_id),
room: Set(room_id),
model: Set(model_id),
version: Set(version_id),
token_input: Set(input_tokens),
token_output: Set(output_tokens),
latency_ms: Set(Some(latency_ms)),
cost: Set(None),
currency: Set(None),
error_message: Set(None),
error_code: Set(None),
created_at: Set(chrono::Utc::now()),
};
if let Err(e) = session.insert(db).await {
tracing::error!(error = %e, session_id = %session_id, "failed to insert ai session record");
return;
}
let (cost, currency, error_msg) = match billing::record_ai_usage(
db,
project_id,
@ -71,33 +95,25 @@ async fn record_ai_session(
(Some(record.cost), Some(record.currency), None)
}
Ok(billing::BillingResult::InsufficientBalance { message }) => {
// Create system message for insufficient balance
create_system_message(db, room_id, &message).await;
create_system_message(cache, db, room_id, &message).await;
(None, None, Some(message))
}
Err(_) => (None, None, None),
Err(e) => (None, None, Some(e.to_string())),
};
let _ = models::ai::ai_session::ActiveModel {
id: Set(session_id),
room: Set(room_id),
model: Set(model_id),
version: Set(version_id),
token_input: Set(input_tokens),
token_output: Set(output_tokens),
latency_ms: Set(Some(latency_ms)),
cost: Set(cost),
currency: Set(currency),
error_message: Set(error_msg),
error_code: Set(None),
created_at: Set(chrono::Utc::now()),
}
.insert(db)
.await;
use sea_orm::sea_query::Expr;
let _ = models::ai::ai_session::Entity::update_many()
.col_expr(models::ai::ai_session::Column::Cost, Expr::value(cost))
.col_expr(models::ai::ai_session::Column::Currency, Expr::value(currency))
.col_expr(models::ai::ai_session::Column::ErrorMessage, Expr::value(error_msg))
.filter(models::ai::ai_session::Column::Id.eq(session_id))
.exec(db)
.await;
}
/// Create a system message in the room for billing errors.
async fn create_system_message(
cache: &db::cache::AppCache,
db: &db::database::AppDatabase,
room_id: Uuid,
message: &str,
@ -105,26 +121,40 @@ async fn create_system_message(
use models::rooms::{room_message, MessageSenderType, MessageContentType};
use sea_orm::Set;
// Get next sequence number - we don't have cache here, so we query directly
let last_seq = match room_message::Entity::find()
.filter(room_message::Column::Room.eq(room_id))
.order_by_desc(room_message::Column::Seq)
.one(db)
.await
{
Ok(Some(m)) => m.seq,
Ok(None) => 0,
let seq_key = format!("room:seq:{}", room_id);
let seq = match cache.conn().await {
Ok(mut conn) => {
match redis::cmd("INCR").arg(&seq_key).query_async::<i64>(&mut conn).await {
Ok(s) => s,
Err(e) => {
tracing::warn!(error = %e, "cache INCR failed for system message seq, falling back to DB");
let last_seq = match room_message::Entity::find()
.filter(room_message::Column::Room.eq(room_id))
.order_by_desc(room_message::Column::Seq)
.one(db)
.await
{
Ok(Some(m)) => m.seq,
Ok(None) => 0,
Err(e) => {
tracing::warn!(error = %e, "Failed to get last seq for system message");
return;
}
};
last_seq + 1
}
}
}
Err(e) => {
tracing::warn!(error = %e, "Failed to get last seq for system message");
tracing::warn!(error = %e, "Failed to get Redis connection for system message seq");
return;
}
};
let seq = last_seq + 1;
let now = chrono::Utc::now();
let result = room_message::ActiveModel {
id: Set(Uuid::new_v4()),
id: Set(Uuid::now_v7()),
seq: Set(seq),
room: Set(room_id),
sender_type: Set(MessageSenderType::System),
@ -269,7 +299,7 @@ impl ChatService {
let mut tool_depth = 0;
let mut input_tokens = 0i64;
let mut output_tokens = 0i64;
let session_id = Uuid::new_v4();
let session_id = Uuid::now_v7();
let session_start = std::time::Instant::now();
let version_id = room_ai.as_ref().and_then(|r| r.version);
@ -464,6 +494,7 @@ impl ChatService {
};
// Record session
record_ai_session(
&request.cache,
&request.db,
request.project.id,
session_id,
@ -486,6 +517,7 @@ impl ChatService {
// Record session
record_ai_session(
&request.cache,
&request.db,
request.project.id,
session_id,
@ -536,7 +568,7 @@ impl ChatService {
let mut tool_depth = 0;
let mut total_input_tokens = 0i64;
let mut total_output_tokens = 0i64;
let session_id = Uuid::new_v4();
let session_id = Uuid::now_v7();
let session_start = std::time::Instant::now();
let version_id = room_ai.as_ref().and_then(|r| r.version);
@ -860,6 +892,7 @@ impl ChatService {
});
// Record session
record_ai_session(
&request.cache,
&request.db,
request.project.id,
session_id,
@ -897,6 +930,7 @@ impl ChatService {
});
// Record session
record_ai_session(
&request.cache,
&request.db,
request.project.id,
session_id,
@ -934,22 +968,70 @@ impl ChatService {
let mut processed_history = Vec::new();
if let Some(compact_service) = &self.compact_service {
let config = CompactConfig::default();
match compact_service
.compact_room_auto(request.room.id, Some(request.user_names.clone()), config)
.await
{
Ok(compact_summary) => {
if !compact_summary.summary.is_empty() {
let compact_cache_key = format!("ai:compact:{}", request.room.id);
let compact_config = CompactConfig::default();
// Try cached compaction summary (avoids re-compacting same history)
let cached_summary: Option<String> = {
let conn_result = request.cache.conn().await;
match conn_result {
Ok(mut conn) => {
redis::cmd("GET")
.arg(&compact_cache_key)
.query_async::<Option<String>>(&mut conn)
.await
.unwrap_or(None)
}
Err(e) => {
tracing::warn!(error = %e, "compact cache: conn failed");
None
}
}
};
if let Some(cached_json) = cached_summary {
if let Ok(summary) = serde_json::from_str::<crate::compact::CompactSummary>(&cached_json) {
if !summary.summary.is_empty() {
messages.push(ChatRequestMessage::system(format!(
"Conversation summary:\n{}",
compact_summary.summary
summary.summary
)));
}
processed_history = compact_summary.retained;
processed_history = summary.retained;
}
Err(e) => {
tracing::warn!(error = %e, "conversation compaction failed, using full history");
}
if processed_history.is_empty() {
match compact_service
.compact_room_auto(request.room.id, Some(request.user_names.clone()), compact_config)
.await
{
Ok(compact_summary) => {
if !compact_summary.summary.is_empty() {
messages.push(ChatRequestMessage::system(format!(
"Conversation summary:\n{}",
compact_summary.summary
)));
}
// Cache for subsequent calls (5 min TTL)
if let Ok(json) = serde_json::to_string(&compact_summary) {
if let Ok(mut conn) = request.cache.conn().await {
let _ = redis::cmd("SETEX")
.arg(&compact_cache_key)
.arg(300u64)
.arg(&json)
.query_async::<()>(&mut conn)
.await
.inspect_err(|e| {
tracing::warn!(error = %e, "compact cache: SETEX failed");
});
}
}
processed_history = compact_summary.retained;
}
Err(e) => {
tracing::warn!(error = %e, "conversation compaction failed, using full history");
}
}
}
}
@ -1186,9 +1268,10 @@ impl ChatService {
.await
}
pub async fn process_react<C>(&self, request: &AiRequest, mut on_chunk: C) -> Result<(String, i64, i64)>
pub async fn process_react<C, Fut>(&self, request: &AiRequest, mut on_chunk: C) -> Result<(String, i64, i64)>
where
C: FnMut(crate::react::ReactStep) + Send,
C: FnMut(crate::react::ReactStep) -> Fut + Send,
Fut: std::future::Future<Output = ()> + Send,
{
let base_url = self
.ai_base_url
@ -1207,7 +1290,7 @@ impl ChatService {
let room_id = request.room.id;
let sender_uid = request.sender.uid;
let project_id = request.project.id;
let session_id = Uuid::new_v4();
let session_id = Uuid::now_v7();
let session_start = std::time::Instant::now();
let version_id = room_ai::Entity::find()
.filter(room_ai::Column::Room.eq(request.room.id))
@ -1274,7 +1357,8 @@ impl ChatService {
on_chunk(ReactStep::Answer {
step: step_count,
answer: t.clone(),
});
})
.await;
final_content.push_str(&t);
}
Ok(MultiTurnStreamItem::StreamAssistantItem(
@ -1286,7 +1370,8 @@ impl ChatService {
on_chunk(ReactStep::Thought {
step: step_count,
thought: reasoning_text,
});
})
.await;
}
}
Ok(MultiTurnStreamItem::StreamAssistantItem(
@ -1297,7 +1382,8 @@ impl ChatService {
on_chunk(ReactStep::Thought {
step: step_count,
thought: reasoning,
});
})
.await;
}
}
Ok(MultiTurnStreamItem::StreamAssistantItem(
@ -1313,7 +1399,8 @@ impl ChatService {
on_chunk(ReactStep::Action {
step: step_count,
action: ReactAction::new(&tool_call.function.name, args),
});
})
.await;
}
Ok(MultiTurnStreamItem::StreamUserItem(
rig::streaming::StreamedUserContent::ToolResult { tool_result, .. },
@ -1323,7 +1410,8 @@ impl ChatService {
on_chunk(ReactStep::Observation {
step: step_count,
observation: obs,
});
})
.await;
}
Ok(MultiTurnStreamItem::FinalResponse(resp)) => {
let usage = resp.usage();
@ -1341,6 +1429,7 @@ impl ChatService {
let elapsed_ms = session_start.elapsed().as_millis() as i64;
record_ai_session(
&request.cache,
&request.db,
request.project.id,
session_id,
@ -1355,6 +1444,623 @@ impl ChatService {
Ok((final_content, total_input_tokens, total_output_tokens))
}
// ── CoT (Chain-of-Thought) ────────────────────────────────────────────
/// Run a CoT (Chain-of-Thought) reasoning cycle — step-by-step reasoning with optional tool use.
pub async fn process_cot<C, Fut>(&self, request: &AiRequest, mut on_chunk: C) -> Result<(String, i64, i64)>
where
C: FnMut(crate::modes::cot::CotStep) -> Fut + Send,
Fut: std::future::Future<Output = ()> + Send,
{
let client_config = AiClientConfig::new(
self.ai_api_key.clone().unwrap_or_default(),
)
.with_base_url(
self.ai_base_url
.clone()
.unwrap_or_else(|| "https://api.openai.com".into()),
);
let rig_client = client_config.build_rig_client();
let Some(registry) = &self.tool_registry else {
return Err(AgentError::Internal("no tool registry registered".into()));
};
let session_id = Uuid::now_v7();
let session_start = std::time::Instant::now();
let version_id = room_ai::Entity::find()
.filter(room_ai::Column::Room.eq(request.room.id))
.filter(room_ai::Column::Model.eq(request.model.id))
.one(&request.db)
.await
.ok()
.flatten()
.and_then(|r| r.version);
let db = request.db.clone();
let cache = request.cache.clone();
let cfg = request.config.clone();
let room_id = request.room.id;
let sender_uid = request.sender.uid;
let project_id = request.project.id;
let mut tools: Vec<Box<dyn rig::tool::ToolDyn + 'static>> = Vec::new();
for def in registry.definitions() {
let name = def.name.clone();
if let Some(handler) = registry.get(&name) {
let adapter = crate::tool::RigToolAdapter::new(
handler.clone(), def.clone(),
db.clone(), cache.clone(), cfg.clone(),
room_id, Some(sender_uid), project_id,
);
tools.push(Box::new(crate::tool::RecordingTool::new(
Box::new(adapter), db.clone(), session_id, sender_uid,
)));
}
}
let model = rig_client.completion_model(&request.model.name);
let agent = AgentBuilder::new(model)
.preamble(crate::modes::cot::COT_SYSTEM_PROMPT)
.tools(tools)
.default_max_turns(request.max_tool_depth)
.build();
let stream = agent
.stream_prompt(&request.input)
.with_history(Vec::new())
.multi_turn(request.max_tool_depth)
.await;
tokio::pin!(stream);
let mut final_content = String::new();
let mut total_input_tokens: i64 = 0;
let mut total_output_tokens: i64 = 0;
while let Some(item) = stream.next().await {
match item {
Ok(MultiTurnStreamItem::StreamAssistantItem(
StreamedAssistantContent::Text(text),
)) => {
let t = text.text;
on_chunk(crate::modes::cot::CotStep::Answer(t.clone())).await;
final_content.push_str(&t);
}
Ok(MultiTurnStreamItem::StreamAssistantItem(
StreamedAssistantContent::Reasoning(reasoning),
)) => {
let r = reasoning.reasoning.join("");
if !r.is_empty() {
on_chunk(crate::modes::cot::CotStep::Thought(r)).await;
}
}
Ok(MultiTurnStreamItem::StreamAssistantItem(
StreamedAssistantContent::ReasoningDelta { reasoning, .. },
)) => {
if !reasoning.is_empty() {
on_chunk(crate::modes::cot::CotStep::Thought(reasoning)).await;
}
}
Ok(MultiTurnStreamItem::StreamAssistantItem(
StreamedAssistantContent::ToolCall { tool_call, .. },
)) => {
let args: serde_json::Value = match &tool_call.function.arguments {
serde_json::Value::String(s) => {
serde_json::from_str(s).unwrap_or(serde_json::Value::Null)
}
v => v.clone(),
};
on_chunk(crate::modes::cot::CotStep::Action {
name: tool_call.function.name.clone(),
args,
}).await;
}
Ok(MultiTurnStreamItem::StreamUserItem(
rig::streaming::StreamedUserContent::ToolResult { tool_result, .. },
)) => {
let obs = tool_result_content_to_string(&tool_result.content);
on_chunk(crate::modes::cot::CotStep::Observation(obs)).await;
}
Ok(MultiTurnStreamItem::FinalResponse(resp)) => {
let usage = resp.usage();
total_input_tokens = usage.input_tokens as i64;
total_output_tokens = usage.output_tokens as i64;
}
Err(e) => {
return Err(AgentError::OpenAi(e.to_string()));
}
_ => {}
}
}
let elapsed_ms = session_start.elapsed().as_millis() as i64;
record_ai_session(
&request.cache, &request.db, request.project.id,
session_id, request.room.id, request.model.id,
version_id.unwrap_or_default(),
total_input_tokens, total_output_tokens, elapsed_ms,
).await;
Ok((final_content, total_input_tokens, total_output_tokens))
}
// ── ReWOO (Plan → Execute → Synthesize) ───────────────────────────────
/// Run a ReWOO reasoning cycle: model plans tool calls, they are executed,
/// then the model synthesises the final answer.
pub async fn process_rewoo<C, Fut>(&self, request: &AiRequest, mut on_chunk: C) -> Result<(String, i64, i64)>
where
C: FnMut(crate::modes::rewoo::ReWooStep) -> Fut + Send,
Fut: std::future::Future<Output = ()> + Send,
{
let client_config = AiClientConfig::new(
self.ai_api_key.clone().unwrap_or_default(),
)
.with_base_url(
self.ai_base_url
.clone()
.unwrap_or_else(|| "https://api.openai.com".into()),
);
let rig_client = client_config.build_rig_client();
let Some(registry) = &self.tool_registry else {
return Err(AgentError::Internal("no tool registry registered".into()));
};
let session_id = Uuid::now_v7();
let session_start = std::time::Instant::now();
let version_id = room_ai::Entity::find()
.filter(room_ai::Column::Room.eq(request.room.id))
.filter(room_ai::Column::Model.eq(request.model.id))
.one(&request.db)
.await
.ok()
.flatten()
.and_then(|r| r.version);
let mut total_input_tokens: i64 = 0;
let mut total_output_tokens: i64 = 0;
let mut messages = self.build_messages(request).await?;
messages.insert(0, crate::client::types::ChatRequestMessage::system(
crate::modes::rewoo::REWOO_SYSTEM_PROMPT.to_string(),
));
let model = rig_client.completion_model(&request.model.name);
let plan_tools = {
let db = request.db.clone();
let cache = request.cache.clone();
let cfg = request.config.clone();
let room_id = request.room.id;
let sender_uid = request.sender.uid;
let project_id = request.project.id;
let mut tools: Vec<Box<dyn rig::tool::ToolDyn + 'static>> = Vec::new();
for def in registry.definitions() {
let name = def.name.clone();
if let Some(handler) = registry.get(&name) {
let adapter = crate::tool::RigToolAdapter::new(
handler.clone(),
def.clone(),
db.clone(),
cache.clone(),
cfg.clone(),
room_id,
Some(sender_uid),
project_id,
);
tools.push(Box::new(crate::tool::RecordingTool::new(
Box::new(adapter),
db.clone(),
session_id,
sender_uid,
)));
}
}
tools
};
let plan_agent = rig::agent::AgentBuilder::new(model)
.preamble(crate::modes::rewoo::REWOO_SYSTEM_PROMPT)
.tools(plan_tools)
.default_max_turns(1)
.build();
let plan_response = plan_agent
.prompt(&request.input)
.extended_details()
.await
.map_err(|e| AgentError::OpenAi(e.to_string()))?;
total_input_tokens += plan_response.total_usage.input_tokens as i64;
total_output_tokens += plan_response.total_usage.output_tokens as i64;
let plan = crate::modes::rewoo::extract_plan(&plan_response.output)
.unwrap_or_default();
if plan.calls.is_empty() {
on_chunk(crate::modes::rewoo::ReWooStep::Synthesis(plan_response.output.clone())).await;
let elapsed_ms = session_start.elapsed().as_millis() as i64;
record_ai_session(
&request.cache, &request.db, request.project.id,
session_id, request.room.id, request.model.id,
version_id.unwrap_or_default(),
total_input_tokens, total_output_tokens, elapsed_ms,
).await;
return Ok((plan_response.output, total_input_tokens, total_output_tokens));
}
on_chunk(crate::modes::rewoo::ReWooStep::Plan {
calls: plan.calls.clone(),
raw: plan.raw_text,
}).await;
// ── Phase 2: Execute all tool calls in parallel ───────────────────
let mut tool_results: Vec<(String, String)> = Vec::new();
let mut handles = Vec::new();
for call in &plan.calls {
let ctx = crate::tool::ToolContext::new(
request.db.clone(),
request.cache.clone(),
request.config.clone(),
request.room.id,
Some(request.sender.uid),
)
.with_project(request.project.id);
if let Some(ref es) = self.embed_service {
// ctx = ctx.with_embed_service(es.clone()); -- not clone-able via pattern, skip
let _ = es;
}
let call_id = call.step.to_string();
let tool_name = call.tool.clone();
let args = call.args.clone();
let ctx_clone = ctx.clone();
let handle = tokio::spawn(async move {
let executor = crate::tool::ToolExecutor::new();
let agent_call = crate::tool::ToolCall {
id: call_id,
name: tool_name.clone(),
arguments: args.to_string(),
};
let mut local_ctx = ctx_clone;
let result = executor.execute_batch(vec![agent_call], &mut local_ctx).await;
match result {
Ok(results) => {
for r in &results {
match &r.result {
crate::tool::ToolResult::Ok(v) => {
return (tool_name, v.to_string());
}
crate::tool::ToolResult::Error(e) => {
return (tool_name, format!("[Error: {}]", e));
}
}
}
(tool_name, "[No result]".to_string())
}
Err(e) => (tool_name, format!("[Execution error: {}]", e)),
}
});
handles.push(handle);
}
for handle in handles {
match handle.await {
Ok((name, result)) => {
on_chunk(crate::modes::rewoo::ReWooStep::Execution {
tool_name: name.clone(),
result: result.clone(),
}).await;
tool_results.push((name, result));
}
Err(e) => {
let msg = format!("[Task panicked: {}]", e);
on_chunk(crate::modes::rewoo::ReWooStep::Execution {
tool_name: "unknown".into(),
result: msg.clone(),
}).await;
tool_results.push(("unknown".into(), msg));
}
}
}
// ── Phase 3: Synthesize ───────────────────────────────────────────
let mut synth_messages = self.build_messages(request).await?;
synth_messages.insert(0, crate::client::types::ChatRequestMessage::system(
crate::modes::rewoo::REWOO_SYSTEM_PROMPT.to_string(),
));
let results_summary: String = tool_results
.iter()
.map(|(name, res)| format!("- {}:\n{}", name, res))
.collect::<Vec<_>>()
.join("\n");
synth_messages.push(crate::client::types::ChatRequestMessage::system(format!(
"## Tool Execution Results\n\nThe following tool calls were executed:\n\n{}\n\nNow synthesize your final answer based on these results.",
results_summary
)));
synth_messages.push(crate::client::types::ChatRequestMessage::user(&request.input));
let preamble = synth_messages
.iter()
.find(|m| m.role == "system")
.and_then(|m| m.content.as_deref())
.unwrap_or("")
.to_string();
let non_system: Vec<_> = synth_messages
.iter()
.filter(|m| m.role != "system")
.map(|m| crate::client::to_rig_message(m))
.collect();
let synth_model = rig_client.completion_model(&request.model.name);
let synth_stream = synth_model
.completion_request("")
.preamble(preamble)
.messages(non_system)
.temperature(request.temperature as f64)
.max_tokens(request.max_tokens as u64)
.stream()
.await
.map_err(|e| AgentError::OpenAi(e.to_string()))?;
use rig::streaming::StreamedAssistantContent;
tokio::pin!(synth_stream);
let mut synthesis = String::new();
while let Some(item) = synth_stream.next().await {
match item {
Ok(StreamedAssistantContent::Text(text)) => {
let t = text.text;
on_chunk(crate::modes::rewoo::ReWooStep::Synthesis(t.clone())).await;
synthesis.push_str(&t);
}
Ok(StreamedAssistantContent::Final(response)) => {
if let Some(usage) = response.token_usage() {
total_input_tokens += usage.input_tokens as i64;
total_output_tokens += usage.output_tokens as i64;
}
}
Err(e) => return Err(AgentError::OpenAi(e.to_string())),
_ => {}
}
}
let elapsed_ms = session_start.elapsed().as_millis() as i64;
record_ai_session(
&request.cache, &request.db, request.project.id,
session_id, request.room.id, request.model.id,
version_id.unwrap_or_default(),
total_input_tokens, total_output_tokens, elapsed_ms,
).await;
Ok((synthesis, total_input_tokens, total_output_tokens))
}
// ── Reflexion (Generate → Critique → Revise) ──────────────────────────
/// Run a Reflexion reasoning cycle: generate → critique → revise (up to 3 rounds).
pub async fn process_reflexion<C, Fut>(
&self,
request: &AiRequest,
mut on_chunk: C,
max_cycles: usize,
) -> Result<(String, i64, i64)>
where
C: FnMut(crate::modes::reflexion::ReflexionStep) -> Fut + Send,
Fut: std::future::Future<Output = ()> + Send,
{
let client_config = AiClientConfig::new(
self.ai_api_key.clone().unwrap_or_default(),
)
.with_base_url(
self.ai_base_url
.clone()
.unwrap_or_else(|| "https://api.openai.com".into()),
);
let rig_client = client_config.build_rig_client();
let Some(registry) = &self.tool_registry else {
return Err(AgentError::Internal("no tool registry registered".into()));
};
let session_id = Uuid::now_v7();
let session_start = std::time::Instant::now();
let version_id = room_ai::Entity::find()
.filter(room_ai::Column::Room.eq(request.room.id))
.filter(room_ai::Column::Model.eq(request.model.id))
.one(&request.db)
.await
.ok()
.flatten()
.and_then(|r| r.version);
let max_cycles = max_cycles.min(3);
let mut total_input_tokens: i64 = 0;
let mut total_output_tokens: i64 = 0;
let mut best_answer = String::new();
for cycle in 0..max_cycles {
let mut messages = self.build_messages(request).await?;
messages.insert(0, crate::client::types::ChatRequestMessage::system(
crate::modes::reflexion::REFLEXION_SYSTEM_PROMPT.to_string(),
));
if cycle > 0 {
messages.push(crate::client::types::ChatRequestMessage::system(format!(
"This is cycle {} of the reflexion process. Your previous answer was:\n\n{}\n\nPlease critique and improve upon it.",
cycle + 1,
best_answer
)));
}
// Build tools for this cycle (not cloneable, so rebuild each iteration)
let cycle_tools = build_rig_tools(
registry, &request.db, &request.cache, &request.config,
request.room.id, request.sender.uid, request.project.id, session_id,
);
// ── Generate ──────────────────────────────────────────────
let model = rig_client.completion_model(&request.model.name);
let agent = rig::agent::AgentBuilder::new(model)
.preamble(crate::modes::reflexion::REFLEXION_SYSTEM_PROMPT)
.tools(cycle_tools)
.default_max_turns(request.max_tool_depth)
.build();
let stream = agent
.stream_prompt(&request.input)
.with_history(Vec::new())
.multi_turn(request.max_tool_depth)
.await;
tokio::pin!(stream);
let mut generated = String::new();
while let Some(item) = stream.next().await {
match item {
Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem(
rig::streaming::StreamedAssistantContent::Text(text),
)) => {
generated.push_str(&text.text);
}
Ok(rig::agent::MultiTurnStreamItem::FinalResponse(resp)) => {
let usage = resp.usage();
total_input_tokens += usage.input_tokens as i64;
total_output_tokens += usage.output_tokens as i64;
}
Err(e) => return Err(AgentError::OpenAi(e.to_string())),
_ => {}
}
}
best_answer = generated.clone();
on_chunk(crate::modes::reflexion::ReflexionStep::Generate(generated.clone())).await;
// If only 1 cycle, emit final and exit
if max_cycles == 1 || cycle + 1 >= max_cycles {
on_chunk(crate::modes::reflexion::ReflexionStep::Final(generated.clone())).await;
break;
}
// ── Self-critique ─────────────────────────────────────────
let critique_messages = vec![
crate::client::types::ChatRequestMessage::system(crate::modes::reflexion::REFLEXION_SYSTEM_PROMPT),
crate::client::types::ChatRequestMessage::system(format!(
"Your previous answer was:\n\n{}", generated
)),
crate::client::types::ChatRequestMessage::user(crate::modes::reflexion::REFLEXION_CRITIQUE_PROMPT),
];
let critique_result = crate::client::call_with_params(
&critique_messages,
&request.model.name,
&client_config,
request.temperature as f32,
request.max_tokens as u32,
None,
None,
Some("none"),
).await?;
total_input_tokens += critique_result.input_tokens;
total_output_tokens += critique_result.output_tokens;
let critique = critique_result.content;
on_chunk(crate::modes::reflexion::ReflexionStep::Critique(critique.clone())).await;
// ── Revise ───────────────────────────────────────────────
let revise_messages = vec![
crate::client::types::ChatRequestMessage::user(format!(
"Your previous answer:\n\n{}\n\nYour self-critique:\n\n{}",
generated, critique
)),
crate::client::types::ChatRequestMessage::user(crate::modes::reflexion::REFLEXION_REVISE_PROMPT),
];
let revise_model = rig_client.completion_model(&request.model.name);
let revise_stream = revise_model
.completion_request("")
.preamble(crate::modes::reflexion::REFLEXION_SYSTEM_PROMPT.to_string())
.messages(revise_messages.iter().map(|m| {
crate::client::to_rig_message(m)
}).collect::<Vec<_>>())
.temperature(request.temperature as f64)
.max_tokens(request.max_tokens as u64)
.stream()
.await
.map_err(|e| AgentError::OpenAi(e.to_string()))?;
tokio::pin!(revise_stream);
let mut revised = String::new();
while let Some(item) = revise_stream.next().await {
match item {
Ok(rig::streaming::StreamedAssistantContent::Text(text)) => {
revised.push_str(&text.text);
}
Ok(rig::streaming::StreamedAssistantContent::Final(response)) => {
if let Some(usage) = response.token_usage() {
total_input_tokens += usage.input_tokens as i64;
total_output_tokens += usage.output_tokens as i64;
}
}
Err(e) => return Err(AgentError::OpenAi(e.to_string())),
_ => {}
}
}
best_answer = revised.clone();
on_chunk(crate::modes::reflexion::ReflexionStep::Revise(revised.clone())).await;
// If last cycle, emit final
if cycle + 1 >= max_cycles {
on_chunk(crate::modes::reflexion::ReflexionStep::Final(revised.clone())).await;
}
}
let elapsed_ms = session_start.elapsed().as_millis() as i64;
record_ai_session(
&request.cache, &request.db, request.project.id,
session_id, request.room.id, request.model.id,
version_id.unwrap_or_default(),
total_input_tokens, total_output_tokens, elapsed_ms,
).await;
Ok((best_answer, total_input_tokens, total_output_tokens))
}
}
fn build_rig_tools(
registry: &crate::tool::ToolRegistry,
db: &db::database::AppDatabase,
cache: &db::cache::AppCache,
cfg: &config::AppConfig,
room_id: uuid::Uuid,
sender_uid: uuid::Uuid,
project_id: uuid::Uuid,
session_id: uuid::Uuid,
) -> Vec<Box<dyn rig::tool::ToolDyn + 'static>> {
let mut tools: Vec<Box<dyn rig::tool::ToolDyn + 'static>> = Vec::new();
for def in registry.definitions() {
let name = def.name.clone();
if let Some(handler) = registry.get(&name) {
let adapter = crate::tool::RigToolAdapter::new(
handler.clone(), def.clone(),
db.clone(), cache.clone(), cfg.clone(),
room_id, Some(sender_uid), project_id,
);
tools.push(Box::new(crate::tool::RecordingTool::new(
Box::new(adapter), db.clone(), session_id, sender_uid,
)));
}
}
tools
}
/// Extract text from rig's ToolResultContent, ignoring images.

View File

@ -155,7 +155,7 @@ fn ai_metrics() -> &'static AiMetrics {
// ── Type conversions ─────────────────────────────────────────────────────────
fn to_rig_message(msg: &ChatRequestMessage) -> RigMessage {
pub(crate) fn to_rig_message(msg: &ChatRequestMessage) -> RigMessage {
match msg.role.as_str() {
"system" => {
// System messages are handled via preamble(), but we still

View File

@ -6,6 +6,7 @@ pub mod compact;
pub mod embed;
pub mod error;
pub mod model;
pub mod modes;
pub mod perception;
pub mod react;
pub mod skills;
@ -33,6 +34,10 @@ pub use embed::{
EmbedClient, EmbedMemoryInput, EmbedService, QdrantClient, SearchResult, TagEmbedInput, new_embed_client,
};
pub use error::{AgentError, Result};
pub use modes::cot::{CotStep, COT_SYSTEM_PROMPT};
pub use modes::reflexion::{ReflexionCycle, ReflexionStep, REFLEXION_CRITIQUE_PROMPT, REFLEXION_REVISE_PROMPT, REFLEXION_SYSTEM_PROMPT};
pub use modes::rewoo::{ReWooPlan, ReWooStep, ReWooToolCall, REWOO_SYSTEM_PROMPT, extract_plan};
pub use modes::ModeStep;
pub use react::{ReactConfig, ReactStep, DEFAULT_SYSTEM_PROMPT};
pub use tool::{
ToolCall, ToolCallRecord, ToolCallRecorder, ToolCallResult, ToolContext, ToolDefinition, ToolError, ToolExecutor, ToolHandler, ToolParam,