gitdataai/libs/agent/client.rs
ZhenYi f7e087e066 fix(agent/service): retry jitter, tool executor ordering, curl SSRF, grep/JSON
- agent/client: full jitter backoff (random(0, base_ms)) instead of equal jitter
- agent/tool/executor: fix buffer_unordered ordering mismatch with
  HashMap-by-index approach for concurrent tool execution
- agent/chat: AiChunkType emit fixes, is_retryable_tool_error refinements,
  process_react uses request.max_tool_depth
- agent/chat/context: fix Function message sender_name field
- file_tools/curl: shared reqwest::Client via OnceLock, manual redirect
  following with per-hop SSRF validation, blocked sensitive headers
- file_tools/grep: fix case-insensitive glob matching, segment consumption
- file_tools/json: bracket notation support, remove .vscodeignore from JSONC
- git_tools: git_diff_stats resolve base/head independently,
  DiffFileOut old_file.path for Deleted, reflog offset_minutes
- git/repo: create_commit read parent tree into index, bare repo init
- project_tools/repos: branch/path validation, .git/ prefix check
- service/agent: tokent integration, billing, pr_summary, code_review fixes
2026-04-25 09:53:31 +08:00

344 lines
11 KiB
Rust

//! Unified AI client with built-in retry, token tracking, and session recording.
//!
//! Provides a single entry point for all AI calls with:
//! - Exponential backoff with jitter (max 3 retries)
//! - Retryable error classification (429/500/502/503/504)
//! - Token usage tracking (input/output)
use async_openai::Client;
use async_openai::config::OpenAIConfig;
use async_openai::types::chat::{
ChatCompletionRequestMessage, ChatCompletionTool, ChatCompletionToolChoiceOption,
ChatCompletionTools, CreateChatCompletionRequest, CreateChatCompletionResponse,
};
use std::time::Instant;
use crate::error::{AgentError, Result};
/// AI call metrics — increments metrics crate counters for all AI calls.
/// These are registered in observability::install_recorder() and exported
/// via both the Prometheus /metrics endpoint and the Redis metrics flusher.
#[derive(Debug, Clone, Default)]
pub struct AiMetrics;
impl AiMetrics {
pub fn new() -> Self {
Self
}
/// Record a successful AI call with token usage.
pub fn record_success(&self, input_tokens: i64, output_tokens: i64, has_function_call: bool) {
metrics::counter!("ai_calls_total").increment(1);
metrics::counter!("ai_calls_success").increment(1);
if input_tokens > 0 {
metrics::counter!("ai_input_tokens_total").increment(input_tokens as u64);
}
if output_tokens > 0 {
metrics::counter!("ai_output_tokens_total").increment(output_tokens as u64);
}
if has_function_call {
metrics::counter!("ai_function_calls_total").increment(1);
}
}
/// Record a failed AI call.
pub fn record_failure(&self) {
metrics::counter!("ai_calls_total").increment(1);
metrics::counter!("ai_calls_failure").increment(1);
}
}
/// Configuration for the AI client.
#[derive(Clone)]
pub struct AiClientConfig {
pub api_key: String,
pub base_url: Option<String>,
}
impl AiClientConfig {
pub fn new(api_key: String) -> Self {
Self {
api_key,
base_url: None,
}
}
pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
self.base_url = Some(base_url.into());
self
}
pub fn build_client(&self) -> Client<OpenAIConfig> {
let mut config = OpenAIConfig::new().with_api_key(&self.api_key);
if let Some(ref url) = self.base_url {
config = config.with_api_base(url);
}
Client::with_config(config)
}
}
/// Response from an AI call, including usage statistics.
#[derive(Debug, Clone)]
pub struct AiCallResponse {
pub content: String,
pub input_tokens: i64,
pub output_tokens: i64,
pub latency_ms: i64,
}
impl AiCallResponse {
pub fn total_tokens(&self) -> i64 {
self.input_tokens + self.output_tokens
}
}
/// Internal state for retry tracking.
#[derive(Debug)]
struct RetryState {
attempt: u32,
max_retries: u32,
max_backoff_ms: u64,
}
impl RetryState {
fn new(max_retries: u32) -> Self {
Self {
attempt: 0,
max_retries,
max_backoff_ms: 5000,
}
}
fn should_retry(&self) -> bool {
self.attempt < self.max_retries
}
/// Calculate backoff duration with full jitter technique.
/// sleep = random(0, min(cap, base * 2^attempt))
fn backoff_duration(&self) -> std::time::Duration {
let exp = self.attempt.min(5);
// base = 500 * 2^exp, capped at max_backoff_ms
let base_ms = 500u64
.saturating_mul(2u64.pow(exp))
.min(self.max_backoff_ms);
// Full jitter: random value in [0, base_ms]
let jitter = fastrand_u64(base_ms + 1) as u64;
std::time::Duration::from_millis(jitter)
}
fn next(&mut self) {
self.attempt += 1;
}
}
/// Fast pseudo-random u64 using a simple LCG.
/// Good enough for jitter — not for cryptography.
fn fastrand_u64(n: u64) -> u64 {
use std::sync::atomic::{AtomicU64, Ordering};
static STATE: AtomicU64 = AtomicU64::new(0x193_667_6a_5e_7c_57);
if n <= 1 {
return 0;
}
let mut current = STATE.load(Ordering::Relaxed);
loop {
let new_val = current.wrapping_mul(6364136223846793005).wrapping_add(1);
match STATE.compare_exchange_weak(current, new_val, Ordering::Relaxed, Ordering::Relaxed) {
Ok(_) => return new_val % n,
Err(actual) => current = actual,
}
}
}
/// Determine if an error is retryable.
fn is_retryable_error(err: &async_openai::error::OpenAIError) -> bool {
use async_openai::error::OpenAIError;
match err {
// Network errors (DNS failure, connection refused, timeout) are always retryable
OpenAIError::Reqwest(_) => true,
// For API errors, check the error code string (e.g., "rate_limit_exceeded")
OpenAIError::ApiError(api_err) => api_err.code.as_ref().map_or(false, |code| {
matches!(
code.as_str(),
"rate_limit_exceeded"
| "internal_server_error"
| "service_unavailable"
| "gateway_timeout"
| "bad_gateway"
)
}),
_ => false,
}
}
/// Global AI metrics shared across all AI client calls.
static AI_METRICS: std::sync::OnceLock<AiMetrics> = std::sync::OnceLock::new();
fn ai_metrics() -> &'static AiMetrics {
AI_METRICS.get_or_init(AiMetrics::new)
}
/// Call the AI model with automatic retry.
pub async fn call_with_retry(
messages: &[ChatCompletionRequestMessage],
model: &str,
config: &AiClientConfig,
max_retries: Option<u32>,
) -> Result<AiCallResponse> {
let client = config.build_client();
let mut state = RetryState::new(max_retries.unwrap_or(3));
loop {
let start = Instant::now();
let req = CreateChatCompletionRequest {
model: model.to_string(),
messages: messages.to_vec(),
..Default::default()
};
let result = client.chat().create(req).await;
match result {
Ok(response) => {
let latency_ms = start.elapsed().as_millis() as i64;
let (input_tokens, output_tokens) = extract_usage(&response);
// Check if response contains a tool call
let has_function_call = response
.choices
.first()
.and_then(|c| c.finish_reason.as_ref())
.map_or(false, |r| *r == async_openai::types::chat::FinishReason::ToolCalls);
ai_metrics().record_success(input_tokens, output_tokens, has_function_call);
return Ok(AiCallResponse {
content: extract_content(&response),
input_tokens,
output_tokens,
latency_ms,
});
}
Err(err) => {
if state.should_retry() && is_retryable_error(&err) {
let duration = state.backoff_duration();
tracing::warn!(
attempt = state.attempt + 1,
max_retries = state.max_retries,
backoff_ms = duration.as_millis() as u64,
model = %model,
error = %err.to_string(),
"ai_call_retry"
);
tokio::time::sleep(duration).await;
state.next();
continue;
}
ai_metrics().record_failure();
return Err(AgentError::OpenAi(err.to_string()));
}
}
}
}
/// Call with custom parameters (temperature, max_tokens, optional tools, optional tool_choice).
///
/// When `tool_choice` is `None` and tools are present, the default is `Auto`.
/// Pass `Some(ChatCompletionToolChoiceOption::None)` to force the model to respond
/// with text only (e.g. when you want JSON-in-text for ReAct parsing).
pub async fn call_with_params(
messages: &[ChatCompletionRequestMessage],
model: &str,
config: &AiClientConfig,
temperature: f32,
max_tokens: u32,
max_retries: Option<u32>,
tools: Option<&[ChatCompletionTool]>,
tool_choice: Option<ChatCompletionToolChoiceOption>,
) -> Result<AiCallResponse> {
let client = config.build_client();
let mut state = RetryState::new(max_retries.unwrap_or(3));
loop {
let start = Instant::now();
let req = CreateChatCompletionRequest {
model: model.to_string(),
messages: messages.to_vec(),
temperature: Some(temperature),
max_completion_tokens: Some(max_tokens),
tools: tools.map(|ts| {
ts.iter()
.map(|t| ChatCompletionTools::Function(t.clone()))
.collect()
}),
tool_choice: tool_choice.clone(),
..Default::default()
};
let result = client.chat().create(req).await;
match result {
Ok(response) => {
let latency_ms = start.elapsed().as_millis() as i64;
let (input_tokens, output_tokens) = extract_usage(&response);
// Check if response contains a tool call
let has_function_call = response
.choices
.first()
.and_then(|c| c.finish_reason.as_ref())
.map_or(false, |r| *r == async_openai::types::chat::FinishReason::ToolCalls);
ai_metrics().record_success(input_tokens, output_tokens, has_function_call);
return Ok(AiCallResponse {
content: extract_content(&response),
input_tokens,
output_tokens,
latency_ms,
});
}
Err(err) => {
if state.should_retry() && is_retryable_error(&err) {
let duration = state.backoff_duration();
tracing::warn!(
attempt = state.attempt + 1,
max_retries = state.max_retries,
backoff_ms = duration.as_millis() as u64,
model = %model,
error = %err.to_string(),
"ai_call_retry"
);
tokio::time::sleep(duration).await;
state.next();
continue;
}
ai_metrics().record_failure();
return Err(AgentError::OpenAi(err.to_string()));
}
}
}
}
/// Extract text content from a chat completion response.
fn extract_content(response: &CreateChatCompletionResponse) -> String {
response
.choices
.first()
.and_then(|c| c.message.content.clone())
.unwrap_or_default()
}
/// Extract usage (input/output tokens) from a response.
fn extract_usage(response: &CreateChatCompletionResponse) -> (i64, i64) {
response
.usage
.as_ref()
.map(|u| {
(
i64::try_from(u.prompt_tokens).unwrap_or(0),
i64::try_from(u.completion_tokens).unwrap_or(0),
)
})
.unwrap_or((0, 0))
}