kochab/src/ratelimiting.rs

107 lines
3.9 KiB
Rust

use dashmap::DashMap;
use std::{fmt::Display, collections::VecDeque, hash::Hash, time::{Duration, Instant}};
/// A simple struct to manage rate limiting.
///
/// Does not require a leaky bucket thread to empty it out, but may occassionally need to
/// trim old keys using [`trim_keys()`].
///
/// [`trim_keys()`]: Self::trim_keys()
pub struct RateLimiter<K: Eq + Hash> {
log: DashMap<K, VecDeque<Instant>>,
burst: usize,
period: Duration,
}
impl<K: Eq + Hash> RateLimiter<K> {
/// Create a new ratelimiter that allows at most `burst` connections in `period`
pub fn new(period: Duration, burst: usize) -> Self {
Self {
log: DashMap::with_capacity(8),
period,
burst,
}
}
/// Check if a key may pass
///
/// If the key has made less than `self.burst` connections in the last `self.period`,
/// then the key is allowed to connect, which is denoted by an `Ok` result. This will
/// register as a new connection from that key.
///
/// If the key is not allowed to connect, than a [`Duration`] denoting the amount of
/// time until the key is permitted is returned, wrapped in an `Err`
pub fn check_key(&self, key: K) -> Result<(), Duration> {
let now = Instant::now();
let count_after = now - self.period;
let mut connections = self.log.entry(key)
.or_insert_with(||VecDeque::with_capacity(self.burst));
let connections = connections.value_mut();
// Chcek if space can be made available. We don't need to trim all expired
// connections, just the one in question to allow this connection.
if let Some(earliest_conn) = connections.front() {
if earliest_conn < &count_after {
connections.pop_front();
}
}
// Check if the connection should be allowed
if connections.len() == self.burst {
Err(connections[0] + self.period - now)
} else {
connections.push_back(now);
Ok(())
}
}
/// Remove any expired keys from the ratelimiter
///
/// This only needs to be called if keys are continuously being added. If keys are
/// being reused, or come from a finite set, then you don't need to worry about this.
///
/// If you have many keys coming from a large set, you should infrequently call this
/// to prevent a memory leak.
///
/// If debug level logging is enabled, this prints an *approximate* number of keys
/// removed to the log. For more precise output, use [`trim_keys_verbose()`]
///
/// [`trim_keys_verbose()`]: RateLimiter::trim_keys_verbose()
pub fn trim_keys(&self) {
let count_after = Instant::now() - self.period;
let len: isize = self.log.len() as isize;
self.log.retain(|_, conns| conns.back().unwrap() > &count_after);
let removed = len - self.log.len() as isize;
if removed.is_positive() {
debug!("Pruned approximately {} expired ratelimit keys", removed);
}
}
}
impl<K: Eq + Hash + Display> RateLimiter<K> {
/// Remove any expired keys from the ratelimiter
///
/// This only needs to be called if keys are continuously being added. If keys are
/// being reused, or come from a finite set, then you don't need to worry about this.
///
/// If you have many keys coming from a large set, you should infrequently call this
/// to prevent a memory leak.
///
/// If debug level logging is on, this prints out any removed keys.
pub fn trim_keys_verbose(&self) {
let count_after = Instant::now() - self.period;
self.log.retain(|ip, conns| {
let should_keep = conns.back().unwrap() > &count_after;
if !should_keep {
debug!("Pruned expired ratelimit key: {}", ip);
}
should_keep
});
}
}