Creating rate-limiting middleware for Actix Web

Some time ago I used Actix Web to create a back-end web service to hash remote images. At that time, I was only slightly disappointed that this web framework was lacking an internal ability to rate-limit server endpoints, but I also knew that this would make a perfect future project. This page is devoted to creating Actix Web middleware that could be used to "wrap" my server's endpoints in order to limit the frequency at which they would receive HTTP GET requests. The requirements were very simple: create middleware that would enforce limits on how many requests would be forwarded to server endpoints within a specific time period. A request volume exceeding these limits will not be forwarded to the endpoint, the middleware will short circuit the request and send an HTTP 429 "Too Many Requests" response.



  • This package uses crates from the standard library and third-party crates, most notably Actix Web.
  • The Limiter struct is the core logical component of the rate-limiting middleware and contains members used to track IP addresses that have made requests to server endpoints, the time at which the last request occurred, a quantity of requests, a duration, and a integer value representing the max number of requests allowed during this duration.
  • A LimiterBuilder is used to construct an Arc to a Limiter guarded by a Mutex. An Arc, atomic reference counter, is used to account for all instances of whatever it wraps. If and when this resource is ever cloned the Arc will be incremented and the new reference will point to the original resource location on the heap. A mutex is used to ensure that the Limiter (or whatever resource it guards) can be accessed safely across multiple threads. thread safety is a large topic.
  • I've created a dedicated RateLimiter struct that contains only a Limiter instance in order to implement the required middleware traits (described in the next section) and a RateLimiterMiddleware struct.
  • pub struct RateLimiter {
        pub(crate) limiter: Arc<Mutex<Limiter>>,
    }
    
    impl RateLimiter {
        pub fn new(limiter: Arc<Mutex<Limiter>>) -> Self {
            Self { limiter }
        }
    }
    
    pub struct RateLimiterMiddleware<S> {
        pub(crate) service: Arc<S>,
        pub(crate) limiter: Arc<Mutex<Limiter>>,
    }
                                                
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use chrono::{DateTime, Duration, Utc};
use actix_service::{Service, Transform};
use actix_web::dev::{ServiceRequest, ServiceResponse};
use actix_web::{Error, HttpResponse};
use futures::future::{ok, Ready};
use std::task::{Context, Poll};
use actix_web::body::{BoxBody, EitherBody, MessageBody};

#[derive(Clone)]
pub struct TimeCount {
    last_request: DateTime<Utc>,
    num_requests: usize,
}

pub struct Limiter {
    pub ip_addresses: HashMap<IpAddr, TimeCount>,
    pub duration: Duration,
    pub num_requests: usize,
}

pub struct LimiterBuilder {
    duration: Duration,
    num_requests: usize,
}

impl LimiterBuilder {
    pub fn new() -> Self {
        Self {
            duration: Duration::days(1),
            num_requests: 1,
        }
    }

    pub fn with_duration(mut self, duration: Duration) -> Self {
        self.duration = duration;
        self
    }

    pub fn with_num_requests(mut self, num_requests: usize) -> Self {
        self.num_requests = num_requests;
        self
    }

    pub fn build(self) -> Arc<Mutex<Limiter>> {
        let ip_addresses = HashMap::new();

        Arc::new(Mutex::new(Limiter {
            ip_addresses,
            duration: self.duration,
            num_requests: self.num_requests,
        }))
    }
}                                          
                        


Implementing the required Transform and Service traits

See official documentation regarding the Transform trait
impl<S, B> Transform<S, ServiceRequest> for RateLimiter
where
    S: Service<ServiceRequest, Response=ServiceResponse<B>, Error=Error> + 'static,
    S::Future: 'static,
    B: 'static + MessageBody,
{
    type Response = ServiceResponse<EitherBody<B, BoxBody>>;
    type Error = Error;
    type Transform = RateLimiterMiddleware<S>;
    type InitError = ();
    type Future = Ready<Result<Self::Transform, Self::InitError>>;

    fn new_transform(&self, service: S) -> Self::Future {
        ok(RateLimiterMiddleware {
            service: Arc::new(service),
            limiter: self.limiter.clone(),
        })
    }
}
And documentation regarding the Service trait
impl<S, B> Service<ServiceRequest> for RateLimiterMiddleware<S>
where
    S: Service<ServiceRequest, Response=ServiceResponse<B>, Error=Error> + 'static,
    S::Future: 'static,
    B: 'static + MessageBody,
{
    type Response = ServiceResponse<EitherBody<B, BoxBody>>;
    type Error = Error;
    type Future = Pin<Box<dyn Future<Output=Result<Self::Response, Self::Error>>>>;

    fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.service.poll_ready(cx)
    }

    fn call(&self, req: ServiceRequest) -> Self::Future {
        let limiter = Arc::clone(&self.limiter);
        let service = Arc::clone(&self.service);

        Box::pin(handle_rate_limiting(req, limiter, service))
    }
}
                            


  • The handle_rate_limiting function is described in the next section.

