diff --git a/_examples/auth/basicauth/main.go b/_examples/auth/basicauth/main.go index 39132d6e..4f06383e 100644 --- a/_examples/auth/basicauth/main.go +++ b/_examples/auth/basicauth/main.go @@ -1,8 +1,6 @@ package main import ( - "time" - "github.com/kataras/iris/v12" "github.com/kataras/iris/v12/middleware/basicauth" ) @@ -10,25 +8,40 @@ import ( func newApp() *iris.Application { app := iris.New() - authConfig := basicauth.Config{ - Users: map[string]string{"myusername": "mypassword", "mySecondusername": "mySecondpassword"}, - Realm: "Authorization Required", // defaults to "Authorization Required" - Expires: time.Duration(30) * time.Minute, - } + /* + opts := basicauth.Options{ + Realm: "Authorization Required", + MaxAge: 30 * time.Minute, + GC: basicauth.GC{ + Every: 2 * time.Hour, + }, + Allow: basicauth.AllowUsers(map[string]string{ + "myusername": "mypassword", + "mySecondusername": "mySecondpassword", + }), + MaxTries: 2, + } + auth := basicauth.New(opts) - authentication := basicauth.New(authConfig) + OR simply: + */ - // to global app.Use(authentication) (or app.UseGlobal before the .Run) + auth := basicauth.Default(map[string]string{ + "myusername": "mypassword", + "mySecondusername": "mySecondpassword", + }) + + // to global app.Use(auth) (or app.UseGlobal before the .Run) // to routes /* - app.Get("/mysecret", authentication, h) + app.Get("/mysecret", auth, h) */ app.Get("/", func(ctx iris.Context) { ctx.Redirect("/admin") }) // to party - needAuth := app.Party("/admin", authentication) + needAuth := app.Party("/admin", auth) { //http://localhost:8080/admin needAuth.Get("/", h) diff --git a/_examples/auth/basicauth/main_test.go b/_examples/auth/basicauth/main_test.go index e3f27486..7a2142e2 100644 --- a/_examples/auth/basicauth/main_test.go +++ b/_examples/auth/basicauth/main_test.go @@ -25,5 +25,5 @@ func TestBasicAuth(t *testing.T) { // with invalid basic auth e.GET("/admin/settings").WithBasicAuth("invalidusername", "invalidpassword"). - Expect().Status(httptest.StatusUnauthorized) + Expect().Status(httptest.StatusForbidden) } diff --git a/_examples/auth/jwt/tutorial/go.mod b/_examples/auth/jwt/tutorial/go.mod index d35aad38..0d8f1db0 100644 --- a/_examples/auth/jwt/tutorial/go.mod +++ b/_examples/auth/jwt/tutorial/go.mod @@ -4,7 +4,7 @@ go 1.15 require ( github.com/google/uuid v1.1.2 - github.com/kataras/iris/v12 v12.2.0-alpha.0.20201106220849-7a19cfb2112f + github.com/kataras/iris/v12 v12.2.0-alpha.0.20201113181155-4d09475c290d golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897 ) diff --git a/_examples/routing/rewrite/main.go b/_examples/routing/rewrite/main.go index 9050e0d3..d6d48ee3 100644 --- a/_examples/routing/rewrite/main.go +++ b/_examples/routing/rewrite/main.go @@ -17,22 +17,21 @@ func main() { newtest := app.Subdomain("newtest") newtest.Get("/", newTestIndex) - newtest.Get("/", newTestAbout) + newtest.Get("/about", newTestAbout) redirects := rewrite.Load("redirects.yml") app.WrapRouter(redirects) - // http://mydomain.com:8080/seo/about -> http://www.mydomain.com:8080/about - // http://test.mydomain.com:8080 -> http://newtest.mydomain.com:8080 - // http://test.mydomain.com:8080/seo/about -> http://newtest.mydomain.com:8080/about - // http://localhost:8080/seo -> http://localhost:8080 - // http://localhost:8080/about - // http://localhost:8080/docs/v12/hello -> http://localhost:8080/docs - // http://localhost:8080/docs/v12some -> http://localhost:8080/docs - // http://localhost:8080/oldsome -> http://localhost:8080 - // http://localhost:8080/oldindex/random -> http://localhost:8080 - // http://localhost:8080/users.json -> http://localhost:8080/users.json - // ^ (but with an internal ?format=json, client can't see it) + // http://mydomain.com:8080/seo/about -> http://www.mydomain.com:8080/about + // http://test.mydomain.com:8080 -> http://newtest.mydomain.com:8080 + // http://test.mydomain.com:8080/seo/about -> http://newtest.mydomain.com:8080/about + // http://mydomain.com:8080/seo -> http://www.mydomain.com:8080 + // http://mydomain.com:8080/about + // http://mydomain.com:8080/docs/v12/hello -> http://www.mydomain.com:8080/docs + // http://mydomain.com:8080/docs/v12some -> http://www.mydomain.com:8080/docs + // http://mydomain.com:8080/oldsome -> http://www.mydomain.com:8080 + // http://mydomain.com:8080/oldindex/random -> http://www.mydomain.com:8080 + // http://mydomain.com:8080/users.json -> http://www.mydomain.com:8080/users?format=json app.Listen(":8080") } diff --git a/_examples/routing/rewrite/redirects.yml b/_examples/routing/rewrite/redirects.yml index 546baccf..389bb522 100644 --- a/_examples/routing/rewrite/redirects.yml +++ b/_examples/routing/rewrite/redirects.yml @@ -21,4 +21,4 @@ RedirectMatch: # REDIRECT_CODE_DIGITS | PATTERN_REGEX | TARGET_REPL # Redirects root domain to www. # Creation of a www subdomain inside the Application is unnecessary, # all requests are handled by the root Application itself. -PrimarySubdomain: www \ No newline at end of file +PrimarySubdomain: www diff --git a/context/context.go b/context/context.go index 71498010..4446f722 100644 --- a/context/context.go +++ b/context/context.go @@ -4726,7 +4726,7 @@ func (ctx *Context) UpsertCookie(cookie *http.Cookie, options ...CookieOption) b // you can change it or simple, use the SetCookie for more control. // // See `CookieExpires` and `AddCookieOptions` for more. -var SetCookieKVExpiration = time.Duration(8760) * time.Hour +var SetCookieKVExpiration = 8760 * time.Hour // SetCookieKV adds a cookie, requires the name(string) and the value(string). // @@ -5343,7 +5343,15 @@ const userContextKey = "iris.user" // SetUser sets a value as a User for this request. // It's used by auth middlewares as a common // method to provide user information to the -// next handlers in the chain +// next handlers in the chain. +// +// The "i" input argument can be: +// - A value which completes the User interface +// - A map[string]interface{}. +// - A value which does not complete the whole User interface +// - A value which does not complete the User interface at all +// (only its `User().GetRaw` method is available). +// // Look the `User` method to retrieve it. func (ctx *Context) SetUser(i interface{}) error { if i == nil { @@ -5371,6 +5379,9 @@ func (ctx *Context) SetUser(i interface{}) error { } // User returns the registered User of this request. +// To get the original value (even if a value set by SetUser does not implement the User interface) +// use its GetRaw method. +// / // See `SetUser` too. func (ctx *Context) User() User { if v := ctx.values.Get(userContextKey); v != nil { diff --git a/context/context_user.go b/context/context_user.go index 893b86d6..009b7e63 100644 --- a/context/context_user.go +++ b/context/context_user.go @@ -33,6 +33,8 @@ var ErrNotSupported = errors.New("not supported") // - UserMap (a wrapper by SetUser) // - UserPartial (a wrapper by SetUser) type User interface { + // GetRaw should return the raw instance of the user, if supported. + GetRaw() (interface{}, error) // GetAuthorization should return the authorization method, // e.g. Basic Authentication. GetAuthorization() (string, error) @@ -92,6 +94,11 @@ type SimpleUser struct { var _ User = (*SimpleUser)(nil) +// GetRaw returns itself. +func (u *SimpleUser) GetRaw() (interface{}, error) { + return u, nil +} + // GetAuthorization returns the authorization method, // e.g. Basic Authentication. func (u *SimpleUser) GetAuthorization() (string, error) { @@ -179,6 +186,11 @@ type UserMap Map var _ User = UserMap{} +// GetRaw returns the underline map. +func (u UserMap) GetRaw() (interface{}, error) { + return Map(u), nil +} + // GetAuthorization returns the authorization or Authorization value of the map. func (u UserMap) GetAuthorization() (string, error) { return u.str("authorization") @@ -292,11 +304,17 @@ type ( GetID() string } - userGetUsername interface { + // UserGetUsername interface which + // requires a single method to complete + // a User on Context.SetUser. + UserGetUsername interface { GetUsername() string } - userGetPassword interface { + // UserGetPassword interface which + // requires a single method to complete + // a User on Context.SetUser. + UserGetPassword interface { GetPassword() string } @@ -319,13 +337,14 @@ type ( // UserPartial is a User. // It's a helper which wraps a struct value that // may or may not complete the whole User interface. + // See Context.SetUser. UserPartial struct { Raw interface{} userGetAuthorization userGetAuthorizedAt userGetID - userGetUsername - userGetPassword + UserGetUsername + UserGetPassword userGetEmail userGetRoles userGetToken @@ -336,61 +355,64 @@ type ( var _ User = (*UserPartial)(nil) func newUserPartial(i interface{}) *UserPartial { - containsAtLeastOneMethod := false + if i == nil { + return nil + } + p := &UserPartial{Raw: i} if u, ok := i.(userGetAuthorization); ok { p.userGetAuthorization = u - containsAtLeastOneMethod = true } if u, ok := i.(userGetAuthorizedAt); ok { p.userGetAuthorizedAt = u - containsAtLeastOneMethod = true } if u, ok := i.(userGetID); ok { p.userGetID = u - containsAtLeastOneMethod = true } - if u, ok := i.(userGetUsername); ok { - p.userGetUsername = u - containsAtLeastOneMethod = true + if u, ok := i.(UserGetUsername); ok { + p.UserGetUsername = u } - if u, ok := i.(userGetPassword); ok { - p.userGetPassword = u - containsAtLeastOneMethod = true + if u, ok := i.(UserGetPassword); ok { + p.UserGetPassword = u } if u, ok := i.(userGetEmail); ok { p.userGetEmail = u - containsAtLeastOneMethod = true } if u, ok := i.(userGetRoles); ok { p.userGetRoles = u - containsAtLeastOneMethod = true } if u, ok := i.(userGetToken); ok { p.userGetToken = u - containsAtLeastOneMethod = true } if u, ok := i.(userGetField); ok { p.userGetField = u - containsAtLeastOneMethod = true } - if !containsAtLeastOneMethod { - return nil - } + // if !containsAtLeastOneMethod { + // return nil + // } return p } +// GetRaw returns the original raw instance of the user. +func (u *UserPartial) GetRaw() (interface{}, error) { + if u == nil { + return nil, ErrNotSupported + } + + return u.Raw, nil +} + // GetAuthorization should return the authorization method, // e.g. Basic Authentication. func (u *UserPartial) GetAuthorization() (string, error) { @@ -422,7 +444,7 @@ func (u *UserPartial) GetID() (string, error) { // GetUsername should return the name of the User. func (u *UserPartial) GetUsername() (string, error) { - if v := u.userGetUsername; v != nil { + if v := u.UserGetUsername; v != nil { return v.GetUsername(), nil } @@ -432,7 +454,7 @@ func (u *UserPartial) GetUsername() (string, error) { // GetPassword should return the encoded or raw password // (depends on the implementation) of the User. func (u *UserPartial) GetPassword() (string, error) { - if v := u.userGetPassword; v != nil { + if v := u.UserGetPassword; v != nil { return v.GetPassword(), nil } diff --git a/middleware/basicauth/basicauth.go b/middleware/basicauth/basicauth.go index a1c83b5c..4aae6d14 100644 --- a/middleware/basicauth/basicauth.go +++ b/middleware/basicauth/basicauth.go @@ -1,205 +1,402 @@ -// Package basicauth provides http basic authentication via middleware. See _examples/auth/basicauth package basicauth -/* -Test files: - - ../../_examples/auth/basicauth/main_test.go - - ./basicauth_test.go -*/ - import ( - "encoding/base64" + stdContext "context" "strconv" "sync" "time" "github.com/kataras/iris/v12/context" + "github.com/kataras/iris/v12/sessions" ) func init() { context.SetHandlerName("iris/middleware/basicauth.*", "iris.basicauth") } -const authorizationType = "Basic Authentication" - -type ( - encodedUser struct { - HeaderValue string - Username string - Password string - logged bool - forceLogout bool // in order to be able to invalidate and use a redirect response. - authorizedAt time.Time // when from !logged to logged. - expires time.Time - mu sync.RWMutex - } - - basicAuthMiddleware struct { - config *Config - // these are filled from the config.Users map at the startup - auth []*encodedUser - realmHeaderValue string - - // The below can be removed but they are here because on the future we may add dynamic options for those two fields, - // it is a bit faster to check the b.$bool as well. - expireEnabled bool // if the config.Expires is a valid date, default is disabled. - askHandlerEnabled bool // if the config.OnAsk is not nil, defaults to false. - } +const ( + DefaultRealm = "Authorization Required" + DefaultMaxTriesCookie = "basicmaxtries" ) -// +const ( + authorizationType = "Basic Authentication" + authenticateHeaderKey = "WWW-Authenticate" + proxyAuthenticateHeaderKey = "Proxy-Authenticate" + authorizationHeaderKey = "Authorization" + proxyAuthorizationHeaderKey = "Proxy-Authorization" +) -// New accepts basicauth.Config and returns a new Handler -// which will ask the client for basic auth (username, password), -// validate that and if valid continues to the next handler, otherwise -// throws a StatusUnauthorized http error code. -// -// Use the `Context.User` method to retrieve the stored user. -func New(c Config) context.Handler { - config := DefaultConfig() - if c.Realm != "" { - config.Realm = c.Realm - } - config.Users = c.Users - config.Expires = c.Expires - config.OnAsk = c.OnAsk +type AuthFunc func(ctx *context.Context, username, password string) (interface{}, bool) - b := &basicAuthMiddleware{config: &config} - b.init() - return b.Serve +type Options struct { + // Realm http://tools.ietf.org/html/rfc2617#section-1.2. + // E.g. "Authorization Required". + Realm string + // In the case of proxies, the challenging status code is 407 (Proxy Authentication Required), + // the Proxy-Authenticate response header contains at least one challenge applicable to the proxy, + // and the Proxy-Authorization request header is used for providing the credentials to the proxy server. + // + // Proxy should be used to gain access to a resource behind a proxy server. + // It authenticates the request to the proxy server, allowing it to transmit the request further. + Proxy bool + // Usage: + // - Allow: AllowUsers(iris.Map{"username": "...", "password": "...", "other_field": ...}, [BCRYPT]) + // - Allow: AllowUsersFile("users.yml", [BCRYPT]) + Allow AuthFunc + // If greater than zero then the server will send 403 forbidden status code afer MaxTries + // of invalid credentials of a specific client consumed (session or cookie based, see MaxTriesCookie). + // By default the server will re-ask for credentials on any amount of invalid credentials. + MaxTries int + // If a session manager is register under the current request, + // then this value should be the key of the session storage which + // the current tries will be stored. Otherwise + // it is the raw cookie name. + // The cookie is stored up to the configured MaxAge if greater than zero or for 1 year, + // so a forbidden client can request for authentication again after the MaxAge expired. + // + // Note that, the session way is recommended as the current tries + // cannot be modified by the client (unless the client removes the session cookie). + // However the raw cookie performs faster. You can always set custom logic + // on the Allow field as you have access to the current request Context. + // To set custom cookie options use the `Context.AddCookieOptions(options ...iris.CookieOption)` + // before the basic auth middleware. + // + // If MaxTries > 0 then it defaults to "basicmaxtries". + // The MaxTries should be set to greater than zero. + MaxTriesCookie string + // If not nil runs after 401 (or 407 if proxy is enabled) status code. + // Can be used to set custom response for unauthenticated clients. + OnAsk context.Handler + // If not nil runs after the 403 forbidden status code (when Allow returned false and MaxTries consumed). + // Can be used to set custom response when client tried to access a resource with invalid credentials. + OnForbidden context.Handler + // MaxAge sets expiration duration for the in-memory credentials map. + // By default an old map entry will be removed when the user visits a page. + // In order to remove old entries automatically please take a look at the `GC` option too. + // + // Usage: + // MaxAge: 30*time.Minute + MaxAge time.Duration + // GC automatically clears old entries every x duration. + // Note that, by old entries we mean expired credentials therefore + // the `MaxAge` option should be already set, + // if it's not then all entries will be removed on "every" duration. + // The standard context can be used for the internal ticker cancelation, it can be nil. + // + // Usage: + // GC: basicauth.GC{Every: 2*time.Hour} + GC GC } -// Default accepts only the users and returns a new Handler -// which will ask the client for basic auth (username, password), -// validate that and if valid continues to the next handler, otherwise -// throws a StatusUnauthorized http error code. -func Default(users map[string]string) context.Handler { - c := DefaultConfig() - c.Users = users - return New(c) +type GC struct { + Context stdContext.Context + Every time.Duration } -func (b *basicAuthMiddleware) init() { - // pass the encoded users from the user's config's Users value - b.auth = make([]*encodedUser, 0, len(b.config.Users)) +// https://tools.ietf.org/html/rfc2617 +// https://developer.mozilla.org/en-US/docs/Web/HTTP/Authentication +// +// As the user ID and password are passed over the network as clear text +// (it is base64 encoded, but base64 is a reversible encoding), the basic authentication scheme is not secure. +// HTTPS/TLS should be used with basic authentication. Without these additional security enhancements, +// basic authentication should not be used to protect sensitive or valuable information. +type BasicAuth struct { + opts Options + // built based on proxy field + askCode int + authorizationHeader string + authenticateHeader string + // built based on realm field. + authenticateHeaderValue string - for k, v := range b.config.Users { - fullUser := k + ":" + v - header := "Basic " + base64.StdEncoding.EncodeToString([]byte(fullUser)) - b.auth = append(b.auth, &encodedUser{ - HeaderValue: header, - Username: k, - Password: v, - logged: false, - expires: DefaultExpireTime, - }) + credentials map[string]*time.Time // key = username:password, value = expiration time (if MaxAge > 0). + mu sync.RWMutex // protects the credentials as they can modified. +} + +func New(opts Options) context.Handler { + var ( + askCode = 401 + authorizationHeader = authorizationHeaderKey + authenticateHeader = authenticateHeaderKey + authenticateHeaderValue = "Basic" + ) + + if opts.Allow == nil { + panic("BasicAuth: Allow field is required") } - // set the auth realm header's value - b.realmHeaderValue = "Basic realm=" + strconv.Quote(b.config.Realm) + if opts.Realm != "" { + authenticateHeaderValue += " realm=" + strconv.Quote(opts.Realm) + } - b.expireEnabled = b.config.Expires > 0 - b.askHandlerEnabled = b.config.OnAsk != nil + if opts.Proxy { + askCode = 407 + authenticateHeader = proxyAuthenticateHeaderKey + authorizationHeader = proxyAuthorizationHeaderKey + } + + if opts.MaxTries > 0 && opts.MaxTriesCookie == "" { + opts.MaxTriesCookie = DefaultMaxTriesCookie + } + + b := &BasicAuth{ + opts: opts, + askCode: askCode, + authorizationHeader: authorizationHeader, + authenticateHeader: authenticateHeader, + authenticateHeaderValue: authenticateHeaderValue, + credentials: make(map[string]*time.Time), + } + + if opts.GC.Every > 0 { + go b.runGC(opts.GC.Context, opts.GC.Every) + } + + return b.serveHTTP } -func (b *basicAuthMiddleware) findAuth(headerValue string) (*encodedUser, bool) { - if headerValue != "" { - for _, user := range b.auth { - if user.HeaderValue == headerValue { - return user, true - } +// - map[string]string form of: {username:password, ...} form. +// - map[string]interface{} form of: []{"username": "...", "password": "...", "other_field": ...}, ...}. +// - []T which T completes the User interface. +// - []T which T contains at least Username and Password fields. +func Default(users interface{}, userOpts ...UserAuthOption) context.Handler { + opts := Options{ + Realm: DefaultRealm, + Allow: AllowUsers(users, userOpts...), + } + return New(opts) +} + +func Load(jsonOrYamlFilename string, userOpts ...UserAuthOption) context.Handler { + opts := Options{ + Realm: DefaultRealm, + Allow: AllowUsersFile(jsonOrYamlFilename, userOpts...), + } + return New(opts) +} + +// askForCredentials sends a response to the client which client should catch +// and ask for username:password credentials. +func (b *BasicAuth) askForCredentials(ctx *context.Context) { + ctx.Header(b.authenticateHeader, b.authenticateHeaderValue) + ctx.StopWithStatus(b.askCode) + + if h := b.opts.OnAsk; h != nil { + h(ctx) + } +} + +// If a (proxy) server receives valid credentials that are inadequate to access a given resource, +// the server should respond with the 403 Forbidden status code. +// Unlike 401 Unauthorized or 407 Proxy Authentication Required, authentication is impossible for this user. +func (b *BasicAuth) forbidden(ctx *context.Context) { + ctx.StopWithStatus(403) + + if h := b.opts.OnForbidden; h != nil { + h(ctx) + } +} + +func (b *BasicAuth) getCurrentTries(ctx *context.Context) (tries int) { + sess := sessions.Get(ctx) + if sess != nil { + tries = sess.GetIntDefault(b.opts.MaxTriesCookie, 0) + } else { + if v := ctx.GetCookie(b.opts.MaxTriesCookie); v != "" { + tries, _ = strconv.Atoi(v) } } - return nil, false + return } -func (b *basicAuthMiddleware) askForCredentials(ctx *context.Context) { - ctx.Header("WWW-Authenticate", b.realmHeaderValue) - ctx.StatusCode(401) - if b.askHandlerEnabled { - b.config.OnAsk(ctx) +func (b *BasicAuth) setCurrentTries(ctx *context.Context, tries int) { + sess := sessions.Get(ctx) + if sess != nil { + sess.Set(b.opts.MaxTriesCookie, tries) + } else { + maxAge := b.opts.MaxAge + if maxAge == 0 { + maxAge = context.SetCookieKVExpiration // 1 year. + } + ctx.SetCookieKV(b.opts.MaxTriesCookie, strconv.Itoa(tries), context.CookieExpires(maxAge)) } } -// Serve the actual basic authentication middleware. -// Use the Context.User method to retrieve the stored user. -func (b *basicAuthMiddleware) Serve(ctx *context.Context) { - auth, found := b.findAuth(ctx.GetHeader("Authorization")) - if !found || auth.forceLogout { - if auth != nil { - auth.mu.Lock() - auth.forceLogout = false - auth.mu.Unlock() +func (b *BasicAuth) resetCurrentTries(ctx *context.Context) { + sess := sessions.Get(ctx) + if sess != nil { + sess.Delete(b.opts.MaxTriesCookie) + } else { + ctx.RemoveCookie(b.opts.MaxTriesCookie) + } +} + +// serveHTTP is the main method of this middleware, +// checks and verifies the auhorization header for basic authentication, +// next handlers will only be executed when the client is allowed to continue. +func (b *BasicAuth) serveHTTP(ctx *context.Context) { + header := ctx.GetHeader(b.authorizationHeader) + fullUser, username, password, ok := decodeHeader(header) + if !ok { // Header is malformed or missing. + b.askForCredentials(ctx) + return + } + + var ( + maxTries = b.opts.MaxTries + tries int + ) + + if maxTries > 0 { + tries = b.getCurrentTries(ctx) + } + + user, ok := b.opts.Allow(ctx, username, password) + if !ok { // This username:password combination was not allowed. + if maxTries > 0 { + tries++ + b.setCurrentTries(ctx, tries) + if tries >= maxTries { // e.g. if MaxTries == 1 then it should be allowed only once, so we must send forbidden now. + b.forbidden(ctx) // a user was forbidden, to reset its status should clear the Authorization header and cookie and request the resource again. + return + } } b.askForCredentials(ctx) - ctx.StopExecution() return - // don't continue to the next handler } - auth.mu.RLock() - logged := auth.logged - auth.mu.RUnlock() - if !logged { - auth.mu.Lock() - auth.authorizedAt = time.Now() - auth.mu.Unlock() + if tries > 0 { + // had failures but it's ok, reset the tries on success. + b.resetCurrentTries(ctx) } - // all ok - if b.expireEnabled { - if !logged { - auth.mu.Lock() - auth.expires = auth.authorizedAt.Add(b.config.Expires) - auth.logged = true - auth.mu.Unlock() + b.mu.RLock() + expiresAt, ok := b.credentials[fullUser] + b.mu.RUnlock() + var authorizedAt time.Time + if ok { + if expiresAt != nil { // Has expiration. + if expiresAt.Before(time.Now()) { // Has been expired. + b.mu.Lock() // Delete the entry. + delete(b.credentials, fullUser) + b.mu.Unlock() + // Re-ask for new credentials. + b.askForCredentials(ctx) + return + } + + // It's ok, find the time authorized to fill the user below, if necessary. + authorizedAt = expiresAt.Add(-b.opts.MaxAge) } - - auth.mu.RLock() - expired := time.Now().After(auth.expires) - auth.mu.RUnlock() - if expired { - auth.mu.Lock() - auth.logged = false - auth.forceLogout = false - auth.mu.Unlock() - b.askForCredentials(ctx) // ask for authentication again - ctx.StopExecution() - return + } else { + // Saved credential not found, first login. + if b.opts.MaxAge > 0 { // Expiration is enabled, set the value. + authorizedAt = time.Now() + t := authorizedAt.Add(b.opts.MaxAge) + expiresAt = &t } + b.mu.Lock() + b.credentials[fullUser] = expiresAt + b.mu.Unlock() } - if !b.config.DisableContextUser { - ctx.SetLogoutFunc(b.Logout) - - auth.mu.RLock() - user := &context.SimpleUser{ + if user == nil { + // No custom uset was set by the auth func, + // it is passed though, set a simple user here: + user = &context.SimpleUser{ Authorization: authorizationType, - AuthorizedAt: auth.authorizedAt, - Username: auth.Username, - Password: auth.Password, + AuthorizedAt: authorizedAt, + Username: username, + Password: password, } - auth.mu.RUnlock() - ctx.SetUser(user) } - ctx.Next() // continue + ctx.SetUser(user) + ctx.SetLogoutFunc(b.logout) + + ctx.Next() } -// Logout sends a 401 so the browser/client can invalidate the -// Basic Authentication and also sets the underline user's logged field to false, -// so its expiration resets when re-ask for credentials. -// -// End-developers should call the `Context.Logout()` method -// to fire this method as this structure is hidden. -func (b *basicAuthMiddleware) Logout(ctx *context.Context) { - ctx.StatusCode(401) - if auth, found := b.findAuth(ctx.GetHeader("Authorization")); found { - auth.mu.Lock() - auth.logged = false - auth.forceLogout = true - auth.mu.Unlock() +// logout clears the current user's credentials. +func (b *BasicAuth) logout(ctx *context.Context) { + var ( + fullUser, username, password string + ok bool + ) + + if u := ctx.User(); u != nil { // Get the saved ones, if any. + username, _ = u.GetUsername() + password, _ = u.GetPassword() + fullUser = username + colonLiteral + password + ok = username != "" && password != "" + } + + if !ok { + // If the custom user does + // not implement those two, then extract from the request header: + header := ctx.GetHeader(b.authorizationHeader) + fullUser, username, password, ok = decodeHeader(header) + } + + if ok { // If it's authorized then try to lock and delete. + if b.opts.Proxy { + ctx.Request().Header.Del(proxyAuthorizationHeaderKey) + } + // delete the request header so future Request().BasicAuth are empty. + ctx.Request().Header.Del(authorizationHeaderKey) + + b.mu.Lock() + delete(b.credentials, fullUser) + b.mu.Unlock() } } + +// runGC runs a function in a separate go routine +// every x duration to clear in-memory expired credential entries. +func (b *BasicAuth) runGC(ctx stdContext.Context, every time.Duration) { + if ctx == nil { + ctx = stdContext.Background() + } + + t := time.NewTicker(every) + defer t.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-t.C: + b.gc() + } + } +} + +// gc removes all entries expired based on the max age or all entries (if max age is missing). +func (b *BasicAuth) gc() int { + now := time.Now() + var markedForDeletion []string + + b.mu.RLock() + for fullUser, expiresAt := range b.credentials { + if expiresAt == nil { + markedForDeletion = append(markedForDeletion, fullUser) + } else if expiresAt.Before(now) { + markedForDeletion = append(markedForDeletion, fullUser) + } + } + b.mu.RUnlock() + + n := len(markedForDeletion) + if n > 0 { + for _, fullUser := range markedForDeletion { + b.mu.Lock() + delete(b.credentials, fullUser) + b.mu.Unlock() + } + } + + return n +} diff --git a/middleware/basicauth/basicauth_test.go b/middleware/basicauth/basicauth_test.go index 2eff61dc..f04478ad 100644 --- a/middleware/basicauth/basicauth_test.go +++ b/middleware/basicauth/basicauth_test.go @@ -8,7 +8,7 @@ import ( "github.com/kataras/iris/v12" "github.com/kataras/iris/v12/httptest" - "github.com/kataras/iris/v12/middleware/basicauth" + basicauth "github.com/kataras/iris/v12/middleware/basicauth" ) func TestBasicAuthUseRouter(t *testing.T) { @@ -18,7 +18,13 @@ func TestBasicAuthUseRouter(t *testing.T) { "admin": "admin", } - app.UseRouter(basicauth.Default(users)) + auth := basicauth.New(basicauth.Options{ + Allow: basicauth.AllowUsers(users), + Realm: basicauth.DefaultRealm, + MaxTries: 1, + }) + + app.UseRouter(auth) app.Get("/user_json", func(ctx iris.Context) { ctx.JSON(ctx.User()) @@ -80,9 +86,9 @@ func TestBasicAuthUseRouter(t *testing.T) { e.GET("/").Expect().Status(httptest.StatusUnauthorized).Body().Equal("Unauthorized") // Test invalid auth. e.GET("/").WithBasicAuth(username, "invalid_password").Expect(). - Status(httptest.StatusUnauthorized).Body().Equal("Unauthorized") + Status(httptest.StatusForbidden) e.GET("/").WithBasicAuth("invaid_username", password).Expect(). - Status(httptest.StatusUnauthorized).Body().Equal("Unauthorized") + Status(httptest.StatusForbidden) // Test different method, it should pass the authentication (no stop on 401) // but it doesn't fire the GET route, instead it gives 405. @@ -97,9 +103,9 @@ func TestBasicAuthUseRouter(t *testing.T) { e.GET("/notfound").Expect().Status(httptest.StatusUnauthorized).Body().Equal("Unauthorized") // Test invalid auth. e.GET("/notfound").WithBasicAuth(username, "invalid_password").Expect(). - Status(httptest.StatusUnauthorized).Body().Equal("Unauthorized") + Status(httptest.StatusForbidden) e.GET("/notfound").WithBasicAuth("invaid_username", password).Expect(). - Status(httptest.StatusUnauthorized).Body().Equal("Unauthorized") + Status(httptest.StatusForbidden) // Test subdomain inherited. sub := e.Builder(func(req *httptest.Request) { @@ -114,9 +120,9 @@ func TestBasicAuthUseRouter(t *testing.T) { sub.GET("/").Expect().Status(httptest.StatusUnauthorized) // Test invalid auth. sub.GET("/").WithBasicAuth(username, "invalid_password").Expect(). - Status(httptest.StatusUnauthorized).Body().Equal("Unauthorized") + Status(httptest.StatusForbidden) sub.GET("/").WithBasicAuth("invaid_username", password).Expect(). - Status(httptest.StatusUnauthorized).Body().Equal("Unauthorized") + Status(httptest.StatusForbidden) // Test pass the authentication but route not found. sub.GET("/notfound").WithBasicAuth(username, password).Expect(). @@ -126,9 +132,9 @@ func TestBasicAuthUseRouter(t *testing.T) { sub.GET("/notfound").Expect().Status(httptest.StatusUnauthorized).Body().Equal("Unauthorized") // Test invalid auth. sub.GET("/notfound").WithBasicAuth(username, "invalid_password").Expect(). - Status(httptest.StatusUnauthorized).Body().Equal("Unauthorized") + Status(httptest.StatusForbidden) sub.GET("/notfound").WithBasicAuth("invaid_username", password).Expect(). - Status(httptest.StatusUnauthorized).Body().Equal("Unauthorized") + Status(httptest.StatusForbidden) // Test a reset-ed Party with a single one UseRouter // which writes on matched routes and reset and send the error on errors. diff --git a/middleware/basicauth/config.go b/middleware/basicauth/config.go deleted file mode 100644 index 70396215..00000000 --- a/middleware/basicauth/config.go +++ /dev/null @@ -1,59 +0,0 @@ -package basicauth - -import ( - "time" - - "github.com/kataras/iris/v12/context" -) - -const ( - // DefaultBasicAuthRealm is "Authorization Required" - DefaultBasicAuthRealm = "Authorization Required" -) - -// DefaultExpireTime zero time -var DefaultExpireTime time.Time // 0001-01-01 00:00:00 +0000 UTC - -// Config the configs for the basicauth middleware -type Config struct { - // Users a map of login and the value (username/password) - Users map[string]string - // Realm http://tools.ietf.org/html/rfc2617#section-1.2. Default is "Authorization Required" - Realm string - // Expires expiration duration, default is 0 never expires. - Expires time.Duration - - // OnAsk fires each time the server asks to the client for credentials in order to gain access and continue to the next handler. - // - // You could also ignore this option and - // - just add a listener for unauthorized status codes with: - // `app.OnErrorCode(iris.StatusUnauthorized, unauthorizedWantsAccessHandler)` - // - or register a middleware which will force `ctx.Next/or direct call` - // the basicauth middleware and check its `ctx.GetStatusCode()`. - // - // However, this option is very useful when you want the framework to fire a handler - // ONLY when the Basic Authentication sends an `iris.StatusUnauthorized`, - // and free the error code listener to catch other types of unauthorized access, i.e Kerberos. - // Also with this one, not recommended at all but, you are able to "force-allow" other users by calling the `ctx.StatusCode` inside this handler; - // i.e when it is possible to create authorized users dynamically but - // if that is the case then you should go with something like sessions instead of basic authentication. - // - // Usage: basicauth.New(basicauth.Config{..., OnAsk: unauthorizedWantsAccessViaBasicAuthHandler}) - // - // Defaults to nil. - OnAsk context.Handler - - // DisableContextUser disables the registration of the custom basicauth Context.Logout - // and the User. - DisableContextUser bool -} - -// DefaultConfig returns the default configs for the BasicAuth middleware -func DefaultConfig() Config { - return Config{make(map[string]string), DefaultBasicAuthRealm, 0, nil, false} -} - -// User returns the user from context key same as ctx.Request().BasicAuth(). -func (c Config) User(ctx *context.Context) (string, string, bool) { - return ctx.Request().BasicAuth() -} diff --git a/middleware/basicauth/header.go b/middleware/basicauth/header.go new file mode 100644 index 00000000..eadb6b8a --- /dev/null +++ b/middleware/basicauth/header.go @@ -0,0 +1,88 @@ +package basicauth + +import ( + "encoding/base64" + "strings" +) + +const ( + spaceChar = ' ' + colonChar = ':' + colonLiteral = string(colonChar) + basicLiteral = "Basic" + basicSpaceLiteral = "Basic " + basicSpaceLiteralLen = len(basicSpaceLiteral) +) + +// The username and password are combined with a single colon (:). +// This means that the username itself cannot contain a colon. +// URL encoding (e.g. https://Aladdin:OpenSesame@www.example.com/index.html) +// has been deprecated by rfc3986. +func encodeHeader(username, password string) (string, bool) { + if strings.Contains(username, colonLiteral) || strings.Contains(password, colonLiteral) { + return "", false + } + fullUser := []byte(username + colonLiteral + password) + header := basicSpaceLiteral + base64.StdEncoding.EncodeToString(fullUser) + + return header, true +} + +// Like net/http.parseBasicAuth +func decodeHeader(header string) (fullUser, username, password string, ok bool) { + if len(header) < basicSpaceLiteralLen || !strings.EqualFold(header[:basicSpaceLiteralLen], basicSpaceLiteral) { + return + } + + c, err := base64.StdEncoding.DecodeString(header[basicSpaceLiteralLen:]) + if err != nil { + return + } + + cs := string(c) + s := strings.IndexByte(cs, colonChar) + if s < 0 { + return + } + return cs, cs[:s], cs[s+1:], true + + /* + for i := 0; i < n; i++ { + if header[i] == spaceChar { + prefix := header[:i] + if prefix != basicLiteral { + return + } + + if n <= i+1 { + return + } + + decodedFullUser, err := base64.RawStdEncoding.DecodeString(header[i+1:]) + if err != nil { + return + } + + fullUser = string(decodedFullUser) + break + } + } + + n = len(fullUser) + for i := n - 1; i > -1; i-- { + if fullUser[i] == colonChar { + username = fullUser[:i] + password = fullUser[i+1:] + + if strings.TrimSpace(username) == "" || strings.TrimSpace(password) == "" { + ok = false + } else { + ok = true + } + + return + } + } + + return*/ +} diff --git a/middleware/basicauth/header_test.go b/middleware/basicauth/header_test.go new file mode 100644 index 00000000..f365da7b --- /dev/null +++ b/middleware/basicauth/header_test.go @@ -0,0 +1,103 @@ +package basicauth + +import "testing" + +func TestHeaderEncode(t *testing.T) { + var tests = []struct { + username string + password string + header string + ok bool + }{ + { + username: "user", + password: "pass", + header: "Basic dXNlcjpwYXNz", + ok: true, + }, + { + username: "user", + password: "p:(notallowed)ass", + header: "", + ok: false, + }, + { + username: "123u%ser", + password: "pass132$", + header: "Basic MTIzdSVzZXI6cGFzczEzMiQ=", + ok: true, + }, + } + + for i, tt := range tests { + got, ok := encodeHeader(tt.username, tt.password) + if tt.ok != ok { + t.Fatalf("[%d] expected: %v but got: %v (username=%s,password=%s)", i, tt.ok, ok, tt.username, tt.password) + } + if tt.header != got { + t.Fatalf("[%d] expected result header: %q but got: %q", i, tt.header, got) + } + } +} + +func TestHeaderDecode(t *testing.T) { + var tests = []struct { + header string + ok bool + username string + password string + }{ + { + header: "Basic dXNlcjpwYXNz", + ok: true, + username: "user", + password: "pass", + }, + { + header: "dXNlcjpwYXNz", + ok: false, + }, + { + header: "Basic ", + ok: false, + }, + { + header: "Basic dXNlcjp", + ok: false, + }, + { + header: "dXNlcjpwYXNz Basic", + ok: false, + }, + { + header: "dXNlcjpwYXNzBasic", + ok: false, + }, + } + + for i, tt := range tests { + fullUser, username, password, ok := decodeHeader(tt.header) + if expected, got := tt.ok, ok; expected != got { + t.Fatalf("[%d] expected: %v but got: %v (header=%s)", i, expected, got, tt.header) + } + + if expected, got := tt.username, username; expected != got { + t.Fatalf("[%d] expected username: %q but got: %q", i, expected, got) + } + + if expected, got := tt.password, password; expected != got { + t.Fatalf("[%d] expected password: %q but got: %q", i, expected, got) + } + + if tt.username != "" || tt.password != "" { + if expected, got := tt.username+colonLiteral+tt.password, fullUser; expected != got { + t.Fatalf("[%d] expected username:password to be: %q but got: %q", i, expected, got) + } + } else { + if fullUser != "" { + t.Fatalf("[%d] expected username:password to be empty but got: %q", i, fullUser) + } + } + + } +} diff --git a/middleware/basicauth/user_auth.go b/middleware/basicauth/user_auth.go new file mode 100644 index 00000000..808981a3 --- /dev/null +++ b/middleware/basicauth/user_auth.go @@ -0,0 +1,256 @@ +package basicauth + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "reflect" + "strings" + + "github.com/kataras/iris/v12/context" + "golang.org/x/crypto/bcrypt" + "gopkg.in/yaml.v3" +) + +type UserAuthOptions struct { + // Defaults to plain check, can be modified for encrypted passwords, see `BCRYPT`. + ComparePassword func(stored, userPassword string) bool +} + +type UserAuthOption func(*UserAuthOptions) + +func BCRYPT(opts *UserAuthOptions) { + opts.ComparePassword = func(stored, userPassword string) bool { + err := bcrypt.CompareHashAndPassword([]byte(stored), []byte(userPassword)) + return err == nil + } +} + +func toUserAuthOptions(opts []UserAuthOption) (options UserAuthOptions) { + for _, opt := range opts { + opt(&options) + } + + if options.ComparePassword == nil { + options.ComparePassword = func(stored, userPassword string) bool { + return stored == userPassword + } + } + + return options +} + +type User interface { + context.UserGetUsername + context.UserGetPassword +} + +// Users +// - map[string]string form of: {username:password, ...} form. +// - map[string]interface{} form of: []{"username": "...", "password": "...", "other_field": ...}, ...}. +// - []T which T completes the User interface. +// - []T which T contains at least Username and Password fields. +func AllowUsers(users interface{}, opts ...UserAuthOption) AuthFunc { + // create a local user structure to be used in the map copy, + // takes longer to initialize but faster to serve. + type user struct { + password string + ref interface{} + } + cp := make(map[string]*user) + + v := reflect.Indirect(reflect.ValueOf(users)) + switch v.Kind() { + case reflect.Slice: + for i := 0; i < v.Len(); i++ { + elem := v.Index(i).Interface() + // MUST contain a username and password. + username, password, ok := extractUsernameAndPassword(elem) + if !ok { + continue + } + + cp[username] = &user{ + password: password, + ref: elem, + } + } + case reflect.Map: + elem := v.Interface() + switch m := elem.(type) { + case map[string]string: + return userMap(m, opts...) + case map[string]interface{}: + username, password, ok := mapUsernameAndPassword(m) + if !ok { + break + } + + cp[username] = &user{ + password: password, + ref: m, + } + default: + panic(fmt.Sprintf("unsupported type of map: %T", users)) + } + default: + panic(fmt.Sprintf("unsupported type: %T", users)) + } + + options := toUserAuthOptions(opts) + + return func(_ *context.Context, username, password string) (interface{}, bool) { + if u, ok := cp[username]; ok { // fast map access, + if options.ComparePassword(u.password, password) { + return u.ref, true + } + } + + return nil, false + } +} + +func userMap(usernamePassword map[string]string, opts ...UserAuthOption) AuthFunc { + options := toUserAuthOptions(opts) + + return func(_ *context.Context, username, password string) (interface{}, bool) { + pass, ok := usernamePassword[username] + return nil, ok && options.ComparePassword(pass, password) + } +} + +func AllowUsersFile(jsonOrYamlFilename string, opts ...UserAuthOption) AuthFunc { + var ( + usernamePassword map[string]string + // no need to support too much forms, this would be for: + // "$username": { "password": "$pass", "other_field": ...} + // users map[string]map[string]interface{} + userList []map[string]interface{} + ) + + if err := decodeFile(jsonOrYamlFilename, &usernamePassword, &userList); err != nil { + panic(err) + } + + if len(usernamePassword) > 0 { + // JSON Form: { "$username":"$pass", "$username": "$pass" } + // YAML Form: $username: $pass + // $username: $pass + return userMap(usernamePassword, opts...) + } + + if len(userList) > 0 { + // JSON Form: [{"username": "$username", "password": "$pass", "other_field": ...}, {"username": ...}, ... ] + // YAML Form: + // - username: $username + // password: $password + // other_field: ... + return AllowUsers(userList, opts...) + } + + panic("malformed document file: " + jsonOrYamlFilename) +} + +func decodeFile(src string, dest ...interface{}) error { + data, err := ioutil.ReadFile(src) + if err != nil { + return err + } + + // We use unmarshal instead of file decoder + // as we may need to read it more than once (dests, see below). + var ( + unmarshal func(data []byte, v interface{}) error + ext string + ) + + if idx := strings.LastIndexByte(src, '.'); idx > 0 { + ext = src[idx:] + } + + switch ext { + case "", ".json": + unmarshal = json.Unmarshal + case ".yml", ".yaml": + unmarshal = yaml.Unmarshal + default: + return fmt.Errorf("unexpected file extension: %s", ext) + } + + var ( + ok bool + lastErr error + ) + + for _, d := range dest { + if err = unmarshal(data, d); err == nil { + ok = true + } else { + lastErr = err + } + } + + if !ok { + return lastErr + } + + return nil // if at least one is succeed we are ok. +} + +func extractUsernameAndPassword(s interface{}) (username, password string, ok bool) { + if s == nil { + return + } + + switch u := s.(type) { + case User: + username = u.GetUsername() + password = u.GetPassword() + ok = username != "" && password != "" + return + case map[string]interface{}: + return mapUsernameAndPassword(u) + default: + b, err := json.Marshal(u) + if err != nil { + return + } + + var m map[string]interface{} + if err = json.Unmarshal(b, &m); err != nil { + return + } + + return mapUsernameAndPassword(m) + } +} + +func mapUsernameAndPassword(m map[string]interface{}) (username, password string, ok bool) { + // type of username: password. + if len(m) == 1 { + for username, v := range m { + if password, ok := v.(string); ok { + ok := username != "" && password != "" + return username, password, ok + } + } + } + + var usernameFound, passwordFound bool + + for k, v := range m { + switch k { + case "username", "Username": + username, usernameFound = v.(string) + case "password", "Password": + password, passwordFound = v.(string) + } + + if usernameFound && passwordFound { + ok = true + break + } + } + + return +} diff --git a/middleware/basicauth/user_auth_test.go b/middleware/basicauth/user_auth_test.go new file mode 100644 index 00000000..1fa96fc5 --- /dev/null +++ b/middleware/basicauth/user_auth_test.go @@ -0,0 +1,287 @@ +package basicauth + +import ( + "errors" + "io/ioutil" + "os" + "reflect" + "testing" + + "github.com/kataras/iris/v12/context" + + "golang.org/x/crypto/bcrypt" + "gopkg.in/yaml.v3" +) + +type IUserRepository interface { + GetByUsernameAndPassword(dest interface{}, username, password string) error +} + +// Test a custom implementation of AuthFunc with a user repository. +// This is a usage example of custom AuthFunc implementation. +func UserRepository(repo IUserRepository, newUserPtr func() interface{}) AuthFunc { + return func(ctx *context.Context, username, password string) (interface{}, bool) { + dest := newUserPtr() + err := repo.GetByUsernameAndPassword(dest, username, password) + if err == nil { + return dest, true + } + + return nil, false + } +} + +type testUser struct { + username string + password string + email string // custom field. +} + +// GetUsername & Getpassword complete the User interface (optional but useful on Context.User()). +func (u *testUser) GetUsername() string { + return u.username +} + +func (u *testUser) GetPassword() string { + return u.password +} + +type testRepo struct { + entries []testUser +} + +// Implements IUserRepository interface. +func (r *testRepo) GetByUsernameAndPassword(dest interface{}, username, password string) error { + for _, e := range r.entries { + if e.username == username && e.password == password { + *dest.(*testUser) = e + return nil + } + } + + return errors.New("invalid credentials") +} + +func TestAllowUserRepository(t *testing.T) { + repo := &testRepo{ + entries: []testUser{ + {username: "kataras", password: "kataras_pass", email: "kataras2006@hotmail.com"}, + }, + } + + allow := UserRepository(repo, func() interface{} { + return new(testUser) + }) + + var tests = []struct { + username string + password string + ok bool + user *testUser + }{ + { + username: "kataras", + password: "kataras_pass", + ok: true, + user: &testUser{username: "kataras", password: "kataras_pass", email: "kataras2006@hotmail.com"}, + }, + { + username: "makis", + password: "makis_password", + ok: false, + }, + } + + for i, tt := range tests { + v, ok := allow(nil, tt.username, tt.password) + + if tt.ok != ok { + t.Fatalf("[%d] expected: %v but got: %v (username=%s,password=%s)", i, tt.ok, ok, tt.username, tt.password) + } + + if !ok { + continue + } + + u, ok := v.(*testUser) + if !ok { + t.Fatalf("[%d] a user should be type of *testUser but got: %#+v (%T)", i, v, v) + } + + if !reflect.DeepEqual(tt.user, u) { + t.Fatalf("[%d] expected user:\n%#+v\nbut got:\n%#+v", i, tt.user, u) + } + } +} + +func TestAllowUsers(t *testing.T) { + users := []User{ + &testUser{username: "kataras", password: "kataras_pass", email: "kataras2006@hotmail.com"}, + } + + allow := AllowUsers(users) + + var tests = []struct { + username string + password string + ok bool + user *testUser + }{ + { + username: "kataras", + password: "kataras_pass", + ok: true, + user: &testUser{username: "kataras", password: "kataras_pass", email: "kataras2006@hotmail.com"}, + }, + { + username: "makis", + password: "makis_password", + ok: false, + }, + } + + for i, tt := range tests { + v, ok := allow(nil, tt.username, tt.password) + + if tt.ok != ok { + t.Fatalf("[%d] expected: %v but got: %v (username=%s,password=%s)", i, tt.ok, ok, tt.username, tt.password) + } + + if !ok { + continue + } + + u, ok := v.(*testUser) + if !ok { + t.Fatalf("[%d] a user should be type of *testUser but got: %#+v (%T)", i, v, v) + } + + if !reflect.DeepEqual(tt.user, u) { + t.Fatalf("[%d] expected user:\n%#+v\nbut got:\n%#+v", i, tt.user, u) + } + } +} + +// Test YAML user loading with b-encrypted passwords. +func TestAllowUsersFile(t *testing.T) { + f, err := ioutil.TempFile("", "*users.yml") + if err != nil { + t.Fatal(err) + } + defer func() { + f.Close() + os.Remove(f.Name()) + }() + + // f.WriteString(` + // - username: kataras + // password: kataras_pass + // age: 27 + // role: admin + // - username: makis + // password: makis_password + // `) + // This form is supported too, although its features are limited (no custom fields): + // f.WriteString(` + // kataras: kataras_pass + // makis: makis_password + // `) + + var tests = []struct { + username string + password string // hashed, auto-filled later on. + inputPassword string + ok bool + user context.Map + }{ + { + username: "kataras", + inputPassword: "kataras_pass", + ok: true, + user: context.Map{"age": 27, "role": "admin"}, // username and password are auto-filled in our tests below. + }, + { + username: "makis", + inputPassword: "makis_password", + ok: true, + user: context.Map{}, + }, + { + username: "invalid", + password: "invalid_pass", + ok: false, + }, + { + username: "notvalid", + password: "", + ok: false, + }, + } + + // Write the tests to the users YAML file. + var usersToWrite []context.Map + for _, tt := range tests { + if tt.ok { + // store the hashed password. + tt.password = mustGeneratePassword(t, tt.inputPassword) + // store and write the username and hashed password. + tt.user["username"] = tt.username + tt.user["password"] = tt.password + + // cannot write it as a stream, write it as a slice. + // enc.Encode(tt.user) + usersToWrite = append(usersToWrite, tt.user) + } + // bcrypt.GenerateFromPassword([]byte("kataras_pass"), bcrypt.DefaultCost) + } + + fileContents, err := yaml.Marshal(usersToWrite) + if err != nil { + t.Fatal(err) + } + f.Write(fileContents) + + // Build the authentication func. + allow := AllowUsersFile(f.Name(), BCRYPT) + for i, tt := range tests { + v, ok := allow(nil, tt.username, tt.inputPassword) + + if tt.ok != ok { + t.Fatalf("[%d] expected: %v but got: %v (username=%s,password=%s,user=%#+v)", i, tt.ok, ok, tt.username, tt.inputPassword, v) + } + + if !ok { + continue + } + + if len(tt.user) == 0 { // when username: password form. + continue + } + + u, ok := v.(context.Map) + if !ok { + t.Fatalf("[%d] a user loaded from external source or file should be alway type of map[string]interface{} but got: %#+v (%T)", i, v, v) + } + + if expected, got := len(tt.user), len(u); expected != got { + t.Fatalf("[%d] expected user map length to be equal, expected: %d but got: %d\n%#+v\n%#+v", i, expected, got, tt.user, u) + } + + for k, v := range tt.user { + if u[k] != v { + t.Fatalf("[%d] expected user map %q to be %q but got: %q", i, k, v, u[k]) + } + } + } + +} + +func mustGeneratePassword(t *testing.T, userPassword string) string { + t.Helper() + hashed, err := bcrypt.GenerateFromPassword([]byte(userPassword), bcrypt.DefaultCost) + if err != nil { + t.Fatal(err) + } + + return string(hashed) +}