73 lines
2.5 KiB
Rust
73 lines
2.5 KiB
Rust
|
use dashmap::DashMap;
|
||
|
|
||
|
use std::{hash::Hash, collections::VecDeque, time::{Duration, Instant}};
|
||
|
|
||
|
/// A simple struct to manage rate limiting.
|
||
|
///
|
||
|
/// Does not require a leaky bucket thread to empty it out, but may occassionally need to
|
||
|
/// trim old keys using [`trim_keys()`].
|
||
|
///
|
||
|
/// [`trim_keys()`][Self::trim_keys()]
|
||
|
pub struct RateLimiter<K: Eq + Hash> {
|
||
|
log: DashMap<K, VecDeque<Instant>>,
|
||
|
burst: usize,
|
||
|
period: Duration,
|
||
|
}
|
||
|
|
||
|
impl<K: Eq + Hash> RateLimiter<K> {
|
||
|
/// Create a new ratelimiter that allows at most `burst` connections in `period`
|
||
|
pub fn new(period: Duration, burst: usize) -> Self {
|
||
|
Self {
|
||
|
log: DashMap::with_capacity(8),
|
||
|
period,
|
||
|
burst,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
/// Check if a key may pass
|
||
|
///
|
||
|
/// If the key has made less than `self.burst` connections in the last `self.period`,
|
||
|
/// then the key is allowed to connect, which is denoted by an `Ok` result. This will
|
||
|
/// register as a new connection from that key.
|
||
|
///
|
||
|
/// If the key is not allowed to connect, than a [`Duration`] denoting the amount of
|
||
|
/// time until the key is permitted is returned, wrapped in an `Err`
|
||
|
pub fn check_key(&self, key: K) -> Result<(), Duration> {
|
||
|
let now = Instant::now();
|
||
|
let count_after = now - self.period;
|
||
|
|
||
|
let mut connections = self.log.entry(key)
|
||
|
.or_insert_with(||VecDeque::with_capacity(self.burst));
|
||
|
let connections = connections.value_mut();
|
||
|
|
||
|
// Chcek if space can be made available. We don't need to trim all expired
|
||
|
// connections, just the one in question to allow this connection.
|
||
|
if let Some(earliest_conn) = connections.front() {
|
||
|
if earliest_conn < &count_after {
|
||
|
connections.pop_front();
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Check if the connection should be allowed
|
||
|
if connections.len() == self.burst {
|
||
|
Err(connections[0] + self.period - now)
|
||
|
} else {
|
||
|
connections.push_back(now);
|
||
|
Ok(())
|
||
|
}
|
||
|
}
|
||
|
|
||
|
/// Remove any expired keys from the ratelimiter
|
||
|
///
|
||
|
/// This only needs to be called if keys are continuously being added. If keys are
|
||
|
/// being reused, or come from a finite set, then you don't need to worry about this.
|
||
|
///
|
||
|
/// If you have many keys coming from a large set, you should infrequently call this
|
||
|
/// to prevent a memory leak.
|
||
|
pub fn trim_keys(&self) {
|
||
|
let count_after = Instant::now() - self.period;
|
||
|
|
||
|
self.log.retain(|_, conns| conns.back().unwrap() > &count_after);
|
||
|
}
|
||
|
}
|