diff --git a/context/context.go b/context/context.go index 29907a07..600c3227 100644 --- a/context/context.go +++ b/context/context.go @@ -135,7 +135,10 @@ type Context struct { // to true on `Next` call when its called on the last handler in the chain. // Reports whether a `Next` is called, // even if the handler index remains the same (last handler). - proceeded uint32 + // + // Also it's responsible to keep the old value of the last known handler index + // before StopExecution. See ResumeExecution. + proceeded int } // NewContext returns a new Context instance. @@ -178,7 +181,7 @@ func (ctx *Context) Clone() *Context { writer: ctx.writer.Clone(), request: req, currentHandlerIndex: stopExecutionIndex, - proceeded: atomic.LoadUint32(&ctx.proceeded), + proceeded: ctx.proceeded, currentRoute: ctx.currentRoute, } } @@ -429,8 +432,12 @@ func (ctx *Context) ResponseWriter() ResponseWriter { // ResetResponseWriter sets a new ResponseWriter implementation // to this Context to use as its writer. // Note, to change the underline http.ResponseWriter use -// ctx.ResponseWriter().SetWriter(http.ResponseWRiter) instead. +// ctx.ResponseWriter().SetWriter(http.ResponseWriter) instead. func (ctx *Context) ResetResponseWriter(newResponseWriter ResponseWriter) { + if rec, ok := ctx.IsRecording(); ok { + releaseResponseRecorder(rec) + } + ctx.writer = newResponseWriter } @@ -560,24 +567,30 @@ func (ctx *Context) HandlerIndex(n int) (currentIndex int) { //} // Alternative way is `!ctx.IsStopped()` if middleware make use of the `ctx.StopExecution()` on failure. func (ctx *Context) Proceed(h Handler) bool { - beforeIdx := ctx.currentHandlerIndex - atomic.StoreUint32(&ctx.proceeded, 0) - h(ctx) + ctx.proceeded = internalPauseExecutionIndex - if ctx.currentHandlerIndex == stopExecutionIndex { + // Store the current index. + beforeIdx := ctx.currentHandlerIndex + h(ctx) + // Retrieve the next one, if Next is called this is beforeIdx + 1 and so on. + afterIdx := ctx.currentHandlerIndex + // Restore prev index, no matter what. + ctx.currentHandlerIndex = beforeIdx + + proceededByNext := ctx.proceeded == internalProceededHandlerIndex + ctx.proceeded = beforeIdx + + // Stop called, return false but keep the handlers index. + if afterIdx == stopExecutionIndex { return false } - if ctx.currentHandlerIndex <= beforeIdx { - // If "h" didn't call its Next - // or it doesn't have a next handler, - // that index will be the same, - // so we check if at least once the - // Next is called on the last handler. - return atomic.CompareAndSwapUint32(&ctx.proceeded, 1, 0) + if proceededByNext { + return true } - return true + // Next called or not. + return afterIdx > beforeIdx } // HandlerName returns the current handler's name, helpful for debugging. @@ -608,14 +621,15 @@ func (ctx *Context) Next() { return } - nextIndex := ctx.currentHandlerIndex + 1 - handlers := ctx.handlers + if ctx.proceeded <= internalPauseExecutionIndex /* pause and proceeded */ { + ctx.proceeded = internalProceededHandlerIndex + return + } - if n := len(handlers); nextIndex == n { - atomic.StoreUint32(&ctx.proceeded, 1) // last handler but Next is called. - } else if nextIndex < n { + nextIndex, n := ctx.currentHandlerIndex+1, len(ctx.handlers) + if nextIndex < n { ctx.currentHandlerIndex = nextIndex - handlers[nextIndex](ctx) + ctx.handlers[nextIndex](ctx) } } @@ -672,13 +686,27 @@ func (ctx *Context) Skip() { ctx.HandlerIndex(ctx.currentHandlerIndex + 1) } -const stopExecutionIndex = -1 // I don't set to a max value because we want to be able to reuse the handlers even if stopped with .Skip +const ( + stopExecutionIndex = -1 + internalPauseExecutionIndex = -2 + internalProceededHandlerIndex = -3 +) // StopExecution stops the handlers chain of this request. // Meaning that any following `Next` calls are ignored, // as a result the next handlers in the chain will not be fire. +// +// See ResumeExecution too. func (ctx *Context) StopExecution() { - ctx.currentHandlerIndex = stopExecutionIndex + if curIdx := ctx.currentHandlerIndex; curIdx != stopExecutionIndex { + // Protect against multiple calls of StopExecution. + // Resume should set the last proceeded handler index. + // Store the current index. + ctx.proceeded = curIdx + // And stop. + ctx.currentHandlerIndex = stopExecutionIndex + } + } // IsStopped reports whether the current position of the context's handlers is -1, @@ -687,6 +715,19 @@ func (ctx *Context) IsStopped() bool { return ctx.currentHandlerIndex == stopExecutionIndex } +// ResumeExecution sets the current handler index to the last +// index of the executed handler before StopExecution method was fired. +// +// Reports whether it's restored after a StopExecution call. +func (ctx *Context) ResumeExecution() bool { + if ctx.IsStopped() { + ctx.currentHandlerIndex = ctx.proceeded + return true + } + + return false +} + // StopWithStatus stops the handlers chain and writes the "statusCode". // // If the status code is a failure one then @@ -1377,7 +1418,7 @@ func (ctx *Context) StatusCode(statusCode int) { // to be executed. Next handlers are being executed on iris because you can alt the // error code and change it to a more specific one, i.e // users := app.Party("/users") -// users.Done(func(ctx iris.Context){ if ctx.StatusCode() == 400 { /* custom error code for /users */ }}) +// users.Done(func(ctx iris.Context){ if ctx.GetStatusCode() == 400 { /* custom error code for /users */ }}) func (ctx *Context) NotFound() { ctx.StatusCode(http.StatusNotFound) } @@ -5570,7 +5611,7 @@ func IsErrPanicRecovery(err error) (*ErrPanicRecovery, bool) { // IsRecovered reports whether this handler has been recovered // by the Iris recover middleware. func (ctx *Context) IsRecovered() (*ErrPanicRecovery, bool) { - if ctx.GetStatusCode() == 500 { + if ctx.GetStatusCode() == http.StatusInternalServerError { // Panic error from recovery middleware is private. return IsErrPanicRecovery(ctx.GetErr()) } diff --git a/core/router/api_container.go b/core/router/api_container.go index d566b0c5..4bf7fe36 100644 --- a/core/router/api_container.go +++ b/core/router/api_container.go @@ -105,10 +105,12 @@ func (api *APIContainer) convertHandlerFuncs(relativePath string, handlersFn ... handlers = append(handlers, api.Container.HandlerWithParams(h, paramsCount)) } + // Note: let end-developer to decide that through Party.SetExecutionRules. // On that type of handlers the end-developer does not have to include the Context in the handler, // so the ctx.Next is automatically called unless an `ErrStopExecution` returned (implementation inside hero pkg). - o := ExecutionOptions{Force: true} - o.apply(&handlers) + // + // o := ExecutionOptions{Force: true} + // o.apply(&handlers) return handlers } diff --git a/core/router/router_handlers_order_test.go b/core/router/router_handlers_order_test.go index 49f97200..0c9c96ff 100644 --- a/core/router/router_handlers_order_test.go +++ b/core/router/router_handlers_order_test.go @@ -362,3 +362,57 @@ func TestUseWrapOrder(t *testing.T) { e.GET("/NotFound").Expect().Status(iris.StatusNotFound).Body().Equal(expectedNotFoundBody) e.GET("/").Expect().Status(iris.StatusOK).Body().Equal(expectedBody) } + +func TestResumeExecution(t *testing.T) { + before := func(ctx iris.Context) { + ctx.WriteString("1") + + curIdx := ctx.HandlerIndex(-1) + + ctx.StopExecution() + ctx.Next() + ctx.StopExecution() + ctx.Next() + ctx.ResumeExecution() + + if ctx.HandlerIndex(-1) != curIdx { + ctx.WriteString("| 1. NOT OK") + } + + ctx.StopExecution() + ctx.ResumeExecution() + + if ctx.HandlerIndex(-1) != curIdx { + ctx.WriteString("| 2. NOT OK") + } + + ctx.Next() + + if ctx.HandlerIndex(-1) != curIdx+2 /* 2 and 3 */ { + ctx.WriteString("| 3. NOT OK") + } + } + + handler := func(ctx iris.Context) { + ctx.WriteString("2") + ctx.Next() + } + + after := func(ctx iris.Context) { + ctx.WriteString("3") + + if !ctx.Proceed(func(ctx iris.Context) { + ctx.Next() + }) { + ctx.WriteString(" | 4. NOT OK") + } + } + + expectedBody := "123" + + app := iris.New() + app.Get("/", before, handler, after) + + e := httptest.New(t, app) + e.GET("/").Expect().Status(iris.StatusOK).Body().Equal(expectedBody) +} diff --git a/middleware/basicauth/basicauth.go b/middleware/basicauth/basicauth.go index ff809880..a044c76a 100644 --- a/middleware/basicauth/basicauth.go +++ b/middleware/basicauth/basicauth.go @@ -443,6 +443,7 @@ func (b *BasicAuth) serveHTTP(ctx *context.Context) { user = &context.SimpleUser{ Authorization: authorizationType, AuthorizedAt: authorizedAt, + ID: username, Username: username, Password: password, }