From d556cfc39a08f1e408cf36b619856f7eaa0a4a6c Mon Sep 17 00:00:00 2001 From: "Gerasimos (Makis) Maropoulos" Date: Wed, 27 May 2020 12:02:17 +0300 Subject: [PATCH] New builtin JWT middleware - this one supports encryption and ed25519 Former-commit-id: ca20d256b766e3e8717e91de7a3f3b5f213af0bc --- HISTORY.md | 2 + NOTICE | 5 +- _examples/README.md | 4 +- _examples/authentication/README.md | 5 +- .../from-yaml-file/configs/iris.yml | 4 +- .../configuration/from-yaml-file/main.go | 4 +- _examples/miscellaneous/jwt/README.md | 29 ++ _examples/miscellaneous/jwt/main.go | 117 +++++ _examples/miscellaneous/jwt/private_rsa.pem | 30 ++ core/router/api_builder.go | 34 +- go.mod | 1 + middleware/jwt/alises.go | 73 +++ middleware/jwt/jwt.go | 424 ++++++++++++++++++ middleware/jwt/jwt_test.go | 119 +++++ middleware/jwt/util.go | 98 ++++ 15 files changed, 930 insertions(+), 19 deletions(-) create mode 100644 _examples/miscellaneous/jwt/README.md create mode 100644 _examples/miscellaneous/jwt/main.go create mode 100644 _examples/miscellaneous/jwt/private_rsa.pem create mode 100644 middleware/jwt/alises.go create mode 100644 middleware/jwt/jwt.go create mode 100644 middleware/jwt/jwt_test.go create mode 100644 middleware/jwt/util.go diff --git a/HISTORY.md b/HISTORY.md index e83fe684..10a169a2 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -371,6 +371,8 @@ Other Improvements: ![DBUG routes](https://iris-go.com/images/v12.2.0-dbug2.png?v=0) +- New builtin [JWT](https://github.com/kataras/iris/tree/master/jwt) middleware based on [square/go-jose](https://github.com/square/go-jose) featured with optional encryption to set claims with sensitive data when necessary. + - `Context.ReadForm` now can return an `iris.ErrEmptyForm` instead of `nil` when the new `Configuration.FireEmptyFormError` is true (or `iris.WithEmptyFormError`) on missing form body to read from. - `Configuration.EnablePathIntelligence | iris.WithPathIntelligence` to enable path intelligence automatic path redirection on the most closest path (if any), [example]((https://github.com/kataras/iris/blob/master/_examples/routing/intelligence/main.go) diff --git a/NOTICE b/NOTICE index d653e1cd..9e55f3a0 100644 --- a/NOTICE +++ b/NOTICE @@ -94,4 +94,7 @@ Revision ID: d1c07411df0bb21f6b21f5b5d9325fac6f29c911 b00f470d46 toml 3012a1dbe2e4bd1 https://github.com/BurntSushi/toml 391d42b32f0577c - b7bbc7f005 \ No newline at end of file + b7bbc7f005 + jose d84c719419c2a90 https://github.com/square/go-jose + 8d188ea67e09652 + f5c1929ae8 \ No newline at end of file diff --git a/_examples/README.md b/_examples/README.md index 1b319f77..c12ec938 100644 --- a/_examples/README.md +++ b/_examples/README.md @@ -133,8 +133,9 @@ * [Sitemap](sitemap/main.go) * Authentication * [Basic Authentication](authentication/basicauth/main.go) + * [JWT](miscellaneous/jwt/main.go) + * [JWT (community edition)](experimental-handlers/jwt/main.go) * [OAUth2](authentication/oauth2/main.go) - * [Request Auth(JWT)](experimental-handlers/jwt/main.go) * [Manage Permissions](permissions/main.go) * Cookies * [Basic](cookies/basic/main.go) @@ -190,6 +191,7 @@ * [The lorca package](desktop-app/lorca) * [The webview package](desktop-app/webview) * Middlewares (Builtin) + * [JWT](miscellaneous/jwt/main.go) * [Rate Limit](miscellaneous/ratelimit/main.go) * [HTTP Method Override](https://github.com/kataras/iris/blob/master/middleware/methodoverride/methodoverride_test.go) * [Request Logger](http_request/request-logger/main.go) diff --git a/_examples/authentication/README.md b/_examples/authentication/README.md index 08d1427d..47939fef 100644 --- a/_examples/authentication/README.md +++ b/_examples/authentication/README.md @@ -2,5 +2,6 @@ - [Basic Authentication](basicauth/main.go) - [OAUth2](oauth2/main.go) -- [Request Auth(JWT)](https://github.com/iris-contrib/middleware/blob/master/jwt) -- [Sessions](https://github.com/kataras/iris/tree/master/_examples/#sessions) \ No newline at end of file +- [JWT](https://github.com/kataras/iris/tree/master/_examples/miscellaneous/jwt) +- [JWT (community edition)](https://github.com/iris-contrib/middleware/blob/master/jwt) +- [Sessions](https://github.com/kataras/iris/tree/master/_examples/sessions) diff --git a/_examples/configuration/from-yaml-file/configs/iris.yml b/_examples/configuration/from-yaml-file/configs/iris.yml index c8cb373b..78113e54 100644 --- a/_examples/configuration/from-yaml-file/configs/iris.yml +++ b/_examples/configuration/from-yaml-file/configs/iris.yml @@ -3,4 +3,6 @@ EnablePathEscape: false FireMethodNotAllowed: true DisableBodyConsumptionOnUnmarshal: true TimeFormat: Mon, 01 Jan 2006 15:04:05 GMT -Charset: UTF-8 \ No newline at end of file +Charset: UTF-8 +Other: + Addr: :8080 \ No newline at end of file diff --git a/_examples/configuration/from-yaml-file/main.go b/_examples/configuration/from-yaml-file/main.go index 452bf2c7..a872d5bd 100644 --- a/_examples/configuration/from-yaml-file/main.go +++ b/_examples/configuration/from-yaml-file/main.go @@ -14,7 +14,9 @@ func main() { // Good when you have two configurations, one for development and a different one for production use. // If iris.YAML's input string argument is "~" then it loads the configuration from the home directory // and can be shared between many iris instances. - app.Listen(":8080", iris.WithConfiguration(iris.YAML("./configs/iris.yml"))) + cfg := iris.YAML("./configs/iris.yml") + addr := cfg.Other["Addr"].(string) + app.Listen(addr, iris.WithConfiguration(cfg)) // or before run: // app.Configure(iris.WithConfiguration(iris.YAML("./configs/iris.yml"))) diff --git a/_examples/miscellaneous/jwt/README.md b/_examples/miscellaneous/jwt/README.md new file mode 100644 index 00000000..1d2e0ae5 --- /dev/null +++ b/_examples/miscellaneous/jwt/README.md @@ -0,0 +1,29 @@ +# Generate RSA + +```sh +$ openssl genrsa -des3 -out private_rsa.pem 2048 +``` + +```go +b, err := ioutil.ReadFile("./private_rsa.pem") +if err != nil { + panic(err) +} +key := jwt.MustParseRSAPrivateKey(b, []byte("pass")) +``` + +OR + +```go +import "crypto/rand" +import "crypto/rsa" + +key, err := rsa.GenerateKey(rand.Reader, 2048) +``` + +# Generate Ed25519 + +```sh +$ openssl genpkey -algorithm Ed25519 -out private_ed25519.pem +$ openssl req -x509 -key private_ed25519.pem -out cert_ed25519.pem -days 365 +``` diff --git a/_examples/miscellaneous/jwt/main.go b/_examples/miscellaneous/jwt/main.go new file mode 100644 index 00000000..930a86f5 --- /dev/null +++ b/_examples/miscellaneous/jwt/main.go @@ -0,0 +1,117 @@ +package main + +import ( + "time" + + "github.com/kataras/iris/v12" + "github.com/kataras/iris/v12/middleware/jwt" +) + +// UserClaims a custom claims structure. You can just use jwt.Claims too. +type UserClaims struct { + jwt.Claims + Username string +} + +func main() { + // hmac + key := []byte("secret") + j, err := jwt.New(1*time.Minute, jwt.HS256, key) + if err != nil { + panic(err) + } + + // OPTIONAL encryption: + encryptionKey := []byte("itsa16bytesecret") + err = j.WithEncryption(jwt.A128GCM, jwt.DIRECT, encryptionKey) + if err != nil { + panic(err) + } + + app := iris.New() + app.Logger().SetLevel("debug") + + app.Get("/authenticate", func(ctx iris.Context) { + standardClaims := jwt.Claims{Issuer: "an-issuer", Audience: jwt.Audience{"an-audience"}} + // NOTE: if custom claims then the `j.Expiry(claims)` (or jwt.Expiry(duration, claims)) + // MUST be called in order to set the expiration time. + customClaims := UserClaims{ + Claims: j.Expiry(standardClaims), + Username: "kataras", + } + + j.WriteToken(ctx, customClaims) + }) + + userRouter := app.Party("/user") + { + // userRouter.Use(j.Verify) + // userRouter.Get("/", func(ctx iris.Context) { + // var claims UserClaims + // if err := jwt.ReadClaims(ctx, &claims); err != nil { + // // Validation-only errors, the rest are already + // // checked on `j.Verify` middleware. + // ctx.StopWithStatus(iris.StatusUnauthorized) + // return + // } + // + // ctx.Writef("Claims: %#+v\n", claims) + // }) + // + // OR: + userRouter.Get("/", func(ctx iris.Context) { + var claims UserClaims + if err := j.VerifyToken(ctx, &claims); err != nil { + ctx.StopWithStatus(iris.StatusUnauthorized) + return + } + + ctx.Writef("Claims: %#+v\n", claims) + }) + } + + app.Listen(":8080") +} + +/* +func load_From_File_Example() { + b, err := ioutil.ReadFile("./private_rsa.pem") + if err != nil { + panic(err) + } + signKey := jwt.MustParseRSAPrivateKey(b, []byte("pass")) + + j, err := jwt.New(15*time.Minute, jwt.RS256, signKey) + if err != nil { + panic(err) + } +} +*/ + +/* +func random_RSA_Sign_And_Encrypt_Example() { + j := jwt.Random(1 * time.Minute) +} +*/ + +/* +func random_manually_generate_RSA_Example() { + signey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + panic(err) + } + encryptionKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + panic(err) + } + + j, err := jwt.New(1*time.Minute, jwt.RS256, signey) + if err != nil { + panic(err) + } + err = j.WithEncryption(jwt.A128CBCHS256, jwt.RSA15, encryptionKey) + if err != nil { + panic(err) + } +} +*/ diff --git a/_examples/miscellaneous/jwt/private_rsa.pem b/_examples/miscellaneous/jwt/private_rsa.pem new file mode 100644 index 00000000..e93fff77 --- /dev/null +++ b/_examples/miscellaneous/jwt/private_rsa.pem @@ -0,0 +1,30 @@ +-----BEGIN RSA PRIVATE KEY----- +Proc-Type: 4,ENCRYPTED +DEK-Info: DES-EDE3-CBC,6B0BC214C94124FE + +lAM48DEM/GdCDimr9Vhi+fSHLgduDb0l2BA4uhILgNby51jxY/4X3IqM6f3ImKX7 +cEd9OBug+pwIugB0UW0L0f5Pd59Ovpiaz3xLci1/19ehYnMqsuP3YAnJm40hT5VP +p0gWRiR415PJ0fPeeJPFx5IsqvkTJ30LWZHUZX4EkdcL5L8PrVbmthGDbLh+OcMc +LzoP8eTglzlZF03nyvAol6+p2eZtvOJLu8nWG25q17kyBx6kEiCsWFcUBTX9G7sH +CM3naByDijqZXE/XXtmTMLSRRnlk7Q5WLxClroHlUP9y8BQFMo2TW4Z+vNjHUkc1 +77ghabX1704bAlIE8LLZJKrm/C5+VKyV6117SVG/2bc4036Y5rStXpANbk1j4K0x +ADvpRhuTpifaogdvJP+8eXBdl841MQMRzWuZHp6UNYYQegoV9C+KHyJx4UPjZyzd +gblZmKgU+BsX3mV6MLhJtd6dheLZtpBsAlSstJxzmgwqz9faONYEGeItXO+NnxbA +mxAp/mI+Fz2jfgYlWjwkyPPzD4k/ZMMzB4XLkKKs9XaxUtTomiDkuYZfnABhxt73 +xBy40V1rb/NyeW80pk1zEHM6Iy/48ETSp9U3k9sSOXjMhYbPXgxDtimV8w0qGFAo +2Tif7ZuaiuC38rOkoHK9C6vy2Dp8lQZ+QBnUKLeFsyhq9CaqSdnyUTMj3oEZXXf+ +TqqeO+PTtl7JaNfGRq6/aMZqxACHkyVUvYvjZzx07CJ2fr+OtNqxallM6Oc/o9NZ +5u7lpgrYaKM/b67q0d2X/AoxR5zrZuM8eam3acD1PwHFQKbJWuFNmjWtnlZNuR3X +fZEmxIKwDlup8TxFcqbbZtPHuQA2mTMTqfRkf8oPSO+N6NNaUpb0ignYyA7Eu5GT +b02d/oNLETMikxUxntMSH7GhuOpfJyELz8krYTttbJ+a93h4wBeYW2+LyAr/cRLB +mbtKLtaN7f3FaOSnu8e0+zlJ7xglHPXqblRL9q6ZDM5UJtJD4rA7LPZHk/0Y1Kb6 +hBh1qMDu0r3IV4X7MDacvxw7aa7D8TyXJiFSvxykVhds+ndjIe51Ics5908+lev3 +nwE69PLMwyqe2vvE2oDwao4XJuBLCHjcv/VagRSz/XQGMbZqb3L6unyd3UPl8JjP +ovipNwM4rFnE54uiUUeki7TZGDYO72vQcSaLrmbeAWc2m202+rqLz0WMm6HpPmCv +IgexpX2MnIeHJ3+BlEjA2u+S6xNSD7qHGk2pb7DD8nRvUdSHAHeaQbrkEfEhhR2Q +Dw5gdw1JyQ0UKBl5ndn/1Ub2Asl016lZjpqHyMIVS4tFixACDsihEYMmq/zQmTj4 +8oBZTU+fycN/KiGKZBsqxIwgYIeMz/GfvoyN5m57l6fwEZALVpveI1pP4fiZB/Z8 +xLKa5JK6L10lAD1YHWc1dPhamf9Sb3JwN2CFtGvjOJ/YjAZu3jJoxi40DtRkE3Rh +HI8Cbx1OORzoo0kO0vy42rz5qunYyVmEzPKtOj+YjVEhVJ85yJZ9bTZtuyqMv8mH +cnwEeIFK8cmm9asbVzQGDwN/UGB4cO3LrMX1RYk4GRttTGlp0729BbmZmu00RnD/ +-----END RSA PRIVATE KEY----- diff --git a/core/router/api_builder.go b/core/router/api_builder.go index 8f9c81c4..7453a031 100644 --- a/core/router/api_builder.go +++ b/core/router/api_builder.go @@ -1135,26 +1135,34 @@ func getCaller() (string, int) { n := runtime.Callers(1, pcs[:]) frames := runtime.CallersFrames(pcs[:n]) wd, _ := os.Getwd() - for { - frame, more := frames.Next() - file := filepath.ToSlash(frame.File) - if !strings.Contains(file, "_test.go") { - if strings.Contains(file, "/kataras/iris") && !strings.Contains(file, "kataras/iris/_examples") && !strings.Contains(file, "iris-contrib/examples") { - if !more { - break - } - continue - } + var ( + frame runtime.Frame + more = true + ) + + for { + if !more { + break } + frame, more = frames.Next() + file := filepath.ToSlash(frame.File) + // fmt.Printf("%s:%d | %s\n", file, frame.Line, frame.Function) + if strings.Contains(file, "go/src/runtime/") { - if !more { - break - } continue } + if !strings.Contains(file, "_test.go") { + if strings.Contains(file, "/kataras/iris") && + !strings.Contains(file, "kataras/iris/_examples") && + !strings.Contains(file, "kataras/iris/middleware") && + !strings.Contains(file, "iris-contrib/examples") { + continue + } + } + if relFile, err := filepath.Rel(wd, file); err == nil { if !strings.HasPrefix(relFile, "..") { // Only if it's relative to this path, not parent. diff --git a/go.mod b/go.mod index 7e63b44d..87286229 100644 --- a/go.mod +++ b/go.mod @@ -29,6 +29,7 @@ require ( github.com/microcosm-cc/bluemonday v1.0.2 github.com/ryanuber/columnize v2.1.0+incompatible github.com/schollz/closestmatch v2.1.0+incompatible + github.com/square/go-jose/v3 v3.0.0-20200430180204-d84c719419c2 github.com/vmihailenco/msgpack/v5 v5.0.0-alpha.2 go.etcd.io/bbolt v1.3.4 golang.org/x/crypto v0.0.0-20200510223506-06a226fb4e37 diff --git a/middleware/jwt/alises.go b/middleware/jwt/alises.go new file mode 100644 index 00000000..9cf35009 --- /dev/null +++ b/middleware/jwt/alises.go @@ -0,0 +1,73 @@ +package jwt + +import ( + "github.com/square/go-jose/v3" + "github.com/square/go-jose/v3/jwt" +) + +type ( + // Claims represents public claim values (as specified in RFC 7519). + Claims = jwt.Claims + // Audience represents the recipients that the token is intended for. + Audience = jwt.Audience +) + +type ( + // KeyAlgorithm represents a key management algorithm. + KeyAlgorithm = jose.KeyAlgorithm + + // SignatureAlgorithm represents a signature (or MAC) algorithm. + SignatureAlgorithm = jose.SignatureAlgorithm + + // ContentEncryption represents a content encryption algorithm. + ContentEncryption = jose.ContentEncryption +) + +// Key management algorithms. +const ( + ED25519 = jose.ED25519 + RSA15 = jose.RSA1_5 + RSAOAEP = jose.RSA_OAEP + RSAOAEP256 = jose.RSA_OAEP_256 + A128KW = jose.A128KW + A192KW = jose.A192KW + A256KW = jose.A256KW + DIRECT = jose.DIRECT + ECDHES = jose.ECDH_ES + ECDHESA128KW = jose.ECDH_ES_A128KW + ECDHESA192KW = jose.ECDH_ES_A192KW + ECDHESA256KW = jose.ECDH_ES_A256KW + A128GCMKW = jose.A128GCMKW + A192GCMKW = jose.A192GCMKW + A256GCMKW = jose.A256GCMKW + PBES2HS256A128KW = jose.PBES2_HS256_A128KW + PBES2HS384A192KW = jose.PBES2_HS384_A192KW + PBES2HS512A256KW = jose.PBES2_HS512_A256KW +) + +// Signature algorithms. +const ( + EdDSA = jose.EdDSA + HS256 = jose.HS256 + HS384 = jose.HS384 + HS512 = jose.HS512 + RS256 = jose.RS256 + RS384 = jose.RS384 + RS512 = jose.RS512 + ES256 = jose.ES256 + ES384 = jose.ES384 + ES512 = jose.ES512 + PS256 = jose.PS256 + PS384 = jose.PS384 + PS512 = jose.PS512 +) + +// Content encryption algorithms. +const ( + A128CBCHS256 = jose.A128CBC_HS256 + A192CBCHS384 = jose.A192CBC_HS384 + A256CBCHS512 = jose.A256CBC_HS512 + A128GCM = jose.A128GCM + A192GCM = jose.A192GCM + A256GCM = jose.A256GCM +) diff --git a/middleware/jwt/jwt.go b/middleware/jwt/jwt.go new file mode 100644 index 00000000..ce6857ab --- /dev/null +++ b/middleware/jwt/jwt.go @@ -0,0 +1,424 @@ +package jwt + +import ( + "crypto" + "crypto/rand" + "crypto/rsa" + "encoding/json" + "errors" + "strings" + "time" + + "github.com/kataras/iris/v12/context" + + "github.com/square/go-jose/v3" + "github.com/square/go-jose/v3/jwt" +) + +func init() { + context.SetHandlerName("iris/middleware/jwt.*", "iris.jwt") +} + +// TokenExtractor is a function that takes a context as input and returns +// a token. An empty string should be returned if no token found +// without additional information. +type TokenExtractor func(context.Context) string + +// FromHeader is a token extractor. +// It reads the token from the Authorization request header of form: +// Authorization: "Bearer {token}". +func FromHeader(ctx context.Context) string { + authHeader := ctx.GetHeader("Authorization") + if authHeader == "" { + return "" + } + + // pure check: authorization header format must be Bearer {token} + authHeaderParts := strings.Split(authHeader, " ") + if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" { + return "" + } + + return authHeaderParts[1] +} + +// FromQuery is a token extractor. +// It reads the token from the "token" url query parameter. +func FromQuery(ctx context.Context) string { + return ctx.URLParam("token") +} + +// FromJSON is a token extractor. +// Reads a json request body and extracts the json based on the given field. +// The request content-type should contain the: application/json header value, otherwise +// this method will not try to read and consume the body. +func FromJSON(jsonKey string) TokenExtractor { + return func(ctx context.Context) string { + if ctx.GetContentTypeRequested() != context.ContentJSONHeaderValue { + return "" + } + + var m context.Map + if err := ctx.ReadJSON(&m); err != nil { + return "" + } + + if m == nil { + return "" + } + + v, ok := m[jsonKey] + if !ok { + return "" + } + + tok, ok := v.(string) + if !ok { + return "" + } + + return tok + } +} + +// JWT holds the necessary information the middleware need +// to sign and verify tokens. +// +// The `RSA(privateFile, publicFile, password)` package-level helper function +// can be used to decode the SignKey and VerifyKey. +type JWT struct { + // MaxAge is the expiration duration of the generated tokens. + MaxAge time.Duration + + // Extractors are used to extract a raw token string value + // from the request. + // Builtin extractors: + // * FromHeader + // * FromQuery + // * FromJSON + // Defaults to a slice of `FromHeader` and `FromQuery`. + Extractors []TokenExtractor + + // Signer is used to sign the token. + // It is set on `New` and `Default` package-level functions. + Signer jose.Signer + // VerificationKey is used to verify the token (public key). + VerificationKey interface{} + + // Encrypter is used to, optionally, encrypt the token. + // It is set on `WithExpiration` method. + Encrypter jose.Encrypter + // DecriptionKey is used to decrypt the token (private key) + DecriptionKey interface{} +} + +// Random returns a new `JWT` instance +// with in-memory generated rsa256 signing and encryption keys (development). +// It panics on errors. Next server ran will invalidate all request tokens. +// +// Use the `New` package-level function for production use. +func Random(maxAge time.Duration) *JWT { + sigKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + panic(err) + } + + j, err := New(maxAge, RS256, sigKey) + if err != nil { + panic(err) + } + + encKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + panic(err) + } + + err = j.WithEncryption(A128CBCHS256, RSA15, encKey) + if err != nil { + panic(err) + } + + return j +} + +type privateKey interface{ Public() crypto.PublicKey } + +// New returns a new JWT instance. +// It accepts a maximum time duration for token expiration +// and the algorithm among with its key for signing and verification. +// +// See `WithEncryption` method to add token encryption too. +// Use `Token` method to generate a new token string +// and `VerifyToken` method to decrypt, verify and bind claims of an incoming request token. +// Token, by default, is extracted by "Authorization: Bearer {token}" request header and +// url query parameter of "token". Token extractors can be modified through the `Extractors` field. +// +// For example, if you want to sign and verify using RSA-256 key: +// 1. Generate key file, e.g: +// $ openssl genrsa -des3 -out private.pem 2048 +// 2. Read file contents with io.ReadFile("./private.pem") +// 3. Pass the []byte result to the `MustParseRSAPrivateKey(contents, password)` package-level helper +// 4. Use the result *rsa.PrivateKey as "key" input parameter of this `New` function. +// +// See aliases.go file for available algorithms. +func New(maxAge time.Duration, alg SignatureAlgorithm, key interface{}) (*JWT, error) { + sig, err := jose.NewSigner(jose.SigningKey{ + Algorithm: alg, + Key: key, + }, (&jose.SignerOptions{}).WithType("JWT")) + + if err != nil { + return nil, err + } + + j := &JWT{ + Signer: sig, + VerificationKey: key, + MaxAge: maxAge, + Extractors: []TokenExtractor{FromHeader, FromQuery}, + } + + if s, ok := key.(privateKey); ok { + j.VerificationKey = s.Public() + } + + return j, nil +} + +// WithEncryption method enables encryption and decryption of the token. +// It sets an appropriate encrypter(`Encrypter` and the `DecriptionKey` fields) based on the key type. +func (j *JWT) WithEncryption(contentEncryption ContentEncryption, alg KeyAlgorithm, key interface{}) error { + var publicKey interface{} = key + if s, ok := key.(privateKey); ok { + publicKey = s.Public() + } + + enc, err := jose.NewEncrypter(contentEncryption, jose.Recipient{ + Algorithm: alg, + Key: publicKey, + }, + (&jose.EncrypterOptions{}).WithType("JWT").WithContentType("JWT"), + ) + + if err != nil { + return err + } + + j.Encrypter = enc + j.DecriptionKey = key + return nil +} + +// Expiry returns a new standard Claims with +// the `Expiry` and `IssuedAt` fields of the "claims" filled +// based on the given "maxAge" duration. +// +// See the `JWT.Expiry` method too. +func Expiry(maxAge time.Duration, claims Claims) Claims { + now := time.Now() + claims.Expiry = jwt.NewNumericDate(now.Add(maxAge)) + claims.IssuedAt = jwt.NewNumericDate(now) + return claims +} + +// Expiry method same as `Expiry` package-level function, +// it returns a Claims with the expiration fields of the "claims" +// filled based on the JWT's `MaxAge` field. +// Only use it when this standard "claims" +// is embedded on a custom claims structure. +// Usage: +// type UserClaims struct { +// jwt.Claims +// Username string +// } +// [...] +// standardClaims := j.Expiry(jwt.Claims{...}) +// customClaims := UserClaims{ +// Claims: standardClaims, +// Username: "kataras", +// } +// j.WriteToken(ctx, customClaims) +func (j *JWT) Expiry(claims Claims) Claims { + return Expiry(j.MaxAge, claims) +} + +// Token generates and returns a new token string. +// See `VerifyToken` too. +func (j *JWT) Token(claims interface{}) (string, error) { + if c, ok := claims.(Claims); ok { + claims = Expiry(j.MaxAge, c) + } + + var ( + token string + err error + ) + + // jwt.Builder and jwt.NestedBuilder contain same methods but they are not the same. + if j.DecriptionKey != nil { + token, err = jwt.SignedAndEncrypted(j.Signer, j.Encrypter).Claims(claims).CompactSerialize() + } else { + token, err = jwt.Signed(j.Signer).Claims(claims).CompactSerialize() + } + + if err != nil { + return "", err + } + + return token, nil +} + +// WriteToken is a helper which just generates(calls the `Token` method) and writes +// a new token to the client in plain text format. +// +// Use the `Token` method to get a new generated token raw string value. +func (j *JWT) WriteToken(ctx context.Context, claims interface{}) error { + token, err := j.Token(claims) + if err != nil { + ctx.StatusCode(500) + return err + } + + _, err = ctx.WriteString(token) + return err +} + +var ( + // ErrTokenMissing when token cannot be extracted from the request. + ErrTokenMissing = errors.New("token is missing") + // ErrTokenInvalid when incoming token is invalid. + ErrTokenInvalid = errors.New("token is invalid") + // ErrTokenExpired when incoming token has expired. + ErrTokenExpired = errors.New("token has expired") +) + +type ( + claimsValidator interface { + ValidateWithLeeway(e jwt.Expected, leeway time.Duration) error + } + claimsAlternativeValidator interface { + Validate() error + } +) + +// IsValidated reports whether a token is already validated through +// `VerifyToken`. It returns true when the claims are compatible +// validators: a `Claims` value or a value that implements the `Validate() error` method. +func IsValidated(ctx context.Context) bool { // see the `ReadClaims`. + return ctx.Values().Get(needsValidationContextKey) == nil +} + +func validateClaims(ctx context.Context, claimsPtr interface{}) (err error) { + switch claims := claimsPtr.(type) { + case claimsValidator: + err = claims.ValidateWithLeeway(jwt.Expected{Time: time.Now()}, 0) + case claimsAlternativeValidator: + err = claims.Validate() + default: + ctx.Values().Set(needsValidationContextKey, struct{}{}) + } + + if err != nil { + if err == jwt.ErrExpired { + return ErrTokenExpired + } + } + + return err +} + +// VerifyToken verifies (and decrypts) the request token, +// it also validates and binds the parsed token's claims to the "claimsPtr" (destination). +// It does return a nil error on success. +func (j *JWT) VerifyToken(ctx context.Context, claimsPtr interface{}) error { + var token string + + for _, extract := range j.Extractors { + if token = extract(ctx); token != "" { + break // ok we found it. + } + } + + if token == "" { + return ErrTokenMissing + } + + var ( + parsedToken *jwt.JSONWebToken + err error + ) + + if j.DecriptionKey != nil { + t, cerr := jwt.ParseSignedAndEncrypted(token) + if cerr != nil { + return cerr + } + + parsedToken, err = t.Decrypt(j.DecriptionKey) + } else { + parsedToken, err = jwt.ParseSigned(token) + } + if err != nil { + return ErrTokenInvalid + } + + if err = parsedToken.Claims(j.VerificationKey, claimsPtr); err != nil { + return ErrTokenInvalid + } + + return validateClaims(ctx, claimsPtr) +} + +const ( + // ClaimsContextKey is the context key which the jwt claims are stored from the `Verify` method. + ClaimsContextKey = "iris.jwt.claims" + needsValidationContextKey = "iris.jwt.claims.unvalidated" +) + +// Verify is a middleware. It verifies and optionally decrypts an incoming request token. +// It does write a 401 unauthorized status code if verification or decryption failed. +// It calls the `ctx.Next` on verified requests. +// +// See `VerifyToken` instead to verify, decrypt, validate and acquire the claims at once. +// +// A call of `ReadClaims` is required to validate and acquire the jwt claims +// on the next request. +func (j *JWT) Verify(ctx context.Context) { + var raw json.RawMessage + if err := j.VerifyToken(ctx, &raw); err != nil { + ctx.StopWithStatus(401) + return + } + + ctx.Values().Set(ClaimsContextKey, raw) + ctx.Next() +} + +// ReadClaims binds the "claimsPtr" (destination) +// to the verified (and decrypted) claims. +// The `Verify` method should be called first (registered as middleware). +func ReadClaims(ctx context.Context, claimsPtr interface{}) error { + v := ctx.Values().Get(ClaimsContextKey) + if v == nil { + return ErrTokenMissing + } + + raw, ok := v.(json.RawMessage) + if !ok { + return ErrTokenMissing + } + + err := json.Unmarshal(raw, claimsPtr) + if err != nil { + return err + } + + // If already validated on VerifyToken (a claimsValidator/claimsAlternativeValidator) + // then no need to perform the check again. + if !IsValidated(ctx) { + ctx.Values().Remove(needsValidationContextKey) + return validateClaims(ctx, claimsPtr) + } + + return nil +} diff --git a/middleware/jwt/jwt_test.go b/middleware/jwt/jwt_test.go new file mode 100644 index 00000000..e580db94 --- /dev/null +++ b/middleware/jwt/jwt_test.go @@ -0,0 +1,119 @@ +// Package jwt_test contains simple Iris jwt tests. Most of the jwt functionality is already tested inside the jose package itself. +package jwt_test + +import ( + "testing" + "time" + + "github.com/kataras/iris/v12" + "github.com/kataras/iris/v12/httptest" + "github.com/kataras/iris/v12/middleware/jwt" +) + +type userClaims struct { + jwt.Claims + Username string +} + +const testMaxAge = 3 * time.Second + +// Random RSA verification and encryption. +func TestRSA(t *testing.T) { + j := jwt.Random(testMaxAge) + testWriteVerifyToken(t, j) +} + +// HMAC verification and encryption. +func TestHMAC(t *testing.T) { + j, err := jwt.New(testMaxAge, jwt.HS256, []byte("secret")) + if err != nil { + t.Fatal(err) + } + err = j.WithEncryption(jwt.A128GCM, jwt.DIRECT, []byte("itsa16bytesecret")) + if err != nil { + t.Fatal(err) + } + + testWriteVerifyToken(t, j) +} + +// HMAC verification only (unecrypted). +func TestVerify(t *testing.T) { + j, err := jwt.New(testMaxAge, jwt.HS256, []byte("another secret")) + if err != nil { + t.Fatal(err) + } + testWriteVerifyToken(t, j) +} + +func testWriteVerifyToken(t *testing.T, j *jwt.JWT) { + t.Helper() + + j.Extractors = append(j.Extractors, jwt.FromJSON("access_token")) + standardClaims := jwt.Claims{Issuer: "an-issuer", Audience: jwt.Audience{"an-audience"}} + expectedClaims := userClaims{ + Claims: j.Expiry(standardClaims), + Username: "kataras", + } + + app := iris.New() + app.Get("/auth", func(ctx iris.Context) { + j.WriteToken(ctx, expectedClaims) + }) + + app.Post("/restricted", func(ctx iris.Context) { + var claims userClaims + if err := j.VerifyToken(ctx, &claims); err != nil { + ctx.StopWithStatus(iris.StatusUnauthorized) + return + } + + ctx.JSON(claims) + }) + + app.Post("/restricted_middleware", j.Verify, func(ctx iris.Context) { + var claims userClaims + if err := jwt.ReadClaims(ctx, &claims); err != nil { + ctx.StopWithStatus(iris.StatusUnauthorized) + return + } + + ctx.JSON(claims) + }) + + e := httptest.New(t, app) + + // Get token. + rawToken := e.GET("/auth").Expect().Status(httptest.StatusOK).Body().Raw() + if rawToken == "" { + t.Fatalf("empty token") + } + + restrictedPaths := [...]string{"/restricted", "/restricted_middleware"} + + now := time.Now() + for _, path := range restrictedPaths { + // Authorization Header. + e.POST(path).WithHeader("Authorization", "Bearer "+rawToken).Expect(). + Status(httptest.StatusOK).JSON().Equal(expectedClaims) + + // URL Query. + e.POST(path).WithQuery("token", rawToken).Expect(). + Status(httptest.StatusOK).JSON().Equal(expectedClaims) + + // JSON Body. + e.POST(path).WithJSON(iris.Map{"access_token": rawToken}).Expect(). + Status(httptest.StatusOK).JSON().Equal(expectedClaims) + + // Missing "Bearer". + e.POST(path).WithHeader("Authorization", rawToken).Expect(). + Status(httptest.StatusUnauthorized) + } + expireRemDur := testMaxAge - time.Since(now) + + // Expiration. + time.Sleep(expireRemDur /* -end */) + for _, path := range restrictedPaths { + e.POST(path).WithQuery("token", rawToken).Expect().Status(httptest.StatusUnauthorized) + } +} diff --git a/middleware/jwt/util.go b/middleware/jwt/util.go new file mode 100644 index 00000000..87159ad6 --- /dev/null +++ b/middleware/jwt/util.go @@ -0,0 +1,98 @@ +package jwt + +import ( + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "errors" +) + +// ErrNotPEM is a panic error of the MustParseXXX functions when the data are not PEM-encoded. +var ErrNotPEM = errors.New("key must be PEM encoded") + +// MustParseRSAPrivateKey encodes a PEM-encoded PKCS1 or PKCS8 private key protected with a password. +func MustParseRSAPrivateKey(key, password []byte) *rsa.PrivateKey { + block, _ := pem.Decode(key) + if block == nil { + panic(ErrNotPEM) + } + + var ( + parsedKey interface{} + err error + ) + + var blockDecrypted []byte + if blockDecrypted, err = x509.DecryptPEMBlock(block, password); err != nil { + panic(err) + } + + if parsedKey, err = x509.ParsePKCS1PrivateKey(blockDecrypted); err != nil { + if parsedKey, err = x509.ParsePKCS8PrivateKey(blockDecrypted); err != nil { + panic(err) + } + } + + privateKey, ok := parsedKey.(*rsa.PrivateKey) + if !ok { + panic("key is not of type *rsa.PrivateKey") + } + + return privateKey +} + +// MustParseRSAPublicKey encodes a PEM encoded PKCS1 or PKCS8 public key. +func MustParseRSAPublicKey(key []byte) *rsa.PublicKey { + var err error + + // Parse PEM block + var block *pem.Block + if block, _ = pem.Decode(key); block == nil { + panic(ErrNotPEM) + } + + // Parse the key + var parsedKey interface{} + if parsedKey, err = x509.ParsePKIXPublicKey(block.Bytes); err != nil { + if cert, err := x509.ParseCertificate(block.Bytes); err == nil { + parsedKey = cert.PublicKey + } else { + panic(err) + } + } + + var pkey *rsa.PublicKey + var ok bool + if pkey, ok = parsedKey.(*rsa.PublicKey); !ok { + panic("key is not of type *rsa.PublicKey") + } + + return pkey +} + +/* +// MustParseEd25519 PEM encoded Ed25519. +func MustParseEd25519(key []byte) ed25519.PrivateKey { + // Parse PEM block + block, _ := pem.Decode(key) + if block == nil { + panic(ErrNotPEM) + } + + type ed25519PrivKey struct { + Version int + ObjectIdentifier struct { + ObjectIdentifier asn1.ObjectIdentifier + } + PrivateKey []byte + } + + var asn1PrivKey ed25519PrivKey + if _, err := asn1.Unmarshal(block.Bytes, &asn1PrivKey); err != nil { + panic(err) + } + + privateKey := ed25519.NewKeyFromSeed(asn1PrivKey.PrivateKey[2:]) + return privateKey +} +*/