gitdataai/libs/agent/chat/react_execution.rs

149 lines
6.5 KiB
Rust

use futures::StreamExt;
use models::rooms::room_ai;
use rig::agent::{AgentBuilder, MultiTurnStreamItem};
use rig::client::CompletionClient;
use rig::streaming::{StreamedAssistantContent, StreamingPrompt};
use sea_orm::*;
use uuid::Uuid;
use super::AiRequest;
use crate::client::AiClientConfig;
use crate::error::{AgentError, Result};
use crate::react::{DEFAULT_SYSTEM_PROMPT, ReactStep};
use crate::react::types::Action as ReactAction;
use crate::tool::{RecordingTool, registry::ToolRegistry};
use super::session_recording::record_ai_session;
pub async fn execute_process_react<C, Fut>(
request: &AiRequest, mut on_chunk: C,
tool_registry: &ToolRegistry,
ai_base_url: Option<String>, ai_api_key: Option<String>,
room_preamble: Option<&str>,
message_producer: Option<queue::MessageProducer>,
) -> Result<(String, i64, i64)>
where
C: FnMut(ReactStep) -> Fut + Send,
Fut: std::future::Future<Output = ()> + Send,
{
let base_url = ai_base_url.unwrap_or_else(|| "https://api.openai.com".into());
let api_key = ai_api_key.unwrap_or_default();
let client_config = AiClientConfig::new(api_key).with_base_url(base_url);
let db = request.db.clone();
let cache = request.cache.clone();
let cfg = request.config.clone();
let room_id = request.room.id;
let sender_uid = request.sender.uid;
let project_id = request.project.id;
let ai_model_id = request.model.id;
let ai_model_name = request.model.name.clone();
let sent_in_turn = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
let session_id = Uuid::now_v7();
let session_start = std::time::Instant::now();
let version_id = room_ai::Entity::find()
.filter(room_ai::Column::Room.eq(request.room.id))
.filter(room_ai::Column::Model.eq(request.model.id))
.one(&request.db).await.ok().flatten().and_then(|r| r.version);
let mut tools: Vec<Box<dyn rig::tool::ToolDyn + 'static>> = Vec::new();
for def in tool_registry.definitions() {
let name = def.name.clone();
if let Some(handler) = tool_registry.get(&name) {
let adapter = crate::tool::RigToolAdapter::new(
handler.clone(), def.clone(), db.clone(), cache.clone(), cfg.clone(),
room_id, Some(sender_uid), project_id, message_producer.clone(),
Some(ai_model_id), Some(ai_model_name.clone()),
sent_in_turn.clone(),
);
tools.push(Box::new(RecordingTool::new(Box::new(adapter), db.clone(), session_id, sender_uid)));
}
}
let rig_client = client_config.build_rig_client();
let model = rig_client.completion_model(&request.model.name);
let preamble = match room_preamble {
Some(rp) => format!("{}\n{}", rp, DEFAULT_SYSTEM_PROMPT),
None => DEFAULT_SYSTEM_PROMPT.to_string(),
};
let agent = AgentBuilder::new(model)
.preamble(&preamble)
.tools(tools)
.default_max_turns(request.max_tool_depth)
.build();
let stream = agent.stream_prompt(&request.input)
.with_history(Vec::new())
.multi_turn(request.max_tool_depth)
.await;
tokio::pin!(stream);
let mut step_count = 0usize;
let mut final_content = String::new();
let mut total_input_tokens: i64 = 0;
let mut total_output_tokens: i64 = 0;
while let Some(item) = stream.next().await {
match item {
Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(text))) => {
step_count += 1;
let t = text.text;
let cleaned = t.replace('\n', "");
on_chunk(ReactStep::Answer { step: step_count, answer: cleaned }).await;
final_content.push_str(&t);
}
Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Reasoning(reasoning))) => {
let reasoning_text = reasoning.reasoning.join("");
if !reasoning_text.is_empty() {
step_count += 1;
on_chunk(ReactStep::Thought { step: step_count, thought: reasoning_text }).await;
}
}
Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::ReasoningDelta { reasoning, .. })) => {
if !reasoning.is_empty() {
step_count += 1;
on_chunk(ReactStep::Thought { step: step_count, thought: reasoning }).await;
}
}
Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::ToolCall { tool_call, .. })) => {
step_count += 1;
let args: serde_json::Value = match &tool_call.function.arguments {
serde_json::Value::String(s) => serde_json::from_str(s).unwrap_or(serde_json::Value::Null),
v => v.clone(),
};
on_chunk(ReactStep::Action { step: step_count, action: ReactAction::new(&tool_call.function.name, args) }).await;
}
Ok(MultiTurnStreamItem::StreamUserItem(rig::streaming::StreamedUserContent::ToolResult { tool_result, .. })) => {
step_count += 1;
let obs = tool_result_content_to_string(&tool_result.content);
on_chunk(ReactStep::Observation { step: step_count, observation: obs }).await;
}
Ok(MultiTurnStreamItem::FinalResponse(resp)) => {
let usage = resp.usage();
total_input_tokens = usage.input_tokens as i64;
total_output_tokens = usage.output_tokens as i64;
}
Err(e) => {
let err_msg = format!("rig agent stream error: {}", e);
return Err(AgentError::OpenAi(err_msg));
}
_ => {}
}
}
let elapsed_ms = session_start.elapsed().as_millis() as i64;
record_ai_session(&request.cache, &request.db, request.project.id, request.sender.uid, session_id, request.room.id, request.model.id, version_id.unwrap_or_default(), total_input_tokens, total_output_tokens, elapsed_ms).await;
Ok((final_content, total_input_tokens, total_output_tokens))
}
/// Extract text from rig's ToolResultContent, ignoring images.
fn tool_result_content_to_string(content: &rig::one_or_many::OneOrMany<rig::completion::message::ToolResultContent>) -> String {
use rig::completion::message::ToolResultContent;
content.iter().filter_map(|item| {
if let ToolResultContent::Text(t) = item { Some(t.text.clone()) } else { None }
}).collect::<Vec<_>>().join("\n")
}