From 5cfe19c79377431ee86b4bca3d2cdae4222bc39a Mon Sep 17 00:00:00 2001 From: Makis Maropoulos Date: Wed, 13 Jul 2016 05:02:43 +0300 Subject: [PATCH] Check for cors middleware conflicts on mux --- http.go | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/http.go b/http.go index 9f8b02f8..02e44f9c 100644 --- a/http.go +++ b/http.go @@ -1260,6 +1260,24 @@ func (r *route) SetMiddleware(m Middleware) { r.middleware = m } +// RouteConflicts checks for route's middleware conflicts +func RouteConflicts(r *route, with string) bool { + for _, h := range r.middleware { + if m, ok := h.(interface { + Conflicts() string + }); ok { + if c := m.Conflicts(); c == with { + return true + } + } + } + return false +} + +func (r *route) hasCors() bool { + return RouteConflicts(r, "httpmethod") +} + const ( // subdomainIndicator where './' exists in a registed path then it contains subdomain subdomainIndicator = "./" @@ -1443,6 +1461,20 @@ func (mux *serveMux) ServeRequest() fasthttp.RequestHandler { getRequestPath = func(reqCtx *fasthttp.RequestCtx) string { return utils.BytesToString(reqCtx.RequestURI()) } } + methodEqual := func(treeMethod []byte, reqMethod []byte) bool { + return bytes.Equal(treeMethod, reqMethod) + } + + // check for cors conflicts + for _, r := range mux.lookups { + if r.hasCors() { + methodEqual = func(treeMethod []byte, reqMethod []byte) bool { + return bytes.Equal(treeMethod, reqMethod) || bytes.Equal(reqMethod, methodOptionsBytes) + } + break + } + } + return func(reqCtx *fasthttp.RequestCtx) { context := mux.cPool.Get().(*Context) context.Reset(reqCtx) @@ -1450,7 +1482,7 @@ func (mux *serveMux) ServeRequest() fasthttp.RequestHandler { routePath := getRequestPath(reqCtx) tree := mux.tree for tree != nil { - if !bytes.Equal(tree.method, reqCtx.Method()) { + if !methodEqual(tree.method, reqCtx.Method()) { // we break any CORS OPTIONS method // but for performance reasons if user wants http method OPTIONS to be served // then must register it with .Options(...)