package redis

import (
	"context"
	"io"
	"sync/atomic"

	"github.com/kataras/iris/v12/core/host"
	"github.com/kataras/iris/v12/middleware/jwt"

	"github.com/go-redis/redis/v8"
)

var defaultContext = context.Background()

type (
	// Options is just a type alias for the go-redis Client Options.
	Options = redis.Options
	// ClusterOptions is just a type alias for the go-redis Cluster Client Options.
	ClusterOptions = redis.ClusterOptions
)

// Client is the interface which both
// go-redis Client and Cluster Client implements.
type Client interface {
	redis.Cmdable // Commands.
	io.Closer     // CloseConnection.
}

// Blocklist is a jwt.Blocklist backed by Redis.
type Blocklist struct {
	// GetKey is a function which can be used how to extract
	// the unique identifier for a token.
	// Required. By default the token key is extracted through the claims.ID ("jti").
	GetKey func(token []byte, claims jwt.Claims) string
	// Prefix the token key into the redis database.
	// Note that if you can also select a different database
	// through ClientOptions (or ClusterOptions).
	// Defaults to empty string (no prefix).
	Prefix string
	// Both Client and ClusterClient implements this interface.
	client    Client
	connected uint32
	// Customize any go-redis fields manually
	// before Connect.
	ClientOptions  Options
	ClusterOptions ClusterOptions
}

var _ jwt.Blocklist = (*Blocklist)(nil)

// NewBlocklist returns a new redis-based Blocklist.
// Modify its ClientOptions or ClusterOptions depending the application needs
// and call its Connect.
//
// Usage:
//
//	blocklist := NewBlocklist()
//	blocklist.ClientOptions.Addr = ...
//	err := blocklist.Connect()
//
// And register it:
//
//	verifier := jwt.NewVerifier(...)
//	verifier.Blocklist = blocklist
func NewBlocklist() *Blocklist {
	return &Blocklist{
		GetKey: defaultGetKey,
		Prefix: "",
		ClientOptions: Options{
			Addr: "127.0.0.1:6379",
			// The rest are defaulted to good values already.
		},
		// If its Addrs > 0 before connect then cluster client is used instead.
		ClusterOptions: ClusterOptions{},
	}
}

func defaultGetKey(_ []byte, claims jwt.Claims) string {
	return claims.ID
}

// Connect prepares the redis client and fires a ping response to it.
func (b *Blocklist) Connect() error {
	if b.Prefix != "" {
		getKey := b.GetKey
		b.GetKey = func(token []byte, claims jwt.Claims) string {
			return b.Prefix + getKey(token, claims)
		}
	}

	if len(b.ClusterOptions.Addrs) > 0 {
		// Use cluster client.
		b.client = redis.NewClusterClient(&b.ClusterOptions)
	} else {
		b.client = redis.NewClient(&b.ClientOptions)
	}

	_, err := b.client.Ping(defaultContext).Result()
	if err != nil {
		return err
	}

	host.RegisterOnInterrupt(func() {
		atomic.StoreUint32(&b.connected, 0)
		b.client.Close()
	})
	atomic.StoreUint32(&b.connected, 1)

	return nil
}

// IsConnected reports whether the Connect function was called.
func (b *Blocklist) IsConnected() bool {
	return atomic.LoadUint32(&b.connected) > 0
}

// ValidateToken checks if the token exists and
func (b *Blocklist) ValidateToken(token []byte, c jwt.Claims, err error) error {
	if err != nil {
		if err == jwt.ErrExpired {
			b.Del(b.GetKey(token, c))
		}

		return err // respect the previous error.
	}

	has, err := b.Has(b.GetKey(token, c))
	if err != nil {
		return err
	} else if has {
		return jwt.ErrBlocked
	}

	return nil
}

// InvalidateToken invalidates a verified JWT token.
func (b *Blocklist) InvalidateToken(token []byte, c jwt.Claims) error {
	key := b.GetKey(token, c)
	return b.client.SetEX(defaultContext, key, token, c.Timeleft()).Err()
}

// Del removes a token from the storage.
func (b *Blocklist) Del(key string) error {
	return b.client.Del(defaultContext, key).Err()
}

// Has reports whether a specific token exists in the storage.
func (b *Blocklist) Has(key string) (bool, error) {
	n, err := b.client.Exists(defaultContext, key).Result()
	return n > 0, err
}

// Count returns the total amount of tokens stored.
func (b *Blocklist) Count() (int64, error) {
	if b.Prefix == "" {
		return b.client.DBSize(defaultContext).Result()
	}

	keys, err := b.getKeys(0)
	if err != nil {
		return 0, err
	}

	return int64(len(keys)), nil
}

func (b *Blocklist) getKeys(cursor uint64) ([]string, error) {
	keys, cursor, err := b.client.Scan(defaultContext, cursor, b.Prefix+"*", 300000).Result()
	if err != nil {
		return nil, err
	}

	if cursor != 0 {
		moreKeys, err := b.getKeys(cursor)
		if err != nil {
			return nil, err
		}

		keys = append(keys, moreKeys...)
	}

	return keys, nil
}