diff --git a/_examples/miscellaneous/ratelimit/main.go b/_examples/miscellaneous/ratelimit/main.go index 2260840a..2f1b235e 100644 --- a/_examples/miscellaneous/ratelimit/main.go +++ b/_examples/miscellaneous/ratelimit/main.go @@ -11,33 +11,68 @@ func main() { app := newApp() app.Logger().SetLevel("debug") + // * http://localhost:8080/v1 + // * http://localhost:8080/v1/other + // * http://localhost:8080/v2/list (with X-API-Key request header) app.Listen(":8080") } func newApp() *iris.Application { app := iris.New() - // Register the rate limiter middleware at the root router. - // - // Fist and second input parameters: - // Allow 1 request per second, with a maximum burst size of 5. - // - // Third optional variadic input parameter: - // Can be a cleanup function. - // Iris provides a cleanup function that will check for old entries and remove them. - // You can customize it, e.g. check every 1 minute - // if a client's last visit was 5 minutes ago ("old" entry) - // and remove it from the memory. - rateLimiter := rate.Limit(1, 5, rate.PurgeEvery(time.Minute, 5*time.Minute)) - app.Use(rateLimiter) + v1 := app.Party("/v1") + { + // Register the rate limiter middleware at the "/v1" subrouter. + // + // Fist and second input parameters: + // Allow 1 request per second, with a maximum burst size of 5. + // + // Third optional variadic input parameter: + // Can be a cleanup function. + // Iris provides a cleanup function that will check for old entries and remove them. + // You can customize it, e.g. check every 1 minute + // if a client's last visit was 5 minutes ago ("old" entry) + // and remove it from the memory. + limitV1 := rate.Limit(1, 5, rate.PurgeEvery(time.Minute, 5*time.Minute)) + // rate.Every helper: 1 request per minute (with burst of 5): + // rate.Limit(rate.Every(1*time.Minute), 5) + v1.Use(limitV1) - // Routes. - app.Get("/", index) - app.Get("/other", other) + v1.Get("/", index) + v1.Get("/other", other) + } + + v2 := app.Party("/v2") + { + v2.Use(useAPIKey) + // Initialize a new rate limit middleware to limit requests + // per API Key(see `useAPIKey` below) instead of client's Remote IP Address. + limitV2 := rate.Limit(1, 5, rate.PurgeEvery(time.Minute, 5*time.Minute)) + v2.Use(limitV2) + + v2.Get("/list", list) + } return app } +func useAPIKey(ctx iris.Context) { + apiKey := ctx.GetHeader("X-API-Key") + if apiKey == "" { // [validate your API Key here...] + ctx.StopWithStatus(iris.StatusForbidden) + return + } + + // Change the method that rate limit matches the requests with a specific user + // and set our own api key as theirs identifier. + rate.SetIdentifier(ctx, apiKey) + ctx.Next() +} + +func list(ctx iris.Context) { + ctx.JSON(iris.Map{"key": "value"}) +} + func index(ctx iris.Context) { ctx.HTML("

Index Page

") } diff --git a/middleware/rate/rate.go b/middleware/rate/rate.go index 04036053..4931c8ea 100644 --- a/middleware/rate/rate.go +++ b/middleware/rate/rate.go @@ -64,6 +64,15 @@ func PurgeEvery(every time.Duration, maxLifetime time.Duration) Option { } } +// Every converts a minimum time interval between events to a limit. +// Usage: Limit(Every(1*time.Minute), 3, options...) +func Every(interval time.Duration) float64 { + if interval <= 0 { + return Inf + } + return 1 / interval.Seconds() +} + type ( // Limiter is featured with the necessary functions to limit requests per second. // It has a single exported method `Purge` which helps to manually remove @@ -72,8 +81,9 @@ type ( Limiter struct { clientDataFunc func(ctx context.Context) interface{} // fill the Client's Data field. exceedHandler context.Handler // when too many requests. - limit rate.Limit - burstSize int + + limit rate.Limit + burstSize int clients map[string]*Client mu sync.RWMutex // mutex for clients. @@ -83,9 +93,9 @@ type ( // It can be retrieved by the `Get` package-level function. // It can be used to manually add RateLimit response headers. Client struct { - Limiter *rate.Limiter - IP string + ID string Data interface{} + Limiter *rate.Limiter lastSeen time.Time mu sync.RWMutex // mutex for lastSeen. @@ -96,7 +106,8 @@ type ( const Inf = math.MaxFloat64 // Limit returns a new rate limiter handler that allows requests up to rate "limit" and permits -// bursts of at most "burst" tokens. +// bursts of at most "burst" tokens. See `rate.SetKey(ctx, key string)` and `rate.Get` too. +// // E.g. Limit(1, 5) to allow 1 request per second, with a maximum burst size of 5. // // See `ExceedHandler`, `ClientData` and `PurgeEvery` for the available "options". @@ -120,24 +131,24 @@ func Limit(limit float64, burst int, options ...Option) context.Handler { // Purge removes client entries from the memory based on the given "condition". func (l *Limiter) Purge(condition func(*Client) bool) { l.mu.Lock() - for ip, client := range l.clients { + for id, client := range l.clients { if condition(client) { - delete(l.clients, ip) + delete(l.clients, id) } } l.mu.Unlock() } func (l *Limiter) serveHTTP(ctx context.Context) { - ip := ctx.RemoteAddr() + id := getIdentifier(ctx) l.mu.RLock() - client, ok := l.clients[ip] + client, ok := l.clients[id] l.mu.RUnlock() if !ok { client = &Client{ + ID: id, Limiter: rate.NewLimiter(l.limit, l.burstSize), - IP: ip, } if l.clientDataFunc != nil { @@ -147,7 +158,7 @@ func (l *Limiter) serveHTTP(ctx context.Context) { // if l.store(ctx, client) { // ^ no, let's keep it simple. l.mu.Lock() - l.clients[ip] = client + l.clients[id] = client l.mu.Unlock() } @@ -169,6 +180,22 @@ func (l *Limiter) serveHTTP(ctx context.Context) { } } +const identifierContextKey = "iris.ratelimit.identifier" + +// SetIdentifier can be called manually from a handler or a middleare +// to change the identifier per client. The default key for a client is its Remote IP. +func SetIdentifier(ctx context.Context, key string) { + ctx.Values().Set(identifierContextKey, key) +} + +func getIdentifier(ctx context.Context) string { + if entry, ok := ctx.Values().GetEntry(identifierContextKey); ok { + return entry.ValueRaw.(string) + } + + return ctx.RemoteAddr() +} + const clientContextKey = "iris.ratelimit.client" // Get returns the current rate limited `Client`. @@ -188,9 +215,9 @@ func Get(ctx context.Context) *Client { } // LastSeen reports the last Client's visit. -func (c *Client) LastSeen() (t time.Time) { +func (c *Client) LastSeen() time.Time { c.mu.RLock() - t = c.lastSeen + t := c.lastSeen c.mu.RUnlock() return t }