iris/middleware/jwt/blocklist/redis/blocklist.go
2020-11-02 18:46:38 +02:00

186 lines
4.4 KiB
Go

package redis
import (
"context"
"io"
"sync/atomic"
"github.com/kataras/iris/v12/core/host"
"github.com/kataras/iris/v12/middleware/jwt"
"github.com/go-redis/redis/v8"
)
var defaultContext = context.Background()
type (
// Options is just a type alias for the go-redis Client Options.
Options = redis.Options
// ClusterOptions is just a type alias for the go-redis Cluster Client Options.
ClusterOptions = redis.ClusterOptions
)
// Client is the interface which both
// go-redis Client and Cluster Client implements.
type Client interface {
redis.Cmdable // Commands.
io.Closer // CloseConnection.
}
// Blocklist is a jwt.Blocklist backed by Redis.
type Blocklist struct {
// GetKey is a function which can be used how to extract
// the unique identifier for a token.
// Required. By default the token key is extracted through the claims.ID ("jti").
GetKey func(token []byte, claims jwt.Claims) string
// Prefix the token key into the redis database.
// Note that if you can also select a different database
// through ClientOptions (or ClusterOptions).
// Defaults to empty string (no prefix).
Prefix string
// Both Client and ClusterClient implements this interface.
client Client
connected uint32
// Customize any go-redis fields manually
// before Connect.
ClientOptions Options
ClusterOptions ClusterOptions
}
var _ jwt.Blocklist = (*Blocklist)(nil)
// NewBlocklist returns a new redis-based Blocklist.
// Modify its ClientOptions or ClusterOptions depending the application needs
// and call its Connect.
//
// Usage:
//
// blocklist := NewBlocklist()
// blocklist.ClientOptions.Addr = ...
// err := blocklist.Connect()
//
// And register it:
//
// verifier := jwt.NewVerifier(...)
// verifier.Blocklist = blocklist
func NewBlocklist() *Blocklist {
return &Blocklist{
GetKey: defaultGetKey,
Prefix: "",
ClientOptions: Options{
Addr: "127.0.0.1:6379",
// The rest are defaulted to good values already.
},
// If its Addrs > 0 before connect then cluster client is used instead.
ClusterOptions: ClusterOptions{},
}
}
func defaultGetKey(_ []byte, claims jwt.Claims) string {
return claims.ID
}
// Connect prepares the redis client and fires a ping response to it.
func (b *Blocklist) Connect() error {
if b.Prefix != "" {
getKey := b.GetKey
b.GetKey = func(token []byte, claims jwt.Claims) string {
return b.Prefix + getKey(token, claims)
}
}
if len(b.ClusterOptions.Addrs) > 0 {
// Use cluster client.
b.client = redis.NewClusterClient(&b.ClusterOptions)
} else {
b.client = redis.NewClient(&b.ClientOptions)
}
_, err := b.client.Ping(defaultContext).Result()
if err != nil {
return err
}
host.RegisterOnInterrupt(func() {
atomic.StoreUint32(&b.connected, 0)
b.client.Close()
})
atomic.StoreUint32(&b.connected, 1)
return nil
}
// IsConnected reports whether the Connect function was called.
func (b *Blocklist) IsConnected() bool {
return atomic.LoadUint32(&b.connected) > 0
}
// ValidateToken checks if the token exists and
func (b *Blocklist) ValidateToken(token []byte, c jwt.Claims, err error) error {
if err != nil {
if err == jwt.ErrExpired {
b.Del(b.GetKey(token, c))
}
return err // respect the previous error.
}
has, err := b.Has(b.GetKey(token, c))
if err != nil {
return err
} else if has {
return jwt.ErrBlocked
}
return nil
}
// InvalidateToken invalidates a verified JWT token.
func (b *Blocklist) InvalidateToken(token []byte, c jwt.Claims) error {
key := b.GetKey(token, c)
return b.client.SetEX(defaultContext, key, token, c.Timeleft()).Err()
}
// Del removes a token from the storage.
func (b *Blocklist) Del(key string) error {
return b.client.Del(defaultContext, key).Err()
}
// Has reports whether a specific token exists in the storage.
func (b *Blocklist) Has(key string) (bool, error) {
n, err := b.client.Exists(defaultContext, key).Result()
return n > 0, err
}
// Count returns the total amount of tokens stored.
func (b *Blocklist) Count() (int64, error) {
if b.Prefix == "" {
return b.client.DBSize(defaultContext).Result()
}
keys, err := b.getKeys(0)
if err != nil {
return 0, err
}
return int64(len(keys)), nil
}
func (b *Blocklist) getKeys(cursor uint64) ([]string, error) {
keys, cursor, err := b.client.Scan(defaultContext, cursor, b.Prefix+"*", 300000).Result()
if err != nil {
return nil, err
}
if cursor != 0 {
moreKeys, err := b.getKeys(cursor)
if err != nil {
return nil, err
}
keys = append(keys, moreKeys...)
}
return keys, nil
}