- Split agent crate into client/, model/, agent/ subdirs - Add billing.rs for token usage recording - Add sync.rs for upstream model sync - EmbedService: Qdrant-backed vector memory for semantic search - ChatService: wire EmbedService for memory lookup, passive skill awareness - ReAct loop: streamline with tokio::select! and proper error handling
268 lines
9.5 KiB
Rust
268 lines
9.5 KiB
Rust
//! Agent service using rig's built-in Agent.
|
|
//!
|
|
//! This is a complete implementation that leverages rig's Agent for
|
|
//! multi-turn reasoning, tool execution, streaming, and token tracking.
|
|
|
|
use futures::Stream;
|
|
use futures::StreamExt;
|
|
use rig::{
|
|
agent::{AgentBuilder, MultiTurnStreamItem},
|
|
client::CompletionClient,
|
|
completion::Prompt,
|
|
streaming::{StreamingPrompt, StreamedAssistantContent},
|
|
};
|
|
use tokio_stream::wrappers::ReceiverStream;
|
|
use tokio::sync::mpsc;
|
|
|
|
use crate::client::AiClientConfig;
|
|
use crate::error::AgentError;
|
|
|
|
/// Response from an agent completion (rig's Agent prompt response).
|
|
#[derive(Debug)]
|
|
pub struct AgentResponse {
|
|
pub content: String,
|
|
pub input_tokens: u64,
|
|
pub output_tokens: u64,
|
|
}
|
|
|
|
/// Streaming chunk from the agent.
|
|
#[derive(Debug)]
|
|
pub enum StreamChunk {
|
|
/// Text delta from the model
|
|
Text(String),
|
|
/// Final response with aggregated usage
|
|
Final {
|
|
content: String,
|
|
input_tokens: u64,
|
|
output_tokens: u64,
|
|
},
|
|
}
|
|
|
|
/// Service for running agents using rig's built-in Agent.
|
|
///
|
|
/// Provides both simple prompting and full streaming with automatic
|
|
/// tool call handling via rig's native Agent.
|
|
pub struct RigAgentService {
|
|
config: AiClientConfig,
|
|
model_name: String,
|
|
}
|
|
|
|
impl RigAgentService {
|
|
/// Create a new RigAgentService.
|
|
pub fn new(config: AiClientConfig, model_name: impl Into<String>) -> Self {
|
|
Self { config, model_name: model_name.into() }
|
|
}
|
|
|
|
/// Run a single prompt with the agent (single-turn, no tools).
|
|
pub async fn prompt(
|
|
&self,
|
|
system_prompt: &str,
|
|
user_input: &str,
|
|
) -> std::result::Result<AgentResponse, AgentError> {
|
|
let client = self.config.build_rig_client();
|
|
let model = client.completion_model(&self.model_name);
|
|
|
|
let agent = AgentBuilder::new(model)
|
|
.preamble(system_prompt)
|
|
.build();
|
|
|
|
let response = agent
|
|
.prompt(user_input)
|
|
.extended_details()
|
|
.await
|
|
.map_err(|e: rig::completion::PromptError| AgentError::OpenAi(e.to_string()))?;
|
|
|
|
Ok(AgentResponse {
|
|
content: response.output,
|
|
input_tokens: response.total_usage.input_tokens,
|
|
output_tokens: response.total_usage.output_tokens,
|
|
})
|
|
}
|
|
|
|
/// Run a prompt with tools (supports multi-turn via rig's Agent).
|
|
///
|
|
/// The agent will automatically handle tool calls by calling rig's
|
|
/// ToolDyn implementations with proper argument deserialization.
|
|
pub async fn prompt_with_tools(
|
|
&self,
|
|
system_prompt: &str,
|
|
user_input: &str,
|
|
tools: Vec<Box<dyn rig::tool::ToolDyn + 'static>>,
|
|
max_turns: usize,
|
|
) -> std::result::Result<AgentResponse, AgentError> {
|
|
let client = self.config.build_rig_client();
|
|
let model = client.completion_model(&self.model_name);
|
|
|
|
let agent = AgentBuilder::new(model)
|
|
.preamble(system_prompt)
|
|
.tools(tools)
|
|
.default_max_turns(max_turns)
|
|
.build();
|
|
|
|
let response = agent
|
|
.prompt(user_input)
|
|
.max_turns(max_turns)
|
|
.extended_details()
|
|
.await
|
|
.map_err(|e: rig::completion::PromptError| AgentError::OpenAi(e.to_string()))?;
|
|
|
|
Ok(AgentResponse {
|
|
content: response.output,
|
|
input_tokens: response.total_usage.input_tokens,
|
|
output_tokens: response.total_usage.output_tokens,
|
|
})
|
|
}
|
|
|
|
/// Stream a prompt with the agent using rig's native streaming.
|
|
///
|
|
/// This returns a proper async stream that yields text chunks as they arrive
|
|
/// and a final response chunk with aggregated token usage. Tool calls are
|
|
/// handled transparently by rig's Agent.
|
|
pub async fn stream_prompt(
|
|
&self,
|
|
system_prompt: &str,
|
|
user_input: &str,
|
|
) -> std::result::Result<
|
|
impl Stream<Item = std::result::Result<StreamChunk, AgentError>>,
|
|
AgentError,
|
|
> {
|
|
let client = self.config.build_rig_client();
|
|
let model = client.completion_model(&self.model_name);
|
|
|
|
let agent = AgentBuilder::new(model)
|
|
.preamble(system_prompt)
|
|
.build();
|
|
|
|
// stream_prompt().await returns StreamingResult directly (not wrapped in Result)
|
|
// StreamingResult is Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem, StreamingError>>>>
|
|
let stream: rig::agent::StreamingResult<_> = agent
|
|
.stream_prompt(user_input)
|
|
.await;
|
|
|
|
// Bridge the rig stream to our channel-based stream
|
|
let (tx, rx) = mpsc::channel::<std::result::Result<StreamChunk, AgentError>>(100);
|
|
|
|
tokio::spawn(async move {
|
|
let mut final_content = String::new();
|
|
|
|
tokio::pin!(stream);
|
|
|
|
while let Some(item) = stream.next().await {
|
|
match item {
|
|
Ok(MultiTurnStreamItem::StreamAssistantItem(
|
|
StreamedAssistantContent::Text(text),
|
|
)) => {
|
|
let _ = tx.send(Ok(StreamChunk::Text(text.text.clone()))).await;
|
|
final_content.push_str(&text.text);
|
|
}
|
|
Ok(MultiTurnStreamItem::StreamAssistantItem(
|
|
StreamedAssistantContent::ToolCall { tool_call, internal_call_id: _ },
|
|
)) => {
|
|
let args_str = match &tool_call.function.arguments {
|
|
serde_json::Value::String(s) => s.clone(),
|
|
v => serde_json::to_string(v).unwrap_or_default(),
|
|
};
|
|
tracing::info!(
|
|
tool = %tool_call.function.name,
|
|
args = %args_str,
|
|
"rig_agent_streaming_tool_call"
|
|
);
|
|
// Tool calllint — emitted for observability, rig handles execution internally
|
|
}
|
|
Ok(MultiTurnStreamItem::StreamUserItem(
|
|
rig::streaming::StreamedUserContent::ToolResult { tool_result, .. },
|
|
)) => {
|
|
tracing::info!(
|
|
tool_result_id = %tool_result.id,
|
|
"rig_agent_streaming_tool_result"
|
|
);
|
|
}
|
|
Ok(MultiTurnStreamItem::FinalResponse(resp)) => {
|
|
let usage = resp.usage();
|
|
let _ = tx.send(Ok(StreamChunk::Final {
|
|
content: final_content.clone(),
|
|
input_tokens: usage.input_tokens,
|
|
output_tokens: usage.output_tokens,
|
|
})).await;
|
|
}
|
|
Err(e) => {
|
|
let _ = tx.send(Err(AgentError::OpenAi(e.to_string()))).await;
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
});
|
|
|
|
Ok(ReceiverStream::new(rx))
|
|
}
|
|
|
|
/// Stream a prompt with tools using rig's native streaming.
|
|
///
|
|
/// Returns a stream thatproperly handles multi-turn tool calls via rig's Agent
|
|
/// streaming infrastructure.
|
|
pub async fn stream_prompt_with_tools(
|
|
&self,
|
|
system_prompt: &str,
|
|
user_input: &str,
|
|
tools: Vec<Box<dyn rig::tool::ToolDyn + 'static>>,
|
|
max_turns: usize,
|
|
) -> std::result::Result<
|
|
impl Stream<Item = std::result::Result<StreamChunk, AgentError>>,
|
|
AgentError,
|
|
> {
|
|
let client = self.config.build_rig_client();
|
|
let model = client.completion_model(&self.model_name);
|
|
|
|
let agent = AgentBuilder::new(model)
|
|
.preamble(system_prompt)
|
|
.tools(tools)
|
|
.default_max_turns(max_turns)
|
|
.build();
|
|
|
|
let stream = agent
|
|
.stream_prompt(user_input)
|
|
.with_history(Vec::new())
|
|
.multi_turn(max_turns)
|
|
.await;
|
|
|
|
let (tx, rx) = mpsc::channel::<std::result::Result<StreamChunk, AgentError>>(100);
|
|
|
|
tokio::spawn(async move {
|
|
let mut final_content = String::new();
|
|
|
|
tokio::pin!(stream);
|
|
|
|
while let Some(item) = stream.next().await {
|
|
match item {
|
|
Ok(MultiTurnStreamItem::StreamAssistantItem(
|
|
StreamedAssistantContent::Text(text),
|
|
)) => {
|
|
let _ = tx.send(Ok(StreamChunk::Text(text.text.clone()))).await;
|
|
final_content.push_str(&text.text);
|
|
}
|
|
Ok(MultiTurnStreamItem::FinalResponse(resp)) => {
|
|
let usage = resp.usage();
|
|
let _ = tx.send(Ok(StreamChunk::Final {
|
|
content: final_content.clone(),
|
|
input_tokens: usage.input_tokens,
|
|
output_tokens: usage.output_tokens,
|
|
})).await;
|
|
}
|
|
Err(e) => {
|
|
let _ = tx.send(Err(AgentError::OpenAi(e.to_string()))).await;
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
});
|
|
|
|
Ok(ReceiverStream::new(rx))
|
|
}
|
|
|
|
/// Count tokens in text using tiktoken for the configured model.
|
|
pub fn count_tokens(&self, text: &str) -> std::result::Result<usize, AgentError> {
|
|
crate::tokent::count_text(text, &self.model_name)
|
|
.map_err(|e| AgentError::Internal(e.to_string()))
|
|
}
|
|
}
|