Add rate limiting feature

ugghhhhhh I hate how big the governer crate is but there's no alternative besides DIY
This commit is contained in:
Emi Tatsuo 2020-11-20 21:15:37 -05:00
parent 536e404fdf
commit 349f6da698
Signed by: Emi
GPG key ID: 68FAB2E2E6DFC98B
2 changed files with 63 additions and 0 deletions

View file

@ -11,6 +11,7 @@ documentation = "https://docs.rs/northstar"
[features] [features]
default = ["serve_dir"] default = ["serve_dir"]
serve_dir = ["mime_guess", "tokio/fs"] serve_dir = ["mime_guess", "tokio/fs"]
rate-limiting = ["governor"]
[dependencies] [dependencies]
anyhow = "1.0.33" anyhow = "1.0.33"
@ -25,6 +26,7 @@ log = "0.4.11"
webpki = "0.21.0" webpki = "0.21.0"
lazy_static = "1.4.0" lazy_static = "1.4.0"
mime_guess = { version = "2.0.3", optional = true } mime_guess = { version = "2.0.3", optional = true }
governor = { version = "0.3.1", optional = true }
[dev-dependencies] [dev-dependencies]
env_logger = "0.8.1" env_logger = "0.8.1"

View file

@ -21,6 +21,8 @@ use tokio_rustls::{rustls, TlsAcceptor};
use rustls::*; use rustls::*;
use anyhow::*; use anyhow::*;
use lazy_static::lazy_static; use lazy_static::lazy_static;
#[cfg(feature="rate-limiting")]
use governor::clock::{Clock, DefaultClock};
use crate::util::opt_timeout; use crate::util::opt_timeout;
use routing::RoutingNode; use routing::RoutingNode;
@ -30,13 +32,26 @@ pub mod routing;
pub use mime; pub use mime;
pub use uriparse as uri; pub use uriparse as uri;
#[cfg(feature="rate-limiting")]
pub use governor::Quota;
pub use types::*; pub use types::*;
pub const REQUEST_URI_MAX_LEN: usize = 1024; pub const REQUEST_URI_MAX_LEN: usize = 1024;
pub const GEMINI_PORT: u16 = 1965; pub const GEMINI_PORT: u16 = 1965;
#[cfg(feature="rate-limiting")]
lazy_static! {
static ref CLOCK: DefaultClock = DefaultClock::default();
}
type Handler = Arc<dyn Fn(Request) -> HandlerResponse + Send + Sync>; type Handler = Arc<dyn Fn(Request) -> HandlerResponse + Send + Sync>;
pub (crate) type HandlerResponse = BoxFuture<'static, Result<Response>>; pub (crate) type HandlerResponse = BoxFuture<'static, Result<Response>>;
#[cfg(feature="rate-limiting")]
type RateLimiter = governor::RateLimiter<
std::net::IpAddr,
governor::state::keyed::DefaultKeyedStateStore<std::net::IpAddr>,
governor::clock::DefaultClock,
>;
#[derive(Clone)] #[derive(Clone)]
pub struct Server { pub struct Server {
@ -45,6 +60,8 @@ pub struct Server {
routes: Arc<RoutingNode<Handler>>, routes: Arc<RoutingNode<Handler>>,
timeout: Duration, timeout: Duration,
complex_timeout: Option<Duration>, complex_timeout: Option<Duration>,
#[cfg(feature="rate-limiting")]
rate_limits: Arc<RoutingNode<RateLimiter>>,
} }
impl Server { impl Server {
@ -67,6 +84,9 @@ impl Server {
} }
async fn serve_client(self, stream: TcpStream) -> Result<()> { async fn serve_client(self, stream: TcpStream) -> Result<()> {
#[cfg(feature="rate-limiting")]
let peer_addr = stream.peer_addr()?.ip();
let fut_accept_request = async { let fut_accept_request = async {
let stream = self.tls_acceptor.accept(stream).await let stream = self.tls_acceptor.accept(stream).await
.context("Failed to establish TLS session")?; .context("Failed to establish TLS session")?;
@ -83,6 +103,13 @@ impl Server {
let (mut request, mut stream) = fut_accept_request.await let (mut request, mut stream) = fut_accept_request.await
.context("Client timed out while waiting for response")??; .context("Client timed out while waiting for response")??;
#[cfg(feature="rate-limiting")]
if let Some(resp) = self.check_rate_limits(peer_addr, &request) {
self.send_response(resp, &mut stream).await
.context("Failed to send response")?;
return Ok(())
}
debug!("Client requested: {}", request.uri()); debug!("Client requested: {}", request.uri());
// Identify the client certificate from the tls stream. This is the first // Identify the client certificate from the tls stream. This is the first
@ -164,6 +191,21 @@ impl Server {
Ok(()) Ok(())
} }
#[cfg(feature="rate-limiting")]
fn check_rate_limits(&self, addr: std::net::IpAddr, req: &Request) -> Option<Response> {
if let Some((_, limiter)) = self.rate_limits.match_request(req) {
if let Err(when) = limiter.check_key(&addr) {
return Some(Response::new(ResponseHeader {
status: Status::SLOW_DOWN,
meta: Meta::new(
when.wait_time_from(CLOCK.now()).as_secs().to_string()
).unwrap()
}))
}
}
None
}
} }
pub struct Builder<A> { pub struct Builder<A> {
@ -173,6 +215,8 @@ pub struct Builder<A> {
timeout: Duration, timeout: Duration,
complex_body_timeout_override: Option<Duration>, complex_body_timeout_override: Option<Duration>,
routes: RoutingNode<Handler>, routes: RoutingNode<Handler>,
#[cfg(feature="rate-limiting")]
rate_limits: RoutingNode<RateLimiter>,
} }
impl<A: ToSocketAddrs> Builder<A> { impl<A: ToSocketAddrs> Builder<A> {
@ -184,6 +228,8 @@ impl<A: ToSocketAddrs> Builder<A> {
cert_path: PathBuf::from("cert/cert.pem"), cert_path: PathBuf::from("cert/cert.pem"),
key_path: PathBuf::from("cert/key.pem"), key_path: PathBuf::from("cert/key.pem"),
routes: RoutingNode::default(), routes: RoutingNode::default(),
#[cfg(feature="rate-limiting")]
rate_limits: RoutingNode::default(),
} }
} }
@ -300,6 +346,19 @@ impl<A: ToSocketAddrs> Builder<A> {
self self
} }
#[cfg(feature="rate-limiting")]
/// Add a rate limit to a route
///
/// A route must be an absolute path, for example "/endpoint" or "/", but not
/// "endpoint". Entering a relative or malformed path will result in a panic.
///
/// For more information about routing mechanics, see the docs for [`RoutingNode`].
pub fn rate_limit(mut self, path: &'static str, quota: Quota) -> Self {
let limiter = RateLimiter::dashmap_with_clock(quota, &CLOCK);
self.rate_limits.add_route(path, limiter);
self
}
pub async fn serve(mut self) -> Result<()> { pub async fn serve(mut self) -> Result<()> {
let config = tls_config(&self.cert_path, &self.key_path) let config = tls_config(&self.cert_path, &self.key_path)
.context("Failed to create TLS config")?; .context("Failed to create TLS config")?;
@ -315,6 +374,8 @@ impl<A: ToSocketAddrs> Builder<A> {
routes: Arc::new(self.routes), routes: Arc::new(self.routes),
timeout: self.timeout, timeout: self.timeout,
complex_timeout: self.complex_body_timeout_override, complex_timeout: self.complex_body_timeout_override,
#[cfg(feature="rate-limiting")]
rate_limits: Arc::new(self.rate_limits),
}; };
server.serve().await server.serve().await