diff --git a/HISTORY.md b/HISTORY.md index 95ad69f3..940c51cd 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -29,7 +29,7 @@ The codebase for Dependency Injection, Internationalization and localization and ## Fixes and Improvements - A generic User interface, see the `Context.SetUser/User` methods in the New Context Methods section for more. In-short, the basicauth middleware's stored user can now be retrieved through `Context.User()` which provides more information than the native `ctx.Request().BasicAuth()` method one. Third-party authentication middleware creators can benefit of these two methods, plus the Logout below. -- A `Context.Logout` method is added, can be used to invalidate [basicauth](https://github.com/kataras/iris/blob/master/_examples/auth/basicauth/main.go) client credentials. +- A `Context.Logout` method is added, can be used to invalidate [basicauth](https://github.com/kataras/iris/blob/master/_examples/auth/basicauth/main.go) or [jwt](https://github.com/kataras/iris/blob/master/_examples/auth/jwt/main.go) client credentials. - Add the ability to [share functions](https://github.com/kataras/iris/tree/master/_examples/routing/writing-a-middleware/share-funcs) between handlers chain and add an [example](https://github.com/kataras/iris/tree/master/_examples/routing/writing-a-middleware/share-services) on sharing Go structures (aka services). - Add the new `Party.UseOnce` method to the `*Route` @@ -315,7 +315,7 @@ var dirOpts = iris.DirOptions{ - `Context.SetFunc(name string, fn interface{}, persistenceArgs ...interface{})` and `Context.CallFunc(name string, args ...interface{}) ([]reflect.Value, error)` to allow middlewares to share functions dynamically when the type of the function is not predictable, see the [example](https://github.com/kataras/iris/tree/master/_examples/routing/writing-a-middleware/share-funcs) for more. - `Context.TextYAML(interface{}) error` same as `Context.YAML` but with set the Content-Type to `text/yaml` instead (Google Chrome renders it as text). - `Context.IsDebug() bool` reports whether the application is running under debug/development mode. It is a shortcut of Application.Logger().Level >= golog.DebugLevel. -- `Context.IsRecovered() bool` reports whether the current request was recovered from the [recover middleware](https://github.com/kataras/iris/tree/master/middleware/recover). Also the `iris.IsErrPrivate` function and `iris.ErrPrivate` interface have been introduced. +- `Context.IsRecovered() bool` reports whether the current request was recovered from the [recover middleware](https://github.com/kataras/iris/tree/master/middleware/recover). Also the `Context.GetErrPublic() (bool, error)`, `Context.SetErrPrivate(err error)` methods and `iris.ErrPrivate` interface have been introduced. - `Context.RecordBody()` same as the Application's `DisableBodyConsumptionOnUnmarshal` configuration field but registers per chain of handlers. It makes the request body readable more than once. - `Context.IsRecordingBody() bool` reports whether the request body can be readen multiple times. - `Context.ReadHeaders(ptr interface{}) error` binds request headers to "ptr". [Example](https://github.com/kataras/iris/blob/master/_examples/request-body/read-headers/main.go). @@ -490,6 +490,7 @@ Prior to this version the `iris.Context` was the only one dependency that has be | [net.IP](https://golang.org/pkg/net/#IP) | `net.ParseIP(ctx.RemoteAddr())` | | [mvc.Code](https://pkg.go.dev/github.com/kataras/iris/v12/mvc?tab=doc#Code) | `ctx.GetStatusCode() int` | | [mvc.Err](https://pkg.go.dev/github.com/kataras/iris/v12/mvc?tab=doc#Err) | `ctx.GetErr() error` | +| [iris/context.User](https://pkg.go.dev/github.com/kataras/iris/v12/context?tab=doc#User) | `ctx.User()` | | `string`, | | | `int, int8, int16, int32, int64`, | | | `uint, uint8, uint16, uint32, uint64`, | | diff --git a/_examples/auth/jwt/main.go b/_examples/auth/jwt/main.go index 2c31091f..ee2ef42f 100644 --- a/_examples/auth/jwt/main.go +++ b/_examples/auth/jwt/main.go @@ -1,3 +1,6 @@ +// Package main shows how you can use the Iris unique JWT middleware. +// The file contains different kind of examples that all do the same job but, +// depending on your code style and your application's requirements, you may choose one over other. package main import ( @@ -7,10 +10,31 @@ import ( "github.com/kataras/iris/v12/middleware/jwt" ) -// UserClaims a custom claims structure. You can just use jwt.Claims too. -type UserClaims struct { - jwt.Claims - Username string +// Claims a custom claims structure. +type Claims struct { + // Optionally define JWT's "iss" (Issuer), + // "sub" (Subject) and "aud" (Audience) for issuer and subject. + // The JWT's "exp" (expiration) and "iat" (issued at) are automatically + // set by the middleware. + Issuer string `json:"iss"` + Subject string `json:"sub"` + Audience []string `json:"aud"` + /* + Note that the above fields can be also extracted via: + jwt.GetTokenInfo(ctx).Claims + But in that example, we just showcase how these info can be embedded + inside your own Go structure. + */ + + // Optionally define a "exp" (Expiry), + // unlike the rest, this is unset on creation + // (unless you want to override the middleware's max age option), + // it's filled automatically by the JWT middleware + // when the request token is verified. + // See the POST /user route. + Expiry *jwt.NumericDate `json:"exp"` + + Username string `json:"username"` } func main() { @@ -20,56 +44,241 @@ func main() { // // Use the `jwt.New` instead for more flexibility, if necessary. j := jwt.HMAC(15*time.Minute, "secret", "itsa16bytesecret") - // By default it extracts the token from url parameter "token={token}" - // and the Authorization Bearer {token} header. - // You can also take token from JSON body: - // j.Extractors = append(j.Extractors, jwt.FromJSON) + + /* + By default it extracts the token from url parameter "token={token}" + and the Authorization Bearer {token} header. + You can also take token from JSON body: + j.Extractors = append(j.Extractors, jwt.FromJSON) + */ + + /* Optionally, enable block list to force-invalidate + verified tokens even before their expiration time. + This is useful when the client doesn't clear + the token on a user logout by itself. + + The duration argument clears any expired token on each every tick. + There is a GC() method that can be manually called to clear expired blocked tokens + from the memory. + + j.Blocklist = jwt.NewBlocklist(30*time.Minute) + OR NewBlocklistContext(stdContext, 30*time.Minute) + + + To invalidate a verified token just call: + j.Invalidate(ctx) inside a route handler. + */ app := iris.New() app.Logger().SetLevel("debug") + app.OnErrorCode(iris.StatusUnauthorized, func(ctx iris.Context) { + // Note that, any error stored by an authentication + // method in Iris is an iris.ErrPrivate. + // Available jwt errors: + // - ErrMissing + // - ErrMissingKey + // - ErrExpired + // - ErrNotValidYet + // - ErrIssuedInTheFuture + // - ErrBlocked + // An iris.ErrPrivate SHOULD never be displayed to the client as it is; + // because it may contain critical security information about the server. + // + // Also keep in mind that JWT middleware logs verification errors to the + // application's logger ("debug") so, normally you don't have to + // bother showing the verification error to the browser/client. + // However, you can retrieve that error and do what ever you feel right: + if err := ctx.GetErr(); err != nil { + // If we have an error stored, + // (JWT middleware stores any verification errors to the Context), + // set the error as response body, + // which is the default behavior if that + // wasn't an authentication error (as explained above) + ctx.WriteString(err.Error()) + } else { + // Else, the default behavior when no error was occured; + // write the status text of the status code: + ctx.WriteString(iris.StatusText(iris.StatusUnauthorized)) + } + }) + app.Get("/authenticate", func(ctx iris.Context) { - standardClaims := jwt.Claims{Issuer: "an-issuer", Audience: jwt.Audience{"an-audience"}} - // NOTE: if custom claims then the `j.Expiry(claims)` (or jwt.Expiry(duration, claims)) - // MUST be called in order to set the expiration time. - customClaims := UserClaims{ - Claims: j.Expiry(standardClaims), + claims := &Claims{ + Issuer: "server", + Audience: []string{"user"}, Username: "kataras", } - j.WriteToken(ctx, customClaims) + // WriteToken generates and sends the token to the client. + // To generate a token use: tok, err := j.Token(claims) + // then you can write it in any form you'd like. + // The expiration JWT fields are automatically + // set by the middleware, that means that your claims value + // only needs to fill fields that your application specifically requires. + j.WriteToken(ctx, claims) }) - userRouter := app.Party("/user") + // Middleware + type-safe method, + // useful in 99% of the cases, when your application + // requires token verification under a whole path prefix, e.g. /protected: + protectedAPI := app.Party("/protected") { - // userRouter.Use(j.Verify) - // userRouter.Get("/", func(ctx iris.Context) { - // var claims UserClaims - // if err := jwt.ReadClaims(ctx, &claims); err != nil { - // // Validation-only errors, the rest are already - // // checked on `j.Verify` middleware. - // ctx.StopWithStatus(iris.StatusUnauthorized) - // return - // } - // - // ctx.Writef("Claims: %#+v\n", claims) - // }) - // - // OR: - userRouter.Get("/", func(ctx iris.Context) { - var claims UserClaims - if err := j.VerifyToken(ctx, &claims); err != nil { - ctx.StopWithStatus(iris.StatusUnauthorized) - return - } + protectedAPI.Use(j.Verify(func() interface{} { + // Must return a pointer to a type. + // + // The Iris JWT implementation is very sophisticated. + // We keep our claims in type-safe form. + // However, you are free to use raw Go maps + // (map[string]interface{} or iris.Map) too (example later on). + // + // Note that you can use the same "j" JWT instance + // to serve different types of claims on other group of routes, + // e.g. postRouter.Use(j.Verify(... return new(Post))). + return new(Claims) + })) - ctx.Writef("Username: %s\nExpires at: %s\n", claims.Username, claims.Expiry.Time()) + protectedAPI.Get("/", func(ctx iris.Context) { + claims := jwt.Get(ctx).(*Claims) + // All fields parsed from token are set to the claims, + // including the Expiry (if defined). + ctx.Writef("Username: %s\nExpires at: %s\nAudience: %s", + claims.Username, claims.Expiry.Time(), claims.Audience) }) } + // Verify token inside a handler method, + // useful when you just need to verify a token on a single spot: + app.Get("/inline", func(ctx iris.Context) { + var claims Claims + _, err := j.VerifyToken(ctx, &claims) + if err != nil { + ctx.StopWithError(iris.StatusUnauthorized, err) + return + } + + ctx.Writef("Username: %s\nExpires at: %s\n", + claims.Username, claims.Expiry.Time()) + }) + + // Use a common map as claims method, + // not recommended, as we support typed claims but + // you can do it: + app.Get("/map/authenticate", func(ctx iris.Context) { + claims := map[string]interface{}{ // or iris.Map for shortcut. + "username": "kataras", + } + + j.WriteToken(ctx, claims) + }) + + app.Get("/map/verify/middleware", j.Verify(func() interface{} { + return &iris.Map{} // or &map[string]interface{}{} + }), func(ctx iris.Context) { + claims := jwt.Get(ctx).(iris.Map) + // The Get method will unwrap the *iris.Map for you, + // so its values are directly accessible: + ctx.Writef("Username: %s\nExpires at: %s\n", + claims["username"], claims["exp"].(*jwt.NumericDate).Time()) + }) + + app.Get("/map/verify", func(ctx iris.Context) { + claims := make(iris.Map) // or make(map[string]interface{}) + + tokenInfo, err := j.VerifyToken(ctx, &claims) + if err != nil { + ctx.StopWithError(iris.StatusUnauthorized, err) + return + } + + ctx.Writef("Username: %s\nExpires at: %s\n", + claims["username"], tokenInfo.Claims.Expiry.Time()) /* the claims["exp"] is also set. */ + }) + + // Use the new Context.User() to retrieve the verified client method: + // 1. Create a go stuct that implements the context.User interface: + app.Get("/users/authenticate", func(ctx iris.Context) { + user := &User{Username: "kataras"} + j.WriteToken(ctx, user) + }) + usersAPI := app.Party("/users") + { + usersAPI.Use(j.Verify(func() interface{} { + return new(User) + })) + + usersAPI.Get("/", func(ctx iris.Context) { + user := ctx.User() + userToken, _ := user.GetToken() + /* + You can also cast it to the underline implementation + and work with its fields: + expires := user.(*User).Expiry.Time() + */ + // OR use the GetTokenInfo to get the parsed token information: + expires := jwt.GetTokenInfo(ctx).Claims.Expiry.Time() + lifetime := expires.Sub(time.Now()) // remeaning time to be expired. + + ctx.Writef("Username: %s\nAuthenticated at: %s\nLifetime: %s\nToken: %s\n", + user.GetUsername(), user.GetAuthorizedAt(), lifetime, userToken) + }) + } + + // http://localhost:8080/authenticate + // http://localhost:8080/protected?token={token} + // http://localhost:8080/inline?token={token} + // + // http://localhost:8080/map/authenticate + // http://localhost:8080/map/verify?token={token} + // http://localhost:8080/map/verify/middleware?token={token} + // + // http://localhost:8080/users/authenticate + // http://localhost:8080/users?token={token} app.Listen(":8080") } +// User is a custom implementation of the Iris Context User interface. +// Optionally, for JWT, you can also implement +// the SetToken(tok string) and +// Validate(ctx iris.Context, claims jwt.Claims, e jwt.Expected) error +// methods to set a token and add custom validation +// to a User value parsed from a token. +type User struct { + iris.User + Username string `json:"username"` + + // Optionally, declare some JWT fields, + // they are automatically filled by the middleware itself. + IssuedAt *jwt.NumericDate `json:"iat"` + Expiry *jwt.NumericDate `json:"exp"` + Token string `json:"-"` +} + +// GetUsername returns the Username. +// Look the iris/context.SimpleUser type +// for all the methods you can implement. +func (u *User) GetUsername() string { + return u.Username +} + +// GetAuthorizedAt returns the IssuedAt time. +// This and the Get/SetToken methods showcase how you can map JWT standard fields +// to an Iris Context User. +func (u *User) GetAuthorizedAt() time.Time { + return u.IssuedAt.Time() +} + +// GetToken is a User interface method. +func (u *User) GetToken() (string, error) { + return u.Token, nil +} + +// SetToken is a special jwt.TokenSetter interface which is +// called automatically when a token is parsed to this User value. +func (u *User) SetToken(tok string) { + u.Token = tok +} + /* func default_RSA_Example() { j := jwt.RSA(15*time.Minute) diff --git a/_examples/auth/jwt/refresh-token/main.go b/_examples/auth/jwt/refresh-token/main.go index 9e3be184..4f41c22d 100644 --- a/_examples/auth/jwt/refresh-token/main.go +++ b/_examples/auth/jwt/refresh-token/main.go @@ -7,133 +7,122 @@ import ( "github.com/kataras/iris/v12/middleware/jwt" ) -// UserClaims a custom claims structure. You can just use jwt.Claims too. +// UserClaims a custom access claims structure. type UserClaims struct { - jwt.Claims - Username string + // We could that JWT field to separate the access and refresh token: + // Issuer string `json:"iss"` + // But let's cover the "required" feature too, see below: + ID string `json:"user_id,required"` + Username string `json:"username,required"` } -// TokenPair holds the access token and refresh token response. -type TokenPair struct { - AccessToken string `json:"access_token"` - RefreshToken string `json:"refresh_token"` -} +// For refresh token, we will just use the jwt.Claims +// structure which contains the standard JWT fields. func main() { app := iris.New() - // Access token, short-live. - accessJWT := jwt.HMAC(15*time.Minute, "secret", "itsa16bytesecret") - // Refresh token, long-live. Important: Give different secret keys(!) - refreshJWT := jwt.HMAC(1*time.Hour, "other secret", "other16bytesecre") - // On refresh token, we extract it only from a request body - // of JSON, e.g. {"refresh_token": $token }. - // You can also do it manually in the handler level though. - refreshJWT.Extractors = []jwt.TokenExtractor{ - jwt.FromJSON("refresh_token"), + j := jwt.HMAC(15*time.Minute, "secret", "itsa16bytesecret") + + app.Get("/authenticate", func(ctx iris.Context) { + generateTokenPair(ctx, j) + }) + + app.Get("/refresh_json", func(ctx iris.Context) { + refreshTokenFromJSON(ctx, j) + }) + + protectedAPI := app.Party("/protected") + { + protectedAPI.Use(j.Verify(func() interface{} { + return new(UserClaims) + })) // OR j.VerifyToken(ctx, &claims, jwt.MeetRequirements(&UserClaims{})) + + protectedAPI.Get("/", func(ctx iris.Context) { + // Get token info, even if our UserClaims does not embed those + // through GetTokenInfo: + expiresAt := jwt.GetTokenInfo(ctx).Claims.Expiry.Time() + // Get your custom JWT claims through Get, + // which is a shortcut of GetTokenInfo(ctx).Value: + claims := jwt.Get(ctx).(*UserClaims) + + ctx.Writef("Username: %s\nExpires at: %s\n", claims.Username, expiresAt) + }) } - // Generate access and refresh tokens and send to the client. - app.Get("/authenticate", func(ctx iris.Context) { - tokenPair, err := generateTokenPair(accessJWT, refreshJWT) - if err != nil { - ctx.StopWithStatus(iris.StatusInternalServerError) - return - } - - ctx.JSON(tokenPair) - }) - - app.Get("/refresh", func(ctx iris.Context) { - // Manual (if jwt.FromJSON missing): - // var payload = struct { - // RefreshToken string `json:"refresh_token"` - // }{} - // - // err := ctx.ReadJSON(&payload) - // if err != nil { - // ctx.StatusCode(iris.StatusBadRequest) - // return - // } - // - // j.VerifyTokenString(ctx, payload.RefreshToken, &claims) - - var claims jwt.Claims - if err := refreshJWT.VerifyToken(ctx, &claims); err != nil { - ctx.Application().Logger().Warnf("verify refresh token: %v", err) - ctx.StopWithStatus(iris.StatusUnauthorized) - return - } - - userID := claims.Subject - if userID == "" { - ctx.StopWithStatus(iris.StatusUnauthorized) - return - } - - // Simulate a database call against our jwt subject. - if userID != "53afcf05-38a3-43c3-82af-8bbbe0e4a149" { - ctx.StopWithStatus(iris.StatusUnauthorized) - return - } - - // All OK, re-generate the new pair and send to client. - tokenPair, err := generateTokenPair(accessJWT, refreshJWT) - if err != nil { - ctx.StopWithStatus(iris.StatusInternalServerError) - return - } - - ctx.JSON(tokenPair) - }) - - app.Get("/", func(ctx iris.Context) { - var claims UserClaims - if err := accessJWT.VerifyToken(ctx, &claims); err != nil { - ctx.StopWithStatus(iris.StatusUnauthorized) - return - } - - ctx.Writef("Username: %s\nExpires at: %s\n", claims.Username, claims.Expiry.Time()) - }) - - // http://localhost:8080 (401) + // http://localhost:8080/protected (401) // http://localhost:8080/authenticate (200) (response JSON {access_token, refresh_token}) - // http://localhost:8080?token={access_token} (200) - // http://localhost:8080?token={refresh_token} (401) - // http://localhost:8080/refresh (request JSON{refresh_token = {refresh_token}}) (200) (response JSON {access_token, refresh_token}) + // http://localhost:8080/protected?token={access_token} (200) + // http://localhost:8080/protected?token={refresh_token} (401) + // http://localhost:8080/refresh_json (request JSON{refresh_token = {refresh_token}}) (200) (response JSON {access_token, refresh_token}) app.Listen(":8080") } -func generateTokenPair(accessJWT, refreshJWT *jwt.JWT) (TokenPair, error) { - standardClaims := jwt.Claims{Issuer: "an-issuer", Audience: jwt.Audience{"an-audience"}} +func generateTokenPair(ctx iris.Context, j *jwt.JWT) { + // Simulate a user... + userID := "53afcf05-38a3-43c3-82af-8bbbe0e4a149" - customClaims := UserClaims{ - Claims: accessJWT.Expiry(standardClaims), + // Map the current user with the refresh token, + // so we make sure, on refresh route, that this refresh token owns + // to that user before re-generate. + refresh := jwt.Claims{Subject: userID} + + access := UserClaims{ + ID: userID, Username: "kataras", } - accessToken, err := accessJWT.Token(customClaims) + // Generates a Token Pair, long-live for refresh tokens, e.g. 1 hour. + // Second argument is the refresh claims and, + // the last one is the access token's claims. + tokenPair, err := j.TokenPair(1*time.Hour, refresh, access) if err != nil { - return TokenPair{}, err + ctx.Application().Logger().Debugf("token pair: %v", err) + ctx.StopWithStatus(iris.StatusInternalServerError) + return } - // At refresh tokens you don't need any custom claims. - refreshClaims := refreshJWT.Expiry(jwt.Claims{ - ID: "refresh_kataras", - // For example, the User ID, - // this is necessary to check against the database - // if the user still exist or has credentials to access our page. - Subject: "53afcf05-38a3-43c3-82af-8bbbe0e4a149", - }) - - refreshToken, err := refreshJWT.Token(refreshClaims) - if err != nil { - return TokenPair{}, err - } - - return TokenPair{ - AccessToken: accessToken, - RefreshToken: refreshToken, - }, nil + // Send the generated token pair to the client. + // The tokenPair looks like: {"access_token": $token, "refresh_token": $token} + ctx.JSON(tokenPair) +} + +func refreshTokenFromJSON(ctx iris.Context, j *jwt.JWT) { + var tokenPair jwt.TokenPair + + // Grab the refresh token from a JSON body (you can let it fetch by URL parameter too but + // it's common practice that you read it from a json body as + // it may contain the access token too (the same response we sent on generateTokenPair)). + err := ctx.ReadJSON(&tokenPair) + if err != nil { + ctx.StatusCode(iris.StatusBadRequest) + return + } + + var refreshClaims jwt.Claims + err = j.VerifyTokenString(ctx, tokenPair.RefreshToken, &refreshClaims) + if err != nil { + ctx.Application().Logger().Debugf("verify refresh token: %v", err) + ctx.StatusCode(iris.StatusUnauthorized) + return + } + + // Assuming you have access to the current user, e.g. sessions. + // + // Simulate a database call against our jwt subject + // to make sure that this refresh token is a pair generated by this user. + currentUserID := "53afcf05-38a3-43c3-82af-8bbbe0e4a149" + + userID := refreshClaims.Subject + if userID != currentUserID { + ctx.StopWithStatus(iris.StatusUnauthorized) + return + } + // + // Otherwise, the request must contain the (old) access token too, + // even if it's invalid, we can still fetch its fields, such as the user id. + // [...leave it for you] + + // All OK, re-generate the new pair and send to client. + generateTokenPair(ctx, j) } diff --git a/_examples/logging/request-logger/accesslog-simple/main.go b/_examples/logging/request-logger/accesslog-simple/main.go index c966e2be..77880c8b 100644 --- a/_examples/logging/request-logger/accesslog-simple/main.go +++ b/_examples/logging/request-logger/accesslog-simple/main.go @@ -29,7 +29,10 @@ func makeAccessLog() *accesslog.AccessLog { ac.PanicLog = accesslog.LogHandler // Set Custom Formatter: - ac.SetFormatter(&accesslog.JSON{}) + ac.SetFormatter(&accesslog.JSON{ + Indent: " ", + HumanTime: true, + }) // ac.SetFormatter(&accesslog.CSV{}) // ac.SetFormatter(&accesslog.Template{Text: "{{.Code}}"}) diff --git a/_examples/sessions/overview/example/example.go b/_examples/sessions/overview/example/example.go index ed4fb5e4..e47624d7 100644 --- a/_examples/sessions/overview/example/example.go +++ b/_examples/sessions/overview/example/example.go @@ -38,9 +38,9 @@ func NewApp(sess *sessions.Sessions) *iris.Application { session := sessions.Get(ctx) isNew := session.IsNew() - session.Set("name", "iris") + session.Set("username", "iris") - ctx.Writef("All ok session set to: %s [isNew=%t]", session.GetString("name"), isNew) + ctx.Writef("All ok session set to: %s [isNew=%t]", session.GetString("username"), isNew) }) app.Get("/get", func(ctx iris.Context) { @@ -48,9 +48,9 @@ func NewApp(sess *sessions.Sessions) *iris.Application { // get a specific value, as string, // if not found then it returns just an empty string. - name := session.GetString("name") + name := session.GetString("username") - ctx.Writef("The name on the /set was: %s", name) + ctx.Writef("The username on the /set was: %s", name) }) app.Get("/set-struct", func(ctx iris.Context) { diff --git a/aliases.go b/aliases.go index cd49bec7..b21ace34 100644 --- a/aliases.go +++ b/aliases.go @@ -59,6 +59,10 @@ type ( Filter = context.Filter // A Map is an alias of map[string]interface{}. Map = context.Map + // User is a generic view of an authorized client. + // See `Context.User` and `SetUser` methods for more. + // An alias for the `context/User` type. + User = context.User // Problem Details for HTTP APIs. // Pass a Problem value to `context.Problem` to // write an "application/problem+json" response. @@ -475,8 +479,6 @@ var ( // on post data, versioning feature and others. // An alias of `context.ErrNotFound`. ErrNotFound = context.ErrNotFound - // IsErrPrivate reports whether the given "err" is a private one. - IsErrPrivate = context.IsErrPrivate // NewProblem returns a new Problem. // Head over to the `Problem` type godoc for more. // @@ -502,6 +504,9 @@ var ( // // A shortcut for the `context#ErrPushNotSupported`. ErrPushNotSupported = context.ErrPushNotSupported + // PrivateError accepts an error and returns a wrapped private one. + // A shortcut for the `context#PrivateError`. + PrivateError = context.PrivateError ) // HTTP Methods copied from `net/http`. diff --git a/context/context.go b/context/context.go index 42b62393..aa15cbac 100644 --- a/context/context.go +++ b/context/context.go @@ -714,9 +714,10 @@ func (ctx *Context) StopWithError(statusCode int, err error) { } ctx.SetErr(err) - if IsErrPrivate(err) { - // error is private, we can't render it, instead . - // let the error handler render the code text. + if _, ok := err.(ErrPrivate); ok { + // error is private, we SHOULD not render it, + // leave the error handler alone to + // render the code's text instead. ctx.StopWithStatus(statusCode) return } @@ -5065,8 +5066,6 @@ func (ctx *Context) IsDebug() bool { return ctx.app.IsDebug() } -const errorContextKey = "iris.context.error" - // SetErr is just a helper that sets an error value // as a context value, it does nothing more. // Also, by-default this error's value is written to the client @@ -5088,14 +5087,71 @@ func (ctx *Context) SetErr(err error) { // GetErr is a helper which retrieves // the error value stored by `SetErr`. +// +// Note that, if an error was stored by `SetErrPrivate` +// then it returns the underline/original error instead +// of the internal error wrapper. func (ctx *Context) GetErr() error { + _, err := ctx.GetErrPublic() + return err +} + +// ErrPrivate if provided then the error saved in context +// should NOT be visible to the client no matter what. +type ErrPrivate interface { + error + IrisPrivateError() +} + +// An internal wrapper for the `SetErrPrivate` method. +type privateError struct{ error } + +func (e privateError) IrisPrivateError() {} + +// PrivateError accepts an error and returns a wrapped private one. +func PrivateError(err error) ErrPrivate { + if err == nil { + return nil + } + + errPrivate, ok := err.(ErrPrivate) + if !ok { + errPrivate = privateError{err} + } + + return errPrivate +} + +const errorContextKey = "iris.context.error" + +// SetErrPrivate sets an error that it's only accessible through `GetErr` +// and it should never be sent to the client. +// +// Same as ctx.SetErr with an error that completes the `ErrPrivate` interface. +// See `GetErrPublic` too. +func (ctx *Context) SetErrPrivate(err error) { + ctx.SetErr(PrivateError(err)) +} + +// GetErrPublic reports whether the stored error +// can be displayed to the client without risking +// to expose security server implementation to the client. +// +// If the error is not nil, it is always the original one. +func (ctx *Context) GetErrPublic() (bool, error) { if v := ctx.values.Get(errorContextKey); v != nil { - if err, ok := v.(error); ok { - return err + switch err := v.(type) { + case privateError: + // If it's an error set by SetErrPrivate then unwrap it. + return false, err.error + case ErrPrivate: + return false, err + case error: + return true, err } } - return nil + return false, nil } // ErrPanicRecovery may be returned from `Context` actions of a `Handler` @@ -5135,22 +5191,6 @@ func IsErrPanicRecovery(err error) (*ErrPanicRecovery, bool) { return v, ok } -// ErrPrivate if provided then the error saved in context -// should NOT be visible to the client no matter what. -type ErrPrivate interface { - IrisPrivateError() -} - -// IsErrPrivate reports whether the given "err" is a private one. -func IsErrPrivate(err error) bool { - if err == nil { - return false - } - - _, ok := err.(ErrPrivate) - return ok -} - // IsRecovered reports whether this handler has been recovered // by the Iris recover middleware. func (ctx *Context) IsRecovered() (*ErrPanicRecovery, bool) { diff --git a/context/context_user.go b/context/context_user.go index 093f3dcd..1e27eb62 100644 --- a/context/context_user.go +++ b/context/context_user.go @@ -2,7 +2,9 @@ package context import ( "errors" + "strings" "time" + "unicode" ) // ErrNotSupported is fired when a specific method is not implemented @@ -21,6 +23,13 @@ var ErrNotSupported = errors.New("not supported") // // The caller is free to cast this with the implementation directly // when special features are offered by the authorization system. +// +// To make optional some of the fields you can just embed the User interface +// and implement whatever methods you want to support. +// +// There are two builtin implementations of the User interface: +// - SimpleUser (type-safe) +// - UserMap (wraps a map[string]interface{}) type User interface { // GetAuthorization should return the authorization method, // e.g. Basic Authentication. @@ -35,7 +44,33 @@ type User interface { GetPassword() string // GetEmail should return the e-mail of the User. GetEmail() string -} + // GetRoles should optionally return the specific user's roles. + // Returns `ErrNotSupported` if this method is not + // implemented by the User implementation. + GetRoles() ([]string, error) + // GetToken should optionally return a token used + // to authorize this User. + GetToken() (string, error) + // GetField should optionally return a dynamic field + // based on its key. Useful for custom user fields. + // Keep in mind that these fields are encoded as a separate JSON key. + GetField(key string) (interface{}, error) +} /* Notes: +We could use a structure of User wrapper and separate interfaces for each of the methods +so they return ErrNotSupported if the implementation is missing it, so the `Features` +field and HasUserFeature can be omitted and +add a Raw() interface{} to return the underline User implementation too. +The advandages of the above idea is that we don't have to add new methods +for each of the builtin features and we can keep the (assumed) struct small. +But we dont as it has many disadvantages, unless is requested. + +The disadvantage of the current implementation is that the developer MUST +complete the whole interface in order to be a valid User and if we add +new methods in the future their implementation will break +(unless they have a static interface implementation check as we have on SimpleUser). +We kind of by-pass this disadvantage by providing a SimpleUser which can be embedded (as pointer) +to the end-developer's custom implementations. +*/ // FeaturedUser optional interface that a User can implement. type FeaturedUser interface { @@ -55,6 +90,9 @@ const ( UsernameFeature PasswordFeature EmailFeature + RolesFeature + TokenFeature + FieldsFeature ) // HasUserFeature reports whether the "u" User @@ -80,13 +118,16 @@ func HasUserFeature(user User, feature UserFeature) (bool, error) { type SimpleUser struct { Authorization string `json:"authorization"` AuthorizedAt time.Time `json:"authorized_at"` - Username string `json:"username"` + Username string `json:"username,omitempty"` Password string `json:"-"` Email string `json:"email,omitempty"` - Features []UserFeature `json:"-"` + Roles []string `json:"roles,omitempty"` + Features []UserFeature `json:"features,omitempty"` + Token string `json:"token,omitempty"` + Fields Map `json:"fields,omitempty"` } -var _ User = (*SimpleUser)(nil) +var _ FeaturedUser = (*SimpleUser)(nil) // GetAuthorization returns the authorization method, // e.g. Basic Authentication. @@ -115,6 +156,39 @@ func (u *SimpleUser) GetEmail() string { return u.Email } +// GetRoles returns the specific user's roles. +// Returns with `ErrNotSupported` if the Roles field is not initialized. +func (u *SimpleUser) GetRoles() ([]string, error) { + if u.Roles == nil { + return nil, ErrNotSupported + } + + return u.Roles, nil +} + +// GetToken returns the token associated with this User. +// It may return empty if the User is not featured with a Token. +// +// The implementation can change that behavior. +// Returns with `ErrNotSupported` if the Token field is empty. +func (u *SimpleUser) GetToken() (string, error) { + if u.Token == "" { + return "", ErrNotSupported + } + + return u.Token, nil +} + +// GetField optionally returns a dynamic field from the `Fields` field +// based on its key. +func (u *SimpleUser) GetField(key string) (interface{}, error) { + if u.Fields == nil { + return nil, ErrNotSupported + } + + return u.Fields[key], nil +} + // GetFeatures returns a list of features // this User implementation offers. func (u *SimpleUser) GetFeatures() []UserFeature { @@ -140,5 +214,159 @@ func (u *SimpleUser) GetFeatures() []UserFeature { features = append(features, EmailFeature) } + if u.Roles != nil { + features = append(features, RolesFeature) + } + + if u.Fields != nil { + features = append(features, FieldsFeature) + } + return features } + +// UserMap can be used to convert a common map[string]interface{} to a User. +// Usage: +// user := map[string]interface{}{ +// "username": "kataras", +// "age" : 27, +// } +// ctx.SetUser(UserMap(user)) +// OR +// user := UserMap{"key": "value",...} +// ctx.SetUser(user) +// [...] +// username := ctx.User().GetUsername() +// age := ctx.User().GetField("age").(int) +// OR cast it: +// user := ctx.User().(UserMap) +// username := user["username"].(string) +// age := user["age"].(int) +type UserMap Map + +var _ FeaturedUser = UserMap{} + +// GetAuthorization returns the authorization or Authorization value of the map. +func (u UserMap) GetAuthorization() string { + return u.str("authorization") +} + +// GetAuthorizedAt returns the authorized_at or Authorized_At value of the map. +func (u UserMap) GetAuthorizedAt() time.Time { + return u.time("authorized_at") +} + +// GetUsername returns the username or Username value of the map. +func (u UserMap) GetUsername() string { + return u.str("username") +} + +// GetPassword returns the password or Password value of the map. +func (u UserMap) GetPassword() string { + return u.str("password") +} + +// GetEmail returns the email or Email value of the map. +func (u UserMap) GetEmail() string { + return u.str("email") +} + +// GetRoles returns the roles or Roles value of the map. +func (u UserMap) GetRoles() ([]string, error) { + if s := u.strSlice("roles"); s != nil { + return s, nil + } + + return nil, ErrNotSupported +} + +// GetToken returns the roles or Roles value of the map. +func (u UserMap) GetToken() (string, error) { + if s := u.str("token"); s != "" { + return s, nil + } + + return "", ErrNotSupported +} + +// GetField returns the raw map's value based on its "key". +// It's not kind of useful here as you can just use the map. +func (u UserMap) GetField(key string) (interface{}, error) { + return u[key], nil +} + +// GetFeatures returns a list of features +// this map offers. +func (u UserMap) GetFeatures() []UserFeature { + if v := u.val("features"); v != nil { // if already contain features. + if features, ok := v.([]UserFeature); ok { + return features + } + } + + // else try to resolve from map values. + features := []UserFeature{FieldsFeature} + + if !u.GetAuthorizedAt().IsZero() { + features = append(features, AuthorizedAtFeature) + } + + if u.GetUsername() != "" { + features = append(features, UsernameFeature) + } + + if u.GetPassword() != "" { + features = append(features, PasswordFeature) + } + + if u.GetEmail() != "" { + features = append(features, EmailFeature) + } + + if roles, err := u.GetRoles(); err == nil && roles != nil { + features = append(features, RolesFeature) + } + + return features +} + +func (u UserMap) val(key string) interface{} { + isTitle := unicode.IsTitle(rune(key[0])) // if starts with uppercase. + if isTitle { + key = strings.ToLower(key) + } + + return u[key] +} + +func (u UserMap) str(key string) string { + if v := u.val(key); v != nil { + if s, ok := v.(string); ok { + return s + } + + // exists or not we don't care, if it's invalid type we don't fill it. + } + + return "" +} + +func (u UserMap) strSlice(key string) []string { + if v := u.val(key); v != nil { + if s, ok := v.([]string); ok { + return s + } + } + + return nil +} + +func (u UserMap) time(key string) time.Time { + if v := u.val(key); v != nil { + if t, ok := v.(time.Time); ok { + return t + } + } + + return time.Time{} +} diff --git a/core/router/handler.go b/core/router/handler.go index 08195abf..3ec5a483 100644 --- a/core/router/handler.go +++ b/core/router/handler.go @@ -129,15 +129,14 @@ type RoutesProvider interface { // api builder } func defaultErrorHandler(ctx *context.Context) { - if err := ctx.GetErr(); err != nil { - if !context.IsErrPrivate(err) { - ctx.WriteString(err.Error()) - return - } + if ok, err := ctx.GetErrPublic(); ok { + // If an error is stored and it's not a private one + // write it to the response body. + ctx.WriteString(err.Error()) + return } - + // Otherwise, write the code's text instead. ctx.WriteString(context.StatusText(ctx.GetStatusCode())) - } func (h *routerHandler) Build(provider RoutesProvider) error { diff --git a/go.mod b/go.mod index fd3c1ee9..6f9f219c 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.15 require ( github.com/BurntSushi/toml v0.3.1 - github.com/CloudyKit/jet/v5 v5.0.3 + github.com/CloudyKit/jet/v5 v5.1.0 github.com/Shopify/goreferrer v0.0.0-20181106222321-ec9c9a553398 github.com/andybalholm/brotli v1.0.1 github.com/aymerick/raymond v2.0.3-0.20180322193309-b565731e1464+incompatible @@ -12,7 +12,7 @@ require ( github.com/eknkc/amber v0.0.0-20171010120322-cdade1c07385 github.com/fatih/structs v1.1.0 github.com/flosch/pongo2/v4 v4.0.0 - github.com/go-redis/redis/v8 v8.2.3 + github.com/go-redis/redis/v8 v8.3.1 github.com/google/uuid v1.1.2 github.com/hashicorp/go-version v1.2.1 github.com/iris-contrib/httpexpect/v2 v2.0.5 @@ -31,16 +31,16 @@ require ( github.com/russross/blackfriday/v2 v2.0.1 github.com/schollz/closestmatch v2.1.0+incompatible github.com/square/go-jose/v3 v3.0.0-20200630053402-0a67ce9b0693 - github.com/tdewolff/minify/v2 v2.9.7 + github.com/tdewolff/minify/v2 v2.9.9 github.com/vmihailenco/msgpack/v5 v5.0.0-beta.1 github.com/yosssi/ace v0.0.5 go.etcd.io/bbolt v1.3.5 - golang.org/x/crypto v0.0.0-20201002170205-7f63de1d35b0 - golang.org/x/net v0.0.0-20201002202402-0a1ea396d57c - golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f + golang.org/x/crypto v0.0.0-20201012173705-84dcc777aaee + golang.org/x/net v0.0.0-20201016165138-7b1cca2348c0 + golang.org/x/sys v0.0.0-20201016160150-f659759dc4ca golang.org/x/text v0.3.3 golang.org/x/time v0.0.0-20200630173020-3af7569d3a1e google.golang.org/protobuf v1.25.0 - gopkg.in/ini.v1 v1.61.0 + gopkg.in/ini.v1 v1.62.0 gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776 ) diff --git a/hero/container.go b/hero/container.go index a9b142b1..ded50532 100644 --- a/hero/container.go +++ b/hero/container.go @@ -170,9 +170,23 @@ var BuiltinDependencies = []*Dependency{ NewDependency(func(ctx *context.Context) Code { return Code(ctx.GetStatusCode()) }).Explicitly(), + // Context Error. May be nil NewDependency(func(ctx *context.Context) Err { - return Err(ctx.GetErr()) + err := ctx.GetErr() + if err == nil { + return nil + } + return err }).Explicitly(), + // Context User, e.g. from basic authentication. + NewDependency(func(ctx *context.Context) context.User { + u := ctx.User() + if u == nil { + return nil + } + + return u + }), // payload and param bindings are dynamically allocated and declared at the end of the `binding` source file. } diff --git a/httptest/httptest.go b/httptest/httptest.go index cbd7edc3..4a5a6a18 100644 --- a/httptest/httptest.go +++ b/httptest/httptest.go @@ -67,11 +67,11 @@ var ( } } - // LogLevel sets the application's log level "val". + // LogLevel sets the application's log level. // Defaults to disabled when testing. - LogLevel = func(val string) OptionSet { + LogLevel = func(level string) OptionSet { return func(c *Configuration) { - c.LogLevel = val + c.LogLevel = level } } ) diff --git a/middleware/README.md b/middleware/README.md index 3c046932..2f11609e 100644 --- a/middleware/README.md +++ b/middleware/README.md @@ -32,7 +32,6 @@ Most of the experimental handlers are ported to work with _iris_'s handler form, | [casbin](https://github.com/iris-contrib/middleware/tree/master/casbin)| An authorization library that supports access control models like ACL, RBAC, ABAC | [iris-contrib/middleware/casbin/_examples](https://github.com/iris-contrib/middleware/tree/master/casbin/_examples) | | [sentry-go (ex. raven)](https://github.com/getsentry/sentry-go/tree/master/iris)| Sentry client in Go | [sentry-go/example/iris](https://github.com/getsentry/sentry-go/blob/master/example/iris/main.go) | | [csrf](https://github.com/iris-contrib/middleware/tree/master/csrf)| Cross-Site Request Forgery Protection | [iris-contrib/middleware/csrf/_example](https://github.com/iris-contrib/middleware/blob/master/csrf/_example/main.go) | -| [go-i18n](https://github.com/iris-contrib/middleware/tree/master/go-i18n)| i18n Iris Loader for nicksnyder/go-i18n | [iris-contrib/middleware/go-i18n/_example](https://github.com/iris-contrib/middleware/blob/master/go-i18n/_example/main.go) | | [throttler](https://github.com/iris-contrib/middleware/tree/master/throttler)| Rate limiting access to HTTP endpoints | [iris-contrib/middleware/throttler/_example](https://github.com/iris-contrib/middleware/blob/master/throttler/_example/main.go) | Third-Party Handlers diff --git a/middleware/jwt/alises.go b/middleware/jwt/alises.go index c7e4f9c7..da34d9c7 100644 --- a/middleware/jwt/alises.go +++ b/middleware/jwt/alises.go @@ -2,6 +2,7 @@ package jwt import ( "github.com/square/go-jose/v3" + "github.com/square/go-jose/v3/json" "github.com/square/go-jose/v3/jwt" ) @@ -14,11 +15,19 @@ type ( // epoch, including leap seconds. Non-integer values can be represented // in the serialized format, but we round to the nearest second. NumericDate = jwt.NumericDate + // Expected defines values used for protected claims validation. + // If field has zero value then validation is skipped. + Expected = jwt.Expected ) var ( // NewNumericDate constructs NumericDate from time.Time value. NewNumericDate = jwt.NewNumericDate + // Marshal returns the JSON encoding of v. + Marshal = json.Marshal + // Unmarshal parses the JSON-encoded data and stores the result + // in the value pointed to by v. + Unmarshal = json.Unmarshal ) type ( diff --git a/middleware/jwt/blocklist.go b/middleware/jwt/blocklist.go new file mode 100644 index 00000000..86622655 --- /dev/null +++ b/middleware/jwt/blocklist.go @@ -0,0 +1,131 @@ +package jwt + +import ( + stdContext "context" + "sync" + "time" +) + +// Blocklist is an in-memory storage of tokens that should be +// immediately invalidated by the server-side. +// The most common way to invalidate a token, e.g. on user logout, +// is to make the client-side remove the token itself. +// However, if someone else has access to that token, +// it could be still valid for new requests until its expiration. +type Blocklist struct { + entries map[string]time.Time // key = token | value = expiration time (to remove expired). + mu sync.RWMutex +} + +// NewBlocklist returns a new up and running in-memory Token Blocklist. +// The returned value can be set to the JWT instance's Blocklist field. +func NewBlocklist(gcEvery time.Duration) *Blocklist { + return NewBlocklistContext(stdContext.Background(), gcEvery) +} + +// NewBlocklistContext same as `NewBlocklist` +// but it also accepts a standard Go Context for GC cancelation. +func NewBlocklistContext(ctx stdContext.Context, gcEvery time.Duration) *Blocklist { + b := &Blocklist{ + entries: make(map[string]time.Time), + } + + if gcEvery > 0 { + go b.runGC(ctx, gcEvery) + } + + return b +} + +// Set upserts a given token, with its expiration time, +// to the block list, so it's immediately invalidated by the server-side. +func (b *Blocklist) Set(token string, expiresAt time.Time) { + b.mu.Lock() + b.entries[token] = expiresAt + b.mu.Unlock() +} + +// Del removes a "token" from the block list. +func (b *Blocklist) Del(token string) { + b.mu.Lock() + delete(b.entries, token) + b.mu.Unlock() +} + +// Count returns the total amount of blocked tokens. +func (b *Blocklist) Count() int { + b.mu.RLock() + n := len(b.entries) + b.mu.RUnlock() + + return n +} + +// Has reports whether the given "token" is blocked by the server. +// This method is called before the token verification, +// so even if was expired it is removed from the block list. +func (b *Blocklist) Has(token string) bool { + if token == "" { + return false + } + + b.mu.RLock() + _, ok := b.entries[token] + b.mu.RUnlock() + + /* No, the Blocklist will be used after the token is parsed, + there we can call the Del method if err was ErrExpired. + if ok { + // As an extra step, to keep the list size as small as possible, + // we delete it from list if it's going to be expired + // ~in the next `blockedExpireLeeway` seconds.~ + // - Let's keep it easier for testing by not setting a leeway. + // if time.Now().Add(blockedExpireLeeway).After(expiresAt) { + if time.Now().After(expiresAt) { + b.Del(token) + } + }*/ + + return ok +} + +// GC iterates over all entries and removes expired tokens. +// This method is helpful to keep the list size small. +// Depending on the application, the GC method can be scheduled +// to called every half or a whole hour. +// A good value for a GC cron task is the JWT's max age (default). +func (b *Blocklist) GC() int { + now := time.Now() + var markedForDeletion []string + + b.mu.RLock() + for token, expiresAt := range b.entries { + if now.After(expiresAt) { + markedForDeletion = append(markedForDeletion, token) + } + } + b.mu.RUnlock() + + n := len(markedForDeletion) + if n > 0 { + for _, token := range markedForDeletion { + b.Del(token) + } + } + + return n +} + +func (b *Blocklist) runGC(ctx stdContext.Context, every time.Duration) { + t := time.NewTicker(every) + + for { + select { + case <-ctx.Done(): + t.Stop() + return + case <-t.C: + b.GC() + } + } +} diff --git a/middleware/jwt/jwt.go b/middleware/jwt/jwt.go index 196acd11..37df4622 100644 --- a/middleware/jwt/jwt.go +++ b/middleware/jwt/jwt.go @@ -2,8 +2,6 @@ package jwt import ( "crypto" - "encoding/json" - "errors" "os" "strings" "time" @@ -85,6 +83,9 @@ func FromJSON(jsonKey string) TokenExtractor { // // The `RSA(privateFile, publicFile, password)` package-level helper function // can be used to decode the SignKey and VerifyKey. +// +// For an easy use look the `HMAC` package-level function +// and the its `NewUser` and `VerifyUser` methods. type JWT struct { // MaxAge is the expiration duration of the generated tokens. MaxAge time.Duration @@ -109,6 +110,17 @@ type JWT struct { Encrypter jose.Encrypter // DecriptionKey is used to decrypt the token (private key) DecriptionKey interface{} + + // Blocklist holds the invalidated-by-server tokens (that are not yet expired). + // It is not initialized by default. + // Initialization Usage: + // j.UseBlocklist() + // OR + // j.Blocklist = jwt.NewBlocklist(gcEveryDuration) + // Usage: + // - ctx.Logout() + // - j.Invalidate(ctx) + Blocklist *Blocklist } type privateKey interface{ Public() crypto.PublicKey } @@ -284,64 +296,68 @@ func (j *JWT) WithEncryption(contentEncryption ContentEncryption, alg KeyAlgorit return nil } -// Expiry returns a new standard Claims with -// the `Expiry` and `IssuedAt` fields of the "claims" filled -// based on the given "maxAge" duration. -// -// See the `JWT.Expiry` method too. -func Expiry(maxAge time.Duration, claims Claims) Claims { - now := time.Now() - claims.Expiry = NewNumericDate(now.Add(maxAge)) - claims.IssuedAt = NewNumericDate(now) - return claims +// UseBlocklist initializes the Blocklist. +// Should be called on jwt middleware creation-time, +// after this, the developer can use the Context.Logout method +// to invalidate a verified token by the server-side. +func (j *JWT) UseBlocklist() { + gcEvery := 30 * time.Minute + if j.MaxAge > 0 { + gcEvery = j.MaxAge + } + j.Blocklist = NewBlocklist(gcEvery) } -// Expiry method same as `Expiry` package-level function, -// it returns a Claims with the expiration fields of the "claims" -// filled based on the JWT's `MaxAge` field. -// Only use it when this standard "claims" -// is embedded on a custom claims structure. -// Usage: -// type UserClaims struct { -// jwt.Claims -// Username string -// } -// [...] -// standardClaims := j.Expiry(jwt.Claims{...}) -// customClaims := UserClaims{ -// Claims: standardClaims, -// Username: "kataras", -// } -// j.WriteToken(ctx, customClaims) -func (j *JWT) Expiry(claims Claims) Claims { - return Expiry(j.MaxAge, claims) +// ExpiryMap adds the expiration based on the "maxAge" to the "claims" map. +// It's called automatically on `Token` method. +func ExpiryMap(maxAge time.Duration, claims context.Map) { + now := time.Now() + if claims["exp"] == nil { + claims["exp"] = NewNumericDate(now.Add(maxAge)) + } + + if claims["iat"] == nil { + claims["iat"] = NewNumericDate(now) + } } // Token generates and returns a new token string. // See `VerifyToken` too. func (j *JWT) Token(claims interface{}) (string, error) { - // switch c := claims.(type) { - // case Claims: - // claims = Expiry(j.MaxAge, c) - // case map[string]interface{}: let's not support map. - // now := time.Now() - // c["iat"] = now.Unix() - // c["exp"] = now.Add(j.MaxAge).Unix() - // } - if c, ok := claims.(Claims); ok { - claims = Expiry(j.MaxAge, c) + return j.token(j.MaxAge, claims) +} + +func (j *JWT) token(maxAge time.Duration, claims interface{}) (string, error) { + if claims == nil { + return "", ErrInvalidKey } + c, nErr := normalize(claims) + if nErr != nil { + return "", nErr + } + + // Set expiration, if missing. + ExpiryMap(maxAge, c) + var ( token string err error ) - // jwt.Builder and jwt.NestedBuilder contain same methods but they are not the same. + // + // Note that the .Claims method there, converts a Struct to a map under the hoods. + // That means that we will not have any performance cost + // if we do it by ourselves and pass always a Map there. + // That gives us the option to allow user to pass ANY go struct + // and we can add the "exp", "nbf", "iat" map values by ourselves + // based on the j.MaxAge. + // (^ done, see normalize, all methods are + // changed to accept totally custom types, no need to embed the standard Claims anymore). if j.DecriptionKey != nil { - token, err = jwt.SignedAndEncrypted(j.Signer, j.Encrypter).Claims(claims).CompactSerialize() + token, err = jwt.SignedAndEncrypted(j.Signer, j.Encrypter).Claims(c).CompactSerialize() } else { - token, err = jwt.Signed(j.Signer).Claims(claims).CompactSerialize() + token, err = jwt.Signed(j.Signer).Claims(c).CompactSerialize() } if err != nil { @@ -351,39 +367,6 @@ func (j *JWT) Token(claims interface{}) (string, error) { return token, nil } -/* Let's no support maps, typed claim is the way to go. -// validateMapClaims validates claims of map type. -func validateMapClaims(m map[string]interface{}, e jwt.Expected, leeway time.Duration) error { - if !e.Time.IsZero() { - if v, ok := m["nbf"]; ok { - if notBefore, ok := v.(NumericDate); ok { - if e.Time.Add(leeway).Before(notBefore.Time()) { - return ErrNotValidYet - } - } - } - - if v, ok := m["exp"]; ok { - if exp, ok := v.(int64); ok { - if e.Time.Add(-leeway).Before(time.Unix(exp, 0)) { - return ErrExpired - } - } - } - - if v, ok := m["iat"]; ok { - if issuedAt, ok := v.(int64); ok { - if e.Time.Add(leeway).Before(time.Unix(issuedAt, 0)) { - return ErrIssuedInTheFuture - } - } - } - } - - return nil -} -*/ - // WriteToken is a helper which just generates(calls the `Token` method) and writes // a new token to the client in plain text format. // @@ -399,91 +382,122 @@ func (j *JWT) WriteToken(ctx *context.Context, claims interface{}) error { return err } -var ( - // ErrMissing when token cannot be extracted from the request. - ErrMissing = errors.New("token is missing") - // ErrExpired indicates that token is used after expiry time indicated in exp claim. - ErrExpired = errors.New("token is expired (exp)") - // ErrNotValidYet indicates that token is used before time indicated in nbf claim. - ErrNotValidYet = errors.New("token not valid yet (nbf)") - // ErrIssuedInTheFuture indicates that the iat field is in the future. - ErrIssuedInTheFuture = errors.New("token issued in the future (iat)") -) - -type ( - claimsValidator interface { - ValidateWithLeeway(e jwt.Expected, leeway time.Duration) error - } - claimsAlternativeValidator interface { // to keep iris-contrib/jwt MapClaims compatible. - Validate() error - } - claimsContextValidator interface { - Validate(ctx *context.Context) error - } -) - -// IsValidated reports whether a token is already validated through -// `VerifyToken`. It returns true when the claims are compatible -// validators: a `Claims` value or a value that implements the `Validate() error` method. -func IsValidated(ctx *context.Context) bool { // see the `ReadClaims`. - return ctx.Values().Get(needsValidationContextKey) == nil -} - -func validateClaims(ctx *context.Context, claims interface{}) (err error) { - switch c := claims.(type) { - case claimsValidator: - err = c.ValidateWithLeeway(jwt.Expected{Time: time.Now()}, 0) - case claimsAlternativeValidator: - err = c.Validate() - case claimsContextValidator: - err = c.Validate(ctx) - case *json.RawMessage: - // if the data type is raw message (json []byte) - // then it should contain exp (and iat and nbf) keys. - // Unmarshal raw message to validate against. - v := new(Claims) - err = json.Unmarshal(*c, v) - if err == nil { - return validateClaims(ctx, v) - } - default: - ctx.Values().Set(needsValidationContextKey, struct{}{}) - } - - if err != nil { - switch err { - case jwt.ErrExpired: - return ErrExpired - case jwt.ErrNotValidYet: - return ErrNotValidYet - case jwt.ErrIssuedInTheFuture: - return ErrIssuedInTheFuture - } - } - - return err -} - // VerifyToken verifies (and decrypts) the request token, // it also validates and binds the parsed token's claims to the "claimsPtr" (destination). -// It does return a nil error on success. -func (j *JWT) VerifyToken(ctx *context.Context, claimsPtr interface{}) error { - var token string +// +// The last, variadic, input argument is optionally, if provided then the +// parsed claims must match the expectations; +// e.g. Audience, Issuer, ID, Subject. +// See `ExpectXXX` package-functions for details. +func (j *JWT) VerifyToken(ctx *context.Context, claimsPtr interface{}, expectations ...Expectation) (*TokenInfo, error) { + token := j.RequestToken(ctx) + return j.VerifyTokenString(ctx, token, claimsPtr, expectations...) +} +// RequestToken extracts the token from the request. +func (j *JWT) RequestToken(ctx *context.Context) (token string) { for _, extract := range j.Extractors { if token = extract(ctx); token != "" { break // ok we found it. } } - return j.VerifyTokenString(ctx, token, claimsPtr) + return } -// VerifyTokenString verifies and unmarshals an extracted token to "claimsPtr" destination. -// The Context is required when the claims validator needs it, otherwise can be nil. -func (j *JWT) VerifyTokenString(ctx *context.Context, token string, claimsPtr interface{}) error { +// TokenSetter is an interface which if implemented +// the extracted, verified, token is stored to the object. +type TokenSetter interface { + SetToken(token string) +} + +// TokenInfo holds the standard token information may required +// for further actions. +// This structure is mostly useful when the developer's go structure +// does not hold the standard jwt fields (e.g. "exp") +// but want access to the parsed token which contains those fields. +// Inside the middleware, it is used to invalidate tokens through server-side, see `Invalidate`. +type TokenInfo struct { + RequestToken string // The request token. + Claims Claims // The standard JWT parsed fields from the request Token. + Value interface{} // The pointer to the end-developer's custom claims structure (see `Get`). +} + +const tokenInfoContextKey = "iris.jwt.token" + +// Get returns the verified developer token claims. +// +// +// Usage: +// j := jwt.New(...) +// app.Use(j.Verify(func() interface{} { return new(CustomClaims) })) +// app.Post("/restricted", func(ctx iris.Context){ +// claims := jwt.Get(ctx).(*CustomClaims) +// [use claims...] +// }) +// +// Note that there is one exception, if the value was a pointer +// to a map[string]interface{}, it returns the map itself so it can be +// accessible directly without the requirement of unwrapping it, e.g. +// j.Verify(func() interface{} { +// return &iris.Map{} +// } +// [...] +// claims := jwt.Get(ctx).(iris.Map) +func Get(ctx *context.Context) interface{} { + if tok := GetTokenInfo(ctx); tok != nil { + switch v := tok.Value.(type) { + case *context.Map: + return *v + default: + return v + } + } + + return nil +} + +// GetTokenInfo returns the verified token's information. +func GetTokenInfo(ctx *context.Context) *TokenInfo { + if v := ctx.Values().Get(tokenInfoContextKey); v != nil { + if t, ok := v.(*TokenInfo); ok { + return t + } + } + + return nil +} + +// Invalidate invalidates a verified JWT token. +// It adds the request token, retrieved by Verify methods, to the block list. +// Next request will be blocked, even if the token was not yet expired. +// This method can be used when the client-side does not clear the token +// on a user logout operation. +// +// Note: the Blocklist should be initialized before serve-time: j.UseBlocklist(). +func (j *JWT) Invalidate(ctx *context.Context) { + if j.Blocklist == nil { + ctx.Application().Logger().Debug("jwt.Invalidate: Blocklist is nil") + return + } + + tokenInfo := GetTokenInfo(ctx) + if tokenInfo == nil { + return + } + + j.Blocklist.Set(tokenInfo.RequestToken, tokenInfo.Claims.Expiry.Time()) +} + +// VerifyTokenString verifies and unmarshals an extracted request token to "dest" destination. +// The last variadic input indicates any further validations against the verified token claims. +// If the given "dest" is a valid context.User then ctx.User() will return it. +// If the token is missing an `ErrMissing` is returned. +// If the incoming token was expired an `ErrExpired` is returned. +// If the incoming token was blocked by the server an `ErrBlocked` is returned. +func (j *JWT) VerifyTokenString(ctx *context.Context, token string, dest interface{}, expectations ...Expectation) (*TokenInfo, error) { if token == "" { - return ErrMissing + return nil, ErrMissing } var ( @@ -494,7 +508,7 @@ func (j *JWT) VerifyTokenString(ctx *context.Context, token string, claimsPtr in if j.DecriptionKey != nil { t, cerr := jwt.ParseSignedAndEncrypted(token) if cerr != nil { - return cerr + return nil, cerr } parsedToken, err = t.Decrypt(j.DecriptionKey) @@ -502,112 +516,163 @@ func (j *JWT) VerifyTokenString(ctx *context.Context, token string, claimsPtr in parsedToken, err = jwt.ParseSigned(token) } if err != nil { - return err + return nil, err } - if err = parsedToken.Claims(j.VerificationKey, claimsPtr); err != nil { - return err + var claims Claims + if err = parsedToken.Claims(j.VerificationKey, dest, &claims); err != nil { + return nil, err } - return validateClaims(ctx, claimsPtr) -} - -const ( - // ClaimsContextKey is the context key which the jwt claims are stored from the `Verify` method. - ClaimsContextKey = "iris.jwt.claims" - needsValidationContextKey = "iris.jwt.claims.unvalidated" -) - -// Verify is a middleware. It verifies and optionally decrypts an incoming request token. -// It does write a 401 unauthorized status code if verification or decryption failed. -// It calls the `ctx.Next` on verified requests. -// -// See `VerifyToken` instead to verify, decrypt, validate and acquire the claims at once. -// -// A call of `ReadClaims` is required to validate and acquire the jwt claims -// on the next request. -func (j *JWT) Verify(ctx *context.Context) { - var raw json.RawMessage - if err := j.VerifyToken(ctx, &raw); err != nil { - ctx.StopWithStatus(401) - return - } - - ctx.Values().Set(ClaimsContextKey, raw) - ctx.Next() -} - -// ReadClaims binds the "claimsPtr" (destination) -// to the verified (and decrypted) claims. -// The `Verify` method should be called first (registered as middleware). -func ReadClaims(ctx *context.Context, claimsPtr interface{}) error { - v := ctx.Values().Get(ClaimsContextKey) - if v == nil { - return ErrMissing - } - - raw, ok := v.(json.RawMessage) - if !ok { - return ErrMissing - } - - err := json.Unmarshal(raw, claimsPtr) - if err != nil { - return err - } - - if !IsValidated(ctx) { - // If already validated on `Verify/VerifyToken` - // then no need to perform the check again. - ctx.Values().Remove(needsValidationContextKey) - return validateClaims(ctx, claimsPtr) - } - - return nil -} - -// Get returns and validates (if not already) the claims -// stored on request context's values storage. -// -// Should be used instead of the `ReadClaims` method when -// a custom verification middleware was registered (see the `Verify` method for an example). -// -// Usage: -// j := jwt.New(...) -// [...] -// app.Use(func(ctx iris.Context) { -// var claims CustomClaims_or_jwt.Claims -// if err := j.VerifyToken(ctx, &claims); err != nil { -// ctx.StopWithStatus(iris.StatusUnauthorized) -// return -// } -// -// ctx.Values().Set(jwt.ClaimsContextKey, claims) -// ctx.Next() -// }) -// [...] -// app.Post("/restricted", func(ctx iris.Context){ -// v, err := jwt.Get(ctx) -// [handle error...] -// claims,ok := v.(CustomClaims_or_jwt.Claims) -// if !ok { -// [do you support more than one type of claims? Handle here] -// } -// [use claims...] -// }) -func Get(ctx *context.Context) (interface{}, error) { - claims := ctx.Values().Get(ClaimsContextKey) - if claims == nil { - return nil, ErrMissing - } - - if !IsValidated(ctx) { - ctx.Values().Remove(needsValidationContextKey) - err := validateClaims(ctx, claims) - if err != nil { - return nil, err + // Build the Expected value. + expected := Expected{} + for _, e := range expectations { + if e != nil { + // expection can be used as a field validation too (see MeetRequirements). + if err = e(&expected, dest); err != nil { + return nil, err + } } } - return claims, nil + // For other standard JWT claims fields such as "exp" + // The developer can just add a field of Expiry *NumericDate `json:"exp"` + // and will be filled by the parsed token automatically. + // No need for more interfaces. + + err = validateClaims(ctx, dest, claims, expected) + if err != nil { + if err == ErrExpired { + // If token was expired remove it from the block list. + if j.Blocklist != nil { + j.Blocklist.Del(token) + } + } + + return nil, err + } + + if j.Blocklist != nil { + // If token exists in the block list, then stop here. + if j.Blocklist.Has(token) { + return nil, ErrBlocked + } + } + + if ut, ok := dest.(TokenSetter); ok { + // The u.Token is empty even if we set it and export it on JSON structure. + // Set it manually. + ut.SetToken(token) + } + + // Set the information. + tokenInfo := &TokenInfo{ + RequestToken: token, + Claims: claims, + Value: dest, + } + + return tokenInfo, nil +} + +// TokenPair holds the access token and refresh token response. +type TokenPair struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` +} + +// TokenPair generates a token pair of access and refresh tokens. +// The first two arguments required for the refresh token +// and the last one is the claims for the access token one. +func (j *JWT) TokenPair(refreshMaxAge time.Duration, refreshClaims interface{}, accessClaims interface{}) (TokenPair, error) { + accessToken, err := j.Token(accessClaims) + if err != nil { + return TokenPair{}, err + } + + refreshToken, err := j.token(refreshMaxAge, refreshClaims) + if err != nil { + return TokenPair{}, nil + } + + pair := TokenPair{ + AccessToken: accessToken, + RefreshToken: refreshToken, + } + + return pair, nil +} + +// Verify returns a middleware which +// decrypts an incoming request token to the result of the given "newPtr". +// It does write a 401 unauthorized status code if verification or decryption failed. +// It calls the `ctx.Next` on verified requests. +// +// Iit unmarshals the token to the specific type returned from the given "newPtr" function. +// It sets the Context User and User's Token too. So the next handler(s) +// of the same chain can access the User through a `Context.User()` call. +// +// Note unlike `VerifyToken`, this method automatically protects +// the claims with JSON required tags (see `MeetRequirements` Expection). +// +// On verified tokens: +// - The information can be retrieved through `Get` and `GetTokenInfo` functions. +// - User is set if the newPtr returns a valid Context User +// - The Context Logout method is set if Blocklist was initialized +// Any error is captured to the Context, +// which can be retrieved by a `ctx.GetErr()` call. +func (j *JWT) Verify(newPtr func() interface{}, expections ...Expectation) context.Handler { + expections = append(expections, MeetRequirements(newPtr())) + + return func(ctx *context.Context) { + ptr := newPtr() + + tokenInfo, err := j.VerifyToken(ctx, ptr, expections...) + if err != nil { + ctx.Application().Logger().Debugf("iris.jwt.Verify: %v", err) + ctx.StopWithError(401, context.PrivateError(err)) + return + } + + if u, ok := ptr.(context.User); ok { + ctx.SetUser(u) + } + + if j.Blocklist != nil { + ctx.SetLogoutFunc(j.Invalidate) + } + + ctx.Values().Set(tokenInfoContextKey, tokenInfo) + ctx.Next() + } +} + +// NewUser returns a new User based on the given "opts". +// The caller can modify the User until its `GetToken` is called. +func (j *JWT) NewUser(opts ...UserOption) *User { + u := &User{ + j: j, + SimpleUser: &context.SimpleUser{ + Authorization: "IRIS_JWT_USER", // Used to separate a refresh token with a user/access one too. + Features: []context.UserFeature{ + context.TokenFeature, + }, + }, + } + + for _, opt := range opts { + opt(u) + } + + return u +} + +// VerifyUser works like the `Verify` method but instead +// it unmarshals the token to the specific User type. +// It sets the Context User too. So the next handler(s) +// of the same chain can access the User through a `Context.User()` call. +func (j *JWT) VerifyUser() context.Handler { + return j.Verify(func() interface{} { + return new(User) + }) } diff --git a/middleware/jwt/jwt_test.go b/middleware/jwt/jwt_test.go index 0d30d6fd..f42e80c1 100644 --- a/middleware/jwt/jwt_test.go +++ b/middleware/jwt/jwt_test.go @@ -12,11 +12,15 @@ import ( ) type userClaims struct { - jwt.Claims - Username string + // Optionally: + Issuer string `json:"iss"` + Subject string `json:"sub"` + Audience jwt.Audience `json:"aud"` + // + Username string `json:"username"` } -const testMaxAge = 3 * time.Second +const testMaxAge = 7 * time.Second // Random RSA verification and encryption. func TestRSA(t *testing.T) { @@ -25,13 +29,13 @@ func TestRSA(t *testing.T) { os.Remove(jwt.DefaultSignFilename) os.Remove(jwt.DefaultEncFilename) }) - testWriteVerifyToken(t, j) + testWriteVerifyBlockToken(t, j) } // HMAC verification and encryption. func TestHMAC(t *testing.T) { j := jwt.HMAC(testMaxAge, "secret", "itsa16bytesecret") - testWriteVerifyToken(t, j) + testWriteVerifyBlockToken(t, j) } func TestNew_HMAC(t *testing.T) { @@ -44,7 +48,7 @@ func TestNew_HMAC(t *testing.T) { t.Fatal(err) } - testWriteVerifyToken(t, j) + testWriteVerifyBlockToken(t, j) } // HMAC verification only (unecrypted). @@ -53,54 +57,60 @@ func TestVerify(t *testing.T) { if err != nil { t.Fatal(err) } - testWriteVerifyToken(t, j) + testWriteVerifyBlockToken(t, j) } -func testWriteVerifyToken(t *testing.T, j *jwt.JWT) { +func testWriteVerifyBlockToken(t *testing.T, j *jwt.JWT) { t.Helper() + j.UseBlocklist() j.Extractors = append(j.Extractors, jwt.FromJSON("access_token")) - standardClaims := jwt.Claims{Issuer: "an-issuer", Audience: jwt.Audience{"an-audience"}} - expectedClaims := userClaims{ - Claims: j.Expiry(standardClaims), + + customClaims := &userClaims{ + Issuer: "an-issuer", + Audience: jwt.Audience{"an-audience"}, + Subject: "user", Username: "kataras", } app := iris.New() + app.OnErrorCode(iris.StatusUnauthorized, func(ctx iris.Context) { + if err := ctx.GetErr(); err != nil { + // Test accessing the private error and set this as the response body. + ctx.WriteString(err.Error()) + } else { // Else the default behavior + ctx.WriteString(iris.StatusText(iris.StatusUnauthorized)) + } + }) + app.Get("/auth", func(ctx iris.Context) { - j.WriteToken(ctx, expectedClaims) + j.WriteToken(ctx, customClaims) }) - app.Post("/restricted", func(ctx iris.Context) { + app.Post("/protected", func(ctx iris.Context) { var claims userClaims - if err := j.VerifyToken(ctx, &claims); err != nil { - ctx.StopWithStatus(iris.StatusUnauthorized) - return - } - - ctx.JSON(claims) - }) - - app.Post("/restricted_middleware_readclaims", j.Verify, func(ctx iris.Context) { - var claims userClaims - if err := jwt.ReadClaims(ctx, &claims); err != nil { - ctx.StopWithStatus(iris.StatusUnauthorized) - return - } - - ctx.JSON(claims) - }) - - app.Post("/restricted_middleware_get", j.Verify, func(ctx iris.Context) { - claims, err := jwt.Get(ctx) + _, err := j.VerifyToken(ctx, &claims) if err != nil { - ctx.StopWithStatus(iris.StatusUnauthorized) + // t.Logf("%s: %v", ctx.Path(), err) + ctx.StopWithError(iris.StatusUnauthorized, iris.PrivateError(err)) return } ctx.JSON(claims) }) + m := app.Party("/middleware") + m.Use(j.Verify(func() interface{} { + return new(userClaims) + })) + m.Post("/protected", func(ctx iris.Context) { + claims := jwt.Get(ctx) + ctx.JSON(claims) + }) + m.Post("/invalidate", func(ctx iris.Context) { + ctx.Logout() // OR j.Invalidate(ctx) + }) + e := httptest.New(t, app) // Get token. @@ -109,31 +119,186 @@ func testWriteVerifyToken(t *testing.T, j *jwt.JWT) { t.Fatalf("empty token") } - restrictedPaths := [...]string{"/restricted", "/restricted_middleware_readclaims", "/restricted_middleware_get"} + restrictedPaths := [...]string{"/protected", "/middleware/protected"} now := time.Now() for _, path := range restrictedPaths { // Authorization Header. e.POST(path).WithHeader("Authorization", "Bearer "+rawToken).Expect(). - Status(httptest.StatusOK).JSON().Equal(expectedClaims) + Status(httptest.StatusOK).JSON().Equal(customClaims) // URL Query. e.POST(path).WithQuery("token", rawToken).Expect(). - Status(httptest.StatusOK).JSON().Equal(expectedClaims) + Status(httptest.StatusOK).JSON().Equal(customClaims) // JSON Body. e.POST(path).WithJSON(iris.Map{"access_token": rawToken}).Expect(). - Status(httptest.StatusOK).JSON().Equal(expectedClaims) + Status(httptest.StatusOK).JSON().Equal(customClaims) // Missing "Bearer". e.POST(path).WithHeader("Authorization", rawToken).Expect(). - Status(httptest.StatusUnauthorized) + Status(httptest.StatusUnauthorized).Body().Equal("token is missing") } + + // Invalidate the token. + e.POST("/middleware/invalidate").WithQuery("token", rawToken).Expect(). + Status(httptest.StatusOK) + // Token is blocked by server. + e.POST("/middleware/protected").WithQuery("token", rawToken).Expect(). + Status(httptest.StatusUnauthorized).Body().Equal("token is blocked") + expireRemDur := testMaxAge - time.Since(now) // Expiration. time.Sleep(expireRemDur /* -end */) for _, path := range restrictedPaths { - e.POST(path).WithQuery("token", rawToken).Expect().Status(httptest.StatusUnauthorized) + e.POST(path).WithQuery("token", rawToken).Expect(). + Status(httptest.StatusUnauthorized).Body().Equal("token is expired (exp)") } } + +func TestVerifyMap(t *testing.T) { + j := jwt.HMAC(testMaxAge, "secret", "itsa16bytesecret") + expectedClaims := iris.Map{ + "iss": "tester", + "username": "makis", + "roles": []string{"admin"}, + } + + app := iris.New() + app.Get("/user/auth", func(ctx iris.Context) { + err := j.WriteToken(ctx, expectedClaims) + if err != nil { + ctx.StopWithError(iris.StatusUnauthorized, err) + return + } + + if expectedClaims["exp"] == nil || expectedClaims["iat"] == nil { + ctx.StopWithText(iris.StatusBadRequest, + "exp or/and iat is nil - this means that the expiry was not set") + return + } + }) + + userAPI := app.Party("/user") + userAPI.Post("/", func(ctx iris.Context) { + var claims iris.Map + if _, err := j.VerifyToken(ctx, &claims); err != nil { + ctx.StopWithError(iris.StatusUnauthorized, iris.PrivateError(err)) + return + } + + ctx.JSON(claims) + }) + + // Test map + Verify middleware. + userAPI.Post("/middleware", j.Verify(func() interface{} { + return &iris.Map{} // or &map[string]interface{}{} + }), func(ctx iris.Context) { + claims := jwt.Get(ctx) + ctx.JSON(claims) + }) + + e := httptest.New(t, app, httptest.LogLevel("error")) + token := e.GET("/user/auth").Expect().Status(httptest.StatusOK).Body().Raw() + if token == "" { + t.Fatalf("empty token") + } + + e.POST("/user").WithHeader("Authorization", "Bearer "+token).Expect(). + Status(httptest.StatusOK).JSON().Equal(expectedClaims) + + e.POST("/user/middleware").WithHeader("Authorization", "Bearer "+token).Expect(). + Status(httptest.StatusOK).JSON().Equal(expectedClaims) + + e.POST("/user").Expect().Status(httptest.StatusUnauthorized) +} + +type customClaims struct { + Username string `json:"username"` + Token string `json:"token"` +} + +func (c *customClaims) SetToken(tok string) { + c.Token = tok +} + +func TestVerifyStruct(t *testing.T) { + maxAge := testMaxAge / 2 + j := jwt.HMAC(maxAge, "secret", "itsa16bytesecret") + + app := iris.New() + app.Get("/user/auth", func(ctx iris.Context) { + err := j.WriteToken(ctx, customClaims{ + Username: "makis", + }) + if err != nil { + ctx.StopWithError(iris.StatusUnauthorized, err) + return + } + }) + + userAPI := app.Party("/user") + userAPI.Post("/", func(ctx iris.Context) { + var claims customClaims + if _, err := j.VerifyToken(ctx, &claims); err != nil { + ctx.StopWithError(iris.StatusUnauthorized, iris.PrivateError(err)) + return + } + + ctx.JSON(claims) + }) + + e := httptest.New(t, app) + token := e.GET("/user/auth").Expect().Status(httptest.StatusOK).Body().Raw() + if token == "" { + t.Fatalf("empty token") + } + e.POST("/user").WithHeader("Authorization", "Bearer "+token).Expect(). + Status(httptest.StatusOK).JSON().Object().ContainsMap(iris.Map{ + "username": "makis", + "token": token, // Test SetToken. + }) + + e.POST("/user").Expect().Status(httptest.StatusUnauthorized) + time.Sleep(maxAge) + e.POST("/user").WithHeader("Authorization", "Bearer "+token).Expect().Status(httptest.StatusUnauthorized) +} + +func TestVerifyUserAndExpected(t *testing.T) { // Tests the jwt.User struct + context validator + expected. + maxAge := testMaxAge / 2 + j := jwt.HMAC(maxAge, "secret", "itsa16bytesecret") + expectedUser := j.NewUser(jwt.Username("makis"), jwt.Roles("admin"), jwt.Fields(iris.Map{ + "custom": true, + })) // only for the sake of the test, we iniitalize it here. + expectedUser.Issuer = "tester" + + app := iris.New() + app.Get("/user/auth", func(ctx iris.Context) { + tok, err := expectedUser.GetToken() + if err != nil { + ctx.StopWithError(iris.StatusInternalServerError, err) + return + } + ctx.WriteString(tok) + }) + + userAPI := app.Party("/user") + userAPI.Use(jwt.WithExpected(jwt.Expected{Issuer: "tester"}, j.VerifyUser())) + userAPI.Post("/", func(ctx iris.Context) { + user := ctx.User() + ctx.JSON(user) + }) + + e := httptest.New(t, app) + token := e.GET("/user/auth").Expect().Status(httptest.StatusOK).Body().Raw() + if token == "" { + t.Fatalf("empty token") + } + + e.POST("/user").WithHeader("Authorization", "Bearer "+token).Expect(). + Status(httptest.StatusOK).JSON().Equal(expectedUser) + + // Test generic client message if we don't manage the private error by ourselves. + e.POST("/user").Expect().Status(httptest.StatusUnauthorized).Body().Equal("Unauthorized") +} diff --git a/middleware/jwt/user.go b/middleware/jwt/user.go new file mode 100644 index 00000000..3d8af27f --- /dev/null +++ b/middleware/jwt/user.go @@ -0,0 +1,187 @@ +package jwt + +import ( + "time" + + "github.com/kataras/iris/v12/context" +) + +// User a common User structure for JWT. +// However, we're not limited to that one; +// any Go structure can be generated as a JWT token. +// +// Look `NewUser` and `VerifyUser` JWT middleware's methods. +// Use its `GetToken` method to generate the token when +// the User structure is set. +type User struct { + Claims + // Note: we could use a map too as the Token is generated when GetToken is called. + *context.SimpleUser + + j *JWT +} + +var ( + _ context.FeaturedUser = (*User)(nil) + _ TokenSetter = (*User)(nil) + _ ContextValidator = (*User)(nil) +) + +// UserOption sets optional fields for a new User +// See `NewUser` instance function. +type UserOption func(*User) + +// Username sets the Username and the JWT Claim's Subject +// to the given "username". +func Username(username string) UserOption { + return func(u *User) { + u.Username = username + u.Claims.Subject = username + u.Features = append(u.Features, context.UsernameFeature) + } +} + +// Email sets the Email field for the User field. +func Email(email string) UserOption { + return func(u *User) { + u.Email = email + u.Features = append(u.Features, context.EmailFeature) + } +} + +// Roles upserts to the User's Roles field. +func Roles(roles ...string) UserOption { + return func(u *User) { + u.Roles = roles + u.Features = append(u.Features, context.RolesFeature) + } +} + +// MaxAge sets claims expiration and the AuthorizedAt User field. +func MaxAge(maxAge time.Duration) UserOption { + return func(u *User) { + now := time.Now() + u.Claims.Expiry = NewNumericDate(now.Add(maxAge)) + u.Claims.IssuedAt = NewNumericDate(now) + u.AuthorizedAt = now + + u.Features = append(u.Features, context.AuthorizedAtFeature) + } +} + +// Fields copies the "fields" to the user's Fields field. +// This can be used to set custom fields to the User instance. +func Fields(fields context.Map) UserOption { + return func(u *User) { + if len(fields) == 0 { + return + } + + if u.Fields == nil { + u.Fields = make(context.Map, len(fields)) + } + + for k, v := range fields { + u.Fields[k] = v + } + + u.Features = append(u.Features, context.FieldsFeature) + } +} + +// SetToken is called automaticaly on VerifyUser/VerifyObject. +// It sets the extracted from request, and verified from server raw token. +func (u *User) SetToken(token string) { + u.Token = token +} + +// GetToken overrides the SimpleUser's Token +// and returns the jwt generated token, among with +// a generator error, if any. +func (u *User) GetToken() (string, error) { + if u.Token != "" { + return u.Token, nil + } + + if u.j != nil { // it's always not nil. + if u.j.MaxAge > 0 { + // if the MaxAge option was not manually set, resolve it from the JWT instance. + MaxAge(u.j.MaxAge)(u) + } + + // we could generate a token here + // but let's do it on GetToken + // as the user fields may change + // by the caller manually until the token + // sent to the client. + tok, err := u.j.Token(u) + if err != nil { + return "", err + } + + u.Token = tok + } + + if u.Token == "" { + return "", ErrMissing + } + + return u.Token, nil +} + +// Validate validates the current user's claims against +// the request. It's called automatically by the JWT instance. +func (u *User) Validate(ctx *context.Context, claims Claims, e Expected) error { + err := u.Claims.ValidateWithLeeway(e, 0) + if err != nil { + return err + } + + if u.SimpleUser.Authorization != "IRIS_JWT_USER" { + return ErrInvalidKey + } + + // We could add specific User Expectations (new struct and accept an interface{}), + // but for the sake of code simplicity we don't, unless is requested, as the caller + // can validate specific fields by its own at the next step. + return nil +} + +// UnmarshalJSON implements the json unmarshaler interface. +func (u *User) UnmarshalJSON(data []byte) error { + err := Unmarshal(data, &u.Claims) + if err != nil { + return err + } + simpleUser := new(context.SimpleUser) + err = Unmarshal(data, simpleUser) + if err != nil { + return err + } + u.SimpleUser = simpleUser + return nil +} + +// MarshalJSON implements the json marshaler interface. +func (u *User) MarshalJSON() ([]byte, error) { + claimsB, err := Marshal(u.Claims) + if err != nil { + return nil, err + } + + userB, err := Marshal(u.SimpleUser) + if err != nil { + return nil, err + } + + if len(userB) == 0 { + return claimsB, nil + } + + claimsB = claimsB[0 : len(claimsB)-1] // remove last '}' + userB = userB[1:] // remove first '{' + + raw := append(claimsB, ',') + raw = append(raw, userB...) + return raw, nil +} diff --git a/middleware/jwt/validation.go b/middleware/jwt/validation.go new file mode 100644 index 00000000..9e8ffd62 --- /dev/null +++ b/middleware/jwt/validation.go @@ -0,0 +1,212 @@ +package jwt + +import ( + "bytes" + "errors" + "reflect" + "strings" + "time" + + "github.com/kataras/iris/v12/context" + + "github.com/square/go-jose/v3/json" + // Use this package instead of the standard encoding/json + // to marshal the NumericDate as expected by the implementation (see 'normalize`). + "github.com/square/go-jose/v3/jwt" +) + +const ( + claimsExpectedContextKey = "iris.jwt.claims.expected" + needsValidationContextKey = "iris.jwt.claims.unvalidated" +) + +var ( + // ErrMissing when token cannot be extracted from the request. + ErrMissing = errors.New("token is missing") + // ErrMissingKey when token does not contain a required JSON field. + ErrMissingKey = errors.New("token is missing a required field") + // ErrExpired indicates that token is used after expiry time indicated in exp claim. + ErrExpired = errors.New("token is expired (exp)") + // ErrNotValidYet indicates that token is used before time indicated in nbf claim. + ErrNotValidYet = errors.New("token not valid yet (nbf)") + // ErrIssuedInTheFuture indicates that the iat field is in the future. + ErrIssuedInTheFuture = errors.New("token issued in the future (iat)") + // ErrBlocked indicates that the token was not yet expired + // but was blocked by the server's Blocklist. + ErrBlocked = errors.New("token is blocked") +) + +// Expectation option to provide +// an extra layer of token validation, a claims type protection. +// See `VerifyToken` method. +type Expectation func(e *Expected, claims interface{}) error + +// Expect protects the claims with the expected values. +func Expect(expected Expected) Expectation { + return func(e *Expected, _ interface{}) error { + *e = expected + return nil + } +} + +// ExpectID protects the claims with an ID validation. +func ExpectID(id string) Expectation { + return func(e *Expected, _ interface{}) error { + e.ID = id + return nil + } +} + +// ExpectIssuer protects the claims with an issuer validation. +func ExpectIssuer(issuer string) Expectation { + return func(e *Expected, _ interface{}) error { + e.Issuer = issuer + return nil + } +} + +// ExpectSubject protects the claims with a subject validation. +func ExpectSubject(sub string) Expectation { + return func(e *Expected, _ interface{}) error { + e.Subject = sub + return nil + } +} + +// ExpectAudience protects the claims with an audience validation. +func ExpectAudience(audience ...string) Expectation { + return func(e *Expected, _ interface{}) error { + e.Audience = audience + return nil + } +} + +// MeetRequirements protects the custom fields of JWT claims +// based on the json:required tag; `json:"name,required"`. +// It accepts the value type. +// +// Usage: +// Verify/VerifyToken(... MeetRequirements(MyUser{})) +func MeetRequirements(claimsType interface{}) Expectation { + // pre-calculate if we need to use reflection at serve time to check for required fields, + // this can work as an alternative of expections for custom non-standard JWT fields. + requireFieldsIndexes := getRequiredFieldIndexes(claimsType) + + return func(e *Expected, claims interface{}) error { + if len(requireFieldsIndexes) > 0 { + val := reflect.Indirect(reflect.ValueOf(claims)) + for _, idx := range requireFieldsIndexes { + field := val.Field(idx) + if field.IsZero() { + return ErrMissingKey + } + } + } + + return nil + } +} + +// WithExpected is a middleware wrapper. It wraps a VerifyXXX middleware +// with expected claims fields protection. +// Usage: +// jwt.WithExpected(jwt.Expected{Issuer:"app"}, j.VerifyUser) +func WithExpected(e Expected, verifyHandler context.Handler) context.Handler { + return func(ctx *context.Context) { + ctx.Values().Set(claimsExpectedContextKey, e) + verifyHandler(ctx) + } +} + +// ContextValidator validates the object based on the given +// claims and the expected once. The end-developer +// can use this method for advanced validations based on the request Context. +type ContextValidator interface { + Validate(ctx *context.Context, claims Claims, e Expected) error +} + +func validateClaims(ctx *context.Context, dest interface{}, claims Claims, expected Expected) (err error) { + // Get any dynamic expectation set by prior middleware. + // See `WithExpected` middleware. + if v := ctx.Values().Get(claimsExpectedContextKey); v != nil { + if e, ok := v.(Expected); ok { + expected = e + } + } + // Force-set the time, it's important for expiration. + expected.Time = time.Now() + switch c := dest.(type) { + case Claims: + err = c.ValidateWithLeeway(expected, 0) + case ContextValidator: + err = c.Validate(ctx, claims, expected) + case *context.Map: + // if the dest is a map then set automatically the expiration settings here, + // so the caller can work further with it. + err = claims.ValidateWithLeeway(expected, 0) + if err == nil { + (*c)["exp"] = claims.Expiry + (*c)["iat"] = claims.IssuedAt + if claims.NotBefore != nil { + (*c)["nbf"] = claims.NotBefore + } + } + default: + err = claims.ValidateWithLeeway(expected, 0) + } + + if err != nil { + switch err { + case jwt.ErrExpired: + return ErrExpired + case jwt.ErrNotValidYet: + return ErrNotValidYet + case jwt.ErrIssuedInTheFuture: + return ErrIssuedInTheFuture + } + } + + return err +} + +func normalize(i interface{}) (context.Map, error) { + if m, ok := i.(context.Map); ok { + return m, nil + } + + m := make(context.Map) + + raw, err := json.Marshal(i) + if err != nil { + return nil, err + } + + d := json.NewDecoder(bytes.NewReader(raw)) + d.UseNumber() + + if err := d.Decode(&m); err != nil { + return nil, err + } + + return m, nil +} + +func getRequiredFieldIndexes(i interface{}) (v []int) { + val := reflect.Indirect(reflect.ValueOf(i)) + typ := val.Type() + if typ.Kind() != reflect.Struct { + return nil + } + + for i := 0; i < val.NumField(); i++ { + field := typ.Field(i) + // Note: for the sake of simplicity we don't lookup for nested objects (FieldByIndex), + // we could do that as we do in dependency injection feature but unless requirested we don't. + tag := field.Tag.Get("json") + if strings.Contains(tag, ",required") { + v = append(v, i) + } + } + + return +}