use dashmap::DashMap; use std::{fmt::Display, collections::VecDeque, hash::Hash, time::{Duration, Instant}}; /// A simple struct to manage rate limiting. /// /// Does not require a leaky bucket thread to empty it out, but may occassionally need to /// trim old keys using [`trim_keys()`]. /// /// [`trim_keys()`]: Self::trim_keys() pub struct RateLimiter { log: DashMap>, burst: usize, period: Duration, } impl RateLimiter { /// Create a new ratelimiter that allows at most `burst` connections in `period` pub fn new(period: Duration, burst: usize) -> Self { Self { log: DashMap::with_capacity(8), period, burst, } } /// Check if a key may pass /// /// If the key has made less than `self.burst` connections in the last `self.period`, /// then the key is allowed to connect, which is denoted by an `Ok` result. This will /// register as a new connection from that key. /// /// If the key is not allowed to connect, than a [`Duration`] denoting the amount of /// time until the key is permitted is returned, wrapped in an `Err` pub fn check_key(&self, key: K) -> Result<(), Duration> { let now = Instant::now(); let count_after = now - self.period; let mut connections = self.log.entry(key) .or_insert_with(||VecDeque::with_capacity(self.burst)); let connections = connections.value_mut(); // Chcek if space can be made available. We don't need to trim all expired // connections, just the one in question to allow this connection. if let Some(earliest_conn) = connections.front() { if earliest_conn < &count_after { connections.pop_front(); } } // Check if the connection should be allowed if connections.len() == self.burst { Err(connections[0] + self.period - now) } else { connections.push_back(now); Ok(()) } } /// Remove any expired keys from the ratelimiter /// /// This only needs to be called if keys are continuously being added. If keys are /// being reused, or come from a finite set, then you don't need to worry about this. /// /// If you have many keys coming from a large set, you should infrequently call this /// to prevent a memory leak. /// /// If debug level logging is enabled, this prints an *approximate* number of keys /// removed to the log. For more precise output, use [`trim_keys_verbose()`] /// /// [`trim_keys_verbose()`]: RateLimiter::trim_keys_verbose() pub fn trim_keys(&self) { let count_after = Instant::now() - self.period; let len: isize = self.log.len() as isize; self.log.retain(|_, conns| conns.back().unwrap() > &count_after); let removed = len - self.log.len() as isize; if removed.is_positive() { debug!("Pruned approximately {} expired ratelimit keys", removed); } } } impl RateLimiter { /// Remove any expired keys from the ratelimiter /// /// This only needs to be called if keys are continuously being added. If keys are /// being reused, or come from a finite set, then you don't need to worry about this. /// /// If you have many keys coming from a large set, you should infrequently call this /// to prevent a memory leak. /// /// If debug level logging is on, this prints out any removed keys. pub fn trim_keys_verbose(&self) { let count_after = Instant::now() - self.period; self.log.retain(|ip, conns| { let should_keep = conns.back().unwrap() > &count_after; if !should_keep { debug!("Pruned expired ratelimit key: {}", ip); } should_keep }); } }