gitdataai/libs/agent/client.rs
2026-04-14 19:02:01 +08:00

280 lines
8.5 KiB
Rust

//! Unified AI client with built-in retry, token tracking, and session recording.
//!
//! Provides a single entry point for all AI calls with:
//! - Exponential backoff with jitter (max 3 retries)
//! - Retryable error classification (429/500/502/503/504)
//! - Token usage tracking (input/output)
use async_openai::Client;
use async_openai::config::OpenAIConfig;
use async_openai::types::chat::{
ChatCompletionRequestMessage, ChatCompletionTool, ChatCompletionToolChoiceOption,
ChatCompletionTools, CreateChatCompletionRequest, CreateChatCompletionResponse,
};
use std::time::Instant;
use crate::error::{AgentError, Result};
/// Configuration for the AI client.
#[derive(Clone)]
pub struct AiClientConfig {
pub api_key: String,
pub base_url: Option<String>,
}
impl AiClientConfig {
pub fn new(api_key: String) -> Self {
Self {
api_key,
base_url: None,
}
}
pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
self.base_url = Some(base_url.into());
self
}
pub fn build_client(&self) -> Client<OpenAIConfig> {
let mut config = OpenAIConfig::new().with_api_key(&self.api_key);
if let Some(ref url) = self.base_url {
config = config.with_api_base(url);
}
Client::with_config(config)
}
}
/// Response from an AI call, including usage statistics.
#[derive(Debug, Clone)]
pub struct AiCallResponse {
pub content: String,
pub input_tokens: i64,
pub output_tokens: i64,
pub latency_ms: i64,
}
impl AiCallResponse {
pub fn total_tokens(&self) -> i64 {
self.input_tokens + self.output_tokens
}
}
/// Internal state for retry tracking.
#[derive(Debug)]
struct RetryState {
attempt: u32,
max_retries: u32,
max_backoff_ms: u64,
}
impl RetryState {
fn new(max_retries: u32) -> Self {
Self {
attempt: 0,
max_retries,
max_backoff_ms: 5000,
}
}
fn should_retry(&self) -> bool {
self.attempt < self.max_retries
}
/// Calculate backoff duration with "full jitter" technique.
fn backoff_duration(&self) -> std::time::Duration {
let exp = self.attempt.min(5);
// base = 500 * 2^exp, capped at max_backoff_ms
let base_ms = 500u64
.saturating_mul(2u64.pow(exp))
.min(self.max_backoff_ms);
// jitter: random [0, base_ms/2]
let jitter = (fastrand_u64(base_ms / 2 + 1)) as u64;
std::time::Duration::from_millis(base_ms / 2 + jitter)
}
fn next(&mut self) {
self.attempt += 1;
}
}
/// Fast pseudo-random u64 using a simple LCG.
/// Good enough for jitter — not for cryptography.
fn fastrand_u64(n: u64) -> u64 {
use std::sync::atomic::{AtomicU64, Ordering};
static STATE: AtomicU64 = AtomicU64::new(0x193_667_6a_5e_7c_57);
if n <= 1 {
return 0;
}
let mut current = STATE.load(Ordering::Relaxed);
loop {
let new_val = current.wrapping_mul(6364136223846793005).wrapping_add(1);
match STATE.compare_exchange_weak(current, new_val, Ordering::Relaxed, Ordering::Relaxed) {
Ok(_) => return new_val % n,
Err(actual) => current = actual,
}
}
}
/// Determine if an error is retryable.
fn is_retryable_error(err: &async_openai::error::OpenAIError) -> bool {
use async_openai::error::OpenAIError;
match err {
// Network errors (DNS failure, connection refused, timeout) are always retryable
OpenAIError::Reqwest(_) => true,
// For API errors, check the error code string (e.g., "rate_limit_exceeded")
OpenAIError::ApiError(api_err) => api_err.code.as_ref().map_or(false, |code| {
matches!(
code.as_str(),
"rate_limit_exceeded"
| "internal_server_error"
| "service_unavailable"
| "gateway_timeout"
| "bad_gateway"
)
}),
_ => false,
}
}
/// Call the AI model with automatic retry.
pub async fn call_with_retry(
messages: &[ChatCompletionRequestMessage],
model: &str,
config: &AiClientConfig,
max_retries: Option<u32>,
) -> Result<AiCallResponse> {
let client = config.build_client();
let mut state = RetryState::new(max_retries.unwrap_or(3));
loop {
let start = Instant::now();
let req = CreateChatCompletionRequest {
model: model.to_string(),
messages: messages.to_vec(),
..Default::default()
};
let result = client.chat().create(req).await;
match result {
Ok(response) => {
let latency_ms = start.elapsed().as_millis() as i64;
let (input_tokens, output_tokens) = extract_usage(&response);
return Ok(AiCallResponse {
content: extract_content(&response),
input_tokens,
output_tokens,
latency_ms,
});
}
Err(err) => {
if state.should_retry() && is_retryable_error(&err) {
let duration = state.backoff_duration();
eprintln!(
"AI call failed (attempt {}/{}), retrying in {:?}",
state.attempt + 1,
state.max_retries,
duration
);
tokio::time::sleep(duration).await;
state.next();
continue;
}
return Err(AgentError::OpenAi(err.to_string()));
}
}
}
}
/// Call with custom parameters (temperature, max_tokens, optional tools).
pub async fn call_with_params(
messages: &[ChatCompletionRequestMessage],
model: &str,
config: &AiClientConfig,
temperature: f32,
max_tokens: u32,
max_retries: Option<u32>,
tools: Option<&[ChatCompletionTool]>,
) -> Result<AiCallResponse> {
let client = config.build_client();
let mut state = RetryState::new(max_retries.unwrap_or(3));
loop {
let start = Instant::now();
let req = CreateChatCompletionRequest {
model: model.to_string(),
messages: messages.to_vec(),
temperature: Some(temperature),
max_completion_tokens: Some(max_tokens),
tools: tools.map(|ts| {
ts.iter()
.map(|t| ChatCompletionTools::Function(t.clone()))
.collect()
}),
tool_choice: tools.filter(|ts| !ts.is_empty()).map(|_| {
ChatCompletionToolChoiceOption::Mode(
async_openai::types::chat::ToolChoiceOptions::Auto,
)
}),
..Default::default()
};
let result = client.chat().create(req).await;
match result {
Ok(response) => {
let latency_ms = start.elapsed().as_millis() as i64;
let (input_tokens, output_tokens) = extract_usage(&response);
return Ok(AiCallResponse {
content: extract_content(&response),
input_tokens,
output_tokens,
latency_ms,
});
}
Err(err) => {
if state.should_retry() && is_retryable_error(&err) {
let duration = state.backoff_duration();
eprintln!(
"AI call failed (attempt {}/{}), retrying in {:?}",
state.attempt + 1,
state.max_retries,
duration
);
tokio::time::sleep(duration).await;
state.next();
continue;
}
return Err(AgentError::OpenAi(err.to_string()));
}
}
}
}
/// Extract text content from a chat completion response.
fn extract_content(response: &CreateChatCompletionResponse) -> String {
response
.choices
.first()
.and_then(|c| c.message.content.clone())
.unwrap_or_default()
}
/// Extract usage (input/output tokens) from a response.
fn extract_usage(response: &CreateChatCompletionResponse) -> (i64, i64) {
response
.usage
.as_ref()
.map(|u| {
(
i64::try_from(u.prompt_tokens).unwrap_or(0),
i64::try_from(u.completion_tokens).unwrap_or(0),
)
})
.unwrap_or((0, 0))
}