diff --git a/core/router/path.go b/core/router/path.go index 9872fd42..93a20e1a 100644 --- a/core/router/path.go +++ b/core/router/path.go @@ -258,6 +258,19 @@ func splitSubdomainAndPath(fullUnparsedPath string) (subdomain string, path stri return // return subdomain without slash, path with slash } +func staticPath(src string) string { + bidx := strings.IndexByte(src, '{') + if bidx == -1 || len(src) <= bidx { + return src // no dynamic part found + } + if bidx <= 1 { // found at first{...} or second index (/{...}), + // although first index should never happen because of the prepended slash. + return "/" + } + + return src[:bidx-1] // (/static/{...} -> /static) +} + // RoutePathReverserOption option signature for the RoutePathReverser. type RoutePathReverserOption func(*RoutePathReverser) diff --git a/core/router/route.go b/core/router/route.go index cefbfa12..ec377cce 100644 --- a/core/router/route.go +++ b/core/router/route.go @@ -325,16 +325,7 @@ func (r *Route) IsStatic() bool { // if /assets/{filepath:path} it will return /assets. func (r *Route) StaticPath() string { src := r.tmpl.Src - bidx := strings.IndexByte(src, '{') - if bidx == -1 || len(src) <= bidx { - return src // no dynamic part found - } - if bidx <= 1 { // found at first{...} or second index (/{...}), - // although first index should never happen because of the prepended slash. - return "/" - } - - return src[:bidx-1] // (/static/{...} -> /static) + return staticPath(src) } // ResolvePath returns the formatted path's %v replaced with the args. diff --git a/core/router/router.go b/core/router/router.go index e6ddf8b4..60ad0201 100644 --- a/core/router/router.go +++ b/core/router/router.go @@ -21,7 +21,7 @@ type Router struct { // not indeed but we don't to risk its usage by third-parties. requestHandler RequestHandler // build-accessible, can be changed to define a custom router or proxy, used on RefreshRouter too. mainHandler http.HandlerFunc // init-accessible - wrapperFunc func(http.ResponseWriter, *http.Request, http.HandlerFunc) + wrapperFunc WrapperFunc cPool *context.Pool // used on RefreshRouter routesProvider RoutesProvider @@ -139,7 +139,7 @@ func (router *Router) BuildRouter(cPool *context.Pool, requestHandler RequestHan } if router.wrapperFunc != nil { // if wrapper used then attach that as the router service - router.mainHandler = NewWrapper(router.wrapperFunc, router.mainHandler).ServeHTTP + router.mainHandler = newWrapper(router.wrapperFunc, router.mainHandler).ServeHTTP } // build closest. @@ -180,12 +180,6 @@ func (router *Router) Downgraded() bool { return router.mainHandler != nil && router.requestHandler == nil } -// WrapperFunc is used as an expected input parameter signature -// for the WrapRouter. It's a "low-level" signature which is compatible -// with the net/http. -// It's being used to run or no run the router based on a custom logic. -type WrapperFunc func(w http.ResponseWriter, r *http.Request, firstNextIsTheRouter http.HandlerFunc) - // WrapRouter adds a wrapper on the top of the main router. // Usually it's useful for third-party middleware // when need to wrap the entire application with a middleware like CORS. @@ -196,28 +190,7 @@ type WrapperFunc func(w http.ResponseWriter, r *http.Request, firstNextIsTheRout // // Before build. func (router *Router) WrapRouter(wrapperFunc WrapperFunc) { - if wrapperFunc == nil { - return - } - - router.mu.Lock() - defer router.mu.Unlock() - - if router.wrapperFunc != nil { - // wrap into one function, from bottom to top, end to begin. - nextWrapper := wrapperFunc - prevWrapper := router.wrapperFunc - wrapperFunc = func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { - if next != nil { - nexthttpFunc := http.HandlerFunc(func(_w http.ResponseWriter, _r *http.Request) { - prevWrapper(_w, _r, next) - }) - nextWrapper(w, r, nexthttpFunc) - } - } - } - - router.wrapperFunc = wrapperFunc + router.wrapperFunc = makeWrapperFunc(router.wrapperFunc, wrapperFunc) } // ServeHTTPC serves the raw context, useful if we have already a context, it by-pass the wrapper. @@ -234,24 +207,3 @@ func (router *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (router *Router) RouteExists(ctx context.Context, method, path string) bool { return router.requestHandler.RouteExists(ctx, method, path) } - -type wrapper struct { - router http.HandlerFunc // http.HandlerFunc to catch the CURRENT state of its .ServeHTTP on case of future change. - wrapperFunc func(http.ResponseWriter, *http.Request, http.HandlerFunc) -} - -// NewWrapper returns a new http.Handler wrapped by the 'wrapperFunc' -// the "next" is the final "wrapped" input parameter. -// -// Application is responsible to make it to work on more than one wrappers -// via composition or func clojure. -func NewWrapper(wrapperFunc func(w http.ResponseWriter, r *http.Request, routerNext http.HandlerFunc), wrapped http.HandlerFunc) http.Handler { - return &wrapper{ - wrapperFunc: wrapperFunc, - router: wrapped, - } -} - -func (wr *wrapper) ServeHTTP(w http.ResponseWriter, r *http.Request) { - wr.wrapperFunc(w, r, wr.router) -} diff --git a/core/router/router_wrapper.go b/core/router/router_wrapper.go new file mode 100644 index 00000000..65400d54 --- /dev/null +++ b/core/router/router_wrapper.go @@ -0,0 +1,52 @@ +package router + +import "net/http" + +// WrapperFunc is used as an expected input parameter signature +// for the WrapRouter. It's a "low-level" signature which is compatible +// with the net/http. +// It's being used to run or no run the router based on a custom logic. +type WrapperFunc func(w http.ResponseWriter, r *http.Request, router http.HandlerFunc) + +func makeWrapperFunc(original WrapperFunc, wrapperFunc WrapperFunc) WrapperFunc { + if wrapperFunc == nil { + return original + } + + if original != nil { + // wrap into one function, from bottom to top, end to begin. + nextWrapper := wrapperFunc + prevWrapper := original + wrapperFunc = func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { + if next != nil { + nexthttpFunc := http.HandlerFunc(func(_w http.ResponseWriter, _r *http.Request) { + prevWrapper(_w, _r, next) + }) + nextWrapper(w, r, nexthttpFunc) + } + } + } + + return wrapperFunc +} + +type wrapper struct { + router http.HandlerFunc // http.HandlerFunc to catch the CURRENT state of its .ServeHTTP on case of future change. + wrapperFunc WrapperFunc +} + +// newWrapper returns a new http.Handler wrapped by the 'wrapperFunc' +// the "next" is the final "wrapped" input parameter. +// +// Application is responsible to make it to work on more than one wrappers +// via composition or func clojure. +func newWrapper(wrapperFunc WrapperFunc, wrapped http.HandlerFunc) http.Handler { + return &wrapper{ + wrapperFunc: wrapperFunc, + router: wrapped, + } +} + +func (wr *wrapper) ServeHTTP(w http.ResponseWriter, r *http.Request) { + wr.wrapperFunc(w, r, wr.router) +} diff --git a/core/router/router_wrapper_test.go b/core/router/router_wrapper_test.go new file mode 100644 index 00000000..a6d6d2d3 --- /dev/null +++ b/core/router/router_wrapper_test.go @@ -0,0 +1,51 @@ +package router + +import ( + "bytes" + "net/http" + "net/http/httptest" + "testing" +) + +func TestMakeWrapperFunc(t *testing.T) { + var ( + firstBody = []byte("1") + secondBody = []byte("2") + mainBody = []byte("3") + expectedBody = append(firstBody, append(secondBody, mainBody...)...) + ) + + pre := func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { + w.Header().Set("X-Custom", "data") + next(w, r) + } + + first := func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { + w.Write(firstBody) + next(w, r) + } + + second := func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { + w.Write(secondBody) + next(w, r) + } + + mainHandler := func(w http.ResponseWriter, r *http.Request) { + w.Write(mainBody) + } + + wrapper := makeWrapperFunc(second, first) + wrapper = makeWrapperFunc(wrapper, pre) + + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "https://iris-go.com", nil) + wrapper(w, r, mainHandler) + + if got := w.Body.Bytes(); !bytes.Equal(expectedBody, got) { + t.Fatalf("expected boy: %s but got: %s", string(expectedBody), string(got)) + } + + if expected, got := "data", w.Header().Get("X-Custom"); expected != got { + t.Fatalf("expected x-custom header: %s but got: %s", expected, got) + } +} diff --git a/iris.go b/iris.go index 33f1ac82..0ef15555 100644 --- a/iris.go +++ b/iris.go @@ -248,7 +248,7 @@ func (app *Application) WWW() router.Party { // Example: https://github.com/kataras/iris/tree/master/_examples/routing/subdomains/redirect func (app *Application) SubdomainRedirect(from, to router.Party) router.Party { sd := router.NewSubdomainRedirectWrapper(app.ConfigurationReadOnly().GetVHost, from.GetRelPath(), to.GetRelPath()) - app.WrapRouter(sd) + app.Router.WrapRouter(sd) return to } @@ -392,13 +392,13 @@ func (app *Application) RegisterView(viewEngine view.Engine) { func (app *Application) View(writer io.Writer, filename string, layout string, bindingData interface{}) error { if app.view.Len() == 0 { err := errors.New("view engine is missing, use `RegisterView`") - app.Logger().Error(err) + app.logger.Error(err) return err } err := app.view.ExecuteWriter(writer, filename, layout, bindingData) if err != nil { - app.Logger().Error(err) + app.logger.Error(err) } return err } @@ -1103,7 +1103,7 @@ func (app *Application) Run(serve Runner, withOrWithout ...Configurator) error { // this will block until an error(unless supervisor's DeferFlow called from a Task). err := serve(app) if err != nil { - app.Logger().Error(err) + app.logger.Error(err) } return err @@ -1179,7 +1179,7 @@ func (app *Application) tryInjectLiveReload() error { bodyCloseTag := []byte("") - app.WrapRouter(func(w http.ResponseWriter, r *http.Request, _ http.HandlerFunc) { + app.Router.WrapRouter(func(w http.ResponseWriter, r *http.Request, _ http.HandlerFunc) { ctx := app.ContextPool.Acquire(w, r) rec := ctx.Recorder() // Record everything and write all in once at the Context release. app.ServeHTTPC(ctx) // We directly call request handler with Context. @@ -1234,7 +1234,7 @@ func (app *Application) tryStartTunneling() { var publicAddr string err := tc.startTunnel(t, &publicAddr) if err != nil { - app.Logger().Errorf("Host: tunneling error: %v", err) + app.logger.Errorf("Host: tunneling error: %v", err) return } @@ -1242,7 +1242,7 @@ func (app *Application) tryStartTunneling() { app.config.vhost = publicAddr[strings.Index(publicAddr, "://")+3:] directLog := []byte(fmt.Sprintf("• Public Address: %s\n", publicAddr)) - app.Logger().Printer.Write(directLog) + app.logger.Printer.Write(directLog) } }) })