//! Unified AI client with built-in retry, token tracking, and session recording. //! //! Provides a single entry point for all AI calls with: //! - Exponential backoff with jitter (max 3 retries) //! - Retryable error classification (429/500/502/503/504) //! - Token usage tracking (input/output) use async_openai::Client; use async_openai::config::OpenAIConfig; use async_openai::types::chat::{ ChatCompletionRequestMessage, ChatCompletionTool, ChatCompletionToolChoiceOption, ChatCompletionTools, CreateChatCompletionRequest, CreateChatCompletionResponse, }; use std::time::Instant; use crate::error::{AgentError, Result}; /// Configuration for the AI client. #[derive(Clone)] pub struct AiClientConfig { pub api_key: String, pub base_url: Option, } impl AiClientConfig { pub fn new(api_key: String) -> Self { Self { api_key, base_url: None, } } pub fn with_base_url(mut self, base_url: impl Into) -> Self { self.base_url = Some(base_url.into()); self } pub fn build_client(&self) -> Client { let mut config = OpenAIConfig::new().with_api_key(&self.api_key); if let Some(ref url) = self.base_url { config = config.with_api_base(url); } Client::with_config(config) } } /// Response from an AI call, including usage statistics. #[derive(Debug, Clone)] pub struct AiCallResponse { pub content: String, pub input_tokens: i64, pub output_tokens: i64, pub latency_ms: i64, } impl AiCallResponse { pub fn total_tokens(&self) -> i64 { self.input_tokens + self.output_tokens } } /// Internal state for retry tracking. #[derive(Debug)] struct RetryState { attempt: u32, max_retries: u32, max_backoff_ms: u64, } impl RetryState { fn new(max_retries: u32) -> Self { Self { attempt: 0, max_retries, max_backoff_ms: 5000, } } fn should_retry(&self) -> bool { self.attempt < self.max_retries } /// Calculate backoff duration with "full jitter" technique. fn backoff_duration(&self) -> std::time::Duration { let exp = self.attempt.min(5); // base = 500 * 2^exp, capped at max_backoff_ms let base_ms = 500u64 .saturating_mul(2u64.pow(exp)) .min(self.max_backoff_ms); // jitter: random [0, base_ms/2] let jitter = (fastrand_u64(base_ms / 2 + 1)) as u64; std::time::Duration::from_millis(base_ms / 2 + jitter) } fn next(&mut self) { self.attempt += 1; } } /// Fast pseudo-random u64 using a simple LCG. /// Good enough for jitter — not for cryptography. fn fastrand_u64(n: u64) -> u64 { use std::sync::atomic::{AtomicU64, Ordering}; static STATE: AtomicU64 = AtomicU64::new(0x193_667_6a_5e_7c_57); if n <= 1 { return 0; } let mut current = STATE.load(Ordering::Relaxed); loop { let new_val = current.wrapping_mul(6364136223846793005).wrapping_add(1); match STATE.compare_exchange_weak(current, new_val, Ordering::Relaxed, Ordering::Relaxed) { Ok(_) => return new_val % n, Err(actual) => current = actual, } } } /// Determine if an error is retryable. fn is_retryable_error(err: &async_openai::error::OpenAIError) -> bool { use async_openai::error::OpenAIError; match err { // Network errors (DNS failure, connection refused, timeout) are always retryable OpenAIError::Reqwest(_) => true, // For API errors, check the error code string (e.g., "rate_limit_exceeded") OpenAIError::ApiError(api_err) => api_err.code.as_ref().map_or(false, |code| { matches!( code.as_str(), "rate_limit_exceeded" | "internal_server_error" | "service_unavailable" | "gateway_timeout" | "bad_gateway" ) }), _ => false, } } /// Call the AI model with automatic retry. pub async fn call_with_retry( messages: &[ChatCompletionRequestMessage], model: &str, config: &AiClientConfig, max_retries: Option, ) -> Result { let client = config.build_client(); let mut state = RetryState::new(max_retries.unwrap_or(3)); loop { let start = Instant::now(); let req = CreateChatCompletionRequest { model: model.to_string(), messages: messages.to_vec(), ..Default::default() }; let result = client.chat().create(req).await; match result { Ok(response) => { let latency_ms = start.elapsed().as_millis() as i64; let (input_tokens, output_tokens) = extract_usage(&response); return Ok(AiCallResponse { content: extract_content(&response), input_tokens, output_tokens, latency_ms, }); } Err(err) => { if state.should_retry() && is_retryable_error(&err) { let duration = state.backoff_duration(); eprintln!( "AI call failed (attempt {}/{}), retrying in {:?}", state.attempt + 1, state.max_retries, duration ); tokio::time::sleep(duration).await; state.next(); continue; } return Err(AgentError::OpenAi(err.to_string())); } } } } /// Call with custom parameters (temperature, max_tokens, optional tools). pub async fn call_with_params( messages: &[ChatCompletionRequestMessage], model: &str, config: &AiClientConfig, temperature: f32, max_tokens: u32, max_retries: Option, tools: Option<&[ChatCompletionTool]>, ) -> Result { let client = config.build_client(); let mut state = RetryState::new(max_retries.unwrap_or(3)); loop { let start = Instant::now(); let req = CreateChatCompletionRequest { model: model.to_string(), messages: messages.to_vec(), temperature: Some(temperature), max_completion_tokens: Some(max_tokens), tools: tools.map(|ts| { ts.iter() .map(|t| ChatCompletionTools::Function(t.clone())) .collect() }), tool_choice: tools.filter(|ts| !ts.is_empty()).map(|_| { ChatCompletionToolChoiceOption::Mode( async_openai::types::chat::ToolChoiceOptions::Auto, ) }), ..Default::default() }; let result = client.chat().create(req).await; match result { Ok(response) => { let latency_ms = start.elapsed().as_millis() as i64; let (input_tokens, output_tokens) = extract_usage(&response); return Ok(AiCallResponse { content: extract_content(&response), input_tokens, output_tokens, latency_ms, }); } Err(err) => { if state.should_retry() && is_retryable_error(&err) { let duration = state.backoff_duration(); eprintln!( "AI call failed (attempt {}/{}), retrying in {:?}", state.attempt + 1, state.max_retries, duration ); tokio::time::sleep(duration).await; state.next(); continue; } return Err(AgentError::OpenAi(err.to_string())); } } } } /// Extract text content from a chat completion response. fn extract_content(response: &CreateChatCompletionResponse) -> String { response .choices .first() .and_then(|c| c.message.content.clone()) .unwrap_or_default() } /// Extract usage (input/output tokens) from a response. fn extract_usage(response: &CreateChatCompletionResponse) -> (i64, i64) { response .usage .as_ref() .map(|u| { ( i64::try_from(u.prompt_tokens).unwrap_or(0), i64::try_from(u.completion_tokens).unwrap_or(0), ) }) .unwrap_or((0, 0)) }