mirror of
https://github.com/kataras/iris.git
synced 2025-01-23 10:41:03 +01:00
321 lines
11 KiB
Go
321 lines
11 KiB
Go
package cors
|
|
|
|
import (
|
|
"errors"
|
|
"net/http"
|
|
"regexp"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/kataras/iris/v12/context"
|
|
)
|
|
|
|
func init() {
|
|
context.SetHandlerName("iris/middleware/cors.*", "iris.cors")
|
|
}
|
|
|
|
var (
|
|
// ErrOriginNotAllowed is given to the error handler
|
|
// when the error is caused because an origin was not allowed to pass through.
|
|
ErrOriginNotAllowed = errors.New("origin not allowed")
|
|
|
|
// AllowAnyOrigin allows all origins to pass.
|
|
AllowAnyOrigin = func(_ *context.Context, _ string) bool {
|
|
return true
|
|
}
|
|
|
|
// DefaultErrorHandler is the default error handler which
|
|
// fires forbidden status (403) on disallowed origins.
|
|
DefaultErrorHandler = func(ctx *context.Context, _ error) {
|
|
ctx.StopWithStatus(http.StatusForbidden)
|
|
}
|
|
|
|
// DefaultOriginExtractor is the default method which
|
|
// an origin is extracted. It returns the value of the request's "Origin" header
|
|
// and always true, means that it allows empty origin headers as well.
|
|
DefaultOriginExtractor = func(ctx *context.Context) (string, bool) {
|
|
header := ctx.GetHeader(originRequestHeader)
|
|
return header, true
|
|
}
|
|
|
|
// StrictOriginExtractor is an ExtractOriginFunc type
|
|
// which is a bit more strictly than the DefaultOriginExtractor.
|
|
// It allows only non-empty "Origin" header values to be passed.
|
|
// If the header is missing, the middleware will not allow the execution
|
|
// of the next handler(s).
|
|
StrictOriginExtractor = func(ctx *context.Context) (string, bool) {
|
|
header := ctx.GetHeader(originRequestHeader)
|
|
return header, header != ""
|
|
}
|
|
)
|
|
|
|
type (
|
|
// ExtractOriginFunc describes the function which should return the request's origin or false.
|
|
ExtractOriginFunc = func(ctx *context.Context) (string, bool)
|
|
|
|
// AllowOriginFunc describes the function which is called when the
|
|
// middleware decides if the request's origin should be allowed or not.
|
|
AllowOriginFunc = func(ctx *context.Context, origin string) bool
|
|
|
|
// HandleErrorFunc describes the function which is fired
|
|
// when a request by a specific (or empty) origin was not allowed to pass through.
|
|
HandleErrorFunc = func(ctx *context.Context, err error)
|
|
|
|
// CORS holds the customizations developers can
|
|
// do on the cors middleware.
|
|
//
|
|
// Read more at: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS.
|
|
CORS struct {
|
|
extractOriginFunc ExtractOriginFunc
|
|
allowOriginFunc AllowOriginFunc
|
|
errorHandler HandleErrorFunc
|
|
|
|
allowCredentialsValue string
|
|
exposeHeadersValue string
|
|
allowHeadersValue string
|
|
allowMethodsValue string
|
|
maxAgeSecondsValue string
|
|
referrerPolicyValue string
|
|
}
|
|
)
|
|
|
|
// New returns the default CORS middleware.
|
|
// For a more advanced type of protection middleware with more options
|
|
// please refer to: https://github.com/iris-contrib/middleware repository instead.
|
|
//
|
|
// Example Code:
|
|
//
|
|
// import "github.com/kataras/iris/v12/middleware/cors"
|
|
// import "github.com/kataras/iris/v12/x/errors"
|
|
//
|
|
// app.UseRouter(cors.New().
|
|
// HandleErrorFunc(func(ctx iris.Context, err error) {
|
|
// errors.FailedPrecondition.Err(ctx, err)
|
|
// }).
|
|
// ExtractOriginFunc(cors.StrictOriginExtractor).
|
|
// ReferrerPolicy(cors.NoReferrerWhenDowngrade).
|
|
// AllowOrigin("domain1.com,domain2.com,domain3.com").
|
|
// Handler())
|
|
func New() *CORS {
|
|
return &CORS{
|
|
extractOriginFunc: DefaultOriginExtractor,
|
|
allowOriginFunc: AllowAnyOrigin,
|
|
errorHandler: DefaultErrorHandler,
|
|
|
|
allowCredentialsValue: "true",
|
|
exposeHeadersValue: "*, Authorization, X-Authorization",
|
|
allowHeadersValue: "*",
|
|
// This field cannot be modified by the end-developer,
|
|
// as we have another type of controlling the HTTP verbs per handler.
|
|
allowMethodsValue: "*",
|
|
maxAgeSecondsValue: "86400",
|
|
referrerPolicyValue: NoReferrerWhenDowngrade.String(),
|
|
}
|
|
}
|
|
|
|
// ExtractOriginFunc sets the function which should return the request's origin.
|
|
func (c *CORS) ExtractOriginFunc(fn ExtractOriginFunc) *CORS {
|
|
c.extractOriginFunc = fn
|
|
return c
|
|
}
|
|
|
|
// AllowOriginFunc sets the function which decides if an origin(domain) is allowed
|
|
// to continue or not.
|
|
//
|
|
// Read more at: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS#access-control-allow-origin.
|
|
func (c *CORS) AllowOriginFunc(fn AllowOriginFunc) *CORS {
|
|
c.allowOriginFunc = fn
|
|
return c
|
|
}
|
|
|
|
// AllowOrigin calls the "AllowOriginFunc" method
|
|
// and registers a function which accepts any incoming
|
|
// request with origin of the given "originLine".
|
|
// The originLine can contain one or more domains separated by comma.
|
|
// See "AllowOrigins" to set a list of strings instead.
|
|
func (c *CORS) AllowOrigin(originLine string) *CORS {
|
|
return c.AllowOrigins(strings.Split(originLine, ",")...)
|
|
}
|
|
|
|
// AllowOriginMatcherFunc sets the allow origin func without iris.Context
|
|
// as its first parameter, i.e. a regular expression.
|
|
func (c *CORS) AllowOriginMatcherFunc(fn func(origin string) bool) *CORS {
|
|
return c.AllowOriginFunc(func(ctx *context.Context, origin string) bool {
|
|
return fn(origin)
|
|
})
|
|
}
|
|
|
|
// AllowOriginRegex calls the "AllowOriginFunc" method
|
|
// and registers a function which accepts any incoming
|
|
// request with origin that matches at least one of the given "regexpLines".
|
|
func (c *CORS) AllowOriginRegex(regexpLines ...string) *CORS {
|
|
matchers := make([]func(string) bool, 0, len(regexpLines))
|
|
for _, line := range regexpLines {
|
|
matcher := regexp.MustCompile(line).MatchString
|
|
matchers = append(matchers, matcher)
|
|
}
|
|
|
|
return c.AllowOriginFunc(func(ctx *context.Context, origin string) bool {
|
|
for _, m := range matchers {
|
|
if m(origin) {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
})
|
|
}
|
|
|
|
// AllowOrigins calls the "AllowOriginFunc" method
|
|
// and registers a function which accepts any incoming
|
|
// request with origin of one of the given "origins".
|
|
func (c *CORS) AllowOrigins(origins ...string) *CORS {
|
|
allowOrigins := make(map[string]struct{}, len(origins)) // read-only at serve time.
|
|
for _, origin := range origins {
|
|
if origin == "*" {
|
|
// If AllowOrigins called with asterix, it is a missuse of this
|
|
// middleware (set AllowAnyOrigin instead).
|
|
allowOrigins = nil
|
|
return c.AllowOriginFunc(AllowAnyOrigin)
|
|
// panic("wildcard is not allowed, use AllowOriginFunc(AllowAnyOrigin) instead")
|
|
// No ^ let's register a function which allows all and continue.
|
|
}
|
|
|
|
origin = strings.TrimSpace(origin)
|
|
allowOrigins[origin] = struct{}{}
|
|
}
|
|
|
|
return c.AllowOriginFunc(func(ctx *context.Context, origin string) bool {
|
|
_, allow := allowOrigins[origin]
|
|
return allow
|
|
})
|
|
}
|
|
|
|
// HandleErrorFunc sets the function which is called
|
|
// when an error of origin not allowed is fired.
|
|
func (c *CORS) HandleErrorFunc(fn HandleErrorFunc) *CORS {
|
|
c.errorHandler = fn
|
|
return c
|
|
}
|
|
|
|
// DisallowCredentials sets the "Access-Control-Allow-Credentials" header to false.
|
|
//
|
|
// Read more at: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS#access-control-allow-credentials.
|
|
func (c *CORS) DisallowCredentials() *CORS {
|
|
c.allowCredentialsValue = "false"
|
|
return c
|
|
}
|
|
|
|
// ExposeHeaders sets the "Access-Control-Expose-Headers" header value.
|
|
//
|
|
// Read more at: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS#access-control-expose-headers.
|
|
func (c *CORS) ExposeHeaders(headers ...string) *CORS {
|
|
c.exposeHeadersValue = strings.Join(headers, ", ")
|
|
return c
|
|
}
|
|
|
|
// AllowHeaders sets the "Access-Control-Allow-Headers" header value.
|
|
//
|
|
// Read more at: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS#access-control-allow-headers.
|
|
func (c *CORS) AllowHeaders(headers ...string) *CORS {
|
|
c.allowHeadersValue = strings.Join(headers, ", ")
|
|
return c
|
|
}
|
|
|
|
// ReferrerPolicy type for referrer-policy header value.
|
|
type ReferrerPolicy string
|
|
|
|
// All available referrer policies.
|
|
// Read more at: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Referrer-Policy.
|
|
const (
|
|
NoReferrer ReferrerPolicy = "no-referrer"
|
|
NoReferrerWhenDowngrade ReferrerPolicy = "no-referrer-when-downgrade"
|
|
Origin ReferrerPolicy = "origin"
|
|
OriginWhenCrossOrigin ReferrerPolicy = "origin-when-cross-origin"
|
|
SameOrigin ReferrerPolicy = "same-origin"
|
|
StrictOrigin ReferrerPolicy = "strict-origin"
|
|
StrictOriginWhenCrossOrigin ReferrerPolicy = "strict-origin-when-cross-origin"
|
|
UnsafeURL ReferrerPolicy = "unsafe-url"
|
|
)
|
|
|
|
// String returns the text representation of the "r" ReferrerPolicy.
|
|
func (r ReferrerPolicy) String() string {
|
|
return string(r)
|
|
}
|
|
|
|
// ReferrerPolicy sets the "Referrer-Policy" header value.
|
|
// Defaults to "no-referrer-when-downgrade".
|
|
//
|
|
// Read more at: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Referrer-Policy
|
|
// and https://developer.mozilla.org/en-US/docs/Web/Security/Referer_header:_privacy_and_security_concerns.
|
|
func (c *CORS) ReferrerPolicy(referrerPolicy ReferrerPolicy) *CORS {
|
|
c.referrerPolicyValue = referrerPolicy.String()
|
|
return c
|
|
}
|
|
|
|
// MaxAge sets the "Access-Control-Max-Age" header value.
|
|
//
|
|
// Read more at: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS#access-control-max-age.
|
|
func (c *CORS) MaxAge(d time.Duration) *CORS {
|
|
c.maxAgeSecondsValue = strconv.FormatFloat(d.Seconds(), 'E', -1, 64)
|
|
return c
|
|
}
|
|
|
|
const (
|
|
originRequestHeader = "Origin"
|
|
allowOriginHeader = "Access-Control-Allow-Origin"
|
|
allowCredentialsHeader = "Access-Control-Allow-Credentials"
|
|
referrerPolicyHeader = "Referrer-Policy"
|
|
exposeHeadersHeader = "Access-Control-Expose-Headers"
|
|
requestMethodHeader = "Access-Control-Request-Method"
|
|
requestHeadersHeader = "Access-Control-Request-Headers"
|
|
allowMethodsHeader = "Access-Control-Allow-Methods"
|
|
allowAllMethodsValue = "*"
|
|
allowHeadersHeader = "Access-Control-Allow-Headers"
|
|
maxAgeHeader = "Access-Control-Max-Age"
|
|
varyHeader = "Vary"
|
|
)
|
|
|
|
func (c *CORS) addVaryHeaders(ctx *context.Context) {
|
|
ctx.Header(varyHeader, originRequestHeader)
|
|
|
|
if ctx.Method() == http.MethodOptions {
|
|
ctx.Header(varyHeader, requestMethodHeader)
|
|
ctx.Header(varyHeader, requestHeadersHeader)
|
|
}
|
|
}
|
|
|
|
// Handler method returns the Iris CORS Handler with basic features.
|
|
// Note that the caller should NOT modify any of the CORS instance fields afterwards.
|
|
func (c *CORS) Handler() context.Handler {
|
|
return func(ctx *context.Context) {
|
|
c.addVaryHeaders(ctx) // add vary headers at any case.
|
|
|
|
origin, ok := c.extractOriginFunc(ctx)
|
|
if !ok || !c.allowOriginFunc(ctx, origin) {
|
|
c.errorHandler(ctx, ErrOriginNotAllowed)
|
|
return
|
|
}
|
|
|
|
if origin == "" { // if we allow empty origins, set it to wildcard.
|
|
origin = "*"
|
|
}
|
|
|
|
ctx.Header(allowOriginHeader, origin)
|
|
ctx.Header(allowCredentialsHeader, c.allowCredentialsValue)
|
|
// 08 July 2021 Mozzila updated the following document: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Referrer-Policy
|
|
ctx.Header(referrerPolicyHeader, c.referrerPolicyValue)
|
|
ctx.Header(exposeHeadersHeader, c.exposeHeadersValue)
|
|
if ctx.Method() == http.MethodOptions {
|
|
ctx.Header(allowMethodsHeader, allowAllMethodsValue)
|
|
ctx.Header(allowHeadersHeader, c.allowHeadersValue)
|
|
ctx.Header(maxAgeHeader, c.maxAgeSecondsValue)
|
|
ctx.StatusCode(http.StatusNoContent)
|
|
return
|
|
}
|
|
|
|
ctx.Next()
|
|
}
|
|
}
|