//! Unified AI client with built-in retry, token tracking, and session recording. //! //! Uses rig-core as the underlying AI provider library. pub mod types; pub use types::{ChatRequestMessage, ToolCall as ClientToolCall}; use std::pin::Pin; use std::sync::Arc; use std::time::Instant; use uuid::Uuid; use crate::error::{AgentError, Result}; use futures::StreamExt; use rig::completion::message::{AssistantContent, Message as RigMessage}; use rig::completion::{CompletionModel, GetTokenUsage, ToolDefinition}; use rig::one_or_many::OneOrMany; use rig::prelude::CompletionClient; use rig::providers::openai; /// AI call metrics — increments metrics crate counters for all AI calls. #[derive(Debug, Clone, Default)] pub struct AiMetrics; impl AiMetrics { pub fn new() -> Self { Self } pub fn record_success(&self, input_tokens: i64, output_tokens: i64, has_function_call: bool) { metrics::counter!("ai_calls_total").increment(1); metrics::counter!("ai_calls_success").increment(1); if input_tokens > 0 { metrics::counter!("ai_input_tokens_total").increment(input_tokens as u64); } if output_tokens > 0 { metrics::counter!("ai_output_tokens_total").increment(output_tokens as u64); } if has_function_call { metrics::counter!("ai_function_calls_total").increment(1); } } pub fn record_failure(&self) { metrics::counter!("ai_calls_total").increment(1); metrics::counter!("ai_calls_failure").increment(1); } } /// Configuration for the AI client. #[derive(Clone)] pub struct AiClientConfig { pub api_key: String, pub base_url: Option, } impl AiClientConfig { pub fn new(api_key: String) -> Self { Self { api_key, base_url: None, } } pub fn with_base_url(mut self, base_url: impl Into) -> Self { self.base_url = Some(base_url.into()); self } /// Build a rig OpenAI client from this config. pub fn build_rig_client(&self) -> openai::Client { let base = self .base_url .clone() .unwrap_or_else(|| "https://api.openai.com".to_string()); openai::Client::builder() .api_key(&self.api_key) .base_url(&base) .build() .expect("Failed to build rig OpenAI client") } } /// Response from an AI call, including usage statistics. #[derive(Debug, Clone)] pub struct AiCallResponse { pub content: String, pub input_tokens: i64, pub output_tokens: i64, pub latency_ms: i64, pub tool_calls: Vec, pub tool_calls_finished: Vec, } impl AiCallResponse { pub fn total_tokens(&self) -> i64 { self.input_tokens + self.output_tokens } } /// Internal state for retry tracking. #[derive(Debug)] struct RetryState { attempt: u32, max_retries: u32, max_backoff_ms: u64, } impl RetryState { fn new(max_retries: u32) -> Self { Self { attempt: 0, max_retries, max_backoff_ms: 5000, } } fn should_retry(&self) -> bool { self.attempt < self.max_retries } fn backoff_duration(&self) -> std::time::Duration { let exp = self.attempt.min(5); let base_ms = 500u64 .saturating_mul(2u64.pow(exp)) .min(self.max_backoff_ms); let max_jitter = (base_ms / 2).max(base_ms); let offset = fastrand_u64(max_jitter + 1).saturating_sub(base_ms / 2); let total = base_ms.saturating_add(offset).min(self.max_backoff_ms); std::time::Duration::from_millis(total) } fn next(&mut self) { self.attempt += 1; } } fn fastrand_u64(n: u64) -> u64 { use std::sync::atomic::{AtomicU64, Ordering}; static STATE: AtomicU64 = AtomicU64::new(0x193_667_6a_5e_7c_57); if n <= 1 { return 0; } let mut current = STATE.load(Ordering::Relaxed); loop { let new_val = current.wrapping_mul(6364136223846793005).wrapping_add(1); match STATE.compare_exchange_weak(current, new_val, Ordering::Relaxed, Ordering::Relaxed) { Ok(_) => return new_val % n, Err(actual) => current = actual, } } } fn is_retryable_error(err: &AgentError) -> bool { let msg = err.to_string(); msg.contains("connection refused") || msg.contains("connection timed out") || msg.contains("network error") || msg.contains("dns error") || msg.contains("error sending request") || msg.contains("Http client error") || msg.contains("rate_limit") || msg.contains("rate limit") || msg.contains("429") || msg.contains("500") || msg.contains("502") || msg.contains("503") || msg.contains("504") || msg.contains("internal_server_error") || msg.contains("service_unavailable") || msg.contains("gateway_timeout") || msg.contains("bad_gateway") } static AI_METRICS: std::sync::OnceLock = std::sync::OnceLock::new(); fn ai_metrics() -> &'static AiMetrics { AI_METRICS.get_or_init(AiMetrics::new) } // ── Type conversions ───────────────────────────────────────────────────────── pub(crate) fn to_rig_message(msg: &ChatRequestMessage) -> RigMessage { match msg.role.as_str() { "system" => { // System messages are handled via preamble(), not passed as messages. // We still need to return a valid RigMessage variant. RigMessage::user(msg.content.as_deref().unwrap_or("")) } "user" => RigMessage::user(msg.content.as_deref().unwrap_or("")), "assistant" => { let mut parts: Vec = Vec::new(); if let Some(ref content) = msg.content { if !content.is_empty() { parts.push(AssistantContent::text(content)); } } if let Some(ref tool_calls) = msg.tool_calls { for tc in tool_calls { // GLM may return empty tool call IDs — fall back to a generated UUID. let id = if tc.id.is_empty() { Uuid::new_v4().to_string() } else { tc.id.clone() }; parts.push(AssistantContent::tool_call_with_call_id( &id, id.clone(), &tc.function.name, serde_json::from_str(&tc.function.arguments) .unwrap_or(serde_json::Value::Null), )); } } if parts.is_empty() { RigMessage::assistant("") } else if parts.len() == 1 { // Single part — use simpler constructors match parts.pop().unwrap() { AssistantContent::Text(t) => RigMessage::assistant(t.text), ac => RigMessage::Assistant { id: None, content: OneOrMany::one(ac), }, } } else { let content = OneOrMany::many(parts).expect("non-empty parts"); RigMessage::Assistant { id: None, content } } } "tool" | "function" => { let id = msg.tool_call_id.as_deref().unwrap_or("unknown").to_string(); let call_id = msg.tool_call_id.clone().or_else(|| Some(id.clone())); let content = msg.content.as_deref().unwrap_or(""); RigMessage::tool_result_with_call_id(id, call_id, content) } "developer" => { // Developer role maps to user/system in rig RigMessage::user(msg.content.as_deref().unwrap_or("")) } _ => RigMessage::user(msg.content.as_deref().unwrap_or("")), } } fn to_rig_tool_def(tool_json: &serde_json::Value) -> Option { let name = tool_json .get("function") .and_then(|f| f.get("name")) .and_then(|n| n.as_str())? .to_string(); let description = tool_json .get("function") .and_then(|f| f.get("description")) .and_then(|d| d.as_str()) .map(|s| s.to_string()) .unwrap_or_default(); let parameters = tool_json .get("function") .and_then(|f| f.get("parameters")) .cloned() .unwrap_or(serde_json::json!({})); Some(ToolDefinition { name, description, parameters, }) } // ── Call helpers ───────────────────────────────────────────────────────────── async fn do_completion( model: &M, messages: &[ChatRequestMessage], temperature: Option, max_tokens: Option, tools: Option<&[serde_json::Value]>, tool_choice: Option<&str>, ) -> Result<(String, u64, u64, Vec, Vec)> where M: CompletionModel, { let preamble = messages .iter() .find(|m| m.role == "system") .and_then(|m| m.content.as_deref()) .unwrap_or("") .to_string(); let non_system: Vec = messages .iter() .filter(|m| m.role != "system") .map(to_rig_message) .collect(); let tool_defs: Vec = tools .map(|ts| ts.iter().filter_map(to_rig_tool_def).collect()) .unwrap_or_default(); let mut builder = model.completion_request(""); if !preamble.is_empty() { builder = builder.preamble(preamble); } if !non_system.is_empty() { builder = builder.messages(non_system); } if let Some(t) = temperature { builder = builder.temperature(t); } if let Some(mt) = max_tokens { builder = builder.max_tokens(mt as u64); } if !tool_defs.is_empty() { builder = builder.tools(tool_defs); } // Only set tool_choice when explicitly provided (mirrors call_stream_once logic) if let Some(tc) = tool_choice { match tc { "none" => { builder = builder.tool_choice(rig::completion::message::ToolChoice::None); } "auto" => { builder = builder.tool_choice(rig::completion::message::ToolChoice::Auto); } s => { builder = builder.tool_choice(rig::completion::message::ToolChoice::Specific { function_names: vec![s.to_string()], }); } } } let response = builder .send() .await .map_err(|e| AgentError::OpenAi(e.to_string()))?; let mut content = String::new(); let mut tool_names: Vec = Vec::new(); let mut tool_calls: Vec = Vec::new(); for item in response.choice { match item { AssistantContent::Text(t) => { content.push_str(&t.text); } AssistantContent::ToolCall(tc) => { tool_names.push(tc.function.name.clone()); tool_calls.push(ClientToolCall { id: tc.id, type_: "function".into(), function: types::ToolCallFunction { name: tc.function.name, arguments: serde_json::to_string(&tc.function.arguments) .unwrap_or_else(|_| "{}".to_string()), }, }); } AssistantContent::Reasoning(_) => {} AssistantContent::Image(_) => {} } } let input_tokens = response.usage.input_tokens; let output_tokens = response.usage.output_tokens; Ok((content, input_tokens, output_tokens, tool_calls, tool_names)) } // ── Public API ─────────────────────────────────────────────────────────────── /// Call the AI model with automatic retry (no custom params). pub async fn call_with_retry( messages: &[ChatRequestMessage], model_name: &str, config: &AiClientConfig, max_retries: Option, ) -> Result { let client = config.build_rig_client(); let model = client.completion_model(model_name); let mut state = RetryState::new(max_retries.unwrap_or(3)); loop { let start = Instant::now(); let result = do_completion(&model, messages, None, None, None, None).await; match result { Ok((content, input_tokens, output_tokens, tool_calls, tool_names)) => { let latency_ms = start.elapsed().as_millis() as i64; let has_function_call = !tool_names.is_empty(); ai_metrics().record_success( input_tokens as i64, output_tokens as i64, has_function_call, ); return Ok(AiCallResponse { content, input_tokens: input_tokens as i64, output_tokens: output_tokens as i64, latency_ms, tool_calls, tool_calls_finished: tool_names, }); } Err(ref err) if state.should_retry() && is_retryable_error(err) => { let duration = state.backoff_duration(); tracing::warn!( attempt = state.attempt + 1, max_retries = state.max_retries, backoff_ms = duration.as_millis() as u64, model = %model_name, error = %err, "ai_call_retry" ); tokio::time::sleep(duration).await; state.next(); } Err(err) => { ai_metrics().record_failure(); return Err(err); } } } } /// Call with custom parameters (temperature, max_tokens, optional tools, optional tool_choice). pub async fn call_with_params( messages: &[ChatRequestMessage], model_name: &str, config: &AiClientConfig, temperature: f32, max_tokens: u32, max_retries: Option, tools: Option<&[serde_json::Value]>, tool_choice: Option<&str>, ) -> Result { let client = config.build_rig_client(); let model = client.completion_model(model_name); let mut state = RetryState::new(max_retries.unwrap_or(3)); loop { let start = Instant::now(); let result = do_completion( &model, messages, Some(temperature as f64), Some(max_tokens), tools, tool_choice, ) .await; match result { Ok((content, input_tokens, output_tokens, tool_calls, tool_names)) => { let latency_ms = start.elapsed().as_millis() as i64; let has_function_call = !tool_names.is_empty(); ai_metrics().record_success( input_tokens as i64, output_tokens as i64, has_function_call, ); return Ok(AiCallResponse { content, input_tokens: input_tokens as i64, output_tokens: output_tokens as i64, latency_ms, tool_calls, tool_calls_finished: tool_names, }); } Err(ref err) if state.should_retry() && is_retryable_error(err) => { let duration = state.backoff_duration(); tracing::warn!( attempt = state.attempt + 1, max_retries = state.max_retries, backoff_ms = duration.as_millis() as u64, model = %model_name, error = %err, "ai_call_retry" ); tokio::time::sleep(duration).await; state.next(); } Err(err) => { ai_metrics().record_failure(); return Err(err); } } } } /// A tool call extracted from streaming response with accumulated arguments. #[derive(Debug, Clone)] pub struct StreamedToolCall { /// Tool call ID pub id: String, /// Tool function name pub name: String, /// Accumulated JSON arguments string pub arguments: String, } /// Type of chunk in the streaming response, preserving arrival order. #[derive(Debug, Clone, PartialEq, Eq)] pub enum StreamChunkType { Thinking, Answer, ToolCall, ToolResult, } /// A single chunk from the streaming response in arrival order. #[derive(Debug, Clone)] pub struct StreamChunk { pub chunk_type: StreamChunkType, pub content: String, } /// Streaming result from rig. #[derive(Debug)] pub struct StreamResponse { pub content: String, pub input_tokens: i64, pub output_tokens: i64, /// Accumulated reasoning/thinking text from the model. pub reasoning_content: String, /// Full tool calls with accumulated arguments (not just names) pub tool_calls: Vec, /// All chunks in arrival order — preserves think/answer/tool interleaving. pub chunks: Vec, } /// Async callback: takes a string delta and broadcasts it to the WebSocket. /// The returned Future must be awaited by the caller. pub type StreamTextCb = Arc Pin + Send>> + Send + Sync>; pub type StreamReasoningCb = Arc Pin + Send>> + Send + Sync>; pub type StreamToolCallCb = Arc< dyn Fn(&StreamedToolCall) -> Pin + Send>> + Send + Sync, >; /// Run a streaming chat completion with 60s timeout and 5 retries. pub async fn call_stream( messages: &[ChatRequestMessage], model_name: &str, config: &AiClientConfig, temperature: f32, max_tokens: u32, tools: Option<&[serde_json::Value]>, tool_choice: Option<&str>, on_text_delta: StreamTextCb, on_reasoning_delta: StreamReasoningCb, on_tool_call: StreamToolCallCb, ) -> Result { let mut state = RetryState::new(5); loop { let result = call_stream_once( messages, model_name, config, temperature, max_tokens, tools, tool_choice, on_text_delta.clone(), on_reasoning_delta.clone(), on_tool_call.clone(), ) .await; match result { Ok(response) => return Ok(response), Err(ref err) if state.should_retry() && is_retryable_error(err) => { let duration = state.backoff_duration(); tracing::warn!( attempt = state.attempt + 1, max_retries = 5, backoff_ms = duration.as_millis() as u64, model = %model_name, error = %err, "ai_stream_retry" ); tokio::time::sleep(duration).await; state.next(); } Err(err) => { ai_metrics().record_failure(); return Err(err); } } } } /// Single attempt of streaming completion with 60s timeout. async fn call_stream_once( messages: &[ChatRequestMessage], model_name: &str, config: &AiClientConfig, temperature: f32, max_tokens: u32, tools: Option<&[serde_json::Value]>, tool_choice: Option<&str>, on_text_delta: StreamTextCb, on_reasoning_delta: StreamReasoningCb, on_tool_call: StreamToolCallCb, ) -> Result { let client = config.build_rig_client(); let model = client.completion_model(model_name); let preamble = messages .iter() .find(|m| m.role == "system") .and_then(|m| m.content.as_deref()) .unwrap_or("") .to_string(); let non_system: Vec = messages .iter() .filter(|m| m.role != "system") .map(to_rig_message) .collect(); let tool_defs: Vec = tools .map(|ts| ts.iter().filter_map(to_rig_tool_def).collect()) .unwrap_or_default(); let mut builder = model .completion_request("") .temperature(temperature as f64) .max_tokens(max_tokens as u64); if !preamble.is_empty() { builder = builder.preamble(preamble); } if !non_system.is_empty() { builder = builder.messages(non_system); } if !tool_defs.is_empty() { builder = builder.tools(tool_defs); } if let Some(tc) = tool_choice { match tc { "none" => { builder = builder.tool_choice(rig::completion::message::ToolChoice::None); } "auto" => { builder = builder.tool_choice(rig::completion::message::ToolChoice::Auto); } s => { builder = builder.tool_choice(rig::completion::message::ToolChoice::Specific { function_names: vec![s.to_string()], }); } } } let stream_fut = async { let mut stream = builder .stream() .await .map_err(|e| AgentError::OpenAi(e.to_string()))?; let mut content = String::new(); let mut reasoning_content = String::new(); let mut tool_calls: Vec = Vec::new(); let mut chunks: Vec = Vec::new(); // Some models (e.g. GLM) ignore tool_choice="none" and still emit tool_calls. // Filter them out so they don't cause spurious tool execution attempts. let skip_tool_calls = tool_choice == Some("none"); use std::collections::HashMap; let mut partial_tool_calls: HashMap = HashMap::new(); let mut stream_finished = false; use rig::streaming::StreamedAssistantContent; while let Some(item) = stream.next().await { match item { Ok(StreamedAssistantContent::Text(text)) => { content.push_str(&text.text); on_text_delta(&text.text).await; chunks.push(StreamChunk { chunk_type: StreamChunkType::Answer, content: text.text, }); } Ok(StreamedAssistantContent::ToolCall { tool_call, internal_call_id, }) => { if skip_tool_calls { partial_tool_calls.remove(&internal_call_id); continue; } let arguments = match &tool_call.function.arguments { serde_json::Value::String(s) => s.clone(), other => serde_json::to_string(other).unwrap_or_else(|_| "{}".to_string()), }; let tc = StreamedToolCall { id: tool_call.id.clone(), name: tool_call.function.name.clone(), arguments, }; on_tool_call(&tc).await; chunks.push(StreamChunk { chunk_type: StreamChunkType::ToolCall, content: serde_json::json!({ "id": tc.id, "name": tc.name, "arguments": tc.arguments, }) .to_string(), }); tool_calls.push(tc); partial_tool_calls.remove(&internal_call_id); } Ok(StreamedAssistantContent::ToolCallDelta { id, internal_call_id, content: delta_content, }) => { if skip_tool_calls { continue; } use rig::streaming::ToolCallDeltaContent; match delta_content { ToolCallDeltaContent::Name(name) => { partial_tool_calls.insert( internal_call_id.clone(), StreamedToolCall { id: id.clone(), name, arguments: String::new(), }, ); } ToolCallDeltaContent::Delta(delta) => { if let Some(tc) = partial_tool_calls.get_mut(&internal_call_id) { tc.arguments.push_str(&delta); } } } } Ok(StreamedAssistantContent::Reasoning(reasoning)) => { for part in &reasoning.content { if let rig::completion::message::ReasoningContent::Text { text, .. } = part { reasoning_content.push_str(text); on_reasoning_delta(text).await; chunks.push(StreamChunk { chunk_type: StreamChunkType::Thinking, content: text.clone(), }); } } } Ok(StreamedAssistantContent::ReasoningDelta { reasoning, .. }) => { reasoning_content.push_str(&reasoning); on_reasoning_delta(&reasoning).await; chunks.push(StreamChunk { chunk_type: StreamChunkType::Thinking, content: reasoning.clone(), }); } Ok(StreamedAssistantContent::Final(response)) => { stream_finished = true; if !skip_tool_calls { for (_, tc) in partial_tool_calls.drain() { tool_calls.push(tc); } } else { partial_tool_calls.drain(); } if let Some(usage) = response.token_usage() { let in_toks = usage.input_tokens as i64; let out_toks = usage.output_tokens as i64; ai_metrics().record_success(in_toks, out_toks, !tool_calls.is_empty()); return Ok(StreamResponse { content, reasoning_content, input_tokens: in_toks, output_tokens: out_toks, tool_calls, chunks, }); } // Usage not available from Final — fall through to flush } Err(e) => return Err(AgentError::OpenAi(e.to_string())), } } // Flush any remaining partial tool calls (if stream ended without Final or Final had no usage) if !stream_finished && !skip_tool_calls { for (_, tc) in partial_tool_calls.drain() { tool_calls.push(tc); } } ai_metrics().record_success(0, 0, !tool_calls.is_empty()); Ok(StreamResponse { content, reasoning_content, input_tokens: 0, output_tokens: 0, tool_calls, chunks, }) }; // 120s timeout for the entire stream match tokio::time::timeout(std::time::Duration::from_secs(120), stream_fut).await { Ok(result) => result, Err(_) => Err(AgentError::Timeout { task_id: 0, seconds: 120, }), } }