255 lines
7.9 KiB
Rust
255 lines
7.9 KiB
Rust
use std::time::Duration;
|
|
|
|
use crate::error::AiError;
|
|
|
|
/// Categorized error for deciding retry/fallback/fatal strategy.
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub enum ErrorCategory {
|
|
/// Transient error, safe to retry with backoff.
|
|
Retryable { reason: String },
|
|
/// Authentication or quota error, switch to fallback model.
|
|
FallbackModel { reason: String },
|
|
/// Non-recoverable error, do not retry.
|
|
Fatal { reason: String },
|
|
/// Token budget exceeded for this run.
|
|
TokenBudgetExceeded,
|
|
/// Request timed out.
|
|
Timeout,
|
|
/// Request was cancelled by the caller.
|
|
Cancelled,
|
|
/// Provider is overloaded or at capacity, retry with longer delay.
|
|
Overloaded { reason: String },
|
|
/// Context window exceeded, needs compaction before retry.
|
|
ContextWindowExceeded { reason: String },
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct RetryPolicy {
|
|
pub max_attempts: usize,
|
|
pub base_delay: Duration,
|
|
pub jitter: bool,
|
|
pub exponential: bool,
|
|
pub switch_to_fallback: bool,
|
|
}
|
|
|
|
impl RetryPolicy {
|
|
pub fn delay_for_attempt(&self, attempt: usize) -> Duration {
|
|
let ms = if self.exponential {
|
|
self.base_delay.as_millis() as u64 * (1u64 << attempt.min(6))
|
|
} else {
|
|
self.base_delay.as_millis() as u64
|
|
};
|
|
|
|
let ms = if self.jitter {
|
|
let half = (ms as f64 * 0.25) as u64;
|
|
let lo = ms.saturating_sub(half);
|
|
let hi = ms.saturating_add(half);
|
|
let mix =
|
|
((attempt as u64).wrapping_mul(1_103_515_245)) % (hi - lo + 1);
|
|
lo + mix
|
|
} else {
|
|
ms
|
|
};
|
|
|
|
Duration::from_millis(ms.max(100))
|
|
}
|
|
}
|
|
|
|
/// Classify an error into a category for retry/fallback decisions.
|
|
///
|
|
/// Inspects both the HTTP status code (when available) and the error message
|
|
/// content to determine the most appropriate category.
|
|
pub fn classify_error(
|
|
error: &AiError,
|
|
http_status: Option<u16>,
|
|
) -> ErrorCategory {
|
|
// HTTP status-based classification takes precedence
|
|
let from_status = match http_status {
|
|
Some(429) => Some(ErrorCategory::Retryable {
|
|
reason: "rate limited (HTTP 429)".to_string(),
|
|
}),
|
|
Some(401) | Some(403) => Some(ErrorCategory::FallbackModel {
|
|
reason: format!(
|
|
"authentication failed (HTTP {})",
|
|
http_status.unwrap()
|
|
),
|
|
}),
|
|
Some(502) | Some(503) => Some(ErrorCategory::Overloaded {
|
|
reason: format!(
|
|
"provider unavailable (HTTP {})",
|
|
http_status.unwrap()
|
|
),
|
|
}),
|
|
Some(504) => Some(ErrorCategory::Timeout),
|
|
Some(413) => Some(ErrorCategory::ContextWindowExceeded {
|
|
reason: "payload too large (HTTP 413)".to_string(),
|
|
}),
|
|
Some(s) if (400..500).contains(&s) => Some(ErrorCategory::Fatal {
|
|
reason: format!("client error (HTTP {})", s),
|
|
}),
|
|
Some(s) if (500..600).contains(&s) => Some(ErrorCategory::Retryable {
|
|
reason: format!("server error (HTTP {})", s),
|
|
}),
|
|
_ => None,
|
|
};
|
|
|
|
if let Some(cat) = from_status {
|
|
return cat;
|
|
}
|
|
|
|
// Message-based classification
|
|
match error {
|
|
AiError::Timeout { .. } => ErrorCategory::Timeout,
|
|
AiError::TokenBudgetExceeded { .. } => {
|
|
ErrorCategory::TokenBudgetExceeded
|
|
}
|
|
AiError::Api(msg) => classify_api_message(msg),
|
|
AiError::Response(msg) => classify_response_message(msg),
|
|
AiError::ModelRetriesExhausted { .. } => ErrorCategory::Fatal {
|
|
reason: error.to_string(),
|
|
},
|
|
_ => ErrorCategory::Fatal {
|
|
reason: error.to_string(),
|
|
},
|
|
}
|
|
}
|
|
|
|
/// Classify API error messages by keyword patterns.
|
|
fn classify_api_message(msg: &str) -> ErrorCategory {
|
|
let lower = msg.to_lowercase();
|
|
|
|
// Rate limiting
|
|
if lower.contains("rate")
|
|
|| lower.contains("too many requests")
|
|
|| lower.contains("throttl")
|
|
{
|
|
return ErrorCategory::Retryable {
|
|
reason: msg.to_string(),
|
|
};
|
|
}
|
|
|
|
// Overloaded / capacity
|
|
if lower.contains("overloaded")
|
|
|| lower.contains("capacity")
|
|
|| lower.contains("too busy")
|
|
|| lower.contains("service unavailable")
|
|
{
|
|
return ErrorCategory::Overloaded {
|
|
reason: msg.to_string(),
|
|
};
|
|
}
|
|
|
|
// Authentication / quota
|
|
if lower.contains("unauthorized")
|
|
|| lower.contains("invalid api key")
|
|
|| lower.contains("api key")
|
|
|| lower.contains("forbidden")
|
|
|| lower.contains("quota exceeded")
|
|
|| lower.contains("insufficient")
|
|
|| lower.contains("billing")
|
|
{
|
|
return ErrorCategory::FallbackModel {
|
|
reason: msg.to_string(),
|
|
};
|
|
}
|
|
|
|
// Context window exceeded
|
|
if lower.contains("context length")
|
|
|| lower.contains("context window")
|
|
|| lower.contains("maximum context")
|
|
|| lower.contains("too many tokens")
|
|
|| lower.contains("max_tokens")
|
|
{
|
|
return ErrorCategory::ContextWindowExceeded {
|
|
reason: msg.to_string(),
|
|
};
|
|
}
|
|
|
|
ErrorCategory::Fatal {
|
|
reason: msg.to_string(),
|
|
}
|
|
}
|
|
|
|
/// Classify response error messages by keyword patterns.
|
|
fn classify_response_message(msg: &str) -> ErrorCategory {
|
|
let lower = msg.to_lowercase();
|
|
|
|
if lower.contains("cancelled") || lower.contains("canceled") {
|
|
return ErrorCategory::Cancelled;
|
|
}
|
|
if lower.contains("timeout") || lower.contains("timed out") {
|
|
return ErrorCategory::Timeout;
|
|
}
|
|
|
|
ErrorCategory::Fatal {
|
|
reason: msg.to_string(),
|
|
}
|
|
}
|
|
|
|
/// Get the recommended retry policy for an error category.
|
|
pub fn retry_policy_for(
|
|
category: &ErrorCategory,
|
|
max_attempts: usize,
|
|
base_delay_ms: u64,
|
|
) -> RetryPolicy {
|
|
match category {
|
|
ErrorCategory::Retryable { .. } => RetryPolicy {
|
|
max_attempts,
|
|
base_delay: Duration::from_millis(base_delay_ms),
|
|
jitter: true,
|
|
exponential: true,
|
|
switch_to_fallback: false,
|
|
},
|
|
ErrorCategory::Overloaded { .. } => RetryPolicy {
|
|
max_attempts: max_attempts.min(5),
|
|
base_delay: Duration::from_millis(base_delay_ms.max(5_000)),
|
|
jitter: true,
|
|
exponential: true,
|
|
switch_to_fallback: true,
|
|
},
|
|
ErrorCategory::FallbackModel { .. } => RetryPolicy {
|
|
max_attempts: 1,
|
|
base_delay: Duration::from_millis(500),
|
|
jitter: false,
|
|
exponential: false,
|
|
switch_to_fallback: true,
|
|
},
|
|
ErrorCategory::ContextWindowExceeded { .. } => RetryPolicy {
|
|
max_attempts: 1,
|
|
base_delay: Duration::from_millis(0),
|
|
jitter: false,
|
|
exponential: false,
|
|
switch_to_fallback: false,
|
|
},
|
|
ErrorCategory::Timeout => RetryPolicy {
|
|
max_attempts: max_attempts.min(2),
|
|
base_delay: Duration::from_millis(base_delay_ms.max(2_000)),
|
|
jitter: true,
|
|
exponential: false,
|
|
switch_to_fallback: false,
|
|
},
|
|
ErrorCategory::TokenBudgetExceeded
|
|
| ErrorCategory::Cancelled
|
|
| ErrorCategory::Fatal { .. } => RetryPolicy {
|
|
max_attempts: 0,
|
|
base_delay: Duration::from_millis(0),
|
|
jitter: false,
|
|
exponential: false,
|
|
switch_to_fallback: false,
|
|
},
|
|
}
|
|
}
|
|
|
|
/// Determine whether the error warrants switching to a fallback model.
|
|
pub fn should_switch_to_fallback(category: &ErrorCategory) -> bool {
|
|
matches!(
|
|
category,
|
|
ErrorCategory::FallbackModel { .. } | ErrorCategory::Overloaded { .. }
|
|
)
|
|
}
|
|
|
|
/// Determine whether compaction should be attempted before retry.
|
|
pub fn should_compact_before_retry(category: &ErrorCategory) -> bool {
|
|
matches!(category, ErrorCategory::ContextWindowExceeded { .. })
|
|
}
|