diff --git a/CHANGELOG.md b/CHANGELOG.md index 47c5d1b..435e8ee 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `server_dir` default feature for serve_dir utils [@Alch-Emi](https://github.com/Alch-Emi) ### Improved - build time and size by [@Alch-Emi](https://github.com/Alch-Emi) +### Changed +- Added route API [@Alch-Emi](https://github.com/Alch-Emi) ## [0.3.0] - 2020-11-14 ### Added diff --git a/examples/certificates.rs b/examples/certificates.rs index 541fbe5..143c71c 100644 --- a/examples/certificates.rs +++ b/examples/certificates.rs @@ -19,7 +19,8 @@ async fn main() -> Result<()> { let users = Arc::>>::default(); Server::bind(("0.0.0.0", GEMINI_PORT)) - .serve(move|req| handle_request(users.clone(), req)) + .add_route("/", move|req| handle_request(users.clone(), req)) + .serve() .await } diff --git a/examples/document.rs b/examples/document.rs index 8ff6bbb..cc889c6 100644 --- a/examples/document.rs +++ b/examples/document.rs @@ -12,7 +12,8 @@ async fn main() -> Result<()> { .init(); Server::bind(("localhost", GEMINI_PORT)) - .serve(handle_request) + .add_route("/",handle_request) + .serve() .await } diff --git a/examples/routing.rs b/examples/routing.rs new file mode 100644 index 0000000..04bded6 --- /dev/null +++ b/examples/routing.rs @@ -0,0 +1,56 @@ +use anyhow::*; +use futures_core::future::BoxFuture; +use futures_util::FutureExt; +use log::LevelFilter; +use northstar::{Document, document::HeadingLevel, Request, Response, GEMINI_PORT}; + +#[tokio::main] +async fn main() -> Result<()> { + env_logger::builder() + .filter_module("northstar", LevelFilter::Debug) + .init(); + + northstar::Server::bind(("localhost", GEMINI_PORT)) + .add_route("/", handle_base) + .add_route("/route", handle_short) + .add_route("/route/long", handle_long) + .serve() + .await +} + +fn handle_base(req: Request) -> BoxFuture<'static, Result> { + let doc = generate_doc("base", &req); + async move { + Ok(Response::document(doc)) + }.boxed() +} + +fn handle_short(req: Request) -> BoxFuture<'static, Result> { + let doc = generate_doc("short", &req); + async move { + Ok(Response::document(doc)) + }.boxed() +} + +fn handle_long(req: Request) -> BoxFuture<'static, Result> { + let doc = generate_doc("long", &req); + async move { + Ok(Response::document(doc)) + }.boxed() +} + +fn generate_doc(route_name: &str, req: &Request) -> Document { + let trailing = req.trailing_segments().join("/"); + let mut doc = Document::new(); + doc.add_heading(HeadingLevel::H1, "Routing Demo") + .add_text(&format!("You're currently on the {} route", route_name)) + .add_text(&format!("Trailing segments: /{}", trailing)) + .add_blank_line() + .add_text("Here's some links to try:") + .add_link_without_label("/") + .add_link_without_label("/route") + .add_link_without_label("/route/long") + .add_link_without_label("/route/not_real") + .add_link_without_label("/rowte"); + doc +} diff --git a/examples/serve_dir.rs b/examples/serve_dir.rs index fd26ac4..de3e0b0 100644 --- a/examples/serve_dir.rs +++ b/examples/serve_dir.rs @@ -11,7 +11,8 @@ async fn main() -> Result<()> { .init(); Server::bind(("localhost", GEMINI_PORT)) - .serve(handle_request) + .add_route("/", handle_request) + .serve() .await } diff --git a/examples/user_management.rs b/examples/user_management.rs index a597934..8f9d570 100644 --- a/examples/user_management.rs +++ b/examples/user_management.rs @@ -18,7 +18,8 @@ async fn main() -> Result<()> { .init(); Server::bind(("0.0.0.0", GEMINI_PORT)) - .serve(handle_request) + .add_route("/", handle_request) + .serve() .await } diff --git a/src/lib.rs b/src/lib.rs index 84a4cbf..c6272d8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,9 +22,11 @@ use rustls::*; use anyhow::*; use lazy_static::lazy_static; use crate::util::opt_timeout; +use routing::RoutingNode; pub mod types; pub mod util; +pub mod routing; #[cfg(feature = "user_management")] pub mod user_management; @@ -45,7 +47,7 @@ pub (crate) type HandlerResponse = BoxFuture<'static, Result>; pub struct Server { tls_acceptor: TlsAcceptor, listener: Arc, - handler: Handler, + routes: Arc>, timeout: Duration, complex_timeout: Option, #[cfg(feature="user_management")] @@ -104,19 +106,26 @@ impl Server { request.set_cert(client_cert); - let handler = (self.handler)(request); - let handler = AssertUnwindSafe(handler); + let response = if let Some((trailing, handler)) = self.routes.match_request(&request) { - let response = util::HandlerCatchUnwind::new(handler).await - .unwrap_or_else(|_| Response::server_error("")) - .or_else(|err| { - error!("Handler failed: {:?}", err); - Response::server_error("") - }) - .context("Request handler failed")?; + request.set_trailing(trailing); - self.send_response(response, &mut stream).await - .context("Failed to send response")?; + let handler = (handler)(request); + let handler = AssertUnwindSafe(handler); + + util::HandlerCatchUnwind::new(handler).await + .unwrap_or_else(|_| Response::server_error("")) + .or_else(|err| { + error!("Handler failed: {:?}", err); + Response::server_error("") + }) + .context("Request handler failed")? + } else { + Response::not_found() + }; + + self.send_response(response, &mut stream).await + .context("Failed to send response")?; Ok(()) } @@ -209,6 +218,7 @@ pub struct Builder { key_path: PathBuf, timeout: Duration, complex_body_timeout_override: Option, + routes: RoutingNode, #[cfg(feature="user_management")] data_dir: PathBuf, } @@ -221,6 +231,7 @@ impl Builder { complex_body_timeout_override: Some(Duration::from_secs(30)), cert_path: PathBuf::from("cert/cert.pem"), key_path: PathBuf::from("cert/key.pem"), + routes: RoutingNode::default(), #[cfg(feature="user_management")] data_dir: "data".into(), } @@ -334,20 +345,33 @@ impl Builder { self } - pub async fn serve(self, handler: F) -> Result<()> + /// Add a handler for a route + /// + /// A route must be an absolute path, for example "/endpoint" or "/", but not + /// "endpoint". Entering a relative or malformed path will result in a panic. + /// + /// For more information about routing mechanics, see the docs for [`RoutingNode`]. + pub fn add_route(mut self, path: &'static str, handler: H) -> Self where - F: Fn(Request) -> HandlerResponse + Send + Sync + 'static, + H: Fn(Request) -> HandlerResponse + Send + Sync + 'static, { + self.routes.add_route(path, Arc::new(handler)); + self + } + + pub async fn serve(mut self) -> Result<()> { let config = tls_config(&self.cert_path, &self.key_path) .context("Failed to create TLS config")?; let listener = TcpListener::bind(self.addr).await .context("Failed to create socket")?; + self.routes.shrink(); + let server = Server { tls_acceptor: TlsAcceptor::from(config), listener: Arc::new(listener), - handler: Arc::new(handler), + routes: Arc::new(self.routes), timeout: self.timeout, complex_timeout: self.complex_body_timeout_override, #[cfg(feature="user_management")] diff --git a/src/routing.rs b/src/routing.rs new file mode 100644 index 0000000..20708f7 --- /dev/null +++ b/src/routing.rs @@ -0,0 +1,157 @@ +//! Utilities for routing requests +//! +//! See [`RoutingNode`] for details on how routes are matched. + +use uriparse::path::{Path, Segment}; + +use std::collections::HashMap; +use std::convert::TryInto; + +use crate::types::Request; + +/// A node for linking values to routes +/// +/// Routing is processed by a tree, with each child being a single path segment. For +/// example, if an entry existed at "/trans/rights", then the root-level node would have +/// a child "trans", which would have a child "rights". "rights" would have no children, +/// but would have an attached entry. +/// +/// If one route is shorter than another, say "/trans/rights" and +/// "/trans/rights/r/human", then the longer route always matches first, so a request for +/// "/trans/rights/r/human/rights" would be routed to "/trans/rights/r/human", and +/// "/trans/rights/now" would route to "/trans/rights" +/// +/// Routing is only performed on normalized paths, so "/endpoint" and "/endpoint/" are +/// considered to be the same route. +pub struct RoutingNode(Option, HashMap); + +impl RoutingNode { + /// Attempt to find and entry based on path segments + /// + /// This searches the network of routing nodes attempting to match a specific request, + /// represented as a sequence of path segments. For example, "/dir/image.png?text" + /// should be represented as `&["dir", "image.png"]`. + /// + /// If a match is found, it is returned, along with the segments of the path trailing + /// the subpath matcing the route. For example, a route `/foo` recieving a request to + /// `/foo/bar` would produce `vec!["bar"]` + /// + /// See [`RoutingNode`] for details on how routes are matched. + pub fn match_path(&self, path: I) -> Option<(Vec, &T)> + where + I: IntoIterator, + S: AsRef, + { + let mut node = self; + let mut path = path.into_iter().filter(|seg| !seg.as_ref().is_empty()); + let mut last_seen_handler = None; + let mut since_last_handler = Vec::new(); + loop { + let Self(maybe_handler, map) = node; + + if maybe_handler.is_some() { + last_seen_handler = maybe_handler.as_ref(); + since_last_handler.clear(); + } + + if let Some(segment) = path.next() { + let maybe_route = map.get(segment.as_ref()); + since_last_handler.push(segment); + + if let Some(route) = maybe_route { + node = route; + } else { + break; + } + } else { + break; + } + }; + + if let Some(handler) = last_seen_handler { + since_last_handler.extend(path); + Some((since_last_handler, handler)) + } else { + None + } + } + + /// Attempt to identify a route for a given [`Request`] + /// + /// See [`RoutingNode::match_path()`] for more information + pub fn match_request(&self, req: &Request) -> Option<(Vec, &T)> { + let mut path = req.path().to_borrowed(); + path.normalize(false); + self.match_path(path.segments()) + .map(|(segs, h)| ( + segs.into_iter() + .map(Segment::as_str) + .map(str::to_owned) + .collect(), + h, + )) + } + + /// Add a route to the network + /// + /// This method wraps [`add_route_by_path()`](Self::add_route_by_path()) while + /// unwrapping any errors that might occur. For this reason, this method only takes + /// static strings. If you would like to add a string dynamically, please use + /// [`RoutingNode::add_route_by_path()`] in order to appropriately deal with any + /// errors that might arise. + pub fn add_route(&mut self, path: &'static str, data: T) { + let path: Path = path.try_into().expect("Malformed path route received"); + self.add_route_by_path(path, data).unwrap(); + } + + /// Add a route to the network + /// + /// The path provided MUST be absolute. Callers should verify this before calling + /// this method. + /// + /// For information about how routes work, see [`RoutingNode::match_path()`] + pub fn add_route_by_path(&mut self, mut path: Path, data: T) -> Result<(), ConflictingRouteError>{ + debug_assert!(path.is_absolute()); + path.normalize(false); + + let mut node = self; + for segment in path.segments() { + if segment != "" { + node = node.1.entry(segment.to_string()).or_default(); + } + } + + if node.0.is_some() { + Err(ConflictingRouteError()) + } else { + node.0 = Some(data); + Ok(()) + } + } + + /// Recursively shrink maps to fit + pub fn shrink(&mut self) { + let mut to_shrink = vec![&mut self.1]; + while let Some(shrink) = to_shrink.pop() { + shrink.shrink_to_fit(); + to_shrink.extend(shrink.values_mut().map(|n| &mut n.1)); + } + } +} + +impl Default for RoutingNode { + fn default() -> Self { + Self(None, HashMap::default()) + } +} + +#[derive(Debug, Clone, Copy)] +pub struct ConflictingRouteError(); + +impl std::error::Error for ConflictingRouteError { } + +impl std::fmt::Display for ConflictingRouteError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "Attempted to create a route with the same matcher as an existing route") + } +} diff --git a/src/types/body.rs b/src/types/body.rs index d1356cc..d2da102 100644 --- a/src/types/body.rs +++ b/src/types/body.rs @@ -11,7 +11,7 @@ pub enum Body { impl From for Body { fn from(document: Document) -> Self { - Body::from(document.to_string()) + Self::from(document.to_string()) } } diff --git a/src/types/document.rs b/src/types/document.rs index e322357..d71c851 100644 --- a/src/types/document.rs +++ b/src/types/document.rs @@ -183,7 +183,7 @@ impl Document { .map(URIReference::into_owned) .or_else(|_| ".".try_into()).expect("Northstar BUG"); let label = LinkLabel::from_lossy(label); - let link = Link { uri, label: Some(label) }; + let link = Link { uri: Box::new(uri), label: Some(label) }; let link = Item::Link(link); self.add_item(link); @@ -213,7 +213,7 @@ impl Document { .map(URIReference::into_owned) .or_else(|_| ".".try_into()).expect("Northstar BUG"); let link = Link { - uri, + uri: Box::new(uri), label: None, }; let link = Item::Link(link); @@ -391,6 +391,7 @@ impl fmt::Display for Document { } } +#[allow(clippy::enum_variant_names)] enum Item { Text(Text), Link(Link), @@ -414,7 +415,7 @@ impl Text { } struct Link { - uri: URIReference<'static>, + uri: Box>, label: Option, } @@ -424,7 +425,7 @@ impl LinkLabel { fn from_lossy(line: impl Cowy) -> Self { let line = strip_newlines(line); - LinkLabel(line) + Self(line) } } diff --git a/src/types/meta.rs b/src/types/meta.rs index ccc17ba..bfb36e5 100644 --- a/src/types/meta.rs +++ b/src/types/meta.rs @@ -12,7 +12,7 @@ impl Meta { /// Creates a new "Meta" string. /// Fails if `meta` contains `\n`. pub fn new(meta: impl Cowy) -> Result { - ensure!(!meta.as_ref().contains("\n"), "Meta must not contain newlines"); + ensure!(!meta.as_ref().contains('\n'), "Meta must not contain newlines"); ensure!(meta.as_ref().len() <= Self::MAX_LEN, "Meta must not exceed {} bytes", Self::MAX_LEN); Ok(Self(meta.into())) diff --git a/src/types/request.rs b/src/types/request.rs index 4d4e678..386c1ce 100644 --- a/src/types/request.rs +++ b/src/types/request.rs @@ -13,6 +13,7 @@ pub struct Request { uri: URIReference<'static>, input: Option, certificate: Option, + trailing_segments: Option>, #[cfg(feature="user_management")] manager: UserManager, } @@ -54,15 +55,41 @@ impl Request { uri, input, certificate, + trailing_segments: None, #[cfg(feature="user_management")] manager, }) } - pub fn uri(&self) -> &URIReference { + pub const fn uri(&self) -> &URIReference { &self.uri } + #[allow(clippy::missing_const_for_fn)] + /// All of the path segments following the route to which this request was bound. + /// + /// For example, if this handler was bound to the `/api` route, and a request was + /// received to `/api/v1/endpoint`, then this value would be `["v1", "endpoint"]`. + /// This should not be confused with [`path_segments()`](Self::path_segments()), which + /// contains *all* of the segments, not just those trailing the route. + /// + /// If the trailing segments have not been set, this method will panic, but this + /// should only be possible if you are constructing the Request yourself. Requests + /// to handlers registered through [`add_route`](northstar::Builder::add_route()) will + /// always have trailing segments set. + pub fn trailing_segments(&self) -> &Vec { + self.trailing_segments.as_ref().unwrap() + } + + /// All of the segments in this path, percent decoded + /// + /// For example, for a request to `/api/v1/endpoint`, this would return `["api", "v1", + /// "endpoint"]`, no matter what route the handler that recieved this request was + /// bound to. This is not to be confused with + /// [`trailing_segments()`](Self::trailing_segments), which contains only the segments + /// following the bound route. + /// + /// Additionally, unlike `trailing_segments()`, this method percent decodes the path. pub fn path_segments(&self) -> Vec { self.uri() .path() @@ -80,7 +107,11 @@ impl Request { self.certificate = cert; } - pub fn certificate(&self) -> Option<&Certificate> { + pub fn set_trailing(&mut self, segments: Vec) { + self.trailing_segments = Some(segments); + } + + pub const fn certificate(&self) -> Option<&Certificate> { self.certificate.as_ref() } diff --git a/src/types/response.rs b/src/types/response.rs index 3e4a84a..991d511 100644 --- a/src/types/response.rs +++ b/src/types/response.rs @@ -12,7 +12,7 @@ pub struct Response { } impl Response { - pub fn new(header: ResponseHeader) -> Self { + pub const fn new(header: ResponseHeader) -> Self { Self { header, body: None, @@ -34,7 +34,7 @@ impl Response { } pub fn success(mime: &Mime) -> Self { - let header = ResponseHeader::success(&mime); + let header = ResponseHeader::success(mime); Self::new(header) } @@ -86,7 +86,7 @@ impl Response { self } - pub fn header(&self) -> &ResponseHeader { + pub const fn header(&self) -> &ResponseHeader { &self.header } diff --git a/src/types/response_header.rs b/src/types/response_header.rs index 56f2af3..b2b5e20 100644 --- a/src/types/response_header.rs +++ b/src/types/response_header.rs @@ -88,11 +88,11 @@ impl ResponseHeader { } } - pub fn status(&self) -> &Status { + pub const fn status(&self) -> &Status { &self.status } - pub fn meta(&self) -> &Meta { + pub const fn meta(&self) -> &Meta { &self.meta } } diff --git a/src/types/status.rs b/src/types/status.rs index a06e9f4..18c58a1 100644 --- a/src/types/status.rs +++ b/src/types/status.rs @@ -1,4 +1,3 @@ - #[derive(Debug,Copy,Clone,PartialEq,Eq)] pub struct Status(u8); @@ -22,7 +21,7 @@ impl Status { pub const CERTIFICATE_NOT_AUTHORIZED: Self = Self(61); pub const CERTIFICATE_NOT_VALID: Self = Self(62); - pub fn code(&self) -> u8 { + pub const fn code(&self) -> u8 { self.0 } @@ -30,7 +29,7 @@ impl Status { self.category().is_success() } - pub fn category(&self) -> StatusCategory { + pub const fn category(&self) -> StatusCategory { let class = self.0 / 10; match class {