404 lines
12 KiB
Rust
404 lines
12 KiB
Rust
use std::{
|
|
cell::{Ref, RefCell},
|
|
convert::Infallible,
|
|
error::Error as StdError,
|
|
future::Future,
|
|
mem,
|
|
pin::Pin,
|
|
rc::Rc,
|
|
};
|
|
|
|
use actix_utils::future::{Ready, ready};
|
|
use actix_web::{
|
|
FromRequest, HttpMessage, HttpRequest, HttpResponse, ResponseError,
|
|
body::BoxBody,
|
|
dev::{Extensions, Payload, ServiceRequest, ServiceResponse},
|
|
};
|
|
use anyhow::Context;
|
|
use derive_more::derive::{Display, From};
|
|
use serde::{Serialize, de::DeserializeOwned};
|
|
use serde_json::{Map, Value};
|
|
use uuid::Uuid;
|
|
|
|
const SESSION_USER_KEY: &str = "session:user_uid";
|
|
const SESSION_WORKSPACE_KEY: &str = "session:workspace_id";
|
|
|
|
#[derive(Clone)]
|
|
pub struct Session(Rc<RefCell<SessionInner>>);
|
|
|
|
#[derive(Debug, Clone, Default, PartialEq, Eq)]
|
|
pub enum SessionStatus {
|
|
Changed,
|
|
Purged,
|
|
Renewed,
|
|
#[default]
|
|
Unchanged,
|
|
}
|
|
|
|
#[derive(Default)]
|
|
struct SessionInner {
|
|
state: Map<String, Value>,
|
|
status: SessionStatus,
|
|
}
|
|
|
|
impl Session {
|
|
pub fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, SessionGetError> {
|
|
if let Some(value) = self.0.borrow().state.get(key) {
|
|
Ok(Some(
|
|
serde_json::from_value::<T>(value.clone())
|
|
.with_context(|| {
|
|
format!(
|
|
"Failed to deserialize the JSON-encoded session data attached to key \
|
|
`{}` as a `{}` type",
|
|
key,
|
|
std::any::type_name::<T>()
|
|
)
|
|
})
|
|
.map_err(SessionGetError)?,
|
|
))
|
|
} else {
|
|
Ok(None)
|
|
}
|
|
}
|
|
|
|
pub fn contains_key(&self, key: &str) -> bool {
|
|
self.0.borrow().state.contains_key(key)
|
|
}
|
|
|
|
pub fn entries(&self) -> Ref<'_, Map<String, Value>> {
|
|
Ref::map(self.0.borrow(), |inner| &inner.state)
|
|
}
|
|
|
|
pub fn status(&self) -> SessionStatus {
|
|
Ref::map(self.0.borrow(), |inner| &inner.status).clone()
|
|
}
|
|
|
|
pub fn insert<T: Serialize>(
|
|
&self,
|
|
key: impl Into<String>,
|
|
value: T,
|
|
) -> Result<(), SessionInsertError> {
|
|
let mut inner = self.0.borrow_mut();
|
|
|
|
if inner.status != SessionStatus::Purged {
|
|
if inner.status != SessionStatus::Renewed {
|
|
inner.status = SessionStatus::Changed;
|
|
}
|
|
|
|
let key = key.into();
|
|
let val = serde_json::to_value(&value)
|
|
.with_context(|| {
|
|
format!(
|
|
"Failed to serialize the provided `{}` type instance as JSON in order to \
|
|
attach as session data to the `{key}` key",
|
|
std::any::type_name::<T>(),
|
|
)
|
|
})
|
|
.map_err(SessionInsertError)?;
|
|
|
|
inner.state.insert(key, val);
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub fn update<T: Serialize + DeserializeOwned, F>(
|
|
&self,
|
|
key: impl Into<String>,
|
|
updater: F,
|
|
) -> Result<(), SessionUpdateError>
|
|
where
|
|
F: FnOnce(T) -> T,
|
|
{
|
|
let mut inner = self.0.borrow_mut();
|
|
let key_str = key.into();
|
|
|
|
if let Some(val) = inner.state.get(&key_str) {
|
|
let value = serde_json::from_value(val.clone())
|
|
.with_context(|| {
|
|
format!(
|
|
"Failed to deserialize the JSON-encoded session data attached to key \
|
|
`{key_str}` as a `{}` type",
|
|
std::any::type_name::<T>()
|
|
)
|
|
})
|
|
.map_err(SessionUpdateError)?;
|
|
|
|
let val = serde_json::to_value(updater(value))
|
|
.with_context(|| {
|
|
format!(
|
|
"Failed to serialize the provided `{}` type instance as JSON in order to \
|
|
attach as session data to the `{key_str}` key",
|
|
std::any::type_name::<T>(),
|
|
)
|
|
})
|
|
.map_err(SessionUpdateError)?;
|
|
|
|
inner.state.insert(key_str, val);
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub fn update_or<T: Serialize + DeserializeOwned, F>(
|
|
&self,
|
|
key: &str,
|
|
default_value: T,
|
|
updater: F,
|
|
) -> Result<(), SessionUpdateError>
|
|
where
|
|
F: FnOnce(T) -> T,
|
|
{
|
|
if self.contains_key(key) {
|
|
self.update(key, updater)
|
|
} else {
|
|
self.insert(key, default_value)
|
|
.map_err(|err| SessionUpdateError(err.into()))
|
|
}
|
|
}
|
|
|
|
pub fn remove(&self, key: &str) -> Option<Value> {
|
|
let mut inner = self.0.borrow_mut();
|
|
|
|
if inner.status != SessionStatus::Purged {
|
|
if inner.status != SessionStatus::Renewed {
|
|
inner.status = SessionStatus::Changed;
|
|
}
|
|
return inner.state.remove(key);
|
|
}
|
|
|
|
None
|
|
}
|
|
|
|
pub fn remove_as<T: DeserializeOwned>(&self, key: &str) -> Option<Result<T, Value>> {
|
|
self.remove(key)
|
|
.map(|value| match serde_json::from_value::<T>(value.clone()) {
|
|
Ok(val) => Ok(val),
|
|
Err(_err) => Err(value),
|
|
})
|
|
}
|
|
|
|
pub fn clear(&self) {
|
|
let mut inner = self.0.borrow_mut();
|
|
|
|
if inner.status != SessionStatus::Purged {
|
|
if inner.status != SessionStatus::Renewed {
|
|
inner.status = SessionStatus::Changed;
|
|
}
|
|
inner.state.clear()
|
|
}
|
|
}
|
|
|
|
pub fn purge(&self) {
|
|
let mut inner = self.0.borrow_mut();
|
|
inner.status = SessionStatus::Purged;
|
|
inner.state.clear();
|
|
}
|
|
|
|
pub fn renew(&self) {
|
|
let mut inner = self.0.borrow_mut();
|
|
|
|
if inner.status != SessionStatus::Purged {
|
|
inner.status = SessionStatus::Renewed;
|
|
}
|
|
}
|
|
|
|
pub fn user(&self) -> Option<Uuid> {
|
|
self.get::<Uuid>(SESSION_USER_KEY).ok().flatten()
|
|
}
|
|
|
|
pub fn set_user(&self, uid: Uuid) {
|
|
let _ = self.insert(SESSION_USER_KEY, uid);
|
|
}
|
|
|
|
pub fn clear_user(&self) {
|
|
let _ = self.remove(SESSION_USER_KEY);
|
|
}
|
|
|
|
pub fn current_workspace_id(&self) -> Option<Uuid> {
|
|
self.get::<Uuid>(SESSION_WORKSPACE_KEY).ok().flatten()
|
|
}
|
|
|
|
pub fn set_current_workspace_id(&self, id: Uuid) {
|
|
let _ = self.insert(SESSION_WORKSPACE_KEY, id);
|
|
}
|
|
|
|
pub fn clear_current_workspace_id(&self) {
|
|
let _ = self.remove(SESSION_WORKSPACE_KEY);
|
|
}
|
|
|
|
pub fn ip_address(&self) -> Option<String> {
|
|
self.get::<String>("session:ip_address").ok().flatten()
|
|
}
|
|
|
|
pub fn user_agent(&self) -> Option<String> {
|
|
self.get::<String>("session:user_agent").ok().flatten()
|
|
}
|
|
|
|
pub fn set_request_info(req: &HttpRequest) {
|
|
let extensions = req.extensions_mut();
|
|
if let Some(inner) = extensions.get::<Rc<RefCell<SessionInner>>>() {
|
|
let mut inner = inner.borrow_mut();
|
|
if let Some(ua) = req.headers().get("user-agent") {
|
|
if let Ok(ua) = ua.to_str() {
|
|
let _ = inner
|
|
.state
|
|
.insert("session:user_agent".to_string(), serde_json::json!(ua));
|
|
}
|
|
}
|
|
let addr = req
|
|
.connection_info()
|
|
.realip_remote_addr()
|
|
.map(|s| s.to_string());
|
|
if let Some(ip) = addr {
|
|
let _ = inner
|
|
.state
|
|
.insert("session:ip_address".to_string(), serde_json::json!(ip));
|
|
}
|
|
}
|
|
}
|
|
|
|
#[allow(clippy::needless_pass_by_ref_mut)]
|
|
pub(crate) fn set_session(
|
|
req: &mut ServiceRequest,
|
|
data: impl IntoIterator<Item = (String, Value)>,
|
|
) {
|
|
let session = Session::get_session(&mut req.extensions_mut());
|
|
let mut inner = session.0.borrow_mut();
|
|
inner.state.extend(data);
|
|
}
|
|
|
|
#[allow(clippy::needless_pass_by_ref_mut)]
|
|
pub(crate) fn get_changes<B>(
|
|
res: &mut ServiceResponse<B>,
|
|
) -> (SessionStatus, Map<String, Value>) {
|
|
if let Some(s_impl) = res
|
|
.request()
|
|
.extensions()
|
|
.get::<Rc<RefCell<SessionInner>>>()
|
|
{
|
|
let state = mem::take(&mut s_impl.borrow_mut().state);
|
|
(s_impl.borrow().status.clone(), state)
|
|
} else {
|
|
(SessionStatus::Unchanged, Map::new())
|
|
}
|
|
}
|
|
|
|
/// This is used internally by the FromRequest impl, but also exposed for WS/manual usage.
|
|
pub fn get_session(extensions: &mut Extensions) -> Session {
|
|
if let Some(s_impl) = extensions.get::<Rc<RefCell<SessionInner>>>() {
|
|
return Session(Rc::clone(s_impl));
|
|
}
|
|
|
|
let inner = Rc::new(RefCell::new(SessionInner::default()));
|
|
extensions.insert(inner.clone());
|
|
|
|
Session(inner)
|
|
}
|
|
}
|
|
|
|
impl FromRequest for Session {
|
|
type Error = Infallible;
|
|
type Future = Ready<Result<Self, Self::Error>>;
|
|
|
|
#[inline]
|
|
fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future {
|
|
ready(Ok(Session::get_session(&mut req.extensions_mut())))
|
|
}
|
|
}
|
|
|
|
/// Extractor for the authenticated user ID from session.
|
|
/// Fails with 401 if the session has no logged-in user.
|
|
#[derive(Clone, Copy)]
|
|
pub struct SessionUser(pub Uuid);
|
|
|
|
impl FromRequest for SessionUser {
|
|
type Error = SessionGetError;
|
|
type Future = Pin<Box<dyn Future<Output = Result<Self, Self::Error>>>>;
|
|
|
|
fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future {
|
|
let req = req.clone();
|
|
Box::pin(async move {
|
|
let uid = {
|
|
let mut extensions = req.extensions_mut();
|
|
let session = Session::get_session(&mut extensions);
|
|
session
|
|
.user()
|
|
.ok_or_else(|| SessionGetError(anyhow::anyhow!("not authenticated")))?
|
|
};
|
|
Ok(SessionUser(uid))
|
|
})
|
|
}
|
|
}
|
|
|
|
/// Extractor for the current workspace ID from session.
|
|
/// Returns None if no workspace is selected (workspace selection is optional).
|
|
#[derive(Clone, Copy)]
|
|
pub struct SessionWorkspace(pub Option<Uuid>);
|
|
|
|
impl FromRequest for SessionWorkspace {
|
|
type Error = Infallible;
|
|
type Future = Ready<Result<Self, Self::Error>>;
|
|
|
|
#[inline]
|
|
fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future {
|
|
let mut extensions = req.extensions_mut();
|
|
let session = Session::get_session(&mut extensions);
|
|
ready(Ok(SessionWorkspace(session.current_workspace_id())))
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Display, From)]
|
|
#[display("{_0}")]
|
|
pub struct SessionGetError(anyhow::Error);
|
|
|
|
impl StdError for SessionGetError {
|
|
fn source(&self) -> Option<&(dyn StdError + 'static)> {
|
|
Some(self.0.as_ref())
|
|
}
|
|
}
|
|
|
|
impl ResponseError for SessionGetError {
|
|
fn error_response(&self) -> HttpResponse<BoxBody> {
|
|
HttpResponse::build(self.status_code())
|
|
.content_type("text/plain")
|
|
.body(self.to_string())
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Display, From)]
|
|
#[display("{_0}")]
|
|
pub struct SessionInsertError(anyhow::Error);
|
|
|
|
impl StdError for SessionInsertError {
|
|
fn source(&self) -> Option<&(dyn StdError + 'static)> {
|
|
Some(self.0.as_ref())
|
|
}
|
|
}
|
|
|
|
impl ResponseError for SessionInsertError {
|
|
fn error_response(&self) -> HttpResponse<BoxBody> {
|
|
HttpResponse::build(self.status_code())
|
|
.content_type("text/plain")
|
|
.body(self.to_string())
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Display, From)]
|
|
#[display("{_0}")]
|
|
pub struct SessionUpdateError(anyhow::Error);
|
|
|
|
impl StdError for SessionUpdateError {
|
|
fn source(&self) -> Option<&(dyn StdError + 'static)> {
|
|
Some(self.0.as_ref())
|
|
}
|
|
}
|
|
|
|
impl ResponseError for SessionUpdateError {
|
|
fn error_response(&self) -> HttpResponse<BoxBody> {
|
|
HttpResponse::build(self.status_code())
|
|
.content_type("text/plain")
|
|
.body(self.to_string())
|
|
}
|
|
}
|