mirror of
https://github.com/kataras/iris.git
synced 2025-01-23 18:51:03 +01:00
315 lines
8.5 KiB
Go
315 lines
8.5 KiB
Go
package redis
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/kataras/iris/v12/sessions"
|
|
|
|
"github.com/kataras/golog"
|
|
)
|
|
|
|
const (
|
|
// DefaultRedisNetwork the redis network option, "tcp".
|
|
DefaultRedisNetwork = "tcp"
|
|
// DefaultRedisAddr the redis address option, "127.0.0.1:6379".
|
|
DefaultRedisAddr = "127.0.0.1:6379"
|
|
// DefaultRedisTimeout the redis idle timeout option, time.Duration(30) * time.Second.
|
|
DefaultRedisTimeout = time.Duration(30) * time.Second
|
|
)
|
|
|
|
// Config the redis configuration used inside sessions
|
|
type Config struct {
|
|
// Network protocol. Defaults to "tcp".
|
|
Network string
|
|
// Addr of a single redis server instance.
|
|
// See "Clusters" field for clusters support.
|
|
// Defaults to "127.0.0.1:6379".
|
|
Addr string
|
|
// Clusters a list of network addresses for clusters.
|
|
// If not empty "Addr" is ignored and Redis clusters feature is used instead.
|
|
Clusters []string
|
|
// Use the specified Username to authenticate the current connection
|
|
// with one of the connections defined in the ACL list when connecting
|
|
// to a Redis 6.0 instance, or greater, that is using the Redis ACL system.
|
|
Username string
|
|
// Optional password. Must match the password specified in the
|
|
// requirepass server configuration option (if connecting to a Redis 5.0 instance, or lower),
|
|
// or the User Password when connecting to a Redis 6.0 instance, or greater,
|
|
// that is using the Redis ACL system.
|
|
Password string
|
|
// If Database is empty "" then no 'SELECT'. Defaults to "".
|
|
Database string
|
|
// Maximum number of socket connections.
|
|
// Default is 10 connections per every CPU as reported by runtime.NumCPU.
|
|
MaxActive int
|
|
// Timeout for connect, write and read, defaults to 30 seconds, 0 means no timeout.
|
|
Timeout time.Duration
|
|
// Prefix "myprefix-for-this-website". Defaults to "".
|
|
Prefix string
|
|
|
|
// TLSConfig will cause Dial to perform a TLS handshake using the provided
|
|
// config. If is nil then no TLS is used.
|
|
// See https://golang.org/pkg/crypto/tls/#Config
|
|
TLSConfig *tls.Config
|
|
|
|
// A Driver should support be a go client for redis communication.
|
|
// It can be set to a custom one or a mock one (for testing).
|
|
//
|
|
// Defaults to `GoRedis()`.
|
|
Driver Driver
|
|
}
|
|
|
|
// DefaultConfig returns the default configuration for Redis service.
|
|
func DefaultConfig() Config {
|
|
return Config{
|
|
Network: DefaultRedisNetwork,
|
|
Addr: DefaultRedisAddr,
|
|
Username: "",
|
|
Password: "",
|
|
Database: "",
|
|
MaxActive: 10,
|
|
Timeout: DefaultRedisTimeout,
|
|
Prefix: "",
|
|
TLSConfig: nil,
|
|
Driver: GoRedis(),
|
|
}
|
|
}
|
|
|
|
// Database the redis back-end session database for the sessions.
|
|
type Database struct {
|
|
c Config
|
|
logger *golog.Logger
|
|
}
|
|
|
|
var _ sessions.Database = (*Database)(nil)
|
|
|
|
// New returns a new redis sessions database.
|
|
func New(cfg ...Config) *Database {
|
|
c := DefaultConfig()
|
|
if len(cfg) > 0 {
|
|
c = cfg[0]
|
|
|
|
if c.Timeout < 0 {
|
|
c.Timeout = DefaultRedisTimeout
|
|
}
|
|
|
|
if c.Network == "" {
|
|
c.Network = DefaultRedisNetwork
|
|
}
|
|
|
|
if c.Addr == "" {
|
|
c.Addr = DefaultRedisAddr
|
|
}
|
|
|
|
if c.Driver == nil {
|
|
c.Driver = GoRedis()
|
|
}
|
|
}
|
|
|
|
if err := c.Driver.Connect(c); err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
db := &Database{c: c}
|
|
_, err := db.c.Driver.PingPong()
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
// runtime.SetFinalizer(db, closeDB)
|
|
return db
|
|
}
|
|
|
|
// SetLogger sets the logger once before server ran.
|
|
// By default the Iris one is injected.
|
|
func (db *Database) SetLogger(logger *golog.Logger) {
|
|
db.logger = logger
|
|
}
|
|
|
|
func (db *Database) makeSID(sid string) string {
|
|
return db.c.Prefix + sid
|
|
}
|
|
|
|
// SessionIDKey the session ID stored to the redis session itself.
|
|
const SessionIDKey = "session_id"
|
|
|
|
// Acquire receives a session's lifetime from the database,
|
|
// if the return value is LifeTime{} then the session manager sets the life time based on the expiration duration lives in configuration.
|
|
func (db *Database) Acquire(sid string, expires time.Duration) sessions.LifeTime {
|
|
sidKey := db.makeSID(sid)
|
|
if !db.c.Driver.Exists(sidKey) {
|
|
if err := db.Set(sid, SessionIDKey, sid, 0, false); err != nil {
|
|
db.logger.Debug(err)
|
|
} else if expires > 0 {
|
|
if err := db.c.Driver.UpdateTTL(sidKey, expires); err != nil {
|
|
db.logger.Debug(err)
|
|
}
|
|
}
|
|
|
|
return sessions.LifeTime{} // session manager will handle the rest.
|
|
}
|
|
|
|
untilExpire := db.c.Driver.TTL(sidKey)
|
|
return sessions.LifeTime{Time: time.Now().Add(untilExpire)}
|
|
}
|
|
|
|
// OnUpdateExpiration will re-set the database's session's entry ttl.
|
|
// https://redis.io/commands/expire#refreshing-expires
|
|
func (db *Database) OnUpdateExpiration(sid string, newExpires time.Duration) error {
|
|
return db.c.Driver.UpdateTTL(db.makeSID(sid), newExpires)
|
|
}
|
|
|
|
// Set sets a key value of a specific session.
|
|
// Ignore the "immutable".
|
|
func (db *Database) Set(sid string, key string, value interface{}, _ time.Duration, _ bool) error {
|
|
valueBytes, err := sessions.DefaultTranscoder.Marshal(value)
|
|
if err != nil {
|
|
db.logger.Error(err)
|
|
return err
|
|
}
|
|
|
|
if err = db.c.Driver.Set(db.makeSID(sid), key, valueBytes); err != nil {
|
|
db.logger.Debug(err)
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Get retrieves a session value based on the key.
|
|
func (db *Database) Get(sid string, key string) (value interface{}) {
|
|
if err := db.Decode(sid, key, &value); err == nil {
|
|
return value
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Decode binds the "outPtr" to the value associated to the provided "key".
|
|
func (db *Database) Decode(sid, key string, outPtr interface{}) error {
|
|
sidKey := db.makeSID(sid)
|
|
data, err := db.c.Driver.Get(sidKey, key)
|
|
if err != nil {
|
|
// not found.
|
|
return err
|
|
}
|
|
|
|
if err = db.decodeValue(data, outPtr); err != nil {
|
|
db.logger.Debugf("unable to unmarshal value of key: '%s%s': %v", sid, key, err)
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (db *Database) decodeValue(val interface{}, outPtr interface{}) error {
|
|
if val == nil {
|
|
return nil
|
|
}
|
|
|
|
switch data := val.(type) {
|
|
case []byte:
|
|
// this is the most common type, as we save all values as []byte,
|
|
// the only exception is where the value is string on HGetAll command.
|
|
return sessions.DefaultTranscoder.Unmarshal(data, outPtr)
|
|
case string:
|
|
return sessions.DefaultTranscoder.Unmarshal([]byte(data), outPtr)
|
|
default:
|
|
return fmt.Errorf("unknown value type of %T", data)
|
|
}
|
|
}
|
|
|
|
func (db *Database) keys(fullSID string) []string {
|
|
keys, err := db.c.Driver.GetKeys(fullSID)
|
|
if err != nil {
|
|
db.logger.Debugf("unable to get all redis keys of session '%s': %v", fullSID, err)
|
|
return nil
|
|
}
|
|
|
|
return keys
|
|
}
|
|
|
|
// Visit loops through all session keys and values.
|
|
func (db *Database) Visit(sid string, cb func(key string, value interface{})) error {
|
|
kv, err := db.c.Driver.GetAll(db.makeSID(sid))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for k, v := range kv {
|
|
var value interface{} // new value each time, we don't know what user will do in "cb".
|
|
if err = db.decodeValue(v, &value); err != nil {
|
|
db.logger.Debugf("unable to decode %s:%s: %v", sid, k, err)
|
|
return err
|
|
}
|
|
|
|
cb(k, value)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Len returns the length of the session's entries (keys).
|
|
func (db *Database) Len(sid string) int {
|
|
return db.c.Driver.Len(sid)
|
|
}
|
|
|
|
// Delete removes a session key value based on its key.
|
|
func (db *Database) Delete(sid string, key string) (deleted bool) {
|
|
err := db.c.Driver.Delete(db.makeSID(sid), key)
|
|
if err != nil {
|
|
db.logger.Error(err)
|
|
}
|
|
return err == nil
|
|
}
|
|
|
|
// Clear removes all session key values but it keeps the session entry.
|
|
func (db *Database) Clear(sid string) error {
|
|
keys := db.keys(db.makeSID(sid))
|
|
for _, key := range keys {
|
|
if key == SessionIDKey {
|
|
continue
|
|
}
|
|
if err := db.c.Driver.Delete(sid, key); err != nil {
|
|
db.logger.Debugf("unable to delete session '%s' value of key: '%s': %v", sid, key, err)
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Release destroys the session, it clears and removes the session entry,
|
|
// session manager will create a new session ID on the next request after this call.
|
|
func (db *Database) Release(sid string) error {
|
|
err := db.c.Driver.Delete(db.makeSID(sid), "")
|
|
if err != nil {
|
|
db.logger.Debugf("Database.Release.Driver.Delete: %s: %v", sid, err)
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
// Close terminates the redis connection.
|
|
func (db *Database) Close() error {
|
|
return closeDB(db)
|
|
}
|
|
|
|
func closeDB(db *Database) error {
|
|
return db.c.Driver.CloseConnection()
|
|
}
|
|
|
|
var (
|
|
// ErrRedisClosed an error with message 'redis: already closed'
|
|
ErrRedisClosed = errors.New("redis: already closed")
|
|
// ErrKeyNotFound a type of error of non-existing redis keys.
|
|
// The producers(the library) of this error will dynamically wrap this error(fmt.Errorf) with the key name.
|
|
// Usage:
|
|
// if err != nil && errors.Is(err, ErrKeyNotFound) {
|
|
// [...]
|
|
// }
|
|
ErrKeyNotFound = errors.New("key not found")
|
|
)
|