jwt: make the Blocklist an interface, so end-developers can implement their own storage (e.g. redis)

This commit is contained in:
Gerasimos (Makis) Maropoulos 2020-10-18 21:51:25 +03:00
parent 3db77684ec
commit 02bfc83f2a
No known key found for this signature in database
GPG Key ID: 5DBE766BD26A54E7
3 changed files with 36 additions and 17 deletions

View File

@ -6,27 +6,46 @@ import (
"time" "time"
) )
// Blocklist is an in-memory storage of tokens that should be // Blocklist should hold and manage invalidated-by-server tokens.
// The `NewBlocklist` and `NewBlocklistContext` functions
// returns a memory storage of tokens,
// it is the internal "blocklist" struct.
//
// The end-developer can implement her/his own blocklist,
// e.g. a redis one to keep persistence of invalidated tokens on server restarts.
// and bind to the JWT middleware's Blocklist field.
type Blocklist interface {
// Set should upsert a token to the storage.
Set(token string, expiresAt time.Time)
// Del should remove a token from the storage.
Del(token string)
// Count should return the total amount of tokens stored.
Count() int
// Has should report whether a specific token exists in the storage.
Has(token string) bool
}
// blocklist is an in-memory storage of tokens that should be
// immediately invalidated by the server-side. // immediately invalidated by the server-side.
// The most common way to invalidate a token, e.g. on user logout, // The most common way to invalidate a token, e.g. on user logout,
// is to make the client-side remove the token itself. // is to make the client-side remove the token itself.
// However, if someone else has access to that token, // However, if someone else has access to that token,
// it could be still valid for new requests until its expiration. // it could be still valid for new requests until its expiration.
type Blocklist struct { type blocklist struct {
entries map[string]time.Time // key = token | value = expiration time (to remove expired). entries map[string]time.Time // key = token | value = expiration time (to remove expired).
mu sync.RWMutex mu sync.RWMutex
} }
// NewBlocklist returns a new up and running in-memory Token Blocklist. // NewBlocklist returns a new up and running in-memory Token Blocklist.
// The returned value can be set to the JWT instance's Blocklist field. // The returned value can be set to the JWT instance's Blocklist field.
func NewBlocklist(gcEvery time.Duration) *Blocklist { func NewBlocklist(gcEvery time.Duration) Blocklist {
return NewBlocklistContext(stdContext.Background(), gcEvery) return NewBlocklistContext(stdContext.Background(), gcEvery)
} }
// NewBlocklistContext same as `NewBlocklist` // NewBlocklistContext same as `NewBlocklist`
// but it also accepts a standard Go Context for GC cancelation. // but it also accepts a standard Go Context for GC cancelation.
func NewBlocklistContext(ctx stdContext.Context, gcEvery time.Duration) *Blocklist { func NewBlocklistContext(ctx stdContext.Context, gcEvery time.Duration) Blocklist {
b := &Blocklist{ b := &blocklist{
entries: make(map[string]time.Time), entries: make(map[string]time.Time),
} }
@ -39,21 +58,21 @@ func NewBlocklistContext(ctx stdContext.Context, gcEvery time.Duration) *Blockli
// Set upserts a given token, with its expiration time, // Set upserts a given token, with its expiration time,
// to the block list, so it's immediately invalidated by the server-side. // to the block list, so it's immediately invalidated by the server-side.
func (b *Blocklist) Set(token string, expiresAt time.Time) { func (b *blocklist) Set(token string, expiresAt time.Time) {
b.mu.Lock() b.mu.Lock()
b.entries[token] = expiresAt b.entries[token] = expiresAt
b.mu.Unlock() b.mu.Unlock()
} }
// Del removes a "token" from the block list. // Del removes a "token" from the block list.
func (b *Blocklist) Del(token string) { func (b *blocklist) Del(token string) {
b.mu.Lock() b.mu.Lock()
delete(b.entries, token) delete(b.entries, token)
b.mu.Unlock() b.mu.Unlock()
} }
// Count returns the total amount of blocked tokens. // Count returns the total amount of blocked tokens.
func (b *Blocklist) Count() int { func (b *blocklist) Count() int {
b.mu.RLock() b.mu.RLock()
n := len(b.entries) n := len(b.entries)
b.mu.RUnlock() b.mu.RUnlock()
@ -64,7 +83,7 @@ func (b *Blocklist) Count() int {
// Has reports whether the given "token" is blocked by the server. // Has reports whether the given "token" is blocked by the server.
// This method is called before the token verification, // This method is called before the token verification,
// so even if was expired it is removed from the block list. // so even if was expired it is removed from the block list.
func (b *Blocklist) Has(token string) bool { func (b *blocklist) Has(token string) bool {
if token == "" { if token == "" {
return false return false
} }
@ -94,7 +113,7 @@ func (b *Blocklist) Has(token string) bool {
// Depending on the application, the GC method can be scheduled // Depending on the application, the GC method can be scheduled
// to called every half or a whole hour. // to called every half or a whole hour.
// A good value for a GC cron task is the JWT's max age (default). // A good value for a GC cron task is the JWT's max age (default).
func (b *Blocklist) GC() int { func (b *blocklist) GC() int {
now := time.Now() now := time.Now()
var markedForDeletion []string var markedForDeletion []string
@ -116,7 +135,7 @@ func (b *Blocklist) GC() int {
return n return n
} }
func (b *Blocklist) runGC(ctx stdContext.Context, every time.Duration) { func (b *blocklist) runGC(ctx stdContext.Context, every time.Duration) {
t := time.NewTicker(every) t := time.NewTicker(every)
for { for {

View File

@ -116,13 +116,13 @@ type JWT struct {
// Blocklist holds the invalidated-by-server tokens (that are not yet expired). // Blocklist holds the invalidated-by-server tokens (that are not yet expired).
// It is not initialized by default. // It is not initialized by default.
// Initialization Usage: // Initialization Usage:
// j.UseBlocklist() // j.InitDefaultBlocklist()
// OR // OR
// j.Blocklist = jwt.NewBlocklist(gcEveryDuration) // j.Blocklist = jwt.NewBlocklist(gcEveryDuration)
// Usage: // Usage:
// - ctx.Logout() // - ctx.Logout()
// - j.Invalidate(ctx) // - j.Invalidate(ctx)
Blocklist *Blocklist Blocklist Blocklist
} }
type privateKey interface{ Public() crypto.PublicKey } type privateKey interface{ Public() crypto.PublicKey }
@ -301,11 +301,11 @@ func (j *JWT) WithEncryption(contentEncryption ContentEncryption, alg KeyAlgorit
return nil return nil
} }
// UseBlocklist initializes the Blocklist. // InitDefaultBlocklist initializes the Blocklist field with the default in-memory implementation.
// Should be called on jwt middleware creation-time, // Should be called on jwt middleware creation-time,
// after this, the developer can use the Context.Logout method // after this, the developer can use the Context.Logout method
// to invalidate a verified token by the server-side. // to invalidate a verified token by the server-side.
func (j *JWT) UseBlocklist() { func (j *JWT) InitDefaultBlocklist() {
gcEvery := 30 * time.Minute gcEvery := 30 * time.Minute
if j.MaxAge > 0 { if j.MaxAge > 0 {
gcEvery = j.MaxAge gcEvery = j.MaxAge
@ -515,7 +515,7 @@ func GetTokenInfo(ctx *context.Context) *TokenInfo {
// This method can be used when the client-side does not clear the token // This method can be used when the client-side does not clear the token
// on a user logout operation. // on a user logout operation.
// //
// Note: the Blocklist should be initialized before serve-time: j.UseBlocklist(). // Note: the Blocklist should be initialized before serve-time: j.InitDefaultBlocklist().
func (j *JWT) Invalidate(ctx *context.Context) { func (j *JWT) Invalidate(ctx *context.Context) {
if j.Blocklist == nil { if j.Blocklist == nil {
ctx.Application().Logger().Debug("jwt.Invalidate: Blocklist is nil") ctx.Application().Logger().Debug("jwt.Invalidate: Blocklist is nil")

View File

@ -63,7 +63,7 @@ func TestVerify(t *testing.T) {
func testWriteVerifyBlockToken(t *testing.T, j *jwt.JWT) { func testWriteVerifyBlockToken(t *testing.T, j *jwt.JWT) {
t.Helper() t.Helper()
j.UseBlocklist() j.InitDefaultBlocklist()
j.Extractors = append(j.Extractors, jwt.FromJSON("access_token")) j.Extractors = append(j.Extractors, jwt.FromJSON("access_token"))
customClaims := &userClaims{ customClaims := &userClaims{