diff --git a/_examples/auth/jwt/blocklist/main.go b/_examples/auth/jwt/blocklist/main.go new file mode 100644 index 00000000..6549f658 --- /dev/null +++ b/_examples/auth/jwt/blocklist/main.go @@ -0,0 +1,110 @@ +package main + +import ( + "context" + "time" + + "github.com/kataras/iris/v12" + "github.com/kataras/iris/v12/middleware/jwt" + "github.com/kataras/iris/v12/middleware/jwt/blocklist/redis" + + // Optionally to set token identifier. + "github.com/google/uuid" +) + +var ( + signatureSharedKey = []byte("sercrethatmaycontainch@r32length") + + signer = jwt.NewSigner(jwt.HS256, signatureSharedKey, 15*time.Minute) + verifier = jwt.NewVerifier(jwt.HS256, signatureSharedKey) +) + +type userClaims struct { + Username string `json:"username"` +} + +func main() { + app := iris.New() + + // IMPORTANT + // + // To use the in-memory blocklist just: + // verifier.WithDefaultBlocklist() + // To use a persistence blocklist, e.g. redis, + // start your redis-server and: + blocklist := redis.NewBlocklist() + // To configure single client or a cluster one: + // blocklist.ClientOptions.Addr = "127.0.0.1:6379" + // blocklist.ClusterOptions.Addrs = []string{...} + // To set a prefix for jwt ids: + // blocklist.Prefix = "myapp-" + // + // To manually connect and check its error before continue: + // err := blocklist.Connect() + // By default the verifier will try to connect, if failed then it will throw http error. + // + // And then register it: + verifier.Blocklist = blocklist + verifyMiddleware := verifier.Verify(func() interface{} { + return new(userClaims) + }) + + app.Get("/", authenticate) + + protectedAPI := app.Party("/protected", verifyMiddleware) + protectedAPI.Get("/", protected) + protectedAPI.Get("/logout", logout) + + // http://localhost:8080 + // http://localhost:8080/protected?token=$token + // http://localhost:8080/logout?token=$token + // http://localhost:8080/protected?token=$token (401) + app.Listen(":8080") +} + +// generateID optionally to set the value for `jwt.ID` on Sign, +// which sets the standard claims value "jti". +// If you use a blocklist with the default Blocklist.GetKey you have to set it. +var generateID = func(*context.Context) string { + id, _ := uuid.NewRandom() + return id.String() +} + +func authenticate(ctx iris.Context) { + claims := userClaims{ + Username: "kataras", + } + + // Generate JWT ID. + random, err := uuid.NewRandom() + if err != nil { + ctx.StopWithError(iris.StatusInternalServerError, err) + return + } + id := random.String() + + // Set the ID with the jwt.ID. + token, err := signer.Sign(claims, jwt.ID(id)) + + if err != nil { + ctx.StopWithError(iris.StatusInternalServerError, err) + return + } + + ctx.Write(token) +} + +func protected(ctx iris.Context) { + claims := jwt.Get(ctx).(*userClaims) + + // To the standard claims, e.g. the generated ID: + // jwt.GetVerifiedToken(ctx).StandardClaims.ID + + ctx.WriteString(claims.Username) +} + +func logout(ctx iris.Context) { + ctx.Logout() + + ctx.Redirect("/", iris.StatusTemporaryRedirect) +} diff --git a/go.mod b/go.mod index 86a975a6..20a4c918 100644 --- a/go.mod +++ b/go.mod @@ -21,7 +21,7 @@ require ( github.com/json-iterator/go v1.1.10 github.com/kataras/blocks v0.0.4 github.com/kataras/golog v0.1.5 - github.com/kataras/jwt v0.0.4 + github.com/kataras/jwt v0.0.5 github.com/kataras/neffos v0.0.16 github.com/kataras/pio v0.0.10 github.com/kataras/sitemap v0.0.5 diff --git a/middleware/jwt/alises.go b/middleware/jwt/alises.go index df27cb5d..81b4f94a 100644 --- a/middleware/jwt/alises.go +++ b/middleware/jwt/alises.go @@ -1,7 +1,59 @@ package jwt -import ( - "github.com/kataras/jwt" +import "github.com/kataras/jwt" + +// Error values. +var ( + ErrBlocked = jwt.ErrBlocked + ErrDecrypt = jwt.ErrDecrypt + ErrExpected = jwt.ErrExpected + ErrExpired = jwt.ErrExpired + ErrInvalidKey = jwt.ErrInvalidKey + ErrIssuedInTheFuture = jwt.ErrIssuedInTheFuture + ErrMissing = jwt.ErrMissing + ErrMissingKey = jwt.ErrMissingKey + ErrNotValidYet = jwt.ErrNotValidYet + ErrTokenAlg = jwt.ErrTokenAlg + ErrTokenForm = jwt.ErrTokenForm + ErrTokenSignature = jwt.ErrTokenSignature +) + +// Signature algorithms. +var ( + EdDSA = jwt.EdDSA + HS256 = jwt.HS256 + HS384 = jwt.HS384 + HS512 = jwt.HS512 + RS256 = jwt.RS256 + RS384 = jwt.RS384 + RS512 = jwt.RS512 + ES256 = jwt.ES256 + ES384 = jwt.ES384 + ES512 = jwt.ES512 + PS256 = jwt.PS256 + PS384 = jwt.PS384 + PS512 = jwt.PS512 +) + +// Signature algorithm helpers. +var ( + MustLoadHMAC = jwt.MustLoadHMAC + LoadHMAC = jwt.LoadHMAC + MustLoadRSA = jwt.MustLoadRSA + LoadPrivateKeyRSA = jwt.LoadPrivateKeyRSA + LoadPublicKeyRSA = jwt.LoadPublicKeyRSA + ParsePrivateKeyRSA = jwt.ParsePrivateKeyRSA + ParsePublicKeyRSA = jwt.ParsePublicKeyRSA + MustLoadECDSA = jwt.MustLoadECDSA + LoadPrivateKeyECDSA = jwt.LoadPrivateKeyECDSA + LoadPublicKeyECDSA = jwt.LoadPublicKeyECDSA + ParsePrivateKeyECDSA = jwt.ParsePrivateKeyECDSA + ParsePublicKeyECDSA = jwt.ParsePublicKeyECDSA + MustLoadEdDSA = jwt.MustLoadEdDSA + LoadPrivateKeyEdDSA = jwt.LoadPrivateKeyEdDSA + LoadPublicKeyEdDSA = jwt.LoadPublicKeyEdDSA + ParsePrivateKeyEdDSA = jwt.ParsePrivateKeyEdDSA + ParsePublicKeyEdDSA = jwt.ParsePublicKeyEdDSA ) // Type alises for the underline jwt package. @@ -31,23 +83,6 @@ type ( TokenPair = jwt.TokenPair ) -// Signature algorithms. -var ( - EdDSA = jwt.EdDSA - HS256 = jwt.HS256 - HS384 = jwt.HS384 - HS512 = jwt.HS512 - RS256 = jwt.RS256 - RS384 = jwt.RS384 - RS512 = jwt.RS512 - ES256 = jwt.ES256 - ES384 = jwt.ES384 - ES512 = jwt.ES512 - PS256 = jwt.PS256 - PS384 = jwt.PS384 - PS512 = jwt.PS512 -) - // Encryption algorithms. var ( GCM = jwt.GCM @@ -73,6 +108,13 @@ var ( // Usage: // signer.Sign(..., jwt.MaxAge(15*time.Minute)) MaxAge = jwt.MaxAge + + // ID is a shurtcut to set jwt ID on Sign. + ID = func(id string) jwt.SignOptionFunc { + return func(c *Claims) { + c.ID = id + } + } ) // Shortcuts for Signing and Verifying. @@ -82,24 +124,3 @@ var ( Sign = jwt.Sign SignEncrypted = jwt.SignEncrypted ) - -// Signature algorithm helpers. -var ( - MustLoadHMAC = jwt.MustLoadHMAC - LoadHMAC = jwt.LoadHMAC - MustLoadRSA = jwt.MustLoadRSA - LoadPrivateKeyRSA = jwt.LoadPrivateKeyRSA - LoadPublicKeyRSA = jwt.LoadPublicKeyRSA - ParsePrivateKeyRSA = jwt.ParsePrivateKeyRSA - ParsePublicKeyRSA = jwt.ParsePublicKeyRSA - MustLoadECDSA = jwt.MustLoadECDSA - LoadPrivateKeyECDSA = jwt.LoadPrivateKeyECDSA - LoadPublicKeyECDSA = jwt.LoadPublicKeyECDSA - ParsePrivateKeyECDSA = jwt.ParsePrivateKeyECDSA - ParsePublicKeyECDSA = jwt.ParsePublicKeyECDSA - MustLoadEdDSA = jwt.MustLoadEdDSA - LoadPrivateKeyEdDSA = jwt.LoadPrivateKeyEdDSA - LoadPublicKeyEdDSA = jwt.LoadPublicKeyEdDSA - ParsePrivateKeyEdDSA = jwt.ParsePrivateKeyEdDSA - ParsePublicKeyEdDSA = jwt.ParsePublicKeyEdDSA -) diff --git a/middleware/jwt/blocklist.go b/middleware/jwt/blocklist.go index 1c827969..da1ed6e6 100644 --- a/middleware/jwt/blocklist.go +++ b/middleware/jwt/blocklist.go @@ -16,11 +16,16 @@ type Blocklist interface { jwt.TokenValidator // InvalidateToken should invalidate a verified JWT token. - InvalidateToken(token []byte, expiry int64) + InvalidateToken(token []byte, c Claims) error // Del should remove a token from the storage. - Del(token []byte) - // Count should return the total amount of tokens stored. - Count() int + Del(key string) error // Has should report whether a specific token exists in the storage. - Has(token []byte) bool + Has(key string) (bool, error) + // Count should return the total amount of tokens stored. + Count() (int64, error) +} + +type blocklistConnect interface { + Connect() error + IsConnected() bool } diff --git a/middleware/jwt/blocklist/redis/blocklist.go b/middleware/jwt/blocklist/redis/blocklist.go new file mode 100644 index 00000000..831429ef --- /dev/null +++ b/middleware/jwt/blocklist/redis/blocklist.go @@ -0,0 +1,188 @@ +package redis + +import ( + "context" + "io" + "sync/atomic" + "time" + + "github.com/kataras/iris/v12/core/host" + "github.com/kataras/iris/v12/middleware/jwt" + + "github.com/go-redis/redis/v8" +) + +var defaultContext = context.Background() + +type ( + // Options is just a type alias for the go-redis Client Options. + Options = redis.Options + // ClusterOptions is just a type alias for the go-redis Cluster Client Options. + ClusterOptions = redis.ClusterOptions +) + +// Client is the interface which both +// go-redis Client and Cluster Client implements. +type Client interface { + redis.Cmdable // Commands. + io.Closer // CloseConnection. +} + +// Blocklist is a jwt.Blocklist backed by Redis. +type Blocklist struct { + Clock func() time.Time // Required. Defaults to time.Now. + // GetKey is a function which can be used how to extract + // the unique identifier for a token. + // Required. By default the token key is extracted through the claims.ID ("jti"). + GetKey func(token []byte, claims jwt.Claims) string + // Prefix the token key into the redis database. + // Note that if you can also select a different database + // through ClientOptions (or ClusterOptions). + // Defaults to empty string (no prefix). + Prefix string + // Both Client and ClusterClient implements this interface. + client Client + connected uint32 + // Customize any go-redis fields manually + // before Connect. + ClientOptions Options + ClusterOptions ClusterOptions +} + +var _ jwt.Blocklist = (*Blocklist)(nil) + +// NewBlocklist returns a new redis-based Blocklist. +// Modify its ClientOptions or ClusterOptions depending the application needs +// and call its Connect. +// +// Usage: +// +// blocklist := NewBlocklist() +// blocklist.ClientOptions.Addr = ... +// err := blocklist.Connect() +// +// And register it: +// +// verifier := jwt.NewVerifier(...) +// verifier.Blocklist = blocklist +func NewBlocklist() *Blocklist { + return &Blocklist{ + Clock: time.Now, + GetKey: defaultGetKey, + Prefix: "", + ClientOptions: Options{ + Addr: "127.0.0.1:6379", + // The rest are defaulted to good values already. + }, + // If its Addrs > 0 before connect then cluster client is used instead. + ClusterOptions: ClusterOptions{}, + } +} + +func defaultGetKey(_ []byte, claims jwt.Claims) string { + return claims.ID +} + +// Connect prepares the redis client and fires a ping response to it. +func (b *Blocklist) Connect() error { + if b.Prefix != "" { + getKey := b.GetKey + b.GetKey = func(token []byte, claims jwt.Claims) string { + return b.Prefix + getKey(token, claims) + } + } + + if len(b.ClusterOptions.Addrs) > 0 { + // Use cluster client. + b.client = redis.NewClusterClient(&b.ClusterOptions) + } else { + b.client = redis.NewClient(&b.ClientOptions) + } + + _, err := b.client.Ping(defaultContext).Result() + if err != nil { + return err + } + + host.RegisterOnInterrupt(func() { + atomic.StoreUint32(&b.connected, 0) + b.client.Close() + }) + atomic.StoreUint32(&b.connected, 1) + + return nil +} + +// IsConnected reports whether the Connect function was called. +func (b *Blocklist) IsConnected() bool { + return atomic.LoadUint32(&b.connected) > 0 +} + +// ValidateToken checks if the token exists and +func (b *Blocklist) ValidateToken(token []byte, c jwt.Claims, err error) error { + if err != nil { + if err == jwt.ErrExpired { + b.Del(b.GetKey(token, c)) + } + + return err // respect the previous error. + } + + has, err := b.Has(b.GetKey(token, c)) + if err != nil { + return err + } else if has { + return jwt.ErrBlocked + } + + return nil +} + +// InvalidateToken invalidates a verified JWT token. +func (b *Blocklist) InvalidateToken(token []byte, c jwt.Claims) error { + key := b.GetKey(token, c) + return b.client.SetEX(defaultContext, key, token, c.Timeleft()).Err() +} + +// Del removes a token from the storage. +func (b *Blocklist) Del(key string) error { + return b.client.Del(defaultContext, key).Err() +} + +// Has reports whether a specific token exists in the storage. +func (b *Blocklist) Has(key string) (bool, error) { + n, err := b.client.Exists(defaultContext, key).Result() + return n > 0, err +} + +// Count returns the total amount of tokens stored. +func (b *Blocklist) Count() (int64, error) { + if b.Prefix == "" { + return b.client.DBSize(defaultContext).Result() + } + + keys, err := b.getKeys(0) + if err != nil { + return 0, err + } + + return int64(len(keys)), nil +} + +func (b *Blocklist) getKeys(cursor uint64) ([]string, error) { + keys, cursor, err := b.client.Scan(defaultContext, cursor, b.Prefix+"*", 300000).Result() + if err != nil { + return nil, err + } + + if cursor != 0 { + moreKeys, err := b.getKeys(cursor) + if err != nil { + return nil, err + } + + keys = append(keys, moreKeys...) + } + + return keys, nil +} diff --git a/middleware/jwt/signer.go b/middleware/jwt/signer.go index acd74237..a121990a 100644 --- a/middleware/jwt/signer.go +++ b/middleware/jwt/signer.go @@ -16,6 +16,9 @@ type Signer struct { Alg Alg Key interface{} + // MaxAge to set "exp" and "iat". + // Recommended value for access tokens: 15 minutes. + // Defaults to 0, no limit. MaxAge time.Duration Encrypt func([]byte) ([]byte, error) diff --git a/middleware/jwt/verifier.go b/middleware/jwt/verifier.go index 1d354ec9..6df2788e 100644 --- a/middleware/jwt/verifier.go +++ b/middleware/jwt/verifier.go @@ -112,7 +112,7 @@ func (v *Verifier) WithDefaultBlocklist() *Verifier { func (v *Verifier) invalidate(ctx *context.Context) { if verifiedToken := GetVerifiedToken(ctx); verifiedToken != nil { - v.Blocklist.InvalidateToken(verifiedToken.Token, verifiedToken.StandardClaims.Expiry) + v.Blocklist.InvalidateToken(verifiedToken.Token, verifiedToken.StandardClaims) ctx.Values().Remove(claimsContextKey) ctx.Values().Remove(verifiedTokenContextKey) ctx.SetUser(nil) @@ -179,6 +179,19 @@ func (v *Verifier) Verify(claimsType func() interface{}, validators ...TokenVali } if v.Blocklist != nil { + // If blocklist implements the connect interface, + // try to connect if it's not already connected manually by developer, + // if errored then just return a handler which will fire this error every single time. + if bc, ok := v.Blocklist.(blocklistConnect); ok { + if !bc.IsConnected() { + if err := bc.Connect(); err != nil { + return func(ctx *context.Context) { + v.ErrorHandler(ctx, err) + } + } + } + } + validators = append([]TokenValidator{v.Blocklist}, append(v.Validators, validators...)...) } diff --git a/sessions/sessiondb/redis/database.go b/sessions/sessiondb/redis/database.go index 6eabee94..1773a208 100644 --- a/sessions/sessiondb/redis/database.go +++ b/sessions/sessiondb/redis/database.go @@ -16,7 +16,7 @@ const ( DefaultRedisNetwork = "tcp" // DefaultRedisAddr the redis address option, "127.0.0.1:6379". DefaultRedisAddr = "127.0.0.1:6379" - // DefaultRedisTimeout the redis idle timeout option, time.Duration(30) * time.Second + // DefaultRedisTimeout the redis idle timeout option, time.Duration(30) * time.Second. DefaultRedisTimeout = time.Duration(30) * time.Second )