iris/sessions/sessiondb/redis/driver_goredis.go
Phước Trung 65bb98f055
Upgrade Redis client to v9 (#2102)
Signed-off-by: Gerasimos (Makis) Maropoulos <kataras2006@hotmail.com>
Co-authored-by: Gerasimos (Makis) Maropoulos <kataras2006@hotmail.com>
2023-03-17 11:03:18 +02:00

210 lines
4.9 KiB
Go

package redis
import (
stdContext "context"
"io"
"strconv"
"time"
"github.com/redis/go-redis/v9"
)
type (
// Options is just a type alias for the go-redis Client Options.
Options = redis.Options
// ClusterOptions is just a type alias for the go-redis Cluster Client Options.
ClusterOptions = redis.ClusterOptions
)
// GoRedisClient is the interface which both
// go-redis's Client and Cluster Client implements.
type GoRedisClient interface {
redis.Cmdable // Commands.
io.Closer // CloseConnection.
}
// GoRedisDriver implements the Sessions Database Driver
// for the go-redis redis driver. See driver.go file.
type GoRedisDriver struct {
// Both Client and ClusterClient implements this interface.
// Custom one can be directly passed but if so, the
// Connect method does nothing (so all connection and client settings are ignored).
Client GoRedisClient
// Customize any go-redis fields manually
// before Connect.
ClientOptions Options
ClusterOptions ClusterOptions
}
var defaultContext = stdContext.Background()
func (r *GoRedisDriver) mergeClientOptions(c Config) *Options {
opts := r.ClientOptions
if opts.Addr == "" {
opts.Addr = c.Addr
}
if opts.Username == "" {
opts.Username = c.Username
}
if opts.Password == "" {
opts.Password = c.Password
}
if opts.DB == 0 {
opts.DB, _ = strconv.Atoi(c.Database)
}
if opts.ReadTimeout == 0 {
opts.ReadTimeout = c.Timeout
}
if opts.WriteTimeout == 0 {
opts.WriteTimeout = c.Timeout
}
if opts.Network == "" {
opts.Network = c.Network
}
if opts.TLSConfig == nil {
opts.TLSConfig = c.TLSConfig
}
if opts.PoolSize == 0 {
opts.PoolSize = c.MaxActive
}
return &opts
}
func (r *GoRedisDriver) mergeClusterOptions(c Config) *ClusterOptions {
opts := r.ClusterOptions
if opts.Username == "" {
opts.Username = c.Username
}
if opts.Password == "" {
opts.Password = c.Password
}
if opts.ReadTimeout == 0 {
opts.ReadTimeout = c.Timeout
}
if opts.WriteTimeout == 0 {
opts.WriteTimeout = c.Timeout
}
if opts.TLSConfig == nil {
opts.TLSConfig = c.TLSConfig
}
if opts.PoolSize == 0 {
opts.PoolSize = c.MaxActive
}
if len(opts.Addrs) == 0 {
opts.Addrs = c.Clusters
}
return &opts
}
// SetClient sets an existing go redis client to the sessions redis driver.
//
// Returns itself.
func (r *GoRedisDriver) SetClient(goRedisClient GoRedisClient) *GoRedisDriver {
r.Client = goRedisClient
return r
}
// Connect initializes the redis client.
func (r *GoRedisDriver) Connect(c Config) error {
if r.Client != nil { // if a custom one was given through SetClient.
return nil
}
if len(c.Clusters) > 0 {
r.Client = redis.NewClusterClient(r.mergeClusterOptions(c))
} else {
r.Client = redis.NewClient(r.mergeClientOptions(c))
}
return nil
}
// PingPong sends a ping message and reports whether
// the PONG message received successfully.
func (r *GoRedisDriver) PingPong() (bool, error) {
pong, err := r.Client.Ping(defaultContext).Result()
return pong == "PONG", err
}
// CloseConnection terminates the underline redis connection.
func (r *GoRedisDriver) CloseConnection() error {
return r.Client.Close()
}
// Set stores a "value" based on the session's "key".
// The value should be type of []byte, so unmarshal can happen.
func (r *GoRedisDriver) Set(sid, key string, value interface{}) error {
return r.Client.HSet(defaultContext, sid, key, value).Err()
}
// Get returns the associated value of the session's given "key".
func (r *GoRedisDriver) Get(sid, key string) (interface{}, error) {
return r.Client.HGet(defaultContext, sid, key).Bytes()
}
// Exists reports whether a session exists or not.
func (r *GoRedisDriver) Exists(sid string) bool {
n, err := r.Client.Exists(defaultContext, sid).Result()
if err != nil {
return false
}
return n > 0
}
// TTL returns any TTL value of the session.
func (r *GoRedisDriver) TTL(sid string) time.Duration {
dur, err := r.Client.TTL(defaultContext, sid).Result()
if err != nil {
return 0
}
return dur
}
// UpdateTTL sets expiration duration of the session.
func (r *GoRedisDriver) UpdateTTL(sid string, newLifetime time.Duration) error {
_, err := r.Client.Expire(defaultContext, sid, newLifetime).Result()
return err
}
// GetAll returns all the key values under the session.
func (r *GoRedisDriver) GetAll(sid string) (map[string]string, error) {
return r.Client.HGetAll(defaultContext, sid).Result()
}
// GetKeys returns all keys under the session.
func (r *GoRedisDriver) GetKeys(sid string) ([]string, error) {
return r.Client.HKeys(defaultContext, sid).Result()
}
// Len returns the total length of key-values of the session.
func (r *GoRedisDriver) Len(sid string) int {
return int(r.Client.HLen(defaultContext, sid).Val())
}
// Delete removes a value from the redis store.
func (r *GoRedisDriver) Delete(sid, key string) error {
if key == "" {
return r.Client.Del(defaultContext, sid).Err()
}
return r.Client.HDel(defaultContext, sid, key).Err()
}