gitdataai/libs/agent/agent/service.rs
ZhenYi 10c0cc007b refactor(agent): split into submodules and add Qdrant embedding
- Split agent crate into client/, model/, agent/ subdirs
- Add billing.rs for token usage recording
- Add sync.rs for upstream model sync
- EmbedService: Qdrant-backed vector memory for semantic search
- ChatService: wire EmbedService for memory lookup, passive skill awareness
- ReAct loop: streamline with tokio::select! and proper error handling
2026-04-25 20:09:33 +08:00

268 lines
9.5 KiB
Rust

//! Agent service using rig's built-in Agent.
//!
//! This is a complete implementation that leverages rig's Agent for
//! multi-turn reasoning, tool execution, streaming, and token tracking.
use futures::Stream;
use futures::StreamExt;
use rig::{
agent::{AgentBuilder, MultiTurnStreamItem},
client::CompletionClient,
completion::Prompt,
streaming::{StreamingPrompt, StreamedAssistantContent},
};
use tokio_stream::wrappers::ReceiverStream;
use tokio::sync::mpsc;
use crate::client::AiClientConfig;
use crate::error::AgentError;
/// Response from an agent completion (rig's Agent prompt response).
#[derive(Debug)]
pub struct AgentResponse {
pub content: String,
pub input_tokens: u64,
pub output_tokens: u64,
}
/// Streaming chunk from the agent.
#[derive(Debug)]
pub enum StreamChunk {
/// Text delta from the model
Text(String),
/// Final response with aggregated usage
Final {
content: String,
input_tokens: u64,
output_tokens: u64,
},
}
/// Service for running agents using rig's built-in Agent.
///
/// Provides both simple prompting and full streaming with automatic
/// tool call handling via rig's native Agent.
pub struct RigAgentService {
config: AiClientConfig,
model_name: String,
}
impl RigAgentService {
/// Create a new RigAgentService.
pub fn new(config: AiClientConfig, model_name: impl Into<String>) -> Self {
Self { config, model_name: model_name.into() }
}
/// Run a single prompt with the agent (single-turn, no tools).
pub async fn prompt(
&self,
system_prompt: &str,
user_input: &str,
) -> std::result::Result<AgentResponse, AgentError> {
let client = self.config.build_rig_client();
let model = client.completion_model(&self.model_name);
let agent = AgentBuilder::new(model)
.preamble(system_prompt)
.build();
let response = agent
.prompt(user_input)
.extended_details()
.await
.map_err(|e: rig::completion::PromptError| AgentError::OpenAi(e.to_string()))?;
Ok(AgentResponse {
content: response.output,
input_tokens: response.total_usage.input_tokens,
output_tokens: response.total_usage.output_tokens,
})
}
/// Run a prompt with tools (supports multi-turn via rig's Agent).
///
/// The agent will automatically handle tool calls by calling rig's
/// ToolDyn implementations with proper argument deserialization.
pub async fn prompt_with_tools(
&self,
system_prompt: &str,
user_input: &str,
tools: Vec<Box<dyn rig::tool::ToolDyn + 'static>>,
max_turns: usize,
) -> std::result::Result<AgentResponse, AgentError> {
let client = self.config.build_rig_client();
let model = client.completion_model(&self.model_name);
let agent = AgentBuilder::new(model)
.preamble(system_prompt)
.tools(tools)
.default_max_turns(max_turns)
.build();
let response = agent
.prompt(user_input)
.max_turns(max_turns)
.extended_details()
.await
.map_err(|e: rig::completion::PromptError| AgentError::OpenAi(e.to_string()))?;
Ok(AgentResponse {
content: response.output,
input_tokens: response.total_usage.input_tokens,
output_tokens: response.total_usage.output_tokens,
})
}
/// Stream a prompt with the agent using rig's native streaming.
///
/// This returns a proper async stream that yields text chunks as they arrive
/// and a final response chunk with aggregated token usage. Tool calls are
/// handled transparently by rig's Agent.
pub async fn stream_prompt(
&self,
system_prompt: &str,
user_input: &str,
) -> std::result::Result<
impl Stream<Item = std::result::Result<StreamChunk, AgentError>>,
AgentError,
> {
let client = self.config.build_rig_client();
let model = client.completion_model(&self.model_name);
let agent = AgentBuilder::new(model)
.preamble(system_prompt)
.build();
// stream_prompt().await returns StreamingResult directly (not wrapped in Result)
// StreamingResult is Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem, StreamingError>>>>
let stream: rig::agent::StreamingResult<_> = agent
.stream_prompt(user_input)
.await;
// Bridge the rig stream to our channel-based stream
let (tx, rx) = mpsc::channel::<std::result::Result<StreamChunk, AgentError>>(100);
tokio::spawn(async move {
let mut final_content = String::new();
tokio::pin!(stream);
while let Some(item) = stream.next().await {
match item {
Ok(MultiTurnStreamItem::StreamAssistantItem(
StreamedAssistantContent::Text(text),
)) => {
let _ = tx.send(Ok(StreamChunk::Text(text.text.clone()))).await;
final_content.push_str(&text.text);
}
Ok(MultiTurnStreamItem::StreamAssistantItem(
StreamedAssistantContent::ToolCall { tool_call, internal_call_id: _ },
)) => {
let args_str = match &tool_call.function.arguments {
serde_json::Value::String(s) => s.clone(),
v => serde_json::to_string(v).unwrap_or_default(),
};
tracing::info!(
tool = %tool_call.function.name,
args = %args_str,
"rig_agent_streaming_tool_call"
);
// Tool calllint — emitted for observability, rig handles execution internally
}
Ok(MultiTurnStreamItem::StreamUserItem(
rig::streaming::StreamedUserContent::ToolResult { tool_result, .. },
)) => {
tracing::info!(
tool_result_id = %tool_result.id,
"rig_agent_streaming_tool_result"
);
}
Ok(MultiTurnStreamItem::FinalResponse(resp)) => {
let usage = resp.usage();
let _ = tx.send(Ok(StreamChunk::Final {
content: final_content.clone(),
input_tokens: usage.input_tokens,
output_tokens: usage.output_tokens,
})).await;
}
Err(e) => {
let _ = tx.send(Err(AgentError::OpenAi(e.to_string()))).await;
}
_ => {}
}
}
});
Ok(ReceiverStream::new(rx))
}
/// Stream a prompt with tools using rig's native streaming.
///
/// Returns a stream thatproperly handles multi-turn tool calls via rig's Agent
/// streaming infrastructure.
pub async fn stream_prompt_with_tools(
&self,
system_prompt: &str,
user_input: &str,
tools: Vec<Box<dyn rig::tool::ToolDyn + 'static>>,
max_turns: usize,
) -> std::result::Result<
impl Stream<Item = std::result::Result<StreamChunk, AgentError>>,
AgentError,
> {
let client = self.config.build_rig_client();
let model = client.completion_model(&self.model_name);
let agent = AgentBuilder::new(model)
.preamble(system_prompt)
.tools(tools)
.default_max_turns(max_turns)
.build();
let stream = agent
.stream_prompt(user_input)
.with_history(Vec::new())
.multi_turn(max_turns)
.await;
let (tx, rx) = mpsc::channel::<std::result::Result<StreamChunk, AgentError>>(100);
tokio::spawn(async move {
let mut final_content = String::new();
tokio::pin!(stream);
while let Some(item) = stream.next().await {
match item {
Ok(MultiTurnStreamItem::StreamAssistantItem(
StreamedAssistantContent::Text(text),
)) => {
let _ = tx.send(Ok(StreamChunk::Text(text.text.clone()))).await;
final_content.push_str(&text.text);
}
Ok(MultiTurnStreamItem::FinalResponse(resp)) => {
let usage = resp.usage();
let _ = tx.send(Ok(StreamChunk::Final {
content: final_content.clone(),
input_tokens: usage.input_tokens,
output_tokens: usage.output_tokens,
})).await;
}
Err(e) => {
let _ = tx.send(Err(AgentError::OpenAi(e.to_string()))).await;
}
_ => {}
}
}
});
Ok(ReceiverStream::new(rx))
}
/// Count tokens in text using tiktoken for the configured model.
pub fn count_tokens(&self, text: &str) -> std::result::Result<usize, AgentError> {
crate::tokent::count_text(text, &self.model_name)
.map_err(|e| AgentError::Internal(e.to_string()))
}
}