package cors

import (
	"errors"
	"net/http"
	"regexp"
	"strconv"
	"strings"
	"time"

	"github.com/kataras/iris/v12/context"
)

func init() {
	context.SetHandlerName("iris/middleware/cors.*", "iris.cors")
}

var (
	// ErrOriginNotAllowed is given to the error handler
	// when the error is caused because an origin was not allowed to pass through.
	ErrOriginNotAllowed = errors.New("origin not allowed")

	// AllowAnyOrigin allows all origins to pass.
	AllowAnyOrigin = func(_ *context.Context, _ string) bool {
		return true
	}

	// DefaultErrorHandler is the default error handler which
	// fires forbidden status (403) on disallowed origins.
	DefaultErrorHandler = func(ctx *context.Context, _ error) {
		ctx.StopWithStatus(http.StatusForbidden)
	}

	// DefaultOriginExtractor is the default method which
	// an origin is extracted. It returns the value of the request's "Origin" header
	// and always true, means that it allows empty origin headers as well.
	DefaultOriginExtractor = func(ctx *context.Context) (string, bool) {
		header := ctx.GetHeader(originRequestHeader)
		return header, true
	}

	// StrictOriginExtractor is an ExtractOriginFunc type
	// which is a bit more strictly than the DefaultOriginExtractor.
	// It allows only non-empty "Origin" header values to be passed.
	// If the header is missing, the middleware will not allow the execution
	// of the next handler(s).
	StrictOriginExtractor = func(ctx *context.Context) (string, bool) {
		header := ctx.GetHeader(originRequestHeader)
		return header, header != ""
	}
)

type (
	// ExtractOriginFunc describes the function which should return the request's origin or false.
	ExtractOriginFunc = func(ctx *context.Context) (string, bool)

	// AllowOriginFunc describes the function which is called when the
	// middleware decides if the request's origin should be allowed or not.
	AllowOriginFunc = func(ctx *context.Context, origin string) bool

	// HandleErrorFunc describes the function which is fired
	// when a request by a specific (or empty) origin was not allowed to pass through.
	HandleErrorFunc = func(ctx *context.Context, err error)

	// CORS holds the customizations developers can
	// do on the cors middleware.
	//
	// Read more at: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS.
	CORS struct {
		extractOriginFunc ExtractOriginFunc
		allowOriginFunc   AllowOriginFunc
		errorHandler      HandleErrorFunc

		allowCredentialsValue string
		exposeHeadersValue    string
		allowHeadersValue     string
		allowMethodsValue     string
		maxAgeSecondsValue    string
		referrerPolicyValue   string
	}
)

// New returns the default CORS middleware.
// For a more advanced type of protection middleware with more options
// please refer to: https://github.com/iris-contrib/middleware repository instead.
//
// Example Code:
//
//		import "github.com/kataras/iris/v12/middleware/cors"
//	 import "github.com/kataras/iris/v12/x/errors"
//
//	 app.UseRouter(cors.New().
//	     HandleErrorFunc(func(ctx iris.Context, err error) {
//	         errors.FailedPrecondition.Err(ctx, err)
//	     }).
//	     ExtractOriginFunc(cors.StrictOriginExtractor).
//	     ReferrerPolicy(cors.NoReferrerWhenDowngrade).
//	     AllowOrigin("domain1.com,domain2.com,domain3.com").
//	     Handler())
func New() *CORS {
	return &CORS{
		extractOriginFunc: DefaultOriginExtractor,
		allowOriginFunc:   AllowAnyOrigin,
		errorHandler:      DefaultErrorHandler,

		allowCredentialsValue: "true",
		exposeHeadersValue:    "*, Authorization, X-Authorization",
		allowHeadersValue:     "*",
		// This field cannot be modified by the end-developer,
		// as we have another type of controlling the HTTP verbs per handler.
		allowMethodsValue:   "*",
		maxAgeSecondsValue:  "86400",
		referrerPolicyValue: NoReferrerWhenDowngrade.String(),
	}
}