Inspecting the core rate-limiting logic



  • First, the IP address of the incoming ServiceRequest is extracted from the request.
  • The timestamp of the most recent request and the number of requests received in the duration range are obtained by accessing the HashMap values for the key IP address. If no values are present for this IP address they will be created with the current time and zero request count (this will be incremented later).
  • Mutable variables are declared so that this logical block can update the last_request_time and num_requests values.
  • The time of the current request being made is checked against the last request, and if the current request violates the constraints defined in the Limiter struct a HttpResponse is returned to the client with HTTP Status Code 429 and additional headers used to provide information to the client about the rate limits.
  • Information in the Limiter struct is updated to reflect the latest request.
pub async fn handle_rate_limiting<S, B>(
    req: ServiceRequest,
    limiter: Arc<Mutex<Limiter>>,
    service: Arc<S>,
) -> Result<ServiceResponse<EitherBody<B>>, Error>
where
    S: Service<ServiceRequest, Response=ServiceResponse<B>, Error=Error>,
    S::Future: 'static,
    B: 'static + MessageBody,
{
    let ip = match req.peer_addr() {
        Some(addr) => addr.ip(),
        None => {
            // peer_Addr only returns None during unit test https://docs.rs/actix-web/latest/actix_web/struct.HttpRequest.html#method.peer_addr
            let res: ServiceResponse<B> = service.call(req).await?;
            return Ok(res.map_into_left_body());
        }
    };

    let now = Utc::now();
    let mut limiter = limiter.lock().unwrap();

    let time_count = {
        limiter.ip_addresses.entry(ip.clone()).or_insert(TimeCount {
            last_request: now,
            num_requests: 0,
        }).clone()
    };

    let last_request_time = time_count.last_request;
    let request_count = time_count.num_requests;

    println!("IP: {} - Last Request Time: {}, Request Count: {}", ip, last_request_time, request_count);

    let mut too_many_requests = false;
    let mut new_last_request_time = last_request_time;
    let mut new_request_count = request_count;

    let limiter_duration_secs = limiter.duration.num_seconds();

    if now - last_request_time <= Duration::seconds(limiter_duration_secs) {
        if request_count >= limiter.num_requests {
            too_many_requests = true;
        } else {
            new_request_count += 1;
            println!("IP: {} - Incremented Request Count: {}", ip, new_request_count);
        }
    } else {
        // Reset time and count
        new_last_request_time = now;
        new_request_count = 1;
        println!("IP: {} - Reset Request Count and Time", ip);
    }


    let entry = limiter.ip_addresses.entry(ip.clone()).or_insert(TimeCount { last_request: now, num_requests: 0 });
    entry.last_request = new_last_request_time;
    entry.num_requests = new_request_count;


    if too_many_requests {
        println!("IP: {} - Too Many Requests", ip);
        let message = format!("Too many requests. Please try again in {} seconds.", limiter_duration_secs.to_string());

        let too_many_requests_response = HttpResponse::TooManyRequests()
            .content_type("text/plain")
            .insert_header(("Retry-After", limiter_duration_secs.to_string()))
            .insert_header(("X-RateLimit-Limit", limiter.num_requests.to_string()))
            .insert_header(("X-RateLimit-Remaining", (limiter.num_requests - new_request_count).to_string()))
            .insert_header(("X-RateLimit-Reset", remaining_time.num_seconds().to_string()))
            .body(message);
        return Ok(ServiceResponse::new(req.request().clone(), too_many_requests_response)
            .map_into_boxed_body()
            .map_into_right_body());
    }

    let res: ServiceResponse<B> = service.call(req).await?;
    Ok(res.map_into_left_body())
}                                                
                        

Writing tests in Rust

One of my favorite features in Rust is how easy it is to write and run tests. In order to keep this page succinct, only one test case will be examined below, but I encourage the reader to view all test cases if they are interested.



  • The #[actix_rt::test] attribute is a macro defined in the Actix runtime that abstracts the lower level details of setting up a mock environment to run the test.
  • This macro does a lot of work on the tester's behalf, but it is still necessary to build a Limiter and initialize the middleware service.
  • Once this is done a series of requests are made in order to check that the rate limiting logic is behaving as expected and that the headers are created correctly.
  • Rust makes running tests very straightforward, as shown in the images below
