From df2350a8bb88c9a769a957ad7b841c857868b347 Mon Sep 17 00:00:00 2001 From: Emi Tatsuo Date: Tue, 24 Nov 2020 13:58:18 -0500 Subject: [PATCH] Switch to a much lighter in-house rate-limiting solution, and use consistant naming of ratelimiting --- Cargo.toml | 4 +- examples/{ratelimits.rs => ratelimiting.rs} | 6 +- src/lib.rs | 55 +++++++--------- src/ratelimiting.rs | 72 +++++++++++++++++++++ 4 files changed, 100 insertions(+), 37 deletions(-) rename examples/{ratelimits.rs => ratelimiting.rs} (91%) create mode 100644 src/ratelimiting.rs diff --git a/Cargo.toml b/Cargo.toml index d53d578..8b5432c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,7 @@ documentation = "https://docs.rs/northstar" [features] default = ["serve_dir"] serve_dir = ["mime_guess", "tokio/fs"] -rate-limiting = ["governor"] +ratelimiting = ["dashmap"] [dependencies] anyhow = "1.0.33" @@ -26,7 +26,7 @@ log = "0.4.11" webpki = "0.21.0" lazy_static = "1.4.0" mime_guess = { version = "2.0.3", optional = true } -governor = { version = "0.3.1", optional = true } +dashmap = { version = "3.11.10", optional = true } [dev-dependencies] env_logger = "0.8.1" diff --git a/examples/ratelimits.rs b/examples/ratelimiting.rs similarity index 91% rename from examples/ratelimits.rs rename to examples/ratelimiting.rs index b9cf2ab..48b215c 100644 --- a/examples/ratelimits.rs +++ b/examples/ratelimiting.rs @@ -1,3 +1,5 @@ +use std::time::Duration; + use anyhow::*; use futures_core::future::BoxFuture; use futures_util::FutureExt; @@ -10,11 +12,9 @@ async fn main() -> Result<()> { .filter_module("northstar", LevelFilter::Debug) .init(); - let two = std::num::NonZeroU32::new(2).unwrap(); - Server::bind(("localhost", GEMINI_PORT)) .add_route("/", handle_request) - .rate_limit("/limit", northstar::Quota::per_minute(two)) + .ratelimit("/limit", 2, Duration::from_secs(60)) .serve() .await } diff --git a/src/lib.rs b/src/lib.rs index eaafc9a..3b99331 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,6 +7,7 @@ use std::{ sync::Arc, path::PathBuf, time::Duration, + net::IpAddr, }; use futures_core::future::BoxFuture; use tokio::{ @@ -21,37 +22,25 @@ use tokio_rustls::{rustls, TlsAcceptor}; use rustls::*; use anyhow::*; use lazy_static::lazy_static; -#[cfg(feature="rate-limiting")] -use governor::clock::{Clock, DefaultClock}; use crate::util::opt_timeout; use routing::RoutingNode; +use ratelimiting::RateLimiter; pub mod types; pub mod util; pub mod routing; +#[cfg(feature = "ratelimiting")] +pub mod ratelimiting; pub use mime; pub use uriparse as uri; -#[cfg(feature="rate-limiting")] -pub use governor::Quota; pub use types::*; pub const REQUEST_URI_MAX_LEN: usize = 1024; pub const GEMINI_PORT: u16 = 1965; -#[cfg(feature="rate-limiting")] -lazy_static! { - static ref CLOCK: DefaultClock = DefaultClock::default(); -} - type Handler = Arc HandlerResponse + Send + Sync>; pub (crate) type HandlerResponse = BoxFuture<'static, Result>; -#[cfg(feature="rate-limiting")] -type RateLimiter = governor::RateLimiter< - std::net::IpAddr, - governor::state::keyed::DefaultKeyedStateStore, - governor::clock::DefaultClock, ->; #[derive(Clone)] pub struct Server { @@ -60,8 +49,8 @@ pub struct Server { routes: Arc>, timeout: Duration, complex_timeout: Option, - #[cfg(feature="rate-limiting")] - rate_limits: Arc>, + #[cfg(feature="ratelimiting")] + rate_limits: Arc>>, } impl Server { @@ -84,7 +73,7 @@ impl Server { } async fn serve_client(self, stream: TcpStream) -> Result<()> { - #[cfg(feature="rate-limiting")] + #[cfg(feature="ratelimiting")] let peer_addr = stream.peer_addr()?.ip(); let fut_accept_request = async { @@ -103,7 +92,7 @@ impl Server { let (mut request, mut stream) = fut_accept_request.await .context("Client timed out while waiting for response")??; - #[cfg(feature="rate-limiting")] + #[cfg(feature="ratelimiting")] if let Some(resp) = self.check_rate_limits(peer_addr, &request) { self.send_response(resp, &mut stream).await .context("Failed to send response")?; @@ -192,15 +181,13 @@ impl Server { Ok(()) } - #[cfg(feature="rate-limiting")] - fn check_rate_limits(&self, addr: std::net::IpAddr, req: &Request) -> Option { + #[cfg(feature="ratelimiting")] + fn check_rate_limits(&self, addr: IpAddr, req: &Request) -> Option { if let Some((_, limiter)) = self.rate_limits.match_request(req) { - if let Err(when) = limiter.check_key(&addr) { + if let Err(when) = limiter.check_key(addr) { return Some(Response::new(ResponseHeader { status: Status::SLOW_DOWN, - meta: Meta::new( - when.wait_time_from(CLOCK.now()).as_secs().to_string() - ).unwrap() + meta: Meta::new(when.as_secs().to_string()).unwrap() })) } } @@ -215,8 +202,8 @@ pub struct Builder { timeout: Duration, complex_body_timeout_override: Option, routes: RoutingNode, - #[cfg(feature="rate-limiting")] - rate_limits: RoutingNode, + #[cfg(feature="ratelimiting")] + rate_limits: RoutingNode>, } impl Builder { @@ -228,7 +215,7 @@ impl Builder { cert_path: PathBuf::from("cert/cert.pem"), key_path: PathBuf::from("cert/key.pem"), routes: RoutingNode::default(), - #[cfg(feature="rate-limiting")] + #[cfg(feature="ratelimiting")] rate_limits: RoutingNode::default(), } } @@ -346,15 +333,19 @@ impl Builder { self } - #[cfg(feature="rate-limiting")] + #[cfg(feature="ratelimiting")] /// Add a rate limit to a route /// + /// The server will allow at most `burst` connections to any endpoints under this + /// route in a period of `period`. All extra requests will recieve a `SLOW_DOWN`, and + /// not be sent to the handler. + /// /// A route must be an absolute path, for example "/endpoint" or "/", but not /// "endpoint". Entering a relative or malformed path will result in a panic. /// /// For more information about routing mechanics, see the docs for [`RoutingNode`]. - pub fn rate_limit(mut self, path: &'static str, quota: Quota) -> Self { - let limiter = RateLimiter::dashmap_with_clock(quota, &CLOCK); + pub fn ratelimit(mut self, path: &'static str, burst: usize, period: Duration) -> Self { + let limiter = RateLimiter::new(period, burst); self.rate_limits.add_route(path, limiter); self } @@ -374,7 +365,7 @@ impl Builder { routes: Arc::new(self.routes), timeout: self.timeout, complex_timeout: self.complex_body_timeout_override, - #[cfg(feature="rate-limiting")] + #[cfg(feature="ratelimiting")] rate_limits: Arc::new(self.rate_limits), }; diff --git a/src/ratelimiting.rs b/src/ratelimiting.rs new file mode 100644 index 0000000..df60e8e --- /dev/null +++ b/src/ratelimiting.rs @@ -0,0 +1,72 @@ +use dashmap::DashMap; + +use std::{hash::Hash, collections::VecDeque, time::{Duration, Instant}}; + +/// A simple struct to manage rate limiting. +/// +/// Does not require a leaky bucket thread to empty it out, but may occassionally need to +/// trim old keys using [`trim_keys()`]. +/// +/// [`trim_keys()`][Self::trim_keys()] +pub struct RateLimiter { + log: DashMap>, + burst: usize, + period: Duration, +} + +impl RateLimiter { + /// Create a new ratelimiter that allows at most `burst` connections in `period` + pub fn new(period: Duration, burst: usize) -> Self { + Self { + log: DashMap::with_capacity(8), + period, + burst, + } + } + + /// Check if a key may pass + /// + /// If the key has made less than `self.burst` connections in the last `self.period`, + /// then the key is allowed to connect, which is denoted by an `Ok` result. This will + /// register as a new connection from that key. + /// + /// If the key is not allowed to connect, than a [`Duration`] denoting the amount of + /// time until the key is permitted is returned, wrapped in an `Err` + pub fn check_key(&self, key: K) -> Result<(), Duration> { + let now = Instant::now(); + let count_after = now - self.period; + + let mut connections = self.log.entry(key) + .or_insert_with(||VecDeque::with_capacity(self.burst)); + let connections = connections.value_mut(); + + // Chcek if space can be made available. We don't need to trim all expired + // connections, just the one in question to allow this connection. + if let Some(earliest_conn) = connections.front() { + if earliest_conn < &count_after { + connections.pop_front(); + } + } + + // Check if the connection should be allowed + if connections.len() == self.burst { + Err(connections[0] + self.period - now) + } else { + connections.push_back(now); + Ok(()) + } + } + + /// Remove any expired keys from the ratelimiter + /// + /// This only needs to be called if keys are continuously being added. If keys are + /// being reused, or come from a finite set, then you don't need to worry about this. + /// + /// If you have many keys coming from a large set, you should infrequently call this + /// to prevent a memory leak. + pub fn trim_keys(&self) { + let count_after = Instant::now() - self.period; + + self.log.retain(|_, conns| conns.back().unwrap() > &count_after); + } +}