package badger import ( "bytes" "errors" "os" "sync/atomic" "time" "github.com/kataras/iris/v12/context" "github.com/kataras/iris/v12/core/memstore" "github.com/kataras/iris/v12/sessions" "github.com/dgraph-io/badger/v2" "github.com/kataras/golog" ) // DefaultFileMode used as the default database's "fileMode" // for creating the sessions directory path, opening and write the session file. var ( DefaultFileMode = 0755 ) // Database the badger(key-value file-based) session storage. type Database struct { // Service is the underline badger database connection, // it's initialized at `New` or `NewFromDB`. // Can be used to get stats. Service *badger.DB logger *golog.Logger closed uint32 // if 1 is closed. } var _ sessions.Database = (*Database)(nil) // New creates and returns a new badger(key-value file-based) storage // instance based on the "directoryPath". // DirectoryPath should is the directory which the badger database will store the sessions, // i.e ./sessions // // It will remove any old session files. func New(directoryPath string) (*Database, error) { if directoryPath == "" { return nil, errors.New("directoryPath is empty") } lindex := directoryPath[len(directoryPath)-1] if lindex != os.PathSeparator && lindex != '/' { directoryPath += string(os.PathSeparator) } // create directories if necessary if err := os.MkdirAll(directoryPath, os.FileMode(DefaultFileMode)); err != nil { return nil, err } opts := badger.DefaultOptions(directoryPath) badgerLogger := context.DefaultLogger("sessionsdb.badger").DisableNewLine() opts.Logger = badgerLogger service, err := badger.Open(opts) if err != nil { badgerLogger.Errorf("unable to initialize the badger-based session database: %v\n", err) return nil, err } return NewFromDB(service), nil } // NewFromDB same as `New` but accepts an already-created custom badger connection instead. func NewFromDB(service *badger.DB) *Database { db := &Database{Service: service} // runtime.SetFinalizer(db, closeDB) return db } // SetLogger sets the logger once before server ran. // By default the Iris one is injected. func (db *Database) SetLogger(logger *golog.Logger) { db.logger = logger } // Acquire receives a session's lifetime from the database, // if the return value is LifeTime{} then the session manager sets the life time based on the expiration duration lives in configuration. func (db *Database) Acquire(sid string, expires time.Duration) memstore.LifeTime { txn := db.Service.NewTransaction(true) defer txn.Commit() bsid := makePrefix(sid) item, err := txn.Get(bsid) if err == nil { // found, return the expiration. return memstore.LifeTime{Time: time.Unix(int64(item.ExpiresAt()), 0)} } // not found, create an entry with ttl and return an empty lifetime, session manager will do its job. if err != nil { if err == badger.ErrKeyNotFound { // create it and set the expiration, we don't care about the value there. err = txn.SetEntry(badger.NewEntry(bsid, bsid).WithTTL(expires)) } } if err != nil { db.logger.Error(err) } return memstore.LifeTime{} // session manager will handle the rest. } // OnUpdateExpiration not implemented here, yet. // Note that this error will not be logged, callers should catch it manually. func (db *Database) OnUpdateExpiration(sid string, newExpires time.Duration) error { return sessions.ErrNotImplemented } var delim = byte('_') func makePrefix(sid string) []byte { return append([]byte(sid), delim) } func makeKey(sid, key string) []byte { return append(makePrefix(sid), []byte(key)...) } // Set sets a key value of a specific session. // Ignore the "immutable". func (db *Database) Set(sid string, key string, value interface{}, ttl time.Duration, immutable bool) error { valueBytes, err := sessions.DefaultTranscoder.Marshal(value) if err != nil { db.logger.Error(err) return err } err = db.Service.Update(func(txn *badger.Txn) error { return txn.SetEntry(badger.NewEntry(makeKey(sid, key), valueBytes).WithTTL(ttl)) }) if err != nil { db.logger.Error(err) } return err } // Get retrieves a session value based on the key. func (db *Database) Get(sid string, key string) (value interface{}) { if err := db.Decode(sid, key, &value); err == nil { return value } return nil } // Decode binds the "outPtr" to the value associated to the provided "key". func (db *Database) Decode(sid, key string, outPtr interface{}) error { err := db.Service.View(func(txn *badger.Txn) error { item, err := txn.Get(makeKey(sid, key)) if err != nil { return err } return item.Value(func(valueBytes []byte) error { return sessions.DefaultTranscoder.Unmarshal(valueBytes, outPtr) }) }) if err != nil && err != badger.ErrKeyNotFound { db.logger.Error(err) } return err } // validSessionItem reports whether the current iterator's item key // is a value of the session id "prefix". func validSessionItem(key, prefix []byte) bool { return len(key) > len(prefix) && bytes.Equal(key[0:len(prefix)], prefix) } // Visit loops through all session keys and values. func (db *Database) Visit(sid string, cb func(key string, value interface{})) error { prefix := makePrefix(sid) txn := db.Service.NewTransaction(false) defer txn.Discard() iter := txn.NewIterator(badger.DefaultIteratorOptions) defer iter.Close() for iter.Rewind(); ; iter.Next() { if !iter.Valid() { break } item := iter.Item() key := item.Key() if !validSessionItem(key, prefix) { continue } var value interface{} err := item.Value(func(valueBytes []byte) error { return sessions.DefaultTranscoder.Unmarshal(valueBytes, &value) }) if err != nil { db.logger.Errorf("[sessionsdb.badger.Visit] %v", err) return err } cb(string(bytes.TrimPrefix(key, prefix)), value) } return nil } var iterOptionsNoValues = badger.IteratorOptions{ PrefetchValues: false, PrefetchSize: 100, Reverse: false, AllVersions: false, } // Len returns the length of the session's entries (keys). func (db *Database) Len(sid string) (n int) { prefix := makePrefix(sid) txn := db.Service.NewTransaction(false) iter := txn.NewIterator(iterOptionsNoValues) for iter.Rewind(); ; iter.Next() { if !iter.Valid() { break } if validSessionItem(iter.Item().Key(), prefix) { n++ } } iter.Close() txn.Discard() return } // Delete removes a session key value based on its key. func (db *Database) Delete(sid string, key string) (deleted bool) { txn := db.Service.NewTransaction(true) err := txn.Delete(makeKey(sid, key)) if err != nil { db.logger.Error(err) return false } return txn.Commit() == nil } // Clear removes all session key values but it keeps the session entry. func (db *Database) Clear(sid string) error { prefix := makePrefix(sid) txn := db.Service.NewTransaction(true) defer txn.Commit() iter := txn.NewIterator(iterOptionsNoValues) defer iter.Close() for iter.Rewind(); iter.ValidForPrefix(prefix); iter.Next() { key := iter.Item().Key() if err := txn.Delete(key); err != nil { db.logger.Warnf("Database.Clear: %s: %v", key, err) return err } } return nil } // Release destroys the session, it clears and removes the session entry, // session manager will create a new session ID on the next request after this call. func (db *Database) Release(sid string) error { // clear all $sid-$key. err := db.Clear(sid) if err != nil { return err } // and remove the $sid. txn := db.Service.NewTransaction(true) if err = txn.Delete([]byte(sid)); err != nil { db.logger.Warnf("Database.Release.Delete: %s: %v", sid, err) return err } if err = txn.Commit(); err != nil { db.logger.Debugf("Database.Release.Commit: %s: %v", sid, err) return err } return nil } // Close shutdowns the badger connection. func (db *Database) Close() error { return closeDB(db) } func closeDB(db *Database) error { if atomic.LoadUint32(&db.closed) > 0 { return nil } err := db.Service.Close() if err != nil { db.logger.Warnf("closing the badger connection: %v", err) } else { atomic.StoreUint32(&db.closed, 1) } return err }