From 61872a1612ef5052ade2d727e9b0a335265fc74d Mon Sep 17 00:00:00 2001 From: "Gerasimos (Makis) Maropoulos" Date: Tue, 1 Mar 2022 21:26:02 +0200 Subject: [PATCH] Add Party.Container.SetDependencyMatcher, hero.Container.DependencyMatcher and hero.Dependency.Match to fullfil the feature request asked at: #1842 --- HISTORY.md | 2 ++ .../response-writer/sse-third-party/main.go | 6 ++-- core/router/api_container.go | 13 ++++++++ hero/binding.go | 21 +++++++++--- hero/binding_test.go | 4 +-- hero/container.go | 33 +++++++++++-------- hero/dependency.go | 26 ++++++++++++--- hero/struct.go | 2 +- sessions/sessiondb/redis/driver_goredis.go | 2 +- 9 files changed, 79 insertions(+), 30 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 8b515181..795858cb 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -28,6 +28,8 @@ The codebase for Dependency Injection, Internationalization and localization and ## Fixes and Improvements +- Add `Container.DependencyMatcher` and `Dependency.Match` to implement the feature requested at [issues/#1842](https://github.com/kataras/iris/issues/1842). + - Register [CORS middleware](middleware/cors) to the Application by default when `iris.Default()` is used instead of `iris.New()`. - Add [x/jsonx: DayTime](/x/jsonx/day_time.go) for JSON marshal and unmarshal of "15:04:05" (hour, minute, second). diff --git a/_examples/response-writer/sse-third-party/main.go b/_examples/response-writer/sse-third-party/main.go index ed87b9f0..af378232 100644 --- a/_examples/response-writer/sse-third-party/main.go +++ b/_examples/response-writer/sse-third-party/main.go @@ -4,11 +4,11 @@ import ( "time" "github.com/kataras/iris/v12" - "github.com/r3labs/sse" + "github.com/r3labs/sse/v2" ) // First of all install the sse third-party package (you can use other if you don't like this approach or go ahead to the "sse" example) -// $ go get -u github.com/r3labs/sse +// $ go get github.com/r3labs/sse/v2@v2.7.4 func main() { app := iris.New() s := sse.New() @@ -22,7 +22,7 @@ func main() { */ s.CreateStream("messages") - app.Any("/events", iris.FromStd(s.HTTPHandler)) + app.Any("/events", iris.FromStd(s)) go func() { // You design when to send messages to the client, diff --git a/core/router/api_container.go b/core/router/api_container.go index 4bf7fe36..c1d9d9e9 100644 --- a/core/router/api_container.go +++ b/core/router/api_container.go @@ -95,6 +95,19 @@ func (api *APIContainer) EnableStrictMode(strictMode bool) *APIContainer { return api } +// SetDependencyMatcher replaces the function that compares equality between +// a dependency and an input (struct field or function parameter). +// +// Defaults to hero.DefaultMatchDependencyFunc. +func (api *APIContainer) SetDependencyMatcher(fn hero.DependencyMatcher) *APIContainer { + if fn == nil { + panic("api container: set dependency matcher: fn cannot be nil") + } + + api.Container.DependencyMatcher = fn + return api +} + // convertHandlerFuncs accepts Iris hero handlers and returns a slice of native Iris handlers. func (api *APIContainer) convertHandlerFuncs(relativePath string, handlersFn ...interface{}) context.Handlers { fullpath := api.Self.GetRelPath() + relativePath diff --git a/hero/binding.go b/hero/binding.go index 63d9287f..fbe48fe3 100644 --- a/hero/binding.go +++ b/hero/binding.go @@ -110,7 +110,12 @@ func (b *binding) Equal(other *binding) bool { return true } -func matchDependency(dep *Dependency, in reflect.Type) bool { +// DependencyMatcher type alias describes a dependency match function. +type DependencyMatcher = func(*Dependency, reflect.Type) bool + +// DefaultDependencyMatcher is the default dependency match function for all DI containers. +// It is used to collect dependencies from struct's fields and function's parameters. +var DefaultDependencyMatcher = func(dep *Dependency, in reflect.Type) bool { if dep.Explicit { return dep.DestType == in } @@ -118,6 +123,14 @@ func matchDependency(dep *Dependency, in reflect.Type) bool { return dep.DestType == nil || equalTypes(dep.DestType, in) } +// ToDependencyMatchFunc converts a DependencyMatcher (generic for all dependencies) +// to a dependency-specific input matcher. +func ToDependencyMatchFunc(d *Dependency, match DependencyMatcher) DependencyMatchFunc { + return func(in reflect.Type) bool { + return match(d, in) + } +} + func getBindingsFor(inputs []reflect.Type, deps []*Dependency, disablePayloadAutoBinding bool, paramsCount int) (bindings []*binding) { // Path parameter start index is the result of [total path parameters] - [total func path parameters inputs], // moving from last to first path parameters and first to last (probably) available input args. @@ -174,7 +187,7 @@ func getBindingsFor(inputs []reflect.Type, deps []*Dependency, disablePayloadAut continue } - match := matchDependency(d, in) + match := d.Match(in) if !match { continue } @@ -278,7 +291,7 @@ func getBindingsForFunc(fn reflect.Value, dependencies []*Dependency, disablePay return bindings } -func getBindingsForStruct(v reflect.Value, dependencies []*Dependency, markExportedFieldsAsRequired bool, disablePayloadAutoBinding bool, paramsCount int, sorter Sorter) (bindings []*binding) { +func getBindingsForStruct(v reflect.Value, dependencies []*Dependency, markExportedFieldsAsRequired bool, disablePayloadAutoBinding bool, matchDependency DependencyMatcher, paramsCount int, sorter Sorter) (bindings []*binding) { typ := indirectType(v.Type()) if typ.Kind() != reflect.Struct { panic("bindings: unresolved: no struct type") @@ -290,7 +303,7 @@ func getBindingsForStruct(v reflect.Value, dependencies []*Dependency, markExpor for _, f := range nonZero { // fmt.Printf("Controller [%s] | NonZero | Field Index: %v | Field Type: %s\n", typ, f.Index, f.Type) bindings = append(bindings, &binding{ - Dependency: newDependency(elem.FieldByIndex(f.Index).Interface(), disablePayloadAutoBinding), + Dependency: newDependency(elem.FieldByIndex(f.Index).Interface(), disablePayloadAutoBinding, nil), Input: newStructFieldInput(f), }) } diff --git a/hero/binding_test.go b/hero/binding_test.go index 95a736e6..6781f807 100644 --- a/hero/binding_test.go +++ b/hero/binding_test.go @@ -524,7 +524,7 @@ func TestBindingsForStruct(t *testing.T) { } for i, tt := range tests { - bindings := getBindingsForStruct(reflect.ValueOf(tt.Value), tt.Registered, false, false, 0, nil) + bindings := getBindingsForStruct(reflect.ValueOf(tt.Value), tt.Registered, false, false, DefaultDependencyMatcher, 0, nil) if expected, got := len(tt.Expected), len(bindings); expected != got { t.Logf("[%d] expected bindings length to be: %d but got: %d:\n", i, expected, got) @@ -565,5 +565,5 @@ func TestBindingsForStructMarkExportedFieldsAsRequred(t *testing.T) { } // should panic if fail. - _ = getBindingsForStruct(reflect.ValueOf(new(controller)), dependencies, true, true, 0, nil) + _ = getBindingsForStruct(reflect.ValueOf(new(controller)), dependencies, true, true, DefaultDependencyMatcher, 0, nil) } diff --git a/hero/container.go b/hero/container.go index b34a5c12..1e53d005 100644 --- a/hero/container.go +++ b/hero/container.go @@ -53,6 +53,9 @@ type Container struct { // set to true to disable that kind of behavior. DisablePayloadAutoBinding bool + // DependencyMatcher holds the function that compares equality between + // a dependency with an input. Defaults to DefaultMatchDependencyFunc. + DependencyMatcher DependencyMatcher // GetErrorHandler should return a valid `ErrorHandler` to handle bindings AND handler dispatch errors. // Defaults to a functon which returns the `DefaultErrorHandler`. GetErrorHandler func(*context.Context) ErrorHandler // cannot be nil. @@ -139,11 +142,11 @@ func (c *Container) fillReport(fullName string, bindings []*binding) { // Contains the iris context, standard context, iris sessions and time dependencies. var BuiltinDependencies = []*Dependency{ // iris context dependency. - newDependency(func(ctx *context.Context) *context.Context { return ctx }, true).Explicitly(), + newDependency(func(ctx *context.Context) *context.Context { return ctx }, true, nil).Explicitly(), // standard context dependency. newDependency(func(ctx *context.Context) stdContext.Context { return ctx.Request().Context() - }, true).Explicitly(), + }, true, nil).Explicitly(), // iris session dependency. newDependency(func(ctx *context.Context) *sessions.Session { session := sessions.Get(ctx) @@ -156,35 +159,35 @@ var BuiltinDependencies = []*Dependency{ } return session - }, true).Explicitly(), + }, true, nil).Explicitly(), // application's logger. newDependency(func(ctx *context.Context) *golog.Logger { return ctx.Application().Logger() - }, true).Explicitly(), + }, true, nil).Explicitly(), // time.Time to time.Now dependency. newDependency(func(ctx *context.Context) time.Time { return time.Now() - }, true).Explicitly(), + }, true, nil).Explicitly(), // standard http Request dependency. newDependency(func(ctx *context.Context) *http.Request { return ctx.Request() - }, true).Explicitly(), + }, true, nil).Explicitly(), // standard http ResponseWriter dependency. newDependency(func(ctx *context.Context) http.ResponseWriter { return ctx.ResponseWriter() - }, true).Explicitly(), + }, true, nil).Explicitly(), // http headers dependency. newDependency(func(ctx *context.Context) http.Header { return ctx.Request().Header - }, true).Explicitly(), + }, true, nil).Explicitly(), // Client IP. newDependency(func(ctx *context.Context) net.IP { return net.ParseIP(ctx.RemoteAddr()) - }, true).Explicitly(), + }, true, nil).Explicitly(), // Status Code (special type for MVC HTTP Error handler to not conflict with path parameters) newDependency(func(ctx *context.Context) Code { return Code(ctx.GetStatusCode()) - }, true).Explicitly(), + }, true, nil).Explicitly(), // Context Error. May be nil newDependency(func(ctx *context.Context) Err { err := ctx.GetErr() @@ -192,7 +195,7 @@ var BuiltinDependencies = []*Dependency{ return nil } return err - }, true).Explicitly(), + }, true, nil).Explicitly(), // Context User, e.g. from basic authentication. newDependency(func(ctx *context.Context) context.User { u := ctx.User() @@ -201,7 +204,7 @@ var BuiltinDependencies = []*Dependency{ } return u - }, true), + }, true, nil), // payload and param bindings are dynamically allocated and declared at the end of the `binding` source file. } @@ -218,6 +221,7 @@ func New(dependencies ...interface{}) *Container { GetErrorHandler: func(*context.Context) ErrorHandler { return DefaultErrorHandler }, + DependencyMatcher: DefaultDependencyMatcher, } for _, dependency := range dependencies { @@ -240,6 +244,7 @@ func (c *Container) Clone() *Container { cloned.Logger = c.Logger cloned.GetErrorHandler = c.GetErrorHandler cloned.Sorter = c.Sorter + cloned.DependencyMatcher = c.DependencyMatcher clonedDeps := make([]*Dependency, len(c.Dependencies)) copy(clonedDeps, c.Dependencies) cloned.Dependencies = clonedDeps @@ -281,7 +286,7 @@ func Register(dependency interface{}) *Dependency { // - Register(func(ctx iris.Context) User {...}) // - Register(func(User) OtherResponse {...}) func (c *Container) Register(dependency interface{}) *Dependency { - d := newDependency(dependency, c.DisablePayloadAutoBinding, c.Dependencies...) + d := newDependency(dependency, c.DisablePayloadAutoBinding, c.DependencyMatcher, c.Dependencies...) if d.DestType == nil { // prepend the dynamic dependency so it will be tried at the end // (we don't care about performance here, design-time) @@ -376,7 +381,7 @@ func (c *Container) Inject(toPtr interface{}) error { typ := val.Type() for _, d := range c.Dependencies { - if d.Static && matchDependency(d, typ) { + if d.Static && c.DependencyMatcher(d, typ) { v, err := d.Handle(nil, &Input{Type: typ}) if err != nil { if err == ErrSeeOther { diff --git a/hero/dependency.go b/hero/dependency.go index e87689de..48527832 100644 --- a/hero/dependency.go +++ b/hero/dependency.go @@ -10,7 +10,15 @@ import ( type ( // DependencyHandler is the native function declaration which implementors should return a value match to an input. - DependencyHandler func(ctx *context.Context, input *Input) (reflect.Value, error) + DependencyHandler = func(ctx *context.Context, input *Input) (reflect.Value, error) + + // DependencyMatchFunc type alias describes dependency + // match function with an input (field or parameter). + // + // See "DependencyMatcher" too, which can be used on a Container to + // change the way dependencies are matched to inputs for all dependencies. + DependencyMatchFunc = func(in reflect.Type) bool + // Dependency describes the design-time dependency to be injected at serve time. // Contains its source location, the dependency handler (provider) itself and information // such as static for static struct values or explicit to bind a value to its exact DestType and not if just assignable to it (interfaces). @@ -26,6 +34,9 @@ type ( // Example of use case: depenendency like time.Time that we want to be bindable // only to time.Time inputs and not to a service with a `String() string` method that time.Time struct implements too. Explicit bool + + // Match holds the matcher. Defaults to the Container's one. + Match DependencyMatchFunc } ) @@ -50,11 +61,11 @@ func (d *Dependency) String() string { // NewDependency converts a function or a function which accepts other dependencies or static struct value to a *Dependency. // // See `Container.Handler` for more. -func NewDependency(dependency interface{}, funcDependencies ...*Dependency) *Dependency { - return newDependency(dependency, false, funcDependencies...) +func NewDependency(dependency interface{}, funcDependencies ...*Dependency) *Dependency { // used only on tests. + return newDependency(dependency, false, nil, funcDependencies...) } -func newDependency(dependency interface{}, disablePayloadAutoBinding bool, funcDependencies ...*Dependency) *Dependency { +func newDependency(dependency interface{}, disablePayloadAutoBinding bool, matchDependency DependencyMatcher, funcDependencies ...*Dependency) *Dependency { if dependency == nil { panic(fmt.Sprintf("bad value: nil: %T", dependency)) } @@ -69,10 +80,15 @@ func newDependency(dependency interface{}, disablePayloadAutoBinding bool, funcD panic(fmt.Sprintf("bad value: %#+v", dependency)) } + if matchDependency == nil { + matchDependency = DefaultDependencyMatcher + } + dest := &Dependency{ Source: newSource(v), OriginalValue: dependency, } + dest.Match = ToDependencyMatchFunc(dest, matchDependency) if !resolveDependency(v, disablePayloadAutoBinding, dest, funcDependencies...) { panic(fmt.Sprintf("bad value: could not resolve a dependency from: %#+v", dependency)) @@ -223,7 +239,7 @@ func fromDependentFunc(v reflect.Value, disablePayloadAutoBinding bool, dest *De if numIn == len(bindings) { static := true for _, b := range bindings { - if !b.Dependency.Static && matchDependency(b.Dependency, typ.In(b.Input.Index)) { + if !b.Dependency.Static && b.Dependency.Match(typ.In(b.Input.Index)) { static = false break } diff --git a/hero/struct.go b/hero/struct.go index c3ae376b..496463a9 100644 --- a/hero/struct.go +++ b/hero/struct.go @@ -51,7 +51,7 @@ func makeStruct(structPtr interface{}, c *Container, partyParamsCount int) *Stru } // get struct's fields bindings. - bindings := getBindingsForStruct(v, c.Dependencies, c.MarkExportedFieldsAsRequired, c.DisablePayloadAutoBinding, partyParamsCount, c.Sorter) + bindings := getBindingsForStruct(v, c.Dependencies, c.MarkExportedFieldsAsRequired, c.DisablePayloadAutoBinding, c.DependencyMatcher, partyParamsCount, c.Sorter) // length bindings of 0, means that it has no fields or all mapped deps are static. // If static then Struct.Acquire will return the same "value" instance, otherwise it will create a new one. diff --git a/sessions/sessiondb/redis/driver_goredis.go b/sessions/sessiondb/redis/driver_goredis.go index 7b020d8a..6477cd75 100644 --- a/sessions/sessiondb/redis/driver_goredis.go +++ b/sessions/sessiondb/redis/driver_goredis.go @@ -85,7 +85,7 @@ func (r *GoRedisDriver) mergeClusterOptions(c Config) *ClusterOptions { } if opts.Password == "" { - opts.Username = c.Password + opts.Password = c.Password } if opts.ReadTimeout == 0 {