gitdataai/libs/agent/chat/react_execution.rs

234 lines
8.0 KiB
Rust

use futures::StreamExt;
use models::rooms::room_ai;
use rig::agent::{AgentBuilder, MultiTurnStreamItem};
use rig::client::CompletionClient;
use rig::streaming::{StreamedAssistantContent, StreamingPrompt};
use sea_orm::*;
use uuid::Uuid;
use super::AiRequest;
use super::session_recording::record_ai_session;
use crate::client::AiClientConfig;
use crate::error::{AgentError, Result};
use crate::react::types::Action as ReactAction;
use crate::react::{DEFAULT_SYSTEM_PROMPT, ReactStep};
use crate::tool::{RecordingTool, registry::ToolRegistry};
pub async fn execute_process_react<C, Fut>(
request: &AiRequest,
mut on_chunk: C,
tool_registry: &ToolRegistry,
ai_base_url: Option<String>,
ai_api_key: Option<String>,
room_preamble: Option<&str>,
message_producer: Option<queue::MessageProducer>,
) -> Result<(String, i64, i64)>
where
C: FnMut(ReactStep) -> Fut + Send,
Fut: std::future::Future<Output = ()> + Send,
{
let base_url = ai_base_url.unwrap_or_else(|| "https://api.openai.com".into());
let api_key = ai_api_key.unwrap_or_default();
let client_config = AiClientConfig::new(api_key).with_base_url(base_url);
let db = request.db.clone();
let cache = request.cache.clone();
let cfg = request.config.clone();
let room_id = request.room.id;
let sender_uid = request.sender.uid;
let project_id = request.project.id;
let ai_model_id = request.model.id;
let ai_model_name = request.model.name.clone();
let sent_in_turn = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
let session_id = Uuid::now_v7();
let session_start = std::time::Instant::now();
let version_id = room_ai::Entity::find()
.filter(room_ai::Column::Room.eq(request.room.id))
.filter(room_ai::Column::Model.eq(request.model.id))
.one(&request.db)
.await
.ok()
.flatten()
.and_then(|r| r.version);
let mut tools: Vec<Box<dyn rig::tool::ToolDyn + 'static>> = Vec::new();
for def in tool_registry.definitions() {
let name = def.name.clone();
if let Some(handler) = tool_registry.get(&name) {
let adapter = crate::tool::RigToolAdapter::new(
handler.clone(),
def.clone(),
db.clone(),
cache.clone(),
cfg.clone(),
room_id,
Some(sender_uid),
project_id,
message_producer.clone(),
Some(ai_model_id),
Some(ai_model_name.clone()),
sent_in_turn.clone(),
);
tools.push(Box::new(RecordingTool::new(
Box::new(adapter),
db.clone(),
session_id,
sender_uid,
)));
}
}
let rig_client = client_config.build_rig_client();
let model = rig_client.completion_model(&request.model.name);
// General rules first (strong LLM attention), room context appended after
// so that output-format rules aren't buried behind long room preamble.
let preamble = match room_preamble {
Some(rp) => format!("{}\n{}", DEFAULT_SYSTEM_PROMPT, rp),
None => DEFAULT_SYSTEM_PROMPT.to_string(),
};
let agent = AgentBuilder::new(model)
.preamble(&preamble)
.tools(tools)
.default_max_turns(request.max_tool_depth)
.build();
let stream = agent
.stream_prompt(&request.input)
.with_history(Vec::<rig::completion::Message>::new())
.multi_turn(request.max_tool_depth)
.await;
tokio::pin!(stream);
let mut step_count = 0usize;
let mut final_content = String::new();
let mut total_input_tokens: i64 = 0;
let mut total_output_tokens: i64 = 0;
while let Some(item) = stream.next().await {
match item {
Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(text))) => {
step_count += 1;
let t = text.text;
on_chunk(ReactStep::Answer {
step: step_count,
answer: t.clone(),
})
.await;
final_content.push_str(&t);
}
Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Reasoning(
reasoning,
))) => {
let reasoning_text: String = reasoning
.content
.iter()
.filter_map(|c| match c {
rig::completion::message::ReasoningContent::Text { text, .. } => {
Some(text.as_str())
}
_ => None,
})
.collect::<Vec<_>>()
.join("");
if !reasoning_text.is_empty() {
step_count += 1;
on_chunk(ReactStep::Thought {
step: step_count,
thought: reasoning_text,
})
.await;
}
}
Ok(MultiTurnStreamItem::StreamAssistantItem(
StreamedAssistantContent::ReasoningDelta { reasoning, .. },
)) => {
if !reasoning.is_empty() {
step_count += 1;
on_chunk(ReactStep::Thought {
step: step_count,
thought: reasoning,
})
.await;
}
}
Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::ToolCall {
tool_call,
..
})) => {
step_count += 1;
let args: serde_json::Value = match &tool_call.function.arguments {
serde_json::Value::String(s) => {
serde_json::from_str(s).unwrap_or(serde_json::Value::Null)
}
v => v.clone(),
};
on_chunk(ReactStep::Action {
step: step_count,
action: ReactAction::new(&tool_call.function.name, args),
})
.await;
}
Ok(MultiTurnStreamItem::StreamUserItem(
rig::streaming::StreamedUserContent::ToolResult { tool_result, .. },
)) => {
step_count += 1;
let obs = tool_result_content_to_string(&tool_result.content);
on_chunk(ReactStep::Observation {
step: step_count,
observation: obs,
})
.await;
}
Ok(MultiTurnStreamItem::FinalResponse(resp)) => {
let usage = resp.usage();
total_input_tokens = usage.input_tokens as i64;
total_output_tokens = usage.output_tokens as i64;
}
Err(e) => {
let err_msg = format!("rig agent stream error: {}", e);
return Err(AgentError::OpenAi(err_msg));
}
_ => {}
}
}
let elapsed_ms = session_start.elapsed().as_millis() as i64;
record_ai_session(
&request.cache,
&request.db,
request.project.id,
request.sender.uid,
session_id,
request.room.id,
request.model.id,
version_id.unwrap_or_default(),
total_input_tokens,
total_output_tokens,
elapsed_ms,
)
.await;
Ok((final_content, total_input_tokens, total_output_tokens))
}
/// Extract text from rig's ToolResultContent, ignoring images.
fn tool_result_content_to_string(
content: &rig::one_or_many::OneOrMany<rig::completion::message::ToolResultContent>,
) -> String {
use rig::completion::message::ToolResultContent;
content
.iter()
.filter_map(|item| {
if let ToolResultContent::Text(t) = item {
Some(t.text.clone())
} else {
None
}
})
.collect::<Vec<_>>()
.join("\n")
}