diff --git a/_examples/miscellaneous/ratelimit/main.go b/_examples/miscellaneous/ratelimit/main.go
index 2260840a..2f1b235e 100644
--- a/_examples/miscellaneous/ratelimit/main.go
+++ b/_examples/miscellaneous/ratelimit/main.go
@@ -11,33 +11,68 @@ func main() {
app := newApp()
app.Logger().SetLevel("debug")
+ // * http://localhost:8080/v1
+ // * http://localhost:8080/v1/other
+ // * http://localhost:8080/v2/list (with X-API-Key request header)
app.Listen(":8080")
}
func newApp() *iris.Application {
app := iris.New()
- // Register the rate limiter middleware at the root router.
- //
- // Fist and second input parameters:
- // Allow 1 request per second, with a maximum burst size of 5.
- //
- // Third optional variadic input parameter:
- // Can be a cleanup function.
- // Iris provides a cleanup function that will check for old entries and remove them.
- // You can customize it, e.g. check every 1 minute
- // if a client's last visit was 5 minutes ago ("old" entry)
- // and remove it from the memory.
- rateLimiter := rate.Limit(1, 5, rate.PurgeEvery(time.Minute, 5*time.Minute))
- app.Use(rateLimiter)
+ v1 := app.Party("/v1")
+ {
+ // Register the rate limiter middleware at the "/v1" subrouter.
+ //
+ // Fist and second input parameters:
+ // Allow 1 request per second, with a maximum burst size of 5.
+ //
+ // Third optional variadic input parameter:
+ // Can be a cleanup function.
+ // Iris provides a cleanup function that will check for old entries and remove them.
+ // You can customize it, e.g. check every 1 minute
+ // if a client's last visit was 5 minutes ago ("old" entry)
+ // and remove it from the memory.
+ limitV1 := rate.Limit(1, 5, rate.PurgeEvery(time.Minute, 5*time.Minute))
+ // rate.Every helper: 1 request per minute (with burst of 5):
+ // rate.Limit(rate.Every(1*time.Minute), 5)
+ v1.Use(limitV1)
- // Routes.
- app.Get("/", index)
- app.Get("/other", other)
+ v1.Get("/", index)
+ v1.Get("/other", other)
+ }
+
+ v2 := app.Party("/v2")
+ {
+ v2.Use(useAPIKey)
+ // Initialize a new rate limit middleware to limit requests
+ // per API Key(see `useAPIKey` below) instead of client's Remote IP Address.
+ limitV2 := rate.Limit(1, 5, rate.PurgeEvery(time.Minute, 5*time.Minute))
+ v2.Use(limitV2)
+
+ v2.Get("/list", list)
+ }
return app
}
+func useAPIKey(ctx iris.Context) {
+ apiKey := ctx.GetHeader("X-API-Key")
+ if apiKey == "" { // [validate your API Key here...]
+ ctx.StopWithStatus(iris.StatusForbidden)
+ return
+ }
+
+ // Change the method that rate limit matches the requests with a specific user
+ // and set our own api key as theirs identifier.
+ rate.SetIdentifier(ctx, apiKey)
+ ctx.Next()
+}
+
+func list(ctx iris.Context) {
+ ctx.JSON(iris.Map{"key": "value"})
+}
+
func index(ctx iris.Context) {
ctx.HTML("
Index Page
")
}
diff --git a/middleware/rate/rate.go b/middleware/rate/rate.go
index 04036053..4931c8ea 100644
--- a/middleware/rate/rate.go
+++ b/middleware/rate/rate.go
@@ -64,6 +64,15 @@ func PurgeEvery(every time.Duration, maxLifetime time.Duration) Option {
}
}
+// Every converts a minimum time interval between events to a limit.
+// Usage: Limit(Every(1*time.Minute), 3, options...)
+func Every(interval time.Duration) float64 {
+ if interval <= 0 {
+ return Inf
+ }
+ return 1 / interval.Seconds()
+}
+
type (
// Limiter is featured with the necessary functions to limit requests per second.
// It has a single exported method `Purge` which helps to manually remove
@@ -72,8 +81,9 @@ type (
Limiter struct {
clientDataFunc func(ctx context.Context) interface{} // fill the Client's Data field.
exceedHandler context.Handler // when too many requests.
- limit rate.Limit
- burstSize int
+
+ limit rate.Limit
+ burstSize int
clients map[string]*Client
mu sync.RWMutex // mutex for clients.
@@ -83,9 +93,9 @@ type (
// It can be retrieved by the `Get` package-level function.
// It can be used to manually add RateLimit response headers.
Client struct {
- Limiter *rate.Limiter
- IP string
+ ID string
Data interface{}
+ Limiter *rate.Limiter
lastSeen time.Time
mu sync.RWMutex // mutex for lastSeen.
@@ -96,7 +106,8 @@ type (
const Inf = math.MaxFloat64
// Limit returns a new rate limiter handler that allows requests up to rate "limit" and permits
-// bursts of at most "burst" tokens.
+// bursts of at most "burst" tokens. See `rate.SetKey(ctx, key string)` and `rate.Get` too.
+//
// E.g. Limit(1, 5) to allow 1 request per second, with a maximum burst size of 5.
//
// See `ExceedHandler`, `ClientData` and `PurgeEvery` for the available "options".
@@ -120,24 +131,24 @@ func Limit(limit float64, burst int, options ...Option) context.Handler {
// Purge removes client entries from the memory based on the given "condition".
func (l *Limiter) Purge(condition func(*Client) bool) {
l.mu.Lock()
- for ip, client := range l.clients {
+ for id, client := range l.clients {
if condition(client) {
- delete(l.clients, ip)
+ delete(l.clients, id)
}
}
l.mu.Unlock()
}
func (l *Limiter) serveHTTP(ctx context.Context) {
- ip := ctx.RemoteAddr()
+ id := getIdentifier(ctx)
l.mu.RLock()
- client, ok := l.clients[ip]
+ client, ok := l.clients[id]
l.mu.RUnlock()
if !ok {
client = &Client{
+ ID: id,
Limiter: rate.NewLimiter(l.limit, l.burstSize),
- IP: ip,
}
if l.clientDataFunc != nil {
@@ -147,7 +158,7 @@ func (l *Limiter) serveHTTP(ctx context.Context) {
// if l.store(ctx, client) {
// ^ no, let's keep it simple.
l.mu.Lock()
- l.clients[ip] = client
+ l.clients[id] = client
l.mu.Unlock()
}
@@ -169,6 +180,22 @@ func (l *Limiter) serveHTTP(ctx context.Context) {
}
}
+const identifierContextKey = "iris.ratelimit.identifier"
+
+// SetIdentifier can be called manually from a handler or a middleare
+// to change the identifier per client. The default key for a client is its Remote IP.
+func SetIdentifier(ctx context.Context, key string) {
+ ctx.Values().Set(identifierContextKey, key)
+}
+
+func getIdentifier(ctx context.Context) string {
+ if entry, ok := ctx.Values().GetEntry(identifierContextKey); ok {
+ return entry.ValueRaw.(string)
+ }
+
+ return ctx.RemoteAddr()
+}
+
const clientContextKey = "iris.ratelimit.client"
// Get returns the current rate limited `Client`.
@@ -188,9 +215,9 @@ func Get(ctx context.Context) *Client {
}
// LastSeen reports the last Client's visit.
-func (c *Client) LastSeen() (t time.Time) {
+func (c *Client) LastSeen() time.Time {
c.mu.RLock()
- t = c.lastSeen
+ t := c.lastSeen
c.mu.RUnlock()
return t
}