gitdataai/lib/cache/local.rs

144 lines
3.6 KiB
Rust

use std::{sync::Arc, time::Duration};
use moka::future::{Cache, CacheBuilder};
use serde::{Serialize, de::DeserializeOwned};
use crate::{CacheError, CacheResult};
const DEFAULT_LOCAL_MAX_CAPACITY: u64 = 10_000;
#[derive(Clone, Debug)]
pub struct LocalCacheConfig {
pub max_capacity: u64,
pub time_to_live: Option<Duration>,
pub time_to_idle: Option<Duration>,
}
impl Default for LocalCacheConfig {
fn default() -> Self {
Self {
max_capacity: DEFAULT_LOCAL_MAX_CAPACITY,
time_to_live: Some(Duration::from_secs(300)),
time_to_idle: None,
}
}
}
#[derive(Clone)]
pub struct MokaCache {
pub inner: Cache<Arc<str>, Arc<[u8]>>,
}
impl MokaCache {
pub fn init() -> Self {
Self::with_config(LocalCacheConfig::default())
}
pub fn with_config(config: LocalCacheConfig) -> Self {
let mut builder = CacheBuilder::new(config.max_capacity);
if let Some(time_to_live) = config.time_to_live {
builder = builder.time_to_live(time_to_live);
}
if let Some(time_to_idle) = config.time_to_idle {
builder = builder.time_to_idle(time_to_idle);
}
Self {
inner: builder.build(),
}
}
pub async fn get<T>(&self, key: &str) -> CacheResult<Option<T>>
where
T: DeserializeOwned,
{
match self.inner.get(key).await {
Some(value) => serde_json::from_slice(value.as_ref())
.map(Some)
.map_err(CacheError::Serialize),
None => Ok(None),
}
}
pub async fn get_json(
&self,
key: &str,
) -> CacheResult<Option<serde_json::Value>> {
self.get(key).await
}
pub async fn set<T>(&self, key: &str, value: &T) -> CacheResult<()>
where
T: Serialize + ?Sized,
{
let value = serde_json::to_vec(value).map_err(CacheError::Serialize)?;
self.inner
.insert(Arc::<str>::from(key), Arc::<[u8]>::from(value))
.await;
Ok(())
}
pub async fn remove(&self, key: &str) {
self.inner.remove(key).await;
}
pub async fn contains_key(&self, key: &str) -> bool {
self.inner.contains_key(key)
}
pub async fn invalidate_all(&self) {
self.inner.invalidate_all();
}
pub fn invalidate_entries_if<F>(&self, predicate: F)
where
F: Fn(&str) -> bool + Send + Sync + 'static,
{
let _ = self
.inner
.invalidate_entries_if(move |key, _| predicate(key));
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use serde::{Deserialize, Serialize};
use super::{LocalCacheConfig, MokaCache};
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
struct User {
id: u64,
name: String,
}
#[tokio::test]
async fn stores_and_reads_typed_values() {
let cache = MokaCache::init();
let user = User {
id: 7,
name: "alice".to_string(),
};
cache.set("user:7", &user).await.unwrap();
assert_eq!(cache.get::<User>("user:7").await.unwrap(), Some(user));
}
#[tokio::test]
async fn expires_values_by_ttl() {
let cache = MokaCache::with_config(LocalCacheConfig {
max_capacity: 16,
time_to_live: Some(Duration::from_millis(25)),
time_to_idle: None,
});
cache.set("short", &"value").await.unwrap();
tokio::time::sleep(Duration::from_millis(60)).await;
assert_eq!(cache.get::<String>("short").await.unwrap(), None);
}
}