gitdataai/lib/ai/agent/error_classifier.rs

255 lines
7.9 KiB
Rust

use std::time::Duration;
use crate::error::AiError;
/// Categorized error for deciding retry/fallback/fatal strategy.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ErrorCategory {
/// Transient error, safe to retry with backoff.
Retryable { reason: String },
/// Authentication or quota error, switch to fallback model.
FallbackModel { reason: String },
/// Non-recoverable error, do not retry.
Fatal { reason: String },
/// Token budget exceeded for this run.
TokenBudgetExceeded,
/// Request timed out.
Timeout,
/// Request was cancelled by the caller.
Cancelled,
/// Provider is overloaded or at capacity, retry with longer delay.
Overloaded { reason: String },
/// Context window exceeded, needs compaction before retry.
ContextWindowExceeded { reason: String },
}
#[derive(Debug, Clone)]
pub struct RetryPolicy {
pub max_attempts: usize,
pub base_delay: Duration,
pub jitter: bool,
pub exponential: bool,
pub switch_to_fallback: bool,
}
impl RetryPolicy {
pub fn delay_for_attempt(&self, attempt: usize) -> Duration {
let ms = if self.exponential {
self.base_delay.as_millis() as u64 * (1u64 << attempt.min(6))
} else {
self.base_delay.as_millis() as u64
};
let ms = if self.jitter {
let half = (ms as f64 * 0.25) as u64;
let lo = ms.saturating_sub(half);
let hi = ms.saturating_add(half);
let mix =
((attempt as u64).wrapping_mul(1_103_515_245)) % (hi - lo + 1);
lo + mix
} else {
ms
};
Duration::from_millis(ms.max(100))
}
}
/// Classify an error into a category for retry/fallback decisions.
///
/// Inspects both the HTTP status code (when available) and the error message
/// content to determine the most appropriate category.
pub fn classify_error(
error: &AiError,
http_status: Option<u16>,
) -> ErrorCategory {
// HTTP status-based classification takes precedence
let from_status = match http_status {
Some(429) => Some(ErrorCategory::Retryable {
reason: "rate limited (HTTP 429)".to_string(),
}),
Some(401) | Some(403) => Some(ErrorCategory::FallbackModel {
reason: format!(
"authentication failed (HTTP {})",
http_status.unwrap()
),
}),
Some(502) | Some(503) => Some(ErrorCategory::Overloaded {
reason: format!(
"provider unavailable (HTTP {})",
http_status.unwrap()
),
}),
Some(504) => Some(ErrorCategory::Timeout),
Some(413) => Some(ErrorCategory::ContextWindowExceeded {
reason: "payload too large (HTTP 413)".to_string(),
}),
Some(s) if (400..500).contains(&s) => Some(ErrorCategory::Fatal {
reason: format!("client error (HTTP {})", s),
}),
Some(s) if (500..600).contains(&s) => Some(ErrorCategory::Retryable {
reason: format!("server error (HTTP {})", s),
}),
_ => None,
};
if let Some(cat) = from_status {
return cat;
}
// Message-based classification
match error {
AiError::Timeout { .. } => ErrorCategory::Timeout,
AiError::TokenBudgetExceeded { .. } => {
ErrorCategory::TokenBudgetExceeded
}
AiError::Api(msg) => classify_api_message(msg),
AiError::Response(msg) => classify_response_message(msg),
AiError::ModelRetriesExhausted { .. } => ErrorCategory::Fatal {
reason: error.to_string(),
},
_ => ErrorCategory::Fatal {
reason: error.to_string(),
},
}
}
/// Classify API error messages by keyword patterns.
fn classify_api_message(msg: &str) -> ErrorCategory {
let lower = msg.to_lowercase();
// Rate limiting
if lower.contains("rate")
|| lower.contains("too many requests")
|| lower.contains("throttl")
{
return ErrorCategory::Retryable {
reason: msg.to_string(),
};
}
// Overloaded / capacity
if lower.contains("overloaded")
|| lower.contains("capacity")
|| lower.contains("too busy")
|| lower.contains("service unavailable")
{
return ErrorCategory::Overloaded {
reason: msg.to_string(),
};
}
// Authentication / quota
if lower.contains("unauthorized")
|| lower.contains("invalid api key")
|| lower.contains("api key")
|| lower.contains("forbidden")
|| lower.contains("quota exceeded")
|| lower.contains("insufficient")
|| lower.contains("billing")
{
return ErrorCategory::FallbackModel {
reason: msg.to_string(),
};
}
// Context window exceeded
if lower.contains("context length")
|| lower.contains("context window")
|| lower.contains("maximum context")
|| lower.contains("too many tokens")
|| lower.contains("max_tokens")
{
return ErrorCategory::ContextWindowExceeded {
reason: msg.to_string(),
};
}
ErrorCategory::Fatal {
reason: msg.to_string(),
}
}
/// Classify response error messages by keyword patterns.
fn classify_response_message(msg: &str) -> ErrorCategory {
let lower = msg.to_lowercase();
if lower.contains("cancelled") || lower.contains("canceled") {
return ErrorCategory::Cancelled;
}
if lower.contains("timeout") || lower.contains("timed out") {
return ErrorCategory::Timeout;
}
ErrorCategory::Fatal {
reason: msg.to_string(),
}
}
/// Get the recommended retry policy for an error category.
pub fn retry_policy_for(
category: &ErrorCategory,
max_attempts: usize,
base_delay_ms: u64,
) -> RetryPolicy {
match category {
ErrorCategory::Retryable { .. } => RetryPolicy {
max_attempts,
base_delay: Duration::from_millis(base_delay_ms),
jitter: true,
exponential: true,
switch_to_fallback: false,
},
ErrorCategory::Overloaded { .. } => RetryPolicy {
max_attempts: max_attempts.min(5),
base_delay: Duration::from_millis(base_delay_ms.max(5_000)),
jitter: true,
exponential: true,
switch_to_fallback: true,
},
ErrorCategory::FallbackModel { .. } => RetryPolicy {
max_attempts: 1,
base_delay: Duration::from_millis(500),
jitter: false,
exponential: false,
switch_to_fallback: true,
},
ErrorCategory::ContextWindowExceeded { .. } => RetryPolicy {
max_attempts: 1,
base_delay: Duration::from_millis(0),
jitter: false,
exponential: false,
switch_to_fallback: false,
},
ErrorCategory::Timeout => RetryPolicy {
max_attempts: max_attempts.min(2),
base_delay: Duration::from_millis(base_delay_ms.max(2_000)),
jitter: true,
exponential: false,
switch_to_fallback: false,
},
ErrorCategory::TokenBudgetExceeded
| ErrorCategory::Cancelled
| ErrorCategory::Fatal { .. } => RetryPolicy {
max_attempts: 0,
base_delay: Duration::from_millis(0),
jitter: false,
exponential: false,
switch_to_fallback: false,
},
}
}
/// Determine whether the error warrants switching to a fallback model.
pub fn should_switch_to_fallback(category: &ErrorCategory) -> bool {
matches!(
category,
ErrorCategory::FallbackModel { .. } | ErrorCategory::Overloaded { .. }
)
}
/// Determine whether compaction should be attempted before retry.
pub fn should_compact_before_retry(category: &ErrorCategory) -> bool {
matches!(category, ErrorCategory::ContextWindowExceeded { .. })
}