diff --git a/middleware/cors/cors.go b/middleware/cors/cors.go index 0e137ad4..58c0da23 100644 --- a/middleware/cors/cors.go +++ b/middleware/cors/cors.go @@ -2,6 +2,7 @@ package cors import ( "errors" + "regexp" "strconv" "strings" "time" @@ -11,7 +12,7 @@ import ( ) func init() { - context.SetHandlerName("iris/middleware/basicauth.*", "iris.cors") + context.SetHandlerName("iris/middleware/cors.*", "iris.cors") } var ( @@ -82,6 +83,69 @@ func (c *CORS) AllowOriginFunc(fn AllowOriginFunc) *CORS { return c } +// AllowOrigin calls the "AllowOriginFunc" method +// and registers a function which accepts any incoming +// request with origin of the given "originLine". +// The originLine can contain one or more domains separated by comma. +// See "AllowOrigins" to set a list of strings instead. +func (c *CORS) AllowOrigin(originLine string) *CORS { + return c.AllowOrigins(strings.Split(originLine, ",")...) +} + +// AllowOriginMatcherFunc sets the allow origin func without iris.Context +// as its first parameter, i.e. a regular expression. +func (c *CORS) AllowOriginMatcherFunc(fn func(origin string) bool) *CORS { + return c.AllowOriginFunc(func(ctx iris.Context, origin string) bool { + return fn(origin) + }) +} + +// AllowOriginRegex calls the "AllowOriginFunc" method +// and registers a function which accepts any incoming +// request with origin that matches at least one of the given "regexpLines". +func (c *CORS) AllowOriginRegex(regexpLines ...string) *CORS { + matchers := make([]func(string) bool, 0, len(regexpLines)) + for _, line := range regexpLines { + matcher := regexp.MustCompile(line).MatchString + matchers = append(matchers, matcher) + } + + return c.AllowOriginFunc(func(ctx iris.Context, origin string) bool { + for _, m := range matchers { + if m(origin) { + return true + } + } + + return false + }) +} + +// AllowOrigins calls the "AllowOriginFunc" method +// and registers a function which accepts any incoming +// request with origin of one of the given "origins". +func (c *CORS) AllowOrigins(origins ...string) *CORS { + allowOrigins := make(map[string]struct{}, len(origins)) // read-only at serve time. + for _, origin := range origins { + if origin == "*" { + // If AllowOrigins called with asterix, it is a missuse of this + // middleware (set AllowAnyOrigin instead). + allowOrigins = nil + return c.AllowOriginFunc(AllowAnyOrigin) + // panic("wildcard is not allowed, use AllowOriginFunc(AllowAnyOrigin) instead") + // No ^ let's register a function which allows all and continue. + } + + origin = strings.TrimSpace(origin) + allowOrigins[origin] = struct{}{} + } + + return c.AllowOriginFunc(func(ctx iris.Context, origin string) bool { + _, allow := allowOrigins[origin] + return allow + }) +} + // HandleErrorFunc sets the function which is called // when an error of origin not allowed is fired. func (c *CORS) HandleErrorFunc(fn HandleErrorFunc) *CORS {