221 lines
6.3 KiB
Rust
221 lines
6.3 KiB
Rust
use std::sync::Arc;
|
|
use std::time::{Duration, Instant};
|
|
use tokio::sync::Mutex;
|
|
use track::CounterVec;
|
|
|
|
use crate::ChannelError;
|
|
|
|
const STATUS_CLOSED: u8 = 0;
|
|
const STATUS_OPEN: u8 = 1;
|
|
const STATUS_HALF_OPEN: u8 = 2;
|
|
|
|
#[derive(Clone)]
|
|
pub struct CircuitBreaker {
|
|
inner: Arc<Inner>,
|
|
metrics: Option<CircuitBreakerMetrics>,
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
struct CircuitBreakerMetrics {
|
|
transitions: CounterVec,
|
|
calls: CounterVec,
|
|
}
|
|
|
|
impl CircuitBreakerMetrics {
|
|
fn new(registry: &track::MetricsRegistry) -> Self {
|
|
Self {
|
|
transitions: registry
|
|
.register_counter_vec(
|
|
"circuit_breaker_transitions_total",
|
|
"Circuit breaker state transitions",
|
|
&["transition"],
|
|
)
|
|
.expect("failed to register circuit_breaker_transitions_total"),
|
|
calls: registry
|
|
.register_counter_vec(
|
|
"circuit_breaker_calls_total",
|
|
"Circuit breaker call outcomes",
|
|
&["outcome"],
|
|
)
|
|
.expect("failed to register circuit_breaker_calls_total"),
|
|
}
|
|
}
|
|
}
|
|
|
|
struct CircuitState {
|
|
status: u8,
|
|
failure_count: u32,
|
|
success_count: u32,
|
|
half_open_calls: u32,
|
|
last_failure_time: Option<Instant>,
|
|
}
|
|
|
|
struct Inner {
|
|
state: Mutex<CircuitState>,
|
|
config: CircuitConfig,
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
struct CircuitConfig {
|
|
failure_threshold: u32,
|
|
success_threshold: u32,
|
|
timeout: Duration,
|
|
half_open_max_calls: u32,
|
|
}
|
|
|
|
impl CircuitBreaker {
|
|
pub fn new() -> Self {
|
|
Self::with_config(5, 2, Duration::from_secs(60), 3)
|
|
}
|
|
|
|
pub fn with_config(
|
|
failure_threshold: u32,
|
|
success_threshold: u32,
|
|
timeout: Duration,
|
|
half_open_max_calls: u32,
|
|
) -> Self {
|
|
Self {
|
|
inner: Arc::new(Inner {
|
|
state: Mutex::new(CircuitState {
|
|
status: STATUS_CLOSED,
|
|
failure_count: 0,
|
|
success_count: 0,
|
|
half_open_calls: 0,
|
|
last_failure_time: None,
|
|
}),
|
|
config: CircuitConfig {
|
|
failure_threshold,
|
|
success_threshold,
|
|
timeout,
|
|
half_open_max_calls,
|
|
},
|
|
}),
|
|
metrics: None,
|
|
}
|
|
}
|
|
|
|
pub fn set_metrics(&mut self, registry: &track::MetricsRegistry) {
|
|
self.metrics = Some(CircuitBreakerMetrics::new(registry));
|
|
}
|
|
|
|
pub async fn call<F, T>(&self, f: F) -> Result<T, CircuitBreakerError>
|
|
where
|
|
F: std::future::Future<Output = Result<T, ChannelError>>,
|
|
{
|
|
let slot_reserved = {
|
|
let mut state = self.inner.state.lock().await;
|
|
match state.status {
|
|
STATUS_OPEN => match state.last_failure_time {
|
|
Some(t) if t.elapsed() > self.inner.config.timeout => {
|
|
state.status = STATUS_HALF_OPEN;
|
|
state.half_open_calls = 1;
|
|
state.success_count = 0;
|
|
true
|
|
}
|
|
_ => false,
|
|
},
|
|
STATUS_HALF_OPEN => {
|
|
if state.half_open_calls
|
|
< self.inner.config.half_open_max_calls
|
|
{
|
|
state.half_open_calls += 1;
|
|
true
|
|
} else {
|
|
false
|
|
}
|
|
}
|
|
_ => true,
|
|
}
|
|
};
|
|
|
|
if !slot_reserved {
|
|
if let Some(m) = &self.metrics {
|
|
m.calls.with_label_values(&["rejected"]).inc();
|
|
}
|
|
return Err(CircuitBreakerError::Open);
|
|
}
|
|
|
|
match f.await {
|
|
Ok(result) => {
|
|
if let Some(m) = &self.metrics {
|
|
m.calls.with_label_values(&["success"]).inc();
|
|
}
|
|
self.on_success().await;
|
|
Ok(result)
|
|
}
|
|
Err(e) => {
|
|
if let Some(m) = &self.metrics {
|
|
m.calls.with_label_values(&["failure"]).inc();
|
|
}
|
|
self.on_failure().await;
|
|
Err(CircuitBreakerError::Inner(e))
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn on_success(&self) {
|
|
let mut state = self.inner.state.lock().await;
|
|
state.failure_count = 0;
|
|
|
|
if state.status == STATUS_HALF_OPEN {
|
|
state.success_count += 1;
|
|
if state.success_count >= self.inner.config.success_threshold {
|
|
state.status = STATUS_CLOSED;
|
|
state.success_count = 0;
|
|
state.half_open_calls = 0;
|
|
if let Some(m) = &self.metrics {
|
|
m.transitions
|
|
.with_label_values(&["half_open_to_closed"])
|
|
.inc();
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn on_failure(&self) {
|
|
let mut state = self.inner.state.lock().await;
|
|
state.failure_count += 1;
|
|
state.last_failure_time = Some(Instant::now());
|
|
|
|
if state.status == STATUS_HALF_OPEN {
|
|
state.status = STATUS_OPEN;
|
|
state.success_count = 0;
|
|
state.half_open_calls = 0;
|
|
if let Some(m) = &self.metrics {
|
|
m.transitions
|
|
.with_label_values(&["half_open_to_open"])
|
|
.inc();
|
|
}
|
|
} else if state.status == STATUS_CLOSED
|
|
&& state.failure_count >= self.inner.config.failure_threshold
|
|
{
|
|
state.status = STATUS_OPEN;
|
|
if let Some(m) = &self.metrics {
|
|
m.transitions.with_label_values(&["closed_to_open"]).inc();
|
|
}
|
|
}
|
|
}
|
|
|
|
pub async fn is_open(&self) -> bool {
|
|
let state = self.inner.state.lock().await;
|
|
state.status == STATUS_OPEN
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub enum CircuitBreakerError {
|
|
Open,
|
|
Inner(ChannelError),
|
|
}
|
|
|
|
impl std::fmt::Display for CircuitBreakerError {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
match self {
|
|
CircuitBreakerError::Open => write!(f, "Circuit breaker is open"),
|
|
CircuitBreakerError::Inner(e) => write!(f, "{}", e),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl std::error::Error for CircuitBreakerError {}
|