gitdataai/libs/session/session.rs
2026-04-14 19:02:01 +08:00

404 lines
12 KiB
Rust

use std::{
cell::{Ref, RefCell},
convert::Infallible,
error::Error as StdError,
future::Future,
mem,
pin::Pin,
rc::Rc,
};
use actix_utils::future::{Ready, ready};
use actix_web::{
FromRequest, HttpMessage, HttpRequest, HttpResponse, ResponseError,
body::BoxBody,
dev::{Extensions, Payload, ServiceRequest, ServiceResponse},
};
use anyhow::Context;
use derive_more::derive::{Display, From};
use serde::{Serialize, de::DeserializeOwned};
use serde_json::{Map, Value};
use uuid::Uuid;
const SESSION_USER_KEY: &str = "session:user_uid";
const SESSION_WORKSPACE_KEY: &str = "session:workspace_id";
#[derive(Clone)]
pub struct Session(Rc<RefCell<SessionInner>>);
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub enum SessionStatus {
Changed,
Purged,
Renewed,
#[default]
Unchanged,
}
#[derive(Default)]
struct SessionInner {
state: Map<String, Value>,
status: SessionStatus,
}
impl Session {
pub fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, SessionGetError> {
if let Some(value) = self.0.borrow().state.get(key) {
Ok(Some(
serde_json::from_value::<T>(value.clone())
.with_context(|| {
format!(
"Failed to deserialize the JSON-encoded session data attached to key \
`{}` as a `{}` type",
key,
std::any::type_name::<T>()
)
})
.map_err(SessionGetError)?,
))
} else {
Ok(None)
}
}
pub fn contains_key(&self, key: &str) -> bool {
self.0.borrow().state.contains_key(key)
}
pub fn entries(&self) -> Ref<'_, Map<String, Value>> {
Ref::map(self.0.borrow(), |inner| &inner.state)
}
pub fn status(&self) -> SessionStatus {
Ref::map(self.0.borrow(), |inner| &inner.status).clone()
}
pub fn insert<T: Serialize>(
&self,
key: impl Into<String>,
value: T,
) -> Result<(), SessionInsertError> {
let mut inner = self.0.borrow_mut();
if inner.status != SessionStatus::Purged {
if inner.status != SessionStatus::Renewed {
inner.status = SessionStatus::Changed;
}
let key = key.into();
let val = serde_json::to_value(&value)
.with_context(|| {
format!(
"Failed to serialize the provided `{}` type instance as JSON in order to \
attach as session data to the `{key}` key",
std::any::type_name::<T>(),
)
})
.map_err(SessionInsertError)?;
inner.state.insert(key, val);
}
Ok(())
}
pub fn update<T: Serialize + DeserializeOwned, F>(
&self,
key: impl Into<String>,
updater: F,
) -> Result<(), SessionUpdateError>
where
F: FnOnce(T) -> T,
{
let mut inner = self.0.borrow_mut();
let key_str = key.into();
if let Some(val) = inner.state.get(&key_str) {
let value = serde_json::from_value(val.clone())
.with_context(|| {
format!(
"Failed to deserialize the JSON-encoded session data attached to key \
`{key_str}` as a `{}` type",
std::any::type_name::<T>()
)
})
.map_err(SessionUpdateError)?;
let val = serde_json::to_value(updater(value))
.with_context(|| {
format!(
"Failed to serialize the provided `{}` type instance as JSON in order to \
attach as session data to the `{key_str}` key",
std::any::type_name::<T>(),
)
})
.map_err(SessionUpdateError)?;
inner.state.insert(key_str, val);
}
Ok(())
}
pub fn update_or<T: Serialize + DeserializeOwned, F>(
&self,
key: &str,
default_value: T,
updater: F,
) -> Result<(), SessionUpdateError>
where
F: FnOnce(T) -> T,
{
if self.contains_key(key) {
self.update(key, updater)
} else {
self.insert(key, default_value)
.map_err(|err| SessionUpdateError(err.into()))
}
}
pub fn remove(&self, key: &str) -> Option<Value> {
let mut inner = self.0.borrow_mut();
if inner.status != SessionStatus::Purged {
if inner.status != SessionStatus::Renewed {
inner.status = SessionStatus::Changed;
}
return inner.state.remove(key);
}
None
}
pub fn remove_as<T: DeserializeOwned>(&self, key: &str) -> Option<Result<T, Value>> {
self.remove(key)
.map(|value| match serde_json::from_value::<T>(value.clone()) {
Ok(val) => Ok(val),
Err(_err) => Err(value),
})
}
pub fn clear(&self) {
let mut inner = self.0.borrow_mut();
if inner.status != SessionStatus::Purged {
if inner.status != SessionStatus::Renewed {
inner.status = SessionStatus::Changed;
}
inner.state.clear()
}
}
pub fn purge(&self) {
let mut inner = self.0.borrow_mut();
inner.status = SessionStatus::Purged;
inner.state.clear();
}
pub fn renew(&self) {
let mut inner = self.0.borrow_mut();
if inner.status != SessionStatus::Purged {
inner.status = SessionStatus::Renewed;
}
}
pub fn user(&self) -> Option<Uuid> {
self.get::<Uuid>(SESSION_USER_KEY).ok().flatten()
}
pub fn set_user(&self, uid: Uuid) {
let _ = self.insert(SESSION_USER_KEY, uid);
}
pub fn clear_user(&self) {
let _ = self.remove(SESSION_USER_KEY);
}
pub fn current_workspace_id(&self) -> Option<Uuid> {
self.get::<Uuid>(SESSION_WORKSPACE_KEY).ok().flatten()
}
pub fn set_current_workspace_id(&self, id: Uuid) {
let _ = self.insert(SESSION_WORKSPACE_KEY, id);
}
pub fn clear_current_workspace_id(&self) {
let _ = self.remove(SESSION_WORKSPACE_KEY);
}
pub fn ip_address(&self) -> Option<String> {
self.get::<String>("session:ip_address").ok().flatten()
}
pub fn user_agent(&self) -> Option<String> {
self.get::<String>("session:user_agent").ok().flatten()
}
pub fn set_request_info(req: &HttpRequest) {
let extensions = req.extensions_mut();
if let Some(inner) = extensions.get::<Rc<RefCell<SessionInner>>>() {
let mut inner = inner.borrow_mut();
if let Some(ua) = req.headers().get("user-agent") {
if let Ok(ua) = ua.to_str() {
let _ = inner
.state
.insert("session:user_agent".to_string(), serde_json::json!(ua));
}
}
let addr = req
.connection_info()
.realip_remote_addr()
.map(|s| s.to_string());
if let Some(ip) = addr {
let _ = inner
.state
.insert("session:ip_address".to_string(), serde_json::json!(ip));
}
}
}
#[allow(clippy::needless_pass_by_ref_mut)]
pub(crate) fn set_session(
req: &mut ServiceRequest,
data: impl IntoIterator<Item = (String, Value)>,
) {
let session = Session::get_session(&mut req.extensions_mut());
let mut inner = session.0.borrow_mut();
inner.state.extend(data);
}
#[allow(clippy::needless_pass_by_ref_mut)]
pub(crate) fn get_changes<B>(
res: &mut ServiceResponse<B>,
) -> (SessionStatus, Map<String, Value>) {
if let Some(s_impl) = res
.request()
.extensions()
.get::<Rc<RefCell<SessionInner>>>()
{
let state = mem::take(&mut s_impl.borrow_mut().state);
(s_impl.borrow().status.clone(), state)
} else {
(SessionStatus::Unchanged, Map::new())
}
}
/// This is used internally by the FromRequest impl, but also exposed for WS/manual usage.
pub fn get_session(extensions: &mut Extensions) -> Session {
if let Some(s_impl) = extensions.get::<Rc<RefCell<SessionInner>>>() {
return Session(Rc::clone(s_impl));
}
let inner = Rc::new(RefCell::new(SessionInner::default()));
extensions.insert(inner.clone());
Session(inner)
}
}
impl FromRequest for Session {
type Error = Infallible;
type Future = Ready<Result<Self, Self::Error>>;
#[inline]
fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future {
ready(Ok(Session::get_session(&mut req.extensions_mut())))
}
}
/// Extractor for the authenticated user ID from session.
/// Fails with 401 if the session has no logged-in user.
#[derive(Clone, Copy)]
pub struct SessionUser(pub Uuid);
impl FromRequest for SessionUser {
type Error = SessionGetError;
type Future = Pin<Box<dyn Future<Output = Result<Self, Self::Error>>>>;
fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future {
let req = req.clone();
Box::pin(async move {
let uid = {
let mut extensions = req.extensions_mut();
let session = Session::get_session(&mut extensions);
session
.user()
.ok_or_else(|| SessionGetError(anyhow::anyhow!("not authenticated")))?
};
Ok(SessionUser(uid))
})
}
}
/// Extractor for the current workspace ID from session.
/// Returns None if no workspace is selected (workspace selection is optional).
#[derive(Clone, Copy)]
pub struct SessionWorkspace(pub Option<Uuid>);
impl FromRequest for SessionWorkspace {
type Error = Infallible;
type Future = Ready<Result<Self, Self::Error>>;
#[inline]
fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future {
let mut extensions = req.extensions_mut();
let session = Session::get_session(&mut extensions);
ready(Ok(SessionWorkspace(session.current_workspace_id())))
}
}
#[derive(Debug, Display, From)]
#[display("{_0}")]
pub struct SessionGetError(anyhow::Error);
impl StdError for SessionGetError {
fn source(&self) -> Option<&(dyn StdError + 'static)> {
Some(self.0.as_ref())
}
}
impl ResponseError for SessionGetError {
fn error_response(&self) -> HttpResponse<BoxBody> {
HttpResponse::build(self.status_code())
.content_type("text/plain")
.body(self.to_string())
}
}
#[derive(Debug, Display, From)]
#[display("{_0}")]
pub struct SessionInsertError(anyhow::Error);
impl StdError for SessionInsertError {
fn source(&self) -> Option<&(dyn StdError + 'static)> {
Some(self.0.as_ref())
}
}
impl ResponseError for SessionInsertError {
fn error_response(&self) -> HttpResponse<BoxBody> {
HttpResponse::build(self.status_code())
.content_type("text/plain")
.body(self.to_string())
}
}
#[derive(Debug, Display, From)]
#[display("{_0}")]
pub struct SessionUpdateError(anyhow::Error);
impl StdError for SessionUpdateError {
fn source(&self) -> Option<&(dyn StdError + 'static)> {
Some(self.0.as_ref())
}
}
impl ResponseError for SessionUpdateError {
fn error_response(&self) -> HttpResponse<BoxBody> {
HttpResponse::build(self.status_code())
.content_type("text/plain")
.body(self.to_string())
}
}