I created a back-end web service to hash remote images using Rust and learned a lot about Rust in the process. Every programming languages has its unique set of strengths and weaknesses. I really appreciated how writing macros in Rust could eliminate boilerplate code and streamline the development process. I also really liked the simplicity of the actix-web framework. I even wrote my own rate limiting middleware to rate limit the server routes. I wanted to run through this same approach, but using Go.
RateLimiter
struct that maps IP addresses to their
respective
limiter
,
a type provided by the rate
package.
RateLimiter
is created by calling the
NewRateLimiter
function.
RateLimited
function will serve as middleware in between the request
contacting an endpoint and the handler function mapped to that endpoint. As such it
receives
an http.HandlerFunc
as sole argument (the handler function to handle the
request)
and also returns a handler function that contains the rate limiting logic. We will see
how
this works involving an anonymous function later.
GetLimiter
function is used to obtain the limiter associated with an IP
address
in the limiters
map.
getIP
function will extract the IP address where the GET request
originated
from.
It contains logic to handle IP addresses containing port number or IPv6 addresses.
package main
import (
"log"
"net/http"
"strings"
"sync"
"time"
"golang.org/x/time/rate"
)
const NUM_REQUESTS = 2
const NUM_SECONDS = 20
type RateLimiter struct {
limiters map[string]*rate.Limiter
mu sync.Mutex
}
var rateLimiter = NewRateLimiter()
func RateLimited(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
return
}
ip := getIP(r)
limiter := rateLimiter.GetLimiter(ip)
if limiter.Allow() {
log.Println("Request allowed for IP:", ip)
next.ServeHTTP(w, r)
} else {
log.Println("Request rate limited for IP:", ip)
http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
}
}
}
func NewRateLimiter() *RateLimiter {
return &RateLimiter{
limiters: make(map[string]*rate.Limiter),
}
}
func (r *RateLimiter) GetLimiter(ip string) *rate.Limiter {
r.mu.Lock()
defer r.mu.Unlock()
limiter, exists := r.limiters[ip]
if !exists {
limiter = rate.NewLimiter(rate.Every(NUM_SECONDS*time.Second), NUM_REQUESTS) // 2 requests every 20 seconds
r.limiters[ip] = limiter
}
return limiter
}
func getIP(r *http.Request) string {
ip := r.Header.Get("X-Forwarded-For")
if ip == "" {
ip = r.RemoteAddr
}
// Extract the IP address without the port
if strings.Contains(ip, ":") {
ip = ip[:strings.LastIndex(ip, ":")]
if strings.Count(ip, ":") > 1 { // IPv6 address
ip = strings.Trim(ip, "[]")
}
}
log.Println("Extracted IP address:", ip)
return ip
}
downloadAndHashImage
function expects a string corresponding to the
URL of the remote image to hash.
downloadAndHashImages
function will call the function and
pass the remote image URL obtained from a JSON config file.
AppState
struct, shown below:
type AppState struct {
EnImageHash string
EnPImageHash string
EsImageHash string
EsPImageHash string
FrImageHash string
PoImageHash string
ItImageHash string
DeImageHash string
mu sync.Mutex
}
getHash
function (which will be wrapped by the rate limiting middleware
described earlier)
is mapped to each endpoint.
http.ResponseWriter
passed to the the initial
http.HandleFunc
invocation (shown below).
package main
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"net/http"
"time"
"github.com/tidwall/gjson"
)
const refreshHashInSeconds = 60
func downloadAndHashImage(url string) (string, error) {
resp, err := http.Get(url)
if err != nil {
return "", fmt.Errorf("error fetching image: %v", err)
}
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("error reading response bytes: %v", err)
}
hash := sha256.Sum256(data)
return hex.EncodeToString(hash[:]), nil
}
func downloadAndHashImages(config gjson.Result) {
for {
state.mu.Lock()
state.EnImageHash, _ = downloadAndHashImage(config.Get("secrets.en_image").String())
state.EnPImageHash, _ = downloadAndHashImage(config.Get("secrets.en_image_p").String())
state.EsImageHash, _ = downloadAndHashImage(config.Get("secrets.es_image").String())
state.EsPImageHash, _ = downloadAndHashImage(config.Get("secrets.es_image_p").String())
state.FrImageHash, _ = downloadAndHashImage(config.Get("secrets.fr_image").String())
state.PoImageHash, _ = downloadAndHashImage(config.Get("secrets.po_image").String())
state.ItImageHash, _ = downloadAndHashImage(config.Get("secrets.it_image").String())
state.DeImageHash, _ = downloadAndHashImage(config.Get("secrets.de_image").String())
state.mu.Unlock()
time.Sleep(time.Duration(refreshHashInSeconds) * time.Second)
}
}
func getHash(w http.ResponseWriter, _ *http.Request, hash *string) {
state.mu.Lock()
defer state.mu.Unlock()
fmt.Fprint(w, *hash)
}
main()
function loads a configuration file containing
the remote image URLs into the Config
struct.
package main
import (
"fmt"
"log"
"net/http"
"os"
"github.com/tidwall/gjson"
)
func loadConfig(filename string) (gjson.Result, error) {
data, err := os.ReadFile(filename)
if err != nil {
return gjson.Result{}, fmt.Errorf("unable to read config file: %v", err)
}
config := gjson.ParseBytes(data)
return config, nil
}
func main() {
config, err := loadConfig("Config.json")
if err != nil {
log.Fatalf("Error loading config: %v", err)
}
go downloadAndHashImages(config)
http.HandleFunc("/en", RateLimited(func(w http.ResponseWriter, r *http.Request) {
getHash(w, r, &state.EnImageHash)
}))
http.HandleFunc("/en_p", RateLimited(func(w http.ResponseWriter, r *http.Request) {
getHash(w, r, &state.EnPImageHash)
}))
http.HandleFunc("/es", RateLimited(func(w http.ResponseWriter, r *http.Request) {
getHash(w, r, &state.EsImageHash)
}))
http.HandleFunc("/es_p", RateLimited(func(w http.ResponseWriter, r *http.Request) {
getHash(w, r, &state.EsPImageHash)
}))
http.HandleFunc("/fr", RateLimited(func(w http.ResponseWriter, r *http.Request) {
getHash(w, r, &state.FrImageHash)
}))
http.HandleFunc("/po", RateLimited(func(w http.ResponseWriter, r *http.Request) {
getHash(w, r, &state.PoImageHash)
}))
http.HandleFunc("/it", RateLimited(func(w http.ResponseWriter, r *http.Request) {
getHash(w, r, &state.ItImageHash)
}))
http.HandleFunc("/de", RateLimited(func(w http.ResponseWriter, r *http.Request) {
getHash(w, r, &state.DeImageHash)
}))
log.Println("Starting server on :9191")
log.Fatal(http.ListenAndServe(":9191", nil))
}