gitdataai/libs/agent/client/mod.rs
ZhenYi 08045eef63 refactor(agent): enhance chat service with state management and billing
Add persistent chat session state (ChatState, sequence tracking, tool
calls). Introduce basic billing record in agent crate. Refine chat
service to route messages through state machine with tool support.
2026-04-30 19:16:44 +08:00

769 lines
27 KiB
Rust

//! Unified AI client with built-in retry, token tracking, and session recording.
//!
//! Uses rig-core as the underlying AI provider library.
pub mod types;
pub use types::{ChatRequestMessage, ToolCall as ClientToolCall};
use std::pin::Pin;
use std::sync::Arc;
use std::time::Instant;
use uuid::Uuid;
use crate::error::{AgentError, Result};
use futures::StreamExt;
use rig::completion::message::{AssistantContent, Message as RigMessage};
use rig::completion::{GetTokenUsage, ToolDefinition, CompletionModel};
use rig::one_or_many::OneOrMany;
use rig::prelude::CompletionClient;
use rig::providers::openai;
/// AI call metrics — increments metrics crate counters for all AI calls.
#[derive(Debug, Clone, Default)]
pub struct AiMetrics;
impl AiMetrics {
pub fn new() -> Self { Self }
pub fn record_success(&self, input_tokens: i64, output_tokens: i64, has_function_call: bool) {
metrics::counter!("ai_calls_total").increment(1);
metrics::counter!("ai_calls_success").increment(1);
if input_tokens > 0 {
metrics::counter!("ai_input_tokens_total").increment(input_tokens as u64);
}
if output_tokens > 0 {
metrics::counter!("ai_output_tokens_total").increment(output_tokens as u64);
}
if has_function_call {
metrics::counter!("ai_function_calls_total").increment(1);
}
}
pub fn record_failure(&self) {
metrics::counter!("ai_calls_total").increment(1);
metrics::counter!("ai_calls_failure").increment(1);
}
}
/// Configuration for the AI client.
#[derive(Clone)]
pub struct AiClientConfig {
pub api_key: String,
pub base_url: Option<String>,
}
impl AiClientConfig {
pub fn new(api_key: String) -> Self {
Self { api_key, base_url: None }
}
pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
self.base_url = Some(base_url.into());
self
}
/// Build a rig OpenAI client from this config.
pub fn build_rig_client(&self) -> openai::Client {
let base = self.base_url.clone().unwrap_or_else(|| "https://api.openai.com".to_string());
openai::Client::builder()
.api_key(&self.api_key)
.base_url(&base)
.build()
.expect("Failed to build rig OpenAI client")
}
}
/// Response from an AI call, including usage statistics.
#[derive(Debug, Clone)]
pub struct AiCallResponse {
pub content: String,
pub input_tokens: i64,
pub output_tokens: i64,
pub latency_ms: i64,
pub tool_calls_finished: Vec<String>,
}
impl AiCallResponse {
pub fn total_tokens(&self) -> i64 {
self.input_tokens + self.output_tokens
}
}
/// Internal state for retry tracking.
#[derive(Debug)]
struct RetryState {
attempt: u32,
max_retries: u32,
max_backoff_ms: u64,
}
impl RetryState {
fn new(max_retries: u32) -> Self {
Self { attempt: 0, max_retries, max_backoff_ms: 5000 }
}
fn should_retry(&self) -> bool { self.attempt < self.max_retries }
fn backoff_duration(&self) -> std::time::Duration {
let exp = self.attempt.min(5);
let base_ms = 500u64.saturating_mul(2u64.pow(exp)).min(self.max_backoff_ms);
let jitter = fastrand_u64(base_ms + 1);
std::time::Duration::from_millis(jitter)
}
fn next(&mut self) { self.attempt += 1; }
}
fn fastrand_u64(n: u64) -> u64 {
use std::sync::atomic::{AtomicU64, Ordering};
static STATE: AtomicU64 = AtomicU64::new(0x193_667_6a_5e_7c_57);
if n <= 1 { return 0; }
let mut current = STATE.load(Ordering::Relaxed);
loop {
let new_val = current.wrapping_mul(6364136223846793005).wrapping_add(1);
match STATE.compare_exchange_weak(current, new_val, Ordering::Relaxed, Ordering::Relaxed) {
Ok(_) => return new_val % n,
Err(actual) => current = actual,
}
}
}
fn is_retryable_error(err: &AgentError) -> bool {
let msg = err.to_string();
msg.contains("connection refused")
|| msg.contains("connection timed out")
|| msg.contains("network error")
|| msg.contains("dns error")
|| msg.contains("error sending request")
|| msg.contains("Http client error")
|| msg.contains("rate_limit")
|| msg.contains("rate limit")
|| msg.contains("429")
|| msg.contains("500")
|| msg.contains("502")
|| msg.contains("503")
|| msg.contains("504")
|| msg.contains("internal_server_error")
|| msg.contains("service_unavailable")
|| msg.contains("gateway_timeout")
|| msg.contains("bad_gateway")
}
static AI_METRICS: std::sync::OnceLock<AiMetrics> = std::sync::OnceLock::new();
fn ai_metrics() -> &'static AiMetrics {
AI_METRICS.get_or_init(AiMetrics::new)
}
// ── Type conversions ─────────────────────────────────────────────────────────
pub(crate) fn to_rig_message(msg: &ChatRequestMessage) -> RigMessage {
match msg.role.as_str() {
"system" => {
// System messages are handled via preamble(), but we still
// need to return something. Return a system message as User for safety.
RigMessage::user(msg.content.as_deref().unwrap_or(""))
}
"user" => {
RigMessage::user(msg.content.as_deref().unwrap_or(""))
}
"assistant" => {
let mut parts: Vec<AssistantContent> = Vec::new();
if let Some(ref content) = msg.content {
if !content.is_empty() {
parts.push(AssistantContent::text(content));
}
}
if let Some(ref tool_calls) = msg.tool_calls {
for tc in tool_calls {
// GLM may return empty tool call IDs — fall back to a generated UUID.
let id = if tc.id.is_empty() {
Uuid::new_v4().to_string()
} else {
tc.id.clone()
};
parts.push(AssistantContent::tool_call_with_call_id(
&id,
id.clone(),
&tc.function.name,
serde_json::from_str(&tc.function.arguments).unwrap_or(serde_json::Value::Null),
));
}
}
if parts.is_empty() {
RigMessage::assistant("")
} else if parts.len() == 1 {
// Single part — use simpler constructors
match parts.pop().unwrap() {
AssistantContent::Text(t) => RigMessage::assistant(t.text),
ac => {
RigMessage::Assistant {
id: None,
content: OneOrMany::one(ac),
}
}
}
} else {
let content = OneOrMany::many(parts).expect("non-empty parts");
RigMessage::Assistant { id: None, content }
}
}
"tool" | "function" => {
let id = msg.tool_call_id.as_deref().unwrap_or("unknown").to_string();
let call_id = msg.tool_call_id.clone().or_else(|| Some(id.clone()));
let content = msg.content.as_deref().unwrap_or("");
RigMessage::tool_result_with_call_id(id, call_id, content)
}
"developer" => {
// Developer role maps to user/system in rig
RigMessage::user(msg.content.as_deref().unwrap_or(""))
}
_ => RigMessage::user(msg.content.as_deref().unwrap_or("")),
}
}
fn to_rig_tool_def(tool_json: &serde_json::Value) -> Option<ToolDefinition> {
let name = tool_json
.get("function")
.and_then(|f| f.get("name"))
.and_then(|n| n.as_str())?
.to_string();
let description = tool_json
.get("function")
.and_then(|f| f.get("description"))
.and_then(|d| d.as_str())
.map(|s| s.to_string())
.unwrap_or_default();
let parameters = tool_json
.get("function")
.and_then(|f| f.get("parameters"))
.cloned()
.unwrap_or(serde_json::json!({}));
Some(ToolDefinition {
name,
description,
parameters,
})
}
// ── Call helpers ─────────────────────────────────────────────────────────────
async fn do_completion<M>(
model: &M,
messages: &[ChatRequestMessage],
temperature: Option<f64>,
max_tokens: Option<u32>,
tools: Option<&[serde_json::Value]>,
tool_choice: Option<&str>,
) -> Result<(String, u64, u64, Vec<String>)>
where
M: CompletionModel<Client = openai::Client>,
{
let mut history: Vec<RigMessage> = messages.iter().map(to_rig_message).collect();
// Extract preamble (first system message) and remove from history
let preamble = messages
.iter()
.find(|m| m.role == "system")
.and_then(|m| m.content.as_deref())
.unwrap_or("")
.to_string();
history.retain(|m| !matches!(m, RigMessage::User { .. } | RigMessage::Assistant { .. }));
// For tool_result messages, we need to add them back
// Actually, let's keep the approach: filter out system, add others back
// The rig completion request uses: preamble (system) + messages (conversation)
// For our messages: system → preamble, rest → messages
let non_system: Vec<RigMessage> = messages
.iter()
.filter(|m| m.role != "system")
.map(to_rig_message)
.collect();
let tool_defs: Vec<ToolDefinition> = tools
.map(|ts| ts.iter().filter_map(to_rig_tool_def).collect())
.unwrap_or_default();
let mut builder = model.completion_request("");
if !preamble.is_empty() {
builder = builder.preamble(preamble);
}
if !non_system.is_empty() {
builder = builder.messages(non_system);
}
if let Some(t) = temperature {
builder = builder.temperature(t);
}
if let Some(mt) = max_tokens {
builder = builder.max_tokens(mt as u64);
}
if !tool_defs.is_empty() {
builder = builder.tools(tool_defs);
}
// Only set tool_choice when explicitly provided (mirrors call_stream_once logic)
if let Some(tc) = tool_choice {
match tc {
"none" => {
builder = builder.tool_choice(rig::completion::message::ToolChoice::None);
}
"auto" => {
builder = builder.tool_choice(rig::completion::message::ToolChoice::Auto);
}
s => {
builder = builder.tool_choice(
rig::completion::message::ToolChoice::Specific {
function_names: vec![s.to_string()],
},
);
}
}
}
let response = builder.send().await.map_err(|e| AgentError::OpenAi(e.to_string()))?;
let mut content = String::new();
let mut tool_names: Vec<String> = Vec::new();
for item in response.choice {
match item {
AssistantContent::Text(t) => {
content.push_str(&t.text);
}
AssistantContent::ToolCall(tc) => {
tool_names.push(tc.function.name.clone());
}
AssistantContent::Reasoning(_) => {}
AssistantContent::Image(_) => {}
}
}
let input_tokens = response.usage.input_tokens;
let output_tokens = response.usage.output_tokens;
Ok((content, input_tokens, output_tokens, tool_names))
}
// ── Public API ───────────────────────────────────────────────────────────────
/// Call the AI model with automatic retry (no custom params).
pub async fn call_with_retry(
messages: &[ChatRequestMessage],
model_name: &str,
config: &AiClientConfig,
max_retries: Option<u32>,
) -> Result<AiCallResponse> {
let client = config.build_rig_client();
let model = client.completion_model(model_name);
let mut state = RetryState::new(max_retries.unwrap_or(3));
loop {
let start = Instant::now();
let result = do_completion(&model, messages, None, None, None, None).await;
match result {
Ok((content, input_tokens, output_tokens, tool_names)) => {
let latency_ms = start.elapsed().as_millis() as i64;
let has_function_call = !tool_names.is_empty();
ai_metrics().record_success(input_tokens as i64, output_tokens as i64, has_function_call);
return Ok(AiCallResponse { content, input_tokens: input_tokens as i64, output_tokens: output_tokens as i64, latency_ms, tool_calls_finished: tool_names });
}
Err(ref err) if state.should_retry() && is_retryable_error(err) => {
let duration = state.backoff_duration();
tracing::warn!(
attempt = state.attempt + 1,
max_retries = state.max_retries,
backoff_ms = duration.as_millis() as u64,
model = %model_name,
error = %err,
"ai_call_retry"
);
tokio::time::sleep(duration).await;
state.next();
}
Err(err) => {
ai_metrics().record_failure();
return Err(err);
}
}
}
}
/// Call with custom parameters (temperature, max_tokens, optional tools, optional tool_choice).
pub async fn call_with_params(
messages: &[ChatRequestMessage],
model_name: &str,
config: &AiClientConfig,
temperature: f32,
max_tokens: u32,
max_retries: Option<u32>,
tools: Option<&[serde_json::Value]>,
tool_choice: Option<&str>,
) -> Result<AiCallResponse> {
let client = config.build_rig_client();
let model = client.completion_model(model_name);
let mut state = RetryState::new(max_retries.unwrap_or(3));
loop {
let start = Instant::now();
let result = do_completion(
&model,
messages,
Some(temperature as f64),
Some(max_tokens),
tools,
tool_choice,
)
.await;
match result {
Ok((content, input_tokens, output_tokens, tool_names)) => {
let latency_ms = start.elapsed().as_millis() as i64;
let has_function_call = !tool_names.is_empty();
ai_metrics().record_success(input_tokens as i64, output_tokens as i64, has_function_call);
return Ok(AiCallResponse { content, input_tokens: input_tokens as i64, output_tokens: output_tokens as i64, latency_ms, tool_calls_finished: tool_names });
}
Err(ref err) if state.should_retry() && is_retryable_error(err) => {
let duration = state.backoff_duration();
tracing::warn!(
attempt = state.attempt + 1,
max_retries = state.max_retries,
backoff_ms = duration.as_millis() as u64,
model = %model_name,
error = %err,
"ai_call_retry"
);
tokio::time::sleep(duration).await;
state.next();
}
Err(err) => {
ai_metrics().record_failure();
return Err(err);
}
}
}
}
/// A tool call extracted from streaming response with accumulated arguments.
#[derive(Debug, Clone)]
pub struct StreamedToolCall {
/// Tool call ID
pub id: String,
/// Tool function name
pub name: String,
/// Accumulated JSON arguments string
pub arguments: String,
}
/// Type of chunk in the streaming response, preserving arrival order.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum StreamChunkType {
Thinking,
Answer,
ToolCall,
}
/// A single chunk from the streaming response in arrival order.
#[derive(Debug, Clone)]
pub struct StreamChunk {
pub chunk_type: StreamChunkType,
pub content: String,
}
/// Streaming result from rig.
#[derive(Debug)]
pub struct StreamResponse {
pub content: String,
pub input_tokens: i64,
pub output_tokens: i64,
/// Accumulated reasoning/thinking text from the model.
pub reasoning_content: String,
/// Full tool calls with accumulated arguments (not just names)
pub tool_calls: Vec<StreamedToolCall>,
/// All chunks in arrival order — preserves think/answer/tool interleaving.
pub chunks: Vec<StreamChunk>,
}
/// Async callback: takes a string delta and broadcasts it to the WebSocket.
/// The returned Future must be awaited by the caller.
pub type StreamTextCb = Arc<dyn Fn(&str) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>> + Send + Sync>;
pub type StreamReasoningCb = Arc<dyn Fn(&str) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>> + Send + Sync>;
pub type StreamToolCallCb = Arc<dyn Fn(&StreamedToolCall) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>> + Send + Sync>;
/// Run a streaming chat completion with 60s timeout and 5 retries.
pub async fn call_stream(
messages: &[ChatRequestMessage],
model_name: &str,
config: &AiClientConfig,
temperature: f32,
max_tokens: u32,
tools: Option<&[serde_json::Value]>,
tool_choice: Option<&str>,
on_text_delta: StreamTextCb,
on_reasoning_delta: StreamReasoningCb,
on_tool_call: StreamToolCallCb,
) -> Result<StreamResponse> {
let mut state = RetryState::new(5);
loop {
let result = call_stream_once(
messages, model_name, config, temperature, max_tokens, tools, tool_choice,
on_text_delta.clone(), on_reasoning_delta.clone(), on_tool_call.clone(),
)
.await;
match result {
Ok(response) => return Ok(response),
Err(ref err) if state.should_retry() && is_retryable_error(err) => {
let duration = state.backoff_duration();
tracing::warn!(
attempt = state.attempt + 1,
max_retries = 5,
backoff_ms = duration.as_millis() as u64,
model = %model_name,
error = %err,
"ai_stream_retry"
);
tokio::time::sleep(duration).await;
state.next();
}
Err(err) => {
ai_metrics().record_failure();
return Err(err);
}
}
}
}
/// Single attempt of streaming completion with 60s timeout.
async fn call_stream_once(
messages: &[ChatRequestMessage],
model_name: &str,
config: &AiClientConfig,
temperature: f32,
max_tokens: u32,
tools: Option<&[serde_json::Value]>,
tool_choice: Option<&str>,
on_text_delta: StreamTextCb,
on_reasoning_delta: StreamReasoningCb,
on_tool_call: StreamToolCallCb,
) -> Result<StreamResponse> {
let client = config.build_rig_client();
let model = client.completion_model(model_name);
let preamble = messages
.iter()
.find(|m| m.role == "system")
.and_then(|m| m.content.as_deref())
.unwrap_or("")
.to_string();
let non_system: Vec<RigMessage> = messages
.iter()
.filter(|m| m.role != "system")
.map(to_rig_message)
.collect();
let tool_defs: Vec<ToolDefinition> = tools
.map(|ts| ts.iter().filter_map(to_rig_tool_def).collect())
.unwrap_or_default();
let mut builder = model
.completion_request("")
.temperature(temperature as f64)
.max_tokens(max_tokens as u64);
if !preamble.is_empty() {
builder = builder.preamble(preamble);
}
if !non_system.is_empty() {
builder = builder.messages(non_system);
}
if !tool_defs.is_empty() {
builder = builder.tools(tool_defs);
}
if let Some(tc) = tool_choice {
match tc {
"none" => {
builder = builder.tool_choice(rig::completion::message::ToolChoice::None);
}
"auto" => {
builder = builder.tool_choice(rig::completion::message::ToolChoice::Auto);
}
s => {
builder = builder.tool_choice(
rig::completion::message::ToolChoice::Specific {
function_names: vec![s.to_string()],
},
);
}
}
}
let stream_fut = async {
let mut stream = builder
.stream()
.await
.map_err(|e| AgentError::OpenAi(e.to_string()))?;
let mut content = String::new();
let mut reasoning_content = String::new();
let mut tool_calls: Vec<StreamedToolCall> = Vec::new();
let mut chunks: Vec<StreamChunk> = Vec::new();
// Some models (e.g. GLM) ignore tool_choice="none" and still emit tool_calls.
// Filter them out so they don't cause spurious tool execution attempts.
let skip_tool_calls = tool_choice == Some("none");
use std::collections::HashMap;
let mut partial_tool_calls: HashMap<String, StreamedToolCall> = HashMap::new();
let mut stream_finished = false;
use rig::streaming::StreamedAssistantContent;
while let Some(item) = stream.next().await {
match item {
Ok(StreamedAssistantContent::Text(text)) => {
content.push_str(&text.text);
on_text_delta(&text.text).await;
chunks.push(StreamChunk {
chunk_type: StreamChunkType::Answer,
content: text.text,
});
}
Ok(StreamedAssistantContent::ToolCall {
tool_call,
internal_call_id,
}) => {
if skip_tool_calls {
partial_tool_calls.remove(&internal_call_id);
continue;
}
let arguments = match &tool_call.function.arguments {
serde_json::Value::String(s) => s.clone(),
other => serde_json::to_string(other).unwrap_or_else(|_| "{}".to_string()),
};
let tc = StreamedToolCall {
id: tool_call.id.clone(),
name: tool_call.function.name.clone(),
arguments,
};
on_tool_call(&tc).await;
chunks.push(StreamChunk {
chunk_type: StreamChunkType::ToolCall,
content: serde_json::json!({
"id": tc.id,
"name": tc.name,
"arguments": tc.arguments,
}).to_string(),
});
tool_calls.push(tc);
partial_tool_calls.remove(&internal_call_id);
}
Ok(StreamedAssistantContent::ToolCallDelta {
id,
internal_call_id,
content: delta_content,
}) => {
if skip_tool_calls {
continue;
}
use rig::streaming::ToolCallDeltaContent;
match delta_content {
ToolCallDeltaContent::Name(name) => {
partial_tool_calls.insert(
internal_call_id.clone(),
StreamedToolCall {
id: id.clone(),
name,
arguments: String::new(),
},
);
}
ToolCallDeltaContent::Delta(delta) => {
if let Some(tc) = partial_tool_calls.get_mut(&internal_call_id) {
tc.arguments.push_str(&delta);
}
}
}
}
Ok(StreamedAssistantContent::Reasoning(reasoning)) => {
for part in &reasoning.reasoning {
reasoning_content.push_str(part);
on_reasoning_delta(part).await;
chunks.push(StreamChunk {
chunk_type: StreamChunkType::Thinking,
content: part.clone(),
});
}
}
Ok(StreamedAssistantContent::ReasoningDelta { reasoning, .. }) => {
reasoning_content.push_str(&reasoning);
on_reasoning_delta(&reasoning).await;
chunks.push(StreamChunk {
chunk_type: StreamChunkType::Thinking,
content: reasoning.clone(),
});
}
Ok(StreamedAssistantContent::Final(response)) => {
stream_finished = true;
if !skip_tool_calls {
for (_, tc) in partial_tool_calls.drain() {
tool_calls.push(tc);
}
} else {
partial_tool_calls.drain();
}
if let Some(usage) = response.token_usage() {
let in_toks = usage.input_tokens as i64;
let out_toks = usage.output_tokens as i64;
ai_metrics().record_success(in_toks, out_toks, !tool_calls.is_empty());
return Ok(StreamResponse {
content,
reasoning_content,
input_tokens: in_toks,
output_tokens: out_toks,
tool_calls,
chunks,
});
}
// Usage not available from Final — fall through to flush
}
Err(e) => return Err(AgentError::OpenAi(e.to_string())),
}
}
// Flush any remaining partial tool calls (if stream ended without Final or Final had no usage)
if !stream_finished && !skip_tool_calls {
for (_, tc) in partial_tool_calls.drain() {
tool_calls.push(tc);
}
}
ai_metrics().record_success(0, 0, !tool_calls.is_empty());
Ok(StreamResponse {
content,
reasoning_content,
input_tokens: 0,
output_tokens: 0,
tool_calls,
chunks,
})
};
// 120s timeout for the entire stream
match tokio::time::timeout(std::time::Duration::from_secs(120), stream_fut).await {
Ok(result) => result,
Err(_) => Err(AgentError::Timeout { task_id: 0, seconds: 120 }),
}
}