package redis

import (
	"crypto/tls"
	"errors"
	"fmt"
	"time"

	"github.com/kataras/iris/v12/sessions"

	"github.com/kataras/golog"
)

const (
	// DefaultRedisNetwork the redis network option, "tcp".
	DefaultRedisNetwork = "tcp"
	// DefaultRedisAddr the redis address option, "127.0.0.1:6379".
	DefaultRedisAddr = "127.0.0.1:6379"
	// DefaultRedisTimeout the redis idle timeout option, time.Duration(30) * time.Second
	DefaultRedisTimeout = time.Duration(30) * time.Second
)

// Config the redis configuration used inside sessions
type Config struct {
	// Network protocol. Defaults to "tcp".
	Network string
	// Addr of a single redis server instance.
	// See "Clusters" field for clusters support.
	// Defaults to "127.0.0.1:6379".
	Addr string
	// Clusters a list of network addresses for clusters.
	// If not empty "Addr" is ignored and Redis clusters feature is used instead.
	Clusters []string
	// Use the specified Username to authenticate the current connection
	// with one of the connections defined in the ACL list when connecting
	// to a Redis 6.0 instance, or greater, that is using the Redis ACL system.
	Username string
	// Optional password. Must match the password specified in the
	// requirepass server configuration option (if connecting to a Redis 5.0 instance, or lower),
	// or the User Password when connecting to a Redis 6.0 instance, or greater,
	// that is using the Redis ACL system.
	Password string
	// If Database is empty "" then no 'SELECT'. Defaults to "".
	Database string
	// Maximum number of socket connections.
	// Default is 10 connections per every CPU as reported by runtime.NumCPU.
	MaxActive int
	// Timeout for connect, write and read, defaults to 30 seconds, 0 means no timeout.
	Timeout time.Duration
	// Prefix "myprefix-for-this-website". Defaults to "".
	Prefix string

	// TLSConfig will cause Dial to perform a TLS handshake using the provided
	// config. If is nil then no TLS is used.
	// See https://golang.org/pkg/crypto/tls/#Config
	TLSConfig *tls.Config

	// A Driver should support be a go client for redis communication.
	// It can be set to a custom one or a mock one (for testing).
	//
	// Defaults to `GoRedis()`.
	Driver Driver
}

// DefaultConfig returns the default configuration for Redis service.
func DefaultConfig() Config {
	return Config{
		Network:   DefaultRedisNetwork,
		Addr:      DefaultRedisAddr,
		Username:  "",
		Password:  "",
		Database:  "",
		MaxActive: 10,
		Timeout:   DefaultRedisTimeout,
		Prefix:    "",
		TLSConfig: nil,
		Driver:    GoRedis(),
	}
}

// Database the redis back-end session database for the sessions.
type Database struct {
	c      Config
	logger *golog.Logger
}

var _ sessions.Database = (*Database)(nil)

// New returns a new redis sessions database.
func New(cfg ...Config) *Database {
	c := DefaultConfig()
	if len(cfg) > 0 {
		c = cfg[0]

		if c.Timeout < 0 {
			c.Timeout = DefaultRedisTimeout
		}

		if c.Network == "" {
			c.Network = DefaultRedisNetwork
		}

		if c.Addr == "" {
			c.Addr = DefaultRedisAddr
		}

		if c.Driver == nil {
			c.Driver = GoRedis()
		}
	}

	if err := c.Driver.Connect(c); err != nil {
		panic(err)
	}

	db := &Database{c: c}
	_, err := db.c.Driver.PingPong()
	if err != nil {
		panic(err)
	}
	// runtime.SetFinalizer(db, closeDB)
	return db
}

// SetLogger sets the logger once before server ran.
// By default the Iris one is injected.
func (db *Database) SetLogger(logger *golog.Logger) {
	db.logger = logger
}

func (db *Database) makeSID(sid string) string {
	return db.c.Prefix + sid
}

// SessionIDKey the session ID stored to the redis session itself.
const SessionIDKey = "session_id"

// Acquire receives a session's lifetime from the database,
// if the return value is LifeTime{} then the session manager sets the life time based on the expiration duration lives in configuration.
func (db *Database) Acquire(sid string, expires time.Duration) sessions.LifeTime {
	sidKey := db.makeSID(sid)
	if !db.c.Driver.Exists(sidKey) {
		if err := db.Set(sid, SessionIDKey, sid, 0, false); err != nil {
			db.logger.Debug(err)
		} else if expires > 0 {
			if err := db.c.Driver.UpdateTTL(sidKey, expires); err != nil {
				db.logger.Debug(err)
			}
		}

		return sessions.LifeTime{} // session manager will handle the rest.
	}

	untilExpire := db.c.Driver.TTL(sidKey)
	return sessions.LifeTime{Time: time.Now().Add(untilExpire)}
}

