Add Party.Container.SetDependencyMatcher, hero.Container.DependencyMatcher and hero.Dependency.Match to fullfil the feature request asked at: #1842

This commit is contained in:
Gerasimos (Makis) Maropoulos 2022-03-01 21:26:02 +02:00
parent 5ce8475f35
commit 61872a1612
No known key found for this signature in database
GPG Key ID: 66FCC29BD385FCA6
9 changed files with 79 additions and 30 deletions

View File

@ -28,6 +28,8 @@ The codebase for Dependency Injection, Internationalization and localization and
## Fixes and Improvements ## Fixes and Improvements
- Add `Container.DependencyMatcher` and `Dependency.Match` to implement the feature requested at [issues/#1842](https://github.com/kataras/iris/issues/1842).
- Register [CORS middleware](middleware/cors) to the Application by default when `iris.Default()` is used instead of `iris.New()`. - Register [CORS middleware](middleware/cors) to the Application by default when `iris.Default()` is used instead of `iris.New()`.
- Add [x/jsonx: DayTime](/x/jsonx/day_time.go) for JSON marshal and unmarshal of "15:04:05" (hour, minute, second). - Add [x/jsonx: DayTime](/x/jsonx/day_time.go) for JSON marshal and unmarshal of "15:04:05" (hour, minute, second).

View File

@ -4,11 +4,11 @@ import (
"time" "time"
"github.com/kataras/iris/v12" "github.com/kataras/iris/v12"
"github.com/r3labs/sse" "github.com/r3labs/sse/v2"
) )
// First of all install the sse third-party package (you can use other if you don't like this approach or go ahead to the "sse" example) // First of all install the sse third-party package (you can use other if you don't like this approach or go ahead to the "sse" example)
// $ go get -u github.com/r3labs/sse // $ go get github.com/r3labs/sse/v2@v2.7.4
func main() { func main() {
app := iris.New() app := iris.New()
s := sse.New() s := sse.New()
@ -22,7 +22,7 @@ func main() {
*/ */
s.CreateStream("messages") s.CreateStream("messages")
app.Any("/events", iris.FromStd(s.HTTPHandler)) app.Any("/events", iris.FromStd(s))
go func() { go func() {
// You design when to send messages to the client, // You design when to send messages to the client,

View File

@ -95,6 +95,19 @@ func (api *APIContainer) EnableStrictMode(strictMode bool) *APIContainer {
return api return api
} }
// SetDependencyMatcher replaces the function that compares equality between
// a dependency and an input (struct field or function parameter).
//
// Defaults to hero.DefaultMatchDependencyFunc.
func (api *APIContainer) SetDependencyMatcher(fn hero.DependencyMatcher) *APIContainer {
if fn == nil {
panic("api container: set dependency matcher: fn cannot be nil")
}
api.Container.DependencyMatcher = fn
return api
}
// convertHandlerFuncs accepts Iris hero handlers and returns a slice of native Iris handlers. // convertHandlerFuncs accepts Iris hero handlers and returns a slice of native Iris handlers.
func (api *APIContainer) convertHandlerFuncs(relativePath string, handlersFn ...interface{}) context.Handlers { func (api *APIContainer) convertHandlerFuncs(relativePath string, handlersFn ...interface{}) context.Handlers {
fullpath := api.Self.GetRelPath() + relativePath fullpath := api.Self.GetRelPath() + relativePath

View File

@ -110,7 +110,12 @@ func (b *binding) Equal(other *binding) bool {
return true return true
} }
func matchDependency(dep *Dependency, in reflect.Type) bool { // DependencyMatcher type alias describes a dependency match function.
type DependencyMatcher = func(*Dependency, reflect.Type) bool
// DefaultDependencyMatcher is the default dependency match function for all DI containers.
// It is used to collect dependencies from struct's fields and function's parameters.
var DefaultDependencyMatcher = func(dep *Dependency, in reflect.Type) bool {
if dep.Explicit { if dep.Explicit {
return dep.DestType == in return dep.DestType == in
} }
@ -118,6 +123,14 @@ func matchDependency(dep *Dependency, in reflect.Type) bool {
return dep.DestType == nil || equalTypes(dep.DestType, in) return dep.DestType == nil || equalTypes(dep.DestType, in)
} }
// ToDependencyMatchFunc converts a DependencyMatcher (generic for all dependencies)
// to a dependency-specific input matcher.
func ToDependencyMatchFunc(d *Dependency, match DependencyMatcher) DependencyMatchFunc {
return func(in reflect.Type) bool {
return match(d, in)
}
}
func getBindingsFor(inputs []reflect.Type, deps []*Dependency, disablePayloadAutoBinding bool, paramsCount int) (bindings []*binding) { func getBindingsFor(inputs []reflect.Type, deps []*Dependency, disablePayloadAutoBinding bool, paramsCount int) (bindings []*binding) {
// Path parameter start index is the result of [total path parameters] - [total func path parameters inputs], // Path parameter start index is the result of [total path parameters] - [total func path parameters inputs],
// moving from last to first path parameters and first to last (probably) available input args. // moving from last to first path parameters and first to last (probably) available input args.
@ -174,7 +187,7 @@ func getBindingsFor(inputs []reflect.Type, deps []*Dependency, disablePayloadAut
continue continue
} }
match := matchDependency(d, in) match := d.Match(in)
if !match { if !match {
continue continue
} }
@ -278,7 +291,7 @@ func getBindingsForFunc(fn reflect.Value, dependencies []*Dependency, disablePay
return bindings return bindings
} }
func getBindingsForStruct(v reflect.Value, dependencies []*Dependency, markExportedFieldsAsRequired bool, disablePayloadAutoBinding bool, paramsCount int, sorter Sorter) (bindings []*binding) { func getBindingsForStruct(v reflect.Value, dependencies []*Dependency, markExportedFieldsAsRequired bool, disablePayloadAutoBinding bool, matchDependency DependencyMatcher, paramsCount int, sorter Sorter) (bindings []*binding) {
typ := indirectType(v.Type()) typ := indirectType(v.Type())
if typ.Kind() != reflect.Struct { if typ.Kind() != reflect.Struct {
panic("bindings: unresolved: no struct type") panic("bindings: unresolved: no struct type")
@ -290,7 +303,7 @@ func getBindingsForStruct(v reflect.Value, dependencies []*Dependency, markExpor
for _, f := range nonZero { for _, f := range nonZero {
// fmt.Printf("Controller [%s] | NonZero | Field Index: %v | Field Type: %s\n", typ, f.Index, f.Type) // fmt.Printf("Controller [%s] | NonZero | Field Index: %v | Field Type: %s\n", typ, f.Index, f.Type)
bindings = append(bindings, &binding{ bindings = append(bindings, &binding{
Dependency: newDependency(elem.FieldByIndex(f.Index).Interface(), disablePayloadAutoBinding), Dependency: newDependency(elem.FieldByIndex(f.Index).Interface(), disablePayloadAutoBinding, nil),
Input: newStructFieldInput(f), Input: newStructFieldInput(f),
}) })
} }

View File

@ -524,7 +524,7 @@ func TestBindingsForStruct(t *testing.T) {
} }
for i, tt := range tests { for i, tt := range tests {
bindings := getBindingsForStruct(reflect.ValueOf(tt.Value), tt.Registered, false, false, 0, nil) bindings := getBindingsForStruct(reflect.ValueOf(tt.Value), tt.Registered, false, false, DefaultDependencyMatcher, 0, nil)
if expected, got := len(tt.Expected), len(bindings); expected != got { if expected, got := len(tt.Expected), len(bindings); expected != got {
t.Logf("[%d] expected bindings length to be: %d but got: %d:\n", i, expected, got) t.Logf("[%d] expected bindings length to be: %d but got: %d:\n", i, expected, got)
@ -565,5 +565,5 @@ func TestBindingsForStructMarkExportedFieldsAsRequred(t *testing.T) {
} }
// should panic if fail. // should panic if fail.
_ = getBindingsForStruct(reflect.ValueOf(new(controller)), dependencies, true, true, 0, nil) _ = getBindingsForStruct(reflect.ValueOf(new(controller)), dependencies, true, true, DefaultDependencyMatcher, 0, nil)
} }

View File

@ -53,6 +53,9 @@ type Container struct {
// set to true to disable that kind of behavior. // set to true to disable that kind of behavior.
DisablePayloadAutoBinding bool DisablePayloadAutoBinding bool
// DependencyMatcher holds the function that compares equality between
// a dependency with an input. Defaults to DefaultMatchDependencyFunc.
DependencyMatcher DependencyMatcher
// GetErrorHandler should return a valid `ErrorHandler` to handle bindings AND handler dispatch errors. // GetErrorHandler should return a valid `ErrorHandler` to handle bindings AND handler dispatch errors.
// Defaults to a functon which returns the `DefaultErrorHandler`. // Defaults to a functon which returns the `DefaultErrorHandler`.
GetErrorHandler func(*context.Context) ErrorHandler // cannot be nil. GetErrorHandler func(*context.Context) ErrorHandler // cannot be nil.
@ -139,11 +142,11 @@ func (c *Container) fillReport(fullName string, bindings []*binding) {
// Contains the iris context, standard context, iris sessions and time dependencies. // Contains the iris context, standard context, iris sessions and time dependencies.
var BuiltinDependencies = []*Dependency{ var BuiltinDependencies = []*Dependency{
// iris context dependency. // iris context dependency.
newDependency(func(ctx *context.Context) *context.Context { return ctx }, true).Explicitly(), newDependency(func(ctx *context.Context) *context.Context { return ctx }, true, nil).Explicitly(),
// standard context dependency. // standard context dependency.
newDependency(func(ctx *context.Context) stdContext.Context { newDependency(func(ctx *context.Context) stdContext.Context {
return ctx.Request().Context() return ctx.Request().Context()
}, true).Explicitly(), }, true, nil).Explicitly(),
// iris session dependency. // iris session dependency.
newDependency(func(ctx *context.Context) *sessions.Session { newDependency(func(ctx *context.Context) *sessions.Session {
session := sessions.Get(ctx) session := sessions.Get(ctx)
@ -156,35 +159,35 @@ var BuiltinDependencies = []*Dependency{
} }
return session return session
}, true).Explicitly(), }, true, nil).Explicitly(),
// application's logger. // application's logger.
newDependency(func(ctx *context.Context) *golog.Logger { newDependency(func(ctx *context.Context) *golog.Logger {
return ctx.Application().Logger() return ctx.Application().Logger()
}, true).Explicitly(), }, true, nil).Explicitly(),
// time.Time to time.Now dependency. // time.Time to time.Now dependency.
newDependency(func(ctx *context.Context) time.Time { newDependency(func(ctx *context.Context) time.Time {
return time.Now() return time.Now()
}, true).Explicitly(), }, true, nil).Explicitly(),
// standard http Request dependency. // standard http Request dependency.
newDependency(func(ctx *context.Context) *http.Request { newDependency(func(ctx *context.Context) *http.Request {
return ctx.Request() return ctx.Request()
}, true).Explicitly(), }, true, nil).Explicitly(),
// standard http ResponseWriter dependency. // standard http ResponseWriter dependency.
newDependency(func(ctx *context.Context) http.ResponseWriter { newDependency(func(ctx *context.Context) http.ResponseWriter {
return ctx.ResponseWriter() return ctx.ResponseWriter()
}, true).Explicitly(), }, true, nil).Explicitly(),
// http headers dependency. // http headers dependency.
newDependency(func(ctx *context.Context) http.Header { newDependency(func(ctx *context.Context) http.Header {
return ctx.Request().Header return ctx.Request().Header
}, true).Explicitly(), }, true, nil).Explicitly(),
// Client IP. // Client IP.
newDependency(func(ctx *context.Context) net.IP { newDependency(func(ctx *context.Context) net.IP {
return net.ParseIP(ctx.RemoteAddr()) return net.ParseIP(ctx.RemoteAddr())
}, true).Explicitly(), }, true, nil).Explicitly(),
// Status Code (special type for MVC HTTP Error handler to not conflict with path parameters) // Status Code (special type for MVC HTTP Error handler to not conflict with path parameters)
newDependency(func(ctx *context.Context) Code { newDependency(func(ctx *context.Context) Code {
return Code(ctx.GetStatusCode()) return Code(ctx.GetStatusCode())
}, true).Explicitly(), }, true, nil).Explicitly(),
// Context Error. May be nil // Context Error. May be nil
newDependency(func(ctx *context.Context) Err { newDependency(func(ctx *context.Context) Err {
err := ctx.GetErr() err := ctx.GetErr()
@ -192,7 +195,7 @@ var BuiltinDependencies = []*Dependency{
return nil return nil
} }
return err return err
}, true).Explicitly(), }, true, nil).Explicitly(),
// Context User, e.g. from basic authentication. // Context User, e.g. from basic authentication.
newDependency(func(ctx *context.Context) context.User { newDependency(func(ctx *context.Context) context.User {
u := ctx.User() u := ctx.User()
@ -201,7 +204,7 @@ var BuiltinDependencies = []*Dependency{
} }
return u return u
}, true), }, true, nil),
// payload and param bindings are dynamically allocated and declared at the end of the `binding` source file. // payload and param bindings are dynamically allocated and declared at the end of the `binding` source file.
} }
@ -218,6 +221,7 @@ func New(dependencies ...interface{}) *Container {
GetErrorHandler: func(*context.Context) ErrorHandler { GetErrorHandler: func(*context.Context) ErrorHandler {
return DefaultErrorHandler return DefaultErrorHandler
}, },
DependencyMatcher: DefaultDependencyMatcher,
} }
for _, dependency := range dependencies { for _, dependency := range dependencies {
@ -240,6 +244,7 @@ func (c *Container) Clone() *Container {
cloned.Logger = c.Logger cloned.Logger = c.Logger
cloned.GetErrorHandler = c.GetErrorHandler cloned.GetErrorHandler = c.GetErrorHandler
cloned.Sorter = c.Sorter cloned.Sorter = c.Sorter
cloned.DependencyMatcher = c.DependencyMatcher
clonedDeps := make([]*Dependency, len(c.Dependencies)) clonedDeps := make([]*Dependency, len(c.Dependencies))
copy(clonedDeps, c.Dependencies) copy(clonedDeps, c.Dependencies)
cloned.Dependencies = clonedDeps cloned.Dependencies = clonedDeps
@ -281,7 +286,7 @@ func Register(dependency interface{}) *Dependency {
// - Register(func(ctx iris.Context) User {...}) // - Register(func(ctx iris.Context) User {...})
// - Register(func(User) OtherResponse {...}) // - Register(func(User) OtherResponse {...})
func (c *Container) Register(dependency interface{}) *Dependency { func (c *Container) Register(dependency interface{}) *Dependency {
d := newDependency(dependency, c.DisablePayloadAutoBinding, c.Dependencies...) d := newDependency(dependency, c.DisablePayloadAutoBinding, c.DependencyMatcher, c.Dependencies...)
if d.DestType == nil { if d.DestType == nil {
// prepend the dynamic dependency so it will be tried at the end // prepend the dynamic dependency so it will be tried at the end
// (we don't care about performance here, design-time) // (we don't care about performance here, design-time)
@ -376,7 +381,7 @@ func (c *Container) Inject(toPtr interface{}) error {
typ := val.Type() typ := val.Type()
for _, d := range c.Dependencies { for _, d := range c.Dependencies {
if d.Static && matchDependency(d, typ) { if d.Static && c.DependencyMatcher(d, typ) {
v, err := d.Handle(nil, &Input{Type: typ}) v, err := d.Handle(nil, &Input{Type: typ})
if err != nil { if err != nil {
if err == ErrSeeOther { if err == ErrSeeOther {

View File

@ -10,7 +10,15 @@ import (
type ( type (
// DependencyHandler is the native function declaration which implementors should return a value match to an input. // DependencyHandler is the native function declaration which implementors should return a value match to an input.
DependencyHandler func(ctx *context.Context, input *Input) (reflect.Value, error) DependencyHandler = func(ctx *context.Context, input *Input) (reflect.Value, error)
// DependencyMatchFunc type alias describes dependency
// match function with an input (field or parameter).
//
// See "DependencyMatcher" too, which can be used on a Container to
// change the way dependencies are matched to inputs for all dependencies.
DependencyMatchFunc = func(in reflect.Type) bool
// Dependency describes the design-time dependency to be injected at serve time. // Dependency describes the design-time dependency to be injected at serve time.
// Contains its source location, the dependency handler (provider) itself and information // Contains its source location, the dependency handler (provider) itself and information
// such as static for static struct values or explicit to bind a value to its exact DestType and not if just assignable to it (interfaces). // such as static for static struct values or explicit to bind a value to its exact DestType and not if just assignable to it (interfaces).
@ -26,6 +34,9 @@ type (
// Example of use case: depenendency like time.Time that we want to be bindable // Example of use case: depenendency like time.Time that we want to be bindable
// only to time.Time inputs and not to a service with a `String() string` method that time.Time struct implements too. // only to time.Time inputs and not to a service with a `String() string` method that time.Time struct implements too.
Explicit bool Explicit bool
// Match holds the matcher. Defaults to the Container's one.
Match DependencyMatchFunc
} }
) )
@ -50,11 +61,11 @@ func (d *Dependency) String() string {
// NewDependency converts a function or a function which accepts other dependencies or static struct value to a *Dependency. // NewDependency converts a function or a function which accepts other dependencies or static struct value to a *Dependency.
// //
// See `Container.Handler` for more. // See `Container.Handler` for more.
func NewDependency(dependency interface{}, funcDependencies ...*Dependency) *Dependency { func NewDependency(dependency interface{}, funcDependencies ...*Dependency) *Dependency { // used only on tests.
return newDependency(dependency, false, funcDependencies...) return newDependency(dependency, false, nil, funcDependencies...)
} }
func newDependency(dependency interface{}, disablePayloadAutoBinding bool, funcDependencies ...*Dependency) *Dependency { func newDependency(dependency interface{}, disablePayloadAutoBinding bool, matchDependency DependencyMatcher, funcDependencies ...*Dependency) *Dependency {
if dependency == nil { if dependency == nil {
panic(fmt.Sprintf("bad value: nil: %T", dependency)) panic(fmt.Sprintf("bad value: nil: %T", dependency))
} }
@ -69,10 +80,15 @@ func newDependency(dependency interface{}, disablePayloadAutoBinding bool, funcD
panic(fmt.Sprintf("bad value: %#+v", dependency)) panic(fmt.Sprintf("bad value: %#+v", dependency))
} }
if matchDependency == nil {
matchDependency = DefaultDependencyMatcher
}
dest := &Dependency{ dest := &Dependency{
Source: newSource(v), Source: newSource(v),
OriginalValue: dependency, OriginalValue: dependency,
} }
dest.Match = ToDependencyMatchFunc(dest, matchDependency)
if !resolveDependency(v, disablePayloadAutoBinding, dest, funcDependencies...) { if !resolveDependency(v, disablePayloadAutoBinding, dest, funcDependencies...) {
panic(fmt.Sprintf("bad value: could not resolve a dependency from: %#+v", dependency)) panic(fmt.Sprintf("bad value: could not resolve a dependency from: %#+v", dependency))
@ -223,7 +239,7 @@ func fromDependentFunc(v reflect.Value, disablePayloadAutoBinding bool, dest *De
if numIn == len(bindings) { if numIn == len(bindings) {
static := true static := true
for _, b := range bindings { for _, b := range bindings {
if !b.Dependency.Static && matchDependency(b.Dependency, typ.In(b.Input.Index)) { if !b.Dependency.Static && b.Dependency.Match(typ.In(b.Input.Index)) {
static = false static = false
break break
} }

View File

@ -51,7 +51,7 @@ func makeStruct(structPtr interface{}, c *Container, partyParamsCount int) *Stru
} }
// get struct's fields bindings. // get struct's fields bindings.
bindings := getBindingsForStruct(v, c.Dependencies, c.MarkExportedFieldsAsRequired, c.DisablePayloadAutoBinding, partyParamsCount, c.Sorter) bindings := getBindingsForStruct(v, c.Dependencies, c.MarkExportedFieldsAsRequired, c.DisablePayloadAutoBinding, c.DependencyMatcher, partyParamsCount, c.Sorter)
// length bindings of 0, means that it has no fields or all mapped deps are static. // length bindings of 0, means that it has no fields or all mapped deps are static.
// If static then Struct.Acquire will return the same "value" instance, otherwise it will create a new one. // If static then Struct.Acquire will return the same "value" instance, otherwise it will create a new one.

View File

@ -85,7 +85,7 @@ func (r *GoRedisDriver) mergeClusterOptions(c Config) *ClusterOptions {
} }
if opts.Password == "" { if opts.Password == "" {
opts.Username = c.Password opts.Password = c.Password
} }
if opts.ReadTimeout == 0 { if opts.ReadTimeout == 0 {