diff --git a/sessions/session.go b/sessions/session.go index 3c1c3518..3d9d46c8 100644 --- a/sessions/session.go +++ b/sessions/session.go @@ -15,17 +15,11 @@ type ( // // This is what will be returned when sess := sessions.Start(). Session struct { - sid string - isNew bool - values memstore.Store // here are the real values - // we could set the flash messages inside values but this will bring us more problems - // because of session databases and because of - // users may want to get all sessions and save them or display them - // but without temp values (flash messages) which are removed after fetching. - // so introduce a new field here. - // NOTE: flashes are not managed by third-party, only inside session struct. + sid string + isNew bool + values memstore.Store // here are the session's values, managed by memstore. flashes map[string]*flashMessage - mu sync.RWMutex + mu sync.RWMutex // for flashes. lifetime LifeTime provider *provider } @@ -42,18 +36,15 @@ func (s *Session) ID() string { return s.sid } -// IsNew returns true if is's a new session +// IsNew returns true if this session is +// created by the current application's process. func (s *Session) IsNew() bool { return s.isNew } // Get returns a value based on its "key". func (s *Session) Get(key string) interface{} { - s.mu.RLock() - value := s.values.Get(key) - s.mu.RUnlock() - - return value + return s.values.Get(key) } // when running on the session manager removes any 'old' flash messages. @@ -69,7 +60,10 @@ func (s *Session) runFlashGC() { // HasFlash returns true if this session has available flash messages. func (s *Session) HasFlash() bool { - return len(s.flashes) > 0 + s.mu.RLock() + has := len(s.flashes) > 0 + s.mu.RUnlock() + return has } // GetFlash returns a stored flash message based on its "key" @@ -102,45 +96,63 @@ func (s *Session) PeekFlash(key string) interface{} { } func (s *Session) peekFlashMessage(key string) (*flashMessage, bool) { - s.mu.Lock() - if fv, found := s.flashes[key]; found { - return fv, true - } - s.mu.Unlock() + s.mu.RLock() + fv, found := s.flashes[key] + s.mu.RUnlock() - return nil, false + if !found { + return nil, false + } + + return fv, true } -// GetString same as Get but returns as string, if nil then returns an empty string. +// GetString same as Get but returns its string representation, +// if key doesn't exist then it returns an empty string. func (s *Session) GetString(key string) string { + return s.GetStringDefault(key, "") +} + +// GetStringDefault same as Get but returns its string representation, +// if key doesn't exist then it returns the "defaultValue". +func (s *Session) GetStringDefault(key string, defaultValue string) string { if value := s.Get(key); value != nil { if v, ok := value.(string); ok { return v } } - return "" + return defaultValue } -// GetFlashString same as GetFlash but returns as string, if nil then returns an empty string. +// GetFlashString same as `GetFlash` but returns its string representation, +// if key doesn't exist then it returns an empty string. func (s *Session) GetFlashString(key string) string { + return s.GetFlashStringDefault(key, "") +} + +// GetFlashStringDefault same as `GetFlash` but returns its string representation, +// if key doesn't exist then it returns the "defaultValue". +func (s *Session) GetFlashStringDefault(key string, defaultValue string) string { if value := s.GetFlash(key); value != nil { if v, ok := value.(string); ok { return v } } - return "" + return defaultValue } var errFindParse = errors.New("Unable to find the %s with key: %s. Found? %#v") -// GetInt same as Get but returns as int, if not found then returns -1 and an error. +// GetInt same as `Get` but returns its int representation, +// if key doesn't exist then it returns -1. func (s *Session) GetInt(key string) (int, error) { return s.GetIntDefault(key, -1) } -// GetIntDefault same as Get but returns as int, if not found then returns the "defaultValue". +// GetIntDefault same as `Get` but returns its int representation, +// if key doesn't exist then it returns the "defaultValue". func (s *Session) GetIntDefault(key string, defaultValue int) (int, error) { v := s.Get(key) @@ -155,8 +167,15 @@ func (s *Session) GetIntDefault(key string, defaultValue int) (int, error) { return defaultValue, errFindParse.Format("int", key, v) } -// GetInt64 same as Get but returns as int64, if not found then returns -1 and an error. +// GetInt64 same as `Get` but returns its int64 representation, +// if key doesn't exist then it returns -1. func (s *Session) GetInt64(key string) (int64, error) { + return s.GetInt64Default(key, -1) +} + +// GetInt64Default same as `Get` but returns its int64 representation, +// if key doesn't exist it returns the "defaultValue". +func (s *Session) GetInt64Default(key string, defaultValue int64) (int64, error) { v := s.Get(key) if vint64, ok := v.(int64); ok { @@ -171,12 +190,18 @@ func (s *Session) GetInt64(key string) (int64, error) { return strconv.ParseInt(vstring, 10, 64) } - return -1, errFindParse.Format("int64", key, v) - + return defaultValue, errFindParse.Format("int64", key, v) } -// GetFloat32 same as Get but returns as float32, if not found then returns -1 and an error. +// GetFloat32 same as `Get` but returns its float32 representation, +// if key doesn't exist then it returns -1. func (s *Session) GetFloat32(key string) (float32, error) { + return s.GetFloat32Default(key, -1) +} + +// GetFloat32Default same as `Get` but returns its float32 representation, +// if key doesn't exist then it returns the "defaultValue". +func (s *Session) GetFloat32Default(key string, defaultValue float32) (float32, error) { v := s.Get(key) if vfloat32, ok := v.(float32); ok { @@ -199,11 +224,18 @@ func (s *Session) GetFloat32(key string) (float32, error) { return float32(vfloat64), nil } - return -1, errFindParse.Format("float32", key, v) + return defaultValue, errFindParse.Format("float32", key, v) } -// GetFloat64 same as Get but returns as float64, if not found then returns -1 and an error. +// GetFloat64 same as `Get` but returns its float64 representation, +// if key doesn't exist then it returns -1. func (s *Session) GetFloat64(key string) (float64, error) { + return s.GetFloat64Default(key, -1) +} + +// GetFloat64Default same as `Get` but returns its float64 representation, +// if key doesn't exist then it returns the "defaultValue". +func (s *Session) GetFloat64Default(key string, defaultValue float64) (float64, error) { v := s.Get(key) if vfloat32, ok := v.(float32); ok { @@ -222,11 +254,18 @@ func (s *Session) GetFloat64(key string) (float64, error) { return strconv.ParseFloat(vstring, 32) } - return -1, errFindParse.Format("float64", key, v) + return defaultValue, errFindParse.Format("float64", key, v) } -// GetBoolean same as Get but returns as boolean, if not found then returns -1 and an error +// GetBoolean same as `Get` but returns its boolean representation, +// if key doesn't exist then it returns false. func (s *Session) GetBoolean(key string) (bool, error) { + return s.GetBooleanDefault(key, false) +} + +// GetBooleanDefault same as `Get` but returns its boolean representation, +// if key doesn't exist then it returns the "defaultValue". +func (s *Session) GetBooleanDefault(key string, defaultValue bool) (bool, error) { v := s.Get(key) // here we could check for "true", "false" and 0 for false and 1 for true // but this may cause unexpected behavior from the developer if they expecting an error @@ -235,7 +274,7 @@ func (s *Session) GetBoolean(key string) (bool, error) { return vb, nil } - return false, errFindParse.Format("bool", key, v) + return defaultValue, errFindParse.Format("bool", key, v) } // GetAll returns a copy of all session's values. @@ -270,11 +309,11 @@ func (s *Session) VisitAll(cb func(k string, v interface{})) { func (s *Session) set(key string, value interface{}, immutable bool) { action := ActionCreate // defaults to create, means the first insert. - s.mu.Lock() isFirst := s.values.Len() == 0 entry, isNew := s.values.Save(key, value, immutable) - s.isNew = false + s.mu.Lock() + s.isNew = false s.mu.Unlock() if !isFirst {