package jwt

import (
	"crypto"
	"encoding/json"
	"errors"
	"os"
	"strings"
	"time"

	"github.com/kataras/iris/v12/context"

	"github.com/square/go-jose/v3"
	"github.com/square/go-jose/v3/jwt"
)

func init() {
	context.SetHandlerName("iris/middleware/jwt.*", "iris.jwt")
}

// TokenExtractor is a function that takes a context as input and returns
// a token. An empty string should be returned if no token found
// without additional information.
type TokenExtractor func(*context.Context) string

// FromHeader is a token extractor.
// It reads the token from the Authorization request header of form:
// Authorization: "Bearer {token}".
func FromHeader(ctx *context.Context) string {
	authHeader := ctx.GetHeader("Authorization")
	if authHeader == "" {
		return ""
	}

	// pure check: authorization header format must be Bearer {token}
	authHeaderParts := strings.Split(authHeader, " ")
	if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" {
		return ""
	}

	return authHeaderParts[1]
}

// FromQuery is a token extractor.
// It reads the token from the "token" url query parameter.
func FromQuery(ctx *context.Context) string {
	return ctx.URLParam("token")
}

// FromJSON is a token extractor.
// Reads a json request body and extracts the json based on the given field.
// The request content-type should contain the: application/json header value, otherwise
// this method will not try to read and consume the body.
func FromJSON(jsonKey string) TokenExtractor {
	return func(ctx *context.Context) string {
		if ctx.GetContentTypeRequested() != context.ContentJSONHeaderValue {
			return ""
		}

		var m context.Map
		if err := ctx.ReadJSON(&m); err != nil {
			return ""
		}

		if m == nil {
			return ""
		}

		v, ok := m[jsonKey]
		if !ok {
			return ""
		}

		tok, ok := v.(string)
		if !ok {
			return ""
		}

		return tok
	}
}

// JWT holds the necessary information the middleware need
// to sign and verify tokens.
//
// The `RSA(privateFile, publicFile, password)` package-level helper function
// can be used to decode the SignKey and VerifyKey.
type JWT struct {
	// MaxAge is the expiration duration of the generated tokens.
	MaxAge time.Duration

	// Extractors are used to extract a raw token string value
	// from the request.
	// Builtin extractors:
	// * FromHeader
	// * FromQuery
	// * FromJSON
	// Defaults to a slice of `FromHeader` and `FromQuery`.
	Extractors []TokenExtractor

	// Signer is used to sign the token.
	// It is set on `New` and `Default` package-level functions.
	Signer jose.Signer
	// VerificationKey is used to verify the token (public key).
	VerificationKey interface{}

	// Encrypter is used to, optionally, encrypt the token.
	// It is set on `WithEncryption` method.
	Encrypter jose.Encrypter
	// DecriptionKey is used to decrypt the token (private key)
	DecriptionKey interface{}
}

type privateKey interface{ Public() crypto.PublicKey }

// New returns a new JWT instance.
// It accepts a maximum time duration for token expiration
// and the algorithm among with its key for signing and verification.
//
// See `WithEncryption` method to add token encryption too.
// Use `Token` method to generate a new token string
// and `VerifyToken` method to decrypt, verify and bind claims of an incoming request token.
// Token, by default, is extracted by "Authorization: Bearer {token}" request header and
// url query parameter of "token". Token extractors can be modified through the `Extractors` field.
//
// For example, if you want to sign and verify using RSA-256 key:
// 1. Generate key file, e.g:
// 		$ openssl genrsa -des3 -out private.pem 2048
// 2. Read file contents with io.ReadFile("./private.pem")
// 3. Pass the []byte result to the `ParseRSAPrivateKey(contents, password)` package-level helper
// 4. Use the result *rsa.PrivateKey as "key" input parameter of this `New` function.
//
// See aliases.go file for available algorithms.
func New(maxAge time.Duration, alg SignatureAlgorithm, key interface{}) (*JWT, error) {
	sig, err := jose.NewSigner(jose.SigningKey{
		Algorithm: alg,
		Key:       key,
	}, (&jose.SignerOptions{}).WithType("JWT"))

	if err != nil {
		return nil, err
	}

	j := &JWT{
		Signer:          sig,
		VerificationKey: key,
		MaxAge:          maxAge,
		Extractors:      []TokenExtractor{FromHeader, FromQuery},
	}

	if s, ok := key.(privateKey); ok {
		j.VerificationKey = s.Public()
	}

	return j, nil
}

