diff --git a/_examples/sessions/database/badger/main.go b/_examples/sessions/database/badger/main.go index 400d2969..4f24b317 100644 --- a/_examples/sessions/database/badger/main.go +++ b/_examples/sessions/database/badger/main.go @@ -88,8 +88,19 @@ func main() { }) app.Get("/update", func(ctx iris.Context) { - // updates expire date with a new date - sess.ShiftExpiration(ctx) + // updates resets the expiration based on the session's `Expires` field. + if err := sess.ShiftExpiration(ctx); err != nil { + if sessions.ErrNotFound.Equal(err) { + ctx.StatusCode(iris.StatusNotFound) + } else if sessions.ErrNotImplemented.Equal(err) { + ctx.StatusCode(iris.StatusNotImplemented) + } else { + ctx.StatusCode(iris.StatusNotModified) + } + + ctx.Writef("%v", err) + ctx.Application().Logger().Error(err) + } }) app.Run(iris.Addr(":8080"), iris.WithoutServerError(iris.ErrServerClosed)) diff --git a/_examples/sessions/database/boltdb/main.go b/_examples/sessions/database/boltdb/main.go index 9e7d6d41..3e2d894f 100644 --- a/_examples/sessions/database/boltdb/main.go +++ b/_examples/sessions/database/boltdb/main.go @@ -89,8 +89,19 @@ func main() { }) app.Get("/update", func(ctx iris.Context) { - // updates expire date with a new date - sess.ShiftExpiration(ctx) + // updates resets the expiration based on the session's `Expires` field. + if err := sess.ShiftExpiration(ctx); err != nil { + if sessions.ErrNotFound.Equal(err) { + ctx.StatusCode(iris.StatusNotFound) + } else if sessions.ErrNotImplemented.Equal(err) { + ctx.StatusCode(iris.StatusNotImplemented) + } else { + ctx.StatusCode(iris.StatusNotModified) + } + + ctx.Writef("%v", err) + ctx.Application().Logger().Error(err) + } }) app.Run(iris.Addr(":8080"), iris.WithoutServerError(iris.ErrServerClosed)) diff --git a/_examples/sessions/database/redis/main.go b/_examples/sessions/database/redis/main.go index 88257c8d..584451a1 100644 --- a/_examples/sessions/database/redis/main.go +++ b/_examples/sessions/database/redis/main.go @@ -116,8 +116,19 @@ func main() { }) app.Get("/update", func(ctx iris.Context) { - // updates expire date with a new date - sess.ShiftExpiration(ctx) + // updates resets the expiration based on the session's `Expires` field. + if err := sess.ShiftExpiration(ctx); err != nil { + if sessions.ErrNotFound.Equal(err) { + ctx.StatusCode(iris.StatusNotFound) + } else if sessions.ErrNotImplemented.Equal(err) { + ctx.StatusCode(iris.StatusNotImplemented) + } else { + ctx.StatusCode(iris.StatusNotModified) + } + + ctx.Writef("%v", err) + ctx.Application().Logger().Error(err) + } }) app.Run(iris.Addr(":8080"), iris.WithoutServerError(iris.ErrServerClosed)) diff --git a/core/errors/errors.go b/core/errors/errors.go index c0369693..e06e97a1 100644 --- a/core/errors/errors.go +++ b/core/errors/errors.go @@ -48,11 +48,17 @@ func NewFromErr(err error) *Error { return &errp } -// Equal returns true if "e" and "e2" are matched, by their IDs. -// It will always returns true if the "e2" is a children of "e" +// Equal returns true if "e" and "to" are matched, by their IDs if it's a core/errors type otherwise it tries to match their error messages. +// It will always returns true if the "to" is a children of "e" // or the error messages are exactly the same, otherwise false. -func (e Error) Equal(e2 Error) bool { - return e.ID == e2.ID || e.Error() == e2.Error() +func (e Error) Equal(to error) bool { + if e2, ok := to.(Error); ok { + return e.ID == e2.ID + } else if e2, ok := to.(*Error); ok { + return e.ID == e2.ID + } + + return e.Error() == to.Error() } // Empty returns true if the "e" Error has no message on its stack. diff --git a/sessions/database.go b/sessions/database.go index 2589c611..01830cfb 100644 --- a/sessions/database.go +++ b/sessions/database.go @@ -4,9 +4,14 @@ import ( "sync" "time" + "github.com/kataras/iris/core/errors" "github.com/kataras/iris/core/memstore" ) +// ErrNotImplemented is returned when a particular feature is not yet implemented yet. +// It can be matched directly, i.e: `isNotImplementedError := sessions.ErrNotImplemented.Equal(err)`. +var ErrNotImplemented = errors.New("not implemented yet") + // Database is the interface which all session databases should implement // By design it doesn't support any type of cookie session like other frameworks. // I want to protect you, believe me. @@ -20,6 +25,16 @@ type Database interface { // Acquire receives a session's lifetime from the database, // if the return value is LifeTime{} then the session manager sets the life time based on the expiration duration lives in configuration. Acquire(sid string, expires time.Duration) LifeTime + // OnUpdateExpiration should re-set the expiration (ttl) of the session entry inside the database, + // it is fired on `ShiftExpiration` and `UpdateExpiration`. + // If the database does not support change of ttl then the session entry will be cloned to another one + // and the old one will be removed, it depends on the chosen database storage. + // + // Check of error is required, if error returned then the rest session's keys are not proceed. + // + // Currently only "redis" database is designed to use that event. + // If A database is not support this feature then a `ErrNotImplemented` will be returned instead. + OnUpdateExpiration(sid string, newExpires time.Duration) error // Set sets a key value of a specific session. // The "immutable" input argument depends on the store, it may not implement it at all. Set(sid string, lifetime LifeTime, key string, value interface{}, immutable bool) @@ -54,6 +69,9 @@ func (s *mem) Acquire(sid string, expires time.Duration) LifeTime { return LifeTime{} } +// Do nothing, the `LifeTime` of the Session will be managed by the callers automatically on memory-based storage. +func (s *mem) OnUpdateExpiration(string, time.Duration) error { return nil } + // immutable depends on the store, it may not implement it at all. func (s *mem) Set(sid string, lifetime LifeTime, key string, value interface{}, immutable bool) { s.mu.RLock() diff --git a/sessions/provider.go b/sessions/provider.go index 7d503ec2..0c1f6c4b 100644 --- a/sessions/provider.go +++ b/sessions/provider.go @@ -3,6 +3,8 @@ package sessions import ( "sync" "time" + + "github.com/kataras/iris/core/errors" ) type ( @@ -76,23 +78,32 @@ func (p *provider) Init(sid string, expires time.Duration) *Session { return newSession } -// UpdateExpiration update expire date of a session. -// if expires > 0 then it updates the destroy task. -// if expires <=0 then it does nothing, to destroy a session call the `Destroy` func instead. -func (p *provider) UpdateExpiration(sid string, expires time.Duration) bool { +// ErrNotFound can be returned when calling `UpdateExpiration` on a non-existing or invalid session entry. +// It can be matched directly, i.e: `isErrNotFound := sessions.ErrNotFound.Equal(err)`. +var ErrNotFound = errors.New("not found") + +// UpdateExpiration resets the expiration of a session. +// if expires > 0 then it will try to update the expiration and destroy task is delayed. +// if expires <= 0 then it does nothing it returns nil, to destroy a session call the `Destroy` func instead. +// +// If the session is not found, it returns a `NotFound` error, this can only happen when you restart the server and you used the memory-based storage(default), +// because the call of the provider's `UpdateExpiration` is always called when the client has a valid session cookie. +// +// If a backend database is used then it may return an `ErrNotImplemented` error if the underline database does not support this operation. +func (p *provider) UpdateExpiration(sid string, expires time.Duration) error { if expires <= 0 { - return false + return nil } p.mu.Lock() sess, found := p.sessions[sid] p.mu.Unlock() if !found { - return false + return ErrNotFound } sess.Lifetime.Shift(expires) - return true + return p.db.OnUpdateExpiration(sid, expires) } // Read returns the store which sid parameter belongs diff --git a/sessions/sessiondb/badger/database.go b/sessions/sessiondb/badger/database.go index 339f4076..223deafe 100644 --- a/sessions/sessiondb/badger/database.go +++ b/sessions/sessiondb/badger/database.go @@ -102,6 +102,12 @@ func (db *Database) Acquire(sid string, expires time.Duration) sessions.LifeTime return sessions.LifeTime{} // session manager will handle the rest. } +// OnUpdateExpiration not implemented here, yet. +// Note that this error will not be logged, callers should catch it manually. +func (db *Database) OnUpdateExpiration(sid string, newExpires time.Duration) error { + return sessions.ErrNotImplemented +} + var delim = byte('_') func makePrefix(sid string) []byte { diff --git a/sessions/sessiondb/boltdb/database.go b/sessions/sessiondb/boltdb/database.go index 23edc0b2..fd241cf6 100644 --- a/sessions/sessiondb/boltdb/database.go +++ b/sessions/sessiondb/boltdb/database.go @@ -210,6 +210,12 @@ func (db *Database) Acquire(sid string, expires time.Duration) (lifetime session return } +// OnUpdateExpiration not implemented here, yet. +// Note that this error will not be logged, callers should catch it manually. +func (db *Database) OnUpdateExpiration(sid string, newExpires time.Duration) error { + return sessions.ErrNotImplemented +} + func makeKey(key string) []byte { return []byte(key) } diff --git a/sessions/sessiondb/redis/database.go b/sessions/sessiondb/redis/database.go index 51981b5a..6477af76 100644 --- a/sessions/sessiondb/redis/database.go +++ b/sessions/sessiondb/redis/database.go @@ -55,6 +55,12 @@ func (db *Database) Acquire(sid string, expires time.Duration) sessions.LifeTime return sessions.LifeTime{Time: time.Now().Add(time.Duration(seconds) * time.Second)} } +// OnUpdateExpiration will re-set the database's session's entry ttl. +// https://redis.io/commands/expire#refreshing-expires +func (db *Database) OnUpdateExpiration(sid string, newExpires time.Duration) error { + return db.redis.UpdateTTLMany(sid, int64(newExpires.Seconds())) +} + const delim = "_" func makeKey(sid, key string) string { diff --git a/sessions/sessiondb/redis/service/service.go b/sessions/sessiondb/redis/service/service.go index ecb059ce..186368d3 100644 --- a/sessions/sessiondb/redis/service/service.go +++ b/sessions/sessiondb/redis/service/service.go @@ -99,6 +99,65 @@ func (r *Service) TTL(key string) (seconds int64, hasExpiration bool, ok bool) { return } +func (r *Service) updateTTLConn(c redis.Conn, key string, newSecondsLifeTime int64) error { + reply, err := c.Do("EXPIRE", r.Config.Prefix+key, newSecondsLifeTime) + if err != nil { + return err + } + + // https://redis.io/commands/expire#return-value + // + // 1 if the timeout was set. + // 0 if key does not exist. + if hadTTLOrExists, ok := reply.(int); ok { + if hadTTLOrExists == 1 { + return nil + } else if hadTTLOrExists == 0 { + return fmt.Errorf("unable to update expiration, the key '%s' was stored without ttl", key) + } // do not check for -1. + } + + return nil +} + +// UpdateTTL will update the ttl of a key. +// Using the "EXPIRE" command. +// Read more at: https://redis.io/commands/expire#refreshing-expires +func (r *Service) UpdateTTL(key string, newSecondsLifeTime int64) error { + c := r.pool.Get() + defer c.Close() + err := c.Err() + if err != nil { + return err + } + + return r.updateTTLConn(c, key, newSecondsLifeTime) +} + +// UpdateTTLMany like `UpdateTTL` but for all keys starting with that "prefix", +// it is a bit faster operation if you need to update all sessions keys (although it can be even faster if we used hash but this will limit other features), +// look the `sessions/Database#OnUpdateExpiration` for example. +func (r *Service) UpdateTTLMany(prefix string, newSecondsLifeTime int64) error { + c := r.pool.Get() + defer c.Close() + if err := c.Err(); err != nil { + return err + } + + keys, err := r.getKeysConn(c, prefix) + if err != nil { + return err + } + + for _, key := range keys { + if err = r.updateTTLConn(c, key, newSecondsLifeTime); err != nil { // fail on first error. + return err + } + } + + return err +} + // GetAll returns all redis entries using the "SCAN" command (2.8+). func (r *Service) GetAll() (interface{}, error) { c := r.pool.Get() @@ -120,15 +179,7 @@ func (r *Service) GetAll() (interface{}, error) { return redisVal, nil } -// GetKeys returns all redis keys using the "SCAN" with MATCH command. -// Read more at: https://redis.io/commands/scan#the-match-option. -func (r *Service) GetKeys(prefix string) ([]string, error) { - c := r.pool.Get() - defer c.Close() - if err := c.Err(); err != nil { - return nil, err - } - +func (r *Service) getKeysConn(c redis.Conn, prefix string) ([]string, error) { if err := c.Send("SCAN", 0, "MATCH", r.Config.Prefix+prefix+"*", "COUNT", 9999999999); err != nil { return nil, err } @@ -155,13 +206,24 @@ func (r *Service) GetKeys(prefix string) ([]string, error) { return keys, nil } - } } return nil, nil } +// GetKeys returns all redis keys using the "SCAN" with MATCH command. +// Read more at: https://redis.io/commands/scan#the-match-option. +func (r *Service) GetKeys(prefix string) ([]string, error) { + c := r.pool.Get() + defer c.Close() + if err := c.Err(); err != nil { + return nil, err + } + + return r.getKeysConn(c, prefix) +} + // GetBytes returns value, err by its key // you can use utils.Deserialize((.GetBytes("yourkey"),&theobject{}) //returns nil and a filled error if something wrong happens diff --git a/sessions/sessions.go b/sessions/sessions.go index b7c84fd8..009f7fca 100644 --- a/sessions/sessions.go +++ b/sessions/sessions.go @@ -90,21 +90,28 @@ func (s *Sessions) Start(ctx context.Context) *Session { // ShiftExpiration move the expire date of a session to a new date // by using session default timeout configuration. -func (s *Sessions) ShiftExpiration(ctx context.Context) { - s.UpdateExpiration(ctx, s.config.Expires) +// It will return `ErrNotImplemented` if a database is used and it does not support this feature, yet. +func (s *Sessions) ShiftExpiration(ctx context.Context) error { + return s.UpdateExpiration(ctx, s.config.Expires) } // UpdateExpiration change expire date of a session to a new date // by using timeout value passed by `expires` receiver. -func (s *Sessions) UpdateExpiration(ctx context.Context, expires time.Duration) { +// It will return `ErrNotFound` when trying to update expiration on a non-existence or not valid session entry. +// It will return `ErrNotImplemented` if a database is used and it does not support this feature, yet. +func (s *Sessions) UpdateExpiration(ctx context.Context, expires time.Duration) error { cookieValue := s.decodeCookieValue(GetCookie(ctx, s.config.Cookie)) - - if cookieValue != "" { - // we should also allow it to expire when the browser closed - if s.provider.UpdateExpiration(cookieValue, expires) || expires == -1 { - s.updateCookie(ctx, cookieValue, expires) - } + if cookieValue == "" { + return ErrNotFound } + + // we should also allow it to expire when the browser closed + err := s.provider.UpdateExpiration(cookieValue, expires) + if err == nil || expires == -1 { + s.updateCookie(ctx, cookieValue, expires) + } + + return err } // DestroyListener is the form of a destroy listener.