use futures::StreamExt; use models::rooms::room_ai; use rig::agent::{AgentBuilder, MultiTurnStreamItem}; use rig::client::CompletionClient; use rig::streaming::{StreamedAssistantContent, StreamingPrompt}; use sea_orm::*; use uuid::Uuid; use super::AiRequest; use super::session_recording::record_ai_session; use crate::client::AiClientConfig; use crate::error::{AgentError, Result}; use crate::react::types::Action as ReactAction; use crate::react::{DEFAULT_SYSTEM_PROMPT, ReactStep}; use crate::tool::{RecordingTool, registry::ToolRegistry}; pub async fn execute_process_react( request: &AiRequest, mut on_chunk: C, tool_registry: &ToolRegistry, ai_base_url: Option, ai_api_key: Option, room_preamble: Option<&str>, message_producer: Option, ) -> Result<(String, i64, i64)> where C: FnMut(ReactStep) -> Fut + Send, Fut: std::future::Future + Send, { let base_url = ai_base_url.unwrap_or_else(|| "https://api.openai.com".into()); let api_key = ai_api_key.unwrap_or_default(); let client_config = AiClientConfig::new(api_key).with_base_url(base_url); let db = request.db.clone(); let cache = request.cache.clone(); let cfg = request.config.clone(); let room_id = request.room.id; let sender_uid = request.sender.uid; let project_id = request.project.id; let ai_model_id = request.model.id; let ai_model_name = request.model.name.clone(); let sent_in_turn = std::sync::Arc::new(std::sync::Mutex::new(Vec::new())); let session_id = Uuid::now_v7(); let session_start = std::time::Instant::now(); let version_id = room_ai::Entity::find() .filter(room_ai::Column::Room.eq(request.room.id)) .filter(room_ai::Column::Model.eq(request.model.id)) .one(&request.db) .await .ok() .flatten() .and_then(|r| r.version); let mut tools: Vec> = Vec::new(); for def in tool_registry.definitions() { let name = def.name.clone(); if let Some(handler) = tool_registry.get(&name) { let adapter = crate::tool::RigToolAdapter::new( handler.clone(), def.clone(), db.clone(), cache.clone(), cfg.clone(), room_id, Some(sender_uid), project_id, message_producer.clone(), Some(ai_model_id), Some(ai_model_name.clone()), sent_in_turn.clone(), ); tools.push(Box::new(RecordingTool::new( Box::new(adapter), db.clone(), session_id, sender_uid, ))); } } let rig_client = client_config.build_rig_client(); let model = rig_client.completion_model(&request.model.name); // General rules first (strong LLM attention), room context appended after // so that output-format rules aren't buried behind long room preamble. let preamble = match room_preamble { Some(rp) => format!("{}\n{}", DEFAULT_SYSTEM_PROMPT, rp), None => DEFAULT_SYSTEM_PROMPT.to_string(), }; let agent = AgentBuilder::new(model) .preamble(&preamble) .tools(tools) .default_max_turns(request.max_tool_depth) .build(); let stream = agent .stream_prompt(&request.input) .with_history(Vec::::new()) .multi_turn(request.max_tool_depth) .await; tokio::pin!(stream); let mut step_count = 0usize; let mut final_content = String::new(); let mut total_input_tokens: i64 = 0; let mut total_output_tokens: i64 = 0; while let Some(item) = stream.next().await { match item { Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(text))) => { step_count += 1; let t = text.text; on_chunk(ReactStep::Answer { step: step_count, answer: t.clone(), }) .await; final_content.push_str(&t); } Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Reasoning( reasoning, ))) => { let reasoning_text: String = reasoning .content .iter() .filter_map(|c| match c { rig::completion::message::ReasoningContent::Text { text, .. } => { Some(text.as_str()) } _ => None, }) .collect::>() .join(""); if !reasoning_text.is_empty() { step_count += 1; on_chunk(ReactStep::Thought { step: step_count, thought: reasoning_text, }) .await; } } Ok(MultiTurnStreamItem::StreamAssistantItem( StreamedAssistantContent::ReasoningDelta { reasoning, .. }, )) => { if !reasoning.is_empty() { step_count += 1; on_chunk(ReactStep::Thought { step: step_count, thought: reasoning, }) .await; } } Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::ToolCall { tool_call, .. })) => { step_count += 1; let args: serde_json::Value = match &tool_call.function.arguments { serde_json::Value::String(s) => { serde_json::from_str(s).unwrap_or(serde_json::Value::Null) } v => v.clone(), }; on_chunk(ReactStep::Action { step: step_count, action: ReactAction::new(&tool_call.function.name, args), }) .await; } Ok(MultiTurnStreamItem::StreamUserItem( rig::streaming::StreamedUserContent::ToolResult { tool_result, .. }, )) => { step_count += 1; let obs = tool_result_content_to_string(&tool_result.content); on_chunk(ReactStep::Observation { step: step_count, observation: obs, }) .await; } Ok(MultiTurnStreamItem::FinalResponse(resp)) => { let usage = resp.usage(); total_input_tokens = usage.input_tokens as i64; total_output_tokens = usage.output_tokens as i64; } Err(e) => { let err_msg = format!("rig agent stream error: {}", e); return Err(AgentError::OpenAi(err_msg)); } _ => {} } } let elapsed_ms = session_start.elapsed().as_millis() as i64; record_ai_session( &request.cache, &request.db, request.project.id, request.sender.uid, session_id, request.room.id, request.model.id, version_id.unwrap_or_default(), total_input_tokens, total_output_tokens, elapsed_ms, ) .await; Ok((final_content, total_input_tokens, total_output_tokens)) } /// Extract text from rig's ToolResultContent, ignoring images. fn tool_result_content_to_string( content: &rig::one_or_many::OneOrMany, ) -> String { use rig::completion::message::ToolResultContent; content .iter() .filter_map(|item| { if let ToolResultContent::Text(t) = item { Some(t.text.clone()) } else { None } }) .collect::>() .join("\n") }