mirror of
https://github.com/kataras/iris.git
synced 2025-02-02 07:20:35 +01:00
jwt: add redis blocklist
This commit is contained in:
parent
836fb18c57
commit
f1ebddb6d9
110
_examples/auth/jwt/blocklist/main.go
Normal file
110
_examples/auth/jwt/blocklist/main.go
Normal file
|
@ -0,0 +1,110 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/kataras/iris/v12"
|
||||
"github.com/kataras/iris/v12/middleware/jwt"
|
||||
"github.com/kataras/iris/v12/middleware/jwt/blocklist/redis"
|
||||
|
||||
// Optionally to set token identifier.
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
var (
|
||||
signatureSharedKey = []byte("sercrethatmaycontainch@r32length")
|
||||
|
||||
signer = jwt.NewSigner(jwt.HS256, signatureSharedKey, 15*time.Minute)
|
||||
verifier = jwt.NewVerifier(jwt.HS256, signatureSharedKey)
|
||||
)
|
||||
|
||||
type userClaims struct {
|
||||
Username string `json:"username"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
app := iris.New()
|
||||
|
||||
// IMPORTANT
|
||||
//
|
||||
// To use the in-memory blocklist just:
|
||||
// verifier.WithDefaultBlocklist()
|
||||
// To use a persistence blocklist, e.g. redis,
|
||||
// start your redis-server and:
|
||||
blocklist := redis.NewBlocklist()
|
||||
// To configure single client or a cluster one:
|
||||
// blocklist.ClientOptions.Addr = "127.0.0.1:6379"
|
||||
// blocklist.ClusterOptions.Addrs = []string{...}
|
||||
// To set a prefix for jwt ids:
|
||||
// blocklist.Prefix = "myapp-"
|
||||
//
|
||||
// To manually connect and check its error before continue:
|
||||
// err := blocklist.Connect()
|
||||
// By default the verifier will try to connect, if failed then it will throw http error.
|
||||
//
|
||||
// And then register it:
|
||||
verifier.Blocklist = blocklist
|
||||
verifyMiddleware := verifier.Verify(func() interface{} {
|
||||
return new(userClaims)
|
||||
})
|
||||
|
||||
app.Get("/", authenticate)
|
||||
|
||||
protectedAPI := app.Party("/protected", verifyMiddleware)
|
||||
protectedAPI.Get("/", protected)
|
||||
protectedAPI.Get("/logout", logout)
|
||||
|
||||
// http://localhost:8080
|
||||
// http://localhost:8080/protected?token=$token
|
||||
// http://localhost:8080/logout?token=$token
|
||||
// http://localhost:8080/protected?token=$token (401)
|
||||
app.Listen(":8080")
|
||||
}
|
||||
|
||||
// generateID optionally to set the value for `jwt.ID` on Sign,
|
||||
// which sets the standard claims value "jti".
|
||||
// If you use a blocklist with the default Blocklist.GetKey you have to set it.
|
||||
var generateID = func(*context.Context) string {
|
||||
id, _ := uuid.NewRandom()
|
||||
return id.String()
|
||||
}
|
||||
|
||||
func authenticate(ctx iris.Context) {
|
||||
claims := userClaims{
|
||||
Username: "kataras",
|
||||
}
|
||||
|
||||
// Generate JWT ID.
|
||||
random, err := uuid.NewRandom()
|
||||
if err != nil {
|
||||
ctx.StopWithError(iris.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
id := random.String()
|
||||
|
||||
// Set the ID with the jwt.ID.
|
||||
token, err := signer.Sign(claims, jwt.ID(id))
|
||||
|
||||
if err != nil {
|
||||
ctx.StopWithError(iris.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx.Write(token)
|
||||
}
|
||||
|
||||
func protected(ctx iris.Context) {
|
||||
claims := jwt.Get(ctx).(*userClaims)
|
||||
|
||||
// To the standard claims, e.g. the generated ID:
|
||||
// jwt.GetVerifiedToken(ctx).StandardClaims.ID
|
||||
|
||||
ctx.WriteString(claims.Username)
|
||||
}
|
||||
|
||||
func logout(ctx iris.Context) {
|
||||
ctx.Logout()
|
||||
|
||||
ctx.Redirect("/", iris.StatusTemporaryRedirect)
|
||||
}
|
2
go.mod
2
go.mod
|
@ -21,7 +21,7 @@ require (
|
|||
github.com/json-iterator/go v1.1.10
|
||||
github.com/kataras/blocks v0.0.4
|
||||
github.com/kataras/golog v0.1.5
|
||||
github.com/kataras/jwt v0.0.4
|
||||
github.com/kataras/jwt v0.0.5
|
||||
github.com/kataras/neffos v0.0.16
|
||||
github.com/kataras/pio v0.0.10
|
||||
github.com/kataras/sitemap v0.0.5
|
||||
|
|
|
@ -1,7 +1,59 @@
|
|||
package jwt
|
||||
|
||||
import (
|
||||
"github.com/kataras/jwt"
|
||||
import "github.com/kataras/jwt"
|
||||
|
||||
// Error values.
|
||||
var (
|
||||
ErrBlocked = jwt.ErrBlocked
|
||||
ErrDecrypt = jwt.ErrDecrypt
|
||||
ErrExpected = jwt.ErrExpected
|
||||
ErrExpired = jwt.ErrExpired
|
||||
ErrInvalidKey = jwt.ErrInvalidKey
|
||||
ErrIssuedInTheFuture = jwt.ErrIssuedInTheFuture
|
||||
ErrMissing = jwt.ErrMissing
|
||||
ErrMissingKey = jwt.ErrMissingKey
|
||||
ErrNotValidYet = jwt.ErrNotValidYet
|
||||
ErrTokenAlg = jwt.ErrTokenAlg
|
||||
ErrTokenForm = jwt.ErrTokenForm
|
||||
ErrTokenSignature = jwt.ErrTokenSignature
|
||||
)
|
||||
|
||||
// Signature algorithms.
|
||||
var (
|
||||
EdDSA = jwt.EdDSA
|
||||
HS256 = jwt.HS256
|
||||
HS384 = jwt.HS384
|
||||
HS512 = jwt.HS512
|
||||
RS256 = jwt.RS256
|
||||
RS384 = jwt.RS384
|
||||
RS512 = jwt.RS512
|
||||
ES256 = jwt.ES256
|
||||
ES384 = jwt.ES384
|
||||
ES512 = jwt.ES512
|
||||
PS256 = jwt.PS256
|
||||
PS384 = jwt.PS384
|
||||
PS512 = jwt.PS512
|
||||
)
|
||||
|
||||
// Signature algorithm helpers.
|
||||
var (
|
||||
MustLoadHMAC = jwt.MustLoadHMAC
|
||||
LoadHMAC = jwt.LoadHMAC
|
||||
MustLoadRSA = jwt.MustLoadRSA
|
||||
LoadPrivateKeyRSA = jwt.LoadPrivateKeyRSA
|
||||
LoadPublicKeyRSA = jwt.LoadPublicKeyRSA
|
||||
ParsePrivateKeyRSA = jwt.ParsePrivateKeyRSA
|
||||
ParsePublicKeyRSA = jwt.ParsePublicKeyRSA
|
||||
MustLoadECDSA = jwt.MustLoadECDSA
|
||||
LoadPrivateKeyECDSA = jwt.LoadPrivateKeyECDSA
|
||||
LoadPublicKeyECDSA = jwt.LoadPublicKeyECDSA
|
||||
ParsePrivateKeyECDSA = jwt.ParsePrivateKeyECDSA
|
||||
ParsePublicKeyECDSA = jwt.ParsePublicKeyECDSA
|
||||
MustLoadEdDSA = jwt.MustLoadEdDSA
|
||||
LoadPrivateKeyEdDSA = jwt.LoadPrivateKeyEdDSA
|
||||
LoadPublicKeyEdDSA = jwt.LoadPublicKeyEdDSA
|
||||
ParsePrivateKeyEdDSA = jwt.ParsePrivateKeyEdDSA
|
||||
ParsePublicKeyEdDSA = jwt.ParsePublicKeyEdDSA
|
||||
)
|
||||
|
||||
// Type alises for the underline jwt package.
|
||||
|
@ -31,23 +83,6 @@ type (
|
|||
TokenPair = jwt.TokenPair
|
||||
)
|
||||
|
||||
// Signature algorithms.
|
||||
var (
|
||||
EdDSA = jwt.EdDSA
|
||||
HS256 = jwt.HS256
|
||||
HS384 = jwt.HS384
|
||||
HS512 = jwt.HS512
|
||||
RS256 = jwt.RS256
|
||||
RS384 = jwt.RS384
|
||||
RS512 = jwt.RS512
|
||||
ES256 = jwt.ES256
|
||||
ES384 = jwt.ES384
|
||||
ES512 = jwt.ES512
|
||||
PS256 = jwt.PS256
|
||||
PS384 = jwt.PS384
|
||||
PS512 = jwt.PS512
|
||||
)
|
||||
|
||||
// Encryption algorithms.
|
||||
var (
|
||||
GCM = jwt.GCM
|
||||
|
@ -73,6 +108,13 @@ var (
|
|||
// Usage:
|
||||
// signer.Sign(..., jwt.MaxAge(15*time.Minute))
|
||||
MaxAge = jwt.MaxAge
|
||||
|
||||
// ID is a shurtcut to set jwt ID on Sign.
|
||||
ID = func(id string) jwt.SignOptionFunc {
|
||||
return func(c *Claims) {
|
||||
c.ID = id
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
// Shortcuts for Signing and Verifying.
|
||||
|
@ -82,24 +124,3 @@ var (
|
|||
Sign = jwt.Sign
|
||||
SignEncrypted = jwt.SignEncrypted
|
||||
)
|
||||
|
||||
// Signature algorithm helpers.
|
||||
var (
|
||||
MustLoadHMAC = jwt.MustLoadHMAC
|
||||
LoadHMAC = jwt.LoadHMAC
|
||||
MustLoadRSA = jwt.MustLoadRSA
|
||||
LoadPrivateKeyRSA = jwt.LoadPrivateKeyRSA
|
||||
LoadPublicKeyRSA = jwt.LoadPublicKeyRSA
|
||||
ParsePrivateKeyRSA = jwt.ParsePrivateKeyRSA
|
||||
ParsePublicKeyRSA = jwt.ParsePublicKeyRSA
|
||||
MustLoadECDSA = jwt.MustLoadECDSA
|
||||
LoadPrivateKeyECDSA = jwt.LoadPrivateKeyECDSA
|
||||
LoadPublicKeyECDSA = jwt.LoadPublicKeyECDSA
|
||||
ParsePrivateKeyECDSA = jwt.ParsePrivateKeyECDSA
|
||||
ParsePublicKeyECDSA = jwt.ParsePublicKeyECDSA
|
||||
MustLoadEdDSA = jwt.MustLoadEdDSA
|
||||
LoadPrivateKeyEdDSA = jwt.LoadPrivateKeyEdDSA
|
||||
LoadPublicKeyEdDSA = jwt.LoadPublicKeyEdDSA
|
||||
ParsePrivateKeyEdDSA = jwt.ParsePrivateKeyEdDSA
|
||||
ParsePublicKeyEdDSA = jwt.ParsePublicKeyEdDSA
|
||||
)
|
||||
|
|
|
@ -16,11 +16,16 @@ type Blocklist interface {
|
|||
jwt.TokenValidator
|
||||
|
||||
// InvalidateToken should invalidate a verified JWT token.
|
||||
InvalidateToken(token []byte, expiry int64)
|
||||
InvalidateToken(token []byte, c Claims) error
|
||||
// Del should remove a token from the storage.
|
||||
Del(token []byte)
|
||||
// Count should return the total amount of tokens stored.
|
||||
Count() int
|
||||
Del(key string) error
|
||||
// Has should report whether a specific token exists in the storage.
|
||||
Has(token []byte) bool
|
||||
Has(key string) (bool, error)
|
||||
// Count should return the total amount of tokens stored.
|
||||
Count() (int64, error)
|
||||
}
|
||||
|
||||
type blocklistConnect interface {
|
||||
Connect() error
|
||||
IsConnected() bool
|
||||
}
|
||||
|
|
188
middleware/jwt/blocklist/redis/blocklist.go
Normal file
188
middleware/jwt/blocklist/redis/blocklist.go
Normal file
|
@ -0,0 +1,188 @@
|
|||
package redis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/kataras/iris/v12/core/host"
|
||||
"github.com/kataras/iris/v12/middleware/jwt"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
)
|
||||
|
||||
var defaultContext = context.Background()
|
||||
|
||||
type (
|
||||
// Options is just a type alias for the go-redis Client Options.
|
||||
Options = redis.Options
|
||||
// ClusterOptions is just a type alias for the go-redis Cluster Client Options.
|
||||
ClusterOptions = redis.ClusterOptions
|
||||
)
|
||||
|
||||
// Client is the interface which both
|
||||
// go-redis Client and Cluster Client implements.
|
||||
type Client interface {
|
||||
redis.Cmdable // Commands.
|
||||
io.Closer // CloseConnection.
|
||||
}
|
||||
|
||||
// Blocklist is a jwt.Blocklist backed by Redis.
|
||||
type Blocklist struct {
|
||||
Clock func() time.Time // Required. Defaults to time.Now.
|
||||
// GetKey is a function which can be used how to extract
|
||||
// the unique identifier for a token.
|
||||
// Required. By default the token key is extracted through the claims.ID ("jti").
|
||||
GetKey func(token []byte, claims jwt.Claims) string
|
||||
// Prefix the token key into the redis database.
|
||||
// Note that if you can also select a different database
|
||||
// through ClientOptions (or ClusterOptions).
|
||||
// Defaults to empty string (no prefix).
|
||||
Prefix string
|
||||
// Both Client and ClusterClient implements this interface.
|
||||
client Client
|
||||
connected uint32
|
||||
// Customize any go-redis fields manually
|
||||
// before Connect.
|
||||
ClientOptions Options
|
||||
ClusterOptions ClusterOptions
|
||||
}
|
||||
|
||||
var _ jwt.Blocklist = (*Blocklist)(nil)
|
||||
|
||||
// NewBlocklist returns a new redis-based Blocklist.
|
||||
// Modify its ClientOptions or ClusterOptions depending the application needs
|
||||
// and call its Connect.
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// blocklist := NewBlocklist()
|
||||
// blocklist.ClientOptions.Addr = ...
|
||||
// err := blocklist.Connect()
|
||||
//
|
||||
// And register it:
|
||||
//
|
||||
// verifier := jwt.NewVerifier(...)
|
||||
// verifier.Blocklist = blocklist
|
||||
func NewBlocklist() *Blocklist {
|
||||
return &Blocklist{
|
||||
Clock: time.Now,
|
||||
GetKey: defaultGetKey,
|
||||
Prefix: "",
|
||||
ClientOptions: Options{
|
||||
Addr: "127.0.0.1:6379",
|
||||
// The rest are defaulted to good values already.
|
||||
},
|
||||
// If its Addrs > 0 before connect then cluster client is used instead.
|
||||
ClusterOptions: ClusterOptions{},
|
||||
}
|
||||
}
|
||||
|
||||
func defaultGetKey(_ []byte, claims jwt.Claims) string {
|
||||
return claims.ID
|
||||
}
|
||||
|
||||
// Connect prepares the redis client and fires a ping response to it.
|
||||
func (b *Blocklist) Connect() error {
|
||||
if b.Prefix != "" {
|
||||
getKey := b.GetKey
|
||||
b.GetKey = func(token []byte, claims jwt.Claims) string {
|
||||
return b.Prefix + getKey(token, claims)
|
||||
}
|
||||
}
|
||||
|
||||
if len(b.ClusterOptions.Addrs) > 0 {
|
||||
// Use cluster client.
|
||||
b.client = redis.NewClusterClient(&b.ClusterOptions)
|
||||
} else {
|
||||
b.client = redis.NewClient(&b.ClientOptions)
|
||||
}
|
||||
|
||||
_, err := b.client.Ping(defaultContext).Result()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
host.RegisterOnInterrupt(func() {
|
||||
atomic.StoreUint32(&b.connected, 0)
|
||||
b.client.Close()
|
||||
})
|
||||
atomic.StoreUint32(&b.connected, 1)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsConnected reports whether the Connect function was called.
|
||||
func (b *Blocklist) IsConnected() bool {
|
||||
return atomic.LoadUint32(&b.connected) > 0
|
||||
}
|
||||
|
||||
// ValidateToken checks if the token exists and
|
||||
func (b *Blocklist) ValidateToken(token []byte, c jwt.Claims, err error) error {
|
||||
if err != nil {
|
||||
if err == jwt.ErrExpired {
|
||||
b.Del(b.GetKey(token, c))
|
||||
}
|
||||
|
||||
return err // respect the previous error.
|
||||
}
|
||||
|
||||
has, err := b.Has(b.GetKey(token, c))
|
||||
if err != nil {
|
||||
return err
|
||||
} else if has {
|
||||
return jwt.ErrBlocked
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// InvalidateToken invalidates a verified JWT token.
|
||||
func (b *Blocklist) InvalidateToken(token []byte, c jwt.Claims) error {
|
||||
key := b.GetKey(token, c)
|
||||
return b.client.SetEX(defaultContext, key, token, c.Timeleft()).Err()
|
||||
}
|
||||
|
||||
// Del removes a token from the storage.
|
||||
func (b *Blocklist) Del(key string) error {
|
||||
return b.client.Del(defaultContext, key).Err()
|
||||
}
|
||||
|
||||
// Has reports whether a specific token exists in the storage.
|
||||
func (b *Blocklist) Has(key string) (bool, error) {
|
||||
n, err := b.client.Exists(defaultContext, key).Result()
|
||||
return n > 0, err
|
||||
}
|
||||
|
||||
// Count returns the total amount of tokens stored.
|
||||
func (b *Blocklist) Count() (int64, error) {
|
||||
if b.Prefix == "" {
|
||||
return b.client.DBSize(defaultContext).Result()
|
||||
}
|
||||
|
||||
keys, err := b.getKeys(0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return int64(len(keys)), nil
|
||||
}
|
||||
|
||||
func (b *Blocklist) getKeys(cursor uint64) ([]string, error) {
|
||||
keys, cursor, err := b.client.Scan(defaultContext, cursor, b.Prefix+"*", 300000).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if cursor != 0 {
|
||||
moreKeys, err := b.getKeys(cursor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
keys = append(keys, moreKeys...)
|
||||
}
|
||||
|
||||
return keys, nil
|
||||
}
|
|
@ -16,6 +16,9 @@ type Signer struct {
|
|||
Alg Alg
|
||||
Key interface{}
|
||||
|
||||
// MaxAge to set "exp" and "iat".
|
||||
// Recommended value for access tokens: 15 minutes.
|
||||
// Defaults to 0, no limit.
|
||||
MaxAge time.Duration
|
||||
|
||||
Encrypt func([]byte) ([]byte, error)
|
||||
|
|
|
@ -112,7 +112,7 @@ func (v *Verifier) WithDefaultBlocklist() *Verifier {
|
|||
|
||||
func (v *Verifier) invalidate(ctx *context.Context) {
|
||||
if verifiedToken := GetVerifiedToken(ctx); verifiedToken != nil {
|
||||
v.Blocklist.InvalidateToken(verifiedToken.Token, verifiedToken.StandardClaims.Expiry)
|
||||
v.Blocklist.InvalidateToken(verifiedToken.Token, verifiedToken.StandardClaims)
|
||||
ctx.Values().Remove(claimsContextKey)
|
||||
ctx.Values().Remove(verifiedTokenContextKey)
|
||||
ctx.SetUser(nil)
|
||||
|
@ -179,6 +179,19 @@ func (v *Verifier) Verify(claimsType func() interface{}, validators ...TokenVali
|
|||
}
|
||||
|
||||
if v.Blocklist != nil {
|
||||
// If blocklist implements the connect interface,
|
||||
// try to connect if it's not already connected manually by developer,
|
||||
// if errored then just return a handler which will fire this error every single time.
|
||||
if bc, ok := v.Blocklist.(blocklistConnect); ok {
|
||||
if !bc.IsConnected() {
|
||||
if err := bc.Connect(); err != nil {
|
||||
return func(ctx *context.Context) {
|
||||
v.ErrorHandler(ctx, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
validators = append([]TokenValidator{v.Blocklist}, append(v.Validators, validators...)...)
|
||||
}
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ const (
|
|||
DefaultRedisNetwork = "tcp"
|
||||
// DefaultRedisAddr the redis address option, "127.0.0.1:6379".
|
||||
DefaultRedisAddr = "127.0.0.1:6379"
|
||||
// DefaultRedisTimeout the redis idle timeout option, time.Duration(30) * time.Second
|
||||
// DefaultRedisTimeout the redis idle timeout option, time.Duration(30) * time.Second.
|
||||
DefaultRedisTimeout = time.Duration(30) * time.Second
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user