gitdataai/lib/session/session.rs
2026-05-30 01:38:40 +08:00

439 lines
12 KiB
Rust

use std::{
cell::{Ref, RefCell},
convert::Infallible,
error::Error as StdError,
future::Future,
mem,
pin::Pin,
rc::Rc,
};
use actix_utils::future::{Ready, ready};
use actix_web::{
FromRequest, HttpMessage, HttpRequest, HttpResponse, ResponseError,
body::BoxBody,
dev::{Extensions, Payload, ServiceRequest, ServiceResponse},
};
use anyhow::Context;
use derive_more::derive::{Display, From};
use serde::{Serialize, de::DeserializeOwned};
use serde_json::{Map, Value};
use uuid::Uuid;
const SESSION_USER_KEY: &str = "session:user_uid";
#[derive(Clone)]
pub struct Session(Rc<RefCell<SessionInner>>);
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub enum SessionStatus {
Changed,
Purged,
Renewed,
#[default]
Unchanged,
}
#[derive(Default)]
struct SessionInner {
state: Map<String, Value>,
status: SessionStatus,
}
impl Session {
pub fn get<T: DeserializeOwned>(
&self,
key: &str,
) -> Result<Option<T>, SessionGetError> {
if let Some(value) = self.0.borrow().state.get(key) {
Ok(Some(
serde_json::from_value::<T>(value.clone())
.with_context(|| {
format!(
"Failed to deserialize the JSON-encoded session data attached to key \
`{}` as a `{}` type",
key,
std::any::type_name::<T>()
)
})
.map_err(SessionGetError)?,
))
} else {
Ok(None)
}
}
pub fn contains_key(&self, key: &str) -> bool {
self.0.borrow().state.contains_key(key)
}
pub fn entries(&self) -> Ref<'_, Map<String, Value>> {
Ref::map(self.0.borrow(), |inner| &inner.state)
}
pub fn status(&self) -> SessionStatus {
Ref::map(self.0.borrow(), |inner| &inner.status).clone()
}
pub fn insert<T: Serialize>(
&self,
key: impl Into<String>,
value: T,
) -> Result<(), SessionInsertError> {
let mut inner = self.0.borrow_mut();
if inner.status != SessionStatus::Purged {
if inner.status != SessionStatus::Renewed {
inner.status = SessionStatus::Changed;
}
let key = key.into();
let val = serde_json::to_value(&value)
.with_context(|| {
format!(
"Failed to serialize the provided `{}` type instance as JSON in order to \
attach as session data to the `{key}` key",
std::any::type_name::<T>(),
)
})
.map_err(SessionInsertError)?;
inner.state.insert(key, val);
}
Ok(())
}
pub fn update<T: Serialize + DeserializeOwned, F>(
&self,
key: impl Into<String>,
updater: F,
) -> Result<(), SessionUpdateError>
where
F: FnOnce(T) -> T,
{
let mut inner = self.0.borrow_mut();
let key_str = key.into();
if let Some(val) = inner.state.get(&key_str) {
if inner.status == SessionStatus::Purged {
return Ok(());
}
let value = serde_json::from_value(val.clone())
.with_context(|| {
format!(
"Failed to deserialize the JSON-encoded session data attached to key \
`{key_str}` as a `{}` type",
std::any::type_name::<T>()
)
})
.map_err(SessionUpdateError)?;
let val = serde_json::to_value(updater(value))
.with_context(|| {
format!(
"Failed to serialize the provided `{}` type instance as JSON in order to \
attach as session data to the `{key_str}` key",
std::any::type_name::<T>(),
)
})
.map_err(SessionUpdateError)?;
if inner.status != SessionStatus::Renewed {
inner.status = SessionStatus::Changed;
}
inner.state.insert(key_str, val);
}
Ok(())
}
pub fn update_or<T: Serialize + DeserializeOwned, F>(
&self,
key: &str,
default_value: T,
updater: F,
) -> Result<(), SessionUpdateError>
where
F: FnOnce(T) -> T,
{
if self.contains_key(key) {
self.update(key, updater)
} else {
self.insert(key, default_value)
.map_err(|err| SessionUpdateError(err.into()))
}
}
pub fn remove(&self, key: &str) -> Option<Value> {
let mut inner = self.0.borrow_mut();
if inner.status != SessionStatus::Purged {
if inner.status != SessionStatus::Renewed {
inner.status = SessionStatus::Changed;
}
return inner.state.remove(key);
}
None
}
pub fn remove_as<T: DeserializeOwned>(
&self,
key: &str,
) -> Option<Result<T, Value>> {
self.remove(key).map(|value| {
match serde_json::from_value::<T>(value.clone()) {
Ok(val) => Ok(val),
Err(_err) => Err(value),
}
})
}
pub fn clear(&self) {
let mut inner = self.0.borrow_mut();
if inner.status != SessionStatus::Purged {
if inner.status != SessionStatus::Renewed {
inner.status = SessionStatus::Changed;
}
inner.state.clear()
}
}
pub fn purge(&self) {
let mut inner = self.0.borrow_mut();
inner.status = SessionStatus::Purged;
inner.state.clear();
}
pub fn renew(&self) {
let mut inner = self.0.borrow_mut();
if inner.status != SessionStatus::Purged {
inner.status = SessionStatus::Renewed;
}
}
pub fn user(&self) -> Option<Uuid> {
self.get::<Uuid>(SESSION_USER_KEY).ok().flatten()
}
pub fn set_user(&self, uid: Uuid) {
let _ = self.insert(SESSION_USER_KEY, uid);
}
pub fn clear_user(&self) {
let _ = self.remove(SESSION_USER_KEY);
}
pub fn ip_address(&self) -> Option<String> {
self.get::<String>("session:ip_address").ok().flatten()
}
pub fn user_agent(&self) -> Option<String> {
self.get::<String>("session:user_agent").ok().flatten()
}
pub fn set_request_info(req: &HttpRequest) {
let extensions = req.extensions_mut();
if let Some(inner) = extensions.get::<Rc<RefCell<SessionInner>>>() {
let mut inner = inner.borrow_mut();
let mut changed = false;
if let Some(ua) = req.headers().get("user-agent")
&& let Ok(ua) = ua.to_str()
{
let _ = inner.state.insert(
"session:user_agent".to_string(),
serde_json::json!(ua),
);
changed = true;
}
let addr = req
.connection_info()
.realip_remote_addr()
.map(|s| s.to_string());
if let Some(ip) = addr {
let _ = inner.state.insert(
"session:ip_address".to_string(),
serde_json::json!(ip),
);
changed = true;
}
if changed && inner.status != SessionStatus::Renewed {
inner.status = SessionStatus::Changed;
}
}
}
#[allow(clippy::needless_pass_by_ref_mut)]
pub(crate) fn set_session(
req: &mut ServiceRequest,
data: impl IntoIterator<Item = (String, Value)>,
) {
let session = Session::get_session(&mut req.extensions_mut());
let mut inner = session.0.borrow_mut();
inner.state.extend(data);
}
#[allow(clippy::needless_pass_by_ref_mut)]
pub(crate) fn get_changes<B>(
res: &mut ServiceResponse<B>,
) -> (SessionStatus, Map<String, Value>) {
if let Some(s_impl) = res
.request()
.extensions()
.get::<Rc<RefCell<SessionInner>>>()
{
let state = mem::take(&mut s_impl.borrow_mut().state);
(s_impl.borrow().status.clone(), state)
} else {
(SessionStatus::Unchanged, Map::new())
}
}
pub fn no_op() -> Self {
Self(Rc::new(RefCell::new(SessionInner::default())))
}
pub fn get_session(extensions: &mut Extensions) -> Session {
if let Some(s_impl) = extensions.get::<Rc<RefCell<SessionInner>>>() {
return Session(Rc::clone(s_impl));
}
let inner = Rc::new(RefCell::new(SessionInner::default()));
extensions.insert(inner.clone());
Session(inner)
}
}
#[cfg(test)]
mod tests {
use serde_json::json;
use super::{Session, SessionStatus, SessionUpdateError};
#[test]
fn update_marks_session_as_changed() -> Result<(), SessionUpdateError> {
let session = Session::no_op();
{
let mut inner = session.0.borrow_mut();
inner.state.insert("counter".to_string(), json!(1_u64));
inner.status = SessionStatus::Unchanged;
}
session.update("counter", |counter: u64| counter + 1)?;
assert_eq!(session.status(), SessionStatus::Changed);
assert_eq!(
session.0.borrow().state.get("counter"),
Some(&json!(2_u64))
);
Ok(())
}
#[test]
fn update_preserves_renewed_status() -> Result<(), SessionUpdateError> {
let session = Session::no_op();
{
let mut inner = session.0.borrow_mut();
inner.state.insert("counter".to_string(), json!(1_u64));
inner.status = SessionStatus::Renewed;
}
session.update("counter", |counter: u64| counter + 1)?;
assert_eq!(session.status(), SessionStatus::Renewed);
assert_eq!(
session.0.borrow().state.get("counter"),
Some(&json!(2_u64))
);
Ok(())
}
}
impl FromRequest for Session {
type Error = Infallible;
type Future = Ready<Result<Self, Self::Error>>;
#[inline]
fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future {
ready(Ok(Session::get_session(&mut req.extensions_mut())))
}
}
#[derive(Clone, Copy)]
pub struct SessionUser(pub Uuid);
impl FromRequest for SessionUser {
type Error = SessionGetError;
type Future = Pin<Box<dyn Future<Output = Result<Self, Self::Error>>>>;
fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future {
let req = req.clone();
Box::pin(async move {
let uid = {
let mut extensions = req.extensions_mut();
let session = Session::get_session(&mut extensions);
session.user().ok_or_else(|| {
SessionGetError(anyhow::anyhow!("not authenticated"))
})?
};
Ok(SessionUser(uid))
})
}
}
#[derive(Debug, Display, From)]
#[display("{_0}")]
pub struct SessionGetError(anyhow::Error);
impl StdError for SessionGetError {
fn source(&self) -> Option<&(dyn StdError + 'static)> {
Some(self.0.as_ref())
}
}
impl ResponseError for SessionGetError {
fn error_response(&self) -> HttpResponse<BoxBody> {
HttpResponse::build(self.status_code())
.content_type("text/plain")
.body(self.to_string())
}
}
#[derive(Debug, Display, From)]
#[display("{_0}")]
pub struct SessionInsertError(anyhow::Error);
impl StdError for SessionInsertError {
fn source(&self) -> Option<&(dyn StdError + 'static)> {
Some(self.0.as_ref())
}
}
impl ResponseError for SessionInsertError {
fn error_response(&self) -> HttpResponse<BoxBody> {
HttpResponse::build(self.status_code())
.content_type("text/plain")
.body(self.to_string())
}
}
#[derive(Debug, Display, From)]
#[display("{_0}")]
pub struct SessionUpdateError(anyhow::Error);
impl StdError for SessionUpdateError {
fn source(&self) -> Option<&(dyn StdError + 'static)> {
Some(self.0.as_ref())
}
}
impl ResponseError for SessionUpdateError {
fn error_response(&self) -> HttpResponse<BoxBody> {
HttpResponse::build(self.status_code())
.content_type("text/plain")
.body(self.to_string())
}
}