// ExtractOriginFunc sets the function which should return the request's origin.
func (c *CORS) ExtractOriginFunc(fn ExtractOriginFunc) *CORS {
	c.extractOriginFunc = fn
	return c
}

// AllowOriginFunc sets the function which decides if an origin(domain) is allowed
// to continue or not.
//
// Read more at: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS#access-control-allow-origin.
func (c *CORS) AllowOriginFunc(fn AllowOriginFunc) *CORS {
	c.allowOriginFunc = fn
	return c
}

// AllowOrigin calls the "AllowOriginFunc" method
// and registers a function which accepts any incoming
// request with origin of the given "originLine".
// The originLine can contain one or more domains separated by comma.
// See "AllowOrigins" to set a list of strings instead.
func (c *CORS) AllowOrigin(originLine string) *CORS {
	return c.AllowOrigins(strings.Split(originLine, ",")...)
}

// AllowOriginMatcherFunc sets the allow origin func without iris.Context
// as its first parameter, i.e. a regular expression.
func (c *CORS) AllowOriginMatcherFunc(fn func(origin string) bool) *CORS {
	return c.AllowOriginFunc(func(ctx *context.Context, origin string) bool {
		return fn(origin)
	})
}

// AllowOriginRegex calls the "AllowOriginFunc" method
// and registers a function which accepts any incoming
// request with origin that matches at least one of the given "regexpLines".
func (c *CORS) AllowOriginRegex(regexpLines ...string) *CORS {
	matchers := make([]func(string) bool, 0, len(regexpLines))
	for _, line := range regexpLines {
		matcher := regexp.MustCompile(line).MatchString
		matchers = append(matchers, matcher)
	}

	return c.AllowOriginFunc(func(ctx *context.Context, origin string) bool {
		for _, m := range matchers {
			if m(origin) {
				return true
			}
		}

		return false
	})
}

// AllowOrigins calls the "AllowOriginFunc" method
// and registers a function which accepts any incoming
// request with origin of one of the given "origins".
func (c *CORS) AllowOrigins(origins ...string) *CORS {
	allowOrigins := make(map[string]struct{}, len(origins)) // read-only at serve time.
	for _, origin := range origins {
		if origin == "*" {
			// If AllowOrigins called with asterix, it is a missuse of this
			// middleware (set AllowAnyOrigin instead).
			allowOrigins = nil
			return c.AllowOriginFunc(AllowAnyOrigin)
			// panic("wildcard is not allowed, use AllowOriginFunc(AllowAnyOrigin) instead")
			// No ^ let's register a function which allows all and continue.
		}

		origin = strings.TrimSpace(origin)
		allowOrigins[origin] = struct{}{}
	}

	return c.AllowOriginFunc(func(ctx *context.Context, origin string) bool {
		_, allow := allowOrigins[origin]
		return allow
	})
}

// HandleErrorFunc sets the function which is called
// when an error of origin not allowed is fired.
func (c *CORS) HandleErrorFunc(fn HandleErrorFunc) *CORS {
	c.errorHandler = fn
	return c
}

// DisallowCredentials sets the "Access-Control-Allow-Credentials" header to false.
//
// Read more at: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS#access-control-allow-credentials.
func (c *CORS) DisallowCredentials() *CORS {
	c.allowCredentialsValue = "false"
	return c
}

// ExposeHeaders sets the "Access-Control-Expose-Headers" header value.
//
// Read more at: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS#access-control-expose-headers.
func (c *CORS) ExposeHeaders(headers ...string) *CORS {
	c.exposeHeadersValue = strings.Join(headers, ", ")
	return c
}

// AllowHeaders sets the "Access-Control-Allow-Headers" header value.
//
// Read more at: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS#access-control-allow-headers.
func (c *CORS) AllowHeaders(headers ...string) *CORS {
	c.allowHeadersValue = strings.Join(headers, ", ")
	return c
}

// ReferrerPolicy type for referrer-policy header value.
type ReferrerPolicy string

