mirror of
https://github.com/kataras/iris.git
synced 2025-01-23 10:41:03 +01:00
New basic auth middleware and GetRaw on User (godocs missing)
This commit is contained in:
parent
962ffd6772
commit
4d857ac53f
|
@ -1,8 +1,6 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/kataras/iris/v12"
|
||||
"github.com/kataras/iris/v12/middleware/basicauth"
|
||||
)
|
||||
|
@ -10,25 +8,40 @@ import (
|
|||
func newApp() *iris.Application {
|
||||
app := iris.New()
|
||||
|
||||
authConfig := basicauth.Config{
|
||||
Users: map[string]string{"myusername": "mypassword", "mySecondusername": "mySecondpassword"},
|
||||
Realm: "Authorization Required", // defaults to "Authorization Required"
|
||||
Expires: time.Duration(30) * time.Minute,
|
||||
/*
|
||||
opts := basicauth.Options{
|
||||
Realm: "Authorization Required",
|
||||
MaxAge: 30 * time.Minute,
|
||||
GC: basicauth.GC{
|
||||
Every: 2 * time.Hour,
|
||||
},
|
||||
Allow: basicauth.AllowUsers(map[string]string{
|
||||
"myusername": "mypassword",
|
||||
"mySecondusername": "mySecondpassword",
|
||||
}),
|
||||
MaxTries: 2,
|
||||
}
|
||||
auth := basicauth.New(opts)
|
||||
|
||||
authentication := basicauth.New(authConfig)
|
||||
OR simply:
|
||||
*/
|
||||
|
||||
// to global app.Use(authentication) (or app.UseGlobal before the .Run)
|
||||
auth := basicauth.Default(map[string]string{
|
||||
"myusername": "mypassword",
|
||||
"mySecondusername": "mySecondpassword",
|
||||
})
|
||||
|
||||
// to global app.Use(auth) (or app.UseGlobal before the .Run)
|
||||
// to routes
|
||||
/*
|
||||
app.Get("/mysecret", authentication, h)
|
||||
app.Get("/mysecret", auth, h)
|
||||
*/
|
||||
|
||||
app.Get("/", func(ctx iris.Context) { ctx.Redirect("/admin") })
|
||||
|
||||
// to party
|
||||
|
||||
needAuth := app.Party("/admin", authentication)
|
||||
needAuth := app.Party("/admin", auth)
|
||||
{
|
||||
//http://localhost:8080/admin
|
||||
needAuth.Get("/", h)
|
||||
|
|
|
@ -25,5 +25,5 @@ func TestBasicAuth(t *testing.T) {
|
|||
|
||||
// with invalid basic auth
|
||||
e.GET("/admin/settings").WithBasicAuth("invalidusername", "invalidpassword").
|
||||
Expect().Status(httptest.StatusUnauthorized)
|
||||
Expect().Status(httptest.StatusForbidden)
|
||||
}
|
||||
|
|
|
@ -4,7 +4,7 @@ go 1.15
|
|||
|
||||
require (
|
||||
github.com/google/uuid v1.1.2
|
||||
github.com/kataras/iris/v12 v12.2.0-alpha.0.20201106220849-7a19cfb2112f
|
||||
github.com/kataras/iris/v12 v12.2.0-alpha.0.20201113181155-4d09475c290d
|
||||
golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897
|
||||
)
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@ func main() {
|
|||
|
||||
newtest := app.Subdomain("newtest")
|
||||
newtest.Get("/", newTestIndex)
|
||||
newtest.Get("/", newTestAbout)
|
||||
newtest.Get("/about", newTestAbout)
|
||||
|
||||
redirects := rewrite.Load("redirects.yml")
|
||||
app.WrapRouter(redirects)
|
||||
|
@ -25,14 +25,13 @@ func main() {
|
|||
// http://mydomain.com:8080/seo/about -> http://www.mydomain.com:8080/about
|
||||
// http://test.mydomain.com:8080 -> http://newtest.mydomain.com:8080
|
||||
// http://test.mydomain.com:8080/seo/about -> http://newtest.mydomain.com:8080/about
|
||||
// http://localhost:8080/seo -> http://localhost:8080
|
||||
// http://localhost:8080/about
|
||||
// http://localhost:8080/docs/v12/hello -> http://localhost:8080/docs
|
||||
// http://localhost:8080/docs/v12some -> http://localhost:8080/docs
|
||||
// http://localhost:8080/oldsome -> http://localhost:8080
|
||||
// http://localhost:8080/oldindex/random -> http://localhost:8080
|
||||
// http://localhost:8080/users.json -> http://localhost:8080/users.json
|
||||
// ^ (but with an internal ?format=json, client can't see it)
|
||||
// http://mydomain.com:8080/seo -> http://www.mydomain.com:8080
|
||||
// http://mydomain.com:8080/about
|
||||
// http://mydomain.com:8080/docs/v12/hello -> http://www.mydomain.com:8080/docs
|
||||
// http://mydomain.com:8080/docs/v12some -> http://www.mydomain.com:8080/docs
|
||||
// http://mydomain.com:8080/oldsome -> http://www.mydomain.com:8080
|
||||
// http://mydomain.com:8080/oldindex/random -> http://www.mydomain.com:8080
|
||||
// http://mydomain.com:8080/users.json -> http://www.mydomain.com:8080/users?format=json
|
||||
app.Listen(":8080")
|
||||
}
|
||||
|
||||
|
|
|
@ -4726,7 +4726,7 @@ func (ctx *Context) UpsertCookie(cookie *http.Cookie, options ...CookieOption) b
|
|||
// you can change it or simple, use the SetCookie for more control.
|
||||
//
|
||||
// See `CookieExpires` and `AddCookieOptions` for more.
|
||||
var SetCookieKVExpiration = time.Duration(8760) * time.Hour
|
||||
var SetCookieKVExpiration = 8760 * time.Hour
|
||||
|
||||
// SetCookieKV adds a cookie, requires the name(string) and the value(string).
|
||||
//
|
||||
|
@ -5343,7 +5343,15 @@ const userContextKey = "iris.user"
|
|||
// SetUser sets a value as a User for this request.
|
||||
// It's used by auth middlewares as a common
|
||||
// method to provide user information to the
|
||||
// next handlers in the chain
|
||||
// next handlers in the chain.
|
||||
//
|
||||
// The "i" input argument can be:
|
||||
// - A value which completes the User interface
|
||||
// - A map[string]interface{}.
|
||||
// - A value which does not complete the whole User interface
|
||||
// - A value which does not complete the User interface at all
|
||||
// (only its `User().GetRaw` method is available).
|
||||
//
|
||||
// Look the `User` method to retrieve it.
|
||||
func (ctx *Context) SetUser(i interface{}) error {
|
||||
if i == nil {
|
||||
|
@ -5371,6 +5379,9 @@ func (ctx *Context) SetUser(i interface{}) error {
|
|||
}
|
||||
|
||||
// User returns the registered User of this request.
|
||||
// To get the original value (even if a value set by SetUser does not implement the User interface)
|
||||
// use its GetRaw method.
|
||||
// /
|
||||
// See `SetUser` too.
|
||||
func (ctx *Context) User() User {
|
||||
if v := ctx.values.Get(userContextKey); v != nil {
|
||||
|
|
|
@ -33,6 +33,8 @@ var ErrNotSupported = errors.New("not supported")
|
|||
// - UserMap (a wrapper by SetUser)
|
||||
// - UserPartial (a wrapper by SetUser)
|
||||
type User interface {
|
||||
// GetRaw should return the raw instance of the user, if supported.
|
||||
GetRaw() (interface{}, error)
|
||||
// GetAuthorization should return the authorization method,
|
||||
// e.g. Basic Authentication.
|
||||
GetAuthorization() (string, error)
|
||||
|
@ -92,6 +94,11 @@ type SimpleUser struct {
|
|||
|
||||
var _ User = (*SimpleUser)(nil)
|
||||
|
||||
// GetRaw returns itself.
|
||||
func (u *SimpleUser) GetRaw() (interface{}, error) {
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// GetAuthorization returns the authorization method,
|
||||
// e.g. Basic Authentication.
|
||||
func (u *SimpleUser) GetAuthorization() (string, error) {
|
||||
|
@ -179,6 +186,11 @@ type UserMap Map
|
|||
|
||||
var _ User = UserMap{}
|
||||
|
||||
// GetRaw returns the underline map.
|
||||
func (u UserMap) GetRaw() (interface{}, error) {
|
||||
return Map(u), nil
|
||||
}
|
||||
|
||||
// GetAuthorization returns the authorization or Authorization value of the map.
|
||||
func (u UserMap) GetAuthorization() (string, error) {
|
||||
return u.str("authorization")
|
||||
|
@ -292,11 +304,17 @@ type (
|
|||
GetID() string
|
||||
}
|
||||
|
||||
userGetUsername interface {
|
||||
// UserGetUsername interface which
|
||||
// requires a single method to complete
|
||||
// a User on Context.SetUser.
|
||||
UserGetUsername interface {
|
||||
GetUsername() string
|
||||
}
|
||||
|
||||
userGetPassword interface {
|
||||
// UserGetPassword interface which
|
||||
// requires a single method to complete
|
||||
// a User on Context.SetUser.
|
||||
UserGetPassword interface {
|
||||
GetPassword() string
|
||||
}
|
||||
|
||||
|
@ -319,13 +337,14 @@ type (
|
|||
// UserPartial is a User.
|
||||
// It's a helper which wraps a struct value that
|
||||
// may or may not complete the whole User interface.
|
||||
// See Context.SetUser.
|
||||
UserPartial struct {
|
||||
Raw interface{}
|
||||
userGetAuthorization
|
||||
userGetAuthorizedAt
|
||||
userGetID
|
||||
userGetUsername
|
||||
userGetPassword
|
||||
UserGetUsername
|
||||
UserGetPassword
|
||||
userGetEmail
|
||||
userGetRoles
|
||||
userGetToken
|
||||
|
@ -336,61 +355,64 @@ type (
|
|||
var _ User = (*UserPartial)(nil)
|
||||
|
||||
func newUserPartial(i interface{}) *UserPartial {
|
||||
containsAtLeastOneMethod := false
|
||||
if i == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
p := &UserPartial{Raw: i}
|
||||
|
||||
if u, ok := i.(userGetAuthorization); ok {
|
||||
p.userGetAuthorization = u
|
||||
containsAtLeastOneMethod = true
|
||||
}
|
||||
|
||||
if u, ok := i.(userGetAuthorizedAt); ok {
|
||||
p.userGetAuthorizedAt = u
|
||||
containsAtLeastOneMethod = true
|
||||
}
|
||||
|
||||
if u, ok := i.(userGetID); ok {
|
||||
p.userGetID = u
|
||||
containsAtLeastOneMethod = true
|
||||
}
|
||||
|
||||
if u, ok := i.(userGetUsername); ok {
|
||||
p.userGetUsername = u
|
||||
containsAtLeastOneMethod = true
|
||||
if u, ok := i.(UserGetUsername); ok {
|
||||
p.UserGetUsername = u
|
||||
}
|
||||
|
||||
if u, ok := i.(userGetPassword); ok {
|
||||
p.userGetPassword = u
|
||||
containsAtLeastOneMethod = true
|
||||
if u, ok := i.(UserGetPassword); ok {
|
||||
p.UserGetPassword = u
|
||||
}
|
||||
|
||||
if u, ok := i.(userGetEmail); ok {
|
||||
p.userGetEmail = u
|
||||
containsAtLeastOneMethod = true
|
||||
}
|
||||
|
||||
if u, ok := i.(userGetRoles); ok {
|
||||
p.userGetRoles = u
|
||||
containsAtLeastOneMethod = true
|
||||
}
|
||||
|
||||
if u, ok := i.(userGetToken); ok {
|
||||
p.userGetToken = u
|
||||
containsAtLeastOneMethod = true
|
||||
}
|
||||
|
||||
if u, ok := i.(userGetField); ok {
|
||||
p.userGetField = u
|
||||
containsAtLeastOneMethod = true
|
||||
}
|
||||
|
||||
if !containsAtLeastOneMethod {
|
||||
return nil
|
||||
}
|
||||
// if !containsAtLeastOneMethod {
|
||||
// return nil
|
||||
// }
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
// GetRaw returns the original raw instance of the user.
|
||||
func (u *UserPartial) GetRaw() (interface{}, error) {
|
||||
if u == nil {
|
||||
return nil, ErrNotSupported
|
||||
}
|
||||
|
||||
return u.Raw, nil
|
||||
}
|
||||
|
||||
// GetAuthorization should return the authorization method,
|
||||
// e.g. Basic Authentication.
|
||||
func (u *UserPartial) GetAuthorization() (string, error) {
|
||||
|
@ -422,7 +444,7 @@ func (u *UserPartial) GetID() (string, error) {
|
|||
|
||||
// GetUsername should return the name of the User.
|
||||
func (u *UserPartial) GetUsername() (string, error) {
|
||||
if v := u.userGetUsername; v != nil {
|
||||
if v := u.UserGetUsername; v != nil {
|
||||
return v.GetUsername(), nil
|
||||
}
|
||||
|
||||
|
@ -432,7 +454,7 @@ func (u *UserPartial) GetUsername() (string, error) {
|
|||
// GetPassword should return the encoded or raw password
|
||||
// (depends on the implementation) of the User.
|
||||
func (u *UserPartial) GetPassword() (string, error) {
|
||||
if v := u.userGetPassword; v != nil {
|
||||
if v := u.UserGetPassword; v != nil {
|
||||
return v.GetPassword(), nil
|
||||
}
|
||||
|
||||
|
|
|
@ -1,205 +1,402 @@
|
|||
// Package basicauth provides http basic authentication via middleware. See _examples/auth/basicauth
|
||||
package basicauth
|
||||
|
||||
/*
|
||||
Test files:
|
||||
- ../../_examples/auth/basicauth/main_test.go
|
||||
- ./basicauth_test.go
|
||||
*/
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
stdContext "context"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/kataras/iris/v12/context"
|
||||
"github.com/kataras/iris/v12/sessions"
|
||||
)
|
||||
|
||||
func init() {
|
||||
context.SetHandlerName("iris/middleware/basicauth.*", "iris.basicauth")
|
||||
}
|
||||
|
||||
const authorizationType = "Basic Authentication"
|
||||
|
||||
type (
|
||||
encodedUser struct {
|
||||
HeaderValue string
|
||||
Username string
|
||||
Password string
|
||||
logged bool
|
||||
forceLogout bool // in order to be able to invalidate and use a redirect response.
|
||||
authorizedAt time.Time // when from !logged to logged.
|
||||
expires time.Time
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
basicAuthMiddleware struct {
|
||||
config *Config
|
||||
// these are filled from the config.Users map at the startup
|
||||
auth []*encodedUser
|
||||
realmHeaderValue string
|
||||
|
||||
// The below can be removed but they are here because on the future we may add dynamic options for those two fields,
|
||||
// it is a bit faster to check the b.$bool as well.
|
||||
expireEnabled bool // if the config.Expires is a valid date, default is disabled.
|
||||
askHandlerEnabled bool // if the config.OnAsk is not nil, defaults to false.
|
||||
}
|
||||
const (
|
||||
DefaultRealm = "Authorization Required"
|
||||
DefaultMaxTriesCookie = "basicmaxtries"
|
||||
)
|
||||
|
||||
const (
|
||||
authorizationType = "Basic Authentication"
|
||||
authenticateHeaderKey = "WWW-Authenticate"
|
||||
proxyAuthenticateHeaderKey = "Proxy-Authenticate"
|
||||
authorizationHeaderKey = "Authorization"
|
||||
proxyAuthorizationHeaderKey = "Proxy-Authorization"
|
||||
)
|
||||
|
||||
type AuthFunc func(ctx *context.Context, username, password string) (interface{}, bool)
|
||||
|
||||
type Options struct {
|
||||
// Realm http://tools.ietf.org/html/rfc2617#section-1.2.
|
||||
// E.g. "Authorization Required".
|
||||
Realm string
|
||||
// In the case of proxies, the challenging status code is 407 (Proxy Authentication Required),
|
||||
// the Proxy-Authenticate response header contains at least one challenge applicable to the proxy,
|
||||
// and the Proxy-Authorization request header is used for providing the credentials to the proxy server.
|
||||
//
|
||||
// Proxy should be used to gain access to a resource behind a proxy server.
|
||||
// It authenticates the request to the proxy server, allowing it to transmit the request further.
|
||||
Proxy bool
|
||||
// Usage:
|
||||
// - Allow: AllowUsers(iris.Map{"username": "...", "password": "...", "other_field": ...}, [BCRYPT])
|
||||
// - Allow: AllowUsersFile("users.yml", [BCRYPT])
|
||||
Allow AuthFunc
|
||||
// If greater than zero then the server will send 403 forbidden status code afer MaxTries
|
||||
// of invalid credentials of a specific client consumed (session or cookie based, see MaxTriesCookie).
|
||||
// By default the server will re-ask for credentials on any amount of invalid credentials.
|
||||
MaxTries int
|
||||
// If a session manager is register under the current request,
|
||||
// then this value should be the key of the session storage which
|
||||
// the current tries will be stored. Otherwise
|
||||
// it is the raw cookie name.
|
||||
// The cookie is stored up to the configured MaxAge if greater than zero or for 1 year,
|
||||
// so a forbidden client can request for authentication again after the MaxAge expired.
|
||||
//
|
||||
// Note that, the session way is recommended as the current tries
|
||||
// cannot be modified by the client (unless the client removes the session cookie).
|
||||
// However the raw cookie performs faster. You can always set custom logic
|
||||
// on the Allow field as you have access to the current request Context.
|
||||
// To set custom cookie options use the `Context.AddCookieOptions(options ...iris.CookieOption)`
|
||||
// before the basic auth middleware.
|
||||
//
|
||||
// If MaxTries > 0 then it defaults to "basicmaxtries".
|
||||
// The MaxTries should be set to greater than zero.
|
||||
MaxTriesCookie string
|
||||
// If not nil runs after 401 (or 407 if proxy is enabled) status code.
|
||||
// Can be used to set custom response for unauthenticated clients.
|
||||
OnAsk context.Handler
|
||||
// If not nil runs after the 403 forbidden status code (when Allow returned false and MaxTries consumed).
|
||||
// Can be used to set custom response when client tried to access a resource with invalid credentials.
|
||||
OnForbidden context.Handler
|
||||
// MaxAge sets expiration duration for the in-memory credentials map.
|
||||
// By default an old map entry will be removed when the user visits a page.
|
||||
// In order to remove old entries automatically please take a look at the `GC` option too.
|
||||
//
|
||||
// Usage:
|
||||
// MaxAge: 30*time.Minute
|
||||
MaxAge time.Duration
|
||||
// GC automatically clears old entries every x duration.
|
||||
// Note that, by old entries we mean expired credentials therefore
|
||||
// the `MaxAge` option should be already set,
|
||||
// if it's not then all entries will be removed on "every" duration.
|
||||
// The standard context can be used for the internal ticker cancelation, it can be nil.
|
||||
//
|
||||
// Usage:
|
||||
// GC: basicauth.GC{Every: 2*time.Hour}
|
||||
GC GC
|
||||
}
|
||||
|
||||
type GC struct {
|
||||
Context stdContext.Context
|
||||
Every time.Duration
|
||||
}
|
||||
|
||||
// https://tools.ietf.org/html/rfc2617
|
||||
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Authentication
|
||||
//
|
||||
// As the user ID and password are passed over the network as clear text
|
||||
// (it is base64 encoded, but base64 is a reversible encoding), the basic authentication scheme is not secure.
|
||||
// HTTPS/TLS should be used with basic authentication. Without these additional security enhancements,
|
||||
// basic authentication should not be used to protect sensitive or valuable information.
|
||||
type BasicAuth struct {
|
||||
opts Options
|
||||
// built based on proxy field
|
||||
askCode int
|
||||
authorizationHeader string
|
||||
authenticateHeader string
|
||||
// built based on realm field.
|
||||
authenticateHeaderValue string
|
||||
|
||||
// New accepts basicauth.Config and returns a new Handler
|
||||
// which will ask the client for basic auth (username, password),
|
||||
// validate that and if valid continues to the next handler, otherwise
|
||||
// throws a StatusUnauthorized http error code.
|
||||
//
|
||||
// Use the `Context.User` method to retrieve the stored user.
|
||||
func New(c Config) context.Handler {
|
||||
config := DefaultConfig()
|
||||
if c.Realm != "" {
|
||||
config.Realm = c.Realm
|
||||
}
|
||||
config.Users = c.Users
|
||||
config.Expires = c.Expires
|
||||
config.OnAsk = c.OnAsk
|
||||
|
||||
b := &basicAuthMiddleware{config: &config}
|
||||
b.init()
|
||||
return b.Serve
|
||||
credentials map[string]*time.Time // key = username:password, value = expiration time (if MaxAge > 0).
|
||||
mu sync.RWMutex // protects the credentials as they can modified.
|
||||
}
|
||||
|
||||
// Default accepts only the users and returns a new Handler
|
||||
// which will ask the client for basic auth (username, password),
|
||||
// validate that and if valid continues to the next handler, otherwise
|
||||
// throws a StatusUnauthorized http error code.
|
||||
func Default(users map[string]string) context.Handler {
|
||||
c := DefaultConfig()
|
||||
c.Users = users
|
||||
return New(c)
|
||||
func New(opts Options) context.Handler {
|
||||
var (
|
||||
askCode = 401
|
||||
authorizationHeader = authorizationHeaderKey
|
||||
authenticateHeader = authenticateHeaderKey
|
||||
authenticateHeaderValue = "Basic"
|
||||
)
|
||||
|
||||
if opts.Allow == nil {
|
||||
panic("BasicAuth: Allow field is required")
|
||||
}
|
||||
|
||||
if opts.Realm != "" {
|
||||
authenticateHeaderValue += " realm=" + strconv.Quote(opts.Realm)
|
||||
}
|
||||
|
||||
if opts.Proxy {
|
||||
askCode = 407
|
||||
authenticateHeader = proxyAuthenticateHeaderKey
|
||||
authorizationHeader = proxyAuthorizationHeaderKey
|
||||
}
|
||||
|
||||
if opts.MaxTries > 0 && opts.MaxTriesCookie == "" {
|
||||
opts.MaxTriesCookie = DefaultMaxTriesCookie
|
||||
}
|
||||
|
||||
b := &BasicAuth{
|
||||
opts: opts,
|
||||
askCode: askCode,
|
||||
authorizationHeader: authorizationHeader,
|
||||
authenticateHeader: authenticateHeader,
|
||||
authenticateHeaderValue: authenticateHeaderValue,
|
||||
credentials: make(map[string]*time.Time),
|
||||
}
|
||||
|
||||
if opts.GC.Every > 0 {
|
||||
go b.runGC(opts.GC.Context, opts.GC.Every)
|
||||
}
|
||||
|
||||
return b.serveHTTP
|
||||
}
|
||||
|
||||
func (b *basicAuthMiddleware) init() {
|
||||
// pass the encoded users from the user's config's Users value
|
||||
b.auth = make([]*encodedUser, 0, len(b.config.Users))
|
||||
|
||||
for k, v := range b.config.Users {
|
||||
fullUser := k + ":" + v
|
||||
header := "Basic " + base64.StdEncoding.EncodeToString([]byte(fullUser))
|
||||
b.auth = append(b.auth, &encodedUser{
|
||||
HeaderValue: header,
|
||||
Username: k,
|
||||
Password: v,
|
||||
logged: false,
|
||||
expires: DefaultExpireTime,
|
||||
})
|
||||
// - map[string]string form of: {username:password, ...} form.
|
||||
// - map[string]interface{} form of: []{"username": "...", "password": "...", "other_field": ...}, ...}.
|
||||
// - []T which T completes the User interface.
|
||||
// - []T which T contains at least Username and Password fields.
|
||||
func Default(users interface{}, userOpts ...UserAuthOption) context.Handler {
|
||||
opts := Options{
|
||||
Realm: DefaultRealm,
|
||||
Allow: AllowUsers(users, userOpts...),
|
||||
}
|
||||
|
||||
// set the auth realm header's value
|
||||
b.realmHeaderValue = "Basic realm=" + strconv.Quote(b.config.Realm)
|
||||
|
||||
b.expireEnabled = b.config.Expires > 0
|
||||
b.askHandlerEnabled = b.config.OnAsk != nil
|
||||
return New(opts)
|
||||
}
|
||||
|
||||
func (b *basicAuthMiddleware) findAuth(headerValue string) (*encodedUser, bool) {
|
||||
if headerValue != "" {
|
||||
for _, user := range b.auth {
|
||||
if user.HeaderValue == headerValue {
|
||||
return user, true
|
||||
func Load(jsonOrYamlFilename string, userOpts ...UserAuthOption) context.Handler {
|
||||
opts := Options{
|
||||
Realm: DefaultRealm,
|
||||
Allow: AllowUsersFile(jsonOrYamlFilename, userOpts...),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, false
|
||||
return New(opts)
|
||||
}
|
||||
|
||||
func (b *basicAuthMiddleware) askForCredentials(ctx *context.Context) {
|
||||
ctx.Header("WWW-Authenticate", b.realmHeaderValue)
|
||||
ctx.StatusCode(401)
|
||||
if b.askHandlerEnabled {
|
||||
b.config.OnAsk(ctx)
|
||||
// askForCredentials sends a response to the client which client should catch
|
||||
// and ask for username:password credentials.
|
||||
func (b *BasicAuth) askForCredentials(ctx *context.Context) {
|
||||
ctx.Header(b.authenticateHeader, b.authenticateHeaderValue)
|
||||
ctx.StopWithStatus(b.askCode)
|
||||
|
||||
if h := b.opts.OnAsk; h != nil {
|
||||
h(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// Serve the actual basic authentication middleware.
|
||||
// Use the Context.User method to retrieve the stored user.
|
||||
func (b *basicAuthMiddleware) Serve(ctx *context.Context) {
|
||||
auth, found := b.findAuth(ctx.GetHeader("Authorization"))
|
||||
if !found || auth.forceLogout {
|
||||
if auth != nil {
|
||||
auth.mu.Lock()
|
||||
auth.forceLogout = false
|
||||
auth.mu.Unlock()
|
||||
// If a (proxy) server receives valid credentials that are inadequate to access a given resource,
|
||||
// the server should respond with the 403 Forbidden status code.
|
||||
// Unlike 401 Unauthorized or 407 Proxy Authentication Required, authentication is impossible for this user.
|
||||
func (b *BasicAuth) forbidden(ctx *context.Context) {
|
||||
ctx.StopWithStatus(403)
|
||||
|
||||
if h := b.opts.OnForbidden; h != nil {
|
||||
h(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BasicAuth) getCurrentTries(ctx *context.Context) (tries int) {
|
||||
sess := sessions.Get(ctx)
|
||||
if sess != nil {
|
||||
tries = sess.GetIntDefault(b.opts.MaxTriesCookie, 0)
|
||||
} else {
|
||||
if v := ctx.GetCookie(b.opts.MaxTriesCookie); v != "" {
|
||||
tries, _ = strconv.Atoi(v)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (b *BasicAuth) setCurrentTries(ctx *context.Context, tries int) {
|
||||
sess := sessions.Get(ctx)
|
||||
if sess != nil {
|
||||
sess.Set(b.opts.MaxTriesCookie, tries)
|
||||
} else {
|
||||
maxAge := b.opts.MaxAge
|
||||
if maxAge == 0 {
|
||||
maxAge = context.SetCookieKVExpiration // 1 year.
|
||||
}
|
||||
ctx.SetCookieKV(b.opts.MaxTriesCookie, strconv.Itoa(tries), context.CookieExpires(maxAge))
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BasicAuth) resetCurrentTries(ctx *context.Context) {
|
||||
sess := sessions.Get(ctx)
|
||||
if sess != nil {
|
||||
sess.Delete(b.opts.MaxTriesCookie)
|
||||
} else {
|
||||
ctx.RemoveCookie(b.opts.MaxTriesCookie)
|
||||
}
|
||||
}
|
||||
|
||||
// serveHTTP is the main method of this middleware,
|
||||
// checks and verifies the auhorization header for basic authentication,
|
||||
// next handlers will only be executed when the client is allowed to continue.
|
||||
func (b *BasicAuth) serveHTTP(ctx *context.Context) {
|
||||
header := ctx.GetHeader(b.authorizationHeader)
|
||||
fullUser, username, password, ok := decodeHeader(header)
|
||||
if !ok { // Header is malformed or missing.
|
||||
b.askForCredentials(ctx)
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
maxTries = b.opts.MaxTries
|
||||
tries int
|
||||
)
|
||||
|
||||
if maxTries > 0 {
|
||||
tries = b.getCurrentTries(ctx)
|
||||
}
|
||||
|
||||
user, ok := b.opts.Allow(ctx, username, password)
|
||||
if !ok { // This username:password combination was not allowed.
|
||||
if maxTries > 0 {
|
||||
tries++
|
||||
b.setCurrentTries(ctx, tries)
|
||||
if tries >= maxTries { // e.g. if MaxTries == 1 then it should be allowed only once, so we must send forbidden now.
|
||||
b.forbidden(ctx) // a user was forbidden, to reset its status should clear the Authorization header and cookie and request the resource again.
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
b.askForCredentials(ctx)
|
||||
ctx.StopExecution()
|
||||
return
|
||||
// don't continue to the next handler
|
||||
}
|
||||
|
||||
auth.mu.RLock()
|
||||
logged := auth.logged
|
||||
auth.mu.RUnlock()
|
||||
if !logged {
|
||||
auth.mu.Lock()
|
||||
auth.authorizedAt = time.Now()
|
||||
auth.mu.Unlock()
|
||||
}
|
||||
|
||||
// all ok
|
||||
if b.expireEnabled {
|
||||
if !logged {
|
||||
auth.mu.Lock()
|
||||
auth.expires = auth.authorizedAt.Add(b.config.Expires)
|
||||
auth.logged = true
|
||||
auth.mu.Unlock()
|
||||
}
|
||||
|
||||
auth.mu.RLock()
|
||||
expired := time.Now().After(auth.expires)
|
||||
auth.mu.RUnlock()
|
||||
if expired {
|
||||
auth.mu.Lock()
|
||||
auth.logged = false
|
||||
auth.forceLogout = false
|
||||
auth.mu.Unlock()
|
||||
b.askForCredentials(ctx) // ask for authentication again
|
||||
ctx.StopExecution()
|
||||
return
|
||||
}
|
||||
|
||||
if tries > 0 {
|
||||
// had failures but it's ok, reset the tries on success.
|
||||
b.resetCurrentTries(ctx)
|
||||
}
|
||||
|
||||
if !b.config.DisableContextUser {
|
||||
ctx.SetLogoutFunc(b.Logout)
|
||||
b.mu.RLock()
|
||||
expiresAt, ok := b.credentials[fullUser]
|
||||
b.mu.RUnlock()
|
||||
var authorizedAt time.Time
|
||||
if ok {
|
||||
if expiresAt != nil { // Has expiration.
|
||||
if expiresAt.Before(time.Now()) { // Has been expired.
|
||||
b.mu.Lock() // Delete the entry.
|
||||
delete(b.credentials, fullUser)
|
||||
b.mu.Unlock()
|
||||
// Re-ask for new credentials.
|
||||
b.askForCredentials(ctx)
|
||||
return
|
||||
}
|
||||
|
||||
auth.mu.RLock()
|
||||
user := &context.SimpleUser{
|
||||
// It's ok, find the time authorized to fill the user below, if necessary.
|
||||
authorizedAt = expiresAt.Add(-b.opts.MaxAge)
|
||||
}
|
||||
} else {
|
||||
// Saved credential not found, first login.
|
||||
if b.opts.MaxAge > 0 { // Expiration is enabled, set the value.
|
||||
authorizedAt = time.Now()
|
||||
t := authorizedAt.Add(b.opts.MaxAge)
|
||||
expiresAt = &t
|
||||
}
|
||||
b.mu.Lock()
|
||||
b.credentials[fullUser] = expiresAt
|
||||
b.mu.Unlock()
|
||||
}
|
||||
|
||||
if user == nil {
|
||||
// No custom uset was set by the auth func,
|
||||
// it is passed though, set a simple user here:
|
||||
user = &context.SimpleUser{
|
||||
Authorization: authorizationType,
|
||||
AuthorizedAt: auth.authorizedAt,
|
||||
Username: auth.Username,
|
||||
Password: auth.Password,
|
||||
AuthorizedAt: authorizedAt,
|
||||
Username: username,
|
||||
Password: password,
|
||||
}
|
||||
auth.mu.RUnlock()
|
||||
}
|
||||
|
||||
ctx.SetUser(user)
|
||||
}
|
||||
ctx.SetLogoutFunc(b.logout)
|
||||
|
||||
ctx.Next() // continue
|
||||
ctx.Next()
|
||||
}
|
||||
|
||||
// Logout sends a 401 so the browser/client can invalidate the
|
||||
// Basic Authentication and also sets the underline user's logged field to false,
|
||||
// so its expiration resets when re-ask for credentials.
|
||||
//
|
||||
// End-developers should call the `Context.Logout()` method
|
||||
// to fire this method as this structure is hidden.
|
||||
func (b *basicAuthMiddleware) Logout(ctx *context.Context) {
|
||||
ctx.StatusCode(401)
|
||||
if auth, found := b.findAuth(ctx.GetHeader("Authorization")); found {
|
||||
auth.mu.Lock()
|
||||
auth.logged = false
|
||||
auth.forceLogout = true
|
||||
auth.mu.Unlock()
|
||||
// logout clears the current user's credentials.
|
||||
func (b *BasicAuth) logout(ctx *context.Context) {
|
||||
var (
|
||||
fullUser, username, password string
|
||||
ok bool
|
||||
)
|
||||
|
||||
if u := ctx.User(); u != nil { // Get the saved ones, if any.
|
||||
username, _ = u.GetUsername()
|
||||
password, _ = u.GetPassword()
|
||||
fullUser = username + colonLiteral + password
|
||||
ok = username != "" && password != ""
|
||||
}
|
||||
|
||||
if !ok {
|
||||
// If the custom user does
|
||||
// not implement those two, then extract from the request header:
|
||||
header := ctx.GetHeader(b.authorizationHeader)
|
||||
fullUser, username, password, ok = decodeHeader(header)
|
||||
}
|
||||
|
||||
if ok { // If it's authorized then try to lock and delete.
|
||||
if b.opts.Proxy {
|
||||
ctx.Request().Header.Del(proxyAuthorizationHeaderKey)
|
||||
}
|
||||
// delete the request header so future Request().BasicAuth are empty.
|
||||
ctx.Request().Header.Del(authorizationHeaderKey)
|
||||
|
||||
b.mu.Lock()
|
||||
delete(b.credentials, fullUser)
|
||||
b.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// runGC runs a function in a separate go routine
|
||||
// every x duration to clear in-memory expired credential entries.
|
||||
func (b *BasicAuth) runGC(ctx stdContext.Context, every time.Duration) {
|
||||
if ctx == nil {
|
||||
ctx = stdContext.Background()
|
||||
}
|
||||
|
||||
t := time.NewTicker(every)
|
||||
defer t.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-t.C:
|
||||
b.gc()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// gc removes all entries expired based on the max age or all entries (if max age is missing).
|
||||
func (b *BasicAuth) gc() int {
|
||||
now := time.Now()
|
||||
var markedForDeletion []string
|
||||
|
||||
b.mu.RLock()
|
||||
for fullUser, expiresAt := range b.credentials {
|
||||
if expiresAt == nil {
|
||||
markedForDeletion = append(markedForDeletion, fullUser)
|
||||
} else if expiresAt.Before(now) {
|
||||
markedForDeletion = append(markedForDeletion, fullUser)
|
||||
}
|
||||
}
|
||||
b.mu.RUnlock()
|
||||
|
||||
n := len(markedForDeletion)
|
||||
if n > 0 {
|
||||
for _, fullUser := range markedForDeletion {
|
||||
b.mu.Lock()
|
||||
delete(b.credentials, fullUser)
|
||||
b.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
return n
|
||||
}
|
||||
|
|
|
@ -8,7 +8,7 @@ import (
|
|||
|
||||
"github.com/kataras/iris/v12"
|
||||
"github.com/kataras/iris/v12/httptest"
|
||||
"github.com/kataras/iris/v12/middleware/basicauth"
|
||||
basicauth "github.com/kataras/iris/v12/middleware/basicauth"
|
||||
)
|
||||
|
||||
func TestBasicAuthUseRouter(t *testing.T) {
|
||||
|
@ -18,7 +18,13 @@ func TestBasicAuthUseRouter(t *testing.T) {
|
|||
"admin": "admin",
|
||||
}
|
||||
|
||||
app.UseRouter(basicauth.Default(users))
|
||||
auth := basicauth.New(basicauth.Options{
|
||||
Allow: basicauth.AllowUsers(users),
|
||||
Realm: basicauth.DefaultRealm,
|
||||
MaxTries: 1,
|
||||
})
|
||||
|
||||
app.UseRouter(auth)
|
||||
|
||||
app.Get("/user_json", func(ctx iris.Context) {
|
||||
ctx.JSON(ctx.User())
|
||||
|
@ -80,9 +86,9 @@ func TestBasicAuthUseRouter(t *testing.T) {
|
|||
e.GET("/").Expect().Status(httptest.StatusUnauthorized).Body().Equal("Unauthorized")
|
||||
// Test invalid auth.
|
||||
e.GET("/").WithBasicAuth(username, "invalid_password").Expect().
|
||||
Status(httptest.StatusUnauthorized).Body().Equal("Unauthorized")
|
||||
Status(httptest.StatusForbidden)
|
||||
e.GET("/").WithBasicAuth("invaid_username", password).Expect().
|
||||
Status(httptest.StatusUnauthorized).Body().Equal("Unauthorized")
|
||||
Status(httptest.StatusForbidden)
|
||||
|
||||
// Test different method, it should pass the authentication (no stop on 401)
|
||||
// but it doesn't fire the GET route, instead it gives 405.
|
||||
|
@ -97,9 +103,9 @@ func TestBasicAuthUseRouter(t *testing.T) {
|
|||
e.GET("/notfound").Expect().Status(httptest.StatusUnauthorized).Body().Equal("Unauthorized")
|
||||
// Test invalid auth.
|
||||
e.GET("/notfound").WithBasicAuth(username, "invalid_password").Expect().
|
||||
Status(httptest.StatusUnauthorized).Body().Equal("Unauthorized")
|
||||
Status(httptest.StatusForbidden)
|
||||
e.GET("/notfound").WithBasicAuth("invaid_username", password).Expect().
|
||||
Status(httptest.StatusUnauthorized).Body().Equal("Unauthorized")
|
||||
Status(httptest.StatusForbidden)
|
||||
|
||||
// Test subdomain inherited.
|
||||
sub := e.Builder(func(req *httptest.Request) {
|
||||
|
@ -114,9 +120,9 @@ func TestBasicAuthUseRouter(t *testing.T) {
|
|||
sub.GET("/").Expect().Status(httptest.StatusUnauthorized)
|
||||
// Test invalid auth.
|
||||
sub.GET("/").WithBasicAuth(username, "invalid_password").Expect().
|
||||
Status(httptest.StatusUnauthorized).Body().Equal("Unauthorized")
|
||||
Status(httptest.StatusForbidden)
|
||||
sub.GET("/").WithBasicAuth("invaid_username", password).Expect().
|
||||
Status(httptest.StatusUnauthorized).Body().Equal("Unauthorized")
|
||||
Status(httptest.StatusForbidden)
|
||||
|
||||
// Test pass the authentication but route not found.
|
||||
sub.GET("/notfound").WithBasicAuth(username, password).Expect().
|
||||
|
@ -126,9 +132,9 @@ func TestBasicAuthUseRouter(t *testing.T) {
|
|||
sub.GET("/notfound").Expect().Status(httptest.StatusUnauthorized).Body().Equal("Unauthorized")
|
||||
// Test invalid auth.
|
||||
sub.GET("/notfound").WithBasicAuth(username, "invalid_password").Expect().
|
||||
Status(httptest.StatusUnauthorized).Body().Equal("Unauthorized")
|
||||
Status(httptest.StatusForbidden)
|
||||
sub.GET("/notfound").WithBasicAuth("invaid_username", password).Expect().
|
||||
Status(httptest.StatusUnauthorized).Body().Equal("Unauthorized")
|
||||
Status(httptest.StatusForbidden)
|
||||
|
||||
// Test a reset-ed Party with a single one UseRouter
|
||||
// which writes on matched routes and reset and send the error on errors.
|
||||
|
|
|
@ -1,59 +0,0 @@
|
|||
package basicauth
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/kataras/iris/v12/context"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultBasicAuthRealm is "Authorization Required"
|
||||
DefaultBasicAuthRealm = "Authorization Required"
|
||||
)
|
||||
|
||||
// DefaultExpireTime zero time
|
||||
var DefaultExpireTime time.Time // 0001-01-01 00:00:00 +0000 UTC
|
||||
|
||||
// Config the configs for the basicauth middleware
|
||||
type Config struct {
|
||||
// Users a map of login and the value (username/password)
|
||||
Users map[string]string
|
||||
// Realm http://tools.ietf.org/html/rfc2617#section-1.2. Default is "Authorization Required"
|
||||
Realm string
|
||||
// Expires expiration duration, default is 0 never expires.
|
||||
Expires time.Duration
|
||||
|
||||
// OnAsk fires each time the server asks to the client for credentials in order to gain access and continue to the next handler.
|
||||
//
|
||||
// You could also ignore this option and
|
||||
// - just add a listener for unauthorized status codes with:
|
||||
// `app.OnErrorCode(iris.StatusUnauthorized, unauthorizedWantsAccessHandler)`
|
||||
// - or register a middleware which will force `ctx.Next/or direct call`
|
||||
// the basicauth middleware and check its `ctx.GetStatusCode()`.
|
||||
//
|
||||
// However, this option is very useful when you want the framework to fire a handler
|
||||
// ONLY when the Basic Authentication sends an `iris.StatusUnauthorized`,
|
||||
// and free the error code listener to catch other types of unauthorized access, i.e Kerberos.
|
||||
// Also with this one, not recommended at all but, you are able to "force-allow" other users by calling the `ctx.StatusCode` inside this handler;
|
||||
// i.e when it is possible to create authorized users dynamically but
|
||||
// if that is the case then you should go with something like sessions instead of basic authentication.
|
||||
//
|
||||
// Usage: basicauth.New(basicauth.Config{..., OnAsk: unauthorizedWantsAccessViaBasicAuthHandler})
|
||||
//
|
||||
// Defaults to nil.
|
||||
OnAsk context.Handler
|
||||
|
||||
// DisableContextUser disables the registration of the custom basicauth Context.Logout
|
||||
// and the User.
|
||||
DisableContextUser bool
|
||||
}
|
||||
|
||||
// DefaultConfig returns the default configs for the BasicAuth middleware
|
||||
func DefaultConfig() Config {
|
||||
return Config{make(map[string]string), DefaultBasicAuthRealm, 0, nil, false}
|
||||
}
|
||||
|
||||
// User returns the user from context key same as ctx.Request().BasicAuth().
|
||||
func (c Config) User(ctx *context.Context) (string, string, bool) {
|
||||
return ctx.Request().BasicAuth()
|
||||
}
|
88
middleware/basicauth/header.go
Normal file
88
middleware/basicauth/header.go
Normal file
|
@ -0,0 +1,88 @@
|
|||
package basicauth
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
spaceChar = ' '
|
||||
colonChar = ':'
|
||||
colonLiteral = string(colonChar)
|
||||
basicLiteral = "Basic"
|
||||
basicSpaceLiteral = "Basic "
|
||||
basicSpaceLiteralLen = len(basicSpaceLiteral)
|
||||
)
|
||||
|
||||
// The username and password are combined with a single colon (:).
|
||||
// This means that the username itself cannot contain a colon.
|
||||
// URL encoding (e.g. https://Aladdin:OpenSesame@www.example.com/index.html)
|
||||
// has been deprecated by rfc3986.
|
||||
func encodeHeader(username, password string) (string, bool) {
|
||||
if strings.Contains(username, colonLiteral) || strings.Contains(password, colonLiteral) {
|
||||
return "", false
|
||||
}
|
||||
fullUser := []byte(username + colonLiteral + password)
|
||||
header := basicSpaceLiteral + base64.StdEncoding.EncodeToString(fullUser)
|
||||
|
||||
return header, true
|
||||
}
|
||||
|
||||
// Like net/http.parseBasicAuth
|
||||
func decodeHeader(header string) (fullUser, username, password string, ok bool) {
|
||||
if len(header) < basicSpaceLiteralLen || !strings.EqualFold(header[:basicSpaceLiteralLen], basicSpaceLiteral) {
|
||||
return
|
||||
}
|
||||
|
||||
c, err := base64.StdEncoding.DecodeString(header[basicSpaceLiteralLen:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
cs := string(c)
|
||||
s := strings.IndexByte(cs, colonChar)
|
||||
if s < 0 {
|
||||
return
|
||||
}
|
||||
return cs, cs[:s], cs[s+1:], true
|
||||
|
||||
/*
|
||||
for i := 0; i < n; i++ {
|
||||
if header[i] == spaceChar {
|
||||
prefix := header[:i]
|
||||
if prefix != basicLiteral {
|
||||
return
|
||||
}
|
||||
|
||||
if n <= i+1 {
|
||||
return
|
||||
}
|
||||
|
||||
decodedFullUser, err := base64.RawStdEncoding.DecodeString(header[i+1:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
fullUser = string(decodedFullUser)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
n = len(fullUser)
|
||||
for i := n - 1; i > -1; i-- {
|
||||
if fullUser[i] == colonChar {
|
||||
username = fullUser[:i]
|
||||
password = fullUser[i+1:]
|
||||
|
||||
if strings.TrimSpace(username) == "" || strings.TrimSpace(password) == "" {
|
||||
ok = false
|
||||
} else {
|
||||
ok = true
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
return*/
|
||||
}
|
103
middleware/basicauth/header_test.go
Normal file
103
middleware/basicauth/header_test.go
Normal file
|
@ -0,0 +1,103 @@
|
|||
package basicauth
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestHeaderEncode(t *testing.T) {
|
||||
var tests = []struct {
|
||||
username string
|
||||
password string
|
||||
header string
|
||||
ok bool
|
||||
}{
|
||||
{
|
||||
username: "user",
|
||||
password: "pass",
|
||||
header: "Basic dXNlcjpwYXNz",
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
username: "user",
|
||||
password: "p:(notallowed)ass",
|
||||
header: "",
|
||||
ok: false,
|
||||
},
|
||||
{
|
||||
username: "123u%ser",
|
||||
password: "pass132$",
|
||||
header: "Basic MTIzdSVzZXI6cGFzczEzMiQ=",
|
||||
ok: true,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
got, ok := encodeHeader(tt.username, tt.password)
|
||||
if tt.ok != ok {
|
||||
t.Fatalf("[%d] expected: %v but got: %v (username=%s,password=%s)", i, tt.ok, ok, tt.username, tt.password)
|
||||
}
|
||||
if tt.header != got {
|
||||
t.Fatalf("[%d] expected result header: %q but got: %q", i, tt.header, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeaderDecode(t *testing.T) {
|
||||
var tests = []struct {
|
||||
header string
|
||||
ok bool
|
||||
username string
|
||||
password string
|
||||
}{
|
||||
{
|
||||
header: "Basic dXNlcjpwYXNz",
|
||||
ok: true,
|
||||
username: "user",
|
||||
password: "pass",
|
||||
},
|
||||
{
|
||||
header: "dXNlcjpwYXNz",
|
||||
ok: false,
|
||||
},
|
||||
{
|
||||
header: "Basic ",
|
||||
ok: false,
|
||||
},
|
||||
{
|
||||
header: "Basic dXNlcjp",
|
||||
ok: false,
|
||||
},
|
||||
{
|
||||
header: "dXNlcjpwYXNz Basic",
|
||||
ok: false,
|
||||
},
|
||||
{
|
||||
header: "dXNlcjpwYXNzBasic",
|
||||
ok: false,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
fullUser, username, password, ok := decodeHeader(tt.header)
|
||||
if expected, got := tt.ok, ok; expected != got {
|
||||
t.Fatalf("[%d] expected: %v but got: %v (header=%s)", i, expected, got, tt.header)
|
||||
}
|
||||
|
||||
if expected, got := tt.username, username; expected != got {
|
||||
t.Fatalf("[%d] expected username: %q but got: %q", i, expected, got)
|
||||
}
|
||||
|
||||
if expected, got := tt.password, password; expected != got {
|
||||
t.Fatalf("[%d] expected password: %q but got: %q", i, expected, got)
|
||||
}
|
||||
|
||||
if tt.username != "" || tt.password != "" {
|
||||
if expected, got := tt.username+colonLiteral+tt.password, fullUser; expected != got {
|
||||
t.Fatalf("[%d] expected username:password to be: %q but got: %q", i, expected, got)
|
||||
}
|
||||
} else {
|
||||
if fullUser != "" {
|
||||
t.Fatalf("[%d] expected username:password to be empty but got: %q", i, fullUser)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
256
middleware/basicauth/user_auth.go
Normal file
256
middleware/basicauth/user_auth.go
Normal file
|
@ -0,0 +1,256 @@
|
|||
package basicauth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/kataras/iris/v12/context"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type UserAuthOptions struct {
|
||||
// Defaults to plain check, can be modified for encrypted passwords, see `BCRYPT`.
|
||||
ComparePassword func(stored, userPassword string) bool
|
||||
}
|
||||
|
||||
type UserAuthOption func(*UserAuthOptions)
|
||||
|
||||
func BCRYPT(opts *UserAuthOptions) {
|
||||
opts.ComparePassword = func(stored, userPassword string) bool {
|
||||
err := bcrypt.CompareHashAndPassword([]byte(stored), []byte(userPassword))
|
||||
return err == nil
|
||||
}
|
||||
}
|
||||
|
||||
func toUserAuthOptions(opts []UserAuthOption) (options UserAuthOptions) {
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
|
||||
if options.ComparePassword == nil {
|
||||
options.ComparePassword = func(stored, userPassword string) bool {
|
||||
return stored == userPassword
|
||||
}
|
||||
}
|
||||
|
||||
return options
|
||||
}
|
||||
|
||||
type User interface {
|
||||
context.UserGetUsername
|
||||
context.UserGetPassword
|
||||
}
|
||||
|
||||
// Users
|
||||
// - map[string]string form of: {username:password, ...} form.
|
||||
// - map[string]interface{} form of: []{"username": "...", "password": "...", "other_field": ...}, ...}.
|
||||
// - []T which T completes the User interface.
|
||||
// - []T which T contains at least Username and Password fields.
|
||||
func AllowUsers(users interface{}, opts ...UserAuthOption) AuthFunc {
|
||||
// create a local user structure to be used in the map copy,
|
||||
// takes longer to initialize but faster to serve.
|
||||
type user struct {
|
||||
password string
|
||||
ref interface{}
|
||||
}
|
||||
cp := make(map[string]*user)
|
||||
|
||||
v := reflect.Indirect(reflect.ValueOf(users))
|
||||
switch v.Kind() {
|
||||
case reflect.Slice:
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
elem := v.Index(i).Interface()
|
||||
// MUST contain a username and password.
|
||||
username, password, ok := extractUsernameAndPassword(elem)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
cp[username] = &user{
|
||||
password: password,
|
||||
ref: elem,
|
||||
}
|
||||
}
|
||||
case reflect.Map:
|
||||
elem := v.Interface()
|
||||
switch m := elem.(type) {
|
||||
case map[string]string:
|
||||
return userMap(m, opts...)
|
||||
case map[string]interface{}:
|
||||
username, password, ok := mapUsernameAndPassword(m)
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
|
||||
cp[username] = &user{
|
||||
password: password,
|
||||
ref: m,
|
||||
}
|
||||
default:
|
||||
panic(fmt.Sprintf("unsupported type of map: %T", users))
|
||||
}
|
||||
default:
|
||||
panic(fmt.Sprintf("unsupported type: %T", users))
|
||||
}
|
||||
|
||||
options := toUserAuthOptions(opts)
|
||||
|
||||
return func(_ *context.Context, username, password string) (interface{}, bool) {
|
||||
if u, ok := cp[username]; ok { // fast map access,
|
||||
if options.ComparePassword(u.password, password) {
|
||||
return u.ref, true
|
||||
}
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
func userMap(usernamePassword map[string]string, opts ...UserAuthOption) AuthFunc {
|
||||
options := toUserAuthOptions(opts)
|
||||
|
||||
return func(_ *context.Context, username, password string) (interface{}, bool) {
|
||||
pass, ok := usernamePassword[username]
|
||||
return nil, ok && options.ComparePassword(pass, password)
|
||||
}
|
||||
}
|
||||
|
||||
func AllowUsersFile(jsonOrYamlFilename string, opts ...UserAuthOption) AuthFunc {
|
||||
var (
|
||||
usernamePassword map[string]string
|
||||
// no need to support too much forms, this would be for:
|
||||
// "$username": { "password": "$pass", "other_field": ...}
|
||||
// users map[string]map[string]interface{}
|
||||
userList []map[string]interface{}
|
||||
)
|
||||
|
||||
if err := decodeFile(jsonOrYamlFilename, &usernamePassword, &userList); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if len(usernamePassword) > 0 {
|
||||
// JSON Form: { "$username":"$pass", "$username": "$pass" }
|
||||
// YAML Form: $username: $pass
|
||||
// $username: $pass
|
||||
return userMap(usernamePassword, opts...)
|
||||
}
|
||||
|
||||
if len(userList) > 0 {
|
||||
// JSON Form: [{"username": "$username", "password": "$pass", "other_field": ...}, {"username": ...}, ... ]
|
||||
// YAML Form:
|
||||
// - username: $username
|
||||
// password: $password
|
||||
// other_field: ...
|
||||
return AllowUsers(userList, opts...)
|
||||
}
|
||||
|
||||
panic("malformed document file: " + jsonOrYamlFilename)
|
||||
}
|
||||
|
||||
func decodeFile(src string, dest ...interface{}) error {
|
||||
data, err := ioutil.ReadFile(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// We use unmarshal instead of file decoder
|
||||
// as we may need to read it more than once (dests, see below).
|
||||
var (
|
||||
unmarshal func(data []byte, v interface{}) error
|
||||
ext string
|
||||
)
|
||||
|
||||
if idx := strings.LastIndexByte(src, '.'); idx > 0 {
|
||||
ext = src[idx:]
|
||||
}
|
||||
|
||||
switch ext {
|
||||
case "", ".json":
|
||||
unmarshal = json.Unmarshal
|
||||
case ".yml", ".yaml":
|
||||
unmarshal = yaml.Unmarshal
|
||||
default:
|
||||
return fmt.Errorf("unexpected file extension: %s", ext)
|
||||
}
|
||||
|
||||
var (
|
||||
ok bool
|
||||
lastErr error
|
||||
)
|
||||
|
||||
for _, d := range dest {
|
||||
if err = unmarshal(data, d); err == nil {
|
||||
ok = true
|
||||
} else {
|
||||
lastErr = err
|
||||
}
|
||||
}
|
||||
|
||||
if !ok {
|
||||
return lastErr
|
||||
}
|
||||
|
||||
return nil // if at least one is succeed we are ok.
|
||||
}
|
||||
|
||||
func extractUsernameAndPassword(s interface{}) (username, password string, ok bool) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch u := s.(type) {
|
||||
case User:
|
||||
username = u.GetUsername()
|
||||
password = u.GetPassword()
|
||||
ok = username != "" && password != ""
|
||||
return
|
||||
case map[string]interface{}:
|
||||
return mapUsernameAndPassword(u)
|
||||
default:
|
||||
b, err := json.Marshal(u)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var m map[string]interface{}
|
||||
if err = json.Unmarshal(b, &m); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return mapUsernameAndPassword(m)
|
||||
}
|
||||
}
|
||||
|
||||
func mapUsernameAndPassword(m map[string]interface{}) (username, password string, ok bool) {
|
||||
// type of username: password.
|
||||
if len(m) == 1 {
|
||||
for username, v := range m {
|
||||
if password, ok := v.(string); ok {
|
||||
ok := username != "" && password != ""
|
||||
return username, password, ok
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var usernameFound, passwordFound bool
|
||||
|
||||
for k, v := range m {
|
||||
switch k {
|
||||
case "username", "Username":
|
||||
username, usernameFound = v.(string)
|
||||
case "password", "Password":
|
||||
password, passwordFound = v.(string)
|
||||
}
|
||||
|
||||
if usernameFound && passwordFound {
|
||||
ok = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
287
middleware/basicauth/user_auth_test.go
Normal file
287
middleware/basicauth/user_auth_test.go
Normal file
|
@ -0,0 +1,287 @@
|
|||
package basicauth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/kataras/iris/v12/context"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type IUserRepository interface {
|
||||
GetByUsernameAndPassword(dest interface{}, username, password string) error
|
||||
}
|
||||
|
||||
// Test a custom implementation of AuthFunc with a user repository.
|
||||
// This is a usage example of custom AuthFunc implementation.
|
||||
func UserRepository(repo IUserRepository, newUserPtr func() interface{}) AuthFunc {
|
||||
return func(ctx *context.Context, username, password string) (interface{}, bool) {
|
||||
dest := newUserPtr()
|
||||
err := repo.GetByUsernameAndPassword(dest, username, password)
|
||||
if err == nil {
|
||||
return dest, true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
type testUser struct {
|
||||
username string
|
||||
password string
|
||||
email string // custom field.
|
||||
}
|
||||
|
||||
// GetUsername & Getpassword complete the User interface (optional but useful on Context.User()).
|
||||
func (u *testUser) GetUsername() string {
|
||||
return u.username
|
||||
}
|
||||
|
||||
func (u *testUser) GetPassword() string {
|
||||
return u.password
|
||||
}
|
||||
|
||||
type testRepo struct {
|
||||
entries []testUser
|
||||
}
|
||||
|
||||
// Implements IUserRepository interface.
|
||||
func (r *testRepo) GetByUsernameAndPassword(dest interface{}, username, password string) error {
|
||||
for _, e := range r.entries {
|
||||
if e.username == username && e.password == password {
|
||||
*dest.(*testUser) = e
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return errors.New("invalid credentials")
|
||||
}
|
||||
|
||||
func TestAllowUserRepository(t *testing.T) {
|
||||
repo := &testRepo{
|
||||
entries: []testUser{
|
||||
{username: "kataras", password: "kataras_pass", email: "kataras2006@hotmail.com"},
|
||||
},
|
||||
}
|
||||
|
||||
allow := UserRepository(repo, func() interface{} {
|
||||
return new(testUser)
|
||||
})
|
||||
|
||||
var tests = []struct {
|
||||
username string
|
||||
password string
|
||||
ok bool
|
||||
user *testUser
|
||||
}{
|
||||
{
|
||||
username: "kataras",
|
||||
password: "kataras_pass",
|
||||
ok: true,
|
||||
user: &testUser{username: "kataras", password: "kataras_pass", email: "kataras2006@hotmail.com"},
|
||||
},
|
||||
{
|
||||
username: "makis",
|
||||
password: "makis_password",
|
||||
ok: false,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
v, ok := allow(nil, tt.username, tt.password)
|
||||
|
||||
if tt.ok != ok {
|
||||
t.Fatalf("[%d] expected: %v but got: %v (username=%s,password=%s)", i, tt.ok, ok, tt.username, tt.password)
|
||||
}
|
||||
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
u, ok := v.(*testUser)
|
||||
if !ok {
|
||||
t.Fatalf("[%d] a user should be type of *testUser but got: %#+v (%T)", i, v, v)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(tt.user, u) {
|
||||
t.Fatalf("[%d] expected user:\n%#+v\nbut got:\n%#+v", i, tt.user, u)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllowUsers(t *testing.T) {
|
||||
users := []User{
|
||||
&testUser{username: "kataras", password: "kataras_pass", email: "kataras2006@hotmail.com"},
|
||||
}
|
||||
|
||||
allow := AllowUsers(users)
|
||||
|
||||
var tests = []struct {
|
||||
username string
|
||||
password string
|
||||
ok bool
|
||||
user *testUser
|
||||
}{
|
||||
{
|
||||
username: "kataras",
|
||||
password: "kataras_pass",
|
||||
ok: true,
|
||||
user: &testUser{username: "kataras", password: "kataras_pass", email: "kataras2006@hotmail.com"},
|
||||
},
|
||||
{
|
||||
username: "makis",
|
||||
password: "makis_password",
|
||||
ok: false,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
v, ok := allow(nil, tt.username, tt.password)
|
||||
|
||||
if tt.ok != ok {
|
||||
t.Fatalf("[%d] expected: %v but got: %v (username=%s,password=%s)", i, tt.ok, ok, tt.username, tt.password)
|
||||
}
|
||||
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
u, ok := v.(*testUser)
|
||||
if !ok {
|
||||
t.Fatalf("[%d] a user should be type of *testUser but got: %#+v (%T)", i, v, v)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(tt.user, u) {
|
||||
t.Fatalf("[%d] expected user:\n%#+v\nbut got:\n%#+v", i, tt.user, u)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test YAML user loading with b-encrypted passwords.
|
||||
func TestAllowUsersFile(t *testing.T) {
|
||||
f, err := ioutil.TempFile("", "*users.yml")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
f.Close()
|
||||
os.Remove(f.Name())
|
||||
}()
|
||||
|
||||
// f.WriteString(`
|
||||
// - username: kataras
|
||||
// password: kataras_pass
|
||||
// age: 27
|
||||
// role: admin
|
||||
// - username: makis
|
||||
// password: makis_password
|
||||
// `)
|
||||
// This form is supported too, although its features are limited (no custom fields):
|
||||
// f.WriteString(`
|
||||
// kataras: kataras_pass
|
||||
// makis: makis_password
|
||||
// `)
|
||||
|
||||
var tests = []struct {
|
||||
username string
|
||||
password string // hashed, auto-filled later on.
|
||||
inputPassword string
|
||||
ok bool
|
||||
user context.Map
|
||||
}{
|
||||
{
|
||||
username: "kataras",
|
||||
inputPassword: "kataras_pass",
|
||||
ok: true,
|
||||
user: context.Map{"age": 27, "role": "admin"}, // username and password are auto-filled in our tests below.
|
||||
},
|
||||
{
|
||||
username: "makis",
|
||||
inputPassword: "makis_password",
|
||||
ok: true,
|
||||
user: context.Map{},
|
||||
},
|
||||
{
|
||||
username: "invalid",
|
||||
password: "invalid_pass",
|
||||
ok: false,
|
||||
},
|
||||
{
|
||||
username: "notvalid",
|
||||
password: "",
|
||||
ok: false,
|
||||
},
|
||||
}
|
||||
|
||||
// Write the tests to the users YAML file.
|
||||
var usersToWrite []context.Map
|
||||
for _, tt := range tests {
|
||||
if tt.ok {
|
||||
// store the hashed password.
|
||||
tt.password = mustGeneratePassword(t, tt.inputPassword)
|
||||
// store and write the username and hashed password.
|
||||
tt.user["username"] = tt.username
|
||||
tt.user["password"] = tt.password
|
||||
|
||||
// cannot write it as a stream, write it as a slice.
|
||||
// enc.Encode(tt.user)
|
||||
usersToWrite = append(usersToWrite, tt.user)
|
||||
}
|
||||
// bcrypt.GenerateFromPassword([]byte("kataras_pass"), bcrypt.DefaultCost)
|
||||
}
|
||||
|
||||
fileContents, err := yaml.Marshal(usersToWrite)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
f.Write(fileContents)
|
||||
|
||||
// Build the authentication func.
|
||||
allow := AllowUsersFile(f.Name(), BCRYPT)
|
||||
for i, tt := range tests {
|
||||
v, ok := allow(nil, tt.username, tt.inputPassword)
|
||||
|
||||
if tt.ok != ok {
|
||||
t.Fatalf("[%d] expected: %v but got: %v (username=%s,password=%s,user=%#+v)", i, tt.ok, ok, tt.username, tt.inputPassword, v)
|
||||
}
|
||||
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(tt.user) == 0 { // when username: password form.
|
||||
continue
|
||||
}
|
||||
|
||||
u, ok := v.(context.Map)
|
||||
if !ok {
|
||||
t.Fatalf("[%d] a user loaded from external source or file should be alway type of map[string]interface{} but got: %#+v (%T)", i, v, v)
|
||||
}
|
||||
|
||||
if expected, got := len(tt.user), len(u); expected != got {
|
||||
t.Fatalf("[%d] expected user map length to be equal, expected: %d but got: %d\n%#+v\n%#+v", i, expected, got, tt.user, u)
|
||||
}
|
||||
|
||||
for k, v := range tt.user {
|
||||
if u[k] != v {
|
||||
t.Fatalf("[%d] expected user map %q to be %q but got: %q", i, k, v, u[k])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func mustGeneratePassword(t *testing.T, userPassword string) string {
|
||||
t.Helper()
|
||||
hashed, err := bcrypt.GenerateFromPassword([]byte(userPassword), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return string(hashed)
|
||||
}
|
Loading…
Reference in New Issue
Block a user