gitdataai/libs/agent/client/mod.rs
ZhenYi c6bb72682b fix(agent): 修复扣费链路并实现级联扣费策略
- billing.rs: 修复参数传递 (model_id -> version_id)
- billing.rs: 新增 BillingResult 枚举支持 InsufficientBalance 错误
- billing.rs: 实现级联扣费 (优先 project 余额,不足时 fallback 到 workspace)
- billing.rs: 余额不足时创建系统消息并持久化
- chat/service.rs: 捕获 InsufficientBalance 错误并调用 create_system_message
- client/mod.rs: 超时时间从 60s 改为 120s
2026-04-28 19:59:06 +08:00

769 lines
27 KiB
Rust

//! Unified AI client with built-in retry, token tracking, and session recording.
//!
//! Uses rig-core as the underlying AI provider library.
pub mod types;
pub use types::{ChatRequestMessage, ToolCall as ClientToolCall};
use std::pin::Pin;
use std::sync::Arc;
use std::time::Instant;
use uuid::Uuid;
use crate::error::{AgentError, Result};
use futures::StreamExt;
use rig::completion::message::{AssistantContent, Message as RigMessage};
use rig::completion::{GetTokenUsage, ToolDefinition, CompletionModel};
use rig::one_or_many::OneOrMany;
use rig::prelude::CompletionClient;
use rig::providers::openai;
/// AI call metrics — increments metrics crate counters for all AI calls.
#[derive(Debug, Clone, Default)]
pub struct AiMetrics;
impl AiMetrics {
pub fn new() -> Self { Self }
pub fn record_success(&self, input_tokens: i64, output_tokens: i64, has_function_call: bool) {
metrics::counter!("ai_calls_total").increment(1);
metrics::counter!("ai_calls_success").increment(1);
if input_tokens > 0 {
metrics::counter!("ai_input_tokens_total").increment(input_tokens as u64);
}
if output_tokens > 0 {
metrics::counter!("ai_output_tokens_total").increment(output_tokens as u64);
}
if has_function_call {
metrics::counter!("ai_function_calls_total").increment(1);
}
}
pub fn record_failure(&self) {
metrics::counter!("ai_calls_total").increment(1);
metrics::counter!("ai_calls_failure").increment(1);
}
}
/// Configuration for the AI client.
#[derive(Clone)]
pub struct AiClientConfig {
pub api_key: String,
pub base_url: Option<String>,
}
impl AiClientConfig {
pub fn new(api_key: String) -> Self {
Self { api_key, base_url: None }
}
pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
self.base_url = Some(base_url.into());
self
}
/// Build a rig OpenAI client from this config.
pub fn build_rig_client(&self) -> openai::Client {
let base = self.base_url.clone().unwrap_or_else(|| "https://api.openai.com".to_string());
openai::Client::builder()
.api_key(&self.api_key)
.base_url(&base)
.build()
.expect("Failed to build rig OpenAI client")
}
}
/// Response from an AI call, including usage statistics.
#[derive(Debug, Clone)]
pub struct AiCallResponse {
pub content: String,
pub input_tokens: i64,
pub output_tokens: i64,
pub latency_ms: i64,
pub tool_calls_finished: Vec<String>,
}
impl AiCallResponse {
pub fn total_tokens(&self) -> i64 {
self.input_tokens + self.output_tokens
}
}
/// Internal state for retry tracking.
#[derive(Debug)]
struct RetryState {
attempt: u32,
max_retries: u32,
max_backoff_ms: u64,
}
impl RetryState {
fn new(max_retries: u32) -> Self {
Self { attempt: 0, max_retries, max_backoff_ms: 5000 }
}
fn should_retry(&self) -> bool { self.attempt < self.max_retries }
fn backoff_duration(&self) -> std::time::Duration {
let exp = self.attempt.min(5);
let base_ms = 500u64.saturating_mul(2u64.pow(exp)).min(self.max_backoff_ms);
let jitter = fastrand_u64(base_ms + 1);
std::time::Duration::from_millis(jitter)
}
fn next(&mut self) { self.attempt += 1; }
}
fn fastrand_u64(n: u64) -> u64 {
use std::sync::atomic::{AtomicU64, Ordering};
static STATE: AtomicU64 = AtomicU64::new(0x193_667_6a_5e_7c_57);
if n <= 1 { return 0; }
let mut current = STATE.load(Ordering::Relaxed);
loop {
let new_val = current.wrapping_mul(6364136223846793005).wrapping_add(1);
match STATE.compare_exchange_weak(current, new_val, Ordering::Relaxed, Ordering::Relaxed) {
Ok(_) => return new_val % n,
Err(actual) => current = actual,
}
}
}
fn is_retryable_error(err: &AgentError) -> bool {
let msg = err.to_string();
msg.contains("connection refused")
|| msg.contains("connection timed out")
|| msg.contains("network error")
|| msg.contains("dns error")
|| msg.contains("error sending request")
|| msg.contains("Http client error")
|| msg.contains("rate_limit")
|| msg.contains("rate limit")
|| msg.contains("429")
|| msg.contains("500")
|| msg.contains("502")
|| msg.contains("503")
|| msg.contains("504")
|| msg.contains("internal_server_error")
|| msg.contains("service_unavailable")
|| msg.contains("gateway_timeout")
|| msg.contains("bad_gateway")
}
static AI_METRICS: std::sync::OnceLock<AiMetrics> = std::sync::OnceLock::new();
fn ai_metrics() -> &'static AiMetrics {
AI_METRICS.get_or_init(AiMetrics::new)
}
// ── Type conversions ─────────────────────────────────────────────────────────
fn to_rig_message(msg: &ChatRequestMessage) -> RigMessage {
match msg.role.as_str() {
"system" => {
// System messages are handled via preamble(), but we still
// need to return something. Return a system message as User for safety.
RigMessage::user(msg.content.as_deref().unwrap_or(""))
}
"user" => {
RigMessage::user(msg.content.as_deref().unwrap_or(""))
}
"assistant" => {
let mut parts: Vec<AssistantContent> = Vec::new();
if let Some(ref content) = msg.content {
if !content.is_empty() {
parts.push(AssistantContent::text(content));
}
}
if let Some(ref tool_calls) = msg.tool_calls {
for tc in tool_calls {
// GLM may return empty tool call IDs — fall back to a generated UUID.
let id = if tc.id.is_empty() {
Uuid::new_v4().to_string()
} else {
tc.id.clone()
};
parts.push(AssistantContent::tool_call_with_call_id(
&id,
id.clone(),
&tc.function.name,
serde_json::from_str(&tc.function.arguments).unwrap_or(serde_json::Value::Null),
));
}
}
if parts.is_empty() {
RigMessage::assistant("")
} else if parts.len() == 1 {
// Single part — use simpler constructors
match parts.pop().unwrap() {
AssistantContent::Text(t) => RigMessage::assistant(t.text),
ac => {
RigMessage::Assistant {
id: None,
content: OneOrMany::one(ac),
}
}
}
} else {
let content = OneOrMany::many(parts).expect("non-empty parts");
RigMessage::Assistant { id: None, content }
}
}
"tool" | "function" => {
let id = msg.tool_call_id.as_deref().unwrap_or("unknown").to_string();
let call_id = msg.tool_call_id.clone().or_else(|| Some(id.clone()));
let content = msg.content.as_deref().unwrap_or("");
RigMessage::tool_result_with_call_id(id, call_id, content)
}
"developer" => {
// Developer role maps to user/system in rig
RigMessage::user(msg.content.as_deref().unwrap_or(""))
}
_ => RigMessage::user(msg.content.as_deref().unwrap_or("")),
}
}
fn to_rig_tool_def(tool_json: &serde_json::Value) -> Option<ToolDefinition> {
let name = tool_json
.get("function")
.and_then(|f| f.get("name"))
.and_then(|n| n.as_str())?
.to_string();
let description = tool_json
.get("function")
.and_then(|f| f.get("description"))
.and_then(|d| d.as_str())
.map(|s| s.to_string())
.unwrap_or_default();
let parameters = tool_json
.get("function")
.and_then(|f| f.get("parameters"))
.cloned()
.unwrap_or(serde_json::json!({}));
Some(ToolDefinition {
name,
description,
parameters,
})
}
// ── Call helpers ─────────────────────────────────────────────────────────────
async fn do_completion<M>(
model: &M,
messages: &[ChatRequestMessage],
temperature: Option<f64>,
max_tokens: Option<u32>,
tools: Option<&[serde_json::Value]>,
tool_choice: Option<&str>,
) -> Result<(String, u64, u64, Vec<String>)>
where
M: CompletionModel<Client = openai::Client>,
{
let mut history: Vec<RigMessage> = messages.iter().map(to_rig_message).collect();
// Extract preamble (first system message) and remove from history
let preamble = messages
.iter()
.find(|m| m.role == "system")
.and_then(|m| m.content.as_deref())
.unwrap_or("")
.to_string();
history.retain(|m| !matches!(m, RigMessage::User { .. } | RigMessage::Assistant { .. }));
// For tool_result messages, we need to add them back
// Actually, let's keep the approach: filter out system, add others back
// The rig completion request uses: preamble (system) + messages (conversation)
// For our messages: system → preamble, rest → messages
let non_system: Vec<RigMessage> = messages
.iter()
.filter(|m| m.role != "system")
.map(to_rig_message)
.collect();
let tool_defs: Vec<ToolDefinition> = tools
.map(|ts| ts.iter().filter_map(to_rig_tool_def).collect())
.unwrap_or_default();
let mut builder = model.completion_request("");
if !preamble.is_empty() {
builder = builder.preamble(preamble);
}
if !non_system.is_empty() {
builder = builder.messages(non_system);
}
if let Some(t) = temperature {
builder = builder.temperature(t);
}
if let Some(mt) = max_tokens {
builder = builder.max_tokens(mt as u64);
}
if !tool_defs.is_empty() {
builder = builder.tools(tool_defs);
}
// Only set tool_choice when explicitly provided (mirrors call_stream_once logic)
if let Some(tc) = tool_choice {
match tc {
"none" => {
builder = builder.tool_choice(rig::completion::message::ToolChoice::None);
}
"auto" => {
builder = builder.tool_choice(rig::completion::message::ToolChoice::Auto);
}
s => {
builder = builder.tool_choice(
rig::completion::message::ToolChoice::Specific {
function_names: vec![s.to_string()],
},
);
}
}
}
let response = builder.send().await.map_err(|e| AgentError::OpenAi(e.to_string()))?;
let mut content = String::new();
let mut tool_names: Vec<String> = Vec::new();
for item in response.choice {
match item {
AssistantContent::Text(t) => {
content.push_str(&t.text);
}
AssistantContent::ToolCall(tc) => {
tool_names.push(tc.function.name.clone());
}
AssistantContent::Reasoning(_) => {}
AssistantContent::Image(_) => {}
}
}
let input_tokens = response.usage.input_tokens;
let output_tokens = response.usage.output_tokens;
Ok((content, input_tokens, output_tokens, tool_names))
}
// ── Public API ───────────────────────────────────────────────────────────────
/// Call the AI model with automatic retry (no custom params).
pub async fn call_with_retry(
messages: &[ChatRequestMessage],
model_name: &str,
config: &AiClientConfig,
max_retries: Option<u32>,
) -> Result<AiCallResponse> {
let client = config.build_rig_client();
let model = client.completion_model(model_name);
let mut state = RetryState::new(max_retries.unwrap_or(3));
loop {
let start = Instant::now();
let result = do_completion(&model, messages, None, None, None, None).await;
match result {
Ok((content, input_tokens, output_tokens, tool_names)) => {
let latency_ms = start.elapsed().as_millis() as i64;
let has_function_call = !tool_names.is_empty();
ai_metrics().record_success(input_tokens as i64, output_tokens as i64, has_function_call);
return Ok(AiCallResponse { content, input_tokens: input_tokens as i64, output_tokens: output_tokens as i64, latency_ms, tool_calls_finished: tool_names });
}
Err(ref err) if state.should_retry() && is_retryable_error(err) => {
let duration = state.backoff_duration();
tracing::warn!(
attempt = state.attempt + 1,
max_retries = state.max_retries,
backoff_ms = duration.as_millis() as u64,
model = %model_name,
error = %err,
"ai_call_retry"
);
tokio::time::sleep(duration).await;
state.next();
}
Err(err) => {
ai_metrics().record_failure();
return Err(err);
}
}
}
}
/// Call with custom parameters (temperature, max_tokens, optional tools, optional tool_choice).
pub async fn call_with_params(
messages: &[ChatRequestMessage],
model_name: &str,
config: &AiClientConfig,
temperature: f32,
max_tokens: u32,
max_retries: Option<u32>,
tools: Option<&[serde_json::Value]>,
tool_choice: Option<&str>,
) -> Result<AiCallResponse> {
let client = config.build_rig_client();
let model = client.completion_model(model_name);
let mut state = RetryState::new(max_retries.unwrap_or(3));
loop {
let start = Instant::now();
let result = do_completion(
&model,
messages,
Some(temperature as f64),
Some(max_tokens),
tools,
tool_choice,
)
.await;
match result {
Ok((content, input_tokens, output_tokens, tool_names)) => {
let latency_ms = start.elapsed().as_millis() as i64;
let has_function_call = !tool_names.is_empty();
ai_metrics().record_success(input_tokens as i64, output_tokens as i64, has_function_call);
return Ok(AiCallResponse { content, input_tokens: input_tokens as i64, output_tokens: output_tokens as i64, latency_ms, tool_calls_finished: tool_names });
}
Err(ref err) if state.should_retry() && is_retryable_error(err) => {
let duration = state.backoff_duration();
tracing::warn!(
attempt = state.attempt + 1,
max_retries = state.max_retries,
backoff_ms = duration.as_millis() as u64,
model = %model_name,
error = %err,
"ai_call_retry"
);
tokio::time::sleep(duration).await;
state.next();
}
Err(err) => {
ai_metrics().record_failure();
return Err(err);
}
}
}
}
/// A tool call extracted from streaming response with accumulated arguments.
#[derive(Debug, Clone)]
pub struct StreamedToolCall {
/// Tool call ID
pub id: String,
/// Tool function name
pub name: String,
/// Accumulated JSON arguments string
pub arguments: String,
}
/// Type of chunk in the streaming response, preserving arrival order.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum StreamChunkType {
Thinking,
Answer,
ToolCall,
}
/// A single chunk from the streaming response in arrival order.
#[derive(Debug, Clone)]
pub struct StreamChunk {
pub chunk_type: StreamChunkType,
pub content: String,
}
/// Streaming result from rig.
#[derive(Debug)]
pub struct StreamResponse {
pub content: String,
pub input_tokens: i64,
pub output_tokens: i64,
/// Accumulated reasoning/thinking text from the model.
pub reasoning_content: String,
/// Full tool calls with accumulated arguments (not just names)
pub tool_calls: Vec<StreamedToolCall>,
/// All chunks in arrival order — preserves think/answer/tool interleaving.
pub chunks: Vec<StreamChunk>,
}
/// Async callback: takes a string delta and broadcasts it to the WebSocket.
/// The returned Future must be awaited by the caller.
pub type StreamTextCb = Arc<dyn Fn(&str) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>> + Send + Sync>;
pub type StreamReasoningCb = Arc<dyn Fn(&str) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>> + Send + Sync>;
pub type StreamToolCallCb = Arc<dyn Fn(&StreamedToolCall) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>> + Send + Sync>;
/// Run a streaming chat completion with 60s timeout and 5 retries.
pub async fn call_stream(
messages: &[ChatRequestMessage],
model_name: &str,
config: &AiClientConfig,
temperature: f32,
max_tokens: u32,
tools: Option<&[serde_json::Value]>,
tool_choice: Option<&str>,
on_text_delta: StreamTextCb,
on_reasoning_delta: StreamReasoningCb,
on_tool_call: StreamToolCallCb,
) -> Result<StreamResponse> {
let mut state = RetryState::new(5);
loop {
let result = call_stream_once(
messages, model_name, config, temperature, max_tokens, tools, tool_choice,
on_text_delta.clone(), on_reasoning_delta.clone(), on_tool_call.clone(),
)
.await;
match result {
Ok(response) => return Ok(response),
Err(ref err) if state.should_retry() && is_retryable_error(err) => {
let duration = state.backoff_duration();
tracing::warn!(
attempt = state.attempt + 1,
max_retries = 5,
backoff_ms = duration.as_millis() as u64,
model = %model_name,
error = %err,
"ai_stream_retry"
);
tokio::time::sleep(duration).await;
state.next();
}
Err(err) => {
ai_metrics().record_failure();
return Err(err);
}
}
}
}
/// Single attempt of streaming completion with 60s timeout.
async fn call_stream_once(
messages: &[ChatRequestMessage],
model_name: &str,
config: &AiClientConfig,
temperature: f32,
max_tokens: u32,
tools: Option<&[serde_json::Value]>,
tool_choice: Option<&str>,
on_text_delta: StreamTextCb,
on_reasoning_delta: StreamReasoningCb,
on_tool_call: StreamToolCallCb,
) -> Result<StreamResponse> {
let client = config.build_rig_client();
let model = client.completion_model(model_name);
let preamble = messages
.iter()
.find(|m| m.role == "system")
.and_then(|m| m.content.as_deref())
.unwrap_or("")
.to_string();
let non_system: Vec<RigMessage> = messages
.iter()
.filter(|m| m.role != "system")
.map(to_rig_message)
.collect();
let tool_defs: Vec<ToolDefinition> = tools
.map(|ts| ts.iter().filter_map(to_rig_tool_def).collect())
.unwrap_or_default();
let mut builder = model
.completion_request("")
.temperature(temperature as f64)
.max_tokens(max_tokens as u64);
if !preamble.is_empty() {
builder = builder.preamble(preamble);
}
if !non_system.is_empty() {
builder = builder.messages(non_system);
}
if !tool_defs.is_empty() {
builder = builder.tools(tool_defs);
}
if let Some(tc) = tool_choice {
match tc {
"none" => {
builder = builder.tool_choice(rig::completion::message::ToolChoice::None);
}
"auto" => {
builder = builder.tool_choice(rig::completion::message::ToolChoice::Auto);
}
s => {
builder = builder.tool_choice(
rig::completion::message::ToolChoice::Specific {
function_names: vec![s.to_string()],
},
);
}
}
}
let stream_fut = async {
let mut stream = builder
.stream()
.await
.map_err(|e| AgentError::OpenAi(e.to_string()))?;
let mut content = String::new();
let mut reasoning_content = String::new();
let mut tool_calls: Vec<StreamedToolCall> = Vec::new();
let mut chunks: Vec<StreamChunk> = Vec::new();
// Some models (e.g. GLM) ignore tool_choice="none" and still emit tool_calls.
// Filter them out so they don't cause spurious tool execution attempts.
let skip_tool_calls = tool_choice == Some("none");
use std::collections::HashMap;
let mut partial_tool_calls: HashMap<String, StreamedToolCall> = HashMap::new();
let mut stream_finished = false;
use rig::streaming::StreamedAssistantContent;
while let Some(item) = stream.next().await {
match item {
Ok(StreamedAssistantContent::Text(text)) => {
content.push_str(&text.text);
on_text_delta(&text.text).await;
chunks.push(StreamChunk {
chunk_type: StreamChunkType::Answer,
content: text.text,
});
}
Ok(StreamedAssistantContent::ToolCall {
tool_call,
internal_call_id,
}) => {
if skip_tool_calls {
partial_tool_calls.remove(&internal_call_id);
continue;
}
let arguments = match &tool_call.function.arguments {
serde_json::Value::String(s) => s.clone(),
other => serde_json::to_string(other).unwrap_or_else(|_| "{}".to_string()),
};
let tc = StreamedToolCall {
id: tool_call.id.clone(),
name: tool_call.function.name.clone(),
arguments,
};
on_tool_call(&tc).await;
chunks.push(StreamChunk {
chunk_type: StreamChunkType::ToolCall,
content: serde_json::json!({
"id": tc.id,
"name": tc.name,
"arguments": tc.arguments,
}).to_string(),
});
tool_calls.push(tc);
partial_tool_calls.remove(&internal_call_id);
}
Ok(StreamedAssistantContent::ToolCallDelta {
id,
internal_call_id,
content: delta_content,
}) => {
if skip_tool_calls {
continue;
}
use rig::streaming::ToolCallDeltaContent;
match delta_content {
ToolCallDeltaContent::Name(name) => {
partial_tool_calls.insert(
internal_call_id.clone(),
StreamedToolCall {
id: id.clone(),
name,
arguments: String::new(),
},
);
}
ToolCallDeltaContent::Delta(delta) => {
if let Some(tc) = partial_tool_calls.get_mut(&internal_call_id) {
tc.arguments.push_str(&delta);
}
}
}
}
Ok(StreamedAssistantContent::Reasoning(reasoning)) => {
for part in &reasoning.reasoning {
reasoning_content.push_str(part);
on_reasoning_delta(part).await;
chunks.push(StreamChunk {
chunk_type: StreamChunkType::Thinking,
content: part.clone(),
});
}
}
Ok(StreamedAssistantContent::ReasoningDelta { reasoning, .. }) => {
reasoning_content.push_str(&reasoning);
on_reasoning_delta(&reasoning).await;
chunks.push(StreamChunk {
chunk_type: StreamChunkType::Thinking,
content: reasoning.clone(),
});
}
Ok(StreamedAssistantContent::Final(response)) => {
stream_finished = true;
if !skip_tool_calls {
for (_, tc) in partial_tool_calls.drain() {
tool_calls.push(tc);
}
} else {
partial_tool_calls.drain();
}
if let Some(usage) = response.token_usage() {
let in_toks = usage.input_tokens as i64;
let out_toks = usage.output_tokens as i64;
ai_metrics().record_success(in_toks, out_toks, !tool_calls.is_empty());
return Ok(StreamResponse {
content,
reasoning_content,
input_tokens: in_toks,
output_tokens: out_toks,
tool_calls,
chunks,
});
}
// Usage not available from Final — fall through to flush
}
Err(e) => return Err(AgentError::OpenAi(e.to_string())),
}
}
// Flush any remaining partial tool calls (if stream ended without Final or Final had no usage)
if !stream_finished && !skip_tool_calls {
for (_, tc) in partial_tool_calls.drain() {
tool_calls.push(tc);
}
}
ai_metrics().record_success(0, 0, !tool_calls.is_empty());
Ok(StreamResponse {
content,
reasoning_content,
input_tokens: 0,
output_tokens: 0,
tool_calls,
chunks,
})
};
// 120s timeout for the entire stream
match tokio::time::timeout(std::time::Duration::from_secs(120), stream_fut).await {
Ok(result) => result,
Err(_) => Err(AgentError::Timeout { task_id: 0, seconds: 120 }),
}
}