From 02bfc83f2a76d01e0a257e9289f0e49e524b55f0 Mon Sep 17 00:00:00 2001 From: "Gerasimos (Makis) Maropoulos" Date: Sun, 18 Oct 2020 21:51:25 +0300 Subject: [PATCH] jwt: make the Blocklist an interface, so end-developers can implement their own storage (e.g. redis) --- middleware/jwt/blocklist.go | 41 +++++++++++++++++++++++++++---------- middleware/jwt/jwt.go | 10 ++++----- middleware/jwt/jwt_test.go | 2 +- 3 files changed, 36 insertions(+), 17 deletions(-) diff --git a/middleware/jwt/blocklist.go b/middleware/jwt/blocklist.go index 86622655..377ca027 100644 --- a/middleware/jwt/blocklist.go +++ b/middleware/jwt/blocklist.go @@ -6,27 +6,46 @@ import ( "time" ) -// Blocklist is an in-memory storage of tokens that should be +// Blocklist should hold and manage invalidated-by-server tokens. +// The `NewBlocklist` and `NewBlocklistContext` functions +// returns a memory storage of tokens, +// it is the internal "blocklist" struct. +// +// The end-developer can implement her/his own blocklist, +// e.g. a redis one to keep persistence of invalidated tokens on server restarts. +// and bind to the JWT middleware's Blocklist field. +type Blocklist interface { + // Set should upsert a token to the storage. + Set(token string, expiresAt time.Time) + // Del should remove a token from the storage. + Del(token string) + // Count should return the total amount of tokens stored. + Count() int + // Has should report whether a specific token exists in the storage. + Has(token string) bool +} + +// blocklist is an in-memory storage of tokens that should be // immediately invalidated by the server-side. // The most common way to invalidate a token, e.g. on user logout, // is to make the client-side remove the token itself. // However, if someone else has access to that token, // it could be still valid for new requests until its expiration. -type Blocklist struct { +type blocklist struct { entries map[string]time.Time // key = token | value = expiration time (to remove expired). mu sync.RWMutex } // NewBlocklist returns a new up and running in-memory Token Blocklist. // The returned value can be set to the JWT instance's Blocklist field. -func NewBlocklist(gcEvery time.Duration) *Blocklist { +func NewBlocklist(gcEvery time.Duration) Blocklist { return NewBlocklistContext(stdContext.Background(), gcEvery) } // NewBlocklistContext same as `NewBlocklist` // but it also accepts a standard Go Context for GC cancelation. -func NewBlocklistContext(ctx stdContext.Context, gcEvery time.Duration) *Blocklist { - b := &Blocklist{ +func NewBlocklistContext(ctx stdContext.Context, gcEvery time.Duration) Blocklist { + b := &blocklist{ entries: make(map[string]time.Time), } @@ -39,21 +58,21 @@ func NewBlocklistContext(ctx stdContext.Context, gcEvery time.Duration) *Blockli // Set upserts a given token, with its expiration time, // to the block list, so it's immediately invalidated by the server-side. -func (b *Blocklist) Set(token string, expiresAt time.Time) { +func (b *blocklist) Set(token string, expiresAt time.Time) { b.mu.Lock() b.entries[token] = expiresAt b.mu.Unlock() } // Del removes a "token" from the block list. -func (b *Blocklist) Del(token string) { +func (b *blocklist) Del(token string) { b.mu.Lock() delete(b.entries, token) b.mu.Unlock() } // Count returns the total amount of blocked tokens. -func (b *Blocklist) Count() int { +func (b *blocklist) Count() int { b.mu.RLock() n := len(b.entries) b.mu.RUnlock() @@ -64,7 +83,7 @@ func (b *Blocklist) Count() int { // Has reports whether the given "token" is blocked by the server. // This method is called before the token verification, // so even if was expired it is removed from the block list. -func (b *Blocklist) Has(token string) bool { +func (b *blocklist) Has(token string) bool { if token == "" { return false } @@ -94,7 +113,7 @@ func (b *Blocklist) Has(token string) bool { // Depending on the application, the GC method can be scheduled // to called every half or a whole hour. // A good value for a GC cron task is the JWT's max age (default). -func (b *Blocklist) GC() int { +func (b *blocklist) GC() int { now := time.Now() var markedForDeletion []string @@ -116,7 +135,7 @@ func (b *Blocklist) GC() int { return n } -func (b *Blocklist) runGC(ctx stdContext.Context, every time.Duration) { +func (b *blocklist) runGC(ctx stdContext.Context, every time.Duration) { t := time.NewTicker(every) for { diff --git a/middleware/jwt/jwt.go b/middleware/jwt/jwt.go index 34ef36bc..3a69723a 100644 --- a/middleware/jwt/jwt.go +++ b/middleware/jwt/jwt.go @@ -116,13 +116,13 @@ type JWT struct { // Blocklist holds the invalidated-by-server tokens (that are not yet expired). // It is not initialized by default. // Initialization Usage: - // j.UseBlocklist() + // j.InitDefaultBlocklist() // OR // j.Blocklist = jwt.NewBlocklist(gcEveryDuration) // Usage: // - ctx.Logout() // - j.Invalidate(ctx) - Blocklist *Blocklist + Blocklist Blocklist } type privateKey interface{ Public() crypto.PublicKey } @@ -301,11 +301,11 @@ func (j *JWT) WithEncryption(contentEncryption ContentEncryption, alg KeyAlgorit return nil } -// UseBlocklist initializes the Blocklist. +// InitDefaultBlocklist initializes the Blocklist field with the default in-memory implementation. // Should be called on jwt middleware creation-time, // after this, the developer can use the Context.Logout method // to invalidate a verified token by the server-side. -func (j *JWT) UseBlocklist() { +func (j *JWT) InitDefaultBlocklist() { gcEvery := 30 * time.Minute if j.MaxAge > 0 { gcEvery = j.MaxAge @@ -515,7 +515,7 @@ func GetTokenInfo(ctx *context.Context) *TokenInfo { // This method can be used when the client-side does not clear the token // on a user logout operation. // -// Note: the Blocklist should be initialized before serve-time: j.UseBlocklist(). +// Note: the Blocklist should be initialized before serve-time: j.InitDefaultBlocklist(). func (j *JWT) Invalidate(ctx *context.Context) { if j.Blocklist == nil { ctx.Application().Logger().Debug("jwt.Invalidate: Blocklist is nil") diff --git a/middleware/jwt/jwt_test.go b/middleware/jwt/jwt_test.go index b3e8c422..277ce7bb 100644 --- a/middleware/jwt/jwt_test.go +++ b/middleware/jwt/jwt_test.go @@ -63,7 +63,7 @@ func TestVerify(t *testing.T) { func testWriteVerifyBlockToken(t *testing.T, j *jwt.JWT) { t.Helper() - j.UseBlocklist() + j.InitDefaultBlocklist() j.Extractors = append(j.Extractors, jwt.FromJSON("access_token")) customClaims := &userClaims{