// Default key filenames for `RSA`.
const (
	DefaultSignFilename = "jwt_sign.key"
	DefaultEncFilename  = "jwt_enc.key"
)

// RSA returns a new `JWT` instance.
// It tries to parse RSA256 keys from "filenames[0]" (defaults to  "jwt_sign.key") and
// "filenames[1]" (defaults to "jwt_enc.key") files or generates and exports new random keys.
//
// It panics on errors.
// Use the `New` package-level function instead for more options.
func RSA(maxAge time.Duration, filenames ...string) *JWT {
	var (
		signFilename = DefaultSignFilename
		encFilename  = DefaultEncFilename
	)

	switch len(filenames) {
	case 1:
		signFilename = filenames[0]
	case 2:
		encFilename = filenames[1]
	}

	// Do not try to create or load enc key if only sign key already exists.
	withEncryption := true
	if fileExists(signFilename) {
		withEncryption = fileExists(encFilename)
	}

	sigKey, err := LoadRSA(signFilename, 2048)
	if err != nil {
		panic(err)
	}

	j, err := New(maxAge, RS256, sigKey)
	if err != nil {
		panic(err)
	}

	if withEncryption {
		encKey, err := LoadRSA(encFilename, 2048)
		if err != nil {
			panic(err)
		}
		err = j.WithEncryption(A128CBCHS256, RSA15, encKey)
		if err != nil {
			panic(err)
		}
	}

	return j
}

const (
	signEnv = "JWT_SECRET"
	encEnv  = "JWT_SECRET_ENC"
)

func getenv(key string, def string) string {
	v := os.Getenv(key)
	if v == "" {
		return def
	}

	return v
}

// HMAC returns a new `JWT` instance.
// It tries to read hmac256 secret keys from system environment variables:
// * JWT_SECRET for signing and verification key and
// * JWT_SECRET_ENC for encryption and decryption key
// and defaults them to the given "keys" respectfully.
//
// It panics on errors.
// Use the `New` package-level function instead for more options.
func HMAC(maxAge time.Duration, keys ...string) *JWT {
	var defaultSignSecret, defaultEncSecret string

	switch len(keys) {
	case 1:
		defaultSignSecret = keys[0]
	case 2:
		defaultEncSecret = keys[1]
	}

	signSecret := getenv(signEnv, defaultSignSecret)
	encSecret := getenv(encEnv, defaultEncSecret)

	j, err := New(maxAge, HS256, []byte(signSecret))
	if err != nil {
		panic(err)
	}

	if encSecret != "" {
		err = j.WithEncryption(A128GCM, DIRECT, []byte(encSecret))
		if err != nil {
			panic(err)
		}
	}

	return j
}

// WithEncryption method enables encryption and decryption of the token.
// It sets an appropriate encrypter(`Encrypter` and the `DecriptionKey` fields) based on the key type.
func (j *JWT) WithEncryption(contentEncryption ContentEncryption, alg KeyAlgorithm, key interface{}) error {
	var publicKey interface{} = key
	if s, ok := key.(privateKey); ok {
		publicKey = s.Public()
	}

	enc, err := jose.NewEncrypter(contentEncryption, jose.Recipient{
		Algorithm: alg,
		Key:       publicKey,
	},
		(&jose.EncrypterOptions{}).WithType("JWT").WithContentType("JWT"),
	)

	if err != nil {
		return err
	}

	j.Encrypter = enc
	j.DecriptionKey = key
	return nil
}

// Expiry returns a new standard Claims with
// the `Expiry` and `IssuedAt` fields of the "claims" filled
// based on the given "maxAge" duration.
//
// See the `JWT.Expiry` method too.
func Expiry(maxAge time.Duration, claims Claims) Claims {
	now := time.Now()
	claims.Expiry = NewNumericDate(now.Add(maxAge))
	claims.IssuedAt = NewNumericDate(now)
	return claims
}

// Expiry method same as `Expiry` package-level function,
// it returns a Claims with the expiration fields of the "claims"
// filled based on the JWT's `MaxAge` field.
// Only use it when this standard "claims"
// is embedded on a custom claims structure.
// Usage:
// type UserClaims struct {
// 	jwt.Claims
// 	Username string
// }
// [...]
// standardClaims := j.Expiry(jwt.Claims{...})
// customClaims := UserClaims{
// 	Claims:   standardClaims,
// 	Username: "kataras",
// }
// j.WriteToken(ctx, customClaims)
func (j *JWT) Expiry(claims Claims) Claims {
	return Expiry(j.MaxAge, claims)
}