// All available referrer policies.
// Read more at: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Referrer-Policy.
const (
	NoReferrer                  ReferrerPolicy = "no-referrer"
	NoReferrerWhenDowngrade     ReferrerPolicy = "no-referrer-when-downgrade"
	Origin                      ReferrerPolicy = "origin"
	OriginWhenCrossOrigin       ReferrerPolicy = "origin-when-cross-origin"
	SameOrigin                  ReferrerPolicy = "same-origin"
	StrictOrigin                ReferrerPolicy = "strict-origin"
	StrictOriginWhenCrossOrigin ReferrerPolicy = "strict-origin-when-cross-origin"
	UnsafeURL                   ReferrerPolicy = "unsafe-url"
)

// String returns the text representation of the "r" ReferrerPolicy.
func (r ReferrerPolicy) String() string {
	return string(r)
}

// ReferrerPolicy sets the "Referrer-Policy" header value.
// Defaults to "no-referrer-when-downgrade".
//
// Read more at: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Referrer-Policy
// and https://developer.mozilla.org/en-US/docs/Web/Security/Referer_header:_privacy_and_security_concerns.
func (c *CORS) ReferrerPolicy(referrerPolicy ReferrerPolicy) *CORS {
	c.referrerPolicyValue = referrerPolicy.String()
	return c
}

// MaxAge sets the "Access-Control-Max-Age" header value.
//
// Read more at: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS#access-control-max-age.
func (c *CORS) MaxAge(d time.Duration) *CORS {
	c.maxAgeSecondsValue = strconv.FormatFloat(d.Seconds(), 'E', -1, 64)
	return c
}

const (
	originRequestHeader    = "Origin"
	allowOriginHeader      = "Access-Control-Allow-Origin"
	allowCredentialsHeader = "Access-Control-Allow-Credentials"
	referrerPolicyHeader   = "Referrer-Policy"
	exposeHeadersHeader    = "Access-Control-Expose-Headers"
	requestMethodHeader    = "Access-Control-Request-Method"
	requestHeadersHeader   = "Access-Control-Request-Headers"
	allowMethodsHeader     = "Access-Control-Allow-Methods"
	allowAllMethodsValue   = "*"
	allowHeadersHeader     = "Access-Control-Allow-Headers"
	maxAgeHeader           = "Access-Control-Max-Age"
	varyHeader             = "Vary"
)

func (c *CORS) addVaryHeaders(ctx *context.Context) {
	ctx.Header(varyHeader, originRequestHeader)

	if ctx.Method() == http.MethodOptions {
		ctx.Header(varyHeader, requestMethodHeader)
		ctx.Header(varyHeader, requestHeadersHeader)
	}
}

// Handler method returns the Iris CORS Handler with basic features.
// Note that the caller should NOT modify any of the CORS instance fields afterwards.
func (c *CORS) Handler() context.Handler {
	return func(ctx *context.Context) {
		c.addVaryHeaders(ctx) // add vary headers at any case.

		origin, ok := c.extractOriginFunc(ctx)
		if !ok || !c.allowOriginFunc(ctx, origin) {
			c.errorHandler(ctx, ErrOriginNotAllowed)
			return
		}

		if origin == "" { // if we allow empty origins, set it to wildcard.
			origin = "*"
		}

		ctx.Header(allowOriginHeader, origin)
		ctx.Header(allowCredentialsHeader, c.allowCredentialsValue)
		// 08 July 2021 Mozzila updated the following document: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Referrer-Policy
		ctx.Header(referrerPolicyHeader, c.referrerPolicyValue)
		ctx.Header(exposeHeadersHeader, c.exposeHeadersValue)
		if ctx.Method() == http.MethodOptions {
			ctx.Header(allowMethodsHeader, allowAllMethodsValue)
			ctx.Header(allowHeadersHeader, c.allowHeadersValue)
			ctx.Header(maxAgeHeader, c.maxAgeSecondsValue)
			ctx.StatusCode(http.StatusNoContent)
			return
		}

		ctx.Next()
	}
}