// Package rate implements rate limiter for Iris client requests. // Example can be found at: _examples/request-ratelimit/main.go. package rate import ( "math" "sync" "time" "github.com/kataras/iris/v12/context" "golang.org/x/time/rate" ) func init() { context.SetHandlerName("iris/middleware/rate.(*Limiter).serveHTTP-fm", "iris.ratelimit") } // Option declares a function which can be passed on `Limit` package-level // to modify its internal fields. Available Options are: // * ExceedHandler // * ClientData // * PurgeEvery type Option func(*Limiter) // ExceedHandler is an `Option` that can be passed at the `Limit` package-level function. // It accepts a handler that will be executed every time a client tries to reach a page/resource // which is not accessible for that moment. func ExceedHandler(handler context.Handler) Option { return func(l *Limiter) { l.exceedHandler = handler } } // ClientData is an `Option` that can be passed at the `Limit` package-level function. // It accepts a function which provides the Iris Context and should return custom data // that will be stored to the Client and be retrieved as `Get(ctx).Client.Data` later on. func ClientData(clientDataFunc func(ctx *context.Context) interface{}) Option { return func(l *Limiter) { l.clientDataFunc = clientDataFunc } } // PurgeEvery is an `Option` that can be passed at the `Limit` package-level function. // This function will check for old entries and remove them. // // E.g. Limit(..., PurgeEvery(time.Minute, 5*time.Minute)) to // check every 1 minute if a client's last visit was 5 minutes ago ("old" entry) // and remove it from the memory. func PurgeEvery(every time.Duration, maxLifetime time.Duration) Option { condition := func(c *Client) bool { // for a custom purger the end-developer may use the c.Data filled from a `ClientData` option. return time.Since(c.LastSeen()) > maxLifetime } return func(l *Limiter) { go func() { for { time.Sleep(every) l.Purge(condition) } }() } } // Every converts a minimum time interval between events to a limit. // Usage: Limit(Every(1*time.Minute), 3, options...) func Every(interval time.Duration) float64 { if interval <= 0 { return Inf } return 1 / interval.Seconds() } type ( // Limiter is featured with the necessary functions to limit requests per second. // It has a single exported method `Purge` which helps to manually remove // old clients from the memory. Limiter is not exposed by a function, // callers should use it inside an `Option` for the `Limit` package-level function. Limiter struct { clientDataFunc func(ctx *context.Context) interface{} // fill the Client's Data field. exceedHandler context.Handler // when too many requests. limit rate.Limit burstSize int clients map[string]*Client mu sync.RWMutex // mutex for clients. } // Client holds some request information and the rate limiter itself. // It can be retrieved by the `Get` package-level function. // It can be used to manually add RateLimit response headers. Client struct { ID string Data interface{} Limiter *rate.Limiter lastSeen time.Time mu sync.RWMutex // mutex for lastSeen. } ) // Inf is the infinite rate limit; it allows all events (even if burst is zero). const Inf = math.MaxFloat64 // Limit returns a new rate limiter handler that allows requests up to rate "limit" and permits // bursts of at most "burst" tokens. See `rate.SetKey(ctx, key string)` and `rate.Get` too. // // E.g. Limit(1, 5) to allow 1 request per second, with a maximum burst size of 5. // // See `ExceedHandler`, `ClientData` and `PurgeEvery` for the available "options". func Limit(limit float64, burst int, options ...Option) context.Handler { l := &Limiter{ clients: make(map[string]*Client), limit: rate.Limit(limit), burstSize: burst, exceedHandler: func(ctx *context.Context) { ctx.StopWithStatus(429) // Too Many Requests. }, } for _, opt := range options { opt(l) } return l.serveHTTP } // Purge removes client entries from the memory based on the given "condition". func (l *Limiter) Purge(condition func(*Client) bool) { l.mu.Lock() for id, client := range l.clients { if condition(client) { delete(l.clients, id) } } l.mu.Unlock() } func (l *Limiter) serveHTTP(ctx *context.Context) { id := getIdentifier(ctx) l.mu.RLock() client, ok := l.clients[id] l.mu.RUnlock() if !ok { client = &Client{ ID: id, Limiter: rate.NewLimiter(l.limit, l.burstSize), } if l.clientDataFunc != nil { client.Data = l.clientDataFunc(ctx) } // if l.store(ctx, client) { // ^ no, let's keep it simple. l.mu.Lock() l.clients[id] = client l.mu.Unlock() } client.mu.Lock() client.lastSeen = time.Now() client.mu.Unlock() ctx.Values().Set(clientContextKey, client) if client.Limiter.Allow() { ctx.Next() return } if l.exceedHandler != nil { l.exceedHandler(ctx) } } const identifierContextKey = "iris.ratelimit.identifier" // SetIdentifier can be called manually from a handler or a middleare // to change the identifier per client. The default key for a client is its Remote IP. func SetIdentifier(ctx *context.Context, key string) { ctx.Values().Set(identifierContextKey, key) } func getIdentifier(ctx *context.Context) string { if entry, ok := ctx.Values().GetEntry(identifierContextKey); ok { return entry.ValueRaw.(string) } return ctx.RemoteAddr() } const clientContextKey = "iris.ratelimit.client" // Get returns the current rate limited `Client`. // Use it when you want to log or add response headers based on the current request limitation. // // You can read more about X-RateLimit response headers at: // https://tools.ietf.org/id/draft-polli-ratelimit-headers-00.html. // A good example of that is the GitHub API itself: https://developer.github.com/v3/#rate-limiting func Get(ctx *context.Context) *Client { if v := ctx.Values().Get(clientContextKey); v != nil { if c, ok := v.(*Client); ok { return c } } return nil } // LastSeen reports the last Client's visit. func (c *Client) LastSeen() time.Time { c.mu.RLock() t := c.lastSeen c.mu.RUnlock() return t } // TokensFromDuration is a unit conversion function from a time duration to the number of tokens // which could be accumulated during that duration at a rate of limit tokens per second. func (c *Client) TokensFromDuration(d time.Duration) float64 { // rate.go#tokensFromDuration limit := float64(c.Limiter.Limit()) sec := float64(d/time.Second) * limit nsec := float64(d%time.Second) * limit return sec + nsec/1e9 } // DurationFromTokens is a unit conversion function from the number of tokens to the duration // of time it takes to accumulate them at a rate of limit tokens per second. func (c *Client) DurationFromTokens(tokens float64) time.Duration { // rate.go#durationFromTokens seconds := tokens / float64(c.Limiter.Limit()) return time.Nanosecond * time.Duration(1e9*seconds) }