// Token generates and returns a new token string.
// See `VerifyToken` too.
func (j *JWT) Token(claims interface{}) (string, error) {
	// switch c := claims.(type) {
	// case Claims:
	// 	claims = Expiry(j.MaxAge, c)
	// case map[string]interface{}: let's not support map.
	// 	now := time.Now()
	// 	c["iat"] = now.Unix()
	// 	c["exp"] = now.Add(j.MaxAge).Unix()
	// }
	if c, ok := claims.(Claims); ok {
		claims = Expiry(j.MaxAge, c)
	}

	var (
		token string
		err   error
	)

	// jwt.Builder and jwt.NestedBuilder contain same methods but they are not the same.
	if j.DecriptionKey != nil {
		token, err = jwt.SignedAndEncrypted(j.Signer, j.Encrypter).Claims(claims).CompactSerialize()
	} else {
		token, err = jwt.Signed(j.Signer).Claims(claims).CompactSerialize()
	}

	if err != nil {
		return "", err
	}

	return token, nil
}

/* Let's no support maps, typed claim is the way to go.
// validateMapClaims validates claims of map type.
func validateMapClaims(m map[string]interface{}, e jwt.Expected, leeway time.Duration) error {
	if !e.Time.IsZero() {
		if v, ok := m["nbf"]; ok {
			if notBefore, ok := v.(NumericDate); ok {
				if e.Time.Add(leeway).Before(notBefore.Time()) {
					return ErrNotValidYet
				}
			}
		}

		if v, ok := m["exp"]; ok {
			if exp, ok := v.(int64); ok {
				if e.Time.Add(-leeway).Before(time.Unix(exp, 0)) {
					return ErrExpired
				}
			}
		}

		if v, ok := m["iat"]; ok {
			if issuedAt, ok := v.(int64); ok {
				if e.Time.Add(leeway).Before(time.Unix(issuedAt, 0)) {
					return ErrIssuedInTheFuture
				}
			}
		}
	}

	return nil
}
*/

// WriteToken is a helper which just generates(calls the `Token` method) and writes
// a new token to the client in plain text format.
//
// Use the `Token` method to get a new generated token raw string value.
func (j *JWT) WriteToken(ctx *context.Context, claims interface{}) error {
	token, err := j.Token(claims)
	if err != nil {
		ctx.StatusCode(500)
		return err
	}

	_, err = ctx.WriteString(token)
	return err
}

var (
	// ErrMissing when token cannot be extracted from the request.
	ErrMissing = errors.New("token is missing")
	// ErrExpired indicates that token is used after expiry time indicated in exp claim.
	ErrExpired = errors.New("token is expired (exp)")
	// ErrNotValidYet indicates that token is used before time indicated in nbf claim.
	ErrNotValidYet = errors.New("token not valid yet (nbf)")
	// ErrIssuedInTheFuture indicates that the iat field is in the future.
	ErrIssuedInTheFuture = errors.New("token issued in the future (iat)")
)

type (
	claimsValidator interface {
		ValidateWithLeeway(e jwt.Expected, leeway time.Duration) error
	}
	claimsAlternativeValidator interface { // to keep iris-contrib/jwt MapClaims compatible.
		Validate() error
	}
	claimsContextValidator interface {
		Validate(ctx *context.Context) error
	}
)

// IsValidated reports whether a token is already validated through
// `VerifyToken`. It returns true when the claims are compatible
// validators: a `Claims` value or a value that implements the `Validate() error` method.
func IsValidated(ctx *context.Context) bool { // see the `ReadClaims`.
	return ctx.Values().Get(needsValidationContextKey) == nil
}

func validateClaims(ctx *context.Context, claims interface{}) (err error) {
	switch c := claims.(type) {
	case claimsValidator:
		err = c.ValidateWithLeeway(jwt.Expected{Time: time.Now()}, 0)
	case claimsAlternativeValidator:
		err = c.Validate()
	case claimsContextValidator:
		err = c.Validate(ctx)
	case *json.RawMessage:
		// if the data type is raw message (json []byte)
		// then it should contain exp (and iat and nbf) keys.
		// Unmarshal raw message to validate against.
		v := new(Claims)
		err = json.Unmarshal(*c, v)
		if err == nil {
			return validateClaims(ctx, v)
		}
	default:
		ctx.Values().Set(needsValidationContextKey, struct{}{})
	}

	if err != nil {
		switch err {
		case jwt.ErrExpired:
			return ErrExpired
		case jwt.ErrNotValidYet:
			return ErrNotValidYet
		case jwt.ErrIssuedInTheFuture:
			return ErrIssuedInTheFuture
		}
	}

	return err
}

