149 lines
6.5 KiB
Rust
149 lines
6.5 KiB
Rust
use futures::StreamExt;
|
|
use models::rooms::room_ai;
|
|
use rig::agent::{AgentBuilder, MultiTurnStreamItem};
|
|
use rig::client::CompletionClient;
|
|
use rig::streaming::{StreamedAssistantContent, StreamingPrompt};
|
|
use sea_orm::*;
|
|
use uuid::Uuid;
|
|
|
|
use super::AiRequest;
|
|
use crate::client::AiClientConfig;
|
|
use crate::error::{AgentError, Result};
|
|
use crate::react::{DEFAULT_SYSTEM_PROMPT, ReactStep};
|
|
use crate::react::types::Action as ReactAction;
|
|
use crate::tool::{RecordingTool, registry::ToolRegistry};
|
|
use super::session_recording::record_ai_session;
|
|
|
|
pub async fn execute_process_react<C, Fut>(
|
|
request: &AiRequest, mut on_chunk: C,
|
|
tool_registry: &ToolRegistry,
|
|
ai_base_url: Option<String>, ai_api_key: Option<String>,
|
|
room_preamble: Option<&str>,
|
|
message_producer: Option<queue::MessageProducer>,
|
|
) -> Result<(String, i64, i64)>
|
|
where
|
|
C: FnMut(ReactStep) -> Fut + Send,
|
|
Fut: std::future::Future<Output = ()> + Send,
|
|
{
|
|
let base_url = ai_base_url.unwrap_or_else(|| "https://api.openai.com".into());
|
|
let api_key = ai_api_key.unwrap_or_default();
|
|
let client_config = AiClientConfig::new(api_key).with_base_url(base_url);
|
|
|
|
let db = request.db.clone();
|
|
let cache = request.cache.clone();
|
|
let cfg = request.config.clone();
|
|
let room_id = request.room.id;
|
|
let sender_uid = request.sender.uid;
|
|
let project_id = request.project.id;
|
|
let ai_model_id = request.model.id;
|
|
let ai_model_name = request.model.name.clone();
|
|
let sent_in_turn = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
|
|
let session_id = Uuid::now_v7();
|
|
let session_start = std::time::Instant::now();
|
|
let version_id = room_ai::Entity::find()
|
|
.filter(room_ai::Column::Room.eq(request.room.id))
|
|
.filter(room_ai::Column::Model.eq(request.model.id))
|
|
.one(&request.db).await.ok().flatten().and_then(|r| r.version);
|
|
|
|
let mut tools: Vec<Box<dyn rig::tool::ToolDyn + 'static>> = Vec::new();
|
|
for def in tool_registry.definitions() {
|
|
let name = def.name.clone();
|
|
if let Some(handler) = tool_registry.get(&name) {
|
|
let adapter = crate::tool::RigToolAdapter::new(
|
|
handler.clone(), def.clone(), db.clone(), cache.clone(), cfg.clone(),
|
|
room_id, Some(sender_uid), project_id, message_producer.clone(),
|
|
Some(ai_model_id), Some(ai_model_name.clone()),
|
|
sent_in_turn.clone(),
|
|
);
|
|
tools.push(Box::new(RecordingTool::new(Box::new(adapter), db.clone(), session_id, sender_uid)));
|
|
}
|
|
}
|
|
|
|
let rig_client = client_config.build_rig_client();
|
|
let model = rig_client.completion_model(&request.model.name);
|
|
|
|
let preamble = match room_preamble {
|
|
Some(rp) => format!("{}\n{}", rp, DEFAULT_SYSTEM_PROMPT),
|
|
None => DEFAULT_SYSTEM_PROMPT.to_string(),
|
|
};
|
|
|
|
let agent = AgentBuilder::new(model)
|
|
.preamble(&preamble)
|
|
.tools(tools)
|
|
.default_max_turns(request.max_tool_depth)
|
|
.build();
|
|
|
|
let stream = agent.stream_prompt(&request.input)
|
|
.with_history(Vec::new())
|
|
.multi_turn(request.max_tool_depth)
|
|
.await;
|
|
|
|
tokio::pin!(stream);
|
|
|
|
let mut step_count = 0usize;
|
|
let mut final_content = String::new();
|
|
let mut total_input_tokens: i64 = 0;
|
|
let mut total_output_tokens: i64 = 0;
|
|
|
|
while let Some(item) = stream.next().await {
|
|
match item {
|
|
Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(text))) => {
|
|
step_count += 1;
|
|
let t = text.text;
|
|
let cleaned = t.replace('\n', "");
|
|
on_chunk(ReactStep::Answer { step: step_count, answer: cleaned }).await;
|
|
final_content.push_str(&t);
|
|
}
|
|
Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Reasoning(reasoning))) => {
|
|
let reasoning_text = reasoning.reasoning.join("");
|
|
if !reasoning_text.is_empty() {
|
|
step_count += 1;
|
|
on_chunk(ReactStep::Thought { step: step_count, thought: reasoning_text }).await;
|
|
}
|
|
}
|
|
Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::ReasoningDelta { reasoning, .. })) => {
|
|
if !reasoning.is_empty() {
|
|
step_count += 1;
|
|
on_chunk(ReactStep::Thought { step: step_count, thought: reasoning }).await;
|
|
}
|
|
}
|
|
Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::ToolCall { tool_call, .. })) => {
|
|
step_count += 1;
|
|
let args: serde_json::Value = match &tool_call.function.arguments {
|
|
serde_json::Value::String(s) => serde_json::from_str(s).unwrap_or(serde_json::Value::Null),
|
|
v => v.clone(),
|
|
};
|
|
on_chunk(ReactStep::Action { step: step_count, action: ReactAction::new(&tool_call.function.name, args) }).await;
|
|
}
|
|
Ok(MultiTurnStreamItem::StreamUserItem(rig::streaming::StreamedUserContent::ToolResult { tool_result, .. })) => {
|
|
step_count += 1;
|
|
let obs = tool_result_content_to_string(&tool_result.content);
|
|
on_chunk(ReactStep::Observation { step: step_count, observation: obs }).await;
|
|
}
|
|
Ok(MultiTurnStreamItem::FinalResponse(resp)) => {
|
|
let usage = resp.usage();
|
|
total_input_tokens = usage.input_tokens as i64;
|
|
total_output_tokens = usage.output_tokens as i64;
|
|
}
|
|
Err(e) => {
|
|
let err_msg = format!("rig agent stream error: {}", e);
|
|
return Err(AgentError::OpenAi(err_msg));
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
|
|
let elapsed_ms = session_start.elapsed().as_millis() as i64;
|
|
record_ai_session(&request.cache, &request.db, request.project.id, request.sender.uid, session_id, request.room.id, request.model.id, version_id.unwrap_or_default(), total_input_tokens, total_output_tokens, elapsed_ms).await;
|
|
|
|
Ok((final_content, total_input_tokens, total_output_tokens))
|
|
}
|
|
|
|
/// Extract text from rig's ToolResultContent, ignoring images.
|
|
fn tool_result_content_to_string(content: &rig::one_or_many::OneOrMany<rig::completion::message::ToolResultContent>) -> String {
|
|
use rig::completion::message::ToolResultContent;
|
|
content.iter().filter_map(|item| {
|
|
if let ToolResultContent::Text(t) = item { Some(t.text.clone()) } else { None }
|
|
}).collect::<Vec<_>>().join("\n")
|
|
}
|