iris/middleware/jwt/verifier.go

211 lines
5.7 KiB
Go
Raw Normal View History

package jwt
import (
"reflect"
"time"
"github.com/kataras/iris/v12/context"
"github.com/kataras/jwt"
)
const (
claimsContextKey = "iris.jwt.claims"
verifiedTokenContextKey = "iris.jwt.token"
)
// Get returns the claims decoded by a verifier.
func Get(ctx *context.Context) interface{} {
if v := ctx.Values().Get(claimsContextKey); v != nil {
return v
}
return nil
}
// GetVerifiedToken returns the verified token structure
// which holds information about the decoded token
// and its standard claims.
func GetVerifiedToken(ctx *context.Context) *VerifiedToken {
if v := ctx.Values().Get(verifiedTokenContextKey); v != nil {
if tok, ok := v.(*VerifiedToken); ok {
return tok
}
}
return nil
}
// Verifier holds common options to verify an incoming token.
// Its Verify method can be used as a middleware to allow authorized clients to access an API.
type Verifier struct {
Alg Alg
Key interface{}
Decrypt func([]byte) ([]byte, error)
Extractors []TokenExtractor
Blocklist Blocklist
Validators []TokenValidator
ErrorHandler func(ctx *context.Context, err error)
}
// NewVerifier accepts the algorithm for the token's signature among with its (private) key
// and optionally some token validators for all verify middlewares that may initialized under this Verifier.
//
// See its Verify method.
func NewVerifier(signatureAlg Alg, signatureKey interface{}, validators ...TokenValidator) *Verifier {
return &Verifier{
Alg: signatureAlg,
Key: signatureKey,
Extractors: []TokenExtractor{FromHeader, FromQuery},
ErrorHandler: func(ctx *context.Context, err error) {
ctx.StopWithError(401, context.PrivateError(err))
},
Validators: validators,
}
}
// WithGCM enables AES-GCM payload encryption.
func (v *Verifier) WithGCM(key, additionalData []byte) *Verifier {
_, decrypt, err := jwt.GCM(key, additionalData)
if err != nil {
panic(err) // important error before serve, stop everything.
}
v.Decrypt = decrypt
return v
}
// WithDefaultBlocklist attaches an in-memory blocklist storage
// to invalidate tokens through server-side.
// To invalidate a token simply call the Context.Logout method.
func (v *Verifier) WithDefaultBlocklist() *Verifier {
v.Blocklist = jwt.NewBlocklist(30 * time.Minute)
return v
}
func (v *Verifier) invalidate(ctx *context.Context) {
if verifiedToken := GetVerifiedToken(ctx); verifiedToken != nil {
v.Blocklist.InvalidateToken(verifiedToken.Token, verifiedToken.StandardClaims.Expiry)
ctx.Values().Remove(claimsContextKey)
ctx.Values().Remove(verifiedTokenContextKey)
ctx.SetUser(nil)
ctx.SetLogoutFunc(nil)
}
}
// RequestToken extracts the token from the request.
func (v *Verifier) RequestToken(ctx *context.Context) (token string) {
for _, extract := range v.Extractors {
if token = extract(ctx); token != "" {
break // ok we found it.
}
}
return
}
type (
// ClaimsValidator is a special interface which, if the destination claims
// implements it then the verifier runs its Validate method before return.
ClaimsValidator interface {
Validate() error
}
// ClaimsContextValidator same as ClaimsValidator but it accepts
// a request context which can be used for further checks before
// validating the incoming token's claims.
ClaimsContextValidator interface {
Validate(*context.Context) error
}
)
// VerifyToken simply verifies the given "token" and validates its standard claims (such as expiration).
// Returns a structure which holds the token's information. See the Verify method instead.
func (v *Verifier) VerifyToken(token []byte, validators ...TokenValidator) (*VerifiedToken, error) {
return jwt.VerifyEncrypted(v.Alg, v.Key, v.Decrypt, token, validators...)
}
// Verify is the most important piece of code inside the Verifier.
// It accepts the "claimsType" function which should return a pointer to a custom structure
// which the token's decode claims valuee will be binded and validated to.
// Returns a common Iris handler which can be used as a middleware to protect an API
// from unauthorized client requests. After this, the route handlers can access the claims
// through the jwt.Get package-level function.
//
// Example Code:
func (v *Verifier) Verify(claimsType func() interface{}, validators ...TokenValidator) context.Handler {
unmarshal := jwt.Unmarshal
if claimsType != nil {
c := claimsType()
if hasRequired(c) {
unmarshal = jwt.UnmarshalWithRequired
}
}
if v.Blocklist != nil {
validators = append([]TokenValidator{v.Blocklist}, append(v.Validators, validators...)...)
}
return func(ctx *context.Context) {
token := []byte(v.RequestToken(ctx))
verifiedToken, err := v.VerifyToken(token, validators...)
if err != nil {
v.ErrorHandler(ctx, err)
return
}
if claimsType != nil {
dest := claimsType()
if err = unmarshal(verifiedToken.Payload, dest); err != nil {
v.ErrorHandler(ctx, err)
return
}
if validator, ok := dest.(ClaimsValidator); ok {
if err = validator.Validate(); err != nil {
v.ErrorHandler(ctx, err)
return
}
} else if contextValidator, ok := dest.(ClaimsContextValidator); ok {
if err = contextValidator.Validate(ctx); err != nil {
v.ErrorHandler(ctx, err)
return
}
}
if u, ok := dest.(context.User); ok {
ctx.SetUser(u)
}
ctx.Values().Set(claimsContextKey, dest)
}
if v.Blocklist != nil {
ctx.SetLogoutFunc(v.invalidate)
}
ctx.Values().Set(verifiedTokenContextKey, verifiedToken)
ctx.Next()
}
}
func hasRequired(i interface{}) bool {
val := reflect.Indirect(reflect.ValueOf(i))
typ := val.Type()
if typ.Kind() != reflect.Struct {
return false
}
for i := 0; i < val.NumField(); i++ {
field := typ.Field(i)
if jwt.HasRequiredJSONTag(field) {
return true
}
}
return false
}