// VerifyToken verifies (and decrypts) the request token,
// it also validates and binds the parsed token's claims to the "claimsPtr" (destination).
// It does return a nil error on success.
func (j *JWT) VerifyToken(ctx *context.Context, claimsPtr interface{}) error {
	var token string

	for _, extract := range j.Extractors {
		if token = extract(ctx); token != "" {
			break // ok we found it.
		}
	}

	if token == "" {
		return ErrMissing
	}

	var (
		parsedToken *jwt.JSONWebToken
		err         error
	)

	if j.DecriptionKey != nil {
		t, cerr := jwt.ParseSignedAndEncrypted(token)
		if cerr != nil {
			return cerr
		}

		parsedToken, err = t.Decrypt(j.DecriptionKey)
	} else {
		parsedToken, err = jwt.ParseSigned(token)
	}
	if err != nil {
		return err
	}

	if err = parsedToken.Claims(j.VerificationKey, claimsPtr); err != nil {
		return err
	}

	return validateClaims(ctx, claimsPtr)
}

const (
	// ClaimsContextKey is the context key which the jwt claims are stored from the `Verify` method.
	ClaimsContextKey          = "iris.jwt.claims"
	needsValidationContextKey = "iris.jwt.claims.unvalidated"
)

// Verify is a middleware. It verifies and optionally decrypts an incoming request token.
// It does write a 401 unauthorized status code if verification or decryption failed.
// It calls the `ctx.Next` on verified requests.
//
// See `VerifyToken` instead to verify, decrypt, validate and acquire the claims at once.
//
// A call of `ReadClaims` is required to validate and acquire the jwt claims
// on the next request.
func (j *JWT) Verify(ctx *context.Context) {
	var raw json.RawMessage
	if err := j.VerifyToken(ctx, &raw); err != nil {
		ctx.StopWithStatus(401)
		return
	}

	ctx.Values().Set(ClaimsContextKey, raw)
	ctx.Next()
}

// ReadClaims binds the "claimsPtr" (destination)
// to the verified (and decrypted) claims.
// The `Verify` method should be called  first (registered as middleware).
func ReadClaims(ctx *context.Context, claimsPtr interface{}) error {
	v := ctx.Values().Get(ClaimsContextKey)
	if v == nil {
		return ErrMissing
	}

	raw, ok := v.(json.RawMessage)
	if !ok {
		return ErrMissing
	}

	err := json.Unmarshal(raw, claimsPtr)
	if err != nil {
		return err
	}

	if !IsValidated(ctx) {
		// If already validated on `Verify/VerifyToken`
		// then no need to perform the check again.
		ctx.Values().Remove(needsValidationContextKey)
		return validateClaims(ctx, claimsPtr)
	}

	return nil
}

// Get returns and validates (if not already) the claims
// stored on request context's values storage.
//
// Should be used instead of the `ReadClaims` method when
// a custom verification middleware was registered (see the `Verify` method for an example).
//
// Usage:
// j := jwt.New(...)
// [...]
// app.Use(func(ctx iris.Context) {
//	var claims CustomClaims_or_jwt.Claims
// 	if err := j.VerifyToken(ctx, &claims); err != nil {
// 		ctx.StopWithStatus(iris.StatusUnauthorized)
// 		return
// 	}
//
// 	ctx.Values().Set(jwt.ClaimsContextKey, claims)
// 	ctx.Next()
// })
// [...]
// app.Post("/restricted", func(ctx iris.Context){
//	v, err := jwt.Get(ctx)
//  [handle error...]
//  claims,ok := v.(CustomClaims_or_jwt.Claims)
//  if !ok {
// 	  [do you support more than one type of claims? Handle here]
// 	}
//  [use claims...]
// })
func Get(ctx *context.Context) (interface{}, error) {
	claims := ctx.Values().Get(ClaimsContextKey)
	if claims == nil {
		return nil, ErrMissing
	}

	if !IsValidated(ctx) {
		ctx.Values().Remove(needsValidationContextKey)
		err := validateClaims(ctx, claims)
		if err != nil {
			return nil, err
		}
	}

	return claims, nil
}