Some time ago I used Actix Web to create a back-end web service to hash remote images. At that time, I was only slightly disappointed that this web framework was lacking an internal ability to rate-limit server endpoints, but I also knew that this would make a perfect future project. This page is devoted to creating Actix Web middleware that could be used to "wrap" my server's endpoints in order to limit the frequency at which they would receive HTTP GET requests. The requirements were very simple: create middleware that would enforce limits on how many requests would be forwarded to server endpoints within a specific time period. A request volume exceeding these limits will not be forwarded to the endpoint, the middleware will short circuit the request and send an HTTP 429 "Too Many Requests" response.
Limiter
struct is the core logical component of the
rate-limiting middleware and contains members used to track IP addresses
that have made requests to server endpoints, the time at which the last
request occurred, a quantity of requests, a duration, and a integer value representing
the
max number of requests allowed during this duration.
LimiterBuilder
is used to construct an Arc
to a Limiter
guarded by a Mutex.
An Arc, atomic reference counter, is used to account for all instances of whatever it
wraps. If and when this resource is
ever cloned the Arc will be incremented and the new reference will point to the original
resource location on the heap.
A mutex is used to ensure that the Limiter
(or whatever resource it guards)
can be accessed safely across multiple
threads. thread safety is a
large topic.
RateLimiter
struct that contains only
a Limiter
instance in order to implement the required middleware traits
(described in the next section) and a RateLimiterMiddleware
struct.
pub struct RateLimiter {
pub(crate) limiter: Arc<Mutex<Limiter>>,
}
impl RateLimiter {
pub fn new(limiter: Arc<Mutex<Limiter>>) -> Self {
Self { limiter }
}
}
pub struct RateLimiterMiddleware<S> {
pub(crate) service: Arc<S>,
pub(crate) limiter: Arc<Mutex<Limiter>>,
}
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use chrono::{DateTime, Duration, Utc};
use actix_service::{Service, Transform};
use actix_web::dev::{ServiceRequest, ServiceResponse};
use actix_web::{Error, HttpResponse};
use futures::future::{ok, Ready};
use std::task::{Context, Poll};
use actix_web::body::{BoxBody, EitherBody, MessageBody};
#[derive(Clone)]
pub struct TimeCount {
last_request: DateTime<Utc>,
num_requests: usize,
}
pub struct Limiter {
pub ip_addresses: HashMap<IpAddr, TimeCount>,
pub duration: Duration,
pub num_requests: usize,
}
pub struct LimiterBuilder {
duration: Duration,
num_requests: usize,
}
impl LimiterBuilder {
pub fn new() -> Self {
Self {
duration: Duration::days(1),
num_requests: 1,
}
}
pub fn with_duration(mut self, duration: Duration) -> Self {
self.duration = duration;
self
}
pub fn with_num_requests(mut self, num_requests: usize) -> Self {
self.num_requests = num_requests;
self
}
pub fn build(self) -> Arc<Mutex<Limiter>> {
let ip_addresses = HashMap::new();
Arc::new(Mutex::new(Limiter {
ip_addresses,
duration: self.duration,
num_requests: self.num_requests,
}))
}
}
impl<S, B> Transform<S, ServiceRequest> for RateLimiter
where
S: Service<ServiceRequest, Response=ServiceResponse<B>, Error=Error> + 'static,
S::Future: 'static,
B: 'static + MessageBody,
{
type Response = ServiceResponse<EitherBody<B, BoxBody>>;
type Error = Error;
type Transform = RateLimiterMiddleware<S>;
type InitError = ();
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
ok(RateLimiterMiddleware {
service: Arc::new(service),
limiter: self.limiter.clone(),
})
}
}
impl<S, B> Service<ServiceRequest> for RateLimiterMiddleware<S>
where
S: Service<ServiceRequest, Response=ServiceResponse<B>, Error=Error> + 'static,
S::Future: 'static,
B: 'static + MessageBody,
{
type Response = ServiceResponse<EitherBody<B, BoxBody>>;
type Error = Error;
type Future = Pin<Box<dyn Future<Output=Result<Self::Response, Self::Error>>>>;
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx)
}
fn call(&self, req: ServiceRequest) -> Self::Future {
let limiter = Arc::clone(&self.limiter);
let service = Arc::clone(&self.service);
Box::pin(handle_rate_limiting(req, limiter, service))
}
}
handle_rate_limiting
function is described in the next section.
last_request_time
and num_requests
values.
Limiter
struct a HttpResponse
is returned to the client with HTTP Status Code
429 and additional headers used to provide information to the client about the
rate limits.
Limiter
struct is updated to reflect the latest request.
pub async fn handle_rate_limiting<S, B>(
req: ServiceRequest,
limiter: Arc<Mutex<Limiter>>,
service: Arc<S>,
) -> Result<ServiceResponse<EitherBody<B>>, Error>
where
S: Service<ServiceRequest, Response=ServiceResponse<B>, Error=Error>,
S::Future: 'static,
B: 'static + MessageBody,
{
let ip = match req.peer_addr() {
Some(addr) => addr.ip(),
None => {
// peer_Addr only returns None during unit test https://docs.rs/actix-web/latest/actix_web/struct.HttpRequest.html#method.peer_addr
let res: ServiceResponse<B> = service.call(req).await?;
return Ok(res.map_into_left_body());
}
};
let now = Utc::now();
let mut limiter = limiter.lock().unwrap();
let time_count = {
limiter.ip_addresses.entry(ip.clone()).or_insert(TimeCount {
last_request: now,
num_requests: 0,
}).clone()
};
let last_request_time = time_count.last_request;
let request_count = time_count.num_requests;
println!("IP: {} - Last Request Time: {}, Request Count: {}", ip, last_request_time, request_count);
let mut too_many_requests = false;
let mut new_last_request_time = last_request_time;
let mut new_request_count = request_count;
let limiter_duration_secs = limiter.duration.num_seconds();
if now - last_request_time <= Duration::seconds(limiter_duration_secs) {
if request_count >= limiter.num_requests {
too_many_requests = true;
} else {
new_request_count += 1;
println!("IP: {} - Incremented Request Count: {}", ip, new_request_count);
}
} else {
// Reset time and count
new_last_request_time = now;
new_request_count = 1;
println!("IP: {} - Reset Request Count and Time", ip);
}
let entry = limiter.ip_addresses.entry(ip.clone()).or_insert(TimeCount { last_request: now, num_requests: 0 });
entry.last_request = new_last_request_time;
entry.num_requests = new_request_count;
if too_many_requests {
println!("IP: {} - Too Many Requests", ip);
let message = format!("Too many requests. Please try again in {} seconds.", limiter_duration_secs.to_string());
let too_many_requests_response = HttpResponse::TooManyRequests()
.content_type("text/plain")
.insert_header(("Retry-After", limiter_duration_secs.to_string()))
.insert_header(("X-RateLimit-Limit", limiter.num_requests.to_string()))
.insert_header(("X-RateLimit-Remaining", (limiter.num_requests - new_request_count).to_string()))
.insert_header(("X-RateLimit-Reset", remaining_time.num_seconds().to_string()))
.body(message);
return Ok(ServiceResponse::new(req.request().clone(), too_many_requests_response)
.map_into_boxed_body()
.map_into_right_body());
}
let res: ServiceResponse<B> = service.call(req).await?;
Ok(res.map_into_left_body())
}
One of my favorite features in Rust is how easy it is to write and run tests. In order to keep this page succinct, only one test case will be examined below, but I encourage the reader to view all test cases if they are interested.
#[actix_rt::test]
attribute is a macro defined in the Actix runtime
that abstracts
the lower level details of setting up a mock environment to run the test.
Limiter
and initialize the middleware service.
#[actix_rt::test]
async fn test_rate_limiter_with_different_ips() {
let limiter = LimiterBuilder::new()
.with_duration(Duration::seconds(20))
.with_num_requests(2)
.build();
let service = init_service(
App::new()
.wrap(RateLimiter::new(Arc::clone(&limiter)))
.route("/", web::get().to(HttpResponse::Ok)),
)
.await;
// First request from IP 127.0.0.1
let req = TestRequest::default()
.peer_addr(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 12345))
.to_request();
let resp = call_service(&service, req).await;
assert_eq!(resp.status(), StatusCode::OK);
// First request from IP 127.0.0.2
let req = TestRequest::default()
.peer_addr(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2)), 12345))
.to_request();
let resp = call_service(&service, req).await;
assert_eq!(resp.status(), StatusCode::OK);
// Second request from IP 127.0.0.1
let req = TestRequest::default()
.peer_addr(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 12345))
.to_request();
let resp = call_service(&service, req).await;
assert_eq!(resp.status(), StatusCode::OK);
// Third request from IP 127.0.0.1 within 20 seconds should be rate-limited
let req = TestRequest::default()
.peer_addr(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 12345))
.to_request();
let resp = call_service(&service, req).await;
assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
assert_eq!(resp.headers().get("Retry-After").unwrap(), "20");
assert_eq!(resp.headers().get("X-RateLimit-Limit").unwrap(), "2");
assert_eq!(resp.headers().get("X-RateLimit-Remaining").unwrap(), "0");
// Check if X-RateLimit-Reset header exists
assert!(resp.headers().get("X-RateLimit-Reset").is_some());
}
Similar to what was shown in the testing environment, we can use the new middleware by wrapping our server endpoints:
let limiter = LimiterBuilder::new()
.with_duration(Duration::seconds(20))
.with_num_requests(2)
.build();
HttpServer::new(move || {
App::new()
.app_data(web::Data::new(app_state.clone()))
.wrap(Logger::default())
.wrap(RateLimiter::new(Arc::clone(&limiter)))
.service(en_image_hash)
})
And now when we make a series of requests to that server running locally we can verify the rate limit:
The Rust standard library contains a log crate
that can be used to
add logging support to any library or binary crate. I've used it to add info
and
warn
messages
to my middleware. Consumers of my crate have the flexibility to tap into these logs if they
choose, and if they choose not to
use it at all the overhead is very low. The below code shows how I integrated the
log
crate:
use log::{info, warn};
...
let ip = match req.peer_addr() {
Some(addr) => addr.ip(),
None => {
// peer_addr only returns None during unit test https://docs.rs/actix-web/latest/actix_web/struct.HttpRequest.html#method.peer_addr
warn!("Requester socket address was found to be None type and will not be rate limited");
let res: ServiceResponse<B> = service.call(req).await?;
return Ok(res.map_into_left_body());
}
};
...
if now - last_request_time <= Duration::seconds(limiter_duration_secs) {
if request_count >= limiter.num_requests {
too_many_requests = true;
} else {
new_request_count += 1;
info!("Incremented request count for {} to {} requests in current duration", ip, new_request_count);
}
} else {
new_last_request_time = now;
new_request_count = 1;
info!("Reset duration and request count for {}", ip);
}
...
if too_many_requests {
info!("Sending 429 response to {}", ip);
let remaining_time = limiter.duration - (now - last_request_time);
let message = format!("Too many requests. Please try again in {} seconds.", remaining_time.num_seconds().to_string());
let too_many_requests_response = HttpResponse::TooManyRequests()
.content_type("text/plain")
.insert_header(("Retry-After", limiter_duration_secs.to_string()))
.insert_header(("X-RateLimit-Limit", limiter.num_requests.to_string()))
.insert_header(("X-RateLimit-Remaining", (limiter.num_requests - new_request_count).to_string()))
.insert_header(("X-RateLimit-Reset", remaining_time.num_seconds().to_string()))
.body(message);
return Ok(ServiceResponse::new(req.request().clone(), too_many_requests_response)
.map_into_boxed_body()
.map_into_right_body());
}
info!("Forwarding request from {} to {}", ip, req.path());
let res: ServiceResponse<B> = service.call(req).await?;
Ok(res.map_into_left_body())
}
In order to tap into the logs in a consumer crate it is necessary to use the env_logger
crate in addition to the log crate, and then
configure the env_logger
with a LevelFilter.
use log::LevelFilter;
use hash_svc::server;
fn main() {
env_logger::builder().filter_level(LevelFilter::Info).init();
match server::run() {
Ok(_) => println!("Server ran successfully"),
Err(e) => eprintln!("Server error: {:?}", e),
}
}
Now when running the main()
program shown above the following logs are sent to
stdout
: