mirror of
https://github.com/kataras/iris.git
synced 2025-01-23 18:51:03 +01:00
186 lines
4.4 KiB
Go
186 lines
4.4 KiB
Go
|
package redis
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"io"
|
||
|
"sync/atomic"
|
||
|
|
||
|
"github.com/kataras/iris/v12/core/host"
|
||
|
"github.com/kataras/iris/v12/middleware/jwt"
|
||
|
|
||
|
"github.com/go-redis/redis/v8"
|
||
|
)
|
||
|
|
||
|
var defaultContext = context.Background()
|
||
|
|
||
|
type (
|
||
|
// Options is just a type alias for the go-redis Client Options.
|
||
|
Options = redis.Options
|
||
|
// ClusterOptions is just a type alias for the go-redis Cluster Client Options.
|
||
|
ClusterOptions = redis.ClusterOptions
|
||
|
)
|
||
|
|
||
|
// Client is the interface which both
|
||
|
// go-redis Client and Cluster Client implements.
|
||
|
type Client interface {
|
||
|
redis.Cmdable // Commands.
|
||
|
io.Closer // CloseConnection.
|
||
|
}
|
||
|
|
||
|
// Blocklist is a jwt.Blocklist backed by Redis.
|
||
|
type Blocklist struct {
|
||
|
// GetKey is a function which can be used how to extract
|
||
|
// the unique identifier for a token.
|
||
|
// Required. By default the token key is extracted through the claims.ID ("jti").
|
||
|
GetKey func(token []byte, claims jwt.Claims) string
|
||
|
// Prefix the token key into the redis database.
|
||
|
// Note that if you can also select a different database
|
||
|
// through ClientOptions (or ClusterOptions).
|
||
|
// Defaults to empty string (no prefix).
|
||
|
Prefix string
|
||
|
// Both Client and ClusterClient implements this interface.
|
||
|
client Client
|
||
|
connected uint32
|
||
|
// Customize any go-redis fields manually
|
||
|
// before Connect.
|
||
|
ClientOptions Options
|
||
|
ClusterOptions ClusterOptions
|
||
|
}
|
||
|
|
||
|
var _ jwt.Blocklist = (*Blocklist)(nil)
|
||
|
|
||
|
// NewBlocklist returns a new redis-based Blocklist.
|
||
|
// Modify its ClientOptions or ClusterOptions depending the application needs
|
||
|
// and call its Connect.
|
||
|
//
|
||
|
// Usage:
|
||
|
//
|
||
|
// blocklist := NewBlocklist()
|
||
|
// blocklist.ClientOptions.Addr = ...
|
||
|
// err := blocklist.Connect()
|
||
|
//
|
||
|
// And register it:
|
||
|
//
|
||
|
// verifier := jwt.NewVerifier(...)
|
||
|
// verifier.Blocklist = blocklist
|
||
|
func NewBlocklist() *Blocklist {
|
||
|
return &Blocklist{
|
||
|
GetKey: defaultGetKey,
|
||
|
Prefix: "",
|
||
|
ClientOptions: Options{
|
||
|
Addr: "127.0.0.1:6379",
|
||
|
// The rest are defaulted to good values already.
|
||
|
},
|
||
|
// If its Addrs > 0 before connect then cluster client is used instead.
|
||
|
ClusterOptions: ClusterOptions{},
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func defaultGetKey(_ []byte, claims jwt.Claims) string {
|
||
|
return claims.ID
|
||
|
}
|
||
|
|
||
|
// Connect prepares the redis client and fires a ping response to it.
|
||
|
func (b *Blocklist) Connect() error {
|
||
|
if b.Prefix != "" {
|
||
|
getKey := b.GetKey
|
||
|
b.GetKey = func(token []byte, claims jwt.Claims) string {
|
||
|
return b.Prefix + getKey(token, claims)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if len(b.ClusterOptions.Addrs) > 0 {
|
||
|
// Use cluster client.
|
||
|
b.client = redis.NewClusterClient(&b.ClusterOptions)
|
||
|
} else {
|
||
|
b.client = redis.NewClient(&b.ClientOptions)
|
||
|
}
|
||
|
|
||
|
_, err := b.client.Ping(defaultContext).Result()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
host.RegisterOnInterrupt(func() {
|
||
|
atomic.StoreUint32(&b.connected, 0)
|
||
|
b.client.Close()
|
||
|
})
|
||
|
atomic.StoreUint32(&b.connected, 1)
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// IsConnected reports whether the Connect function was called.
|
||
|
func (b *Blocklist) IsConnected() bool {
|
||
|
return atomic.LoadUint32(&b.connected) > 0
|
||
|
}
|
||
|
|
||
|
// ValidateToken checks if the token exists and
|
||
|
func (b *Blocklist) ValidateToken(token []byte, c jwt.Claims, err error) error {
|
||
|
if err != nil {
|
||
|
if err == jwt.ErrExpired {
|
||
|
b.Del(b.GetKey(token, c))
|
||
|
}
|
||
|
|
||
|
return err // respect the previous error.
|
||
|
}
|
||
|
|
||
|
has, err := b.Has(b.GetKey(token, c))
|
||
|
if err != nil {
|
||
|
return err
|
||
|
} else if has {
|
||
|
return jwt.ErrBlocked
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// InvalidateToken invalidates a verified JWT token.
|
||
|
func (b *Blocklist) InvalidateToken(token []byte, c jwt.Claims) error {
|
||
|
key := b.GetKey(token, c)
|
||
|
return b.client.SetEX(defaultContext, key, token, c.Timeleft()).Err()
|
||
|
}
|
||
|
|
||
|
// Del removes a token from the storage.
|
||
|
func (b *Blocklist) Del(key string) error {
|
||
|
return b.client.Del(defaultContext, key).Err()
|
||
|
}
|
||
|
|
||
|
// Has reports whether a specific token exists in the storage.
|
||
|
func (b *Blocklist) Has(key string) (bool, error) {
|
||
|
n, err := b.client.Exists(defaultContext, key).Result()
|
||
|
return n > 0, err
|
||
|
}
|
||
|
|
||
|
// Count returns the total amount of tokens stored.
|
||
|
func (b *Blocklist) Count() (int64, error) {
|
||
|
if b.Prefix == "" {
|
||
|
return b.client.DBSize(defaultContext).Result()
|
||
|
}
|
||
|
|
||
|
keys, err := b.getKeys(0)
|
||
|
if err != nil {
|
||
|
return 0, err
|
||
|
}
|
||
|
|
||
|
return int64(len(keys)), nil
|
||
|
}
|
||
|
|
||
|
func (b *Blocklist) getKeys(cursor uint64) ([]string, error) {
|
||
|
keys, cursor, err := b.client.Scan(defaultContext, cursor, b.Prefix+"*", 300000).Result()
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
if cursor != 0 {
|
||
|
moreKeys, err := b.getKeys(cursor)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
keys = append(keys, moreKeys...)
|
||
|
}
|
||
|
|
||
|
return keys, nil
|
||
|
}
|