- Split agent crate into client/, model/, agent/ subdirs - Add billing.rs for token usage recording - Add sync.rs for upstream model sync - EmbedService: Qdrant-backed vector memory for semantic search - ChatService: wire EmbedService for memory lookup, passive skill awareness - ReAct loop: streamline with tokio::select! and proper error handling
613 lines
21 KiB
Rust
613 lines
21 KiB
Rust
//! Unified AI client with built-in retry, token tracking, and session recording.
|
|
//!
|
|
//! Uses rig-core as the underlying AI provider library.
|
|
|
|
pub mod types;
|
|
pub use types::{ChatRequestMessage, ToolCall as ClientToolCall};
|
|
|
|
use std::time::Instant;
|
|
use uuid::Uuid;
|
|
|
|
use crate::error::{AgentError, Result};
|
|
|
|
use futures::StreamExt;
|
|
use rig::completion::message::{AssistantContent, Message as RigMessage};
|
|
use rig::completion::{GetTokenUsage, ToolDefinition, CompletionModel};
|
|
use rig::one_or_many::OneOrMany;
|
|
use rig::prelude::CompletionClient;
|
|
use rig::providers::openai;
|
|
|
|
/// AI call metrics — increments metrics crate counters for all AI calls.
|
|
#[derive(Debug, Clone, Default)]
|
|
pub struct AiMetrics;
|
|
|
|
impl AiMetrics {
|
|
pub fn new() -> Self { Self }
|
|
|
|
pub fn record_success(&self, input_tokens: i64, output_tokens: i64, has_function_call: bool) {
|
|
metrics::counter!("ai_calls_total").increment(1);
|
|
metrics::counter!("ai_calls_success").increment(1);
|
|
if input_tokens > 0 {
|
|
metrics::counter!("ai_input_tokens_total").increment(input_tokens as u64);
|
|
}
|
|
if output_tokens > 0 {
|
|
metrics::counter!("ai_output_tokens_total").increment(output_tokens as u64);
|
|
}
|
|
if has_function_call {
|
|
metrics::counter!("ai_function_calls_total").increment(1);
|
|
}
|
|
}
|
|
|
|
pub fn record_failure(&self) {
|
|
metrics::counter!("ai_calls_total").increment(1);
|
|
metrics::counter!("ai_calls_failure").increment(1);
|
|
}
|
|
}
|
|
|
|
/// Configuration for the AI client.
|
|
#[derive(Clone)]
|
|
pub struct AiClientConfig {
|
|
pub api_key: String,
|
|
pub base_url: Option<String>,
|
|
}
|
|
|
|
impl AiClientConfig {
|
|
pub fn new(api_key: String) -> Self {
|
|
Self { api_key, base_url: None }
|
|
}
|
|
|
|
pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
|
|
self.base_url = Some(base_url.into());
|
|
self
|
|
}
|
|
|
|
/// Build a rig OpenAI client from this config.
|
|
pub fn build_rig_client(&self) -> openai::Client {
|
|
let base = self.base_url.clone().unwrap_or_else(|| "https://api.openai.com".to_string());
|
|
openai::Client::builder()
|
|
.api_key(&self.api_key)
|
|
.base_url(&base)
|
|
.build()
|
|
.expect("Failed to build rig OpenAI client")
|
|
}
|
|
}
|
|
|
|
/// Response from an AI call, including usage statistics.
|
|
#[derive(Debug, Clone)]
|
|
pub struct AiCallResponse {
|
|
pub content: String,
|
|
pub input_tokens: i64,
|
|
pub output_tokens: i64,
|
|
pub latency_ms: i64,
|
|
pub tool_calls_finished: Vec<String>,
|
|
}
|
|
|
|
impl AiCallResponse {
|
|
pub fn total_tokens(&self) -> i64 {
|
|
self.input_tokens + self.output_tokens
|
|
}
|
|
}
|
|
|
|
/// Internal state for retry tracking.
|
|
#[derive(Debug)]
|
|
struct RetryState {
|
|
attempt: u32,
|
|
max_retries: u32,
|
|
max_backoff_ms: u64,
|
|
}
|
|
|
|
impl RetryState {
|
|
fn new(max_retries: u32) -> Self {
|
|
Self { attempt: 0, max_retries, max_backoff_ms: 5000 }
|
|
}
|
|
fn should_retry(&self) -> bool { self.attempt < self.max_retries }
|
|
fn backoff_duration(&self) -> std::time::Duration {
|
|
let exp = self.attempt.min(5);
|
|
let base_ms = 500u64.saturating_mul(2u64.pow(exp)).min(self.max_backoff_ms);
|
|
let jitter = fastrand_u64(base_ms + 1);
|
|
std::time::Duration::from_millis(jitter)
|
|
}
|
|
fn next(&mut self) { self.attempt += 1; }
|
|
}
|
|
|
|
fn fastrand_u64(n: u64) -> u64 {
|
|
use std::sync::atomic::{AtomicU64, Ordering};
|
|
static STATE: AtomicU64 = AtomicU64::new(0x193_667_6a_5e_7c_57);
|
|
if n <= 1 { return 0; }
|
|
let mut current = STATE.load(Ordering::Relaxed);
|
|
loop {
|
|
let new_val = current.wrapping_mul(6364136223846793005).wrapping_add(1);
|
|
match STATE.compare_exchange_weak(current, new_val, Ordering::Relaxed, Ordering::Relaxed) {
|
|
Ok(_) => return new_val % n,
|
|
Err(actual) => current = actual,
|
|
}
|
|
}
|
|
}
|
|
|
|
fn is_retryable_error(err: &AgentError) -> bool {
|
|
let msg = err.to_string();
|
|
msg.contains("connection refused")
|
|
|| msg.contains("connection timed out")
|
|
|| msg.contains("network error")
|
|
|| msg.contains("dns error")
|
|
|| msg.contains("rate_limit")
|
|
|| msg.contains("rate limit")
|
|
|| msg.contains("429")
|
|
|| msg.contains("500")
|
|
|| msg.contains("502")
|
|
|| msg.contains("503")
|
|
|| msg.contains("504")
|
|
|| msg.contains("internal_server_error")
|
|
|| msg.contains("service_unavailable")
|
|
|| msg.contains("gateway_timeout")
|
|
|| msg.contains("bad_gateway")
|
|
}
|
|
|
|
static AI_METRICS: std::sync::OnceLock<AiMetrics> = std::sync::OnceLock::new();
|
|
|
|
fn ai_metrics() -> &'static AiMetrics {
|
|
AI_METRICS.get_or_init(AiMetrics::new)
|
|
}
|
|
|
|
// ── Type conversions ─────────────────────────────────────────────────────────
|
|
|
|
fn to_rig_message(msg: &ChatRequestMessage) -> RigMessage {
|
|
match msg.role.as_str() {
|
|
"system" => {
|
|
// System messages are handled via preamble(), but we still
|
|
// need to return something. Return a system message as User for safety.
|
|
RigMessage::user(msg.content.as_deref().unwrap_or(""))
|
|
}
|
|
"user" => {
|
|
RigMessage::user(msg.content.as_deref().unwrap_or(""))
|
|
}
|
|
"assistant" => {
|
|
let mut parts: Vec<AssistantContent> = Vec::new();
|
|
if let Some(ref content) = msg.content {
|
|
if !content.is_empty() {
|
|
parts.push(AssistantContent::text(content));
|
|
}
|
|
}
|
|
if let Some(ref tool_calls) = msg.tool_calls {
|
|
for tc in tool_calls {
|
|
// GLM may return empty tool call IDs — fall back to a generated UUID.
|
|
let id = if tc.id.is_empty() {
|
|
Uuid::new_v4().to_string()
|
|
} else {
|
|
tc.id.clone()
|
|
};
|
|
parts.push(AssistantContent::tool_call_with_call_id(
|
|
&id,
|
|
id.clone(),
|
|
&tc.function.name,
|
|
serde_json::from_str(&tc.function.arguments).unwrap_or(serde_json::Value::Null),
|
|
));
|
|
}
|
|
}
|
|
if parts.is_empty() {
|
|
RigMessage::assistant("")
|
|
} else if parts.len() == 1 {
|
|
// Single part — use simpler constructors
|
|
match parts.pop().unwrap() {
|
|
AssistantContent::Text(t) => RigMessage::assistant(t.text),
|
|
ac => {
|
|
RigMessage::Assistant {
|
|
id: None,
|
|
content: OneOrMany::one(ac),
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
let content = OneOrMany::many(parts).expect("non-empty parts");
|
|
RigMessage::Assistant { id: None, content }
|
|
}
|
|
}
|
|
"tool" | "function" => {
|
|
let id = msg.tool_call_id.as_deref().unwrap_or("unknown").to_string();
|
|
let call_id = msg.tool_call_id.clone().or_else(|| Some(id.clone()));
|
|
let content = msg.content.as_deref().unwrap_or("");
|
|
RigMessage::tool_result_with_call_id(id, call_id, content)
|
|
}
|
|
"developer" => {
|
|
// Developer role maps to user/system in rig
|
|
RigMessage::user(msg.content.as_deref().unwrap_or(""))
|
|
}
|
|
_ => RigMessage::user(msg.content.as_deref().unwrap_or("")),
|
|
}
|
|
}
|
|
|
|
fn to_rig_tool_def(tool_json: &serde_json::Value) -> Option<ToolDefinition> {
|
|
let name = tool_json
|
|
.get("function")
|
|
.and_then(|f| f.get("name"))
|
|
.and_then(|n| n.as_str())?
|
|
.to_string();
|
|
|
|
let description = tool_json
|
|
.get("function")
|
|
.and_then(|f| f.get("description"))
|
|
.and_then(|d| d.as_str())
|
|
.map(|s| s.to_string())
|
|
.unwrap_or_default();
|
|
|
|
let parameters = tool_json
|
|
.get("function")
|
|
.and_then(|f| f.get("parameters"))
|
|
.cloned()
|
|
.unwrap_or(serde_json::json!({}));
|
|
|
|
Some(ToolDefinition {
|
|
name,
|
|
description,
|
|
parameters,
|
|
})
|
|
}
|
|
|
|
|
|
// ── Call helpers ─────────────────────────────────────────────────────────────
|
|
|
|
async fn do_completion<M>(
|
|
model: &M,
|
|
messages: &[ChatRequestMessage],
|
|
temperature: Option<f64>,
|
|
max_tokens: Option<u32>,
|
|
tools: Option<&[serde_json::Value]>,
|
|
tool_choice: Option<&str>,
|
|
) -> Result<(String, u64, u64, Vec<String>)>
|
|
where
|
|
M: CompletionModel<Client = openai::Client>,
|
|
{
|
|
let mut history: Vec<RigMessage> = messages.iter().map(to_rig_message).collect();
|
|
|
|
// Extract preamble (first system message) and remove from history
|
|
let preamble = messages
|
|
.iter()
|
|
.find(|m| m.role == "system")
|
|
.and_then(|m| m.content.as_deref())
|
|
.unwrap_or("")
|
|
.to_string();
|
|
|
|
history.retain(|m| !matches!(m, RigMessage::User { .. } | RigMessage::Assistant { .. }));
|
|
|
|
// For tool_result messages, we need to add them back
|
|
// Actually, let's keep the approach: filter out system, add others back
|
|
// The rig completion request uses: preamble (system) + messages (conversation)
|
|
// For our messages: system → preamble, rest → messages
|
|
let non_system: Vec<RigMessage> = messages
|
|
.iter()
|
|
.filter(|m| m.role != "system")
|
|
.map(to_rig_message)
|
|
.collect();
|
|
|
|
let tool_defs: Vec<ToolDefinition> = tools
|
|
.map(|ts| ts.iter().filter_map(to_rig_tool_def).collect())
|
|
.unwrap_or_default();
|
|
|
|
let tc = match tool_choice {
|
|
Some("none") => rig::completion::message::ToolChoice::None,
|
|
Some("auto") | None => rig::completion::message::ToolChoice::Auto,
|
|
Some(s) => rig::completion::message::ToolChoice::Specific {
|
|
function_names: vec![s.to_string()],
|
|
},
|
|
};
|
|
|
|
let mut builder = model.completion_request("");
|
|
|
|
if !preamble.is_empty() {
|
|
builder = builder.preamble(preamble);
|
|
}
|
|
|
|
if !non_system.is_empty() {
|
|
builder = builder.messages(non_system);
|
|
}
|
|
|
|
if let Some(t) = temperature {
|
|
builder = builder.temperature(t);
|
|
}
|
|
|
|
if let Some(mt) = max_tokens {
|
|
builder = builder.max_tokens(mt as u64);
|
|
}
|
|
|
|
if !tool_defs.is_empty() {
|
|
builder = builder.tools(tool_defs);
|
|
}
|
|
|
|
builder = builder.tool_choice(tc);
|
|
|
|
let response = builder.send().await.map_err(|e| AgentError::OpenAi(e.to_string()))?;
|
|
|
|
let mut content = String::new();
|
|
let mut tool_names: Vec<String> = Vec::new();
|
|
for item in response.choice {
|
|
match item {
|
|
AssistantContent::Text(t) => {
|
|
content.push_str(&t.text);
|
|
}
|
|
AssistantContent::ToolCall(tc) => {
|
|
tool_names.push(tc.function.name.clone());
|
|
}
|
|
AssistantContent::Reasoning(_) => {}
|
|
AssistantContent::Image(_) => {}
|
|
}
|
|
}
|
|
|
|
let input_tokens = response.usage.input_tokens;
|
|
let output_tokens = response.usage.output_tokens;
|
|
|
|
Ok((content, input_tokens, output_tokens, tool_names))
|
|
}
|
|
|
|
// ── Public API ───────────────────────────────────────────────────────────────
|
|
|
|
/// Call the AI model with automatic retry (no custom params).
|
|
pub async fn call_with_retry(
|
|
messages: &[ChatRequestMessage],
|
|
model_name: &str,
|
|
config: &AiClientConfig,
|
|
max_retries: Option<u32>,
|
|
) -> Result<AiCallResponse> {
|
|
let client = config.build_rig_client();
|
|
let model = client.completion_model(model_name);
|
|
let mut state = RetryState::new(max_retries.unwrap_or(3));
|
|
|
|
loop {
|
|
let start = Instant::now();
|
|
|
|
let result = do_completion(&model, messages, None, None, None, None).await;
|
|
|
|
match result {
|
|
Ok((content, input_tokens, output_tokens, tool_names)) => {
|
|
let latency_ms = start.elapsed().as_millis() as i64;
|
|
let has_function_call = !tool_names.is_empty();
|
|
ai_metrics().record_success(input_tokens as i64, output_tokens as i64, has_function_call);
|
|
return Ok(AiCallResponse { content, input_tokens: input_tokens as i64, output_tokens: output_tokens as i64, latency_ms, tool_calls_finished: tool_names });
|
|
}
|
|
Err(ref err) if state.should_retry() && is_retryable_error(err) => {
|
|
let duration = state.backoff_duration();
|
|
tracing::warn!(
|
|
attempt = state.attempt + 1,
|
|
max_retries = state.max_retries,
|
|
backoff_ms = duration.as_millis() as u64,
|
|
model = %model_name,
|
|
error = %err,
|
|
"ai_call_retry"
|
|
);
|
|
tokio::time::sleep(duration).await;
|
|
state.next();
|
|
}
|
|
Err(err) => {
|
|
ai_metrics().record_failure();
|
|
return Err(err);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Call with custom parameters (temperature, max_tokens, optional tools, optional tool_choice).
|
|
pub async fn call_with_params(
|
|
messages: &[ChatRequestMessage],
|
|
model_name: &str,
|
|
config: &AiClientConfig,
|
|
temperature: f32,
|
|
max_tokens: u32,
|
|
max_retries: Option<u32>,
|
|
tools: Option<&[serde_json::Value]>,
|
|
tool_choice: Option<&str>,
|
|
) -> Result<AiCallResponse> {
|
|
let client = config.build_rig_client();
|
|
let model = client.completion_model(model_name);
|
|
let mut state = RetryState::new(max_retries.unwrap_or(3));
|
|
|
|
loop {
|
|
let start = Instant::now();
|
|
|
|
let result = do_completion(
|
|
&model,
|
|
messages,
|
|
Some(temperature as f64),
|
|
Some(max_tokens),
|
|
tools,
|
|
tool_choice,
|
|
)
|
|
.await;
|
|
|
|
match result {
|
|
Ok((content, input_tokens, output_tokens, tool_names)) => {
|
|
let latency_ms = start.elapsed().as_millis() as i64;
|
|
let has_function_call = !tool_names.is_empty();
|
|
ai_metrics().record_success(input_tokens as i64, output_tokens as i64, has_function_call);
|
|
return Ok(AiCallResponse { content, input_tokens: input_tokens as i64, output_tokens: output_tokens as i64, latency_ms, tool_calls_finished: tool_names });
|
|
}
|
|
Err(ref err) if state.should_retry() && is_retryable_error(err) => {
|
|
let duration = state.backoff_duration();
|
|
tracing::warn!(
|
|
attempt = state.attempt + 1,
|
|
max_retries = state.max_retries,
|
|
backoff_ms = duration.as_millis() as u64,
|
|
model = %model_name,
|
|
error = %err,
|
|
"ai_call_retry"
|
|
);
|
|
tokio::time::sleep(duration).await;
|
|
state.next();
|
|
}
|
|
Err(err) => {
|
|
ai_metrics().record_failure();
|
|
return Err(err);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// A tool call extracted from streaming response with accumulated arguments.
|
|
#[derive(Debug, Clone)]
|
|
pub struct StreamedToolCall {
|
|
/// Tool call ID
|
|
pub id: String,
|
|
/// Tool function name
|
|
pub name: String,
|
|
/// Accumulated JSON arguments string
|
|
pub arguments: String,
|
|
}
|
|
|
|
/// Streaming result from rig.
|
|
#[derive(Debug)]
|
|
pub struct StreamResponse {
|
|
pub content: String,
|
|
pub input_tokens: i64,
|
|
pub output_tokens: i64,
|
|
/// Full tool calls with accumulated arguments (not just names)
|
|
pub tool_calls: Vec<StreamedToolCall>,
|
|
}
|
|
|
|
/// Run a streaming chat completion.
|
|
pub async fn call_stream(
|
|
messages: &[ChatRequestMessage],
|
|
model_name: &str,
|
|
config: &AiClientConfig,
|
|
temperature: f32,
|
|
max_tokens: u32,
|
|
tools: Option<&[serde_json::Value]>,
|
|
mut on_text_delta: impl FnMut(&str),
|
|
) -> Result<StreamResponse> {
|
|
let client = config.build_rig_client();
|
|
let model = client.completion_model(model_name);
|
|
|
|
let preamble = messages
|
|
.iter()
|
|
.find(|m| m.role == "system")
|
|
.and_then(|m| m.content.as_deref())
|
|
.unwrap_or("")
|
|
.to_string();
|
|
|
|
let non_system: Vec<RigMessage> = messages
|
|
.iter()
|
|
.filter(|m| m.role != "system")
|
|
.map(to_rig_message)
|
|
.collect();
|
|
|
|
let tool_defs: Vec<ToolDefinition> = tools
|
|
.map(|ts| ts.iter().filter_map(to_rig_tool_def).collect())
|
|
.unwrap_or_default();
|
|
|
|
let mut builder = model
|
|
.completion_request("")
|
|
.temperature(temperature as f64)
|
|
.max_tokens(max_tokens as u64);
|
|
|
|
if !preamble.is_empty() {
|
|
builder = builder.preamble(preamble);
|
|
}
|
|
if !non_system.is_empty() {
|
|
builder = builder.messages(non_system);
|
|
}
|
|
if !tool_defs.is_empty() {
|
|
builder = builder.tools(tool_defs);
|
|
}
|
|
|
|
let mut stream = builder
|
|
.stream()
|
|
.await
|
|
.map_err(|e| AgentError::OpenAi(e.to_string()))?;
|
|
|
|
let mut content = String::new();
|
|
let mut tool_calls: Vec<StreamedToolCall> = Vec::new();
|
|
|
|
// Track partial tool calls by internal_call_id for argument accumulation
|
|
use std::collections::HashMap;
|
|
let mut partial_tool_calls: HashMap<String, StreamedToolCall> = HashMap::new();
|
|
let mut stream_finished = false;
|
|
|
|
use rig::streaming::StreamedAssistantContent;
|
|
|
|
while let Some(item) = stream.next().await {
|
|
match item {
|
|
Ok(StreamedAssistantContent::Text(text)) => {
|
|
content.push_str(&text.text);
|
|
on_text_delta(&text.text);
|
|
}
|
|
Ok(StreamedAssistantContent::ToolCall {
|
|
tool_call,
|
|
internal_call_id,
|
|
}) => {
|
|
// Complete tool call - extract arguments from the JSON Value
|
|
let arguments = match &tool_call.function.arguments {
|
|
serde_json::Value::String(s) => s.clone(),
|
|
other => serde_json::to_string(other).unwrap_or_else(|_| "{}".to_string()),
|
|
};
|
|
tool_calls.push(StreamedToolCall {
|
|
id: tool_call.id.clone(),
|
|
name: tool_call.function.name.clone(),
|
|
arguments,
|
|
});
|
|
// Remove from partial if it was being accumulated
|
|
partial_tool_calls.remove(&internal_call_id);
|
|
}
|
|
Ok(StreamedAssistantContent::ToolCallDelta {
|
|
id,
|
|
internal_call_id,
|
|
content,
|
|
}) => {
|
|
use rig::streaming::ToolCallDeltaContent;
|
|
match content {
|
|
ToolCallDeltaContent::Name(name) => {
|
|
// Start accumulating a new tool call
|
|
partial_tool_calls.insert(
|
|
internal_call_id.clone(),
|
|
StreamedToolCall {
|
|
id: id.clone(),
|
|
name,
|
|
arguments: String::new(),
|
|
},
|
|
);
|
|
}
|
|
ToolCallDeltaContent::Delta(delta) => {
|
|
// Append to existing partial tool call
|
|
if let Some(tc) = partial_tool_calls.get_mut(&internal_call_id) {
|
|
tc.arguments.push_str(&delta);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
Ok(StreamedAssistantContent::Reasoning(_)) => {}
|
|
Ok(StreamedAssistantContent::ReasoningDelta { .. }) => {}
|
|
Ok(StreamedAssistantContent::Final(response)) => {
|
|
stream_finished = true;
|
|
// Flush any remaining partial tool calls
|
|
for (_, tc) in partial_tool_calls.drain() {
|
|
tool_calls.push(tc);
|
|
}
|
|
if let Some(usage) = response.token_usage() {
|
|
ai_metrics().record_success(
|
|
usage.input_tokens as i64,
|
|
usage.output_tokens as i64,
|
|
!tool_calls.is_empty(),
|
|
);
|
|
return Ok(StreamResponse {
|
|
content,
|
|
input_tokens: usage.input_tokens as i64,
|
|
output_tokens: usage.output_tokens as i64,
|
|
tool_calls,
|
|
});
|
|
}
|
|
}
|
|
Err(e) => return Err(AgentError::OpenAi(e.to_string())),
|
|
}
|
|
}
|
|
|
|
// Flush any remaining partial tool calls (if stream ended without Final)
|
|
if !stream_finished {
|
|
for (_, tc) in partial_tool_calls.drain() {
|
|
tool_calls.push(tc);
|
|
}
|
|
}
|
|
ai_metrics().record_success(0, 0, !tool_calls.is_empty());
|
|
Ok(StreamResponse {
|
|
content,
|
|
input_tokens: 0,
|
|
output_tokens: 0,
|
|
tool_calls,
|
|
})
|
|
}
|