package methodoverride import ( stdContext "context" "net/http" "strings" "github.com/kataras/iris/v12/context" "github.com/kataras/iris/v12/core/router" ) func init() { context.SetHandlerName("iris/middleware/methodoverride.*", "iris.methodoverride") } type options struct { getters []GetterFunc methods []string saveOriginalMethodContextKey interface{} // if not nil original value will be saved. } func (o *options) configure(opts ...Option) { for _, opt := range opts { opt(o) } } func (o *options) canOverride(method string) bool { for _, s := range o.methods { if s == method { return true } } return false } func (o *options) get(w http.ResponseWriter, r *http.Request) string { for _, getter := range o.getters { if v := getter(w, r); v != "" { return strings.ToUpper(v) } } return "" } // Option sets options for a fresh method override wrapper. // See `New` package-level function for more. type Option func(*options) // Methods can be used to add methods that can be overridden. // Defaults to "POST". func Methods(methods ...string) Option { for i, s := range methods { methods[i] = strings.ToUpper(s) } return func(opts *options) { opts.methods = append(opts.methods, methods...) } } // SaveOriginalMethod will save the original method // on Context.Request().Context().Value(requestContextKey). // // Defaults to nil, don't save it. func SaveOriginalMethod(requestContextKey interface{}) Option { return func(opts *options) { if requestContextKey == nil { opts.saveOriginalMethodContextKey = nil } opts.saveOriginalMethodContextKey = requestContextKey } } // GetterFunc is the type signature for declaring custom logic // to extract the method name which a POST request will be replaced with. type GetterFunc func(http.ResponseWriter, *http.Request) string // Getter sets a custom logic to use to extract the method name // to override the POST method with. // Defaults to nil. func Getter(customFunc GetterFunc) Option { return func(opts *options) { opts.getters = append(opts.getters, customFunc) } } // Headers that client can send to specify a method // to override the POST method with. // // Defaults to: // X-HTTP-Method // X-HTTP-Method-Override // X-Method-Override func Headers(headers ...string) Option { getter := func(w http.ResponseWriter, r *http.Request) string { for _, s := range headers { if v := r.Header.Get(s); v != "" { w.Header().Add("Vary", s) return v } } return "" } return Getter(getter) } // FormField specifies a form field to use to determinate the method // to override the POST method with. // // Example Field: // // // Defaults to: "_method". func FormField(fieldName string) Option { return FormFieldWithConf(fieldName, nil) } // FormFieldWithConf same as `FormField` but it accepts the application's // configuration to parse the form based on the app core configuration. func FormFieldWithConf(fieldName string, conf context.ConfigurationReadOnly) Option { var ( postMaxMemory int64 = 32 << 20 // 32 MB resetBody = false ) if conf != nil { postMaxMemory = conf.GetPostMaxMemory() resetBody = conf.GetDisableBodyConsumptionOnUnmarshal() } getter := func(w http.ResponseWriter, r *http.Request) string { return context.FormValueDefault(r, fieldName, "", postMaxMemory, resetBody) } return Getter(getter) } // Query specifies a url parameter name to use to determinate the method // to override the POST methos with. // // Example URL Query string: // http://localhost:8080/path?_method=DELETE // // Defaults to: "_method". func Query(paramName string) Option { getter := func(w http.ResponseWriter, r *http.Request) string { return r.URL.Query().Get(paramName) } return Getter(getter) } // Only clears all default or previously registered values // and uses only the "o" option(s). // // The default behavior is to check for all the following by order: // headers, form field, query string // and any custom getter (if set). // Use this method to override that // behavior and use only the passed option(s) // to determinate the method to override with. // // Use cases: // // 1. When need to check only for headers and ignore other fields: // New(Only(Headers("X-Custom-Header"))) // // 2. When need to check only for (first) form field and (second) custom getter: // New(Only(FormField("fieldName"), Getter(...))) func Only(o ...Option) Option { return func(opts *options) { opts.getters = opts.getters[0:0] opts.configure(o...) } } // New returns a new method override wrapper // which can be registered with `Application.WrapRouter`. // // Use this wrapper when you expecting clients // that do not support certain HTTP operations such as DELETE or PUT for security reasons. // This wrapper will accept a method, based on criteria, to override the POST method with. // // Read more at: // https://github.com/kataras/iris/issues/1325 func New(opt ...Option) router.WrapperFunc { opts := new(options) // Default values. opts.configure( Methods(http.MethodPost), Headers("X-HTTP-Method", "X-HTTP-Method-Override", "X-Method-Override"), FormField("_method"), Query("_method"), ) opts.configure(opt...) return func(w http.ResponseWriter, r *http.Request, proceed http.HandlerFunc) { originalMethod := strings.ToUpper(r.Method) if opts.canOverride(originalMethod) { newMethod := opts.get(w, r) if newMethod != "" { if opts.saveOriginalMethodContextKey != nil { r = r.WithContext(stdContext.WithValue(r.Context(), opts.saveOriginalMethodContextKey, originalMethod)) } r.Method = newMethod } } proceed(w, r) } }