iris/sessions/sessiondb/redis/database.go
Gerasimos (Makis) Maropoulos f1ebddb6d9
jwt: add redis blocklist
2020-11-02 06:31:28 +02:00

315 lines
8.5 KiB
Go

package redis
import (
"crypto/tls"
"errors"
"fmt"
"time"
"github.com/kataras/iris/v12/sessions"
"github.com/kataras/golog"
)
const (
// DefaultRedisNetwork the redis network option, "tcp".
DefaultRedisNetwork = "tcp"
// DefaultRedisAddr the redis address option, "127.0.0.1:6379".
DefaultRedisAddr = "127.0.0.1:6379"
// DefaultRedisTimeout the redis idle timeout option, time.Duration(30) * time.Second.
DefaultRedisTimeout = time.Duration(30) * time.Second
)
// Config the redis configuration used inside sessions
type Config struct {
// Network protocol. Defaults to "tcp".
Network string
// Addr of a single redis server instance.
// See "Clusters" field for clusters support.
// Defaults to "127.0.0.1:6379".
Addr string
// Clusters a list of network addresses for clusters.
// If not empty "Addr" is ignored and Redis clusters feature is used instead.
Clusters []string
// Use the specified Username to authenticate the current connection
// with one of the connections defined in the ACL list when connecting
// to a Redis 6.0 instance, or greater, that is using the Redis ACL system.
Username string
// Optional password. Must match the password specified in the
// requirepass server configuration option (if connecting to a Redis 5.0 instance, or lower),
// or the User Password when connecting to a Redis 6.0 instance, or greater,
// that is using the Redis ACL system.
Password string
// If Database is empty "" then no 'SELECT'. Defaults to "".
Database string
// Maximum number of socket connections.
// Default is 10 connections per every CPU as reported by runtime.NumCPU.
MaxActive int
// Timeout for connect, write and read, defaults to 30 seconds, 0 means no timeout.
Timeout time.Duration
// Prefix "myprefix-for-this-website". Defaults to "".
Prefix string
// TLSConfig will cause Dial to perform a TLS handshake using the provided
// config. If is nil then no TLS is used.
// See https://golang.org/pkg/crypto/tls/#Config
TLSConfig *tls.Config
// A Driver should support be a go client for redis communication.
// It can be set to a custom one or a mock one (for testing).
//
// Defaults to `GoRedis()`.
Driver Driver
}
// DefaultConfig returns the default configuration for Redis service.
func DefaultConfig() Config {
return Config{
Network: DefaultRedisNetwork,
Addr: DefaultRedisAddr,
Username: "",
Password: "",
Database: "",
MaxActive: 10,
Timeout: DefaultRedisTimeout,
Prefix: "",
TLSConfig: nil,
Driver: GoRedis(),
}
}
// Database the redis back-end session database for the sessions.
type Database struct {
c Config
logger *golog.Logger
}
var _ sessions.Database = (*Database)(nil)
// New returns a new redis sessions database.
func New(cfg ...Config) *Database {
c := DefaultConfig()
if len(cfg) > 0 {
c = cfg[0]
if c.Timeout < 0 {
c.Timeout = DefaultRedisTimeout
}
if c.Network == "" {
c.Network = DefaultRedisNetwork
}
if c.Addr == "" {
c.Addr = DefaultRedisAddr
}
if c.Driver == nil {
c.Driver = GoRedis()
}
}
if err := c.Driver.Connect(c); err != nil {
panic(err)
}
db := &Database{c: c}
_, err := db.c.Driver.PingPong()
if err != nil {
panic(err)
}
// runtime.SetFinalizer(db, closeDB)
return db
}
// SetLogger sets the logger once before server ran.
// By default the Iris one is injected.
func (db *Database) SetLogger(logger *golog.Logger) {
db.logger = logger
}
func (db *Database) makeSID(sid string) string {
return db.c.Prefix + sid
}
// SessionIDKey the session ID stored to the redis session itself.
const SessionIDKey = "session_id"
// Acquire receives a session's lifetime from the database,
// if the return value is LifeTime{} then the session manager sets the life time based on the expiration duration lives in configuration.
func (db *Database) Acquire(sid string, expires time.Duration) sessions.LifeTime {
sidKey := db.makeSID(sid)
if !db.c.Driver.Exists(sidKey) {
if err := db.Set(sid, SessionIDKey, sid, 0, false); err != nil {
db.logger.Debug(err)
} else if expires > 0 {
if err := db.c.Driver.UpdateTTL(sidKey, expires); err != nil {
db.logger.Debug(err)
}
}
return sessions.LifeTime{} // session manager will handle the rest.
}
untilExpire := db.c.Driver.TTL(sidKey)
return sessions.LifeTime{Time: time.Now().Add(untilExpire)}
}
// OnUpdateExpiration will re-set the database's session's entry ttl.
// https://redis.io/commands/expire#refreshing-expires
func (db *Database) OnUpdateExpiration(sid string, newExpires time.Duration) error {
return db.c.Driver.UpdateTTL(db.makeSID(sid), newExpires)
}
// Set sets a key value of a specific session.
// Ignore the "immutable".
func (db *Database) Set(sid string, key string, value interface{}, _ time.Duration, _ bool) error {
valueBytes, err := sessions.DefaultTranscoder.Marshal(value)
if err != nil {
db.logger.Error(err)
return err
}
if err = db.c.Driver.Set(db.makeSID(sid), key, valueBytes); err != nil {
db.logger.Debug(err)
return err
}
return nil
}
// Get retrieves a session value based on the key.
func (db *Database) Get(sid string, key string) (value interface{}) {
if err := db.Decode(sid, key, &value); err == nil {
return value
}
return nil
}
// Decode binds the "outPtr" to the value associated to the provided "key".
func (db *Database) Decode(sid, key string, outPtr interface{}) error {
sidKey := db.makeSID(sid)
data, err := db.c.Driver.Get(sidKey, key)
if err != nil {
// not found.
return err
}
if err = db.decodeValue(data, outPtr); err != nil {
db.logger.Debugf("unable to unmarshal value of key: '%s%s': %v", sid, key, err)
return err
}
return nil
}
func (db *Database) decodeValue(val interface{}, outPtr interface{}) error {
if val == nil {
return nil
}
switch data := val.(type) {
case []byte:
// this is the most common type, as we save all values as []byte,
// the only exception is where the value is string on HGetAll command.
return sessions.DefaultTranscoder.Unmarshal(data, outPtr)
case string:
return sessions.DefaultTranscoder.Unmarshal([]byte(data), outPtr)
default:
return fmt.Errorf("unknown value type of %T", data)
}
}
func (db *Database) keys(fullSID string) []string {
keys, err := db.c.Driver.GetKeys(fullSID)
if err != nil {
db.logger.Debugf("unable to get all redis keys of session '%s': %v", fullSID, err)
return nil
}
return keys
}
// Visit loops through all session keys and values.
func (db *Database) Visit(sid string, cb func(key string, value interface{})) error {
kv, err := db.c.Driver.GetAll(db.makeSID(sid))
if err != nil {
return err
}
for k, v := range kv {
var value interface{} // new value each time, we don't know what user will do in "cb".
if err = db.decodeValue(v, &value); err != nil {
db.logger.Debugf("unable to decode %s:%s: %v", sid, k, err)
return err
}
cb(k, value)
}
return nil
}
// Len returns the length of the session's entries (keys).
func (db *Database) Len(sid string) int {
return db.c.Driver.Len(sid)
}
// Delete removes a session key value based on its key.
func (db *Database) Delete(sid string, key string) (deleted bool) {
err := db.c.Driver.Delete(db.makeSID(sid), key)
if err != nil {
db.logger.Error(err)
}
return err == nil
}
// Clear removes all session key values but it keeps the session entry.
func (db *Database) Clear(sid string) error {
keys := db.keys(db.makeSID(sid))
for _, key := range keys {
if key == SessionIDKey {
continue
}
if err := db.c.Driver.Delete(sid, key); err != nil {
db.logger.Debugf("unable to delete session '%s' value of key: '%s': %v", sid, key, err)
return err
}
}
return nil
}
// Release destroys the session, it clears and removes the session entry,
// session manager will create a new session ID on the next request after this call.
func (db *Database) Release(sid string) error {
err := db.c.Driver.Delete(db.makeSID(sid), "")
if err != nil {
db.logger.Debugf("Database.Release.Driver.Delete: %s: %v", sid, err)
}
return err
}
// Close terminates the redis connection.
func (db *Database) Close() error {
return closeDB(db)
}
func closeDB(db *Database) error {
return db.c.Driver.CloseConnection()
}
var (
// ErrRedisClosed an error with message 'redis: already closed'
ErrRedisClosed = errors.New("redis: already closed")
// ErrKeyNotFound a type of error of non-existing redis keys.
// The producers(the library) of this error will dynamically wrap this error(fmt.Errorf) with the key name.
// Usage:
// if err != nil && errors.Is(err, ErrKeyNotFound) {
// [...]
// }
ErrKeyNotFound = errors.New("key not found")
)