jwt: add redis blocklist

This commit is contained in:
Gerasimos (Makis) Maropoulos 2020-11-02 06:31:28 +02:00
parent 836fb18c57
commit f1ebddb6d9
No known key found for this signature in database
GPG Key ID: 5DBE766BD26A54E7
8 changed files with 388 additions and 48 deletions

View File

@ -0,0 +1,110 @@
package main
import (
"context"
"time"
"github.com/kataras/iris/v12"
"github.com/kataras/iris/v12/middleware/jwt"
"github.com/kataras/iris/v12/middleware/jwt/blocklist/redis"
// Optionally to set token identifier.
"github.com/google/uuid"
)
var (
signatureSharedKey = []byte("sercrethatmaycontainch@r32length")
signer = jwt.NewSigner(jwt.HS256, signatureSharedKey, 15*time.Minute)
verifier = jwt.NewVerifier(jwt.HS256, signatureSharedKey)
)
type userClaims struct {
Username string `json:"username"`
}
func main() {
app := iris.New()
// IMPORTANT
//
// To use the in-memory blocklist just:
// verifier.WithDefaultBlocklist()
// To use a persistence blocklist, e.g. redis,
// start your redis-server and:
blocklist := redis.NewBlocklist()
// To configure single client or a cluster one:
// blocklist.ClientOptions.Addr = "127.0.0.1:6379"
// blocklist.ClusterOptions.Addrs = []string{...}
// To set a prefix for jwt ids:
// blocklist.Prefix = "myapp-"
//
// To manually connect and check its error before continue:
// err := blocklist.Connect()
// By default the verifier will try to connect, if failed then it will throw http error.
//
// And then register it:
verifier.Blocklist = blocklist
verifyMiddleware := verifier.Verify(func() interface{} {
return new(userClaims)
})
app.Get("/", authenticate)
protectedAPI := app.Party("/protected", verifyMiddleware)
protectedAPI.Get("/", protected)
protectedAPI.Get("/logout", logout)
// http://localhost:8080
// http://localhost:8080/protected?token=$token
// http://localhost:8080/logout?token=$token
// http://localhost:8080/protected?token=$token (401)
app.Listen(":8080")
}
// generateID optionally to set the value for `jwt.ID` on Sign,
// which sets the standard claims value "jti".
// If you use a blocklist with the default Blocklist.GetKey you have to set it.
var generateID = func(*context.Context) string {
id, _ := uuid.NewRandom()
return id.String()
}
func authenticate(ctx iris.Context) {
claims := userClaims{
Username: "kataras",
}
// Generate JWT ID.
random, err := uuid.NewRandom()
if err != nil {
ctx.StopWithError(iris.StatusInternalServerError, err)
return
}
id := random.String()
// Set the ID with the jwt.ID.
token, err := signer.Sign(claims, jwt.ID(id))
if err != nil {
ctx.StopWithError(iris.StatusInternalServerError, err)
return
}
ctx.Write(token)
}
func protected(ctx iris.Context) {
claims := jwt.Get(ctx).(*userClaims)
// To the standard claims, e.g. the generated ID:
// jwt.GetVerifiedToken(ctx).StandardClaims.ID
ctx.WriteString(claims.Username)
}
func logout(ctx iris.Context) {
ctx.Logout()
ctx.Redirect("/", iris.StatusTemporaryRedirect)
}

2
go.mod
View File

