gitdataai/libs/agent/agent/rig_tool.rs

235 lines
7.9 KiB
Rust

use futures::Stream;
use futures::StreamExt;
use rig::{
agent::{AgentBuilder, MultiTurnStreamItem},
client::CompletionClient,
completion::Prompt,
streaming::{StreamedAssistantContent, StreamingPrompt},
};
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use crate::client::AiClientConfig;
use crate::error::AgentError;
#[derive(Debug)]
pub struct AgentResponse {
pub content: String,
pub input_tokens: u64,
pub output_tokens: u64,
}
#[derive(Debug)]
pub enum StreamChunk {
Text(String),
Final {
content: String,
input_tokens: u64,
output_tokens: u64,
},
}
pub struct RigAgentService {
config: AiClientConfig,
model_name: String,
}
impl RigAgentService {
pub fn new(config: AiClientConfig, model_name: impl Into<String>) -> Self {
Self {
config,
model_name: model_name.into(),
}
}
pub async fn prompt(
&self,
system_prompt: &str,
user_input: &str,
) -> std::result::Result<AgentResponse, AgentError> {
let client = self.config.build_rig_client();
let model = client.completion_model(&self.model_name);
let agent = AgentBuilder::new(model).preamble(system_prompt).build();
let response = agent
.prompt(user_input)
.extended_details()
.await
.map_err(|e: rig::completion::PromptError| AgentError::OpenAi(e.to_string()))?;
Ok(AgentResponse {
content: response.output,
input_tokens: response.usage.input_tokens,
output_tokens: response.usage.output_tokens,
})
}
pub async fn prompt_with_tools(
&self,
system_prompt: &str,
user_input: &str,
tools: Vec<Box<dyn rig::tool::ToolDyn + 'static>>,
max_turns: usize,
) -> std::result::Result<AgentResponse, AgentError> {
let client = self.config.build_rig_client();
let model = client.completion_model(&self.model_name);
let agent = AgentBuilder::new(model)
.preamble(system_prompt)
.tools(tools)
.default_max_turns(max_turns)
.build();
let response = agent
.prompt(user_input)
.max_turns(max_turns)
.extended_details()
.await
.map_err(|e: rig::completion::PromptError| AgentError::OpenAi(e.to_string()))?;
Ok(AgentResponse {
content: response.output,
input_tokens: response.usage.input_tokens,
output_tokens: response.usage.output_tokens,
})
}
pub async fn stream_prompt(
&self,
system_prompt: &str,
user_input: &str,
) -> std::result::Result<
impl Stream<Item = std::result::Result<StreamChunk, AgentError>>,
AgentError,
> {
let client = self.config.build_rig_client();
let model = client.completion_model(&self.model_name);
let agent = AgentBuilder::new(model).preamble(system_prompt).build();
let stream: rig::agent::StreamingResult<_> = agent.stream_prompt(user_input).await;
let (tx, rx) = mpsc::channel::<std::result::Result<StreamChunk, AgentError>>(100);
tokio::spawn(async move {
let mut final_content = String::new();
tokio::pin!(stream);
while let Some(item) = stream.next().await {
match item {
Ok(MultiTurnStreamItem::StreamAssistantItem(
StreamedAssistantContent::Text(text),
)) => {
let _ = tx.send(Ok(StreamChunk::Text(text.text.clone()))).await;
final_content.push_str(&text.text);
}
Ok(MultiTurnStreamItem::StreamAssistantItem(
StreamedAssistantContent::ToolCall {
tool_call,
internal_call_id: _,
},
)) => {
let args_str = match &tool_call.function.arguments {
serde_json::Value::String(s) => s.clone(),
v => serde_json::to_string(v).unwrap_or_default(),
};
tracing::info!(
tool = %tool_call.function.name,
args = %args_str,
"rig_agent_streaming_tool_call"
);
}
Ok(MultiTurnStreamItem::StreamUserItem(
rig::streaming::StreamedUserContent::ToolResult { tool_result, .. },
)) => {
tracing::info!(
tool_result_id = %tool_result.id,
"rig_agent_streaming_tool_result"
);
}
Ok(MultiTurnStreamItem::FinalResponse(resp)) => {
let usage = resp.usage();
let _ = tx
.send(Ok(StreamChunk::Final {
content: final_content.clone(),
input_tokens: usage.input_tokens,
output_tokens: usage.output_tokens,
}))
.await;
}
Err(e) => {
let _ = tx.send(Err(AgentError::OpenAi(e.to_string()))).await;
}
_ => {}
}
}
});
Ok(ReceiverStream::new(rx))
}
pub async fn stream_prompt_with_tools(
&self,
system_prompt: &str,
user_input: &str,
tools: Vec<Box<dyn rig::tool::ToolDyn + 'static>>,
max_turns: usize,
) -> std::result::Result<
impl Stream<Item = std::result::Result<StreamChunk, AgentError>>,
AgentError,
> {
let client = self.config.build_rig_client();
let model = client.completion_model(&self.model_name);
let agent = AgentBuilder::new(model)
.preamble(system_prompt)
.tools(tools)
.default_max_turns(max_turns)
.build();
let stream = agent
.stream_prompt(user_input)
.with_history(Vec::<rig::completion::Message>::new())
.multi_turn(max_turns)
.await;
let (tx, rx) = mpsc::channel::<Result<StreamChunk, AgentError>>(100);
tokio::spawn(async move {
let mut final_content = String::new();
tokio::pin!(stream);
while let Some(item) = stream.next().await {
match item {
Ok(MultiTurnStreamItem::StreamAssistantItem(
StreamedAssistantContent::Text(text),
)) => {
let _ = tx.send(Ok(StreamChunk::Text(text.text.clone()))).await;
final_content.push_str(&text.text);
}
Ok(MultiTurnStreamItem::FinalResponse(resp)) => {
let usage = resp.usage();
let _ = tx
.send(Ok(StreamChunk::Final {
content: final_content.clone(),
input_tokens: usage.input_tokens,
output_tokens: usage.output_tokens,
}))
.await;
}
Err(e) => {
let _ = tx.send(Err(AgentError::OpenAi(e.to_string()))).await;
}
_ => {}
}
}
});
Ok(ReceiverStream::new(rx))
}
pub fn count_tokens(&self, text: &str) -> Result<usize, AgentError> {
crate::tokent::count_text(text, &self.model_name)
.map_err(|e| AgentError::Internal(e.to_string()))
}
}