//! Unified AI client with built-in retry, token tracking, and session recording. //! //! Provides a single entry point for all AI calls with: //! - Exponential backoff with jitter (max 3 retries) //! - Retryable error classification (429/500/502/503/504) //! - Token usage tracking (input/output) use async_openai::Client; use async_openai::config::OpenAIConfig; use async_openai::types::chat::{ ChatCompletionRequestMessage, ChatCompletionTool, ChatCompletionToolChoiceOption, ChatCompletionTools, CreateChatCompletionRequest, CreateChatCompletionResponse, }; use std::time::Instant; use crate::error::{AgentError, Result}; /// AI call metrics — increments metrics crate counters for all AI calls. /// These are registered in observability::install_recorder() and exported /// via both the Prometheus /metrics endpoint and the Redis metrics flusher. #[derive(Debug, Clone, Default)] pub struct AiMetrics; impl AiMetrics { pub fn new() -> Self { Self } /// Record a successful AI call with token usage. pub fn record_success(&self, input_tokens: i64, output_tokens: i64, has_function_call: bool) { metrics::counter!("ai_calls_total").increment(1); metrics::counter!("ai_calls_success").increment(1); if input_tokens > 0 { metrics::counter!("ai_input_tokens_total").increment(input_tokens as u64); } if output_tokens > 0 { metrics::counter!("ai_output_tokens_total").increment(output_tokens as u64); } if has_function_call { metrics::counter!("ai_function_calls_total").increment(1); } } /// Record a failed AI call. pub fn record_failure(&self) { metrics::counter!("ai_calls_total").increment(1); metrics::counter!("ai_calls_failure").increment(1); } } /// Configuration for the AI client. #[derive(Clone)] pub struct AiClientConfig { pub api_key: String, pub base_url: Option, } impl AiClientConfig { pub fn new(api_key: String) -> Self { Self { api_key, base_url: None, } } pub fn with_base_url(mut self, base_url: impl Into) -> Self { self.base_url = Some(base_url.into()); self } pub fn build_client(&self) -> Client { let mut config = OpenAIConfig::new().with_api_key(&self.api_key); if let Some(ref url) = self.base_url { config = config.with_api_base(url); } Client::with_config(config) } } /// Response from an AI call, including usage statistics. #[derive(Debug, Clone)] pub struct AiCallResponse { pub content: String, pub input_tokens: i64, pub output_tokens: i64, pub latency_ms: i64, } impl AiCallResponse { pub fn total_tokens(&self) -> i64 { self.input_tokens + self.output_tokens } } /// Internal state for retry tracking. #[derive(Debug)] struct RetryState { attempt: u32, max_retries: u32, max_backoff_ms: u64, } impl RetryState { fn new(max_retries: u32) -> Self { Self { attempt: 0, max_retries, max_backoff_ms: 5000, } } fn should_retry(&self) -> bool { self.attempt < self.max_retries } /// Calculate backoff duration with full jitter technique. /// sleep = random(0, min(cap, base * 2^attempt)) fn backoff_duration(&self) -> std::time::Duration { let exp = self.attempt.min(5); // base = 500 * 2^exp, capped at max_backoff_ms let base_ms = 500u64 .saturating_mul(2u64.pow(exp)) .min(self.max_backoff_ms); // Full jitter: random value in [0, base_ms] let jitter = fastrand_u64(base_ms + 1) as u64; std::time::Duration::from_millis(jitter) } fn next(&mut self) { self.attempt += 1; } } /// Fast pseudo-random u64 using a simple LCG. /// Good enough for jitter — not for cryptography. fn fastrand_u64(n: u64) -> u64 { use std::sync::atomic::{AtomicU64, Ordering}; static STATE: AtomicU64 = AtomicU64::new(0x193_667_6a_5e_7c_57); if n <= 1 { return 0; } let mut current = STATE.load(Ordering::Relaxed); loop { let new_val = current.wrapping_mul(6364136223846793005).wrapping_add(1); match STATE.compare_exchange_weak(current, new_val, Ordering::Relaxed, Ordering::Relaxed) { Ok(_) => return new_val % n, Err(actual) => current = actual, } } } /// Determine if an error is retryable. fn is_retryable_error(err: &async_openai::error::OpenAIError) -> bool { use async_openai::error::OpenAIError; match err { // Network errors (DNS failure, connection refused, timeout) are always retryable OpenAIError::Reqwest(_) => true, // For API errors, check the error code string (e.g., "rate_limit_exceeded") OpenAIError::ApiError(api_err) => api_err.code.as_ref().map_or(false, |code| { matches!( code.as_str(), "rate_limit_exceeded" | "internal_server_error" | "service_unavailable" | "gateway_timeout" | "bad_gateway" ) }), _ => false, } } /// Global AI metrics shared across all AI client calls. static AI_METRICS: std::sync::OnceLock = std::sync::OnceLock::new(); fn ai_metrics() -> &'static AiMetrics { AI_METRICS.get_or_init(AiMetrics::new) } /// Call the AI model with automatic retry. pub async fn call_with_retry( messages: &[ChatCompletionRequestMessage], model: &str, config: &AiClientConfig, max_retries: Option, ) -> Result { let client = config.build_client(); let mut state = RetryState::new(max_retries.unwrap_or(3)); loop { let start = Instant::now(); let req = CreateChatCompletionRequest { model: model.to_string(), messages: messages.to_vec(), ..Default::default() }; let result = client.chat().create(req).await; match result { Ok(response) => { let latency_ms = start.elapsed().as_millis() as i64; let (input_tokens, output_tokens) = extract_usage(&response); // Check if response contains a tool call let has_function_call = response .choices .first() .and_then(|c| c.finish_reason.as_ref()) .map_or(false, |r| *r == async_openai::types::chat::FinishReason::ToolCalls); ai_metrics().record_success(input_tokens, output_tokens, has_function_call); return Ok(AiCallResponse { content: extract_content(&response), input_tokens, output_tokens, latency_ms, }); } Err(err) => { if state.should_retry() && is_retryable_error(&err) { let duration = state.backoff_duration(); tracing::warn!( attempt = state.attempt + 1, max_retries = state.max_retries, backoff_ms = duration.as_millis() as u64, model = %model, error = %err.to_string(), "ai_call_retry" ); tokio::time::sleep(duration).await; state.next(); continue; } ai_metrics().record_failure(); return Err(AgentError::OpenAi(err.to_string())); } } } } /// Call with custom parameters (temperature, max_tokens, optional tools, optional tool_choice). /// /// When `tool_choice` is `None` and tools are present, the default is `Auto`. /// Pass `Some(ChatCompletionToolChoiceOption::None)` to force the model to respond /// with text only (e.g. when you want JSON-in-text for ReAct parsing). pub async fn call_with_params( messages: &[ChatCompletionRequestMessage], model: &str, config: &AiClientConfig, temperature: f32, max_tokens: u32, max_retries: Option, tools: Option<&[ChatCompletionTool]>, tool_choice: Option, ) -> Result { let client = config.build_client(); let mut state = RetryState::new(max_retries.unwrap_or(3)); loop { let start = Instant::now(); let req = CreateChatCompletionRequest { model: model.to_string(), messages: messages.to_vec(), temperature: Some(temperature), max_completion_tokens: Some(max_tokens), tools: tools.map(|ts| { ts.iter() .map(|t| ChatCompletionTools::Function(t.clone())) .collect() }), tool_choice: tool_choice.clone(), ..Default::default() }; let result = client.chat().create(req).await; match result { Ok(response) => { let latency_ms = start.elapsed().as_millis() as i64; let (input_tokens, output_tokens) = extract_usage(&response); // Check if response contains a tool call let has_function_call = response .choices .first() .and_then(|c| c.finish_reason.as_ref()) .map_or(false, |r| *r == async_openai::types::chat::FinishReason::ToolCalls); ai_metrics().record_success(input_tokens, output_tokens, has_function_call); return Ok(AiCallResponse { content: extract_content(&response), input_tokens, output_tokens, latency_ms, }); } Err(err) => { if state.should_retry() && is_retryable_error(&err) { let duration = state.backoff_duration(); tracing::warn!( attempt = state.attempt + 1, max_retries = state.max_retries, backoff_ms = duration.as_millis() as u64, model = %model, error = %err.to_string(), "ai_call_retry" ); tokio::time::sleep(duration).await; state.next(); continue; } ai_metrics().record_failure(); return Err(AgentError::OpenAi(err.to_string())); } } } } /// Extract text content from a chat completion response. fn extract_content(response: &CreateChatCompletionResponse) -> String { response .choices .first() .and_then(|c| c.message.content.clone()) .unwrap_or_default() } /// Extract usage (input/output tokens) from a response. fn extract_usage(response: &CreateChatCompletionResponse) -> (i64, i64) { response .usage .as_ref() .map(|u| { ( i64::try_from(u.prompt_tokens).unwrap_or(0), i64::try_from(u.completion_tokens).unwrap_or(0), ) }) .unwrap_or((0, 0)) }