@ -21,7 +21,7 @@ require (
github.com/json-iterator/go v1.1.10
github.com/kataras/blocks v0.0.4
github.com/kataras/golog v0.1.5
github.com/kataras/jwt v0.0.4
github.com/kataras/jwt v0.0.5
github.com/kataras/neffos v0.0.16
github.com/kataras/pio v0.0.10
github.com/kataras/sitemap v0.0.5

View File

@ -1,7 +1,59 @@
package jwt
import (
"github.com/kataras/jwt"
import "github.com/kataras/jwt"
// Error values.
var (
ErrBlocked = jwt.ErrBlocked
ErrDecrypt = jwt.ErrDecrypt
ErrExpected = jwt.ErrExpected
ErrExpired = jwt.ErrExpired
ErrInvalidKey = jwt.ErrInvalidKey
ErrIssuedInTheFuture = jwt.ErrIssuedInTheFuture
ErrMissing = jwt.ErrMissing
ErrMissingKey = jwt.ErrMissingKey
ErrNotValidYet = jwt.ErrNotValidYet
ErrTokenAlg = jwt.ErrTokenAlg
ErrTokenForm = jwt.ErrTokenForm
ErrTokenSignature = jwt.ErrTokenSignature
)
// Signature algorithms.
var (
EdDSA = jwt.EdDSA
HS256 = jwt.HS256
HS384 = jwt.HS384
HS512 = jwt.HS512
RS256 = jwt.RS256
RS384 = jwt.RS384
RS512 = jwt.RS512
ES256 = jwt.ES256
ES384 = jwt.ES384
ES512 = jwt.ES512
PS256 = jwt.PS256
PS384 = jwt.PS384
PS512 = jwt.PS512
)
// Signature algorithm helpers.
var (
MustLoadHMAC = jwt.MustLoadHMAC
LoadHMAC = jwt.LoadHMAC
MustLoadRSA = jwt.MustLoadRSA
LoadPrivateKeyRSA = jwt.LoadPrivateKeyRSA
LoadPublicKeyRSA = jwt.LoadPublicKeyRSA
ParsePrivateKeyRSA = jwt.ParsePrivateKeyRSA
ParsePublicKeyRSA = jwt.ParsePublicKeyRSA
MustLoadECDSA = jwt.MustLoadECDSA
LoadPrivateKeyECDSA = jwt.LoadPrivateKeyECDSA
LoadPublicKeyECDSA = jwt.LoadPublicKeyECDSA
ParsePrivateKeyECDSA = jwt.ParsePrivateKeyECDSA
ParsePublicKeyECDSA = jwt.ParsePublicKeyECDSA
MustLoadEdDSA = jwt.MustLoadEdDSA
LoadPrivateKeyEdDSA = jwt.LoadPrivateKeyEdDSA
LoadPublicKeyEdDSA = jwt.LoadPublicKeyEdDSA
ParsePrivateKeyEdDSA = jwt.ParsePrivateKeyEdDSA
ParsePublicKeyEdDSA = jwt.ParsePublicKeyEdDSA
)
// Type alises for the underline jwt package.
@ -31,23 +83,6 @@ type (
TokenPair = jwt.TokenPair
)
// Signature algorithms.
var (
EdDSA = jwt.EdDSA
HS256 = jwt.HS256
HS384 = jwt.HS384
HS512 = jwt.HS512
RS256 = jwt.RS256
RS384 = jwt.RS384
RS512 = jwt.RS512
ES256 = jwt.ES256
ES384 = jwt.ES384
ES512 = jwt.ES512
PS256 = jwt.PS256
PS384 = jwt.PS384
PS512 = jwt.PS512
)
// Encryption algorithms.
var (
GCM = jwt.GCM
@ -73,6 +108,13 @@ var (
// Usage:
// signer.Sign(..., jwt.MaxAge(15*time.Minute))
MaxAge = jwt.MaxAge
// ID is a shurtcut to set jwt ID on Sign.
ID = func(id string) jwt.SignOptionFunc {
return func(c *Claims) {
c.ID = id
}
}
)
// Shortcuts for Signing and Verifying.
@ -82,24 +124,3 @@ var (
Sign = jwt.Sign
SignEncrypted = jwt.SignEncrypted
)
// Signature algorithm helpers.
var (
MustLoadHMAC = jwt.MustLoadHMAC
LoadHMAC = jwt.LoadHMAC
MustLoadRSA = jwt.MustLoadRSA
LoadPrivateKeyRSA = jwt.LoadPrivateKeyRSA
LoadPublicKeyRSA = jwt.LoadPublicKeyRSA
ParsePrivateKeyRSA = jwt.ParsePrivateKeyRSA
ParsePublicKeyRSA = jwt.ParsePublicKeyRSA
MustLoadECDSA = jwt.MustLoadECDSA
LoadPrivateKeyECDSA = jwt.LoadPrivateKeyECDSA
LoadPublicKeyECDSA = jwt.LoadPublicKeyECDSA
ParsePrivateKeyECDSA = jwt.ParsePrivateKeyECDSA
ParsePublicKeyECDSA = jwt.ParsePublicKeyECDSA
MustLoadEdDSA = jwt.MustLoadEdDSA
LoadPrivateKeyEdDSA = jwt.LoadPrivateKeyEdDSA
LoadPublicKeyEdDSA = jwt.LoadPublicKeyEdDSA
ParsePrivateKeyEdDSA = jwt.ParsePrivateKeyEdDSA
ParsePublicKeyEdDSA = jwt.ParsePublicKeyEdDSA
)

View File

@ -16,11 +16,16 @@ type Blocklist interface {
jwt.TokenValidator
// InvalidateToken should invalidate a verified JWT token.
InvalidateToken(token []byte, expiry int64)
InvalidateToken(token []byte, c Claims) error
// Del should remove a token from the storage.
Del(token []byte)
// Count should return the total amount of tokens stored.
Count() int
Del(key string) error
// Has should report whether a specific token exists in the storage.
Has(token []byte) bool
Has(key string) (bool, error)
// Count should return the total amount of tokens stored.
Count() (int64, error)
}
type blocklistConnect interface {
Connect() error
IsConnected() bool
}

View File

@ -0,0 +1,188 @@
package redis
import (
"context"
"io"
"sync/atomic"
"time"
"github.com/kataras/iris/v12/core/host"
"github.com/kataras/iris/v12/middleware/jwt"
"github.com/go-redis/redis/v8"
)
var defaultContext = context.Background()
type (
// Options is just a type alias for the go-redis Client Options.
Options = redis.Options
// ClusterOptions is just a type alias for the go-redis Cluster Client Options.
ClusterOptions = redis.ClusterOptions
)
// Client is the interface which both
// go-redis Client and Cluster Client implements.
type Client interface {
redis.Cmdable // Commands.
io.Closer // CloseConnection.
}
// Blocklist is a jwt.Blocklist backed by Redis.
type Blocklist struct {
Clock func() time.Time // Required. Defaults to time.Now.
// GetKey is a function which can be used how to extract
// the unique identifier for a token.
// Required. By default the token key is extracted through the claims.ID ("jti").
GetKey func(token []byte, claims jwt.Claims) string
// Prefix the token key into the redis database.
// Note that if you can also select a different database
// through ClientOptions (or ClusterOptions).
// Defaults to empty string (no prefix).
Prefix string
// Both Client and ClusterClient implements this interface.
client Client
connected uint32
// Customize any go-redis fields manually
// before Connect.
ClientOptions Options
ClusterOptions ClusterOptions
}
var _ jwt.Blocklist = (*Blocklist)(nil)
// NewBlocklist returns a new redis-based Blocklist.
// Modify its ClientOptions or ClusterOptions depending the application needs
// and call its Connect.
//
// Usage:
//
// blocklist := NewBlocklist()
// blocklist.ClientOptions.Addr = ...
// err := blocklist.Connect()
//
// And register it:
//
// verifier := jwt.NewVerifier(...)
// verifier.Blocklist = blocklist
func NewBlocklist() *Blocklist {
return &Blocklist{
Clock: time.Now,
GetKey: defaultGetKey,
Prefix: "",
ClientOptions: Options{
Addr: "127.0.0.1:6379",
// The rest are defaulted to good values already.
},
// If its Addrs > 0 before connect then cluster client is used instead.
ClusterOptions: ClusterOptions{},
}
}
func defaultGetKey(_ []byte, claims jwt.Claims) string {
return claims.ID
}
// Connect prepares the redis client and fires a ping response to it.
func (b *Blocklist) Connect() error {
if b.Prefix != "" {
getKey := b.GetKey
b.GetKey = func(token []byte, claims jwt.Claims) string {
return b.Prefix + getKey(token, claims)
}
}
if len(b.ClusterOptions.Addrs) > 0 {
// Use cluster client.
b.client = redis.NewClusterClient(&b.ClusterOptions)
} else {
b.client = redis.NewClient(&b.ClientOptions)
}
_, err := b.client.Ping(defaultContext).Result()
if err != nil {
return err
}
host.RegisterOnInterrupt(func() {
atomic.StoreUint32(&b.connected, 0)
b.client.Close()
})
atomic.StoreUint32(&b.connected, 1)
return nil
}
// IsConnected reports whether the Connect function was called.
func (b *Blocklist) IsConnected() bool {
return atomic.LoadUint32(&b.connected) > 0
}
// ValidateToken checks if the token exists and
func (b *Blocklist) ValidateToken(token []byte, c jwt.Claims, err error) error {
if err != nil {
if err == jwt.ErrExpired {
b.Del(b.GetKey(token, c))
}
return err // respect the previous error.
}
has, err := b.Has(b.GetKey(token, c))
if err != nil {
return err
} else if has {
return jwt.ErrBlocked
}
return nil
}
// InvalidateToken invalidates a verified JWT token.
func (b *Blocklist) InvalidateToken(token []byte, c jwt.Claims) error {
key := b.GetKey(token, c)
return b.client.SetEX(defaultContext, key, token, c.Timeleft()).Err()
}
// Del removes a token from the storage.
func (b *Blocklist) Del(key string) error {
return b.client.Del(defaultContext, key).Err()
}
// Has reports whether a specific token exists in the storage.
func (b *Blocklist) Has(key string) (bool, error) {
n, err := b.client.Exists(defaultContext, key).Result()
return n > 0, err
}
// Count returns the total amount of tokens stored.
func (b *Blocklist) Count() (int64, error) {
if b.Prefix == "" {
return b.client.DBSize(defaultContext).Result()
}
keys, err := b.getKeys(0)
if err != nil {
return 0, err
}
return int64(len(keys)), nil
}
func (b *Blocklist) getKeys(cursor uint64) ([]string, error) {
keys, cursor, err := b.client.Scan(defaultContext, cursor, b.Prefix+"*", 300000).Result()
if err != nil {
return nil, err
}
if cursor != 0 {
moreKeys, err := b.getKeys(cursor)
if err != nil {
return nil, err
}
keys = append(keys, moreKeys...)
}
return keys, nil
}

View File

@ -16,6 +16,9 @@ type Signer struct {
Alg Alg
Key interface{}
// MaxAge to set "exp" and "iat".
// Recommended value for access tokens: 15 minutes.
// Defaults to 0, no limit.
MaxAge time.Duration
Encrypt func([]byte) ([]byte, error)

View File

@ -112,7 +112,7 @@ func (v *Verifier) WithDefaultBlocklist() *Verifier {
func (v *Verifier) invalidate(ctx *context.Context) {
if verifiedToken := GetVerifiedToken(ctx); verifiedToken != nil {
v.Blocklist.InvalidateToken(verifiedToken.Token, verifiedToken.StandardClaims.Expiry)
v.Blocklist.InvalidateToken(verifiedToken.Token, verifiedToken.StandardClaims)
ctx.Values().Remove(claimsContextKey)
ctx.Values().Remove(verifiedTokenContextKey)
ctx.SetUser(nil)
@ -179,6 +179,19 @@ func (v *Verifier) Verify(claimsType func() interface{}, validators ...TokenVali
}
if v.Blocklist != nil {
// If blocklist implements the connect interface,
// try to connect if it's not already connected manually by developer,
// if errored then just return a handler which will fire this error every single time.
if bc, ok := v.Blocklist.(blocklistConnect); ok {
if !bc.IsConnected() {
if err := bc.Connect(); err != nil {
return func(ctx *context.Context) {
v.ErrorHandler(ctx, err)
}
}
}
}
validators = append([]TokenValidator{v.Blocklist}, append(v.Validators, validators...)...)
}

View File

@ -16,7 +16,7 @@ const (
DefaultRedisNetwork = "tcp"
// DefaultRedisAddr the redis address option, "127.0.0.1:6379".
DefaultRedisAddr = "127.0.0.1:6379"
// DefaultRedisTimeout the redis idle timeout option, time.Duration(30) * time.Second
// DefaultRedisTimeout the redis idle timeout option, time.Duration(30) * time.Second.
DefaultRedisTimeout = time.Duration(30) * time.Second
)