Login flow now redirects back to where it started
Closes #8 Still in testing, but seems good I think
This commit is contained in:
parent
98583e737f
commit
e83f2ca109
|
@ -33,7 +33,7 @@ async fn main() -> Result<()> {
|
||||||
.add_authenticated_input_route("/update", "Enter your new string:", handle_update)
|
.add_authenticated_input_route("/update", "Enter your new string:", handle_update)
|
||||||
|
|
||||||
// Add routes for handling user authentication
|
// Add routes for handling user authentication
|
||||||
.add_um_routes::<String>("/")
|
.add_um_routes::<String>()
|
||||||
|
|
||||||
// Start the server
|
// Start the server
|
||||||
.serve()
|
.serve()
|
||||||
|
|
|
@ -2,8 +2,8 @@
|
||||||
|
|
||||||
It seems like you don't have a client certificate enabled. In order to log in, you need to connect using a client certificate. If your client supports it, you can use the link below to activate a certificate.
|
It seems like you don't have a client certificate enabled. In order to log in, you need to connect using a client certificate. If your client supports it, you can use the link below to activate a certificate.
|
||||||
|
|
||||||
=> /account/askcert Choose a Certificate
|
=> /account/askcert/{redirect} Choose a Certificate
|
||||||
|
|
||||||
If your client can't automatically manage client certificates, check the link below for a list of clients that support client certificates.
|
If your client can't automatically manage client certificates, check the link below for a list of clients that support client certificates.
|
||||||
|
|
||||||
=> /account/clients Clients
|
=> /account/clients/{redirect} Clients
|
||||||
|
|
|
@ -2,17 +2,24 @@ use anyhow::Result;
|
||||||
use tokio::net::ToSocketAddrs;
|
use tokio::net::ToSocketAddrs;
|
||||||
use serde::{Serialize, de::DeserializeOwned};
|
use serde::{Serialize, de::DeserializeOwned};
|
||||||
|
|
||||||
|
#[cfg(feature = "dashmap")]
|
||||||
|
use dashmap::DashMap;
|
||||||
|
#[cfg(not(feature = "dashmap"))]
|
||||||
|
use std::collections::HashMap;
|
||||||
|
#[cfg(not(feature = "dashmap"))]
|
||||||
|
use std::sync::RwLock;
|
||||||
|
|
||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
|
|
||||||
use crate::{Document, Request, Response};
|
use crate::{Document, Request, Response};
|
||||||
use crate::types::document::HeadingLevel;
|
use crate::types::document::HeadingLevel;
|
||||||
use crate::user_management::{User, RegisteredUser, UserManagerError};
|
use crate::user_management::{
|
||||||
|
User,
|
||||||
const UNAUTH: &str = include_str!("pages/unauth.gmi");
|
RegisteredUser,
|
||||||
#[cfg(feature = "user_management_advanced")]
|
UserManager,
|
||||||
const NSI: &str = include_str!("pages/nsi.gmi");
|
UserManagerError,
|
||||||
#[cfg(not(feature = "user_management_advanced"))]
|
user::NotSignedInUser,
|
||||||
const NSI: &str = include_str!("pages/nopass/nsi.gmi");
|
};
|
||||||
|
|
||||||
/// Import this trait to use [`add_um_routes()`](Self::add_um_routes())
|
/// Import this trait to use [`add_um_routes()`](Self::add_um_routes())
|
||||||
pub trait UserManagementRoutes: private::Sealed {
|
pub trait UserManagementRoutes: private::Sealed {
|
||||||
|
@ -31,7 +38,7 @@ pub trait UserManagementRoutes: private::Sealed {
|
||||||
///
|
///
|
||||||
/// The `redir` argument allows you to specify the point that users will be directed
|
/// The `redir` argument allows you to specify the point that users will be directed
|
||||||
/// to return to once their account has been created.
|
/// to return to once their account has been created.
|
||||||
fn add_um_routes<UserData: Serialize + DeserializeOwned + Default + 'static>(self, redir: &'static str) -> Self;
|
fn add_um_routes<UserData: Serialize + DeserializeOwned + Default + 'static>(self) -> Self;
|
||||||
|
|
||||||
/// Add a special route that requires users to be logged in
|
/// Add a special route that requires users to be logged in
|
||||||
///
|
///
|
||||||
|
@ -88,15 +95,15 @@ impl<A: ToSocketAddrs> UserManagementRoutes for crate::Server<A> {
|
||||||
/// Add pre-configured routes to the serve to handle authentication
|
/// Add pre-configured routes to the serve to handle authentication
|
||||||
///
|
///
|
||||||
/// See [`UserManagementRoutes::add_um_routes()`]
|
/// See [`UserManagementRoutes::add_um_routes()`]
|
||||||
fn add_um_routes<UserData: Serialize + DeserializeOwned + Default + 'static>(self, redir: &'static str) -> Self {
|
fn add_um_routes<UserData: Serialize + DeserializeOwned + Default + 'static>(self) -> Self {
|
||||||
#[allow(unused_mut)]
|
#[allow(unused_mut)]
|
||||||
let mut modified_self = self.add_route("/account", move|r|handle_base::<UserData>(r, redir))
|
let mut modified_self = self.add_route("/account", handle_base::<UserData>)
|
||||||
.add_route("/account/askcert", move|r|handle_ask_cert::<UserData>(r, redir))
|
.add_route("/account/askcert", handle_ask_cert::<UserData>)
|
||||||
.add_route("/account/register", move|r|handle_register::<UserData>(r, redir));
|
.add_route("/account/register", handle_register::<UserData>);
|
||||||
|
|
||||||
#[cfg(feature = "user_management_advanced")] {
|
#[cfg(feature = "user_management_advanced")] {
|
||||||
modified_self = modified_self
|
modified_self = modified_self
|
||||||
.add_route("/account/login", move|r|handle_login::<UserData>(r, redir))
|
.add_route("/account/login", handle_login::<UserData>)
|
||||||
.add_route("/account/password", handle_password::<UserData>);
|
.add_route("/account/password", handle_password::<UserData>);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -119,11 +126,14 @@ impl<A: ToSocketAddrs> UserManagementRoutes for crate::Server<A> {
|
||||||
self.add_route(path, move|request: Request| {
|
self.add_route(path, move|request: Request| {
|
||||||
let handler = handler.clone();
|
let handler = handler.clone();
|
||||||
async move {
|
async move {
|
||||||
|
let segments = request.path_segments();
|
||||||
|
let segments = segments.iter().map(String::as_ref).collect::<Vec<&str>>();
|
||||||
Ok(match request.user::<UserData>()? {
|
Ok(match request.user::<UserData>()? {
|
||||||
User::Unauthenticated => {
|
User::Unauthenticated => {
|
||||||
Response::success_gemini(UNAUTH)
|
render_unauth_page(segments)
|
||||||
},
|
},
|
||||||
User::NotSignedIn(_) => {
|
User::NotSignedIn(user) => {
|
||||||
|
save_redirect(&user, segments);
|
||||||
Response::success_gemini(NSI)
|
Response::success_gemini(NSI)
|
||||||
},
|
},
|
||||||
User::SignedIn(user) => {
|
User::SignedIn(user) => {
|
||||||
|
@ -161,26 +171,46 @@ impl<A: ToSocketAddrs> UserManagementRoutes for crate::Server<A> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_base<UserData: Serialize + DeserializeOwned>(request: Request, redirect: &'static str) -> Result<Response> {
|
#[cfg(feature = "user_management_advanced")]
|
||||||
|
const NSI: &str = include_str!("pages/nsi.gmi");
|
||||||
|
#[cfg(not(feature = "user_management_advanced"))]
|
||||||
|
const NSI: &str = include_str!("pages/nopass/nsi.gmi");
|
||||||
|
|
||||||
|
// TODO periodically clean these
|
||||||
|
#[cfg(feature = "dashmap")]
|
||||||
|
lazy_static::lazy_static! {
|
||||||
|
static ref PENDING_REDIRECTS: DashMap<u32, String> = Default::default();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "dashmap"))]
|
||||||
|
lazy_static::lazy_static! {
|
||||||
|
static ref PENDING_REDIRECTS: RwLock<HashMap<u32, String>> = Default::default();
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_base<UserData: Serialize + DeserializeOwned>(request: Request) -> Result<Response> {
|
||||||
|
let segments = request.trailing_segments().iter().map(String::as_str).collect::<Vec<&str>>();
|
||||||
Ok(match request.user::<UserData>()? {
|
Ok(match request.user::<UserData>()? {
|
||||||
User::Unauthenticated => {
|
User::Unauthenticated => {
|
||||||
Response::success_gemini(UNAUTH)
|
render_unauth_page(segments)
|
||||||
},
|
},
|
||||||
User::NotSignedIn(_) => {
|
User::NotSignedIn(usr) => {
|
||||||
|
save_redirect(&usr, segments);
|
||||||
Response::success_gemini(NSI)
|
Response::success_gemini(NSI)
|
||||||
},
|
},
|
||||||
User::SignedIn(user) => {
|
User::SignedIn(user) => {
|
||||||
render_settings_menu(user, redirect)
|
render_settings_menu(user)
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_ask_cert<UserData: Serialize + DeserializeOwned>(request: Request, redirect: &'static str) -> Result<Response> {
|
async fn handle_ask_cert<UserData: Serialize + DeserializeOwned>(request: Request) -> Result<Response> {
|
||||||
Ok(match request.user::<UserData>()? {
|
Ok(match request.user::<UserData>()? {
|
||||||
User::Unauthenticated => {
|
User::Unauthenticated => {
|
||||||
Response::client_certificate_required()
|
Response::client_certificate_required()
|
||||||
},
|
},
|
||||||
User::NotSignedIn(_) => {
|
User::NotSignedIn(nsi) => {
|
||||||
|
let segments = request.trailing_segments().iter().map(String::as_str).collect::<Vec<&str>>();
|
||||||
|
save_redirect(&nsi, segments);
|
||||||
#[cfg(feature = "user_management_advanced")] {
|
#[cfg(feature = "user_management_advanced")] {
|
||||||
Response::success_gemini(include_str!("pages/askcert/success.gmi"))
|
Response::success_gemini(include_str!("pages/askcert/success.gmi"))
|
||||||
}
|
}
|
||||||
|
@ -192,16 +222,16 @@ async fn handle_ask_cert<UserData: Serialize + DeserializeOwned>(request: Reques
|
||||||
Response::success_gemini(format!(
|
Response::success_gemini(format!(
|
||||||
include_str!("pages/askcert/exists.gmi"),
|
include_str!("pages/askcert/exists.gmi"),
|
||||||
username = user.username(),
|
username = user.username(),
|
||||||
redirect = redirect,
|
redirect = get_redirect(&user),
|
||||||
))
|
))
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_register<UserData: Serialize + DeserializeOwned + Default>(request: Request, redirect: &'static str) -> Result<Response> {
|
async fn handle_register<UserData: Serialize + DeserializeOwned + Default>(request: Request) -> Result<Response> {
|
||||||
Ok(match request.user::<UserData>()? {
|
Ok(match request.user::<UserData>()? {
|
||||||
User::Unauthenticated => {
|
User::Unauthenticated => {
|
||||||
Response::success_gemini(UNAUTH)
|
render_unauth_page(&[""])
|
||||||
},
|
},
|
||||||
User::NotSignedIn(nsi) => {
|
User::NotSignedIn(nsi) => {
|
||||||
if let Some(username) = request.input() {
|
if let Some(username) = request.input() {
|
||||||
|
@ -220,19 +250,19 @@ async fn handle_register<UserData: Serialize + DeserializeOwned + Default>(reque
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
Ok(_) => {
|
Ok(user) => {
|
||||||
#[cfg(feature = "user_management_advanced")] {
|
#[cfg(feature = "user_management_advanced")] {
|
||||||
Response::success_gemini(format!(
|
Response::success_gemini(format!(
|
||||||
include_str!("pages/register/success.gmi"),
|
include_str!("pages/register/success.gmi"),
|
||||||
username = username,
|
username = username,
|
||||||
redirect = redirect,
|
redirect = get_redirect(&user),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
#[cfg(not(feature = "user_management_advanced"))] {
|
#[cfg(not(feature = "user_management_advanced"))] {
|
||||||
Response::success_gemini(format!(
|
Response::success_gemini(format!(
|
||||||
include_str!("pages/nopass/register/success.gmi"),
|
include_str!("pages/nopass/register/success.gmi"),
|
||||||
username = username,
|
username = username,
|
||||||
redirect = redirect,
|
redirect = get_redirect(&user),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -243,16 +273,16 @@ async fn handle_register<UserData: Serialize + DeserializeOwned + Default>(reque
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
User::SignedIn(user) => {
|
User::SignedIn(user) => {
|
||||||
render_settings_menu(user, redirect)
|
render_settings_menu(user)
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "user_management_advanced")]
|
#[cfg(feature = "user_management_advanced")]
|
||||||
async fn handle_login<UserData: Serialize + DeserializeOwned + Default>(request: Request, redirect: &'static str) -> Result<Response> {
|
async fn handle_login<UserData: Serialize + DeserializeOwned + Default>(request: Request) -> Result<Response> {
|
||||||
Ok(match request.user::<UserData>()? {
|
Ok(match request.user::<UserData>()? {
|
||||||
User::Unauthenticated => {
|
User::Unauthenticated => {
|
||||||
Response::success_gemini(UNAUTH)
|
render_unauth_page(&[""])
|
||||||
},
|
},
|
||||||
User::NotSignedIn(nsi) => {
|
User::NotSignedIn(nsi) => {
|
||||||
if let Some(username) = request.trailing_segments().get(0) {
|
if let Some(username) = request.trailing_segments().get(0) {
|
||||||
|
@ -264,11 +294,11 @@ async fn handle_login<UserData: Serialize + DeserializeOwned + Default>(request:
|
||||||
username = username,
|
username = username,
|
||||||
))
|
))
|
||||||
},
|
},
|
||||||
Ok(_) => {
|
Ok(Some(user)) => {
|
||||||
Response::success_gemini(format!(
|
Response::success_gemini(format!(
|
||||||
include_str!("pages/login/success.gmi"),
|
include_str!("pages/login/success.gmi"),
|
||||||
username = username,
|
username = username,
|
||||||
redirect = redirect,
|
redirect = get_redirect(&user),
|
||||||
))
|
))
|
||||||
},
|
},
|
||||||
Err(e) => return Err(e.into()),
|
Err(e) => return Err(e.into()),
|
||||||
|
@ -285,7 +315,7 @@ async fn handle_login<UserData: Serialize + DeserializeOwned + Default>(request:
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
User::SignedIn(user) => {
|
User::SignedIn(user) => {
|
||||||
render_settings_menu(user, redirect)
|
render_settings_menu(user)
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -294,9 +324,10 @@ async fn handle_login<UserData: Serialize + DeserializeOwned + Default>(request:
|
||||||
async fn handle_password<UserData: Serialize + DeserializeOwned + Default>(request: Request) -> Result<Response> {
|
async fn handle_password<UserData: Serialize + DeserializeOwned + Default>(request: Request) -> Result<Response> {
|
||||||
Ok(match request.user::<UserData>()? {
|
Ok(match request.user::<UserData>()? {
|
||||||
User::Unauthenticated => {
|
User::Unauthenticated => {
|
||||||
Response::success_gemini(UNAUTH)
|
render_unauth_page(&[""])
|
||||||
},
|
},
|
||||||
User::NotSignedIn(_) => {
|
User::NotSignedIn(nsi) => {
|
||||||
|
save_redirect(&nsi, &[""]);
|
||||||
Response::success_gemini(NSI)
|
Response::success_gemini(NSI)
|
||||||
},
|
},
|
||||||
User::SignedIn(mut user) => {
|
User::SignedIn(mut user) => {
|
||||||
|
@ -318,10 +349,8 @@ async fn handle_password<UserData: Serialize + DeserializeOwned + Default>(reque
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
fn render_settings_menu<UserData: Serialize + DeserializeOwned>(
|
fn render_settings_menu<UserData: Serialize + DeserializeOwned>(
|
||||||
user: RegisteredUser<UserData>,
|
user: RegisteredUser<UserData>
|
||||||
redirect: &str
|
|
||||||
) -> Response {
|
) -> Response {
|
||||||
let mut document = Document::new();
|
let mut document = Document::new();
|
||||||
document
|
document
|
||||||
|
@ -329,7 +358,7 @@ fn render_settings_menu<UserData: Serialize + DeserializeOwned>(
|
||||||
.add_blank_line()
|
.add_blank_line()
|
||||||
.add_text(&format!("Welcome {}!", user.username()))
|
.add_text(&format!("Welcome {}!", user.username()))
|
||||||
.add_blank_line()
|
.add_blank_line()
|
||||||
.add_link(redirect, "Back to the app")
|
.add_link(get_redirect(&user).as_str(), "Back to the app")
|
||||||
.add_blank_line();
|
.add_blank_line();
|
||||||
|
|
||||||
#[cfg(feature = "user_management_advanced")]
|
#[cfg(feature = "user_management_advanced")]
|
||||||
|
@ -355,6 +384,48 @@ fn render_settings_menu<UserData: Serialize + DeserializeOwned>(
|
||||||
document.into()
|
document.into()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn render_unauth_page<'a>(
|
||||||
|
redirect: impl AsRef<[&'a str]>,
|
||||||
|
) -> Response {
|
||||||
|
Response::success_gemini(format!(
|
||||||
|
include_str!("pages/unauth.gmi"),
|
||||||
|
redirect = redirect.as_ref().join("/"),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn save_redirect<'a>(
|
||||||
|
user: &NotSignedInUser,
|
||||||
|
redirect: impl AsRef<[&'a str]>,
|
||||||
|
) {
|
||||||
|
let mut redirect = redirect.as_ref().join("/");
|
||||||
|
redirect.insert(0, '/');
|
||||||
|
if redirect.len() > 1 {
|
||||||
|
#[cfg(feature = "dashmap")]
|
||||||
|
let ref_to_map = &*PENDING_REDIRECTS;
|
||||||
|
#[cfg(not(feature = "dashmap"))]
|
||||||
|
let mut ref_to_map = PENDING_REDIRECTS.write().unwrap();
|
||||||
|
|
||||||
|
let cert_hash = UserManager::hash_certificate(&user.certificate);
|
||||||
|
debug!("Added \"{}\" as redirect for cert {:x}", redirect, cert_hash);
|
||||||
|
ref_to_map.insert(cert_hash, redirect);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_redirect<T: Serialize + DeserializeOwned>(user: &RegisteredUser<T>) -> String {
|
||||||
|
let cert_hash = UserManager::hash_certificate(user.active_certificate().unwrap());
|
||||||
|
|
||||||
|
#[cfg(feature = "dashmap")]
|
||||||
|
let maybe_redir = PENDING_REDIRECTS.get(&cert_hash).map(|r| r.clone());
|
||||||
|
#[cfg(not(feature = "dashmap"))]
|
||||||
|
let ref_to_map = PENDING_REDIRECTS.read().unwrap();
|
||||||
|
#[cfg(not(feature = "dashmap"))]
|
||||||
|
let maybe_redir = ref_to_map.get(&cert_hash).cloned();
|
||||||
|
|
||||||
|
let redirect = maybe_redir.unwrap_or_else(||"/".to_string());
|
||||||
|
debug!("Accessed redirect to \"{}\" for cert {:x}", redirect, cert_hash);
|
||||||
|
redirect
|
||||||
|
}
|
||||||
|
|
||||||
mod private {
|
mod private {
|
||||||
pub trait Sealed {}
|
pub trait Sealed {}
|
||||||
impl<A> Sealed for crate::Server<A> {}
|
impl<A> Sealed for crate::Server<A> {}
|
||||||
|
|
Loading…
Reference in a new issue