#[actix_rt::test]
async fn test_rate_limiter_with_different_ips() {
    let limiter = LimiterBuilder::new()
        .with_duration(Duration::seconds(20))
        .with_num_requests(2)
        .build();

    let service = init_service(
        App::new()
            .wrap(RateLimiter::new(Arc::clone(&limiter)))
            .route("/", web::get().to(HttpResponse::Ok)),
    )
        .await;

    // First request from IP 127.0.0.1
    let req = TestRequest::default()
        .peer_addr(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 12345))
        .to_request();
    let resp = call_service(&service, req).await;
    assert_eq!(resp.status(), StatusCode::OK);

    // First request from IP 127.0.0.2
    let req = TestRequest::default()
        .peer_addr(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2)), 12345))
        .to_request();
    let resp = call_service(&service, req).await;
    assert_eq!(resp.status(), StatusCode::OK);

    // Second request from IP 127.0.0.1
    let req = TestRequest::default()
        .peer_addr(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 12345))
        .to_request();
    let resp = call_service(&service, req).await;
    assert_eq!(resp.status(), StatusCode::OK);

    // Third request from IP 127.0.0.1 within 20 seconds should be rate-limited
    let req = TestRequest::default()
        .peer_addr(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 12345))
        .to_request();
    let resp = call_service(&service, req).await;
    assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
    assert_eq!(resp.headers().get("Retry-After").unwrap(), "20");
    assert_eq!(resp.headers().get("X-RateLimit-Limit").unwrap(), "2");
    assert_eq!(resp.headers().get("X-RateLimit-Remaining").unwrap(), "0");
    // Check if X-RateLimit-Reset header exists
    assert!(resp.headers().get("X-RateLimit-Reset").is_some());
}

                            
terminal output showing that tests passed
terminal output showing that tests passed


Integrating the middleware into a server application

Similar to what was shown in the testing environment, we can use the new middleware by wrapping our server endpoints:

                        let limiter = LimiterBuilder::new()
                        .with_duration(Duration::seconds(20))
                        .with_num_requests(2)
                        .build();

                        HttpServer::new(move || {
                            App::new()
                                .app_data(web::Data::new(app_state.clone()))
                                .wrap(Logger::default())
                                .wrap(RateLimiter::new(Arc::clone(&limiter)))
                                .service(en_image_hash)
                        })
                    

And now when we make a series of requests to that server running locally we can verify the rate limit:



terminal output showing that tests passed

Adding support for logging

The Rust standard library contains a log crate that can be used to add logging support to any library or binary crate. I've used it to add info and warn messages to my middleware. Consumers of my crate have the flexibility to tap into these logs if they choose, and if they choose not to use it at all the overhead is very low. The below code shows how I integrated the log crate:

use log::{info, warn};

...

let ip = match req.peer_addr() {
    Some(addr) => addr.ip(),
    None => {
        // peer_addr only returns None during unit test https://docs.rs/actix-web/latest/actix_web/struct.HttpRequest.html#method.peer_addr
        warn!("Requester socket address was found to be None type and will not be rate limited");
        let res: ServiceResponse<B> = service.call(req).await?;
        return Ok(res.map_into_left_body());
    }
};

...

if now - last_request_time <= Duration::seconds(limiter_duration_secs) {
    if request_count >= limiter.num_requests {
        too_many_requests = true;
    } else {
        new_request_count += 1;
        info!("Incremented request count for {} to {} requests in current duration", ip, new_request_count);
    }
} else {
    new_last_request_time = now;
    new_request_count = 1;
    info!("Reset duration and request count for {}", ip);
}

...

if too_many_requests {
    info!("Sending 429 response to {}", ip);
    let remaining_time = limiter.duration - (now - last_request_time);
    let message = format!("Too many requests. Please try again in {} seconds.", remaining_time.num_seconds().to_string());

    let too_many_requests_response = HttpResponse::TooManyRequests()
        .content_type("text/plain")
        .insert_header(("Retry-After", limiter_duration_secs.to_string()))
        .insert_header(("X-RateLimit-Limit", limiter.num_requests.to_string()))
        .insert_header(("X-RateLimit-Remaining", (limiter.num_requests - new_request_count).to_string()))
        .insert_header(("X-RateLimit-Reset", remaining_time.num_seconds().to_string()))
        .body(message);
    return Ok(ServiceResponse::new(req.request().clone(), too_many_requests_response)
        .map_into_boxed_body()
        .map_into_right_body());
}

info!("Forwarding request from {} to {}", ip, req.path());
let res: ServiceResponse<B> = service.call(req).await?;
Ok(res.map_into_left_body())
}                            
                        

In order to tap into the logs in a consumer crate it is necessary to use the env_logger crate in addition to the log crate, and then configure the env_logger with a LevelFilter.

use log::LevelFilter;
use hash_svc::server;

fn main() {
    env_logger::builder().filter_level(LevelFilter::Info).init();

    match server::run() {
        Ok(_) => println!("Server ran successfully"),
        Err(e) => eprintln!("Server error: {:?}", e),
    }
}
                        

Now when running the main() program shown above the following logs are sent to stdout:

log messages shown in stdout


Thanks for reading

Check out my published crate at crates.io.

View the project source code on GitHub

Top Of Page