diff --git a/src/lib.rs b/src/lib.rs index bbb3153..9cdf1cf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -246,6 +246,8 @@ pub struct Server { rate_limits: RoutingNode>, #[cfg(feature="user_management")] data_dir: PathBuf, + #[cfg(feature="user_management")] + database: Option, #[cfg(feature="certgen")] certgen_mode: CertGenMode, } @@ -263,6 +265,8 @@ impl Server { rate_limits: RoutingNode::default(), #[cfg(feature="user_management")] data_dir: "data".into(), + #[cfg(feature="user_management")] + database: None, #[cfg(feature="certgen")] certgen_mode: CertGenMode::Interactive, } @@ -271,12 +275,30 @@ impl Server { #[cfg(feature="user_management")] /// Sets the directory to store user data in /// + /// This will only be used if a database is not provided with [`set_database()`]. + /// /// Defaults to `./data` if not specified + /// + /// [`set_database()`]: Self::set_database() pub fn set_database_dir(mut self, path: impl Into) -> Self { self.data_dir = path.into(); self } + #[cfg(feature="user_management")] + /// Sets a specific database to use + /// + /// This opens to trees within the database, both namespaced to avoid collisions. + /// + /// If this is not provided, a database will be opened at the directory provided by + /// [`set_database_dir()`] + /// + /// [`set_database_dir()`]: Self::set_database_dir() + pub fn set_database(mut self, db: sled::Db) -> Self { + self.database = Some(db); + self + } + #[cfg(feature="certgen")] /// Determine where certificate config comes from, if generation is required /// @@ -429,6 +451,9 @@ impl Server { self.routes.shrink(); + #[cfg(feature="user_management")] + let data_dir = self.data_dir; + let server = ServerInner { tls_acceptor: TlsAcceptor::from(config), routes: Arc::new(self.routes), @@ -437,7 +462,9 @@ impl Server { #[cfg(feature="ratelimiting")] rate_limits: Arc::new(self.rate_limits), #[cfg(feature="user_management")] - manager: UserManager::new(self.data_dir)?, + manager: UserManager::new( + self.database.unwrap_or_else(move|| sled::open(data_dir).unwrap()) + )?, }; server.serve(listener).await diff --git a/src/user_management/manager.rs b/src/user_management/manager.rs index 61035d9..4dd639a 100644 --- a/src/user_management/manager.rs +++ b/src/user_management/manager.rs @@ -1,8 +1,6 @@ use rustls::Certificate; use serde::{Deserialize, Serialize, de::DeserializeOwned}; -use std::path::Path; - use crate::user_management::{User, Result}; use crate::user_management::user::{RegisteredUser, NotSignedInUser, PartialUser}; @@ -37,8 +35,7 @@ impl UserManager { /// /// The `dir` argument is the path to a data directory, to be populated using sled. /// This will be created if it does not exist. - pub fn new(dir: impl AsRef) -> Result { - let db = sled::open(dir)?; + pub fn new(db: sled::Db) -> Result { Ok(Self { users: db.open_tree("gay.emii.kochab.users")?, certificates: db.open_tree("gay.emii.kochab.certificates")?,