- agent/client: full jitter backoff (random(0, base_ms)) instead of equal jitter - agent/tool/executor: fix buffer_unordered ordering mismatch with HashMap-by-index approach for concurrent tool execution - agent/chat: AiChunkType emit fixes, is_retryable_tool_error refinements, process_react uses request.max_tool_depth - agent/chat/context: fix Function message sender_name field - file_tools/curl: shared reqwest::Client via OnceLock, manual redirect following with per-hop SSRF validation, blocked sensitive headers - file_tools/grep: fix case-insensitive glob matching, segment consumption - file_tools/json: bracket notation support, remove .vscodeignore from JSONC - git_tools: git_diff_stats resolve base/head independently, DiffFileOut old_file.path for Deleted, reflog offset_minutes - git/repo: create_commit read parent tree into index, bare repo init - project_tools/repos: branch/path validation, .git/ prefix check - service/agent: tokent integration, billing, pr_summary, code_review fixes
344 lines
11 KiB
Rust
344 lines
11 KiB
Rust
//! Unified AI client with built-in retry, token tracking, and session recording.
|
|
//!
|
|
//! Provides a single entry point for all AI calls with:
|
|
//! - Exponential backoff with jitter (max 3 retries)
|
|
//! - Retryable error classification (429/500/502/503/504)
|
|
//! - Token usage tracking (input/output)
|
|
|
|
use async_openai::Client;
|
|
use async_openai::config::OpenAIConfig;
|
|
use async_openai::types::chat::{
|
|
ChatCompletionRequestMessage, ChatCompletionTool, ChatCompletionToolChoiceOption,
|
|
ChatCompletionTools, CreateChatCompletionRequest, CreateChatCompletionResponse,
|
|
};
|
|
use std::time::Instant;
|
|
|
|
use crate::error::{AgentError, Result};
|
|
|
|
/// AI call metrics — increments metrics crate counters for all AI calls.
|
|
/// These are registered in observability::install_recorder() and exported
|
|
/// via both the Prometheus /metrics endpoint and the Redis metrics flusher.
|
|
#[derive(Debug, Clone, Default)]
|
|
pub struct AiMetrics;
|
|
|
|
impl AiMetrics {
|
|
pub fn new() -> Self {
|
|
Self
|
|
}
|
|
|
|
/// Record a successful AI call with token usage.
|
|
pub fn record_success(&self, input_tokens: i64, output_tokens: i64, has_function_call: bool) {
|
|
metrics::counter!("ai_calls_total").increment(1);
|
|
metrics::counter!("ai_calls_success").increment(1);
|
|
if input_tokens > 0 {
|
|
metrics::counter!("ai_input_tokens_total").increment(input_tokens as u64);
|
|
}
|
|
if output_tokens > 0 {
|
|
metrics::counter!("ai_output_tokens_total").increment(output_tokens as u64);
|
|
}
|
|
if has_function_call {
|
|
metrics::counter!("ai_function_calls_total").increment(1);
|
|
}
|
|
}
|
|
|
|
/// Record a failed AI call.
|
|
pub fn record_failure(&self) {
|
|
metrics::counter!("ai_calls_total").increment(1);
|
|
metrics::counter!("ai_calls_failure").increment(1);
|
|
}
|
|
}
|
|
|
|
/// Configuration for the AI client.
|
|
#[derive(Clone)]
|
|
pub struct AiClientConfig {
|
|
pub api_key: String,
|
|
pub base_url: Option<String>,
|
|
}
|
|
|
|
impl AiClientConfig {
|
|
pub fn new(api_key: String) -> Self {
|
|
Self {
|
|
api_key,
|
|
base_url: None,
|
|
}
|
|
}
|
|
|
|
pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
|
|
self.base_url = Some(base_url.into());
|
|
self
|
|
}
|
|
|
|
pub fn build_client(&self) -> Client<OpenAIConfig> {
|
|
let mut config = OpenAIConfig::new().with_api_key(&self.api_key);
|
|
if let Some(ref url) = self.base_url {
|
|
config = config.with_api_base(url);
|
|
}
|
|
Client::with_config(config)
|
|
}
|
|
}
|
|
|
|
/// Response from an AI call, including usage statistics.
|
|
#[derive(Debug, Clone)]
|
|
pub struct AiCallResponse {
|
|
pub content: String,
|
|
pub input_tokens: i64,
|
|
pub output_tokens: i64,
|
|
pub latency_ms: i64,
|
|
}
|
|
|
|
impl AiCallResponse {
|
|
pub fn total_tokens(&self) -> i64 {
|
|
self.input_tokens + self.output_tokens
|
|
}
|
|
}
|
|
|
|
/// Internal state for retry tracking.
|
|
#[derive(Debug)]
|
|
struct RetryState {
|
|
attempt: u32,
|
|
max_retries: u32,
|
|
max_backoff_ms: u64,
|
|
}
|
|
|
|
impl RetryState {
|
|
fn new(max_retries: u32) -> Self {
|
|
Self {
|
|
attempt: 0,
|
|
max_retries,
|
|
max_backoff_ms: 5000,
|
|
}
|
|
}
|
|
|
|
fn should_retry(&self) -> bool {
|
|
self.attempt < self.max_retries
|
|
}
|
|
|
|
/// Calculate backoff duration with full jitter technique.
|
|
/// sleep = random(0, min(cap, base * 2^attempt))
|
|
fn backoff_duration(&self) -> std::time::Duration {
|
|
let exp = self.attempt.min(5);
|
|
// base = 500 * 2^exp, capped at max_backoff_ms
|
|
let base_ms = 500u64
|
|
.saturating_mul(2u64.pow(exp))
|
|
.min(self.max_backoff_ms);
|
|
// Full jitter: random value in [0, base_ms]
|
|
let jitter = fastrand_u64(base_ms + 1) as u64;
|
|
std::time::Duration::from_millis(jitter)
|
|
}
|
|
|
|
fn next(&mut self) {
|
|
self.attempt += 1;
|
|
}
|
|
}
|
|
|
|
/// Fast pseudo-random u64 using a simple LCG.
|
|
/// Good enough for jitter — not for cryptography.
|
|
fn fastrand_u64(n: u64) -> u64 {
|
|
use std::sync::atomic::{AtomicU64, Ordering};
|
|
static STATE: AtomicU64 = AtomicU64::new(0x193_667_6a_5e_7c_57);
|
|
if n <= 1 {
|
|
return 0;
|
|
}
|
|
let mut current = STATE.load(Ordering::Relaxed);
|
|
loop {
|
|
let new_val = current.wrapping_mul(6364136223846793005).wrapping_add(1);
|
|
match STATE.compare_exchange_weak(current, new_val, Ordering::Relaxed, Ordering::Relaxed) {
|
|
Ok(_) => return new_val % n,
|
|
Err(actual) => current = actual,
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Determine if an error is retryable.
|
|
fn is_retryable_error(err: &async_openai::error::OpenAIError) -> bool {
|
|
use async_openai::error::OpenAIError;
|
|
match err {
|
|
// Network errors (DNS failure, connection refused, timeout) are always retryable
|
|
OpenAIError::Reqwest(_) => true,
|
|
// For API errors, check the error code string (e.g., "rate_limit_exceeded")
|
|
OpenAIError::ApiError(api_err) => api_err.code.as_ref().map_or(false, |code| {
|
|
matches!(
|
|
code.as_str(),
|
|
"rate_limit_exceeded"
|
|
| "internal_server_error"
|
|
| "service_unavailable"
|
|
| "gateway_timeout"
|
|
| "bad_gateway"
|
|
)
|
|
}),
|
|
_ => false,
|
|
}
|
|
}
|
|
|
|
/// Global AI metrics shared across all AI client calls.
|
|
static AI_METRICS: std::sync::OnceLock<AiMetrics> = std::sync::OnceLock::new();
|
|
|
|
fn ai_metrics() -> &'static AiMetrics {
|
|
AI_METRICS.get_or_init(AiMetrics::new)
|
|
}
|
|
|
|
/// Call the AI model with automatic retry.
|
|
pub async fn call_with_retry(
|
|
messages: &[ChatCompletionRequestMessage],
|
|
model: &str,
|
|
config: &AiClientConfig,
|
|
max_retries: Option<u32>,
|
|
) -> Result<AiCallResponse> {
|
|
let client = config.build_client();
|
|
let mut state = RetryState::new(max_retries.unwrap_or(3));
|
|
|
|
loop {
|
|
let start = Instant::now();
|
|
|
|
let req = CreateChatCompletionRequest {
|
|
model: model.to_string(),
|
|
messages: messages.to_vec(),
|
|
..Default::default()
|
|
};
|
|
|
|
let result = client.chat().create(req).await;
|
|
|
|
match result {
|
|
Ok(response) => {
|
|
let latency_ms = start.elapsed().as_millis() as i64;
|
|
let (input_tokens, output_tokens) = extract_usage(&response);
|
|
|
|
// Check if response contains a tool call
|
|
let has_function_call = response
|
|
.choices
|
|
.first()
|
|
.and_then(|c| c.finish_reason.as_ref())
|
|
.map_or(false, |r| *r == async_openai::types::chat::FinishReason::ToolCalls);
|
|
ai_metrics().record_success(input_tokens, output_tokens, has_function_call);
|
|
|
|
return Ok(AiCallResponse {
|
|
content: extract_content(&response),
|
|
input_tokens,
|
|
output_tokens,
|
|
latency_ms,
|
|
});
|
|
}
|
|
Err(err) => {
|
|
if state.should_retry() && is_retryable_error(&err) {
|
|
let duration = state.backoff_duration();
|
|
tracing::warn!(
|
|
attempt = state.attempt + 1,
|
|
max_retries = state.max_retries,
|
|
backoff_ms = duration.as_millis() as u64,
|
|
model = %model,
|
|
error = %err.to_string(),
|
|
"ai_call_retry"
|
|
);
|
|
tokio::time::sleep(duration).await;
|
|
state.next();
|
|
continue;
|
|
}
|
|
ai_metrics().record_failure();
|
|
return Err(AgentError::OpenAi(err.to_string()));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Call with custom parameters (temperature, max_tokens, optional tools, optional tool_choice).
|
|
///
|
|
/// When `tool_choice` is `None` and tools are present, the default is `Auto`.
|
|
/// Pass `Some(ChatCompletionToolChoiceOption::None)` to force the model to respond
|
|
/// with text only (e.g. when you want JSON-in-text for ReAct parsing).
|
|
pub async fn call_with_params(
|
|
messages: &[ChatCompletionRequestMessage],
|
|
model: &str,
|
|
config: &AiClientConfig,
|
|
temperature: f32,
|
|
max_tokens: u32,
|
|
max_retries: Option<u32>,
|
|
tools: Option<&[ChatCompletionTool]>,
|
|
tool_choice: Option<ChatCompletionToolChoiceOption>,
|
|
) -> Result<AiCallResponse> {
|
|
let client = config.build_client();
|
|
let mut state = RetryState::new(max_retries.unwrap_or(3));
|
|
|
|
loop {
|
|
let start = Instant::now();
|
|
|
|
let req = CreateChatCompletionRequest {
|
|
model: model.to_string(),
|
|
messages: messages.to_vec(),
|
|
temperature: Some(temperature),
|
|
max_completion_tokens: Some(max_tokens),
|
|
tools: tools.map(|ts| {
|
|
ts.iter()
|
|
.map(|t| ChatCompletionTools::Function(t.clone()))
|
|
.collect()
|
|
}),
|
|
tool_choice: tool_choice.clone(),
|
|
..Default::default()
|
|
};
|
|
|
|
let result = client.chat().create(req).await;
|
|
|
|
match result {
|
|
Ok(response) => {
|
|
let latency_ms = start.elapsed().as_millis() as i64;
|
|
let (input_tokens, output_tokens) = extract_usage(&response);
|
|
|
|
// Check if response contains a tool call
|
|
let has_function_call = response
|
|
.choices
|
|
.first()
|
|
.and_then(|c| c.finish_reason.as_ref())
|
|
.map_or(false, |r| *r == async_openai::types::chat::FinishReason::ToolCalls);
|
|
ai_metrics().record_success(input_tokens, output_tokens, has_function_call);
|
|
|
|
return Ok(AiCallResponse {
|
|
content: extract_content(&response),
|
|
input_tokens,
|
|
output_tokens,
|
|
latency_ms,
|
|
});
|
|
}
|
|
Err(err) => {
|
|
if state.should_retry() && is_retryable_error(&err) {
|
|
let duration = state.backoff_duration();
|
|
tracing::warn!(
|
|
attempt = state.attempt + 1,
|
|
max_retries = state.max_retries,
|
|
backoff_ms = duration.as_millis() as u64,
|
|
model = %model,
|
|
error = %err.to_string(),
|
|
"ai_call_retry"
|
|
);
|
|
tokio::time::sleep(duration).await;
|
|
state.next();
|
|
continue;
|
|
}
|
|
ai_metrics().record_failure();
|
|
return Err(AgentError::OpenAi(err.to_string()));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Extract text content from a chat completion response.
|
|
fn extract_content(response: &CreateChatCompletionResponse) -> String {
|
|
response
|
|
.choices
|
|
.first()
|
|
.and_then(|c| c.message.content.clone())
|
|
.unwrap_or_default()
|
|
}
|
|
|
|
/// Extract usage (input/output tokens) from a response.
|
|
fn extract_usage(response: &CreateChatCompletionResponse) -> (i64, i64) {
|
|
response
|
|
.usage
|
|
.as_ref()
|
|
.map(|u| {
|
|
(
|
|
i64::try_from(u.prompt_tokens).unwrap_or(0),
|
|
i64::try_from(u.completion_tokens).unwrap_or(0),
|
|
)
|
|
})
|
|
.unwrap_or((0, 0))
|
|
}
|