use std::sync::Arc; use std::time::{Duration, Instant}; use tokio::sync::Mutex; use track::CounterVec; use crate::ChannelError; const STATUS_CLOSED: u8 = 0; const STATUS_OPEN: u8 = 1; const STATUS_HALF_OPEN: u8 = 2; #[derive(Clone)] pub struct CircuitBreaker { inner: Arc, metrics: Option, } #[derive(Clone)] struct CircuitBreakerMetrics { transitions: CounterVec, calls: CounterVec, } impl CircuitBreakerMetrics { fn new(registry: &track::MetricsRegistry) -> Self { Self { transitions: registry .register_counter_vec( "circuit_breaker_transitions_total", "Circuit breaker state transitions", &["transition"], ) .expect("failed to register circuit_breaker_transitions_total"), calls: registry .register_counter_vec( "circuit_breaker_calls_total", "Circuit breaker call outcomes", &["outcome"], ) .expect("failed to register circuit_breaker_calls_total"), } } } struct CircuitState { status: u8, failure_count: u32, success_count: u32, half_open_calls: u32, last_failure_time: Option, } struct Inner { state: Mutex, config: CircuitConfig, } #[derive(Clone)] struct CircuitConfig { failure_threshold: u32, success_threshold: u32, timeout: Duration, half_open_max_calls: u32, } impl CircuitBreaker { pub fn new() -> Self { Self::with_config(5, 2, Duration::from_secs(60), 3) } pub fn with_config( failure_threshold: u32, success_threshold: u32, timeout: Duration, half_open_max_calls: u32, ) -> Self { Self { inner: Arc::new(Inner { state: Mutex::new(CircuitState { status: STATUS_CLOSED, failure_count: 0, success_count: 0, half_open_calls: 0, last_failure_time: None, }), config: CircuitConfig { failure_threshold, success_threshold, timeout, half_open_max_calls, }, }), metrics: None, } } pub fn set_metrics(&mut self, registry: &track::MetricsRegistry) { self.metrics = Some(CircuitBreakerMetrics::new(registry)); } pub async fn call(&self, f: F) -> Result where F: std::future::Future>, { let slot_reserved = { let mut state = self.inner.state.lock().await; match state.status { STATUS_OPEN => match state.last_failure_time { Some(t) if t.elapsed() > self.inner.config.timeout => { state.status = STATUS_HALF_OPEN; state.half_open_calls = 1; state.success_count = 0; true } _ => false, }, STATUS_HALF_OPEN => { if state.half_open_calls < self.inner.config.half_open_max_calls { state.half_open_calls += 1; true } else { false } } _ => true, } }; if !slot_reserved { if let Some(m) = &self.metrics { m.calls.with_label_values(&["rejected"]).inc(); } return Err(CircuitBreakerError::Open); } match f.await { Ok(result) => { if let Some(m) = &self.metrics { m.calls.with_label_values(&["success"]).inc(); } self.on_success().await; Ok(result) } Err(e) => { if let Some(m) = &self.metrics { m.calls.with_label_values(&["failure"]).inc(); } self.on_failure().await; Err(CircuitBreakerError::Inner(e)) } } } async fn on_success(&self) { let mut state = self.inner.state.lock().await; state.failure_count = 0; if state.status == STATUS_HALF_OPEN { state.success_count += 1; if state.success_count >= self.inner.config.success_threshold { state.status = STATUS_CLOSED; state.success_count = 0; state.half_open_calls = 0; if let Some(m) = &self.metrics { m.transitions .with_label_values(&["half_open_to_closed"]) .inc(); } } } } async fn on_failure(&self) { let mut state = self.inner.state.lock().await; state.failure_count += 1; state.last_failure_time = Some(Instant::now()); if state.status == STATUS_HALF_OPEN { state.status = STATUS_OPEN; state.success_count = 0; state.half_open_calls = 0; if let Some(m) = &self.metrics { m.transitions .with_label_values(&["half_open_to_open"]) .inc(); } } else if state.status == STATUS_CLOSED && state.failure_count >= self.inner.config.failure_threshold { state.status = STATUS_OPEN; if let Some(m) = &self.metrics { m.transitions.with_label_values(&["closed_to_open"]).inc(); } } } pub async fn is_open(&self) -> bool { let state = self.inner.state.lock().await; state.status == STATUS_OPEN } } #[derive(Debug)] pub enum CircuitBreakerError { Open, Inner(ChannelError), } impl std::fmt::Display for CircuitBreakerError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { CircuitBreakerError::Open => write!(f, "Circuit breaker is open"), CircuitBreakerError::Inner(e) => write!(f, "{}", e), } } } impl std::error::Error for CircuitBreakerError {}