280 lines
8.5 KiB
Rust
280 lines
8.5 KiB
Rust
//! Unified AI client with built-in retry, token tracking, and session recording.
|
|
//!
|
|
//! Provides a single entry point for all AI calls with:
|
|
//! - Exponential backoff with jitter (max 3 retries)
|
|
//! - Retryable error classification (429/500/502/503/504)
|
|
//! - Token usage tracking (input/output)
|
|
|
|
use async_openai::Client;
|
|
use async_openai::config::OpenAIConfig;
|
|
use async_openai::types::chat::{
|
|
ChatCompletionRequestMessage, ChatCompletionTool, ChatCompletionToolChoiceOption,
|
|
ChatCompletionTools, CreateChatCompletionRequest, CreateChatCompletionResponse,
|
|
};
|
|
use std::time::Instant;
|
|
|
|
use crate::error::{AgentError, Result};
|
|
|
|
/// Configuration for the AI client.
|
|
#[derive(Clone)]
|
|
pub struct AiClientConfig {
|
|
pub api_key: String,
|
|
pub base_url: Option<String>,
|
|
}
|
|
|
|
impl AiClientConfig {
|
|
pub fn new(api_key: String) -> Self {
|
|
Self {
|
|
api_key,
|
|
base_url: None,
|
|
}
|
|
}
|
|
|
|
pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
|
|
self.base_url = Some(base_url.into());
|
|
self
|
|
}
|
|
|
|
pub fn build_client(&self) -> Client<OpenAIConfig> {
|
|
let mut config = OpenAIConfig::new().with_api_key(&self.api_key);
|
|
if let Some(ref url) = self.base_url {
|
|
config = config.with_api_base(url);
|
|
}
|
|
Client::with_config(config)
|
|
}
|
|
}
|
|
|
|
/// Response from an AI call, including usage statistics.
|
|
#[derive(Debug, Clone)]
|
|
pub struct AiCallResponse {
|
|
pub content: String,
|
|
pub input_tokens: i64,
|
|
pub output_tokens: i64,
|
|
pub latency_ms: i64,
|
|
}
|
|
|
|
impl AiCallResponse {
|
|
pub fn total_tokens(&self) -> i64 {
|
|
self.input_tokens + self.output_tokens
|
|
}
|
|
}
|
|
|
|
/// Internal state for retry tracking.
|
|
#[derive(Debug)]
|
|
struct RetryState {
|
|
attempt: u32,
|
|
max_retries: u32,
|
|
max_backoff_ms: u64,
|
|
}
|
|
|
|
impl RetryState {
|
|
fn new(max_retries: u32) -> Self {
|
|
Self {
|
|
attempt: 0,
|
|
max_retries,
|
|
max_backoff_ms: 5000,
|
|
}
|
|
}
|
|
|
|
fn should_retry(&self) -> bool {
|
|
self.attempt < self.max_retries
|
|
}
|
|
|
|
/// Calculate backoff duration with "full jitter" technique.
|
|
fn backoff_duration(&self) -> std::time::Duration {
|
|
let exp = self.attempt.min(5);
|
|
// base = 500 * 2^exp, capped at max_backoff_ms
|
|
let base_ms = 500u64
|
|
.saturating_mul(2u64.pow(exp))
|
|
.min(self.max_backoff_ms);
|
|
// jitter: random [0, base_ms/2]
|
|
let jitter = (fastrand_u64(base_ms / 2 + 1)) as u64;
|
|
std::time::Duration::from_millis(base_ms / 2 + jitter)
|
|
}
|
|
|
|
fn next(&mut self) {
|
|
self.attempt += 1;
|
|
}
|
|
}
|
|
|
|
/// Fast pseudo-random u64 using a simple LCG.
|
|
/// Good enough for jitter — not for cryptography.
|
|
fn fastrand_u64(n: u64) -> u64 {
|
|
use std::sync::atomic::{AtomicU64, Ordering};
|
|
static STATE: AtomicU64 = AtomicU64::new(0x193_667_6a_5e_7c_57);
|
|
if n <= 1 {
|
|
return 0;
|
|
}
|
|
let mut current = STATE.load(Ordering::Relaxed);
|
|
loop {
|
|
let new_val = current.wrapping_mul(6364136223846793005).wrapping_add(1);
|
|
match STATE.compare_exchange_weak(current, new_val, Ordering::Relaxed, Ordering::Relaxed) {
|
|
Ok(_) => return new_val % n,
|
|
Err(actual) => current = actual,
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Determine if an error is retryable.
|
|
fn is_retryable_error(err: &async_openai::error::OpenAIError) -> bool {
|
|
use async_openai::error::OpenAIError;
|
|
match err {
|
|
// Network errors (DNS failure, connection refused, timeout) are always retryable
|
|
OpenAIError::Reqwest(_) => true,
|
|
// For API errors, check the error code string (e.g., "rate_limit_exceeded")
|
|
OpenAIError::ApiError(api_err) => api_err.code.as_ref().map_or(false, |code| {
|
|
matches!(
|
|
code.as_str(),
|
|
"rate_limit_exceeded"
|
|
| "internal_server_error"
|
|
| "service_unavailable"
|
|
| "gateway_timeout"
|
|
| "bad_gateway"
|
|
)
|
|
}),
|
|
_ => false,
|
|
}
|
|
}
|
|
|
|
/// Call the AI model with automatic retry.
|
|
pub async fn call_with_retry(
|
|
messages: &[ChatCompletionRequestMessage],
|
|
model: &str,
|
|
config: &AiClientConfig,
|
|
max_retries: Option<u32>,
|
|
) -> Result<AiCallResponse> {
|
|
let client = config.build_client();
|
|
let mut state = RetryState::new(max_retries.unwrap_or(3));
|
|
|
|
loop {
|
|
let start = Instant::now();
|
|
|
|
let req = CreateChatCompletionRequest {
|
|
model: model.to_string(),
|
|
messages: messages.to_vec(),
|
|
..Default::default()
|
|
};
|
|
|
|
let result = client.chat().create(req).await;
|
|
|
|
match result {
|
|
Ok(response) => {
|
|
let latency_ms = start.elapsed().as_millis() as i64;
|
|
let (input_tokens, output_tokens) = extract_usage(&response);
|
|
|
|
return Ok(AiCallResponse {
|
|
content: extract_content(&response),
|
|
input_tokens,
|
|
output_tokens,
|
|
latency_ms,
|
|
});
|
|
}
|
|
Err(err) => {
|
|
if state.should_retry() && is_retryable_error(&err) {
|
|
let duration = state.backoff_duration();
|
|
eprintln!(
|
|
"AI call failed (attempt {}/{}), retrying in {:?}",
|
|
state.attempt + 1,
|
|
state.max_retries,
|
|
duration
|
|
);
|
|
tokio::time::sleep(duration).await;
|
|
state.next();
|
|
continue;
|
|
}
|
|
return Err(AgentError::OpenAi(err.to_string()));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Call with custom parameters (temperature, max_tokens, optional tools).
|
|
pub async fn call_with_params(
|
|
messages: &[ChatCompletionRequestMessage],
|
|
model: &str,
|
|
config: &AiClientConfig,
|
|
temperature: f32,
|
|
max_tokens: u32,
|
|
max_retries: Option<u32>,
|
|
tools: Option<&[ChatCompletionTool]>,
|
|
) -> Result<AiCallResponse> {
|
|
let client = config.build_client();
|
|
let mut state = RetryState::new(max_retries.unwrap_or(3));
|
|
|
|
loop {
|
|
let start = Instant::now();
|
|
|
|
let req = CreateChatCompletionRequest {
|
|
model: model.to_string(),
|
|
messages: messages.to_vec(),
|
|
temperature: Some(temperature),
|
|
max_completion_tokens: Some(max_tokens),
|
|
tools: tools.map(|ts| {
|
|
ts.iter()
|
|
.map(|t| ChatCompletionTools::Function(t.clone()))
|
|
.collect()
|
|
}),
|
|
tool_choice: tools.filter(|ts| !ts.is_empty()).map(|_| {
|
|
ChatCompletionToolChoiceOption::Mode(
|
|
async_openai::types::chat::ToolChoiceOptions::Auto,
|
|
)
|
|
}),
|
|
..Default::default()
|
|
};
|
|
|
|
let result = client.chat().create(req).await;
|
|
|
|
match result {
|
|
Ok(response) => {
|
|
let latency_ms = start.elapsed().as_millis() as i64;
|
|
let (input_tokens, output_tokens) = extract_usage(&response);
|
|
|
|
return Ok(AiCallResponse {
|
|
content: extract_content(&response),
|
|
input_tokens,
|
|
output_tokens,
|
|
latency_ms,
|
|
});
|
|
}
|
|
Err(err) => {
|
|
if state.should_retry() && is_retryable_error(&err) {
|
|
let duration = state.backoff_duration();
|
|
eprintln!(
|
|
"AI call failed (attempt {}/{}), retrying in {:?}",
|
|
state.attempt + 1,
|
|
state.max_retries,
|
|
duration
|
|
);
|
|
tokio::time::sleep(duration).await;
|
|
state.next();
|
|
continue;
|
|
}
|
|
return Err(AgentError::OpenAi(err.to_string()));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Extract text content from a chat completion response.
|
|
fn extract_content(response: &CreateChatCompletionResponse) -> String {
|
|
response
|
|
.choices
|
|
.first()
|
|
.and_then(|c| c.message.content.clone())
|
|
.unwrap_or_default()
|
|
}
|
|
|
|
/// Extract usage (input/output tokens) from a response.
|
|
fn extract_usage(response: &CreateChatCompletionResponse) -> (i64, i64) {
|
|
response
|
|
.usage
|
|
.as_ref()
|
|
.map(|u| {
|
|
(
|
|
i64::try_from(u.prompt_tokens).unwrap_or(0),
|
|
i64::try_from(u.completion_tokens).unwrap_or(0),
|
|
)
|
|
})
|
|
.unwrap_or((0, 0))
|
|
}
|