use std::{ cell::{Ref, RefCell}, convert::Infallible, error::Error as StdError, future::Future, mem, pin::Pin, rc::Rc, }; use actix_utils::future::{Ready, ready}; use actix_web::{ FromRequest, HttpMessage, HttpRequest, HttpResponse, ResponseError, body::BoxBody, dev::{Extensions, Payload, ServiceRequest, ServiceResponse}, }; use anyhow::Context; use derive_more::derive::{Display, From}; use serde::{Serialize, de::DeserializeOwned}; use serde_json::{Map, Value}; use uuid::Uuid; const SESSION_USER_KEY: &str = "session:user_uid"; const SESSION_WORKSPACE_KEY: &str = "session:workspace_id"; #[derive(Clone)] pub struct Session(Rc>); #[derive(Debug, Clone, Default, PartialEq, Eq)] pub enum SessionStatus { Changed, Purged, Renewed, #[default] Unchanged, } #[derive(Default)] struct SessionInner { state: Map, status: SessionStatus, } impl Session { pub fn get(&self, key: &str) -> Result, SessionGetError> { if let Some(value) = self.0.borrow().state.get(key) { Ok(Some( serde_json::from_value::(value.clone()) .with_context(|| { format!( "Failed to deserialize the JSON-encoded session data attached to key \ `{}` as a `{}` type", key, std::any::type_name::() ) }) .map_err(SessionGetError)?, )) } else { Ok(None) } } pub fn contains_key(&self, key: &str) -> bool { self.0.borrow().state.contains_key(key) } pub fn entries(&self) -> Ref<'_, Map> { Ref::map(self.0.borrow(), |inner| &inner.state) } pub fn status(&self) -> SessionStatus { Ref::map(self.0.borrow(), |inner| &inner.status).clone() } pub fn insert( &self, key: impl Into, value: T, ) -> Result<(), SessionInsertError> { let mut inner = self.0.borrow_mut(); if inner.status != SessionStatus::Purged { if inner.status != SessionStatus::Renewed { inner.status = SessionStatus::Changed; } let key = key.into(); let val = serde_json::to_value(&value) .with_context(|| { format!( "Failed to serialize the provided `{}` type instance as JSON in order to \ attach as session data to the `{key}` key", std::any::type_name::(), ) }) .map_err(SessionInsertError)?; inner.state.insert(key, val); } Ok(()) } pub fn update( &self, key: impl Into, updater: F, ) -> Result<(), SessionUpdateError> where F: FnOnce(T) -> T, { let mut inner = self.0.borrow_mut(); let key_str = key.into(); if let Some(val) = inner.state.get(&key_str) { let value = serde_json::from_value(val.clone()) .with_context(|| { format!( "Failed to deserialize the JSON-encoded session data attached to key \ `{key_str}` as a `{}` type", std::any::type_name::() ) }) .map_err(SessionUpdateError)?; let val = serde_json::to_value(updater(value)) .with_context(|| { format!( "Failed to serialize the provided `{}` type instance as JSON in order to \ attach as session data to the `{key_str}` key", std::any::type_name::(), ) }) .map_err(SessionUpdateError)?; inner.state.insert(key_str, val); } Ok(()) } pub fn update_or( &self, key: &str, default_value: T, updater: F, ) -> Result<(), SessionUpdateError> where F: FnOnce(T) -> T, { if self.contains_key(key) { self.update(key, updater) } else { self.insert(key, default_value) .map_err(|err| SessionUpdateError(err.into())) } } pub fn remove(&self, key: &str) -> Option { let mut inner = self.0.borrow_mut(); if inner.status != SessionStatus::Purged { if inner.status != SessionStatus::Renewed { inner.status = SessionStatus::Changed; } return inner.state.remove(key); } None } pub fn remove_as(&self, key: &str) -> Option> { self.remove(key) .map(|value| match serde_json::from_value::(value.clone()) { Ok(val) => Ok(val), Err(_err) => Err(value), }) } pub fn clear(&self) { let mut inner = self.0.borrow_mut(); if inner.status != SessionStatus::Purged { if inner.status != SessionStatus::Renewed { inner.status = SessionStatus::Changed; } inner.state.clear() } } pub fn purge(&self) { let mut inner = self.0.borrow_mut(); inner.status = SessionStatus::Purged; inner.state.clear(); } pub fn renew(&self) { let mut inner = self.0.borrow_mut(); if inner.status != SessionStatus::Purged { inner.status = SessionStatus::Renewed; } } pub fn user(&self) -> Option { self.get::(SESSION_USER_KEY).ok().flatten() } pub fn set_user(&self, uid: Uuid) { let _ = self.insert(SESSION_USER_KEY, uid); } pub fn clear_user(&self) { let _ = self.remove(SESSION_USER_KEY); } pub fn current_workspace_id(&self) -> Option { self.get::(SESSION_WORKSPACE_KEY).ok().flatten() } pub fn set_current_workspace_id(&self, id: Uuid) { let _ = self.insert(SESSION_WORKSPACE_KEY, id); } pub fn clear_current_workspace_id(&self) { let _ = self.remove(SESSION_WORKSPACE_KEY); } pub fn ip_address(&self) -> Option { self.get::("session:ip_address").ok().flatten() } pub fn user_agent(&self) -> Option { self.get::("session:user_agent").ok().flatten() } pub fn set_request_info(req: &HttpRequest) { let extensions = req.extensions_mut(); if let Some(inner) = extensions.get::>>() { let mut inner = inner.borrow_mut(); if let Some(ua) = req.headers().get("user-agent") { if let Ok(ua) = ua.to_str() { let _ = inner .state .insert("session:user_agent".to_string(), serde_json::json!(ua)); } } let addr = req .connection_info() .realip_remote_addr() .map(|s| s.to_string()); if let Some(ip) = addr { let _ = inner .state .insert("session:ip_address".to_string(), serde_json::json!(ip)); } } } #[allow(clippy::needless_pass_by_ref_mut)] pub(crate) fn set_session( req: &mut ServiceRequest, data: impl IntoIterator, ) { let session = Session::get_session(&mut req.extensions_mut()); let mut inner = session.0.borrow_mut(); inner.state.extend(data); } #[allow(clippy::needless_pass_by_ref_mut)] pub(crate) fn get_changes( res: &mut ServiceResponse, ) -> (SessionStatus, Map) { if let Some(s_impl) = res .request() .extensions() .get::>>() { let state = mem::take(&mut s_impl.borrow_mut().state); (s_impl.borrow().status.clone(), state) } else { (SessionStatus::Unchanged, Map::new()) } } /// This is used internally by the FromRequest impl, but also exposed for WS/manual usage. pub fn get_session(extensions: &mut Extensions) -> Session { if let Some(s_impl) = extensions.get::>>() { return Session(Rc::clone(s_impl)); } let inner = Rc::new(RefCell::new(SessionInner::default())); extensions.insert(inner.clone()); Session(inner) } } impl FromRequest for Session { type Error = Infallible; type Future = Ready>; #[inline] fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { ready(Ok(Session::get_session(&mut req.extensions_mut()))) } } /// Extractor for the authenticated user ID from session. /// Fails with 401 if the session has no logged-in user. #[derive(Clone, Copy)] pub struct SessionUser(pub Uuid); impl FromRequest for SessionUser { type Error = SessionGetError; type Future = Pin>>>; fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { let req = req.clone(); Box::pin(async move { let uid = { let mut extensions = req.extensions_mut(); let session = Session::get_session(&mut extensions); session .user() .ok_or_else(|| SessionGetError(anyhow::anyhow!("not authenticated")))? }; Ok(SessionUser(uid)) }) } } /// Extractor for the current workspace ID from session. /// Returns None if no workspace is selected (workspace selection is optional). #[derive(Clone, Copy)] pub struct SessionWorkspace(pub Option); impl FromRequest for SessionWorkspace { type Error = Infallible; type Future = Ready>; #[inline] fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { let mut extensions = req.extensions_mut(); let session = Session::get_session(&mut extensions); ready(Ok(SessionWorkspace(session.current_workspace_id()))) } } #[derive(Debug, Display, From)] #[display("{_0}")] pub struct SessionGetError(anyhow::Error); impl StdError for SessionGetError { fn source(&self) -> Option<&(dyn StdError + 'static)> { Some(self.0.as_ref()) } } impl ResponseError for SessionGetError { fn error_response(&self) -> HttpResponse { HttpResponse::build(self.status_code()) .content_type("text/plain") .body(self.to_string()) } } #[derive(Debug, Display, From)] #[display("{_0}")] pub struct SessionInsertError(anyhow::Error); impl StdError for SessionInsertError { fn source(&self) -> Option<&(dyn StdError + 'static)> { Some(self.0.as_ref()) } } impl ResponseError for SessionInsertError { fn error_response(&self) -> HttpResponse { HttpResponse::build(self.status_code()) .content_type("text/plain") .body(self.to_string()) } } #[derive(Debug, Display, From)] #[display("{_0}")] pub struct SessionUpdateError(anyhow::Error); impl StdError for SessionUpdateError { fn source(&self) -> Option<&(dyn StdError + 'static)> { Some(self.0.as_ref()) } } impl ResponseError for SessionUpdateError { fn error_response(&self) -> HttpResponse { HttpResponse::build(self.status_code()) .content_type("text/plain") .body(self.to_string()) } }