I wanted to implement some functionality whereby administrators for a mobile app I developed would be able to send app users messages. Push messages weren't desirable for the use case and it was required that users could opt-in to receiving messages specific to a given language channel supported by the app. This page will be focused on how I accomplished this using Rust and Actix Web.
use log::error;
use std::str::FromStr;
#[derive(serde::Serialize, serde::Deserialize, Clone, PartialEq, Eq, Hash, Debug)]
pub enum Langs {
English,
Spanish,
French,
Italian,
Portuguese,
German,
}
impl Langs {
fn from_str_internal(s: &str) -> Option<Self> {
match s {
"English" => Some(Langs::English),
"Spanish" => Some(Langs::Spanish),
"French" => Some(Langs::French),
"Italian" => Some(Langs::Italian),
"Portuguese" => Some(Langs::Portuguese),
"German" => Some(Langs::German),
_ => None,
}
}
}
impl ToString for Langs {
fn to_string(&self) -> String {
match self {
Langs::English => "English",
Langs::Spanish => "Spanish",
Langs::French => "French",
Langs::Italian => "Italian",
Langs::Portuguese => "Portuguese",
Langs::German => "German",
}
.to_string()
}
}
impl From<String> for Langs {
fn from(s: String) -> Self {
Langs::from_str_internal(&s).unwrap_or_else(|| {
error!("Unknown language found in table: {}", s);
Langs::English
})
}
}
impl FromStr for Langs {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Langs::from_str_internal(s).ok_or_else(|| format!("Unknown language: {}", s))
}
}
#[derive(serde::Serialize, serde::Deserialize, Clone, Copy, Debug, PartialEq)]
pub enum Expiration {
Hour = 60 * 60,
Day = 60 * 60 * 24,
Week = 60 * 60 * 24 * 7,
Quarter = 60 * 60 * 24 * 7 * 12,
Year = 60 * 60 * 24 * 365,
}
#[derive(serde::Serialize, serde::Deserialize, Clone, Debug, PartialEq)]
pub struct Message {
pub id: Uuid,
pub created: DateTime<Utc>,
pub content: String,
pub lang: Langs,
pub expires: Expiration,
pub title: String,
pub image_url : Option<String>,
}
#[derive(serde::Serialize, serde::Deserialize, Clone, Debug, PartialEq)]
pub struct NewMessage {
pub content: String,
pub lang: Langs,
pub expires: Expiration,
pub title: String,
pub image_url : Option<String>,
}
#[derive(serde::Serialize, serde::Deserialize, Clone, Debug, PartialEq)]
pub struct EditMessage {
pub id: Uuid,
pub content: String,
pub title: String,
pub image_url : Option<String>,
}
// Endpoint to post a new message to the shared message repo
#[post("/api/messages")]
pub async fn add_message(
repo: Data<Arc<Mutex<Vec<Message>>>>,
body: Json<NewMessage>,
) -> Result<HttpResponse, actix_web::Error> {
let mut repo = repo.lock().map_err(|_| {
actix_web::error::ErrorInternalServerError("Failed to acquire lock on message repo")
})?;
let new_message = Message {
id: Uuid::new_v4(),
created: Utc::now(),
content: body.content.clone(),
lang: body.lang.clone(),
expires: body.expires.clone(),
title: body.title.clone(),
image_url: body.image_url.clone(),
};
repo.push(new_message);
Ok(HttpResponse::Ok().finish())
}
// Endpoint to get a message by language
#[get("/api/messages/{lang}")]
pub async fn get_messages_by_lang(
repo: Data<Arc<Mutex<Vec<Message>>>>,
lang: Path<Langs>,
) -> Result<HttpResponse, actix_web::Error> {
let repo = repo.lock().map_err(|_| {
actix_web::error::ErrorInternalServerError("Failed to acquire lock on message repo")
})?;
let messages: Vec<Message> = repo.iter().filter(|x| x.lang == *lang).cloned().collect();
Ok(HttpResponse::Ok()
.content_type("application/json; charset=utf-8")
.json(messages))
}
// Endpoint to edit a message
#[patch("/api/messages")]
pub async fn edit_message(
repo: Data<Arc<Mutex<Vec<Message>>>>,
body: Json<EditMessage>,
) -> Result<HttpResponse, actix_web::Error> {
let mut repo = repo.lock().map_err(|_| {
actix_web::error::ErrorInternalServerError("Failed to acquire lock on message repo")
})?;
if let Some(index) = repo.iter().position(|x| x.id == body.id) {
repo[index].title = body.title.clone();
repo[index].content = body.content.clone();
repo[index].image_url = body.image_url.clone();
Ok(HttpResponse::Ok().finish())
} else {
Ok(HttpResponse::NotFound().finish())
}
}
// Endpoint to delete a message by id
#[delete("/api/messages/{id}")]
pub async fn delete_message(
repo: Data<Arc<Mutex<Vec<Message>>>>,
id: Path<Uuid>,
) -> Result<HttpResponse, actix_web::Error> {
let mut repo = repo.lock().map_err(|_| {
actix_web::error::ErrorInternalServerError("Failed to acquire lock on message repo")
})?;
if let Some(index) = repo.iter().position(|x| x.id == *id) {
repo.remove(index);
Ok(HttpResponse::Ok().finish())
} else {
Ok(HttpResponse::NotFound().finish())
}
}
// Function to iterate over a Arc<Mutex<Vec<Message>>> and remove any messages exceeding a certain age
pub fn remove_old_messages(repo: Arc<Mutex<Vec<Message>>>) {
let mut repo = repo.lock().unwrap();
let now = Utc::now();
repo.retain(|msg| {
now.signed_duration_since(msg.created) < Duration::seconds(msg.expires as i64)
});
}
async fn main() -> std::io::Result<()> {
dotenv().expect("Failed to read .env file");
std::env::set_var("RUST_LOG", "debug");
env_logger::init();
let insecure_listen_addr = env::var("LISTEN_HTTP").expect("LISTEN_HTTP must be set");
let secure_listen_addr = env::var("LISTEN_HTTPS").expect("LISTEN_HTTPS must be set");
let cert_path = env::var("TLS_CERT_PATH").expect("TLS_CERT_PATH must be set");
let key_path = env::var("TLS_KEY_PATH").expect("TLS_KEY_PATH must be set");
let rustls_config = load_rustls_config(&cert_path, &key_path)?;
let messages: Arc<Mutex<Vec<Message>>> = Arc::new(Mutex::new(Vec::new()));
let background_messages_clone = messages.clone();
let tls_messages_clone = messages.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(60));
loop {
interval.tick().await;
remove_old_messages(background_messages_clone.clone());
}
});
let limiter = LimiterBuilder::new()
.with_duration(Duration::minutes(1))
.with_num_requests(60)
.build();
let tls_limiter = LimiterBuilder::new()
.with_duration(Duration::minutes(1))
.with_num_requests(3)
.build();
let insecure_app_factory = move || {
let logger = Logger::default();
App::new()
.wrap(logger)
.wrap(SecurityHeaders)
.wrap(RateLimiter::new(Arc::clone(&limiter)))
.app_data(Data::new(messages.clone()))
.configure(routing::configure_insecure_message_routes)
};
let secure_app_factory = move || {
let logger = Logger::default();
App::new()
.wrap(logger)
.wrap(SecurityHeaders)
.wrap(RateLimiter::new(Arc::clone(&tls_limiter)))
.app_data(Data::new(tls_messages_clone.clone()))
.service(
scope("/admin")
.wrap(ApiKeyMiddleware)
.configure(routing::configure_secure_message_routes),
)
};
let http_server = HttpServer::new(insecure_app_factory.clone())
.bind(insecure_listen_addr)?
.run();
let https_server = HttpServer::new(secure_app_factory)
.bind_rustls(&secure_listen_addr, rustls_config)?
.run();
futures_util::try_join!(http_server, https_server)?;
Ok(())
}
I've used Actix Web middleware to wrap all endpoints with security headers and sensitive endpoints with logic to validate the security header. I am not currently serving web pages from this server, and so some of the security headers may not seem relevant, but they aren't doing any harm. In the future, I would like to validate the security header using constant time comparison to mitigate any timing attacks. For now, I am relying on rate limiting and network segmentation to avoid anything like this.
pub struct SecurityHeaders;
impl<S, B> actix_service::Transform<S, ServiceRequest> for SecurityHeaders
where
S: actix_service::Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
S::Future: 'static,
{
type Response = ServiceResponse<B>;
type Error = Error;
type InitError = ();
type Transform = SecurityHeadersMiddleware<S>;
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
ok(SecurityHeadersMiddleware { service })
}
}
pub struct SecurityHeadersMiddleware<S> {
service: S,
}
impl<S, B> actix_service::Service<ServiceRequest> for SecurityHeadersMiddleware<S>
where
S: actix_service::Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
S::Future: 'static,
{
type Response = ServiceResponse<B>;
type Error = Error;
type Future = futures_util::future::LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx)
}
fn call(&self, req: ServiceRequest) -> Self::Future {
let fut = self.service.call(req);
async move {
let mut res = fut.await?;
res.headers_mut().insert(
actix_web::http::header::CONTENT_SECURITY_POLICY,
"default-src 'none'; frame-ancestors 'none'".parse().unwrap(),
);
res.headers_mut().insert(
actix_web::http::header::X_CONTENT_TYPE_OPTIONS,
"nosniff".parse().unwrap(),
);
res.headers_mut().insert(
actix_web::http::header::X_FRAME_OPTIONS,
"DENY".parse().unwrap(),
);
res.headers_mut().insert(
actix_web::http::header::STRICT_TRANSPORT_SECURITY,
"max-age=31536000; includeSubDomains".parse().unwrap(),
);
Ok(res)
}
.boxed_local()
}
}
pub struct ApiKeyMiddleware;
impl<S, B> actix_service::Transform<S, ServiceRequest> for ApiKeyMiddleware
where
S: actix_service::Service<ServiceRequest, Response = actix_web::dev::ServiceResponse<B>, Error = Error> + 'static,
S::Future: 'static,
{
type Response = actix_web::dev::ServiceResponse<B>;
type Error = Error;
type InitError = ();
type Transform = ApiKeyMiddlewareService<S>;
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
ok(ApiKeyMiddlewareService { service })
}
}
pub struct ApiKeyMiddlewareService<S> {
service: S,
}
impl<S, B> actix_service::Service<ServiceRequest> for ApiKeyMiddlewareService<S>
where
S: actix_service::Service<ServiceRequest, Response = actix_web::dev::ServiceResponse<B>, Error = Error> + 'static,
S::Future: 'static,
{
type Response = actix_web::dev::ServiceResponse<B>;
type Error = Error;
type Future = futures_util::future::LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx)
}
fn call(&self, req: ServiceRequest) -> Self::Future {
let api_key = req.headers().get("x-api-key").cloned();
let fut = self.service.call(req);
async move {
if let Some(api_key) = api_key {
let expected_api_key = env::var("ADMIN_API_KEY").expect("ADMIN_API_KEY must be set");
if api_key.to_str().unwrap_or("") == expected_api_key {
return fut.await;
}
}
Err(actix_web::error::ErrorUnauthorized("Invalid API key")).into()
}
.boxed_local()
}
}
Actix Web is a great framework for backend service development, and has quickly become my top pick for developing web services.