From 057fa462f496303e3ad05355d6e9bacbc27c5f9e Mon Sep 17 00:00:00 2001 From: "Gerasimos (Makis) Maropoulos" Date: Mon, 15 Aug 2022 00:58:49 +0300 Subject: [PATCH] allow setting a custom go-redis client at Iris go redis sessions driver as request at https://chat.iris-go.com --- HISTORY.md | 1 + _examples/sessions/database/redis/main.go | 7 +++- sessions/sessiondb/redis/database.go | 1 + sessions/sessiondb/redis/driver_goredis.go | 44 ++++++++++++++-------- 4 files changed, 37 insertions(+), 16 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 3c88a5b6..05cee57e 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -28,6 +28,7 @@ The codebase for Dependency Injection, Internationalization and localization and ## Fixes and Improvements +- Enable setting a custom "go-redis" client through `SetClient` go redis driver method or `Client` struct field on sessions/database/redis driver as requested at [chat](https://chat.iris-go.com). - Ignore `"csrf.token"` form data key when missing on `ctx.ReadForm` by default as requested at [#1941](https://github.com/kataras/iris/issues/1941). - Fix [CVE-2020-5398](https://github.com/advisories/GHSA-8wx2-9q48-vm9r). diff --git a/_examples/sessions/database/redis/main.go b/_examples/sessions/database/redis/main.go index 81a15c12..e7f7cea5 100644 --- a/_examples/sessions/database/redis/main.go +++ b/_examples/sessions/database/redis/main.go @@ -30,11 +30,16 @@ func main() { Password: "", Database: "", Prefix: "myapp-", - Driver: redis.GoRedis(), // defaults. + Driver: redis.GoRedis(), // defaults to this driver. + // To set a custom, existing go-redis client, use the "SetClient" method: + // Driver: redis.GoRedis().SetClient(customGoRedisClient) }) // Optionally configure the underline driver: // driver := redis.GoRedis() + // // To set a custom client: + // driver.SetClient(customGoRedisClient) + // OR: // driver.ClientOptions = redis.Options{...} // driver.ClusterOptions = redis.ClusterOptions{...} // redis.New(redis.Config{Driver: driver, ...}) diff --git a/sessions/sessiondb/redis/database.go b/sessions/sessiondb/redis/database.go index ed40b257..3307b1af 100644 --- a/sessions/sessiondb/redis/database.go +++ b/sessions/sessiondb/redis/database.go @@ -30,6 +30,7 @@ type Config struct { Addr string // Clusters a list of network addresses for clusters. // If not empty "Addr" is ignored and Redis clusters feature is used instead. + // Note that this field is ignored when setgging a custom `GoRedisClient`. Clusters []string // Use the specified Username to authenticate the current connection // with one of the connections defined in the ACL list when connecting diff --git a/sessions/sessiondb/redis/driver_goredis.go b/sessions/sessiondb/redis/driver_goredis.go index 6477cd75..50e994d0 100644 --- a/sessions/sessiondb/redis/driver_goredis.go +++ b/sessions/sessiondb/redis/driver_goredis.go @@ -27,7 +27,9 @@ type GoRedisClient interface { // for the go-redis redis driver. See driver.go file. type GoRedisDriver struct { // Both Client and ClusterClient implements this interface. - client GoRedisClient + // Custom one can be directly passed but if so, the + // Connect method does nothing (so all connection and client settings are ignored). + Client GoRedisClient // Customize any go-redis fields manually // before Connect. ClientOptions Options @@ -111,12 +113,24 @@ func (r *GoRedisDriver) mergeClusterOptions(c Config) *ClusterOptions { return &opts } +// SetClient sets an existing go redis client to the sessions redis driver. +// +// Returns itself. +func (r *GoRedisDriver) SetClient(goRedisClient GoRedisClient) *GoRedisDriver { + r.Client = goRedisClient + return r +} + // Connect initializes the redis client. func (r *GoRedisDriver) Connect(c Config) error { + if r.Client != nil { // if a custom one was given through SetClient. + return nil + } + if len(c.Clusters) > 0 { - r.client = redis.NewClusterClient(r.mergeClusterOptions(c)) + r.Client = redis.NewClusterClient(r.mergeClusterOptions(c)) } else { - r.client = redis.NewClient(r.mergeClientOptions(c)) + r.Client = redis.NewClient(r.mergeClientOptions(c)) } return nil @@ -125,29 +139,29 @@ func (r *GoRedisDriver) Connect(c Config) error { // PingPong sends a ping message and reports whether // the PONG message received successfully. func (r *GoRedisDriver) PingPong() (bool, error) { - pong, err := r.client.Ping(defaultContext).Result() + pong, err := r.Client.Ping(defaultContext).Result() return pong == "PONG", err } // CloseConnection terminates the underline redis connection. func (r *GoRedisDriver) CloseConnection() error { - return r.client.Close() + return r.Client.Close() } // Set stores a "value" based on the session's "key". // The value should be type of []byte, so unmarshal can happen. func (r *GoRedisDriver) Set(sid, key string, value interface{}) error { - return r.client.HSet(defaultContext, sid, key, value).Err() + return r.Client.HSet(defaultContext, sid, key, value).Err() } // Get returns the associated value of the session's given "key". func (r *GoRedisDriver) Get(sid, key string) (interface{}, error) { - return r.client.HGet(defaultContext, sid, key).Bytes() + return r.Client.HGet(defaultContext, sid, key).Bytes() } // Exists reports whether a session exists or not. func (r *GoRedisDriver) Exists(sid string) bool { - n, err := r.client.Exists(defaultContext, sid).Result() + n, err := r.Client.Exists(defaultContext, sid).Result() if err != nil { return false } @@ -157,7 +171,7 @@ func (r *GoRedisDriver) Exists(sid string) bool { // TTL returns any TTL value of the session. func (r *GoRedisDriver) TTL(sid string) time.Duration { - dur, err := r.client.TTL(defaultContext, sid).Result() + dur, err := r.Client.TTL(defaultContext, sid).Result() if err != nil { return 0 } @@ -167,29 +181,29 @@ func (r *GoRedisDriver) TTL(sid string) time.Duration { // UpdateTTL sets expiration duration of the session. func (r *GoRedisDriver) UpdateTTL(sid string, newLifetime time.Duration) error { - _, err := r.client.Expire(defaultContext, sid, newLifetime).Result() + _, err := r.Client.Expire(defaultContext, sid, newLifetime).Result() return err } // GetAll returns all the key values under the session. func (r *GoRedisDriver) GetAll(sid string) (map[string]string, error) { - return r.client.HGetAll(defaultContext, sid).Result() + return r.Client.HGetAll(defaultContext, sid).Result() } // GetKeys returns all keys under the session. func (r *GoRedisDriver) GetKeys(sid string) ([]string, error) { - return r.client.HKeys(defaultContext, sid).Result() + return r.Client.HKeys(defaultContext, sid).Result() } // Len returns the total length of key-values of the session. func (r *GoRedisDriver) Len(sid string) int { - return int(r.client.HLen(defaultContext, sid).Val()) + return int(r.Client.HLen(defaultContext, sid).Val()) } // Delete removes a value from the redis store. func (r *GoRedisDriver) Delete(sid, key string) error { if key == "" { - return r.client.Del(defaultContext, sid).Err() + return r.Client.Del(defaultContext, sid).Err() } - return r.client.HDel(defaultContext, sid, key).Err() + return r.Client.HDel(defaultContext, sid, key).Err() }