diff --git a/examples/user_management.rs b/examples/user_management.rs index 740a0f1..bbb14f1 100644 --- a/examples/user_management.rs +++ b/examples/user_management.rs @@ -33,7 +33,7 @@ async fn main() -> Result<()> { .add_authenticated_input_route("/update", "Enter your new string:", handle_update) // Add routes for handling user authentication - .add_um_routes::("/") + .add_um_routes::() // Start the server .serve() diff --git a/src/user_management/pages/unauth.gmi b/src/user_management/pages/unauth.gmi index 609ac5c..a234a4c 100644 --- a/src/user_management/pages/unauth.gmi +++ b/src/user_management/pages/unauth.gmi @@ -2,8 +2,8 @@ It seems like you don't have a client certificate enabled. In order to log in, you need to connect using a client certificate. If your client supports it, you can use the link below to activate a certificate. -=> /account/askcert Choose a Certificate +=> /account/askcert/{redirect} Choose a Certificate If your client can't automatically manage client certificates, check the link below for a list of clients that support client certificates. -=> /account/clients Clients +=> /account/clients/{redirect} Clients diff --git a/src/user_management/routes.rs b/src/user_management/routes.rs index 0aeb28b..6c46383 100644 --- a/src/user_management/routes.rs +++ b/src/user_management/routes.rs @@ -2,17 +2,24 @@ use anyhow::Result; use tokio::net::ToSocketAddrs; use serde::{Serialize, de::DeserializeOwned}; +#[cfg(feature = "dashmap")] +use dashmap::DashMap; +#[cfg(not(feature = "dashmap"))] +use std::collections::HashMap; +#[cfg(not(feature = "dashmap"))] +use std::sync::RwLock; + use std::future::Future; use crate::{Document, Request, Response}; use crate::types::document::HeadingLevel; -use crate::user_management::{User, RegisteredUser, UserManagerError}; - -const UNAUTH: &str = include_str!("pages/unauth.gmi"); -#[cfg(feature = "user_management_advanced")] -const NSI: &str = include_str!("pages/nsi.gmi"); -#[cfg(not(feature = "user_management_advanced"))] -const NSI: &str = include_str!("pages/nopass/nsi.gmi"); +use crate::user_management::{ + User, + RegisteredUser, + UserManager, + UserManagerError, + user::NotSignedInUser, +}; /// Import this trait to use [`add_um_routes()`](Self::add_um_routes()) pub trait UserManagementRoutes: private::Sealed { @@ -31,7 +38,7 @@ pub trait UserManagementRoutes: private::Sealed { /// /// The `redir` argument allows you to specify the point that users will be directed /// to return to once their account has been created. - fn add_um_routes(self, redir: &'static str) -> Self; + fn add_um_routes(self) -> Self; /// Add a special route that requires users to be logged in /// @@ -88,15 +95,15 @@ impl UserManagementRoutes for crate::Server { /// Add pre-configured routes to the serve to handle authentication /// /// See [`UserManagementRoutes::add_um_routes()`] - fn add_um_routes(self, redir: &'static str) -> Self { + fn add_um_routes(self) -> Self { #[allow(unused_mut)] - let mut modified_self = self.add_route("/account", move|r|handle_base::(r, redir)) - .add_route("/account/askcert", move|r|handle_ask_cert::(r, redir)) - .add_route("/account/register", move|r|handle_register::(r, redir)); + let mut modified_self = self.add_route("/account", handle_base::) + .add_route("/account/askcert", handle_ask_cert::) + .add_route("/account/register", handle_register::); #[cfg(feature = "user_management_advanced")] { modified_self = modified_self - .add_route("/account/login", move|r|handle_login::(r, redir)) + .add_route("/account/login", handle_login::) .add_route("/account/password", handle_password::); } @@ -119,11 +126,14 @@ impl UserManagementRoutes for crate::Server { self.add_route(path, move|request: Request| { let handler = handler.clone(); async move { + let segments = request.path_segments(); + let segments = segments.iter().map(String::as_ref).collect::>(); Ok(match request.user::()? { User::Unauthenticated => { - Response::success_gemini(UNAUTH) + render_unauth_page(segments) }, - User::NotSignedIn(_) => { + User::NotSignedIn(user) => { + save_redirect(&user, segments); Response::success_gemini(NSI) }, User::SignedIn(user) => { @@ -161,26 +171,46 @@ impl UserManagementRoutes for crate::Server { } } -async fn handle_base(request: Request, redirect: &'static str) -> Result { +#[cfg(feature = "user_management_advanced")] +const NSI: &str = include_str!("pages/nsi.gmi"); +#[cfg(not(feature = "user_management_advanced"))] +const NSI: &str = include_str!("pages/nopass/nsi.gmi"); + +// TODO periodically clean these +#[cfg(feature = "dashmap")] +lazy_static::lazy_static! { + static ref PENDING_REDIRECTS: DashMap = Default::default(); +} + +#[cfg(not(feature = "dashmap"))] +lazy_static::lazy_static! { + static ref PENDING_REDIRECTS: RwLock> = Default::default(); +} + +async fn handle_base(request: Request) -> Result { + let segments = request.trailing_segments().iter().map(String::as_str).collect::>(); Ok(match request.user::()? { User::Unauthenticated => { - Response::success_gemini(UNAUTH) + render_unauth_page(segments) }, - User::NotSignedIn(_) => { + User::NotSignedIn(usr) => { + save_redirect(&usr, segments); Response::success_gemini(NSI) }, User::SignedIn(user) => { - render_settings_menu(user, redirect) + render_settings_menu(user) }, }) } -async fn handle_ask_cert(request: Request, redirect: &'static str) -> Result { +async fn handle_ask_cert(request: Request) -> Result { Ok(match request.user::()? { User::Unauthenticated => { Response::client_certificate_required() }, - User::NotSignedIn(_) => { + User::NotSignedIn(nsi) => { + let segments = request.trailing_segments().iter().map(String::as_str).collect::>(); + save_redirect(&nsi, segments); #[cfg(feature = "user_management_advanced")] { Response::success_gemini(include_str!("pages/askcert/success.gmi")) } @@ -192,16 +222,16 @@ async fn handle_ask_cert(request: Reques Response::success_gemini(format!( include_str!("pages/askcert/exists.gmi"), username = user.username(), - redirect = redirect, + redirect = get_redirect(&user), )) }, }) } -async fn handle_register(request: Request, redirect: &'static str) -> Result { +async fn handle_register(request: Request) -> Result { Ok(match request.user::()? { User::Unauthenticated => { - Response::success_gemini(UNAUTH) + render_unauth_page(&[""]) }, User::NotSignedIn(nsi) => { if let Some(username) = request.input() { @@ -220,19 +250,19 @@ async fn handle_register(reque )) } }, - Ok(_) => { + Ok(user) => { #[cfg(feature = "user_management_advanced")] { Response::success_gemini(format!( include_str!("pages/register/success.gmi"), username = username, - redirect = redirect, + redirect = get_redirect(&user), )) } #[cfg(not(feature = "user_management_advanced"))] { Response::success_gemini(format!( include_str!("pages/nopass/register/success.gmi"), username = username, - redirect = redirect, + redirect = get_redirect(&user), )) } }, @@ -243,16 +273,16 @@ async fn handle_register(reque } }, User::SignedIn(user) => { - render_settings_menu(user, redirect) + render_settings_menu(user) }, }) } #[cfg(feature = "user_management_advanced")] -async fn handle_login(request: Request, redirect: &'static str) -> Result { +async fn handle_login(request: Request) -> Result { Ok(match request.user::()? { User::Unauthenticated => { - Response::success_gemini(UNAUTH) + render_unauth_page(&[""]) }, User::NotSignedIn(nsi) => { if let Some(username) = request.trailing_segments().get(0) { @@ -264,11 +294,11 @@ async fn handle_login(request: username = username, )) }, - Ok(_) => { + Ok(Some(user)) => { Response::success_gemini(format!( include_str!("pages/login/success.gmi"), username = username, - redirect = redirect, + redirect = get_redirect(&user), )) }, Err(e) => return Err(e.into()), @@ -285,7 +315,7 @@ async fn handle_login(request: } }, User::SignedIn(user) => { - render_settings_menu(user, redirect) + render_settings_menu(user) }, }) } @@ -294,9 +324,10 @@ async fn handle_login(request: async fn handle_password(request: Request) -> Result { Ok(match request.user::()? { User::Unauthenticated => { - Response::success_gemini(UNAUTH) + render_unauth_page(&[""]) }, - User::NotSignedIn(_) => { + User::NotSignedIn(nsi) => { + save_redirect(&nsi, &[""]); Response::success_gemini(NSI) }, User::SignedIn(mut user) => { @@ -318,10 +349,8 @@ async fn handle_password(reque }) } - fn render_settings_menu( - user: RegisteredUser, - redirect: &str + user: RegisteredUser ) -> Response { let mut document = Document::new(); document @@ -329,7 +358,7 @@ fn render_settings_menu( .add_blank_line() .add_text(&format!("Welcome {}!", user.username())) .add_blank_line() - .add_link(redirect, "Back to the app") + .add_link(get_redirect(&user).as_str(), "Back to the app") .add_blank_line(); #[cfg(feature = "user_management_advanced")] @@ -355,6 +384,48 @@ fn render_settings_menu( document.into() } +fn render_unauth_page<'a>( + redirect: impl AsRef<[&'a str]>, +) -> Response { + Response::success_gemini(format!( + include_str!("pages/unauth.gmi"), + redirect = redirect.as_ref().join("/"), + )) +} + +fn save_redirect<'a>( + user: &NotSignedInUser, + redirect: impl AsRef<[&'a str]>, +) { + let mut redirect = redirect.as_ref().join("/"); + redirect.insert(0, '/'); + if redirect.len() > 1 { + #[cfg(feature = "dashmap")] + let ref_to_map = &*PENDING_REDIRECTS; + #[cfg(not(feature = "dashmap"))] + let mut ref_to_map = PENDING_REDIRECTS.write().unwrap(); + + let cert_hash = UserManager::hash_certificate(&user.certificate); + debug!("Added \"{}\" as redirect for cert {:x}", redirect, cert_hash); + ref_to_map.insert(cert_hash, redirect); + } +} + +fn get_redirect(user: &RegisteredUser) -> String { + let cert_hash = UserManager::hash_certificate(user.active_certificate().unwrap()); + + #[cfg(feature = "dashmap")] + let maybe_redir = PENDING_REDIRECTS.get(&cert_hash).map(|r| r.clone()); + #[cfg(not(feature = "dashmap"))] + let ref_to_map = PENDING_REDIRECTS.read().unwrap(); + #[cfg(not(feature = "dashmap"))] + let maybe_redir = ref_to_map.get(&cert_hash).cloned(); + + let redirect = maybe_redir.unwrap_or_else(||"/".to_string()); + debug!("Accessed redirect to \"{}\" for cert {:x}", redirect, cert_hash); + redirect +} + mod private { pub trait Sealed {} impl Sealed for crate::Server {}