diff --git a/Cargo.toml b/Cargo.toml index 9ad991d..d53d578 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ documentation = "https://docs.rs/northstar" [features] default = ["serve_dir"] serve_dir = ["mime_guess", "tokio/fs"] +rate-limiting = ["governor"] [dependencies] anyhow = "1.0.33" @@ -25,6 +26,7 @@ log = "0.4.11" webpki = "0.21.0" lazy_static = "1.4.0" mime_guess = { version = "2.0.3", optional = true } +governor = { version = "0.3.1", optional = true } [dev-dependencies] env_logger = "0.8.1" diff --git a/src/lib.rs b/src/lib.rs index e957262..eaafc9a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -21,6 +21,8 @@ use tokio_rustls::{rustls, TlsAcceptor}; use rustls::*; use anyhow::*; use lazy_static::lazy_static; +#[cfg(feature="rate-limiting")] +use governor::clock::{Clock, DefaultClock}; use crate::util::opt_timeout; use routing::RoutingNode; @@ -30,13 +32,26 @@ pub mod routing; pub use mime; pub use uriparse as uri; +#[cfg(feature="rate-limiting")] +pub use governor::Quota; pub use types::*; pub const REQUEST_URI_MAX_LEN: usize = 1024; pub const GEMINI_PORT: u16 = 1965; +#[cfg(feature="rate-limiting")] +lazy_static! { + static ref CLOCK: DefaultClock = DefaultClock::default(); +} + type Handler = Arc HandlerResponse + Send + Sync>; pub (crate) type HandlerResponse = BoxFuture<'static, Result>; +#[cfg(feature="rate-limiting")] +type RateLimiter = governor::RateLimiter< + std::net::IpAddr, + governor::state::keyed::DefaultKeyedStateStore, + governor::clock::DefaultClock, +>; #[derive(Clone)] pub struct Server { @@ -45,6 +60,8 @@ pub struct Server { routes: Arc>, timeout: Duration, complex_timeout: Option, + #[cfg(feature="rate-limiting")] + rate_limits: Arc>, } impl Server { @@ -67,6 +84,9 @@ impl Server { } async fn serve_client(self, stream: TcpStream) -> Result<()> { + #[cfg(feature="rate-limiting")] + let peer_addr = stream.peer_addr()?.ip(); + let fut_accept_request = async { let stream = self.tls_acceptor.accept(stream).await .context("Failed to establish TLS session")?; @@ -83,6 +103,13 @@ impl Server { let (mut request, mut stream) = fut_accept_request.await .context("Client timed out while waiting for response")??; + #[cfg(feature="rate-limiting")] + if let Some(resp) = self.check_rate_limits(peer_addr, &request) { + self.send_response(resp, &mut stream).await + .context("Failed to send response")?; + return Ok(()) + } + debug!("Client requested: {}", request.uri()); // Identify the client certificate from the tls stream. This is the first @@ -164,6 +191,21 @@ impl Server { Ok(()) } + + #[cfg(feature="rate-limiting")] + fn check_rate_limits(&self, addr: std::net::IpAddr, req: &Request) -> Option { + if let Some((_, limiter)) = self.rate_limits.match_request(req) { + if let Err(when) = limiter.check_key(&addr) { + return Some(Response::new(ResponseHeader { + status: Status::SLOW_DOWN, + meta: Meta::new( + when.wait_time_from(CLOCK.now()).as_secs().to_string() + ).unwrap() + })) + } + } + None + } } pub struct Builder { @@ -173,6 +215,8 @@ pub struct Builder { timeout: Duration, complex_body_timeout_override: Option, routes: RoutingNode, + #[cfg(feature="rate-limiting")] + rate_limits: RoutingNode, } impl Builder { @@ -184,6 +228,8 @@ impl Builder { cert_path: PathBuf::from("cert/cert.pem"), key_path: PathBuf::from("cert/key.pem"), routes: RoutingNode::default(), + #[cfg(feature="rate-limiting")] + rate_limits: RoutingNode::default(), } } @@ -300,6 +346,19 @@ impl Builder { self } + #[cfg(feature="rate-limiting")] + /// Add a rate limit to a route + /// + /// A route must be an absolute path, for example "/endpoint" or "/", but not + /// "endpoint". Entering a relative or malformed path will result in a panic. + /// + /// For more information about routing mechanics, see the docs for [`RoutingNode`]. + pub fn rate_limit(mut self, path: &'static str, quota: Quota) -> Self { + let limiter = RateLimiter::dashmap_with_clock(quota, &CLOCK); + self.rate_limits.add_route(path, limiter); + self + } + pub async fn serve(mut self) -> Result<()> { let config = tls_config(&self.cert_path, &self.key_path) .context("Failed to create TLS config")?; @@ -315,6 +374,8 @@ impl Builder { routes: Arc::new(self.routes), timeout: self.timeout, complex_timeout: self.complex_body_timeout_override, + #[cfg(feature="rate-limiting")] + rate_limits: Arc::new(self.rate_limits), }; server.serve().await