diff --git a/libs/agent/Cargo.toml b/libs/agent/Cargo.toml index 8949838..3a10de2 100644 --- a/libs/agent/Cargo.toml +++ b/libs/agent/Cargo.toml @@ -42,5 +42,6 @@ rust_decimal = { workspace = true } reqwest = { workspace = true, features = ["json"] } utoipa = { workspace = true } tokio-stream = { workspace = true } +redis = { workspace = true, features = ["tokio-comp"] } [lints] workspace = true diff --git a/libs/agent/billing.rs b/libs/agent/billing.rs index 8886fe7..3d7dcd3 100644 --- a/libs/agent/billing.rs +++ b/libs/agent/billing.rs @@ -118,6 +118,7 @@ pub async fn record_ai_usage( let new_balance = project_billing.balance - total_cost; let mut updated: project_billing::ActiveModel = project_billing.into(); updated.balance = Set(new_balance); + updated.updated_at = Set(now); updated.update(&txn).await?; txn.commit().await?; @@ -183,8 +184,10 @@ pub async fn record_ai_usage( .await?; let new_balance = workspace_billing.balance - total_cost; + let new_total_spent = workspace_billing.total_spent + total_cost; let mut updated: workspace_billing::ActiveModel = workspace_billing.into(); updated.balance = Set(new_balance); + updated.total_spent = Set(new_total_spent); updated.updated_at = Set(now); updated.update(&txn).await?; diff --git a/libs/agent/chat/mod.rs b/libs/agent/chat/mod.rs index ee89864..0c0bc2a 100644 --- a/libs/agent/chat/mod.rs +++ b/libs/agent/chat/mod.rs @@ -78,5 +78,7 @@ pub enum Mention { pub mod context; pub mod service; +pub mod state; pub use context::{AiContextSenderType, RoomMessageContext}; pub use service::ChatService; +pub use state::{AgentRuntime, AgentState}; diff --git a/libs/agent/chat/service.rs b/libs/agent/chat/service.rs index de48f11..3316fa7 100644 --- a/libs/agent/chat/service.rs +++ b/libs/agent/chat/service.rs @@ -3,6 +3,7 @@ use models::projects::project_skill; use models::rooms::room_ai; use rig::agent::{AgentBuilder, MultiTurnStreamItem}; use rig::client::CompletionClient; +use rig::completion::{CompletionModel, GetTokenUsage, Prompt}; use rig::streaming::{StreamedAssistantContent, StreamingPrompt}; use sea_orm::*; use std::pin::Pin; @@ -48,6 +49,7 @@ pub struct ProcessResult { /// Record an AI session with cost calculation. async fn record_ai_session( + cache: &db::cache::AppCache, db: &db::database::AppDatabase, project_id: Uuid, session_id: Uuid, @@ -58,6 +60,28 @@ async fn record_ai_session( output_tokens: i64, latency_ms: i64, ) { + metrics::histogram!("ai_call_latency_ms", "model" => model_id.to_string()).record(latency_ms as f64); + + let session = models::ai::ai_session::ActiveModel { + id: Set(session_id), + room: Set(room_id), + model: Set(model_id), + version: Set(version_id), + token_input: Set(input_tokens), + token_output: Set(output_tokens), + latency_ms: Set(Some(latency_ms)), + cost: Set(None), + currency: Set(None), + error_message: Set(None), + error_code: Set(None), + created_at: Set(chrono::Utc::now()), + }; + + if let Err(e) = session.insert(db).await { + tracing::error!(error = %e, session_id = %session_id, "failed to insert ai session record"); + return; + } + let (cost, currency, error_msg) = match billing::record_ai_usage( db, project_id, @@ -71,33 +95,25 @@ async fn record_ai_session( (Some(record.cost), Some(record.currency), None) } Ok(billing::BillingResult::InsufficientBalance { message }) => { - // Create system message for insufficient balance - create_system_message(db, room_id, &message).await; + create_system_message(cache, db, room_id, &message).await; (None, None, Some(message)) } - Err(_) => (None, None, None), + Err(e) => (None, None, Some(e.to_string())), }; - let _ = models::ai::ai_session::ActiveModel { - id: Set(session_id), - room: Set(room_id), - model: Set(model_id), - version: Set(version_id), - token_input: Set(input_tokens), - token_output: Set(output_tokens), - latency_ms: Set(Some(latency_ms)), - cost: Set(cost), - currency: Set(currency), - error_message: Set(error_msg), - error_code: Set(None), - created_at: Set(chrono::Utc::now()), - } - .insert(db) - .await; + use sea_orm::sea_query::Expr; + let _ = models::ai::ai_session::Entity::update_many() + .col_expr(models::ai::ai_session::Column::Cost, Expr::value(cost)) + .col_expr(models::ai::ai_session::Column::Currency, Expr::value(currency)) + .col_expr(models::ai::ai_session::Column::ErrorMessage, Expr::value(error_msg)) + .filter(models::ai::ai_session::Column::Id.eq(session_id)) + .exec(db) + .await; } /// Create a system message in the room for billing errors. async fn create_system_message( + cache: &db::cache::AppCache, db: &db::database::AppDatabase, room_id: Uuid, message: &str, @@ -105,26 +121,40 @@ async fn create_system_message( use models::rooms::{room_message, MessageSenderType, MessageContentType}; use sea_orm::Set; - // Get next sequence number - we don't have cache here, so we query directly - let last_seq = match room_message::Entity::find() - .filter(room_message::Column::Room.eq(room_id)) - .order_by_desc(room_message::Column::Seq) - .one(db) - .await - { - Ok(Some(m)) => m.seq, - Ok(None) => 0, + let seq_key = format!("room:seq:{}", room_id); + let seq = match cache.conn().await { + Ok(mut conn) => { + match redis::cmd("INCR").arg(&seq_key).query_async::(&mut conn).await { + Ok(s) => s, + Err(e) => { + tracing::warn!(error = %e, "cache INCR failed for system message seq, falling back to DB"); + let last_seq = match room_message::Entity::find() + .filter(room_message::Column::Room.eq(room_id)) + .order_by_desc(room_message::Column::Seq) + .one(db) + .await + { + Ok(Some(m)) => m.seq, + Ok(None) => 0, + Err(e) => { + tracing::warn!(error = %e, "Failed to get last seq for system message"); + return; + } + }; + last_seq + 1 + } + } + } Err(e) => { - tracing::warn!(error = %e, "Failed to get last seq for system message"); + tracing::warn!(error = %e, "Failed to get Redis connection for system message seq"); return; } }; - let seq = last_seq + 1; let now = chrono::Utc::now(); let result = room_message::ActiveModel { - id: Set(Uuid::new_v4()), + id: Set(Uuid::now_v7()), seq: Set(seq), room: Set(room_id), sender_type: Set(MessageSenderType::System), @@ -269,7 +299,7 @@ impl ChatService { let mut tool_depth = 0; let mut input_tokens = 0i64; let mut output_tokens = 0i64; - let session_id = Uuid::new_v4(); + let session_id = Uuid::now_v7(); let session_start = std::time::Instant::now(); let version_id = room_ai.as_ref().and_then(|r| r.version); @@ -464,6 +494,7 @@ impl ChatService { }; // Record session record_ai_session( + &request.cache, &request.db, request.project.id, session_id, @@ -486,6 +517,7 @@ impl ChatService { // Record session record_ai_session( + &request.cache, &request.db, request.project.id, session_id, @@ -536,7 +568,7 @@ impl ChatService { let mut tool_depth = 0; let mut total_input_tokens = 0i64; let mut total_output_tokens = 0i64; - let session_id = Uuid::new_v4(); + let session_id = Uuid::now_v7(); let session_start = std::time::Instant::now(); let version_id = room_ai.as_ref().and_then(|r| r.version); @@ -860,6 +892,7 @@ impl ChatService { }); // Record session record_ai_session( + &request.cache, &request.db, request.project.id, session_id, @@ -897,6 +930,7 @@ impl ChatService { }); // Record session record_ai_session( + &request.cache, &request.db, request.project.id, session_id, @@ -934,22 +968,70 @@ impl ChatService { let mut processed_history = Vec::new(); if let Some(compact_service) = &self.compact_service { - let config = CompactConfig::default(); - match compact_service - .compact_room_auto(request.room.id, Some(request.user_names.clone()), config) - .await - { - Ok(compact_summary) => { - if !compact_summary.summary.is_empty() { + let compact_cache_key = format!("ai:compact:{}", request.room.id); + let compact_config = CompactConfig::default(); + + // Try cached compaction summary (avoids re-compacting same history) + let cached_summary: Option = { + let conn_result = request.cache.conn().await; + match conn_result { + Ok(mut conn) => { + redis::cmd("GET") + .arg(&compact_cache_key) + .query_async::>(&mut conn) + .await + .unwrap_or(None) + } + Err(e) => { + tracing::warn!(error = %e, "compact cache: conn failed"); + None + } + } + }; + + if let Some(cached_json) = cached_summary { + if let Ok(summary) = serde_json::from_str::(&cached_json) { + if !summary.summary.is_empty() { messages.push(ChatRequestMessage::system(format!( "Conversation summary:\n{}", - compact_summary.summary + summary.summary ))); } - processed_history = compact_summary.retained; + processed_history = summary.retained; } - Err(e) => { - tracing::warn!(error = %e, "conversation compaction failed, using full history"); + } + + if processed_history.is_empty() { + match compact_service + .compact_room_auto(request.room.id, Some(request.user_names.clone()), compact_config) + .await + { + Ok(compact_summary) => { + if !compact_summary.summary.is_empty() { + messages.push(ChatRequestMessage::system(format!( + "Conversation summary:\n{}", + compact_summary.summary + ))); + } + // Cache for subsequent calls (5 min TTL) + if let Ok(json) = serde_json::to_string(&compact_summary) { + if let Ok(mut conn) = request.cache.conn().await { + let _ = redis::cmd("SETEX") + .arg(&compact_cache_key) + .arg(300u64) + .arg(&json) + .query_async::<()>(&mut conn) + .await + .inspect_err(|e| { + tracing::warn!(error = %e, "compact cache: SETEX failed"); + }); + } + } + processed_history = compact_summary.retained; + } + Err(e) => { + tracing::warn!(error = %e, "conversation compaction failed, using full history"); + } } } } @@ -1186,9 +1268,10 @@ impl ChatService { .await } - pub async fn process_react(&self, request: &AiRequest, mut on_chunk: C) -> Result<(String, i64, i64)> + pub async fn process_react(&self, request: &AiRequest, mut on_chunk: C) -> Result<(String, i64, i64)> where - C: FnMut(crate::react::ReactStep) + Send, + C: FnMut(crate::react::ReactStep) -> Fut + Send, + Fut: std::future::Future + Send, { let base_url = self .ai_base_url @@ -1207,7 +1290,7 @@ impl ChatService { let room_id = request.room.id; let sender_uid = request.sender.uid; let project_id = request.project.id; - let session_id = Uuid::new_v4(); + let session_id = Uuid::now_v7(); let session_start = std::time::Instant::now(); let version_id = room_ai::Entity::find() .filter(room_ai::Column::Room.eq(request.room.id)) @@ -1274,7 +1357,8 @@ impl ChatService { on_chunk(ReactStep::Answer { step: step_count, answer: t.clone(), - }); + }) + .await; final_content.push_str(&t); } Ok(MultiTurnStreamItem::StreamAssistantItem( @@ -1286,7 +1370,8 @@ impl ChatService { on_chunk(ReactStep::Thought { step: step_count, thought: reasoning_text, - }); + }) + .await; } } Ok(MultiTurnStreamItem::StreamAssistantItem( @@ -1297,7 +1382,8 @@ impl ChatService { on_chunk(ReactStep::Thought { step: step_count, thought: reasoning, - }); + }) + .await; } } Ok(MultiTurnStreamItem::StreamAssistantItem( @@ -1313,7 +1399,8 @@ impl ChatService { on_chunk(ReactStep::Action { step: step_count, action: ReactAction::new(&tool_call.function.name, args), - }); + }) + .await; } Ok(MultiTurnStreamItem::StreamUserItem( rig::streaming::StreamedUserContent::ToolResult { tool_result, .. }, @@ -1323,7 +1410,8 @@ impl ChatService { on_chunk(ReactStep::Observation { step: step_count, observation: obs, - }); + }) + .await; } Ok(MultiTurnStreamItem::FinalResponse(resp)) => { let usage = resp.usage(); @@ -1341,6 +1429,7 @@ impl ChatService { let elapsed_ms = session_start.elapsed().as_millis() as i64; record_ai_session( + &request.cache, &request.db, request.project.id, session_id, @@ -1355,6 +1444,623 @@ impl ChatService { Ok((final_content, total_input_tokens, total_output_tokens)) } + + // ── CoT (Chain-of-Thought) ──────────────────────────────────────────── + + /// Run a CoT (Chain-of-Thought) reasoning cycle — step-by-step reasoning with optional tool use. + pub async fn process_cot(&self, request: &AiRequest, mut on_chunk: C) -> Result<(String, i64, i64)> + where + C: FnMut(crate::modes::cot::CotStep) -> Fut + Send, + Fut: std::future::Future + Send, + { + let client_config = AiClientConfig::new( + self.ai_api_key.clone().unwrap_or_default(), + ) + .with_base_url( + self.ai_base_url + .clone() + .unwrap_or_else(|| "https://api.openai.com".into()), + ); + let rig_client = client_config.build_rig_client(); + + let Some(registry) = &self.tool_registry else { + return Err(AgentError::Internal("no tool registry registered".into())); + }; + + let session_id = Uuid::now_v7(); + let session_start = std::time::Instant::now(); + let version_id = room_ai::Entity::find() + .filter(room_ai::Column::Room.eq(request.room.id)) + .filter(room_ai::Column::Model.eq(request.model.id)) + .one(&request.db) + .await + .ok() + .flatten() + .and_then(|r| r.version); + + let db = request.db.clone(); + let cache = request.cache.clone(); + let cfg = request.config.clone(); + let room_id = request.room.id; + let sender_uid = request.sender.uid; + let project_id = request.project.id; + + let mut tools: Vec> = Vec::new(); + for def in registry.definitions() { + let name = def.name.clone(); + if let Some(handler) = registry.get(&name) { + let adapter = crate::tool::RigToolAdapter::new( + handler.clone(), def.clone(), + db.clone(), cache.clone(), cfg.clone(), + room_id, Some(sender_uid), project_id, + ); + tools.push(Box::new(crate::tool::RecordingTool::new( + Box::new(adapter), db.clone(), session_id, sender_uid, + ))); + } + } + + let model = rig_client.completion_model(&request.model.name); + let agent = AgentBuilder::new(model) + .preamble(crate::modes::cot::COT_SYSTEM_PROMPT) + .tools(tools) + .default_max_turns(request.max_tool_depth) + .build(); + + let stream = agent + .stream_prompt(&request.input) + .with_history(Vec::new()) + .multi_turn(request.max_tool_depth) + .await; + + tokio::pin!(stream); + + let mut final_content = String::new(); + let mut total_input_tokens: i64 = 0; + let mut total_output_tokens: i64 = 0; + + while let Some(item) = stream.next().await { + match item { + Ok(MultiTurnStreamItem::StreamAssistantItem( + StreamedAssistantContent::Text(text), + )) => { + let t = text.text; + on_chunk(crate::modes::cot::CotStep::Answer(t.clone())).await; + final_content.push_str(&t); + } + Ok(MultiTurnStreamItem::StreamAssistantItem( + StreamedAssistantContent::Reasoning(reasoning), + )) => { + let r = reasoning.reasoning.join(""); + if !r.is_empty() { + on_chunk(crate::modes::cot::CotStep::Thought(r)).await; + } + } + Ok(MultiTurnStreamItem::StreamAssistantItem( + StreamedAssistantContent::ReasoningDelta { reasoning, .. }, + )) => { + if !reasoning.is_empty() { + on_chunk(crate::modes::cot::CotStep::Thought(reasoning)).await; + } + } + Ok(MultiTurnStreamItem::StreamAssistantItem( + StreamedAssistantContent::ToolCall { tool_call, .. }, + )) => { + let args: serde_json::Value = match &tool_call.function.arguments { + serde_json::Value::String(s) => { + serde_json::from_str(s).unwrap_or(serde_json::Value::Null) + } + v => v.clone(), + }; + on_chunk(crate::modes::cot::CotStep::Action { + name: tool_call.function.name.clone(), + args, + }).await; + } + Ok(MultiTurnStreamItem::StreamUserItem( + rig::streaming::StreamedUserContent::ToolResult { tool_result, .. }, + )) => { + let obs = tool_result_content_to_string(&tool_result.content); + on_chunk(crate::modes::cot::CotStep::Observation(obs)).await; + } + Ok(MultiTurnStreamItem::FinalResponse(resp)) => { + let usage = resp.usage(); + total_input_tokens = usage.input_tokens as i64; + total_output_tokens = usage.output_tokens as i64; + } + Err(e) => { + return Err(AgentError::OpenAi(e.to_string())); + } + _ => {} + } + } + + let elapsed_ms = session_start.elapsed().as_millis() as i64; + record_ai_session( + &request.cache, &request.db, request.project.id, + session_id, request.room.id, request.model.id, + version_id.unwrap_or_default(), + total_input_tokens, total_output_tokens, elapsed_ms, + ).await; + + Ok((final_content, total_input_tokens, total_output_tokens)) + } + + // ── ReWOO (Plan → Execute → Synthesize) ─────────────────────────────── + + /// Run a ReWOO reasoning cycle: model plans tool calls, they are executed, + /// then the model synthesises the final answer. + pub async fn process_rewoo(&self, request: &AiRequest, mut on_chunk: C) -> Result<(String, i64, i64)> + where + C: FnMut(crate::modes::rewoo::ReWooStep) -> Fut + Send, + Fut: std::future::Future + Send, + { + let client_config = AiClientConfig::new( + self.ai_api_key.clone().unwrap_or_default(), + ) + .with_base_url( + self.ai_base_url + .clone() + .unwrap_or_else(|| "https://api.openai.com".into()), + ); + let rig_client = client_config.build_rig_client(); + + let Some(registry) = &self.tool_registry else { + return Err(AgentError::Internal("no tool registry registered".into())); + }; + + let session_id = Uuid::now_v7(); + let session_start = std::time::Instant::now(); + let version_id = room_ai::Entity::find() + .filter(room_ai::Column::Room.eq(request.room.id)) + .filter(room_ai::Column::Model.eq(request.model.id)) + .one(&request.db) + .await + .ok() + .flatten() + .and_then(|r| r.version); + + let mut total_input_tokens: i64 = 0; + let mut total_output_tokens: i64 = 0; + + let mut messages = self.build_messages(request).await?; + messages.insert(0, crate::client::types::ChatRequestMessage::system( + crate::modes::rewoo::REWOO_SYSTEM_PROMPT.to_string(), + )); + let model = rig_client.completion_model(&request.model.name); + + let plan_tools = { + let db = request.db.clone(); + let cache = request.cache.clone(); + let cfg = request.config.clone(); + let room_id = request.room.id; + let sender_uid = request.sender.uid; + let project_id = request.project.id; + + let mut tools: Vec> = Vec::new(); + for def in registry.definitions() { + let name = def.name.clone(); + if let Some(handler) = registry.get(&name) { + let adapter = crate::tool::RigToolAdapter::new( + handler.clone(), + def.clone(), + db.clone(), + cache.clone(), + cfg.clone(), + room_id, + Some(sender_uid), + project_id, + ); + tools.push(Box::new(crate::tool::RecordingTool::new( + Box::new(adapter), + db.clone(), + session_id, + sender_uid, + ))); + } + } + tools + }; + + let plan_agent = rig::agent::AgentBuilder::new(model) + .preamble(crate::modes::rewoo::REWOO_SYSTEM_PROMPT) + .tools(plan_tools) + .default_max_turns(1) + .build(); + + let plan_response = plan_agent + .prompt(&request.input) + .extended_details() + .await + .map_err(|e| AgentError::OpenAi(e.to_string()))?; + + total_input_tokens += plan_response.total_usage.input_tokens as i64; + total_output_tokens += plan_response.total_usage.output_tokens as i64; + + let plan = crate::modes::rewoo::extract_plan(&plan_response.output) + .unwrap_or_default(); + + if plan.calls.is_empty() { + on_chunk(crate::modes::rewoo::ReWooStep::Synthesis(plan_response.output.clone())).await; + let elapsed_ms = session_start.elapsed().as_millis() as i64; + record_ai_session( + &request.cache, &request.db, request.project.id, + session_id, request.room.id, request.model.id, + version_id.unwrap_or_default(), + total_input_tokens, total_output_tokens, elapsed_ms, + ).await; + return Ok((plan_response.output, total_input_tokens, total_output_tokens)); + } + + on_chunk(crate::modes::rewoo::ReWooStep::Plan { + calls: plan.calls.clone(), + raw: plan.raw_text, + }).await; + + // ── Phase 2: Execute all tool calls in parallel ─────────────────── + let mut tool_results: Vec<(String, String)> = Vec::new(); + let mut handles = Vec::new(); + + for call in &plan.calls { + let ctx = crate::tool::ToolContext::new( + request.db.clone(), + request.cache.clone(), + request.config.clone(), + request.room.id, + Some(request.sender.uid), + ) + .with_project(request.project.id); + if let Some(ref es) = self.embed_service { + // ctx = ctx.with_embed_service(es.clone()); -- not clone-able via pattern, skip + let _ = es; + } + + let call_id = call.step.to_string(); + let tool_name = call.tool.clone(); + let args = call.args.clone(); + let ctx_clone = ctx.clone(); + + let handle = tokio::spawn(async move { + let executor = crate::tool::ToolExecutor::new(); + let agent_call = crate::tool::ToolCall { + id: call_id, + name: tool_name.clone(), + arguments: args.to_string(), + }; + let mut local_ctx = ctx_clone; + let result = executor.execute_batch(vec![agent_call], &mut local_ctx).await; + match result { + Ok(results) => { + for r in &results { + match &r.result { + crate::tool::ToolResult::Ok(v) => { + return (tool_name, v.to_string()); + } + crate::tool::ToolResult::Error(e) => { + return (tool_name, format!("[Error: {}]", e)); + } + } + } + (tool_name, "[No result]".to_string()) + } + Err(e) => (tool_name, format!("[Execution error: {}]", e)), + } + }); + handles.push(handle); + } + + for handle in handles { + match handle.await { + Ok((name, result)) => { + on_chunk(crate::modes::rewoo::ReWooStep::Execution { + tool_name: name.clone(), + result: result.clone(), + }).await; + tool_results.push((name, result)); + } + Err(e) => { + let msg = format!("[Task panicked: {}]", e); + on_chunk(crate::modes::rewoo::ReWooStep::Execution { + tool_name: "unknown".into(), + result: msg.clone(), + }).await; + tool_results.push(("unknown".into(), msg)); + } + } + } + + // ── Phase 3: Synthesize ─────────────────────────────────────────── + let mut synth_messages = self.build_messages(request).await?; + synth_messages.insert(0, crate::client::types::ChatRequestMessage::system( + crate::modes::rewoo::REWOO_SYSTEM_PROMPT.to_string(), + )); + + let results_summary: String = tool_results + .iter() + .map(|(name, res)| format!("- {}:\n{}", name, res)) + .collect::>() + .join("\n"); + + synth_messages.push(crate::client::types::ChatRequestMessage::system(format!( + "## Tool Execution Results\n\nThe following tool calls were executed:\n\n{}\n\nNow synthesize your final answer based on these results.", + results_summary + ))); + synth_messages.push(crate::client::types::ChatRequestMessage::user(&request.input)); + + let preamble = synth_messages + .iter() + .find(|m| m.role == "system") + .and_then(|m| m.content.as_deref()) + .unwrap_or("") + .to_string(); + let non_system: Vec<_> = synth_messages + .iter() + .filter(|m| m.role != "system") + .map(|m| crate::client::to_rig_message(m)) + .collect(); + + let synth_model = rig_client.completion_model(&request.model.name); + let synth_stream = synth_model + .completion_request("") + .preamble(preamble) + .messages(non_system) + .temperature(request.temperature as f64) + .max_tokens(request.max_tokens as u64) + .stream() + .await + .map_err(|e| AgentError::OpenAi(e.to_string()))?; + + use rig::streaming::StreamedAssistantContent; + tokio::pin!(synth_stream); + + let mut synthesis = String::new(); + while let Some(item) = synth_stream.next().await { + match item { + Ok(StreamedAssistantContent::Text(text)) => { + let t = text.text; + on_chunk(crate::modes::rewoo::ReWooStep::Synthesis(t.clone())).await; + synthesis.push_str(&t); + } + Ok(StreamedAssistantContent::Final(response)) => { + if let Some(usage) = response.token_usage() { + total_input_tokens += usage.input_tokens as i64; + total_output_tokens += usage.output_tokens as i64; + } + } + Err(e) => return Err(AgentError::OpenAi(e.to_string())), + _ => {} + } + } + + let elapsed_ms = session_start.elapsed().as_millis() as i64; + record_ai_session( + &request.cache, &request.db, request.project.id, + session_id, request.room.id, request.model.id, + version_id.unwrap_or_default(), + total_input_tokens, total_output_tokens, elapsed_ms, + ).await; + + Ok((synthesis, total_input_tokens, total_output_tokens)) + } + + // ── Reflexion (Generate → Critique → Revise) ────────────────────────── + + /// Run a Reflexion reasoning cycle: generate → critique → revise (up to 3 rounds). + pub async fn process_reflexion( + &self, + request: &AiRequest, + mut on_chunk: C, + max_cycles: usize, + ) -> Result<(String, i64, i64)> + where + C: FnMut(crate::modes::reflexion::ReflexionStep) -> Fut + Send, + Fut: std::future::Future + Send, + { + let client_config = AiClientConfig::new( + self.ai_api_key.clone().unwrap_or_default(), + ) + .with_base_url( + self.ai_base_url + .clone() + .unwrap_or_else(|| "https://api.openai.com".into()), + ); + let rig_client = client_config.build_rig_client(); + let Some(registry) = &self.tool_registry else { + return Err(AgentError::Internal("no tool registry registered".into())); + }; + + let session_id = Uuid::now_v7(); + let session_start = std::time::Instant::now(); + let version_id = room_ai::Entity::find() + .filter(room_ai::Column::Room.eq(request.room.id)) + .filter(room_ai::Column::Model.eq(request.model.id)) + .one(&request.db) + .await + .ok() + .flatten() + .and_then(|r| r.version); + + let max_cycles = max_cycles.min(3); + + let mut total_input_tokens: i64 = 0; + let mut total_output_tokens: i64 = 0; + let mut best_answer = String::new(); + + for cycle in 0..max_cycles { + let mut messages = self.build_messages(request).await?; + messages.insert(0, crate::client::types::ChatRequestMessage::system( + crate::modes::reflexion::REFLEXION_SYSTEM_PROMPT.to_string(), + )); + + if cycle > 0 { + messages.push(crate::client::types::ChatRequestMessage::system(format!( + "This is cycle {} of the reflexion process. Your previous answer was:\n\n{}\n\nPlease critique and improve upon it.", + cycle + 1, + best_answer + ))); + } + + // Build tools for this cycle (not cloneable, so rebuild each iteration) + let cycle_tools = build_rig_tools( + registry, &request.db, &request.cache, &request.config, + request.room.id, request.sender.uid, request.project.id, session_id, + ); + + // ── Generate ────────────────────────────────────────────── + let model = rig_client.completion_model(&request.model.name); + let agent = rig::agent::AgentBuilder::new(model) + .preamble(crate::modes::reflexion::REFLEXION_SYSTEM_PROMPT) + .tools(cycle_tools) + .default_max_turns(request.max_tool_depth) + .build(); + + let stream = agent + .stream_prompt(&request.input) + .with_history(Vec::new()) + .multi_turn(request.max_tool_depth) + .await; + + tokio::pin!(stream); + let mut generated = String::new(); + + while let Some(item) = stream.next().await { + match item { + Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem( + rig::streaming::StreamedAssistantContent::Text(text), + )) => { + generated.push_str(&text.text); + } + Ok(rig::agent::MultiTurnStreamItem::FinalResponse(resp)) => { + let usage = resp.usage(); + total_input_tokens += usage.input_tokens as i64; + total_output_tokens += usage.output_tokens as i64; + } + Err(e) => return Err(AgentError::OpenAi(e.to_string())), + _ => {} + } + } + + best_answer = generated.clone(); + on_chunk(crate::modes::reflexion::ReflexionStep::Generate(generated.clone())).await; + + // If only 1 cycle, emit final and exit + if max_cycles == 1 || cycle + 1 >= max_cycles { + on_chunk(crate::modes::reflexion::ReflexionStep::Final(generated.clone())).await; + break; + } + + // ── Self-critique ───────────────────────────────────────── + let critique_messages = vec![ + crate::client::types::ChatRequestMessage::system(crate::modes::reflexion::REFLEXION_SYSTEM_PROMPT), + crate::client::types::ChatRequestMessage::system(format!( + "Your previous answer was:\n\n{}", generated + )), + crate::client::types::ChatRequestMessage::user(crate::modes::reflexion::REFLEXION_CRITIQUE_PROMPT), + ]; + + let critique_result = crate::client::call_with_params( + &critique_messages, + &request.model.name, + &client_config, + request.temperature as f32, + request.max_tokens as u32, + None, + None, + Some("none"), + ).await?; + + total_input_tokens += critique_result.input_tokens; + total_output_tokens += critique_result.output_tokens; + let critique = critique_result.content; + on_chunk(crate::modes::reflexion::ReflexionStep::Critique(critique.clone())).await; + + // ── Revise ─────────────────────────────────────────────── + let revise_messages = vec![ + crate::client::types::ChatRequestMessage::user(format!( + "Your previous answer:\n\n{}\n\nYour self-critique:\n\n{}", + generated, critique + )), + crate::client::types::ChatRequestMessage::user(crate::modes::reflexion::REFLEXION_REVISE_PROMPT), + ]; + + let revise_model = rig_client.completion_model(&request.model.name); + let revise_stream = revise_model + .completion_request("") + .preamble(crate::modes::reflexion::REFLEXION_SYSTEM_PROMPT.to_string()) + .messages(revise_messages.iter().map(|m| { + crate::client::to_rig_message(m) + }).collect::>()) + .temperature(request.temperature as f64) + .max_tokens(request.max_tokens as u64) + .stream() + .await + .map_err(|e| AgentError::OpenAi(e.to_string()))?; + + tokio::pin!(revise_stream); + let mut revised = String::new(); + + while let Some(item) = revise_stream.next().await { + match item { + Ok(rig::streaming::StreamedAssistantContent::Text(text)) => { + revised.push_str(&text.text); + } + Ok(rig::streaming::StreamedAssistantContent::Final(response)) => { + if let Some(usage) = response.token_usage() { + total_input_tokens += usage.input_tokens as i64; + total_output_tokens += usage.output_tokens as i64; + } + } + Err(e) => return Err(AgentError::OpenAi(e.to_string())), + _ => {} + } + } + + best_answer = revised.clone(); + on_chunk(crate::modes::reflexion::ReflexionStep::Revise(revised.clone())).await; + + // If last cycle, emit final + if cycle + 1 >= max_cycles { + on_chunk(crate::modes::reflexion::ReflexionStep::Final(revised.clone())).await; + } + } + + let elapsed_ms = session_start.elapsed().as_millis() as i64; + record_ai_session( + &request.cache, &request.db, request.project.id, + session_id, request.room.id, request.model.id, + version_id.unwrap_or_default(), + total_input_tokens, total_output_tokens, elapsed_ms, + ).await; + + Ok((best_answer, total_input_tokens, total_output_tokens)) + } +} + +fn build_rig_tools( + registry: &crate::tool::ToolRegistry, + db: &db::database::AppDatabase, + cache: &db::cache::AppCache, + cfg: &config::AppConfig, + room_id: uuid::Uuid, + sender_uid: uuid::Uuid, + project_id: uuid::Uuid, + session_id: uuid::Uuid, +) -> Vec> { + let mut tools: Vec> = Vec::new(); + for def in registry.definitions() { + let name = def.name.clone(); + if let Some(handler) = registry.get(&name) { + let adapter = crate::tool::RigToolAdapter::new( + handler.clone(), def.clone(), + db.clone(), cache.clone(), cfg.clone(), + room_id, Some(sender_uid), project_id, + ); + tools.push(Box::new(crate::tool::RecordingTool::new( + Box::new(adapter), db.clone(), session_id, sender_uid, + ))); + } + } + tools } /// Extract text from rig's ToolResultContent, ignoring images. diff --git a/libs/agent/client/mod.rs b/libs/agent/client/mod.rs index 13bff36..02849ce 100644 --- a/libs/agent/client/mod.rs +++ b/libs/agent/client/mod.rs @@ -155,7 +155,7 @@ fn ai_metrics() -> &'static AiMetrics { // ── Type conversions ───────────────────────────────────────────────────────── -fn to_rig_message(msg: &ChatRequestMessage) -> RigMessage { +pub(crate) fn to_rig_message(msg: &ChatRequestMessage) -> RigMessage { match msg.role.as_str() { "system" => { // System messages are handled via preamble(), but we still diff --git a/libs/agent/lib.rs b/libs/agent/lib.rs index eeb4172..26b6d83 100644 --- a/libs/agent/lib.rs +++ b/libs/agent/lib.rs @@ -6,6 +6,7 @@ pub mod compact; pub mod embed; pub mod error; pub mod model; +pub mod modes; pub mod perception; pub mod react; pub mod skills; @@ -33,6 +34,10 @@ pub use embed::{ EmbedClient, EmbedMemoryInput, EmbedService, QdrantClient, SearchResult, TagEmbedInput, new_embed_client, }; pub use error::{AgentError, Result}; +pub use modes::cot::{CotStep, COT_SYSTEM_PROMPT}; +pub use modes::reflexion::{ReflexionCycle, ReflexionStep, REFLEXION_CRITIQUE_PROMPT, REFLEXION_REVISE_PROMPT, REFLEXION_SYSTEM_PROMPT}; +pub use modes::rewoo::{ReWooPlan, ReWooStep, ReWooToolCall, REWOO_SYSTEM_PROMPT, extract_plan}; +pub use modes::ModeStep; pub use react::{ReactConfig, ReactStep, DEFAULT_SYSTEM_PROMPT}; pub use tool::{ ToolCall, ToolCallRecord, ToolCallRecorder, ToolCallResult, ToolContext, ToolDefinition, ToolError, ToolExecutor, ToolHandler, ToolParam,