add TLSConfig on redis configuration #1515

Former-commit-id: 3ce4a43185c7b6b5250f49483d7d229ea9dd1670
This commit is contained in:
Gerasimos (Makis) Maropoulos 2020-05-17 03:25:32 +03:00
parent 571322f595
commit 21a013569f
6 changed files with 78 additions and 62 deletions

View File

@ -974,12 +974,14 @@ func (api *APIBuilder) Favicon(favPath string, requestPath ...string) *Route {
// OnErrorCode registers a handlers chain for this `Party` for a specific HTTP status code. // OnErrorCode registers a handlers chain for this `Party` for a specific HTTP status code.
// Read more at: http://www.iana.org/assignments/http-status-codes/http-status-codes.xhtml // Read more at: http://www.iana.org/assignments/http-status-codes/http-status-codes.xhtml
// Look `OnAnyErrorCode` too. // Look `OnAnyErrorCode` too.
func (api *APIBuilder) OnErrorCode(statusCode int, handlers ...context.Handler) { func (api *APIBuilder) OnErrorCode(statusCode int, handlers ...context.Handler) (routes []*Route) {
api.handle(statusCode, "", "/", handlers...) routes = append(routes, api.handle(statusCode, "", "/", handlers...))
if api.relativePath != "/" { if api.relativePath != "/" {
api.handle(statusCode, "", "/{tail:path}", handlers...) routes = append(routes, api.handle(statusCode, "", "/{tail:path}", handlers...))
} }
return
} }
// ClientErrorCodes holds the 4xx Client errors. // ClientErrorCodes holds the 4xx Client errors.
@ -1083,10 +1085,12 @@ func StatusText(code int) string {
// OnAnyErrorCode registers a handlers chain for all error codes // OnAnyErrorCode registers a handlers chain for all error codes
// (4xxx and 5xxx, change the `ClientErrorCodes` and `ServerErrorCodes` variables to modify those) // (4xxx and 5xxx, change the `ClientErrorCodes` and `ServerErrorCodes` variables to modify those)
// Look `OnErrorCode` too. // Look `OnErrorCode` too.
func (api *APIBuilder) OnAnyErrorCode(handlers ...context.Handler) { func (api *APIBuilder) OnAnyErrorCode(handlers ...context.Handler) (routes []*Route) {
for _, statusCode := range append(ClientErrorCodes, ServerErrorCodes...) { for _, statusCode := range append(ClientErrorCodes, ServerErrorCodes...) {
api.OnErrorCode(statusCode, handlers...) routes = append(routes, api.OnErrorCode(statusCode, handlers...)...)
} }
return
} }
// Layout overrides the parent template layout with a more specific layout for this Party. // Layout overrides the parent template layout with a more specific layout for this Party.

View File

@ -97,6 +97,14 @@ func (api *APIContainer) convertHandlerFuncs(relativePath string, handlersFn ...
return handlers return handlers
} }
func fixRouteInfo(route *Route, handlersFn []interface{}) {
// Fix main handler name and source modified by execution rules wrapper.
route.MainHandlerName, route.MainHandlerIndex = context.MainHandlerName(handlersFn...)
if len(handlersFn) > route.MainHandlerIndex {
route.SourceFileName, route.SourceLineNumber = context.HandlerFileLineRel(handlersFn[route.MainHandlerIndex])
}
}
// Handler receives a function which can receive dependencies and output result // Handler receives a function which can receive dependencies and output result
// and returns a common Iris Handler, useful for Versioning API integration otherwise // and returns a common Iris Handler, useful for Versioning API integration otherwise
// the `Handle/Get/Post...` methods are preferable. // the `Handle/Get/Post...` methods are preferable.
@ -134,13 +142,7 @@ func (api *APIContainer) Done(handlersFn ...interface{}) {
func (api *APIContainer) Handle(method, relativePath string, handlersFn ...interface{}) *Route { func (api *APIContainer) Handle(method, relativePath string, handlersFn ...interface{}) *Route {
handlers := api.convertHandlerFuncs(relativePath, handlersFn...) handlers := api.convertHandlerFuncs(relativePath, handlersFn...)
route := api.Self.Handle(method, relativePath, handlers...) route := api.Self.Handle(method, relativePath, handlers...)
fixRouteInfo(route, handlersFn)
// Fix main handler name and source modified by execution rules wrapper.
route.MainHandlerName, route.MainHandlerIndex = context.MainHandlerName(handlersFn...)
if len(handlersFn) > route.MainHandlerIndex {
route.SourceFileName, route.SourceLineNumber = context.HandlerFileLineRel(handlersFn[route.MainHandlerIndex])
}
return route return route
} }

View File

@ -36,11 +36,11 @@ type Party interface {
// OnErrorCode registers a handlers chain for this `Party` for a specific HTTP status code. // OnErrorCode registers a handlers chain for this `Party` for a specific HTTP status code.
// Read more at: http://www.iana.org/assignments/http-status-codes/http-status-codes.xhtml // Read more at: http://www.iana.org/assignments/http-status-codes/http-status-codes.xhtml
// Look `OnAnyErrorCode` too. // Look `OnAnyErrorCode` too.
OnErrorCode(statusCode int, handlers ...context.Handler) OnErrorCode(statusCode int, handlers ...context.Handler) []*Route
// OnAnyErrorCode registers a handlers chain for all error codes // OnAnyErrorCode registers a handlers chain for all error codes
// (4xxx and 5xxx, change the `ClientErrorCodes` and `ServerErrorCodes` variables to modify those) // (4xxx and 5xxx, change the `ClientErrorCodes` and `ServerErrorCodes` variables to modify those)
// Look `OnErrorCode` too. // Look `OnErrorCode` too.
OnAnyErrorCode(handlers ...context.Handler) OnAnyErrorCode(handlers ...context.Handler) []*Route
// Party groups routes which may have the same prefix and share same handlers, // Party groups routes which may have the same prefix and share same handlers,
// returns that new rich subrouter. // returns that new rich subrouter.

View File

@ -1,6 +1,7 @@
package redis package redis
import ( import (
"crypto/tls"
"errors" "errors"
"time" "time"
@ -45,6 +46,11 @@ type Config struct {
// Delim the delimeter for the keys on the sessiondb. Defaults to "-". // Delim the delimeter for the keys on the sessiondb. Defaults to "-".
Delim string Delim string
// TLSConfig will cause Dial to perform a TLS handshake using the provided
// config. If is nil then no TLS is used.
// See https://golang.org/pkg/crypto/tls/#Config
TLSConfig *tls.Config
// Driver supports `Redigo()` or `Radix()` go clients for redis. // Driver supports `Redigo()` or `Radix()` go clients for redis.
// Configure each driver by the return value of their constructors. // Configure each driver by the return value of their constructors.
// //
@ -63,6 +69,7 @@ func DefaultConfig() Config {
Timeout: DefaultRedisTimeout, Timeout: DefaultRedisTimeout,
Prefix: "", Prefix: "",
Delim: DefaultDelim, Delim: DefaultDelim,
TLSConfig: nil,
Driver: Redigo(), Driver: Redigo(),
} }
} }

View File

@ -45,6 +45,10 @@ func (r *RadixDriver) Connect(c Config) error {
var options []radix.DialOpt var options []radix.DialOpt
if c.TLSConfig != nil {
options = append(options, radix.DialUseTLS(c.TLSConfig))
}
if c.Password != "" { if c.Password != "" {
options = append(options, radix.DialAuthPass(c.Password)) options = append(options, radix.DialAuthPass(c.Password))
} }

View File

@ -271,62 +271,61 @@ func (r *RedigoDriver) Delete(key string) error {
return err return err
} }
func dial(network string, addr string, pass string, timeout time.Duration) (redis.Conn, error) { // Connect connects to the redis, called only once.
if network == "" {
network = DefaultRedisNetwork
}
if addr == "" {
addr = DefaultRedisAddr
}
var options []redis.DialOption
if timeout > 0 {
options = append(options,
redis.DialConnectTimeout(timeout),
redis.DialReadTimeout(timeout),
redis.DialWriteTimeout(timeout))
}
c, err := redis.Dial(network, addr, options...)
if err != nil {
return nil, err
}
if pass != "" {
if _, err = c.Do("AUTH", pass); err != nil {
c.Close()
return nil, err
}
}
return c, err
}
// Connect connects to the redis, called only once
func (r *RedigoDriver) Connect(c Config) error { func (r *RedigoDriver) Connect(c Config) error {
if c.Network == "" {
c.Network = DefaultRedisNetwork
}
if c.Addr == "" {
c.Addr = DefaultRedisAddr
}
pool := &redis.Pool{IdleTimeout: r.IdleTimeout, MaxIdle: r.MaxIdle, Wait: r.Wait, MaxActive: c.MaxActive} pool := &redis.Pool{IdleTimeout: r.IdleTimeout, MaxIdle: r.MaxIdle, Wait: r.Wait, MaxActive: c.MaxActive}
pool.TestOnBorrow = func(c redis.Conn, t time.Time) error { pool.TestOnBorrow = func(c redis.Conn, t time.Time) error {
_, err := c.Do("PING") _, err := c.Do("PING")
return err return err
} }
if c.Database != "" { var options []redis.DialOption
pool.Dial = func() (redis.Conn, error) {
red, err := dial(c.Network, c.Addr, c.Password, c.Timeout) if c.Timeout > 0 {
if err != nil { options = append(options,
return nil, err redis.DialConnectTimeout(c.Timeout),
} redis.DialReadTimeout(c.Timeout),
if _, err = red.Do("SELECT", c.Database); err != nil { redis.DialWriteTimeout(c.Timeout))
red.Close()
return nil, err
}
return red, err
}
} else {
pool.Dial = func() (redis.Conn, error) {
return dial(c.Network, c.Addr, c.Password, c.Timeout)
}
} }
if c.TLSConfig != nil {
options = append(options,
redis.DialTLSConfig(c.TLSConfig),
redis.DialUseTLS(true),
)
}
pool.Dial = func() (redis.Conn, error) {
conn, err := redis.Dial(c.Network, c.Addr, options...)
if err != nil {
return nil, err
}
if c.Password != "" {
if _, err = conn.Do("AUTH", c.Password); err != nil {
conn.Close()
return nil, err
}
}
if c.Database != "" {
if _, err = conn.Do("SELECT", c.Database); err != nil {
conn.Close()
return nil, err
}
}
return conn, err
}
r.Connected = true r.Connected = true
r.pool = pool r.pool = pool
r.Config = c r.Config = c