package jwt import ( "crypto" "encoding/json" "errors" "os" "strings" "time" "github.com/kataras/iris/v12/context" "github.com/square/go-jose/v3" "github.com/square/go-jose/v3/jwt" ) func init() { context.SetHandlerName("iris/middleware/jwt.*", "iris.jwt") } // TokenExtractor is a function that takes a context as input and returns // a token. An empty string should be returned if no token found // without additional information. type TokenExtractor func(context.Context) string // FromHeader is a token extractor. // It reads the token from the Authorization request header of form: // Authorization: "Bearer {token}". func FromHeader(ctx context.Context) string { authHeader := ctx.GetHeader("Authorization") if authHeader == "" { return "" } // pure check: authorization header format must be Bearer {token} authHeaderParts := strings.Split(authHeader, " ") if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" { return "" } return authHeaderParts[1] } // FromQuery is a token extractor. // It reads the token from the "token" url query parameter. func FromQuery(ctx context.Context) string { return ctx.URLParam("token") } // FromJSON is a token extractor. // Reads a json request body and extracts the json based on the given field. // The request content-type should contain the: application/json header value, otherwise // this method will not try to read and consume the body. func FromJSON(jsonKey string) TokenExtractor { return func(ctx context.Context) string { if ctx.GetContentTypeRequested() != context.ContentJSONHeaderValue { return "" } var m context.Map if err := ctx.ReadJSON(&m); err != nil { return "" } if m == nil { return "" } v, ok := m[jsonKey] if !ok { return "" } tok, ok := v.(string) if !ok { return "" } return tok } } // JWT holds the necessary information the middleware need // to sign and verify tokens. // // The `RSA(privateFile, publicFile, password)` package-level helper function // can be used to decode the SignKey and VerifyKey. type JWT struct { // MaxAge is the expiration duration of the generated tokens. MaxAge time.Duration // Extractors are used to extract a raw token string value // from the request. // Builtin extractors: // * FromHeader // * FromQuery // * FromJSON // Defaults to a slice of `FromHeader` and `FromQuery`. Extractors []TokenExtractor // Signer is used to sign the token. // It is set on `New` and `Default` package-level functions. Signer jose.Signer // VerificationKey is used to verify the token (public key). VerificationKey interface{} // Encrypter is used to, optionally, encrypt the token. // It is set on `WithEncryption` method. Encrypter jose.Encrypter // DecriptionKey is used to decrypt the token (private key) DecriptionKey interface{} } type privateKey interface{ Public() crypto.PublicKey } // New returns a new JWT instance. // It accepts a maximum time duration for token expiration // and the algorithm among with its key for signing and verification. // // See `WithEncryption` method to add token encryption too. // Use `Token` method to generate a new token string // and `VerifyToken` method to decrypt, verify and bind claims of an incoming request token. // Token, by default, is extracted by "Authorization: Bearer {token}" request header and // url query parameter of "token". Token extractors can be modified through the `Extractors` field. // // For example, if you want to sign and verify using RSA-256 key: // 1. Generate key file, e.g: // $ openssl genrsa -des3 -out private.pem 2048 // 2. Read file contents with io.ReadFile("./private.pem") // 3. Pass the []byte result to the `ParseRSAPrivateKey(contents, password)` package-level helper // 4. Use the result *rsa.PrivateKey as "key" input parameter of this `New` function. // // See aliases.go file for available algorithms. func New(maxAge time.Duration, alg SignatureAlgorithm, key interface{}) (*JWT, error) { sig, err := jose.NewSigner(jose.SigningKey{ Algorithm: alg, Key: key, }, (&jose.SignerOptions{}).WithType("JWT")) if err != nil { return nil, err } j := &JWT{ Signer: sig, VerificationKey: key, MaxAge: maxAge, Extractors: []TokenExtractor{FromHeader, FromQuery}, } if s, ok := key.(privateKey); ok { j.VerificationKey = s.Public() } return j, nil } // Default key filenames for `RSA`. const ( DefaultSignFilename = "jwt_sign.key" DefaultEncFilename = "jwt_enc.key" ) // RSA returns a new `JWT` instance. // It tries to parse RSA256 keys from "filenames[0]" (defaults to "jwt_sign.key") and // "filenames[1]" (defaults to "jwt_enc.key") files or generates and exports new random keys. // // It panics on errors. // Use the `New` package-level function instead for more options. func RSA(maxAge time.Duration, filenames ...string) *JWT { var ( signFilename = DefaultSignFilename encFilename = DefaultEncFilename ) switch len(filenames) { case 1: signFilename = filenames[0] case 2: encFilename = filenames[1] } // Do not try to create or load enc key if only sign key already exists. withEncryption := true if fileExists(signFilename) { withEncryption = fileExists(encFilename) } sigKey, err := LoadRSA(signFilename, 2048) if err != nil { panic(err) } j, err := New(maxAge, RS256, sigKey) if err != nil { panic(err) } if withEncryption { encKey, err := LoadRSA(encFilename, 2048) if err != nil { panic(err) } err = j.WithEncryption(A128CBCHS256, RSA15, encKey) if err != nil { panic(err) } } return j } const ( signEnv = "JWT_SECRET" encEnv = "JWT_SECRET_ENC" ) func getenv(key string, def string) string { v := os.Getenv(key) if v == "" { return def } return v } // HMAC returns a new `JWT` instance. // It tries to read hmac256 secret keys from system environment variables: // * JWT_SECRET for signing and verification key and // * JWT_SECRET_ENC for encryption and decryption key // and defaults them to the given "keys" respectfully. // // It panics on errors. // Use the `New` package-level function instead for more options. func HMAC(maxAge time.Duration, keys ...string) *JWT { var defaultSignSecret, defaultEncSecret string switch len(keys) { case 1: defaultSignSecret = keys[0] case 2: defaultEncSecret = keys[1] } signSecret := getenv(signEnv, defaultSignSecret) encSecret := getenv(encEnv, defaultEncSecret) j, err := New(maxAge, HS256, []byte(signSecret)) if err != nil { panic(err) } if encSecret != "" { err = j.WithEncryption(A128GCM, DIRECT, []byte(encSecret)) if err != nil { panic(err) } } return j } // WithEncryption method enables encryption and decryption of the token. // It sets an appropriate encrypter(`Encrypter` and the `DecriptionKey` fields) based on the key type. func (j *JWT) WithEncryption(contentEncryption ContentEncryption, alg KeyAlgorithm, key interface{}) error { var publicKey interface{} = key if s, ok := key.(privateKey); ok { publicKey = s.Public() } enc, err := jose.NewEncrypter(contentEncryption, jose.Recipient{ Algorithm: alg, Key: publicKey, }, (&jose.EncrypterOptions{}).WithType("JWT").WithContentType("JWT"), ) if err != nil { return err } j.Encrypter = enc j.DecriptionKey = key return nil } // Expiry returns a new standard Claims with // the `Expiry` and `IssuedAt` fields of the "claims" filled // based on the given "maxAge" duration. // // See the `JWT.Expiry` method too. func Expiry(maxAge time.Duration, claims Claims) Claims { now := time.Now() claims.Expiry = NewNumericDate(now.Add(maxAge)) claims.IssuedAt = NewNumericDate(now) return claims } // Expiry method same as `Expiry` package-level function, // it returns a Claims with the expiration fields of the "claims" // filled based on the JWT's `MaxAge` field. // Only use it when this standard "claims" // is embedded on a custom claims structure. // Usage: // type UserClaims struct { // jwt.Claims // Username string // } // [...] // standardClaims := j.Expiry(jwt.Claims{...}) // customClaims := UserClaims{ // Claims: standardClaims, // Username: "kataras", // } // j.WriteToken(ctx, customClaims) func (j *JWT) Expiry(claims Claims) Claims { return Expiry(j.MaxAge, claims) } // Token generates and returns a new token string. // See `VerifyToken` too. func (j *JWT) Token(claims interface{}) (string, error) { // switch c := claims.(type) { // case Claims: // claims = Expiry(j.MaxAge, c) // case map[string]interface{}: let's not support map. // now := time.Now() // c["iat"] = now.Unix() // c["exp"] = now.Add(j.MaxAge).Unix() // } if c, ok := claims.(Claims); ok { claims = Expiry(j.MaxAge, c) } var ( token string err error ) // jwt.Builder and jwt.NestedBuilder contain same methods but they are not the same. if j.DecriptionKey != nil { token, err = jwt.SignedAndEncrypted(j.Signer, j.Encrypter).Claims(claims).CompactSerialize() } else { token, err = jwt.Signed(j.Signer).Claims(claims).CompactSerialize() } if err != nil { return "", err } return token, nil } /* Let's no support maps, typed claim is the way to go. // validateMapClaims validates claims of map type. func validateMapClaims(m map[string]interface{}, e jwt.Expected, leeway time.Duration) error { if !e.Time.IsZero() { if v, ok := m["nbf"]; ok { if notBefore, ok := v.(NumericDate); ok { if e.Time.Add(leeway).Before(notBefore.Time()) { return ErrNotValidYet } } } if v, ok := m["exp"]; ok { if exp, ok := v.(int64); ok { if e.Time.Add(-leeway).Before(time.Unix(exp, 0)) { return ErrExpired } } } if v, ok := m["iat"]; ok { if issuedAt, ok := v.(int64); ok { if e.Time.Add(leeway).Before(time.Unix(issuedAt, 0)) { return ErrIssuedInTheFuture } } } } return nil } */ // WriteToken is a helper which just generates(calls the `Token` method) and writes // a new token to the client in plain text format. // // Use the `Token` method to get a new generated token raw string value. func (j *JWT) WriteToken(ctx context.Context, claims interface{}) error { token, err := j.Token(claims) if err != nil { ctx.StatusCode(500) return err } _, err = ctx.WriteString(token) return err } var ( // ErrMissing when token cannot be extracted from the request. ErrMissing = errors.New("token is missing") // ErrExpired indicates that token is used after expiry time indicated in exp claim. ErrExpired = errors.New("token is expired (exp)") // ErrNotValidYet indicates that token is used before time indicated in nbf claim. ErrNotValidYet = errors.New("token not valid yet (nbf)") // ErrIssuedInTheFuture indicates that the iat field is in the future. ErrIssuedInTheFuture = errors.New("token issued in the future (iat)") ) type ( claimsValidator interface { ValidateWithLeeway(e jwt.Expected, leeway time.Duration) error } claimsAlternativeValidator interface { // to keep iris-contrib/jwt MapClaims compatible. Validate() error } claimsContextValidator interface { Validate(ctx context.Context) error } ) // IsValidated reports whether a token is already validated through // `VerifyToken`. It returns true when the claims are compatible // validators: a `Claims` value or a value that implements the `Validate() error` method. func IsValidated(ctx context.Context) bool { // see the `ReadClaims`. return ctx.Values().Get(needsValidationContextKey) == nil } func validateClaims(ctx context.Context, claims interface{}) (err error) { switch c := claims.(type) { case claimsValidator: err = c.ValidateWithLeeway(jwt.Expected{Time: time.Now()}, 0) case claimsAlternativeValidator: err = c.Validate() case claimsContextValidator: err = c.Validate(ctx) case *json.RawMessage: // if the data type is raw message (json []byte) // then it should contain exp (and iat and nbf) keys. // Unmarshal raw message to validate against. v := new(Claims) err = json.Unmarshal(*c, v) if err == nil { return validateClaims(ctx, v) } default: ctx.Values().Set(needsValidationContextKey, struct{}{}) } if err != nil { switch err { case jwt.ErrExpired: return ErrExpired case jwt.ErrNotValidYet: return ErrNotValidYet case jwt.ErrIssuedInTheFuture: return ErrIssuedInTheFuture } } return err } // VerifyToken verifies (and decrypts) the request token, // it also validates and binds the parsed token's claims to the "claimsPtr" (destination). // It does return a nil error on success. func (j *JWT) VerifyToken(ctx context.Context, claimsPtr interface{}) error { var token string for _, extract := range j.Extractors { if token = extract(ctx); token != "" { break // ok we found it. } } if token == "" { return ErrMissing } var ( parsedToken *jwt.JSONWebToken err error ) if j.DecriptionKey != nil { t, cerr := jwt.ParseSignedAndEncrypted(token) if cerr != nil { return cerr } parsedToken, err = t.Decrypt(j.DecriptionKey) } else { parsedToken, err = jwt.ParseSigned(token) } if err != nil { return err } if err = parsedToken.Claims(j.VerificationKey, claimsPtr); err != nil { return err } return validateClaims(ctx, claimsPtr) } const ( // ClaimsContextKey is the context key which the jwt claims are stored from the `Verify` method. ClaimsContextKey = "iris.jwt.claims" needsValidationContextKey = "iris.jwt.claims.unvalidated" ) // Verify is a middleware. It verifies and optionally decrypts an incoming request token. // It does write a 401 unauthorized status code if verification or decryption failed. // It calls the `ctx.Next` on verified requests. // // See `VerifyToken` instead to verify, decrypt, validate and acquire the claims at once. // // A call of `ReadClaims` is required to validate and acquire the jwt claims // on the next request. func (j *JWT) Verify(ctx context.Context) { var raw json.RawMessage if err := j.VerifyToken(ctx, &raw); err != nil { ctx.StopWithStatus(401) return } ctx.Values().Set(ClaimsContextKey, raw) ctx.Next() } // ReadClaims binds the "claimsPtr" (destination) // to the verified (and decrypted) claims. // The `Verify` method should be called first (registered as middleware). func ReadClaims(ctx context.Context, claimsPtr interface{}) error { v := ctx.Values().Get(ClaimsContextKey) if v == nil { return ErrMissing } raw, ok := v.(json.RawMessage) if !ok { return ErrMissing } err := json.Unmarshal(raw, claimsPtr) if err != nil { return err } if !IsValidated(ctx) { // If already validated on `Verify/VerifyToken` // then no need to perform the check again. ctx.Values().Remove(needsValidationContextKey) return validateClaims(ctx, claimsPtr) } return nil } // Get returns and validates (if not already) the claims // stored on request context's values storage. // // Should be used instead of the `ReadClaims` method when // a custom verification middleware was registered (see the `Verify` method for an example). // // Usage: // j := jwt.New(...) // [...] // app.Use(func(ctx iris.Context) { // var claims CustomClaims_or_jwt.Claims // if err := j.VerifyToken(ctx, &claims); err != nil { // ctx.StopWithStatus(iris.StatusUnauthorized) // return // } // // ctx.Values().Set(jwt.ClaimsContextKey, claims) // ctx.Next() // }) // [...] // app.Post("/restricted", func(ctx iris.Context){ // v, err := jwt.Get(ctx) // [handle error...] // claims,ok := v.(CustomClaims_or_jwt.Claims) // if !ok { // [do you support more than one type of claims? Handle here] // } // [use claims...] // }) func Get(ctx context.Context) (interface{}, error) { claims := ctx.Values().Get(ClaimsContextKey) if claims == nil { return nil, ErrMissing } if !IsValidated(ctx) { ctx.Values().Remove(needsValidationContextKey) err := validateClaims(ctx, claims) if err != nil { return nil, err } } return claims, nil }