use std::time::Duration; use crate::error::AiError; /// Categorized error for deciding retry/fallback/fatal strategy. #[derive(Debug, Clone, PartialEq, Eq)] pub enum ErrorCategory { /// Transient error, safe to retry with backoff. Retryable { reason: String }, /// Authentication or quota error, switch to fallback model. FallbackModel { reason: String }, /// Non-recoverable error, do not retry. Fatal { reason: String }, /// Token budget exceeded for this run. TokenBudgetExceeded, /// Request timed out. Timeout, /// Request was cancelled by the caller. Cancelled, /// Provider is overloaded or at capacity, retry with longer delay. Overloaded { reason: String }, /// Context window exceeded, needs compaction before retry. ContextWindowExceeded { reason: String }, } #[derive(Debug, Clone)] pub struct RetryPolicy { pub max_attempts: usize, pub base_delay: Duration, pub jitter: bool, pub exponential: bool, pub switch_to_fallback: bool, } impl RetryPolicy { pub fn delay_for_attempt(&self, attempt: usize) -> Duration { let ms = if self.exponential { self.base_delay.as_millis() as u64 * (1u64 << attempt.min(6)) } else { self.base_delay.as_millis() as u64 }; let ms = if self.jitter { let half = (ms as f64 * 0.25) as u64; let lo = ms.saturating_sub(half); let hi = ms.saturating_add(half); let mix = ((attempt as u64).wrapping_mul(1_103_515_245)) % (hi - lo + 1); lo + mix } else { ms }; Duration::from_millis(ms.max(100)) } } /// Classify an error into a category for retry/fallback decisions. /// /// Inspects both the HTTP status code (when available) and the error message /// content to determine the most appropriate category. pub fn classify_error(error: &AiError, http_status: Option) -> ErrorCategory { // HTTP status-based classification takes precedence let from_status = match http_status { Some(429) => Some(ErrorCategory::Retryable { reason: "rate limited (HTTP 429)".to_string(), }), Some(401) | Some(403) => Some(ErrorCategory::FallbackModel { reason: format!("authentication failed (HTTP {})", http_status.unwrap()), }), Some(502) | Some(503) => Some(ErrorCategory::Overloaded { reason: format!("provider unavailable (HTTP {})", http_status.unwrap()), }), Some(504) => Some(ErrorCategory::Timeout), Some(413) => Some(ErrorCategory::ContextWindowExceeded { reason: "payload too large (HTTP 413)".to_string(), }), Some(s) if (400..500).contains(&s) => Some(ErrorCategory::Fatal { reason: format!("client error (HTTP {})", s), }), Some(s) if (500..600).contains(&s) => Some(ErrorCategory::Retryable { reason: format!("server error (HTTP {})", s), }), _ => None, }; if let Some(cat) = from_status { return cat; } // Message-based classification match error { AiError::Timeout { .. } => ErrorCategory::Timeout, AiError::TokenBudgetExceeded { .. } => ErrorCategory::TokenBudgetExceeded, AiError::Api(msg) => classify_api_message(msg), AiError::Response(msg) => classify_response_message(msg), AiError::ModelRetriesExhausted { .. } => ErrorCategory::Fatal { reason: error.to_string(), }, _ => ErrorCategory::Fatal { reason: error.to_string(), }, } } /// Classify API error messages by keyword patterns. fn classify_api_message(msg: &str) -> ErrorCategory { let lower = msg.to_lowercase(); // Rate limiting if lower.contains("rate") || lower.contains("too many requests") || lower.contains("throttl") { return ErrorCategory::Retryable { reason: msg.to_string(), }; } // Overloaded / capacity if lower.contains("overloaded") || lower.contains("capacity") || lower.contains("too busy") || lower.contains("service unavailable") { return ErrorCategory::Overloaded { reason: msg.to_string(), }; } // Authentication / quota if lower.contains("unauthorized") || lower.contains("invalid api key") || lower.contains("api key") || lower.contains("forbidden") || lower.contains("quota exceeded") || lower.contains("insufficient") || lower.contains("billing") { return ErrorCategory::FallbackModel { reason: msg.to_string(), }; } // Context window exceeded if lower.contains("context length") || lower.contains("context window") || lower.contains("maximum context") || lower.contains("too many tokens") || lower.contains("max_tokens") { return ErrorCategory::ContextWindowExceeded { reason: msg.to_string(), }; } ErrorCategory::Fatal { reason: msg.to_string(), } } /// Classify response error messages by keyword patterns. fn classify_response_message(msg: &str) -> ErrorCategory { let lower = msg.to_lowercase(); if lower.contains("cancelled") || lower.contains("canceled") { return ErrorCategory::Cancelled; } if lower.contains("timeout") || lower.contains("timed out") { return ErrorCategory::Timeout; } ErrorCategory::Fatal { reason: msg.to_string(), } } /// Get the recommended retry policy for an error category. pub fn retry_policy_for( category: &ErrorCategory, max_attempts: usize, base_delay_ms: u64, ) -> RetryPolicy { match category { ErrorCategory::Retryable { .. } => RetryPolicy { max_attempts, base_delay: Duration::from_millis(base_delay_ms), jitter: true, exponential: true, switch_to_fallback: false, }, ErrorCategory::Overloaded { .. } => RetryPolicy { max_attempts: max_attempts.min(5), base_delay: Duration::from_millis(base_delay_ms.max(5_000)), jitter: true, exponential: true, switch_to_fallback: true, }, ErrorCategory::FallbackModel { .. } => RetryPolicy { max_attempts: 1, base_delay: Duration::from_millis(500), jitter: false, exponential: false, switch_to_fallback: true, }, ErrorCategory::ContextWindowExceeded { .. } => RetryPolicy { max_attempts: 1, base_delay: Duration::from_millis(0), jitter: false, exponential: false, switch_to_fallback: false, }, ErrorCategory::Timeout => RetryPolicy { max_attempts: max_attempts.min(2), base_delay: Duration::from_millis(base_delay_ms.max(2_000)), jitter: true, exponential: false, switch_to_fallback: false, }, ErrorCategory::TokenBudgetExceeded | ErrorCategory::Cancelled | ErrorCategory::Fatal { .. } => { RetryPolicy { max_attempts: 0, base_delay: Duration::from_millis(0), jitter: false, exponential: false, switch_to_fallback: false, } } } } /// Determine whether the error warrants switching to a fallback model. pub fn should_switch_to_fallback(category: &ErrorCategory) -> bool { matches!( category, ErrorCategory::FallbackModel { .. } | ErrorCategory::Overloaded { .. } ) } /// Determine whether compaction should be attempted before retry. pub fn should_compact_before_retry(category: &ErrorCategory) -> bool { matches!(category, ErrorCategory::ContextWindowExceeded { .. }) }