Fix user management module, rework certificates to use hashes
This commit is contained in:
parent
244fd25112
commit
8b9fbce489
|
@ -14,7 +14,7 @@ include = ["src/**", "Cargo.*", "CHANGELOG.md", "LICENSE*", "README.md"]
|
||||||
[features]
|
[features]
|
||||||
default = ["scgi_srv"]
|
default = ["scgi_srv"]
|
||||||
user_management = ["sled", "bincode", "serde/derive", "crc32fast", "lazy_static"]
|
user_management = ["sled", "bincode", "serde/derive", "crc32fast", "lazy_static"]
|
||||||
user_management_advanced = ["rust-argon2", "ring", "user_management"]
|
user_management_advanced = ["rust-argon2", "user_management"]
|
||||||
user_management_routes = ["user_management"]
|
user_management_routes = ["user_management"]
|
||||||
serve_dir = ["mime_guess", "tokio/fs"]
|
serve_dir = ["mime_guess", "tokio/fs"]
|
||||||
ratelimiting = ["dashmap"]
|
ratelimiting = ["dashmap"]
|
||||||
|
@ -28,6 +28,7 @@ tokio = { version = "0.3.1", features = ["io-util","net","time", "rt"] }
|
||||||
uriparse = "0.6.3"
|
uriparse = "0.6.3"
|
||||||
percent-encoding = "2.1.0"
|
percent-encoding = "2.1.0"
|
||||||
log = "0.4.11"
|
log = "0.4.11"
|
||||||
|
ring = "0.16.15"
|
||||||
lazy_static = { version = "1.4.0", optional = true }
|
lazy_static = { version = "1.4.0", optional = true }
|
||||||
rustls = { version = "0.18.1", features = ["dangerous_configuration"], optional = true}
|
rustls = { version = "0.18.1", features = ["dangerous_configuration"], optional = true}
|
||||||
webpki = { version = "0.21.0", optional = true}
|
webpki = { version = "0.21.0", optional = true}
|
||||||
|
@ -39,7 +40,6 @@ bincode = { version = "1.3.1", optional = true }
|
||||||
serde = { version = "1.0", optional = true }
|
serde = { version = "1.0", optional = true }
|
||||||
rust-argon2 = { version = "0.8.2", optional = true }
|
rust-argon2 = { version = "0.8.2", optional = true }
|
||||||
crc32fast = { version = "1.2.1", optional = true }
|
crc32fast = { version = "1.2.1", optional = true }
|
||||||
ring = { version = "0.16.15", optional = true }
|
|
||||||
rcgen = { version = "0.8.5", optional = true }
|
rcgen = { version = "0.8.5", optional = true }
|
||||||
squeegee = { git = "https://gitlab.com/Alch_Emi/squeegee.git", branch = "main", optional = true }
|
squeegee = { git = "https://gitlab.com/Alch_Emi/squeegee.git", branch = "main", optional = true }
|
||||||
|
|
||||||
|
|
29
src/lib.rs
29
src/lib.rs
|
@ -5,16 +5,15 @@ use std::{
|
||||||
time::Duration,
|
time::Duration,
|
||||||
};
|
};
|
||||||
#[cfg(feature = "gemini_srv")]
|
#[cfg(feature = "gemini_srv")]
|
||||||
use std::{
|
use std::convert::TryFrom;
|
||||||
convert::TryFrom,
|
|
||||||
path::PathBuf,
|
|
||||||
};
|
|
||||||
#[cfg(feature = "scgi_srv")]
|
#[cfg(feature = "scgi_srv")]
|
||||||
use std::{
|
use std::{
|
||||||
collections::HashMap,
|
collections::HashMap,
|
||||||
net::SocketAddr,
|
net::SocketAddr,
|
||||||
str::FromStr,
|
str::FromStr,
|
||||||
};
|
};
|
||||||
|
#[cfg(any(feature = "gemini_srv", feature = "user_management"))]
|
||||||
|
use std::path::PathBuf;
|
||||||
use tokio::{
|
use tokio::{
|
||||||
io,
|
io,
|
||||||
io::BufReader,
|
io::BufReader,
|
||||||
|
@ -125,7 +124,7 @@ impl ServerInner {
|
||||||
#[cfg(feature = "gemini_srv")]
|
#[cfg(feature = "gemini_srv")]
|
||||||
stream: TcpStream,
|
stream: TcpStream,
|
||||||
#[cfg(feature = "scgi_srv")]
|
#[cfg(feature = "scgi_srv")]
|
||||||
stream: impl AsyncWrite + AsyncRead + Unpin,
|
stream: impl AsyncWrite + AsyncRead + Unpin + Send,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let fut_accept_request = async {
|
let fut_accept_request = async {
|
||||||
#[cfg(feature = "gemini_srv")]
|
#[cfg(feature = "gemini_srv")]
|
||||||
|
@ -218,7 +217,7 @@ impl ServerInner {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn send_response(&self, mut response: Response, stream: &mut (impl AsyncWrite + Unpin)) -> Result<()> {
|
async fn send_response(&self, mut response: Response, stream: &mut (impl AsyncWrite + Unpin + Send)) -> Result<()> {
|
||||||
let maybe_body = response.take_body();
|
let maybe_body = response.take_body();
|
||||||
let header = response.header();
|
let header = response.header();
|
||||||
|
|
||||||
|
@ -280,7 +279,7 @@ impl ServerInner {
|
||||||
#[cfg(feature = "gemini_srv")]
|
#[cfg(feature = "gemini_srv")]
|
||||||
async fn receive_request(
|
async fn receive_request(
|
||||||
&self,
|
&self,
|
||||||
stream: &mut (impl AsyncBufRead + Unpin),
|
stream: &mut (impl AsyncBufRead + Unpin + Send),
|
||||||
) -> Result<Request> {
|
) -> Result<Request> {
|
||||||
const HEADER_LIMIT: usize = REQUEST_URI_MAX_LEN + "\r\n".len();
|
const HEADER_LIMIT: usize = REQUEST_URI_MAX_LEN + "\r\n".len();
|
||||||
let mut stream = stream.take(HEADER_LIMIT as u64);
|
let mut stream = stream.take(HEADER_LIMIT as u64);
|
||||||
|
@ -377,7 +376,13 @@ impl ServerInner {
|
||||||
|
|
||||||
trace!("Headers received: {:?}", headers);
|
trace!("Headers received: {:?}", headers);
|
||||||
|
|
||||||
Ok(Request::new(headers)?)
|
Ok(
|
||||||
|
Request::new(
|
||||||
|
headers,
|
||||||
|
#[cfg(feature = "user_management")]
|
||||||
|
self.manager.clone(),
|
||||||
|
)?
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -653,7 +658,7 @@ impl Server {
|
||||||
///
|
///
|
||||||
/// `addr` can be anything `tokio` can parse, including just a string like
|
/// `addr` can be anything `tokio` can parse, including just a string like
|
||||||
/// "localhost:1965"
|
/// "localhost:1965"
|
||||||
pub async fn serve_ip(self, addr: impl ToSocketAddrs) -> Result<()> {
|
pub async fn serve_ip(self, addr: impl ToSocketAddrs + Send) -> Result<()> {
|
||||||
let server = self.build()?;
|
let server = self.build()?;
|
||||||
let socket = TcpListener::bind(addr).await?;
|
let socket = TcpListener::bind(addr).await?;
|
||||||
server.serve_ip(socket).await
|
server.serve_ip(socket).await
|
||||||
|
@ -677,7 +682,7 @@ impl Default for Server {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn send_response_header(header: &ResponseHeader, stream: &mut (impl AsyncWrite + Unpin)) -> Result<()> {
|
async fn send_response_header(header: &ResponseHeader, stream: &mut (impl AsyncWrite + Unpin + Send)) -> Result<()> {
|
||||||
let header = format!(
|
let header = format!(
|
||||||
"{status} {meta}\r\n",
|
"{status} {meta}\r\n",
|
||||||
status = header.status.code(),
|
status = header.status.code(),
|
||||||
|
@ -690,7 +695,7 @@ async fn send_response_header(header: &ResponseHeader, stream: &mut (impl AsyncW
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn maybe_send_response_body(maybe_body: Option<Body>, stream: &mut (impl AsyncWrite + Unpin)) -> Result<()> {
|
async fn maybe_send_response_body(maybe_body: Option<Body>, stream: &mut (impl AsyncWrite + Unpin + Send)) -> Result<()> {
|
||||||
if let Some(body) = maybe_body {
|
if let Some(body) = maybe_body {
|
||||||
send_response_body(body, stream).await?;
|
send_response_body(body, stream).await?;
|
||||||
}
|
}
|
||||||
|
@ -698,7 +703,7 @@ async fn maybe_send_response_body(maybe_body: Option<Body>, stream: &mut (impl A
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn send_response_body(body: Body, stream: &mut (impl AsyncWrite + Unpin)) -> Result<()> {
|
async fn send_response_body(body: Body, stream: &mut (impl AsyncWrite + Unpin + Send)) -> Result<()> {
|
||||||
match body {
|
match body {
|
||||||
Body::Bytes(bytes) => stream.write_all(&bytes).await?,
|
Body::Bytes(bytes) => stream.write_all(&bytes).await?,
|
||||||
Body::Reader(mut reader) => { io::copy(&mut reader, stream).await?; },
|
Body::Reader(mut reader) => { io::copy(&mut reader, stream).await?; },
|
||||||
|
|
|
@ -1,7 +1,3 @@
|
||||||
#[cfg(feature = "gemini_srv")]
|
|
||||||
pub use rustls::Certificate;
|
|
||||||
#[cfg(feature = "scgi_srv")]
|
|
||||||
pub type Certificate = String;
|
|
||||||
pub use uriparse::URIReference;
|
pub use uriparse::URIReference;
|
||||||
|
|
||||||
mod meta;
|
mod meta;
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
use std::ops;
|
use std::ops;
|
||||||
|
use std::convert::TryInto;
|
||||||
#[cfg(feature = "scgi_srv")]
|
#[cfg(feature = "scgi_srv")]
|
||||||
use std::{
|
use std::{
|
||||||
collections::HashMap,
|
collections::HashMap,
|
||||||
|
@ -7,9 +8,10 @@ use std::{
|
||||||
use anyhow::*;
|
use anyhow::*;
|
||||||
use percent_encoding::percent_decode_str;
|
use percent_encoding::percent_decode_str;
|
||||||
use uriparse::URIReference;
|
use uriparse::URIReference;
|
||||||
use crate::types::Certificate;
|
|
||||||
#[cfg(feature="user_management")]
|
#[cfg(feature="user_management")]
|
||||||
use serde::{Serialize, de::DeserializeOwned};
|
use serde::{Serialize, de::DeserializeOwned};
|
||||||
|
#[cfg(feature = "gemini_srv")]
|
||||||
|
use ring::digest;
|
||||||
|
|
||||||
#[cfg(feature="user_management")]
|
#[cfg(feature="user_management")]
|
||||||
use crate::user_management::{UserManager, User};
|
use crate::user_management::{UserManager, User};
|
||||||
|
@ -17,7 +19,7 @@ use crate::user_management::{UserManager, User};
|
||||||
pub struct Request {
|
pub struct Request {
|
||||||
uri: URIReference<'static>,
|
uri: URIReference<'static>,
|
||||||
input: Option<String>,
|
input: Option<String>,
|
||||||
certificate: Option<Certificate>,
|
certificate: Option<[u8; 32]>,
|
||||||
trailing_segments: Option<Vec<String>>,
|
trailing_segments: Option<Vec<String>>,
|
||||||
#[cfg(feature="user_management")]
|
#[cfg(feature="user_management")]
|
||||||
manager: UserManager,
|
manager: UserManager,
|
||||||
|
@ -49,7 +51,13 @@ impl Request {
|
||||||
)
|
)
|
||||||
.context("Request URI is invalid")?
|
.context("Request URI is invalid")?
|
||||||
.into_owned(),
|
.into_owned(),
|
||||||
headers.get("TLS_CLIENT_HASH").cloned(),
|
headers.get("TLS_CLIENT_HASH")
|
||||||
|
.map(|hsh| {
|
||||||
|
ring::test::from_hex(hsh.as_str())
|
||||||
|
.expect("Received invalid certificate fingerprint from upstream")
|
||||||
|
.try_into()
|
||||||
|
.expect("Received certificate fingerprint of invalid lenght from upstream")
|
||||||
|
}),
|
||||||
);
|
);
|
||||||
|
|
||||||
uri.normalize();
|
uri.normalize();
|
||||||
|
@ -140,8 +148,14 @@ impl Request {
|
||||||
&self.headers
|
&self.headers
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn set_cert(&mut self, cert: Option<Certificate>) {
|
#[cfg(feature = "gemini_srv")]
|
||||||
self.certificate = cert;
|
pub (crate) fn set_cert(&mut self, cert: Option<rustls::Certificate>) {
|
||||||
|
self.certificate = cert.map(|cert| {
|
||||||
|
digest::digest(&digest::SHA256, cert.0.as_ref())
|
||||||
|
.as_ref()
|
||||||
|
.try_into()
|
||||||
|
.expect("SHA256 didn't return 256 bits")
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn set_trailing(&mut self, segments: Vec<String>) {
|
pub fn set_trailing(&mut self, segments: Vec<String>) {
|
||||||
|
@ -149,7 +163,8 @@ impl Request {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::missing_const_for_fn)]
|
#[allow(clippy::missing_const_for_fn)]
|
||||||
pub fn certificate(&self) -> Option<&Certificate> {
|
/// Get the fingerprint of the certificate the user is connecting with
|
||||||
|
pub fn certificate(&self) -> Option<&[u8; 32]> {
|
||||||
self.certificate.as_ref()
|
self.certificate.as_ref()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,24 +1,8 @@
|
||||||
use rustls::Certificate;
|
use serde::{Serialize, de::DeserializeOwned};
|
||||||
use serde::{Deserialize, Serialize, de::DeserializeOwned};
|
|
||||||
|
|
||||||
use crate::user_management::{User, Result};
|
use crate::user_management::{User, Result};
|
||||||
use crate::user_management::user::{RegisteredUser, NotSignedInUser, PartialUser};
|
use crate::user_management::user::{RegisteredUser, NotSignedInUser, PartialUser};
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
/// Data stored in the certificate tree about a certain certificate
|
|
||||||
pub struct CertificateData {
|
|
||||||
#[serde(with = "CertificateDef")]
|
|
||||||
/// The certificate in question
|
|
||||||
pub certificate: Certificate,
|
|
||||||
|
|
||||||
/// The username of the user to which this certificate is registered
|
|
||||||
pub owner_username: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
|
||||||
#[serde(remote = "Certificate")]
|
|
||||||
struct CertificateDef(Vec<u8>);
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
/// A struct containing information for managing users.
|
/// A struct containing information for managing users.
|
||||||
///
|
///
|
||||||
|
@ -43,21 +27,14 @@ impl UserManager {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Produce a u32 hash from a certificate, used for [`lookup_certificate()`](Self::lookup_certificate())
|
/// Lookup the owner of a certificate based on it's fingerprint
|
||||||
pub fn hash_certificate(cert: &Certificate) -> u32 {
|
|
||||||
let mut hasher = crc32fast::Hasher::new();
|
|
||||||
hasher.update(cert.0.as_ref());
|
|
||||||
hasher.finalize()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Lookup information about a certificate based on it's u32 hash
|
|
||||||
///
|
///
|
||||||
/// # Errors
|
/// # Errors
|
||||||
/// An error is thrown if there is an error reading from the database or if data
|
/// An error is thrown if there is an error reading from the database or if data
|
||||||
/// recieved from the database is corrupt
|
/// recieved from the database is corrupt
|
||||||
pub fn lookup_certificate(&self, cert: u32) -> Result<Option<CertificateData>> {
|
pub fn lookup_certificate(&self, cert: [u8; 32]) -> Result<Option<String>> {
|
||||||
if let Some(bytes) = self.certificates.get(cert.to_le_bytes())? {
|
if let Some(bytes) = self.certificates.get(cert)? {
|
||||||
Ok(Some(bincode::deserialize(&bytes)?))
|
Ok(Some(std::str::from_utf8(bytes.as_ref())?.to_string()))
|
||||||
} else {
|
} else {
|
||||||
Ok(None)
|
Ok(None)
|
||||||
}
|
}
|
||||||
|
@ -116,20 +93,19 @@ impl UserManager {
|
||||||
/// Pancis if the database is corrupt
|
/// Pancis if the database is corrupt
|
||||||
pub fn get_user<UserData>(
|
pub fn get_user<UserData>(
|
||||||
&self,
|
&self,
|
||||||
cert: Option<&Certificate>
|
cert: Option<&[u8; 32]>
|
||||||
) -> Result<User<UserData>>
|
) -> Result<User<UserData>>
|
||||||
where
|
where
|
||||||
UserData: Serialize + DeserializeOwned
|
UserData: Serialize + DeserializeOwned
|
||||||
{
|
{
|
||||||
if let Some(certificate) = cert {
|
if let Some(certificate) = cert {
|
||||||
let cert_hash = Self::hash_certificate(certificate);
|
if let Some(username) = self.lookup_certificate(*certificate)? {
|
||||||
if let Some(certificate_data) = self.lookup_certificate(cert_hash)? {
|
let user_inner = self.lookup_user(&username)?
|
||||||
let user_inner = self.lookup_user(&certificate_data.owner_username)?
|
|
||||||
.expect("Database corruption: Certificate data refers to non-existant user");
|
.expect("Database corruption: Certificate data refers to non-existant user");
|
||||||
Ok(User::SignedIn(user_inner.with_cert(certificate_data.certificate)))
|
Ok(User::SignedIn(user_inner.with_cert(*certificate)))
|
||||||
} else {
|
} else {
|
||||||
Ok(User::NotSignedIn(NotSignedInUser {
|
Ok(User::NotSignedIn(NotSignedInUser {
|
||||||
certificate: certificate.clone(),
|
certificate: *certificate,
|
||||||
manager: self.clone(),
|
manager: self.clone(),
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,7 +26,6 @@ mod routes;
|
||||||
pub use routes::UserManagementRoutes;
|
pub use routes::UserManagementRoutes;
|
||||||
pub use manager::UserManager;
|
pub use manager::UserManager;
|
||||||
pub use user::User;
|
pub use user::User;
|
||||||
pub use manager::CertificateData;
|
|
||||||
// Imports for docs
|
// Imports for docs
|
||||||
#[allow(unused_imports)]
|
#[allow(unused_imports)]
|
||||||
use user::{NotSignedInUser, RegisteredUser};
|
use user::{NotSignedInUser, RegisteredUser};
|
||||||
|
@ -39,7 +38,8 @@ pub enum UserManagerError {
|
||||||
PasswordNotSet,
|
PasswordNotSet,
|
||||||
DatabaseError(sled::Error),
|
DatabaseError(sled::Error),
|
||||||
DatabaseTransactionError(sled::transaction::TransactionError),
|
DatabaseTransactionError(sled::transaction::TransactionError),
|
||||||
DeserializeError(bincode::Error),
|
DeserializeBincodeError(bincode::Error),
|
||||||
|
DeserializeUtf8Error(std::str::Utf8Error),
|
||||||
#[cfg(feature = "user_management_advanced")]
|
#[cfg(feature = "user_management_advanced")]
|
||||||
Argon2Error(argon2::Error),
|
Argon2Error(argon2::Error),
|
||||||
}
|
}
|
||||||
|
@ -58,7 +58,13 @@ impl From<sled::transaction::TransactionError> for UserManagerError {
|
||||||
|
|
||||||
impl From<bincode::Error> for UserManagerError {
|
impl From<bincode::Error> for UserManagerError {
|
||||||
fn from(error: bincode::Error) -> Self {
|
fn from(error: bincode::Error) -> Self {
|
||||||
Self::DeserializeError(error)
|
Self::DeserializeBincodeError(error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<std::str::Utf8Error> for UserManagerError {
|
||||||
|
fn from(error: std::str::Utf8Error) -> Self {
|
||||||
|
Self::DeserializeUtf8Error(error)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -74,7 +80,8 @@ impl std::error::Error for UserManagerError {
|
||||||
match self {
|
match self {
|
||||||
Self::DatabaseError(e) => Some(e),
|
Self::DatabaseError(e) => Some(e),
|
||||||
Self::DatabaseTransactionError(e) => Some(e),
|
Self::DatabaseTransactionError(e) => Some(e),
|
||||||
Self::DeserializeError(e) => Some(e),
|
Self::DeserializeBincodeError(e) => Some(e),
|
||||||
|
Self::DeserializeUtf8Error(e) => Some(e),
|
||||||
#[cfg(feature = "user_management_advanced")]
|
#[cfg(feature = "user_management_advanced")]
|
||||||
Self::Argon2Error(e) => Some(e),
|
Self::Argon2Error(e) => Some(e),
|
||||||
_ => None
|
_ => None
|
||||||
|
@ -93,8 +100,10 @@ impl std::fmt::Display for UserManagerError {
|
||||||
write!(f, "Error accessing the user database: {}", e),
|
write!(f, "Error accessing the user database: {}", e),
|
||||||
Self::DatabaseTransactionError(e) =>
|
Self::DatabaseTransactionError(e) =>
|
||||||
write!(f, "Error accessing the user database: {}", e),
|
write!(f, "Error accessing the user database: {}", e),
|
||||||
Self::DeserializeError(e) =>
|
Self::DeserializeBincodeError(e) =>
|
||||||
write!(f, "Recieved messy data from database, possible corruption: {}", e),
|
write!(f, "Recieved messy data from database, possible corruption: {}", e),
|
||||||
|
Self::DeserializeUtf8Error(e) =>
|
||||||
|
write!(f, "Recieved invalid UTF-8 from database, possible corruption: {}", e),
|
||||||
#[cfg(feature = "user_management_advanced")]
|
#[cfg(feature = "user_management_advanced")]
|
||||||
Self::Argon2Error(e) =>
|
Self::Argon2Error(e) =>
|
||||||
write!(f, "Argon2 Error, likely malformed password hash, possible database corruption: {}", e),
|
write!(f, "Argon2 Error, likely malformed password hash, possible database corruption: {}", e),
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use tokio::net::ToSocketAddrs;
|
|
||||||
use serde::{Serialize, de::DeserializeOwned};
|
use serde::{Serialize, de::DeserializeOwned};
|
||||||
|
|
||||||
#[cfg(feature = "dashmap")]
|
#[cfg(feature = "dashmap")]
|
||||||
|
@ -16,7 +15,6 @@ use crate::types::document::HeadingLevel;
|
||||||
use crate::user_management::{
|
use crate::user_management::{
|
||||||
User,
|
User,
|
||||||
RegisteredUser,
|
RegisteredUser,
|
||||||
UserManager,
|
|
||||||
UserManagerError,
|
UserManagerError,
|
||||||
user::NotSignedInUser,
|
user::NotSignedInUser,
|
||||||
};
|
};
|
||||||
|
@ -91,7 +89,7 @@ pub trait UserManagementRoutes: private::Sealed {
|
||||||
F: Send + Sync + 'static + Future<Output = Result<Response>>;
|
F: Send + Sync + 'static + Future<Output = Result<Response>>;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<A: ToSocketAddrs> UserManagementRoutes for crate::Server<A> {
|
impl UserManagementRoutes for crate::Server {
|
||||||
/// Add pre-configured routes to the serve to handle authentication
|
/// Add pre-configured routes to the serve to handle authentication
|
||||||
///
|
///
|
||||||
/// See [`UserManagementRoutes::add_um_routes()`]
|
/// See [`UserManagementRoutes::add_um_routes()`]
|
||||||
|
@ -187,7 +185,7 @@ lazy_static::lazy_static! {
|
||||||
|
|
||||||
#[cfg(not(feature = "dashmap"))]
|
#[cfg(not(feature = "dashmap"))]
|
||||||
lazy_static::lazy_static! {
|
lazy_static::lazy_static! {
|
||||||
static ref PENDING_REDIRECTS: RwLock<HashMap<u32, String>> = Default::default();
|
static ref PENDING_REDIRECTS: RwLock<HashMap<[u8; 32], String>> = Default::default();
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_base<UserData: Serialize + DeserializeOwned>(request: Request) -> Result<Response> {
|
async fn handle_base<UserData: Serialize + DeserializeOwned>(request: Request) -> Result<Response> {
|
||||||
|
@ -408,28 +406,27 @@ fn save_redirect<'a>(
|
||||||
#[cfg(not(feature = "dashmap"))]
|
#[cfg(not(feature = "dashmap"))]
|
||||||
let mut ref_to_map = PENDING_REDIRECTS.write().unwrap();
|
let mut ref_to_map = PENDING_REDIRECTS.write().unwrap();
|
||||||
|
|
||||||
let cert_hash = UserManager::hash_certificate(&user.certificate);
|
debug!("Added \"{}\" as redirect for cert {:x?}", redirect, &user.certificate);
|
||||||
debug!("Added \"{}\" as redirect for cert {:x}", redirect, cert_hash);
|
ref_to_map.insert(user.certificate, redirect);
|
||||||
ref_to_map.insert(cert_hash, redirect);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_redirect<T: Serialize + DeserializeOwned>(user: &RegisteredUser<T>) -> String {
|
fn get_redirect<T: Serialize + DeserializeOwned>(user: &RegisteredUser<T>) -> String {
|
||||||
let cert_hash = UserManager::hash_certificate(user.active_certificate().unwrap());
|
let cert = user.active_certificate().unwrap();
|
||||||
|
|
||||||
#[cfg(feature = "dashmap")]
|
#[cfg(feature = "dashmap")]
|
||||||
let maybe_redir = PENDING_REDIRECTS.get(&cert_hash).map(|r| r.clone());
|
let maybe_redir = PENDING_REDIRECTS.get(cert).map(|r| r.clone());
|
||||||
#[cfg(not(feature = "dashmap"))]
|
#[cfg(not(feature = "dashmap"))]
|
||||||
let ref_to_map = PENDING_REDIRECTS.read().unwrap();
|
let ref_to_map = PENDING_REDIRECTS.read().unwrap();
|
||||||
#[cfg(not(feature = "dashmap"))]
|
#[cfg(not(feature = "dashmap"))]
|
||||||
let maybe_redir = ref_to_map.get(&cert_hash).cloned();
|
let maybe_redir = ref_to_map.get(cert).cloned();
|
||||||
|
|
||||||
let redirect = maybe_redir.unwrap_or_else(||"/".to_string());
|
let redirect = maybe_redir.unwrap_or_else(||"/".to_string());
|
||||||
debug!("Accessed redirect to \"{}\" for cert {:x}", redirect, cert_hash);
|
debug!("Accessed redirect to \"{}\" for cert {:x?}", redirect, cert);
|
||||||
redirect
|
redirect
|
||||||
}
|
}
|
||||||
|
|
||||||
mod private {
|
mod private {
|
||||||
pub trait Sealed {}
|
pub trait Sealed {}
|
||||||
impl<A> Sealed for crate::Server<A> {}
|
impl Sealed for crate::Server {}
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,13 +13,11 @@
|
||||||
//! the data stored for almost all users. This is accomplished through the
|
//! the data stored for almost all users. This is accomplished through the
|
||||||
//! [`as_mut()`](RegisteredUser::as_mut) method. Changes made this way must be persisted
|
//! [`as_mut()`](RegisteredUser::as_mut) method. Changes made this way must be persisted
|
||||||
//! using [`save()`](RegisteredUser::save()) or by dropping the user.
|
//! using [`save()`](RegisteredUser::save()) or by dropping the user.
|
||||||
use rustls::Certificate;
|
|
||||||
use serde::{Deserialize, Serialize, de::DeserializeOwned};
|
use serde::{Deserialize, Serialize, de::DeserializeOwned};
|
||||||
use sled::Transactional;
|
use sled::Transactional;
|
||||||
|
|
||||||
use crate::user_management::UserManager;
|
use crate::user_management::UserManager;
|
||||||
use crate::user_management::Result;
|
use crate::user_management::Result;
|
||||||
use crate::user_management::manager::CertificateData;
|
|
||||||
|
|
||||||
#[cfg(feature = "user_management_advanced")]
|
#[cfg(feature = "user_management_advanced")]
|
||||||
const ARGON2_CONFIG: argon2::Config = argon2::Config {
|
const ARGON2_CONFIG: argon2::Config = argon2::Config {
|
||||||
|
@ -46,7 +44,7 @@ lazy_static::lazy_static! {
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
pub (crate) struct PartialUser<UserData> {
|
pub (crate) struct PartialUser<UserData> {
|
||||||
pub data: UserData,
|
pub data: UserData,
|
||||||
pub certificates: Vec<u32>,
|
pub certificates: Vec<[u8; 32]>,
|
||||||
#[cfg(feature = "user_management_advanced")]
|
#[cfg(feature = "user_management_advanced")]
|
||||||
pub pass_hash: Option<(Vec<u8>, [u8; 32])>,
|
pub pass_hash: Option<(Vec<u8>, [u8; 32])>,
|
||||||
}
|
}
|
||||||
|
@ -94,7 +92,7 @@ pub enum User<UserData: Serialize + DeserializeOwned> {
|
||||||
///
|
///
|
||||||
/// For more information about the user lifecycle and sign-in stages, see [`User`]
|
/// For more information about the user lifecycle and sign-in stages, see [`User`]
|
||||||
pub struct NotSignedInUser {
|
pub struct NotSignedInUser {
|
||||||
pub (crate) certificate: Certificate,
|
pub (crate) certificate: [u8; 32],
|
||||||
pub (crate) manager: UserManager,
|
pub (crate) manager: UserManager,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -120,7 +118,7 @@ impl NotSignedInUser {
|
||||||
|
|
||||||
let mut newser = RegisteredUser::new(
|
let mut newser = RegisteredUser::new(
|
||||||
username,
|
username,
|
||||||
Some(self.certificate.clone()),
|
Some(self.certificate),
|
||||||
self.manager,
|
self.manager,
|
||||||
PartialUser {
|
PartialUser {
|
||||||
data: UserData::default(),
|
data: UserData::default(),
|
||||||
|
@ -179,9 +177,8 @@ impl NotSignedInUser {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let certhash = UserManager::hash_certificate(&self.certificate);
|
info!("User {} attached certificate with fingerprint {:x?}", username, &self.certificate[..]);
|
||||||
info!("User {} attached certificate with hash {:x}", username, certhash);
|
user.add_certificate(self.certificate)?;
|
||||||
user.add_certificate(self.certificate.clone())?;
|
|
||||||
user.active_certificate = Some(self.certificate);
|
user.active_certificate = Some(self.certificate);
|
||||||
Ok(Some(user))
|
Ok(Some(user))
|
||||||
} else {
|
} else {
|
||||||
|
@ -196,7 +193,7 @@ impl NotSignedInUser {
|
||||||
/// For more information about the user lifecycle and sign-in stages, see [`User`]
|
/// For more information about the user lifecycle and sign-in stages, see [`User`]
|
||||||
pub struct RegisteredUser<UserData: Serialize + DeserializeOwned> {
|
pub struct RegisteredUser<UserData: Serialize + DeserializeOwned> {
|
||||||
username: String,
|
username: String,
|
||||||
active_certificate: Option<Certificate>,
|
active_certificate: Option<[u8; 32]>,
|
||||||
manager: UserManager,
|
manager: UserManager,
|
||||||
inner: PartialUser<UserData>,
|
inner: PartialUser<UserData>,
|
||||||
/// Indicates that [`RegisteredUser::as_mut()`] has been called, but [`RegisteredUser::save()`] has not
|
/// Indicates that [`RegisteredUser::as_mut()`] has been called, but [`RegisteredUser::save()`] has not
|
||||||
|
@ -208,7 +205,7 @@ impl<UserData: Serialize + DeserializeOwned> RegisteredUser<UserData> {
|
||||||
/// Create a new user from parts
|
/// Create a new user from parts
|
||||||
pub (crate) fn new(
|
pub (crate) fn new(
|
||||||
username: String,
|
username: String,
|
||||||
active_certificate: Option<Certificate>,
|
active_certificate: Option<[u8; 32]>,
|
||||||
manager: UserManager,
|
manager: UserManager,
|
||||||
inner: PartialUser<UserData>
|
inner: PartialUser<UserData>
|
||||||
) -> Self {
|
) -> Self {
|
||||||
|
@ -226,30 +223,22 @@ impl<UserData: Serialize + DeserializeOwned> RegisteredUser<UserData> {
|
||||||
/// This is not to be confused with [`RegisteredUser::add_certificate`], which
|
/// This is not to be confused with [`RegisteredUser::add_certificate`], which
|
||||||
/// performs the database operations needed to register a new certificate to a user.
|
/// performs the database operations needed to register a new certificate to a user.
|
||||||
/// This literally just marks the active certificate.
|
/// This literally just marks the active certificate.
|
||||||
pub (crate) fn with_cert(mut self, cert: Certificate) -> Self {
|
pub (crate) fn with_cert(mut self, cert: [u8; 32]) -> Self {
|
||||||
self.active_certificate = Some(cert);
|
self.active_certificate = Some(cert);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the [`Certificate`] that the user is currently using to connect.
|
/// Get the fingerprint of the certificate that the user is currently using.
|
||||||
///
|
///
|
||||||
/// If this user was retrieved by a [`UserManager::lookup_user()`], this will be
|
/// If this user was retrieved by a [`UserManager::lookup_user()`], this will be
|
||||||
/// [`None`]. In all other cases, this will be [`Some`].
|
/// [`None`]. In all other cases, this will be [`Some`].
|
||||||
pub fn active_certificate(&self) -> Option<&Certificate> {
|
pub fn active_certificate(&self) -> Option<&[u8; 32]> {
|
||||||
self.active_certificate.as_ref()
|
self.active_certificate.as_ref()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Produce a list of all [`Certificate`]s registered to this account
|
/// Produce a list of all certificate fingerprints registered to this account
|
||||||
pub fn all_certificates(&self) -> Vec<Certificate> {
|
pub fn all_certificates(&self) -> &Vec<[u8; 32]> {
|
||||||
self.inner.certificates
|
&self.inner.certificates
|
||||||
.iter()
|
|
||||||
.map(
|
|
||||||
|cid| self.manager.lookup_certificate(*cid)
|
|
||||||
.expect("Database corruption: User refers to non-existant certificate")
|
|
||||||
.expect("Error accessing database")
|
|
||||||
.certificate
|
|
||||||
)
|
|
||||||
.collect()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the user's current username.
|
/// Get the user's current username.
|
||||||
|
@ -335,18 +324,10 @@ impl<UserData: Serialize + DeserializeOwned> RegisteredUser<UserData> {
|
||||||
/// If you have a [`NotSignedInUser`] and are looking for a way to link them to an
|
/// If you have a [`NotSignedInUser`] and are looking for a way to link them to an
|
||||||
/// existing user, consider [`NotSignedInUser::attach()`], which contains facilities for
|
/// existing user, consider [`NotSignedInUser::attach()`], which contains facilities for
|
||||||
/// password checking and automatically performs the user lookup.
|
/// password checking and automatically performs the user lookup.
|
||||||
pub fn add_certificate(&mut self, certificate: Certificate) -> Result<()> {
|
pub fn add_certificate(&mut self, certificate: [u8; 32]) -> Result<()> {
|
||||||
let cert_hash = UserManager::hash_certificate(&certificate);
|
self.inner.certificates.push(certificate);
|
||||||
|
|
||||||
self.inner.certificates.push(cert_hash);
|
|
||||||
|
|
||||||
let cert_info = CertificateData {
|
|
||||||
certificate,
|
|
||||||
owner_username: self.username.clone(),
|
|
||||||
};
|
|
||||||
|
|
||||||
let inner_serialized = bincode::serialize(&self.inner)?;
|
let inner_serialized = bincode::serialize(&self.inner)?;
|
||||||
let cert_info_serialized = bincode::serialize(&cert_info)?;
|
|
||||||
|
|
||||||
(&self.manager.users, &self.manager.certificates)
|
(&self.manager.users, &self.manager.certificates)
|
||||||
.transaction(|(tx_usr, tx_crt)| {
|
.transaction(|(tx_usr, tx_crt)| {
|
||||||
|
@ -355,8 +336,8 @@ impl<UserData: Serialize + DeserializeOwned> RegisteredUser<UserData> {
|
||||||
inner_serialized.clone(),
|
inner_serialized.clone(),
|
||||||
)?;
|
)?;
|
||||||
tx_crt.insert(
|
tx_crt.insert(
|
||||||
cert_hash.to_le_bytes().as_ref(),
|
&certificate,
|
||||||
cert_info_serialized.clone(),
|
self.username.as_bytes(),
|
||||||
)?;
|
)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
})?;
|
})?;
|
||||||
|
|
Loading…
Reference in New Issue