package jwt

import (
	"reflect"
	"time"

	"github.com/kataras/iris/v12/context"

	"github.com/kataras/jwt"
)

const (
	claimsContextKey        = "iris.jwt.claims"
	verifiedTokenContextKey = "iris.jwt.token"
)

// Get returns the claims decoded by a verifier.
func Get(ctx *context.Context) interface{} {
	if v := ctx.Values().Get(claimsContextKey); v != nil {
		return v
	}

	return nil
}

// GetVerifiedToken returns the verified token structure
// which holds information about the decoded token
// and its standard claims.
func GetVerifiedToken(ctx *context.Context) *VerifiedToken {
	if v := ctx.Values().Get(verifiedTokenContextKey); v != nil {
		if tok, ok := v.(*VerifiedToken); ok {
			return tok
		}
	}

	return nil
}

// Verifier holds common options to verify an incoming token.
// Its Verify method can be used as a middleware to allow authorized clients to access an API.
//
// It does not support JWE, JWK.
type Verifier struct {
	Alg Alg
	Key interface{}

	Decrypt func([]byte) ([]byte, error)

	Extractors []TokenExtractor
	Blocklist  Blocklist
	Validators []TokenValidator

	ErrorHandler func(ctx *context.Context, err error)
	// DisableContextUser disables the registration of the claims as context User.
	DisableContextUser bool
}

// NewVerifier accepts the algorithm for the token's signature among with its (public) key
// and optionally some token validators for all verify middlewares that may initialized under this Verifier.
// See its Verify method.
//
// Usage:
//
//	verifier := NewVerifier(HS256, secret)
//
// OR
//
//	verifier := NewVerifier(HS256, secret, Expected{Issuer: "my-app"})
//
//	claimsGetter := func() interface{} { return new(userClaims) }
//	middleware := verifier.Verify(claimsGetter)
//
// OR
//
//	middleware := verifier.Verify(claimsGetter, Expected{Issuer: "my-app"})
//
// Register the middleware, e.g.
//
//	app.Use(middleware)
//
// Get the claims:
//
//	claims := jwt.Get(ctx).(*userClaims)
//	username := claims.Username
//
// Get the context user:
//
//	username, err := ctx.User().GetUsername()
func NewVerifier(signatureAlg Alg, signatureKey interface{}, validators ...TokenValidator) *Verifier {
	if signatureAlg == HS256 {
		// A tiny helper if the end-developer uses string instead of []byte for hmac keys.
		if k, ok := signatureKey.(string); ok {
			signatureKey = []byte(k)
		}
	}

	return &Verifier{
		Alg:        signatureAlg,
		Key:        signatureKey,
		Extractors: []TokenExtractor{FromHeader, FromQuery},
		ErrorHandler: func(ctx *context.Context, err error) {
			ctx.StopWithError(401, context.PrivateError(err))
		},
		Validators: validators,
	}
}

// WithDecryption enables AES-GCM payload-only encryption.
func (v *Verifier) WithDecryption(key, additionalData []byte) *Verifier {
	_, decrypt, err := jwt.GCM(key, additionalData)
	if err != nil {
		panic(err) // important error before serve, stop everything.
	}

	v.Decrypt = decrypt
	return v
}

// WithDefaultBlocklist attaches an in-memory blocklist storage
// to invalidate tokens through server-side.
// To invalidate a token simply call the Context.Logout method.
func (v *Verifier) WithDefaultBlocklist() *Verifier {
	v.Blocklist = jwt.NewBlocklist(30 * time.Minute)
	return v
}

func (v *Verifier) invalidate(ctx *context.Context) {
	if verifiedToken := GetVerifiedToken(ctx); verifiedToken != nil {
		v.Blocklist.InvalidateToken(verifiedToken.Token, verifiedToken.StandardClaims)
		ctx.Values().Remove(claimsContextKey)
		ctx.Values().Remove(verifiedTokenContextKey)
		ctx.SetUser(nil)
		ctx.SetLogoutFunc(nil)
	}
}

// RequestToken extracts the token from the request.
func (v *Verifier) RequestToken(ctx *context.Context) (token string) {
	for _, extract := range v.Extractors {
		if token = extract(ctx); token != "" {
			break // ok we found it.
		}
	}

	return
}