// OnUpdateExpiration will re-set the database's session's entry ttl.
// https://redis.io/commands/expire#refreshing-expires
func (db *Database) OnUpdateExpiration(sid string, newExpires time.Duration) error {
	return db.c.Driver.UpdateTTL(db.makeSID(sid), newExpires)
}

// Set sets a key value of a specific session.
// Ignore the "immutable".
func (db *Database) Set(sid string, key string, value interface{}, _ time.Duration, _ bool) error {
	valueBytes, err := sessions.DefaultTranscoder.Marshal(value)
	if err != nil {
		db.logger.Error(err)
		return err
	}

	if err = db.c.Driver.Set(db.makeSID(sid), key, valueBytes); err != nil {
		db.logger.Debug(err)
		return err
	}

	return nil
}

// Get retrieves a session value based on the key.
func (db *Database) Get(sid string, key string) (value interface{}) {
	if err := db.Decode(sid, key, &value); err == nil {
		return value
	}

	return nil
}

// Decode binds the "outPtr" to the value associated to the provided "key".
func (db *Database) Decode(sid, key string, outPtr interface{}) error {
	sidKey := db.makeSID(sid)
	data, err := db.c.Driver.Get(sidKey, key)
	if err != nil {
		// not found.
		return err
	}

	if err = db.decodeValue(data, outPtr); err != nil {
		db.logger.Debugf("unable to unmarshal value of key: '%s%s': %v", sid, key, err)
		return err
	}

	return nil
}

func (db *Database) decodeValue(val interface{}, outPtr interface{}) error {
	if val == nil {
		return nil
	}

	switch data := val.(type) {
	case []byte:
		// this is the most common type, as we save all values as []byte,
		// the only exception is where the value is string on HGetAll command.
		return sessions.DefaultTranscoder.Unmarshal(data, outPtr)
	case string:
		return sessions.DefaultTranscoder.Unmarshal([]byte(data), outPtr)
	default:
		return fmt.Errorf("unknown value type of %T", data)
	}
}

func (db *Database) keys(fullSID string) []string {
	keys, err := db.c.Driver.GetKeys(fullSID)
	if err != nil {
		db.logger.Debugf("unable to get all redis keys of session '%s': %v", fullSID, err)
		return nil
	}

	return keys
}

// Visit loops through all session keys and values.
func (db *Database) Visit(sid string, cb func(key string, value interface{})) error {
	kv, err := db.c.Driver.GetAll(db.makeSID(sid))
	if err != nil {
		return err
	}

	for k, v := range kv {
		var value interface{} // new value each time, we don't know what user will do in "cb".
		if err = db.decodeValue(v, &value); err != nil {
			db.logger.Debugf("unable to decode %s:%s: %v", sid, k, err)
			return err
		}

		cb(k, value)
	}

	return nil
}

// Len returns the length of the session's entries (keys).
func (db *Database) Len(sid string) int {
	return db.c.Driver.Len(sid)
}

// Delete removes a session key value based on its key.
func (db *Database) Delete(sid string, key string) (deleted bool) {
	err := db.c.Driver.Delete(db.makeSID(sid), key)
	if err != nil {
		db.logger.Error(err)
	}
	return err == nil
}

// Clear removes all session key values but it keeps the session entry.
func (db *Database) Clear(sid string) error {
	keys := db.keys(db.makeSID(sid))
	for _, key := range keys {
		if key == SessionIDKey {
			continue
		}
		if err := db.c.Driver.Delete(sid, key); err != nil {
			db.logger.Debugf("unable to delete session '%s' value of key: '%s': %v", sid, key, err)
			return err
		}
	}

	return nil
}

// Release destroys the session, it clears and removes the session entry,
// session manager will create a new session ID on the next request after this call.
func (db *Database) Release(sid string) error {
	err := db.c.Driver.Delete(db.makeSID(sid), "")
	if err != nil {
		db.logger.Debugf("Database.Release.Driver.Delete: %s: %v", sid, err)
	}

	return err
}

// Close terminates the redis connection.
func (db *Database) Close() error {
	return closeDB(db)
}

func closeDB(db *Database) error {
	return db.c.Driver.CloseConnection()
}

var (
	// ErrRedisClosed an error with message 'redis: already closed'
	ErrRedisClosed = errors.New("redis: already closed")
	// ErrKeyNotFound a type of error of non-existing redis keys.
	// The producers(the library) of this error will dynamically wrap this error(fmt.Errorf) with the key name.
	// Usage:
	// if err != nil && errors.Is(err, ErrKeyNotFound) {
	// [...]
	// }
	ErrKeyNotFound = errors.New("key not found")
)