107 lines
3.9 KiB
Rust
107 lines
3.9 KiB
Rust
use dashmap::DashMap;
|
|
|
|
use std::{fmt::Display, collections::VecDeque, hash::Hash, time::{Duration, Instant}};
|
|
|
|
/// A simple struct to manage rate limiting.
|
|
///
|
|
/// Does not require a leaky bucket thread to empty it out, but may occassionally need to
|
|
/// trim old keys using [`trim_keys()`].
|
|
///
|
|
/// [`trim_keys()`]: Self::trim_keys()
|
|
pub struct RateLimiter<K: Eq + Hash> {
|
|
log: DashMap<K, VecDeque<Instant>>,
|
|
burst: usize,
|
|
period: Duration,
|
|
}
|
|
|
|
impl<K: Eq + Hash> RateLimiter<K> {
|
|
/// Create a new ratelimiter that allows at most `burst` connections in `period`
|
|
pub fn new(period: Duration, burst: usize) -> Self {
|
|
Self {
|
|
log: DashMap::with_capacity(8),
|
|
period,
|
|
burst,
|
|
}
|
|
}
|
|
|
|
/// Check if a key may pass
|
|
///
|
|
/// If the key has made less than `self.burst` connections in the last `self.period`,
|
|
/// then the key is allowed to connect, which is denoted by an `Ok` result. This will
|
|
/// register as a new connection from that key.
|
|
///
|
|
/// If the key is not allowed to connect, than a [`Duration`] denoting the amount of
|
|
/// time until the key is permitted is returned, wrapped in an `Err`
|
|
pub fn check_key(&self, key: K) -> Result<(), Duration> {
|
|
let now = Instant::now();
|
|
let count_after = now - self.period;
|
|
|
|
let mut connections = self.log.entry(key)
|
|
.or_insert_with(||VecDeque::with_capacity(self.burst));
|
|
let connections = connections.value_mut();
|
|
|
|
// Chcek if space can be made available. We don't need to trim all expired
|
|
// connections, just the one in question to allow this connection.
|
|
if let Some(earliest_conn) = connections.front() {
|
|
if earliest_conn < &count_after {
|
|
connections.pop_front();
|
|
}
|
|
}
|
|
|
|
// Check if the connection should be allowed
|
|
if connections.len() == self.burst {
|
|
Err(connections[0] + self.period - now)
|
|
} else {
|
|
connections.push_back(now);
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
/// Remove any expired keys from the ratelimiter
|
|
///
|
|
/// This only needs to be called if keys are continuously being added. If keys are
|
|
/// being reused, or come from a finite set, then you don't need to worry about this.
|
|
///
|
|
/// If you have many keys coming from a large set, you should infrequently call this
|
|
/// to prevent a memory leak.
|
|
///
|
|
/// If debug level logging is enabled, this prints an *approximate* number of keys
|
|
/// removed to the log. For more precise output, use [`trim_keys_verbose()`]
|
|
///
|
|
/// [`trim_keys_verbose()`]: RateLimiter::trim_keys_verbose()
|
|
pub fn trim_keys(&self) {
|
|
let count_after = Instant::now() - self.period;
|
|
|
|
let len: isize = self.log.len() as isize;
|
|
self.log.retain(|_, conns| conns.back().unwrap() > &count_after);
|
|
let removed = len - self.log.len() as isize;
|
|
if removed.is_positive() {
|
|
debug!("Pruned approximately {} expired ratelimit keys", removed);
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<K: Eq + Hash + Display> RateLimiter<K> {
|
|
|
|
/// Remove any expired keys from the ratelimiter
|
|
///
|
|
/// This only needs to be called if keys are continuously being added. If keys are
|
|
/// being reused, or come from a finite set, then you don't need to worry about this.
|
|
///
|
|
/// If you have many keys coming from a large set, you should infrequently call this
|
|
/// to prevent a memory leak.
|
|
///
|
|
/// If debug level logging is on, this prints out any removed keys.
|
|
pub fn trim_keys_verbose(&self) {
|
|
let count_after = Instant::now() - self.period;
|
|
|
|
self.log.retain(|ip, conns| {
|
|
let should_keep = conns.back().unwrap() > &count_after;
|
|
if !should_keep {
|
|
debug!("Pruned expired ratelimit key: {}", ip);
|
|
}
|
|
should_keep
|
|
});
|
|
}
|
|
}
|