From 4eb7705fae98fd664e838b80a6a622fe563eb104 Mon Sep 17 00:00:00 2001 From: "Gerasimos (Makis) Maropoulos" Date: Sat, 20 Jan 2024 20:32:56 +0200 Subject: [PATCH] minor improvements --- .../routing/http-wire-errors/service/main.go | 11 +- _examples/routing/macros/main.go | 2 +- _proposals/route_builder.md | 178 ++++++++++++++++++ _proposals/xerrors_party.md | 95 ++++++++++ context/context.go | 23 ++- context/handler.go | 14 ++ core/router/api_builder.go | 9 +- iris_guide.go | 6 +- macro/macro.go | 11 +- macro/macros.go | 45 ++--- middleware/accesslog/accesslog.go | 2 +- middleware/recover/recover.go | 120 ++++++------ x/errors/handlers.go | 44 +++-- 13 files changed, 444 insertions(+), 116 deletions(-) create mode 100644 _proposals/route_builder.md create mode 100644 _proposals/xerrors_party.md diff --git a/_examples/routing/http-wire-errors/service/main.go b/_examples/routing/http-wire-errors/service/main.go index f97fcd01..e46f2a9d 100644 --- a/_examples/routing/http-wire-errors/service/main.go +++ b/_examples/routing/http-wire-errors/service/main.go @@ -105,6 +105,7 @@ type ( // // r.Post("/", errors.Validation(validateCreateRequest), createHandler(service)) // [more code here...] +// // func validateCreateRequest(ctx iris.Context, r *CreateRequest) error { // return validation.Join( // validation.String("fullname", r.Fullname).NotEmpty().Fullname().Length(3, 50), @@ -162,7 +163,7 @@ func (r *CreateRequest) HandleResponse(ctx iris.Context, resp *CreateResponse) e } */ -func afterServiceCallButBeforeDataSent(ctx iris.Context, req *CreateRequest, resp *CreateResponse) error { +func afterServiceCallButBeforeDataSent(ctx iris.Context, req CreateRequest, resp *CreateResponse) error { fmt.Printf("intercept: request got: %+v\nresponse sent: %#+v\n", req, resp) return nil } @@ -231,10 +232,18 @@ func (s *myService) ListPaginated(ctx context.Context, opts pagination.ListOptio return filteredResp, len(all), nil // errors.New("list paginated: test error") } +func (s *myService) GetByID(ctx context.Context, id string) (CreateResponse, error) { + return CreateResponse{Firstname: "Gerasimos"}, nil // errors.New("get by id: test error") +} + func (s *myService) Delete(ctx context.Context, id string) error { return nil // errors.New("delete: test error") } +func (s *myService) Update(ctx context.Context, req CreateRequest) (bool, error) { + return true, nil // false, errors.New("update: test error") +} + func (s *myService) DeleteWithFeedback(ctx context.Context, id string) (bool, error) { return true, nil // false, errors.New("delete: test error") } diff --git a/_examples/routing/macros/main.go b/_examples/routing/macros/main.go index 2a1aafca..9c4c9be1 100644 --- a/_examples/routing/macros/main.go +++ b/_examples/routing/macros/main.go @@ -17,7 +17,7 @@ func main() { app := iris.New() app.Logger().SetLevel("debug") - app.Macros().Register("slice", "", false, true, func(paramValue string) (interface{}, bool) { + app.Macros().Register("slice", "", []string{}, false, true, func(paramValue string) (interface{}, bool) { return strings.Split(paramValue, "/"), true }).RegisterFunc("contains", func(expectedItems []string) func(paramValue []string) bool { sort.Strings(expectedItems) diff --git a/_proposals/route_builder.md b/_proposals/route_builder.md new file mode 100644 index 00000000..38b4a47e --- /dev/null +++ b/_proposals/route_builder.md @@ -0,0 +1,178 @@ +```go +package main + +import ( + "fmt" + "strings" + + "github.com/kataras/iris/v12/macro" +) + +func main() { + path := NewRouteBuilder(). + Path("/user"). + String("name", "prefix(ma)", "suffix(kis)"). + Int("age"). + Path("/friends"). + Wildcard("rest"). + Build() + + fmt.Println(path) +} + +type RouteBuilder struct { + path string +} + +func NewRouteBuilder() *RouteBuilder { + return &RouteBuilder{ + path: "/", + } +} + +func (r *RouteBuilder) Path(path string) *RouteBuilder { + if path[0] != '/' { + path = "/" + path + } + + r.path = strings.TrimSuffix(r.path, "/") + path + return r +} + +type StaticPathBuilder interface { + Path(path string) *RouteBuilder +} + +func (r *RouteBuilder) Param(param ParamBuilder) *RouteBuilder { // StaticPathBuilder { + path := "" // keep it here, a single call to r.Path must be done. + if len(r.path) == 0 || r.path[len(r.path)-1] != '/' { + path += "/" // if for some reason no prior Path("/") was called for delimeter between path parameter. + } + + path += fmt.Sprintf("{%s:%s", param.GetName(), param.GetParamType().Indent()) + if funcs := param.GetFuncs(); len(funcs) > 0 { + path += fmt.Sprintf(" %s", strings.Join(funcs, " ")) + } + path += "}" + + return r.Path(path) +} + +func (r *RouteBuilder) String(name string, funcs ...string) *RouteBuilder { + return r.Param(Param(macro.String, name, funcs...)) +} + +func (r *RouteBuilder) Int(name string, funcs ...string) *RouteBuilder { + return r.Param(Param(macro.Int, name, funcs...)) +} + +func (r *RouteBuilder) Int8(name string, funcs ...string) *RouteBuilder { + return r.Param(Param(macro.Int8, name, funcs...)) +} + +func (r *RouteBuilder) Int16(name string, funcs ...string) *RouteBuilder { + return r.Param(Param(macro.Int16, name, funcs...)) +} + +func (r *RouteBuilder) Int32(name string, funcs ...string) *RouteBuilder { + return r.Param(Param(macro.Int32, name, funcs...)) +} + +func (r *RouteBuilder) Int64(name string, funcs ...string) *RouteBuilder { + return r.Param(Param(macro.Int64, name, funcs...)) +} + +func (r *RouteBuilder) Uint(name string, funcs ...string) *RouteBuilder { + return r.Param(Param(macro.Uint, name, funcs...)) +} + +func (r *RouteBuilder) Uint8(name string, funcs ...string) *RouteBuilder { + return r.Param(Param(macro.Uint8, name, funcs...)) +} + +func (r *RouteBuilder) Uint16(name string, funcs ...string) *RouteBuilder { + return r.Param(Param(macro.Uint16, name, funcs...)) +} + +func (r *RouteBuilder) Uint32(name string, funcs ...string) *RouteBuilder { + return r.Param(Param(macro.Uint32, name, funcs...)) +} + +func (r *RouteBuilder) Uint64(name string, funcs ...string) *RouteBuilder { + return r.Param(Param(macro.Uint64, name, funcs...)) +} + +func (r *RouteBuilder) Bool(name string, funcs ...string) *RouteBuilder { + return r.Param(Param(macro.Bool, name, funcs...)) +} + +func (r *RouteBuilder) Alphabetical(name string, funcs ...string) *RouteBuilder { + return r.Param(Param(macro.Alphabetical, name, funcs...)) +} + +func (r *RouteBuilder) File(name string, funcs ...string) *RouteBuilder { + return r.Param(Param(macro.File, name, funcs...)) +} + +func (r *RouteBuilder) Wildcard(name string, funcs ...string) *RouteBuilder { + return r.Param(Param(macro.Path, name, funcs...)) +} + +func (r *RouteBuilder) UUID(name string, funcs ...string) *RouteBuilder { + return r.Param(Param(macro.UUID, name, funcs...)) +} + +func (r *RouteBuilder) Mail(name string, funcs ...string) *RouteBuilder { + return r.Param(Param(macro.Mail, name, funcs...)) +} + +func (r *RouteBuilder) Email(name string, funcs ...string) *RouteBuilder { + return r.Param(Param(macro.Email, name, funcs...)) +} + +func (r *RouteBuilder) Date(name string, funcs ...string) *RouteBuilder { + return r.Param(Param(macro.Date, name, funcs...)) +} + +func (r *RouteBuilder) Weekday(name string, funcs ...string) *RouteBuilder { + return r.Param(Param(macro.Weekday, name, funcs...)) +} + +func (r *RouteBuilder) Build() string { + return r.path +} + +type ParamBuilder interface { + GetName() string + GetFuncs() []string + GetParamType() *macro.Macro +} + +type pathParam struct { + Name string + Funcs []string + ParamType *macro.Macro +} + +var _ ParamBuilder = (*pathParam)(nil) + +func Param(paramType *macro.Macro, name string, funcs ...string) ParamBuilder { + return &pathParam{ + Name: name, + ParamType: paramType, + Funcs: funcs, + } +} + +func (p *pathParam) GetName() string { + return p.Name +} + +func (p *pathParam) GetParamType() *macro.Macro { + return p.ParamType +} + +func (p *pathParam) GetFuncs() []string { + return p.Funcs +} +``` diff --git a/_proposals/xerrors_party.md b/_proposals/xerrors_party.md new file mode 100644 index 00000000..519b1f2c --- /dev/null +++ b/_proposals/xerrors_party.md @@ -0,0 +1,95 @@ +```go +app.PartyConfigure("/api", errors.NewParty[CreateRequest, CreateResponse, ListFilter](). + Create(service.Create). + Update(service.Update). + Delete(service.DeleteWithFeedback). + List(service.ListPaginated). + Get(service.GetByID).Validation(validateCreateRequest)) +``` + +```go +type Party[T, R, F any] struct { + validators []ContextRequestFunc[T] + filterValidators []ContextRequestFunc[F] + filterIntercepters []ContextResponseFunc[F, R] + intercepters []ContextResponseFunc[T, R] + + serviceCreateFunc func(stdContext.Context, T) (R, error) + serviceUpdateFunc func(stdContext.Context, T) (bool, error) + serviceDeleteFunc func(stdContext.Context, string) (bool, error) + serviceListFunc func(stdContext.Context, pagination.ListOptions, F /* filter options */) ([]R, int, error) + serviceGetFunc func(stdContext.Context, string) (R, error) +} + +func (p *Party[T, R, F]) Configure(r router.Party) { + if p.serviceCreateFunc != nil { + r.Post("/", Validation(p.validators...), Intercept(p.intercepters...), CreateHandler(p.serviceCreateFunc)) + } + + if p.serviceUpdateFunc != nil { + r.Put("/{id:string}", Validation(p.validators...), Intercept(p.intercepters...), NoContentOrNotModifiedHandler(p.serviceUpdateFunc)) + } + + if p.serviceListFunc != nil { + r.Post("/list", Validation(p.filterValidators...), Intercept(p.filterIntercepters...), ListHandler(p.serviceListFunc)) + } + + if p.serviceDeleteFunc != nil { + r.Delete("/{id:string}", NoContentOrNotModifiedHandler(p.serviceDeleteFunc, PathParam[string]("id"))) + } + + if p.serviceGetFunc != nil { + r.Get("/{id:string}", Handler(p.serviceGetFunc, PathParam[string]("id"))) + } +} + +func NewParty[T, R, F any]() *Party[T, R, F] { + return &Party[T, R, F]{} +} + +func (p *Party[T, R, F]) Validation(validators ...ContextRequestFunc[T]) *Party[T, R, F] { + p.validators = append(p.validators, validators...) + return p +} + +func (p *Party[T, R, F]) FilterValidation(filterValidators ...ContextRequestFunc[F]) *Party[T, R, F] { + p.filterValidators = append(p.filterValidators, filterValidators...) + return p +} + +func (p *Party[T, R, F]) Intercept(intercepters ...ContextResponseFunc[T, R]) *Party[T, R, F] { + p.intercepters = append(p.intercepters, intercepters...) + return p +} + +func (p *Party[T, R, F]) FilterIntercept(filterIntercepters ...ContextResponseFunc[F, R]) *Party[T, R, F] { + p.filterIntercepters = append(p.filterIntercepters, filterIntercepters...) + return p +} + +func (p *Party[T, R, F]) Create(fn func(stdContext.Context, T) (R, error)) *Party[T, R, F] { + p.serviceCreateFunc = fn + return p +} + +func (p *Party[T, R, F]) Update(fn func(stdContext.Context, T) (bool, error)) *Party[T, R, F] { + p.serviceUpdateFunc = fn + return p +} + +func (p *Party[T, R, F]) Delete(fn func(stdContext.Context, string) (bool, error)) *Party[T, R, F] { + p.serviceDeleteFunc = fn + return p +} + +func (p *Party[T, R, F]) List(fn func(stdContext.Context, pagination.ListOptions, F /* filter options */) ([]R, int, error)) *Party[T, R, F] { + p.serviceListFunc = fn + return p +} + +func (p *Party[T, R, F]) Get(fn func(stdContext.Context, string) (R, error)) *Party[T, R, F] { + p.serviceGetFunc = fn + return p +} + +``` diff --git a/context/context.go b/context/context.go index 0e9546dc..09054d93 100644 --- a/context/context.go +++ b/context/context.go @@ -6333,11 +6333,13 @@ func (ctx *Context) GetErrPublic() (bool, error) { // which recovers from a manual panic. type ErrPanicRecovery struct { ErrPrivate - Cause interface{} - Callers []string // file:line callers. - Stack []byte // the full debug stack. - RegisteredHandlers []string // file:line of all registered handlers. - CurrentHandler string // the handler panic came from. + Cause interface{} + Callers []string // file:line callers. + Stack []byte // the full debug stack. + RegisteredHandlers []string // file:line of all registered handlers. + CurrentHandlerFileLine string // the handler panic came from. + CurrentHandlerName string // the handler name panic came from. + Request string // the http dumped request. } // Error implements the Go standard error type. @@ -6348,7 +6350,7 @@ func (e *ErrPanicRecovery) Error() string { } } - return fmt.Sprintf("%v\n%s", e.Cause, strings.Join(e.Callers, "\n")) + return fmt.Sprintf("%v\n%s\nRequest:\n%s", e.Cause, strings.Join(e.Callers, "\n"), e.Request) } // Is completes the internal errors.Is interface. @@ -6357,6 +6359,15 @@ func (e *ErrPanicRecovery) Is(err error) bool { return ok } +func (e *ErrPanicRecovery) LogMessage() string { + logMessage := fmt.Sprintf("Recovered from a route's Handler('%s')\n", e.CurrentHandlerName) + logMessage += fmt.Sprint(e.Request) + logMessage += fmt.Sprintf("%s\n", e.Cause) + logMessage += fmt.Sprintf("%s\n", strings.Join(e.Callers, "\n")) + + return logMessage +} + // IsErrPanicRecovery reports whether the given "err" is a type of ErrPanicRecovery. func IsErrPanicRecovery(err error) (*ErrPanicRecovery, bool) { if err == nil { diff --git a/context/handler.go b/context/handler.go index 5939b3ce..9e13df90 100644 --- a/context/handler.go +++ b/context/handler.go @@ -111,6 +111,20 @@ type Handler = func(*Context) // See `Handler` for more. type Handlers = []Handler +// CopyHandlers returns a copy of "handlers" Handlers slice. +func CopyHandlers(handlers []Handler) Handlers { + handlersCp := make([]Handler, 0, len(handlers)) + for _, handler := range handlers { + if handler == nil { + continue + } + + handlersCp = append(handlersCp, handler) + } + + return handlersCp +} + func valueOf(v interface{}) reflect.Value { if val, ok := v.(reflect.Value); ok { return val diff --git a/core/router/api_builder.go b/core/router/api_builder.go index 2bbe4174..013b39d6 100644 --- a/core/router/api_builder.go +++ b/core/router/api_builder.go @@ -708,9 +708,11 @@ func (api *APIBuilder) createRoutes(errorCode int, methods []string, relativePat errorCode = 0 } + mainHandlers := context.CopyHandlers(handlers) + if errorCode == 0 { if len(methods) == 0 || methods[0] == "ALL" || methods[0] == "ANY" { // then use like it was .Any - return api.Any(relativePath, handlers...) + return api.Any(relativePath, mainHandlers...) } } @@ -727,7 +729,7 @@ func (api *APIBuilder) createRoutes(errorCode int, methods []string, relativePat filename, line := hero.GetCaller() fullpath := api.relativePath + relativePath // for now, keep the last "/" if any, "/xyz/" - if len(handlers) == 0 { + if len(mainHandlers) == 0 { api.logger.Errorf("missing handlers for route[%s:%d] %s: %s", filename, line, strings.Join(methods, ", "), fullpath) return nil } @@ -751,12 +753,11 @@ func (api *APIBuilder) createRoutes(errorCode int, methods []string, relativePat beginHandlers = context.JoinHandlers(beginHandlers, api.middlewareErrorCode) } - mainHandlers := context.Handlers(handlers) // before join the middleware + handlers + done handlers and apply the execution rules. mainHandlerName, mainHandlerIndex := context.MainHandlerName(mainHandlers) - mainHandlerFileName, mainHandlerFileNumber := context.HandlerFileLineRel(handlers[mainHandlerIndex]) + mainHandlerFileName, mainHandlerFileNumber := context.HandlerFileLineRel(mainHandlers[mainHandlerIndex]) // TODO: think of it. if mainHandlerFileName == "" { diff --git a/iris_guide.go b/iris_guide.go index 5e791f7c..7d39f346 100644 --- a/iris_guide.go +++ b/iris_guide.go @@ -1,6 +1,7 @@ package iris import ( + "strings" "time" "github.com/kataras/iris/v12/core/router" @@ -518,8 +519,9 @@ func (s *step7) Build() *Application { ctx.Next() }) - if allowOrigin := s.step6.step5.step4.step3.step2.step1.originLine; allowOrigin != "" && allowOrigin != "none" { - app.UseRouter(cors.New().AllowOrigin(allowOrigin).Handler()) + if allowOrigin := s.step6.step5.step4.step3.step2.step1.originLine; strings.TrimSpace(allowOrigin) != "" && allowOrigin != "none" { + corsMiddleware := cors.New().HandleErrorFunc(errors.FailedPrecondition.Err).AllowOrigin(allowOrigin).Handler() + app.UseRouter(corsMiddleware) } if s.step6.step5.step4.step3.step2.enableCompression { diff --git a/macro/macro.go b/macro/macro.go index 59e22390..8558e4d2 100644 --- a/macro/macro.go +++ b/macro/macro.go @@ -256,6 +256,8 @@ type ( Evaluator ParamEvaluator handleError interface{} funcs []ParamFunc + + goType reflect.Type } // ParamFuncBuilder is a func @@ -278,12 +280,13 @@ type ( // NewMacro creates and returns a Macro that can be used as a registry for // a new customized parameter type and its functions. -func NewMacro(indent, alias string, master, trailing bool, evaluator ParamEvaluator) *Macro { +func NewMacro(indent, alias string, valueType any, master, trailing bool, evaluator ParamEvaluator) *Macro { return &Macro{ indent: indent, alias: alias, master: master, trailing: trailing, + goType: reflect.TypeOf(valueType), Evaluator: evaluator, } @@ -312,6 +315,12 @@ func (m *Macro) Trailing() bool { return m.trailing } +// GoType returns the type of the parameter type's evaluator. +// string if it's a string evaluator, int if it's an int evaluator etc. +func (m *Macro) GoType() reflect.Type { + return m.goType +} + // HandleError registers a handler which will be executed // when a parameter evaluator returns false and a non nil value which is a type of `error`. // The "fnHandler" value MUST BE a type of `func(iris.Context, paramIndex int, err error)`, diff --git a/macro/macros.go b/macro/macros.go index f3adfc8f..438cfa7e 100644 --- a/macro/macros.go +++ b/macro/macros.go @@ -19,7 +19,7 @@ var ( // Allows anything (single path segment, as everything except the `Path`). // Its functions can be used by the rest of the macros and param types whenever not available function by name is used. // Because of its "master" boolean value to true (third parameter). - String = NewMacro("string", "", true, false, nil). + String = NewMacro("string", "", "", true, false, nil). RegisterFunc("regexp", MustRegexp). // checks if param value starts with the 'prefix' arg RegisterFunc("prefix", func(prefix string) func(string) bool { @@ -81,7 +81,7 @@ var ( // both positive and negative numbers, actual value can be min-max int64 or min-max int32 depends on the arch. // If x64: -9223372036854775808 to 9223372036854775807. // If x32: -2147483648 to 2147483647 and etc.. - Int = NewMacro("int", "number", false, false, func(paramValue string) (interface{}, bool) { + Int = NewMacro("int", "number", 0, false, false, func(paramValue string) (interface{}, bool) { v, err := strconv.Atoi(paramValue) if err != nil { return err, false @@ -113,7 +113,7 @@ var ( // Int8 type // -128 to 127. - Int8 = NewMacro("int8", "", false, false, func(paramValue string) (interface{}, bool) { + Int8 = NewMacro("int8", "", int8(0), false, false, func(paramValue string) (interface{}, bool) { v, err := strconv.ParseInt(paramValue, 10, 8) if err != nil { return err, false @@ -138,7 +138,7 @@ var ( // Int16 type // -32768 to 32767. - Int16 = NewMacro("int16", "", false, false, func(paramValue string) (interface{}, bool) { + Int16 = NewMacro("int16", "", int16(0), false, false, func(paramValue string) (interface{}, bool) { v, err := strconv.ParseInt(paramValue, 10, 16) if err != nil { return err, false @@ -163,7 +163,7 @@ var ( // Int32 type // -2147483648 to 2147483647. - Int32 = NewMacro("int32", "", false, false, func(paramValue string) (interface{}, bool) { + Int32 = NewMacro("int32", "", int32(0), false, false, func(paramValue string) (interface{}, bool) { v, err := strconv.ParseInt(paramValue, 10, 32) if err != nil { return err, false @@ -188,7 +188,7 @@ var ( // Int64 as int64 type // -9223372036854775808 to 9223372036854775807. - Int64 = NewMacro("int64", "long", false, false, func(paramValue string) (interface{}, bool) { + Int64 = NewMacro("int64", "long", int64(0), false, false, func(paramValue string) (interface{}, bool) { v, err := strconv.ParseInt(paramValue, 10, 64) if err != nil { // if err == strconv.ErrRange... return err, false @@ -221,7 +221,7 @@ var ( // actual value can be min-max uint64 or min-max uint32 depends on the arch. // If x64: 0 to 18446744073709551615. // If x32: 0 to 4294967295 and etc. - Uint = NewMacro("uint", "", false, false, func(paramValue string) (interface{}, bool) { + Uint = NewMacro("uint", "", uint(0), false, false, func(paramValue string) (interface{}, bool) { v, err := strconv.ParseUint(paramValue, 10, strconv.IntSize) // 32,64... if err != nil { return err, false @@ -252,7 +252,7 @@ var ( // Uint8 as uint8 type // 0 to 255. - Uint8 = NewMacro("uint8", "", false, false, func(paramValue string) (interface{}, bool) { + Uint8 = NewMacro("uint8", "", uint8(0), false, false, func(paramValue string) (interface{}, bool) { v, err := strconv.ParseUint(paramValue, 10, 8) if err != nil { return err, false @@ -283,7 +283,7 @@ var ( // Uint16 as uint16 type // 0 to 65535. - Uint16 = NewMacro("uint16", "", false, false, func(paramValue string) (interface{}, bool) { + Uint16 = NewMacro("uint16", "", uint16(0), false, false, func(paramValue string) (interface{}, bool) { v, err := strconv.ParseUint(paramValue, 10, 16) if err != nil { return err, false @@ -308,7 +308,7 @@ var ( // Uint32 as uint32 type // 0 to 4294967295. - Uint32 = NewMacro("uint32", "", false, false, func(paramValue string) (interface{}, bool) { + Uint32 = NewMacro("uint32", "", uint32(0), false, false, func(paramValue string) (interface{}, bool) { v, err := strconv.ParseUint(paramValue, 10, 32) if err != nil { return err, false @@ -333,7 +333,7 @@ var ( // Uint64 as uint64 type // 0 to 18446744073709551615. - Uint64 = NewMacro("uint64", "", false, false, func(paramValue string) (interface{}, bool) { + Uint64 = NewMacro("uint64", "", uint64(0), false, false, func(paramValue string) (interface{}, bool) { v, err := strconv.ParseUint(paramValue, 10, 64) if err != nil { return err, false @@ -365,7 +365,7 @@ var ( // Bool or boolean as bool type // a string which is "1" or "t" or "T" or "TRUE" or "true" or "True" // or "0" or "f" or "F" or "FALSE" or "false" or "False". - Bool = NewMacro("bool", "boolean", false, false, func(paramValue string) (interface{}, bool) { + Bool = NewMacro("bool", "boolean", false, false, false, func(paramValue string) (interface{}, bool) { // a simple if statement is faster than regex ^(true|false|True|False|t|0|f|FALSE|TRUE)$ // in this case. v, err := strconv.ParseBool(paramValue) @@ -380,7 +380,7 @@ var ( alphabeticalEval = MustRegexp("^[a-zA-Z ]+$") // Alphabetical letter type // letters only (upper or lowercase) - Alphabetical = NewMacro("alphabetical", "", false, false, func(paramValue string) (interface{}, bool) { + Alphabetical = NewMacro("alphabetical", "", "", false, false, func(paramValue string) (interface{}, bool) { if !alphabeticalEval(paramValue) { return fmt.Errorf("%s: %w", paramValue, ErrParamNotAlphabetical), false } @@ -397,7 +397,7 @@ var ( // dash (-) // point (.) // no spaces! or other character - File = NewMacro("file", "", false, false, func(paramValue string) (interface{}, bool) { + File = NewMacro("file", "", "", false, false, func(paramValue string) (interface{}, bool) { if !fileEval(paramValue) { return fmt.Errorf("%s: %w", paramValue, ErrParamNotFile), false } @@ -410,11 +410,11 @@ var ( // types because I want to give the opportunity to the user // to organise the macro functions based on wildcard or single dynamic named path parameter. // Should be living in the latest path segment of a route path. - Path = NewMacro("path", "", false, true, nil) + Path = NewMacro("path", "", "", false, true, nil) // UUID string type for validating a uuidv4 (and v1) path parameter. // Read more at: https://tools.ietf.org/html/rfc4122. - UUID = NewMacro("uuid", "uuidv4", false, false, func(paramValue string) (interface{}, bool) { + UUID = NewMacro("uuid", "uuidv4", "", false, false, func(paramValue string) (interface{}, bool) { _, err := uuid.Parse(paramValue) // this is x10+ times faster than regexp. if err != nil { return err, false @@ -425,7 +425,7 @@ var ( // Email string type for validating an e-mail path parameter. It returns the address as string, instead of an *mail.Address. // Read more at go std mail.ParseAddress method. See the ':email' path parameter for a more strictly version of validation. - Mail = NewMacro("mail", "", false, false, func(paramValue string) (interface{}, bool) { + Mail = NewMacro("mail", "", "", false, false, func(paramValue string) (interface{}, bool) { _, err := mail.ParseAddress(paramValue) if err != nil { return fmt.Errorf("%s: %w", paramValue, err), false @@ -437,7 +437,7 @@ var ( // Email string type for validating an e-mail path parameter. It returns the address as string, instead of an *mail.Address. // It is a combined validation using mail.ParseAddress and net.LookupMX so only valid domains can be passed. // It's a more strictly version of the ':mail' path parameter. - Email = NewMacro("email", "", false, false, func(paramValue string) (interface{}, bool) { + Email = NewMacro("email", "", "", false, false, func(paramValue string) (interface{}, bool) { _, err := mail.ParseAddress(paramValue) if err != nil { return fmt.Errorf("%s: %w", paramValue, err), false @@ -460,7 +460,7 @@ var ( simpleDateLayout = "2006/01/02" // Date type. - Date = NewMacro("date", "", false, true, func(paramValue string) (interface{}, bool) { + Date = NewMacro("date", "", time.Time{}, false, true, func(paramValue string) (interface{}, bool) { tt, err := time.Parse(simpleDateLayout, paramValue) if err != nil { return fmt.Errorf("%s: %w", paramValue, err), false @@ -492,7 +492,7 @@ var ( // Weekday type, returns a type of time.Weekday. // Valid values: // 0 to 7 (leading zeros don't matter) or "Sunday" to "Monday" or "sunday" to "monday". - Weekday = NewMacro("weekday", "", false, false, func(paramValue string) (interface{}, bool) { + Weekday = NewMacro("weekday", "", time.Weekday(0), false, false, func(paramValue string) (interface{}, bool) { d, ok := longDayNames[paramValue] if !ok { // try parse from integer. @@ -545,13 +545,14 @@ type Macros []*Macro // Register registers a custom Macro. // The "indent" should not be empty and should be unique, it is the parameter type's name, i.e "string". // The "alias" is optionally and it should be unique, it is the alias of the parameter type. +// The "valueType" should be the zero value of the parameter type, i.e "" for string, 0 for int and etc. // "isMaster" and "isTrailing" is for default parameter type and wildcard respectfully. // The "evaluator" is the function that is converted to an Iris handler which is executed every time // before the main chain of a route's handlers that contains this macro of the specific parameter type. // // Read https://github.com/kataras/iris/tree/main/_examples/routing/macros for more details. -func (ms *Macros) Register(indent, alias string, isMaster, isTrailing bool, evaluator ParamEvaluator) *Macro { - macro := NewMacro(indent, alias, isMaster, isTrailing, evaluator) +func (ms *Macros) Register(indent, alias string, valueType any, isMaster, isTrailing bool, evaluator ParamEvaluator) *Macro { + macro := NewMacro(indent, alias, valueType, isMaster, isTrailing, evaluator) if ms.register(macro) { return macro } diff --git a/middleware/accesslog/accesslog.go b/middleware/accesslog/accesslog.go index a9f4d277..35d83f04 100644 --- a/middleware/accesslog/accesslog.go +++ b/middleware/accesslog/accesslog.go @@ -1056,7 +1056,7 @@ func (ac *AccessLog) getErrorText(err error) (text string) { // caller checks fo switch ac.PanicLog { case LogHandler: - text = errPanic.CurrentHandler + text = errPanic.CurrentHandlerFileLine case LogCallers: text = strings.Join(errPanic.Callers, "\n") case LogStack: diff --git a/middleware/recover/recover.go b/middleware/recover/recover.go index 949ad076..f3d06864 100644 --- a/middleware/recover/recover.go +++ b/middleware/recover/recover.go @@ -6,7 +6,6 @@ import ( "net/http/httputil" "runtime" "runtime/debug" - "strings" "github.com/kataras/iris/v12/context" ) @@ -15,65 +14,76 @@ func init() { context.SetHandlerName("iris/middleware/recover.*", "iris.recover") } -func getRequestLogs(ctx *context.Context) string { - rawReq, _ := httputil.DumpRequest(ctx.Request(), false) - return string(rawReq) -} - -// New returns a new recover middleware, -// it recovers from panics and logs -// the panic message to the application's logger "Warn" level. +// New returns a new recovery middleware, +// it recovers from panics and logs the +// panic message to the application's logger "Warn" level. func New() context.Handler { return func(ctx *context.Context) { defer func() { - if err := recover(); err != nil { - if ctx.IsStopped() { // handled by other middleware. - return - } - - var callers []string - for i := 1; ; i++ { - _, file, line, got := runtime.Caller(i) - if !got { - break - } - - callers = append(callers, fmt.Sprintf("%s:%d", file, line)) - } - - // when stack finishes - logMessage := fmt.Sprintf("Recovered from a route's Handler('%s')\n", ctx.HandlerName()) - logMessage += fmt.Sprint(getRequestLogs(ctx)) - logMessage += fmt.Sprintf("%s\n", err) - logMessage += fmt.Sprintf("%s\n", strings.Join(callers, "\n")) - ctx.Application().Logger().Warn(logMessage) - - // get the list of registered handlers and the - // handler which panic derived from. - handlers := ctx.Handlers() - handlersFileLines := make([]string, 0, len(handlers)) - currentHandlerIndex := ctx.HandlerIndex(-1) - currentHandlerFileLine := "???" - for i, h := range ctx.Handlers() { - file, line := context.HandlerFileLine(h) - fileline := fmt.Sprintf("%s:%d", file, line) - handlersFileLines = append(handlersFileLines, fileline) - if i == currentHandlerIndex { - currentHandlerFileLine = fileline - } - } - - // see accesslog.wasRecovered too. - ctx.StopWithPlainError(500, &context.ErrPanicRecovery{ - Cause: err, - Callers: callers, - Stack: debug.Stack(), - RegisteredHandlers: handlersFileLines, - CurrentHandler: currentHandlerFileLine, - }) - } + if err := PanicRecoveryError(ctx, recover()); err != nil { + ctx.StopWithPlainError(500, err) + ctx.Application().Logger().Warn(err.LogMessage()) + } // else it's already handled. }() ctx.Next() } } + +// PanicRecoveryError returns a new ErrPanicRecovery error. +func PanicRecoveryError(ctx *context.Context, err any) *context.ErrPanicRecovery { + if recoveryErr, ok := ctx.IsRecovered(); ok { + // If registered before any other recovery middleware, get its error. + // Because of defer this will be executed last, after the recovery middleware in this case. + return recoveryErr + } + + if err == nil { + return nil + } else if ctx.IsStopped() { + return nil + } + + var callers []string + for i := 2; ; /* 1 for New() 2 for NewPanicRecoveryError */ i++ { + _, file, line, got := runtime.Caller(i) + if !got { + break + } + + callers = append(callers, fmt.Sprintf("%s:%d", file, line)) + } + + // get the list of registered handlers and the + // handler which panic derived from. + handlers := ctx.Handlers() + handlersFileLines := make([]string, 0, len(handlers)) + currentHandlerIndex := ctx.HandlerIndex(-1) + currentHandlerFileLine := "???" + for i, h := range ctx.Handlers() { + file, line := context.HandlerFileLine(h) + fileline := fmt.Sprintf("%s:%d", file, line) + handlersFileLines = append(handlersFileLines, fileline) + if i == currentHandlerIndex { + currentHandlerFileLine = fileline + } + } + + // see accesslog.wasRecovered too. + recoveryErr := &context.ErrPanicRecovery{ + Cause: err, + Callers: callers, + Stack: debug.Stack(), + RegisteredHandlers: handlersFileLines, + CurrentHandlerFileLine: currentHandlerFileLine, + CurrentHandlerName: ctx.HandlerName(), + Request: getRequestLogs(ctx), + } + + return recoveryErr +} + +func getRequestLogs(ctx *context.Context) string { + rawReq, _ := httputil.DumpRequest(ctx.Request(), false) + return string(rawReq) +} diff --git a/x/errors/handlers.go b/x/errors/handlers.go index 53719a89..ee028121 100644 --- a/x/errors/handlers.go +++ b/x/errors/handlers.go @@ -3,11 +3,11 @@ package errors import ( stdContext "context" "errors" - "fmt" "io" "net/http" "github.com/kataras/iris/v12/context" + recovery "github.com/kataras/iris/v12/middleware/recover" "github.com/kataras/iris/v12/x/pagination" "golang.org/x/exp/constraints" @@ -17,17 +17,7 @@ import ( // to the logger and the client. func RecoveryHandler(ctx *context.Context) { defer func() { - if rec := recover(); rec != nil { - var err error - switch v := rec.(type) { - case error: - err = v - case string: - err = New(v) - default: - err = fmt.Errorf("%v", v) - } - + if err := recovery.PanicRecoveryError(ctx, recover()); err != nil { Internal.LogErr(ctx, err) ctx.StopExecution() } @@ -165,6 +155,10 @@ const contextRequestHandlerFuncKey = "iris.errors.ContextRequestHandler" // ) // } func Validation[T any](validators ...ContextRequestFunc[T]) context.Handler { + if len(validators) == 0 { + return nil + } + validator := joinContextRequestFuncs[T](validators) return func(ctx *context.Context) { @@ -227,27 +221,27 @@ func validateRequest[T any](ctx *context.Context, req T) bool { // ResponseHandler is an interface which can be implemented by a request payload struct // in order to handle a response before sending it to the client. -type ResponseHandler[R any, RPointer *R] interface { - HandleResponse(ctx *context.Context, response RPointer) error +type ResponseHandler[R any] interface { + HandleResponse(ctx *context.Context, response *R) error } // ContextResponseFunc is a function which takes a context, a generic type T and a generic type R and returns an error. -type ContextResponseFunc[T, R any, RPointer *R] func(*context.Context, T, RPointer) error +type ContextResponseFunc[T, R any] func(*context.Context, T, *R) error const contextResponseHandlerFuncKey = "iris.errors.ContextResponseHandler" -func validateResponse[T, R any, RPointer *R](ctx *context.Context, req T, resp RPointer) bool { +func validateResponse[T, R any](ctx *context.Context, req T, resp *R) bool { var err error - if contextResponseHandler, ok := any(&req).(ResponseHandler[R, RPointer]); ok { + if contextResponseHandler, ok := any(&req).(ResponseHandler[R]); ok { err = contextResponseHandler.HandleResponse(ctx, resp) } if err == nil { if v := ctx.Values().Get(contextResponseHandlerFuncKey); v != nil { - if contextResponseHandlerFunc, ok := v.(ContextResponseFunc[T, R, RPointer]); ok && contextResponseHandlerFunc != nil { + if contextResponseHandlerFunc, ok := v.(ContextResponseFunc[T, R]); ok && contextResponseHandlerFunc != nil { err = contextResponseHandlerFunc(ctx, req, resp) - } else if contextResponseHandlerFunc, ok := v.(ContextResponseFunc[*T, R, RPointer]); ok && contextResponseHandlerFunc != nil { + } else if contextResponseHandlerFunc, ok := v.(ContextResponseFunc[*T, R]); ok && contextResponseHandlerFunc != nil { err = contextResponseHandlerFunc(ctx, &req, resp) } } @@ -262,8 +256,12 @@ func validateResponse[T, R any, RPointer *R](ctx *context.Context, req T, resp R // Example Code: // // app.Post("/", errors.Intercept(func(ctx iris.Context, req *CreateRequest, resp *CreateResponse) error{ ... }), errors.CreateHandler(service.Create)) -func Intercept[T, R any, RPointer *R](responseHandlers ...ContextResponseFunc[T, R, RPointer]) context.Handler { - responseHandler := joinContextResponseFuncs[T, R, RPointer](responseHandlers) +func Intercept[T, R any](responseHandlers ...ContextResponseFunc[T, R]) context.Handler { + if len(responseHandlers) == 0 { + return nil + } + + responseHandler := joinContextResponseFuncs[T, R](responseHandlers) return func(ctx *context.Context) { ctx.Values().Set(contextResponseHandlerFuncKey, responseHandler) @@ -271,7 +269,7 @@ func Intercept[T, R any, RPointer *R](responseHandlers ...ContextResponseFunc[T, } } -func joinContextResponseFuncs[T, R any, RPointer *R](responseHandlerFuncs []ContextResponseFunc[T, R, RPointer]) ContextResponseFunc[T, R, RPointer] { +func joinContextResponseFuncs[T, R any](responseHandlerFuncs []ContextResponseFunc[T, R]) ContextResponseFunc[T, R] { if len(responseHandlerFuncs) == 0 || responseHandlerFuncs[0] == nil { panic("at least one context response handler function is required") } @@ -280,7 +278,7 @@ func joinContextResponseFuncs[T, R any, RPointer *R](responseHandlerFuncs []Cont return responseHandlerFuncs[0] } - return func(ctx *context.Context, req T, resp RPointer) error { + return func(ctx *context.Context, req T, resp *R) error { for _, handler := range responseHandlerFuncs { if handler == nil { continue