diff --git a/_examples/README.md b/_examples/README.md index 8b4c742e..16035754 100644 --- a/_examples/README.md +++ b/_examples/README.md @@ -256,6 +256,7 @@ You can serve [quicktemplate](https://github.com/valyala/quicktemplate) and [her ### Miscellaneous +- [Method Override](https://github.com/kataras/iris/blob/master/middleware/methodoverride/methodoverride_test.go) **NEW** - [Request Logger](http_request/request-logger/main.go) * [log requests to a file](http_request/request-logger/request-logger-file/main.go) - [Localization and Internationalization](miscellaneous/i18n/main.go) diff --git a/_examples/README_ZH.md b/_examples/README_ZH.md index 8f86ac9f..c5266f2b 100644 --- a/_examples/README_ZH.md +++ b/_examples/README_ZH.md @@ -363,6 +363,7 @@ You can serve [quicktemplate](https://github.com/valyala/quicktemplate) and [her ### 其他 +- [Method Override](https://github.com/kataras/iris/blob/master/middleware/methodoverride/methodoverride_test.go) **更新** - [请求记录器](http_request/request-logger/main.go) * [将请求记录到文件](http_request/request-logger/request-logger-file/main.go) - [本地化和多语言支持](miscellaneous/i18n/main.go) diff --git a/context/context.go b/context/context.go index c334f08c..bbdb1f68 100644 --- a/context/context.go +++ b/context/context.go @@ -2050,6 +2050,16 @@ func (ctx *context) FormValueDefault(name string, def string) string { return def } +// FormValueDefault retruns a single parsed form value. +func FormValueDefault(r *http.Request, name string, def string, postMaxMemory int64, resetBody bool) string { + if form, has := GetForm(r, postMaxMemory, resetBody); has { + if v := form[name]; len(v) > 0 { + return v[0] + } + } + return def +} + // FormValue returns a single parsed form value by its "name", // including both the URL field's query parameters and the POST or PUT form data. func (ctx *context) FormValue(name string) string { @@ -2070,6 +2080,11 @@ func (ctx *context) FormValues() map[string][]string { // Form contains the parsed form data, including both the URL // field's query parameters and the POST or PUT form data. func (ctx *context) form() (form map[string][]string, found bool) { + return GetForm(ctx.request, ctx.Application().ConfigurationReadOnly().GetPostMaxMemory(), ctx.Application().ConfigurationReadOnly().GetDisableBodyConsumptionOnUnmarshal()) +} + +// GetForm returns the request form (url queries, post or multipart) values. +func GetForm(r *http.Request, postMaxMemory int64, resetBody bool) (form map[string][]string, found bool) { /* net/http/request.go#1219 for k, v := range f.Value { @@ -2079,21 +2094,34 @@ func (ctx *context) form() (form map[string][]string, found bool) { } */ + if form := r.Form; len(form) > 0 { + return form, true + } + + if form := r.PostForm; len(form) > 0 { + return form, true + } + + if m := r.MultipartForm; m != nil { + if len(m.Value) > 0 { + return m.Value, true + } + } + var ( - keepBody = ctx.Application().ConfigurationReadOnly().GetDisableBodyConsumptionOnUnmarshal() bodyCopy []byte ) - if keepBody { + if resetBody { // on POST, PUT and PATCH it will read the form values from request body otherwise from URL queries. - if m := ctx.Method(); m == "POST" || m == "PUT" || m == "PATCH" { - bodyCopy, _ = ctx.GetBody() + if m := r.Method; m == "POST" || m == "PUT" || m == "PATCH" { + bodyCopy, _ = GetBody(r, resetBody) if len(bodyCopy) == 0 { return nil, false } - // ctx.request.Body = ioutil.NopCloser(io.TeeReader(ctx.request.Body, buf)) + // r.Body = ioutil.NopCloser(io.TeeReader(r.Body, buf)) } else { - keepBody = false + resetBody = false } } @@ -2101,23 +2129,23 @@ func (ctx *context) form() (form map[string][]string, found bool) { // therefore we don't need to call it here, although it doesn't hurt. // After one call to ParseMultipartForm or ParseForm, // subsequent calls have no effect, are idempotent. - err := ctx.request.ParseMultipartForm(ctx.Application().ConfigurationReadOnly().GetPostMaxMemory()) - if keepBody { - ctx.request.Body = ioutil.NopCloser(bytes.NewBuffer(bodyCopy)) + err := r.ParseMultipartForm(postMaxMemory) + if resetBody { + r.Body = ioutil.NopCloser(bytes.NewBuffer(bodyCopy)) } if err != nil && err != http.ErrNotMultipart { return nil, false } - if form := ctx.request.Form; len(form) > 0 { + if form := r.Form; len(form) > 0 { return form, true } - if form := ctx.request.PostForm; len(form) > 0 { + if form := r.PostForm; len(form) > 0 { return form, true } - if m := ctx.request.MultipartForm; m != nil { + if m := r.MultipartForm; m != nil { if len(m.Value) > 0 { return m.Value, true } @@ -2387,15 +2415,20 @@ func (ctx *context) SetMaxRequestBodySize(limitOverBytes int64) { // // However, whenever you can use the `ctx.Request().Body` instead. func (ctx *context) GetBody() ([]byte, error) { - data, err := ioutil.ReadAll(ctx.request.Body) + return GetBody(ctx.request, ctx.Application().ConfigurationReadOnly().GetDisableBodyConsumptionOnUnmarshal()) +} + +// GetBody reads and returns the request body. +func GetBody(r *http.Request, resetBody bool) ([]byte, error) { + data, err := ioutil.ReadAll(r.Body) if err != nil { return nil, err } - if ctx.Application().ConfigurationReadOnly().GetDisableBodyConsumptionOnUnmarshal() { + if resetBody { // * remember, Request.Body has no Bytes(), we have to consume them first // and after re-set them to the body, this is the only solution. - ctx.request.Body = ioutil.NopCloser(bytes.NewBuffer(data)) + r.Body = ioutil.NopCloser(bytes.NewBuffer(data)) } return data, nil diff --git a/core/router/router.go b/core/router/router.go index 7ff55e86..b180d502 100644 --- a/core/router/router.go +++ b/core/router/router.go @@ -81,6 +81,8 @@ func (router *Router) BuildRouter(cPool *context.Pool, requestHandler RequestHan // the important router.mainHandler = func(w http.ResponseWriter, r *http.Request) { ctx := cPool.Acquire(w, r) + // Note: we can't get all r.Context().Value key-value pairs + // and save them to ctx.values. router.requestHandler.HandleRequest(ctx) cPool.Release(ctx) } diff --git a/middleware/methodoverride/methodoverride.go b/middleware/methodoverride/methodoverride.go new file mode 100644 index 00000000..924f1e94 --- /dev/null +++ b/middleware/methodoverride/methodoverride.go @@ -0,0 +1,212 @@ +package methodoverride + +import ( + stdContext "context" + "net/http" + "strings" + + "github.com/kataras/iris/context" + "github.com/kataras/iris/core/router" +) + +type options struct { + getters []GetterFunc + methods []string + saveOriginalMethodContextKey interface{} // if not nil original value will be saved. +} + +func (o *options) configure(opts ...Option) { + for _, opt := range opts { + opt(o) + } +} + +func (o *options) canOverride(method string) bool { + for _, s := range o.methods { + if s == method { + return true + } + } + + return false +} + +func (o *options) get(w http.ResponseWriter, r *http.Request) string { + for _, getter := range o.getters { + if v := getter(w, r); v != "" { + return v + } + } + + return "" +} + +// Option sets options for a fresh method override wrapper. +// See `New` package-level function for more. +type Option func(*options) + +// Methods can be used to add methods that can be overridden. +// Defaults to "POST". +func Methods(methods ...string) Option { + for i, s := range methods { + methods[i] = strings.ToUpper(s) + } + + return func(opts *options) { + opts.methods = append(opts.methods, methods...) + } +} + +// SaveOriginalMethod will save the original method +// on Context.Request().Context().Value(requestContextKey). +// +// Defaults to nil, don't save it. +func SaveOriginalMethod(requestContextKey interface{}) Option { + return func(opts *options) { + if requestContextKey == nil { + opts.saveOriginalMethodContextKey = nil + } + opts.saveOriginalMethodContextKey = requestContextKey + } +} + +// GetterFunc is the type signature for declaring custom logic +// to extract the method name which a POST request will be replaced with. +type GetterFunc func(http.ResponseWriter, *http.Request) string + +// Getter sets a custom logic to use to extract the method name +// to override the POST method with. +// Defaults to nil. +func Getter(customFunc GetterFunc) Option { + return func(opts *options) { + opts.getters = append(opts.getters, customFunc) + } +} + +// Headers that client can send to specify a method +// to override the POST method with. +// +// Defaults to: +// X-HTTP-Method +// X-HTTP-Method-Override +// X-Method-Override +func Headers(headers ...string) Option { + getter := func(w http.ResponseWriter, r *http.Request) string { + for _, s := range headers { + if v := r.Header.Get(s); v != "" { + w.Header().Add("Vary", s) + return v + } + } + + return "" + } + + return Getter(getter) +} + +// FormField specifies a form field to use to determinate the method +// to override the POST method with. +// +// Example Field: +// +// +// Defaults to: "_method". +func FormField(fieldName string) Option { + return FormFieldWithConf(fieldName, nil) +} + +// FormFieldWithConf same as `FormField` but it accepts the application's +// configuration to parse the form based on the app core configuration. +func FormFieldWithConf(fieldName string, conf context.ConfigurationReadOnly) Option { + var ( + postMaxMemory int64 = 32 << 20 // 32 MB + resetBody = false + ) + + if conf != nil { + postMaxMemory = conf.GetPostMaxMemory() + resetBody = conf.GetDisableBodyConsumptionOnUnmarshal() + } + + getter := func(w http.ResponseWriter, r *http.Request) string { + return context.FormValueDefault(r, fieldName, "", postMaxMemory, resetBody) + } + + return Getter(getter) +} + +// Query specifies a url parameter name to use to determinate the method +// to override the POST methos with. +// +// Example URL Query string: +// http://localhost:8080/path?_method=DELETE +// +// Defaults to: "_method". +func Query(paramName string) Option { + getter := func(w http.ResponseWriter, r *http.Request) string { + return r.URL.Query().Get(paramName) + } + + return Getter(getter) +} + +// Only clears all default or previously registered values +// and uses only the "o" option(s). +// +// The default behavior is to check for all the following by order: +// headers, form field, query string +// and any custom getter (if set). +// Use this method to override that +// behavior and use only the passed option(s) +// to determinate the method to override with. +// +// Use cases: +// 1. When need to check only for headers and ignore other fields: +// New(Only(Headers("X-Custom-Header"))) +// +// 2. When need to check only for (first) form field and (second) custom getter: +// New(Only(FormField("fieldName"), Getter(...))) +func Only(o ...Option) Option { + return func(opts *options) { + opts.getters = opts.getters[0:0] + opts.configure(o...) + } +} + +// New returns a new method override wrapper +// which can be registered with `Application.WrapRouter`. +// +// Use this wrapper when you expecting clients +// that do not support certain HTTP operations such as DELETE or PUT for security reasons. +// This wrapper will accept a method, based on criteria, to override the POST method with. +// +// +// Read more at: +// https://github.com/kataras/iris/issues/1325 +func New(opt ...Option) router.WrapperFunc { + opts := new(options) + // Default values. + opts.configure( + Methods(http.MethodPost), + Headers("X-HTTP-Method", "X-HTTP-Method-Override", "X-Method-Override"), + FormField("_method"), + Query("_method"), + ) + opts.configure(opt...) + + return func(w http.ResponseWriter, r *http.Request, proceed http.HandlerFunc) { + originalMethod := strings.ToUpper(r.Method) + if opts.canOverride(originalMethod) { + newMethod := opts.get(w, r) + if newMethod != "" { + if opts.saveOriginalMethodContextKey != nil { + r = r.WithContext(stdContext.WithValue(r.Context(), opts.saveOriginalMethodContextKey, originalMethod)) + } + r.Method = newMethod + } + } + + proceed(w, r) + } +} diff --git a/middleware/methodoverride/methodoverride_test.go b/middleware/methodoverride/methodoverride_test.go new file mode 100644 index 00000000..7bf9705c --- /dev/null +++ b/middleware/methodoverride/methodoverride_test.go @@ -0,0 +1,73 @@ +package methodoverride_test + +import ( + "testing" + + "github.com/kataras/iris" + "github.com/kataras/iris/httptest" + "github.com/kataras/iris/middleware/methodoverride" +) + +func TestMethodOverrideWrapper(t *testing.T) { + app := iris.New() + + mo := methodoverride.New( + // Defaults to nil. + // + methodoverride.SaveOriginalMethod("_originalMethod"), + // Default values. + // + // methodoverride.Methods(http.MethodPost), + // methodoverride.Headers("X-HTTP-Method", "X-HTTP-Method-Override", "X-Method-Override"), + // methodoverride.FormField("_method"), + // methodoverride.Query("_method"), + ) + // Register it with `WrapRouter`. + app.WrapRouter(mo) + + var ( + expectedDelResponse = "delete resp" + expectedPostResponse = "post resp" + ) + + app.Post("/path", func(ctx iris.Context) { + ctx.WriteString(expectedPostResponse) + }) + + app.Delete("/path", func(ctx iris.Context) { + ctx.WriteString(expectedDelResponse) + }) + + app.Delete("/path2", func(ctx iris.Context) { + ctx.Writef("%s%s", expectedDelResponse, ctx.Request().Context().Value("_originalMethod")) + }) + + e := httptest.New(t, app) + + // Test headers. + e.POST("/path").WithHeader("X-HTTP-Method", iris.MethodDelete).Expect(). + Status(iris.StatusOK).Body().Equal(expectedDelResponse) + e.POST("/path").WithHeader("X-HTTP-Method-Override", iris.MethodDelete).Expect(). + Status(iris.StatusOK).Body().Equal(expectedDelResponse) + e.POST("/path").WithHeader("X-Method-Override", iris.MethodDelete).Expect(). + Status(iris.StatusOK).Body().Equal(expectedDelResponse) + + // Test form field value. + e.POST("/path").WithFormField("_method", iris.MethodDelete).Expect(). + Status(iris.StatusOK).Body().Equal(expectedDelResponse) + + // Test URL Query (although it's the same as form field in this case). + e.POST("/path").WithQuery("_method", iris.MethodDelete).Expect(). + Status(iris.StatusOK).Body().Equal(expectedDelResponse) + + // Test saved original method and + // Test without registered "POST" route. + e.POST("/path2").WithQuery("_method", iris.MethodDelete).Expect(). + Status(iris.StatusOK).Body().Equal(expectedDelResponse + iris.MethodPost) + + // Test simple POST request without method override fields. + e.POST("/path").Expect().Status(iris.StatusOK).Body().Equal(expectedPostResponse) + + // Test simple DELETE request. + e.DELETE("/path").Expect().Status(iris.StatusOK).Body().Equal(expectedDelResponse) +}