//! Unified AI client with built-in retry, token tracking, and session recording. //! //! Uses rig-core as the underlying AI provider library. pub mod types; pub use types::{ChatRequestMessage, ToolCall as ClientToolCall}; use std::time::Instant; use uuid::Uuid; use crate::error::{AgentError, Result}; use futures::StreamExt; use rig::completion::message::{AssistantContent, Message as RigMessage}; use rig::completion::{GetTokenUsage, ToolDefinition, CompletionModel}; use rig::one_or_many::OneOrMany; use rig::prelude::CompletionClient; use rig::providers::openai; /// AI call metrics — increments metrics crate counters for all AI calls. #[derive(Debug, Clone, Default)] pub struct AiMetrics; impl AiMetrics { pub fn new() -> Self { Self } pub fn record_success(&self, input_tokens: i64, output_tokens: i64, has_function_call: bool) { metrics::counter!("ai_calls_total").increment(1); metrics::counter!("ai_calls_success").increment(1); if input_tokens > 0 { metrics::counter!("ai_input_tokens_total").increment(input_tokens as u64); } if output_tokens > 0 { metrics::counter!("ai_output_tokens_total").increment(output_tokens as u64); } if has_function_call { metrics::counter!("ai_function_calls_total").increment(1); } } pub fn record_failure(&self) { metrics::counter!("ai_calls_total").increment(1); metrics::counter!("ai_calls_failure").increment(1); } } /// Configuration for the AI client. #[derive(Clone)] pub struct AiClientConfig { pub api_key: String, pub base_url: Option, } impl AiClientConfig { pub fn new(api_key: String) -> Self { Self { api_key, base_url: None } } pub fn with_base_url(mut self, base_url: impl Into) -> Self { self.base_url = Some(base_url.into()); self } /// Build a rig OpenAI client from this config. pub fn build_rig_client(&self) -> openai::Client { let base = self.base_url.clone().unwrap_or_else(|| "https://api.openai.com".to_string()); openai::Client::builder() .api_key(&self.api_key) .base_url(&base) .build() .expect("Failed to build rig OpenAI client") } } /// Response from an AI call, including usage statistics. #[derive(Debug, Clone)] pub struct AiCallResponse { pub content: String, pub input_tokens: i64, pub output_tokens: i64, pub latency_ms: i64, pub tool_calls_finished: Vec, } impl AiCallResponse { pub fn total_tokens(&self) -> i64 { self.input_tokens + self.output_tokens } } /// Internal state for retry tracking. #[derive(Debug)] struct RetryState { attempt: u32, max_retries: u32, max_backoff_ms: u64, } impl RetryState { fn new(max_retries: u32) -> Self { Self { attempt: 0, max_retries, max_backoff_ms: 5000 } } fn should_retry(&self) -> bool { self.attempt < self.max_retries } fn backoff_duration(&self) -> std::time::Duration { let exp = self.attempt.min(5); let base_ms = 500u64.saturating_mul(2u64.pow(exp)).min(self.max_backoff_ms); let jitter = fastrand_u64(base_ms + 1); std::time::Duration::from_millis(jitter) } fn next(&mut self) { self.attempt += 1; } } fn fastrand_u64(n: u64) -> u64 { use std::sync::atomic::{AtomicU64, Ordering}; static STATE: AtomicU64 = AtomicU64::new(0x193_667_6a_5e_7c_57); if n <= 1 { return 0; } let mut current = STATE.load(Ordering::Relaxed); loop { let new_val = current.wrapping_mul(6364136223846793005).wrapping_add(1); match STATE.compare_exchange_weak(current, new_val, Ordering::Relaxed, Ordering::Relaxed) { Ok(_) => return new_val % n, Err(actual) => current = actual, } } } fn is_retryable_error(err: &AgentError) -> bool { let msg = err.to_string(); msg.contains("connection refused") || msg.contains("connection timed out") || msg.contains("network error") || msg.contains("dns error") || msg.contains("rate_limit") || msg.contains("rate limit") || msg.contains("429") || msg.contains("500") || msg.contains("502") || msg.contains("503") || msg.contains("504") || msg.contains("internal_server_error") || msg.contains("service_unavailable") || msg.contains("gateway_timeout") || msg.contains("bad_gateway") } static AI_METRICS: std::sync::OnceLock = std::sync::OnceLock::new(); fn ai_metrics() -> &'static AiMetrics { AI_METRICS.get_or_init(AiMetrics::new) } // ── Type conversions ───────────────────────────────────────────────────────── fn to_rig_message(msg: &ChatRequestMessage) -> RigMessage { match msg.role.as_str() { "system" => { // System messages are handled via preamble(), but we still // need to return something. Return a system message as User for safety. RigMessage::user(msg.content.as_deref().unwrap_or("")) } "user" => { RigMessage::user(msg.content.as_deref().unwrap_or("")) } "assistant" => { let mut parts: Vec = Vec::new(); if let Some(ref content) = msg.content { if !content.is_empty() { parts.push(AssistantContent::text(content)); } } if let Some(ref tool_calls) = msg.tool_calls { for tc in tool_calls { // GLM may return empty tool call IDs — fall back to a generated UUID. let id = if tc.id.is_empty() { Uuid::new_v4().to_string() } else { tc.id.clone() }; parts.push(AssistantContent::tool_call_with_call_id( &id, id.clone(), &tc.function.name, serde_json::from_str(&tc.function.arguments).unwrap_or(serde_json::Value::Null), )); } } if parts.is_empty() { RigMessage::assistant("") } else if parts.len() == 1 { // Single part — use simpler constructors match parts.pop().unwrap() { AssistantContent::Text(t) => RigMessage::assistant(t.text), ac => { RigMessage::Assistant { id: None, content: OneOrMany::one(ac), } } } } else { let content = OneOrMany::many(parts).expect("non-empty parts"); RigMessage::Assistant { id: None, content } } } "tool" | "function" => { let id = msg.tool_call_id.as_deref().unwrap_or("unknown").to_string(); let call_id = msg.tool_call_id.clone().or_else(|| Some(id.clone())); let content = msg.content.as_deref().unwrap_or(""); RigMessage::tool_result_with_call_id(id, call_id, content) } "developer" => { // Developer role maps to user/system in rig RigMessage::user(msg.content.as_deref().unwrap_or("")) } _ => RigMessage::user(msg.content.as_deref().unwrap_or("")), } } fn to_rig_tool_def(tool_json: &serde_json::Value) -> Option { let name = tool_json .get("function") .and_then(|f| f.get("name")) .and_then(|n| n.as_str())? .to_string(); let description = tool_json .get("function") .and_then(|f| f.get("description")) .and_then(|d| d.as_str()) .map(|s| s.to_string()) .unwrap_or_default(); let parameters = tool_json .get("function") .and_then(|f| f.get("parameters")) .cloned() .unwrap_or(serde_json::json!({})); Some(ToolDefinition { name, description, parameters, }) } // ── Call helpers ───────────────────────────────────────────────────────────── async fn do_completion( model: &M, messages: &[ChatRequestMessage], temperature: Option, max_tokens: Option, tools: Option<&[serde_json::Value]>, tool_choice: Option<&str>, ) -> Result<(String, u64, u64, Vec)> where M: CompletionModel, { let mut history: Vec = messages.iter().map(to_rig_message).collect(); // Extract preamble (first system message) and remove from history let preamble = messages .iter() .find(|m| m.role == "system") .and_then(|m| m.content.as_deref()) .unwrap_or("") .to_string(); history.retain(|m| !matches!(m, RigMessage::User { .. } | RigMessage::Assistant { .. })); // For tool_result messages, we need to add them back // Actually, let's keep the approach: filter out system, add others back // The rig completion request uses: preamble (system) + messages (conversation) // For our messages: system → preamble, rest → messages let non_system: Vec = messages .iter() .filter(|m| m.role != "system") .map(to_rig_message) .collect(); let tool_defs: Vec = tools .map(|ts| ts.iter().filter_map(to_rig_tool_def).collect()) .unwrap_or_default(); let tc = match tool_choice { Some("none") => rig::completion::message::ToolChoice::None, Some("auto") | None => rig::completion::message::ToolChoice::Auto, Some(s) => rig::completion::message::ToolChoice::Specific { function_names: vec![s.to_string()], }, }; let mut builder = model.completion_request(""); if !preamble.is_empty() { builder = builder.preamble(preamble); } if !non_system.is_empty() { builder = builder.messages(non_system); } if let Some(t) = temperature { builder = builder.temperature(t); } if let Some(mt) = max_tokens { builder = builder.max_tokens(mt as u64); } if !tool_defs.is_empty() { builder = builder.tools(tool_defs); } builder = builder.tool_choice(tc); let response = builder.send().await.map_err(|e| AgentError::OpenAi(e.to_string()))?; let mut content = String::new(); let mut tool_names: Vec = Vec::new(); for item in response.choice { match item { AssistantContent::Text(t) => { content.push_str(&t.text); } AssistantContent::ToolCall(tc) => { tool_names.push(tc.function.name.clone()); } AssistantContent::Reasoning(_) => {} AssistantContent::Image(_) => {} } } let input_tokens = response.usage.input_tokens; let output_tokens = response.usage.output_tokens; Ok((content, input_tokens, output_tokens, tool_names)) } // ── Public API ─────────────────────────────────────────────────────────────── /// Call the AI model with automatic retry (no custom params). pub async fn call_with_retry( messages: &[ChatRequestMessage], model_name: &str, config: &AiClientConfig, max_retries: Option, ) -> Result { let client = config.build_rig_client(); let model = client.completion_model(model_name); let mut state = RetryState::new(max_retries.unwrap_or(3)); loop { let start = Instant::now(); let result = do_completion(&model, messages, None, None, None, None).await; match result { Ok((content, input_tokens, output_tokens, tool_names)) => { let latency_ms = start.elapsed().as_millis() as i64; let has_function_call = !tool_names.is_empty(); ai_metrics().record_success(input_tokens as i64, output_tokens as i64, has_function_call); return Ok(AiCallResponse { content, input_tokens: input_tokens as i64, output_tokens: output_tokens as i64, latency_ms, tool_calls_finished: tool_names }); } Err(ref err) if state.should_retry() && is_retryable_error(err) => { let duration = state.backoff_duration(); tracing::warn!( attempt = state.attempt + 1, max_retries = state.max_retries, backoff_ms = duration.as_millis() as u64, model = %model_name, error = %err, "ai_call_retry" ); tokio::time::sleep(duration).await; state.next(); } Err(err) => { ai_metrics().record_failure(); return Err(err); } } } } /// Call with custom parameters (temperature, max_tokens, optional tools, optional tool_choice). pub async fn call_with_params( messages: &[ChatRequestMessage], model_name: &str, config: &AiClientConfig, temperature: f32, max_tokens: u32, max_retries: Option, tools: Option<&[serde_json::Value]>, tool_choice: Option<&str>, ) -> Result { let client = config.build_rig_client(); let model = client.completion_model(model_name); let mut state = RetryState::new(max_retries.unwrap_or(3)); loop { let start = Instant::now(); let result = do_completion( &model, messages, Some(temperature as f64), Some(max_tokens), tools, tool_choice, ) .await; match result { Ok((content, input_tokens, output_tokens, tool_names)) => { let latency_ms = start.elapsed().as_millis() as i64; let has_function_call = !tool_names.is_empty(); ai_metrics().record_success(input_tokens as i64, output_tokens as i64, has_function_call); return Ok(AiCallResponse { content, input_tokens: input_tokens as i64, output_tokens: output_tokens as i64, latency_ms, tool_calls_finished: tool_names }); } Err(ref err) if state.should_retry() && is_retryable_error(err) => { let duration = state.backoff_duration(); tracing::warn!( attempt = state.attempt + 1, max_retries = state.max_retries, backoff_ms = duration.as_millis() as u64, model = %model_name, error = %err, "ai_call_retry" ); tokio::time::sleep(duration).await; state.next(); } Err(err) => { ai_metrics().record_failure(); return Err(err); } } } } /// A tool call extracted from streaming response with accumulated arguments. #[derive(Debug, Clone)] pub struct StreamedToolCall { /// Tool call ID pub id: String, /// Tool function name pub name: String, /// Accumulated JSON arguments string pub arguments: String, } /// Streaming result from rig. #[derive(Debug)] pub struct StreamResponse { pub content: String, pub input_tokens: i64, pub output_tokens: i64, /// Full tool calls with accumulated arguments (not just names) pub tool_calls: Vec, } /// Run a streaming chat completion. pub async fn call_stream( messages: &[ChatRequestMessage], model_name: &str, config: &AiClientConfig, temperature: f32, max_tokens: u32, tools: Option<&[serde_json::Value]>, mut on_text_delta: impl FnMut(&str), ) -> Result { let client = config.build_rig_client(); let model = client.completion_model(model_name); let preamble = messages .iter() .find(|m| m.role == "system") .and_then(|m| m.content.as_deref()) .unwrap_or("") .to_string(); let non_system: Vec = messages .iter() .filter(|m| m.role != "system") .map(to_rig_message) .collect(); let tool_defs: Vec = tools .map(|ts| ts.iter().filter_map(to_rig_tool_def).collect()) .unwrap_or_default(); let mut builder = model .completion_request("") .temperature(temperature as f64) .max_tokens(max_tokens as u64); if !preamble.is_empty() { builder = builder.preamble(preamble); } if !non_system.is_empty() { builder = builder.messages(non_system); } if !tool_defs.is_empty() { builder = builder.tools(tool_defs); } let mut stream = builder .stream() .await .map_err(|e| AgentError::OpenAi(e.to_string()))?; let mut content = String::new(); let mut tool_calls: Vec = Vec::new(); // Track partial tool calls by internal_call_id for argument accumulation use std::collections::HashMap; let mut partial_tool_calls: HashMap = HashMap::new(); let mut stream_finished = false; use rig::streaming::StreamedAssistantContent; while let Some(item) = stream.next().await { match item { Ok(StreamedAssistantContent::Text(text)) => { content.push_str(&text.text); on_text_delta(&text.text); } Ok(StreamedAssistantContent::ToolCall { tool_call, internal_call_id, }) => { // Complete tool call - extract arguments from the JSON Value let arguments = match &tool_call.function.arguments { serde_json::Value::String(s) => s.clone(), other => serde_json::to_string(other).unwrap_or_else(|_| "{}".to_string()), }; tool_calls.push(StreamedToolCall { id: tool_call.id.clone(), name: tool_call.function.name.clone(), arguments, }); // Remove from partial if it was being accumulated partial_tool_calls.remove(&internal_call_id); } Ok(StreamedAssistantContent::ToolCallDelta { id, internal_call_id, content, }) => { use rig::streaming::ToolCallDeltaContent; match content { ToolCallDeltaContent::Name(name) => { // Start accumulating a new tool call partial_tool_calls.insert( internal_call_id.clone(), StreamedToolCall { id: id.clone(), name, arguments: String::new(), }, ); } ToolCallDeltaContent::Delta(delta) => { // Append to existing partial tool call if let Some(tc) = partial_tool_calls.get_mut(&internal_call_id) { tc.arguments.push_str(&delta); } } } } Ok(StreamedAssistantContent::Reasoning(_)) => {} Ok(StreamedAssistantContent::ReasoningDelta { .. }) => {} Ok(StreamedAssistantContent::Final(response)) => { stream_finished = true; // Flush any remaining partial tool calls for (_, tc) in partial_tool_calls.drain() { tool_calls.push(tc); } if let Some(usage) = response.token_usage() { ai_metrics().record_success( usage.input_tokens as i64, usage.output_tokens as i64, !tool_calls.is_empty(), ); return Ok(StreamResponse { content, input_tokens: usage.input_tokens as i64, output_tokens: usage.output_tokens as i64, tool_calls, }); } } Err(e) => return Err(AgentError::OpenAi(e.to_string())), } } // Flush any remaining partial tool calls (if stream ended without Final) if !stream_finished { for (_, tc) in partial_tool_calls.drain() { tool_calls.push(tc); } } ai_metrics().record_success(0, 0, !tool_calls.is_empty()); Ok(StreamResponse { content, input_tokens: 0, output_tokens: 0, tool_calls, }) }