use std::sync::Arc; use std::time::{Duration, Instant}; use tokio::sync::Mutex; use crate::ChannelError; const STATUS_CLOSED: u8 = 0; const STATUS_OPEN: u8 = 1; const STATUS_HALF_OPEN: u8 = 2; #[derive(Clone)] pub struct CircuitBreaker { inner: Arc, } struct CircuitState { status: u8, failure_count: u32, success_count: u32, half_open_calls: u32, last_failure_time: Option, } struct Inner { state: Mutex, config: CircuitConfig, } #[derive(Clone)] struct CircuitConfig { failure_threshold: u32, success_threshold: u32, timeout: Duration, half_open_max_calls: u32, } impl CircuitBreaker { pub fn new() -> Self { Self::with_config(5, 2, Duration::from_secs(60), 3) } pub fn with_config( failure_threshold: u32, success_threshold: u32, timeout: Duration, half_open_max_calls: u32, ) -> Self { Self { inner: Arc::new(Inner { state: Mutex::new(CircuitState { status: STATUS_CLOSED, failure_count: 0, success_count: 0, half_open_calls: 0, last_failure_time: None, }), config: CircuitConfig { failure_threshold, success_threshold, timeout, half_open_max_calls, }, }), } } pub async fn call(&self, f: F) -> Result where F: std::future::Future>, { let slot_reserved = { let mut state = self.inner.state.lock().await; match state.status { STATUS_OPEN => match state.last_failure_time { Some(t) if t.elapsed() > self.inner.config.timeout => { state.status = STATUS_HALF_OPEN; state.half_open_calls = 1; state.success_count = 0; true } _ => false, }, STATUS_HALF_OPEN => { if state.half_open_calls < self.inner.config.half_open_max_calls { state.half_open_calls += 1; true } else { false } } _ => true, // Closed → allow } }; // Lock released before executing the call. if !slot_reserved { return Err(CircuitBreakerError::Open); } match f.await { Ok(result) => { self.on_success().await; Ok(result) } Err(e) => { self.on_failure().await; Err(CircuitBreakerError::Inner(e)) } } } async fn on_success(&self) { let mut state = self.inner.state.lock().await; state.failure_count = 0; if state.status == STATUS_HALF_OPEN { state.success_count += 1; if state.success_count >= self.inner.config.success_threshold { state.status = STATUS_CLOSED; state.success_count = 0; state.half_open_calls = 0; } } } async fn on_failure(&self) { let mut state = self.inner.state.lock().await; state.failure_count += 1; state.last_failure_time = Some(Instant::now()); if state.status == STATUS_HALF_OPEN { state.status = STATUS_OPEN; state.success_count = 0; state.half_open_calls = 0; } else if state.status == STATUS_CLOSED && state.failure_count >= self.inner.config.failure_threshold { state.status = STATUS_OPEN; } } pub async fn is_open(&self) -> bool { let state = self.inner.state.lock().await; state.status == STATUS_OPEN } } #[derive(Debug)] pub enum CircuitBreakerError { Open, Inner(ChannelError), } impl std::fmt::Display for CircuitBreakerError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { CircuitBreakerError::Open => write!(f, "Circuit breaker is open"), CircuitBreakerError::Inner(e) => write!(f, "{}", e), } } } impl std::error::Error for CircuitBreakerError {}