type (
	// ClaimsValidator is a special interface which, if the destination claims
	// implements it then the verifier runs its Validate method before return.
	ClaimsValidator interface {
		Validate() error
	}

	// ClaimsContextValidator same as ClaimsValidator but it accepts
	// a request context which can be used for further checks before
	// validating the incoming token's claims.
	ClaimsContextValidator interface {
		Validate(*context.Context) error
	}
)

// VerifyToken simply verifies the given "token" and validates its standard claims (such as expiration).
// Returns a structure which holds the token's information. See the Verify method instead.
func (v *Verifier) VerifyToken(token []byte, validators ...TokenValidator) (*VerifiedToken, error) {
	return jwt.VerifyEncrypted(v.Alg, v.Key, v.Decrypt, token, validators...)
}

// Verify is the most important piece of code inside the Verifier.
// It accepts the "claimsType" function which should return a pointer to a custom structure
// which the token's decode claims valuee will be binded and validated to.
// Returns a common Iris handler which can be used as a middleware to protect an API
// from unauthorized client requests. After this, the route handlers can access the claims
// through the jwt.Get package-level function.
//
// By default it extracts the token from Authorization: Bearer $token header and ?token URL Query parameter,
// to change that behavior modify its Extractors field.
//
// By default a 401 status code with a generic message will be sent to the client on
// a token verification or claims validation failure, to change that behavior
// modify its ErrorHandler field or register OnErrorCode(401, errorHandler) and
// retrieve the error through Context.GetErr method.
//
// If the "claimsType" is nil then only the jwt.GetVerifiedToken is available
// and the handler should unmarshal the payload to extract the claims by itself.
func (v *Verifier) Verify(claimsType func() interface{}, validators ...TokenValidator) context.Handler {
	unmarshal := jwt.Unmarshal
	if claimsType != nil {
		c := claimsType()
		if hasRequired(c) {
			unmarshal = jwt.UnmarshalWithRequired
		}
	}

	if v.Blocklist != nil {
		// If blocklist implements the connect interface,
		// try to connect if it's not already connected manually by developer,
		// if errored then just return a handler which will fire this error every single time.
		if bc, ok := v.Blocklist.(blocklistConnect); ok {
			if !bc.IsConnected() {
				if err := bc.Connect(); err != nil {
					return func(ctx *context.Context) {
						v.ErrorHandler(ctx, err)
					}
				}
			}
		}

		validators = append([]TokenValidator{v.Blocklist}, append(v.Validators, validators...)...)
	}

	return func(ctx *context.Context) {
		token := []byte(v.RequestToken(ctx))
		verifiedToken, err := v.VerifyToken(token, validators...)
		if err != nil {
			v.ErrorHandler(ctx, err)
			return
		}

		if claimsType != nil {
			dest := claimsType()
			if err = unmarshal(verifiedToken.Payload, dest); err != nil {
				v.ErrorHandler(ctx, err)
				return
			}

			if validator, ok := dest.(ClaimsValidator); ok {
				if err = validator.Validate(); err != nil {
					v.ErrorHandler(ctx, err)
					return
				}
			} else if contextValidator, ok := dest.(ClaimsContextValidator); ok {
				if err = contextValidator.Validate(ctx); err != nil {
					v.ErrorHandler(ctx, err)
					return
				}
			}

			if !v.DisableContextUser {
				ctx.SetUser(dest)
			}

			ctx.Values().Set(claimsContextKey, dest)
		}

		if v.Blocklist != nil {
			ctx.SetLogoutFunc(v.invalidate)
		}

		ctx.Values().Set(verifiedTokenContextKey, verifiedToken)
		ctx.Next()
	}
}

func hasRequired(i interface{}) bool {
	val := reflect.Indirect(reflect.ValueOf(i))
	typ := val.Type()
	if typ.Kind() != reflect.Struct {
		return false
	}

	for i := 0; i < val.NumField(); i++ {
		field := typ.Field(i)
		if jwt.HasRequiredJSONTag(field) {
			return true
		}
	}

	return false
}