From 51e1b44c7203c8ace3a33fbf82de8cf77e7a25ff Mon Sep 17 00:00:00 2001 From: "Gerasimos (Makis) Maropoulos" Date: Wed, 8 Aug 2018 13:22:00 +0300 Subject: [PATCH] websocket: use of sync.Map Former-commit-id: 8ecb1e6f70380195ce916d4dfc3fe8d41c851995 --- _examples/websocket/custom-go-client/main.go | 2 + websocket/server.go | 214 +++++++++---------- 2 files changed, 103 insertions(+), 113 deletions(-) diff --git a/_examples/websocket/custom-go-client/main.go b/_examples/websocket/custom-go-client/main.go index dde7560a..da199b5a 100644 --- a/_examples/websocket/custom-go-client/main.go +++ b/_examples/websocket/custom-go-client/main.go @@ -18,6 +18,8 @@ import ( // WS is the current websocket connection var WS *xwebsocket.Conn +// $ go run main.go server +// $ go run main.go client func main() { if len(os.Args) == 2 && strings.ToLower(os.Args[1]) == "server" { ServerLoop() diff --git a/websocket/server.go b/websocket/server.go index b8f0d487..50b71126 100644 --- a/websocket/server.go +++ b/websocket/server.go @@ -8,75 +8,6 @@ import ( "github.com/gorilla/websocket" ) -type connectionKV struct { - key string // the connection ID - value *connection -} - -type connections []connectionKV - -func (cs *connections) add(key string, value *connection) { - args := *cs - n := len(args) - // check if already id/key exist, if yes replace the conn - for i := 0; i < n; i++ { - kv := &args[i] - if kv.key == key { - kv.value = value - return - } - } - - c := cap(args) - // make the connections slice bigger and put the conn - if c > n { - args = args[:n+1] - kv := &args[n] - kv.key = key - kv.value = value - *cs = args - return - } - // append to the connections slice and put the conn - kv := connectionKV{} - kv.key = key - kv.value = value - *cs = append(args, kv) -} - -func (cs *connections) get(key string) *connection { - args := *cs - n := len(args) - for i := 0; i < n; i++ { - kv := &args[i] - if kv.key == key { - return kv.value - } - } - return nil -} - -// returns the connection which removed and a bool value of found or not -// the connection is useful to fire the disconnect events, we use that form in order to -// make work things faster without the need of get-remove, just -remove should do the job. -func (cs *connections) remove(key string) (*connection, bool) { - args := *cs - n := len(args) - for i := 0; i < n; i++ { - kv := &args[i] - if kv.key == key { - conn := kv.value - // we found the index, - // let's remove the item by appending to the temp and - // after set the pointer of the slice to this temp args - args = append(args[:i], args[i+1:]...) - *cs = args - return conn, true - } - } - return nil, false -} - type ( // ConnectionFunc is the callback which fires when a client/connection is connected to the Server. // Receives one parameter which is the Connection @@ -104,9 +35,9 @@ type ( // To serve the built'n javascript client-side library look the `websocket.ClientHandler`. Server struct { config Config - connections connections + connections sync.Map // key = the Connection ID. rooms map[string][]string // by default a connection is joined to a room which has the connection id as its name - mu sync.RWMutex // for rooms + mu sync.RWMutex // for rooms and connections, they should be in-sync between them as well. onConnectionListeners []ConnectionFunc //connectionPool sync.Pool // sadly we can't make this because the websocket connection is live until is closed. upgrader websocket.Upgrader @@ -121,8 +52,9 @@ type ( func New(cfg Config) *Server { cfg = cfg.Validate() return &Server{ - config: cfg, - rooms: make(map[string][]string, 0), + config: cfg, + connections: sync.Map{}, // ready-to-use, this is not necessary. + rooms: make(map[string][]string), onConnectionListeners: make([]ConnectionFunc, 0), upgrader: websocket.Upgrader{ HandshakeTimeout: cfg.HandshakeTimeout, @@ -189,6 +121,22 @@ func (s *Server) Upgrade(ctx context.Context) Connection { return s.handleConnection(ctx, conn) } +func (s *Server) addConnection(c *connection) { + s.connections.Store(c.id, c) +} + +func (s *Server) getConnection(connID string) (*connection, bool) { + if cValue, ok := s.connections.Load(connID); ok { + // this cast is not necessary, + // we know that we always save a connection, but for good or worse let it be here. + if conn, ok := cValue.(*connection); ok { + return conn, ok + } + } + + return nil, false +} + // wrapConnection wraps an underline connection to an iris websocket connection. // It does NOT starts its writer, reader and event mux, the caller is responsible for that. func (s *Server) handleConnection(ctx context.Context, websocketConn UnderlineConnection) *connection { @@ -197,10 +145,10 @@ func (s *Server) handleConnection(ctx context.Context, websocketConn UnderlineCo // create the new connection c := newConnection(ctx, s, websocketConn, cid) // add the connection to the Server's list - s.connections.add(cid, c) + s.addConnection(c) // join to itself - s.Join(c.ID(), c.ID()) + s.Join(c.id, c.id) return c } @@ -237,8 +185,8 @@ func (s *Server) OnConnection(cb ConnectionFunc) { // useful when you have defined a custom connection id generator (based on a database) // and you want to check if that connection is already connected (on multiple tabs) func (s *Server) IsConnected(connID string) bool { - c := s.connections.get(connID) - return c != nil + _, found := s.getConnection(connID) + return found } // Join joins a websocket client to a room, @@ -320,48 +268,76 @@ func (s *Server) leave(roomName string, connID string) (left bool) { } if left { - // fire the on room leave connection's listeners - s.connections.get(connID).fireOnLeave(roomName) + // fire the on room leave connection's listeners, + // the existence check is not necessary here. + if c, ok := s.getConnection(connID); ok { + c.fireOnLeave(roomName) + } } return } // GetTotalConnections returns the number of total connections -func (s *Server) GetTotalConnections() int { - s.mu.RLock() - l := len(s.connections) - s.mu.RUnlock() - return l +func (s *Server) GetTotalConnections() (n int) { + s.connections.Range(func(k, v interface{}) bool { + n++ + return true + }) + + return n } // GetConnections returns all connections func (s *Server) GetConnections() []Connection { - s.mu.RLock() - conns := make([]Connection, len(s.connections), len(s.connections)) - for i, c := range s.connections { - conns[i] = c.value - } - s.mu.RUnlock() + // first call of Range to get the total length, we don't want to use append or manually grow the list here for many reasons. + length := s.GetTotalConnections() + conns := make([]Connection, length, length) + i := 0 + // second call of Range. + s.connections.Range(func(k, v interface{}) bool { + conn, ok := v.(*connection) + if !ok { + // if for some reason (should never happen), the value is not stored as *connection + // then stop the iteration and don't continue insertion of the result connections + // in order to avoid any issues while end-dev will try to iterate a nil entry. + return false + } + conns[i] = conn + i++ + return true + }) + return conns } // GetConnection returns single connection -func (s *Server) GetConnection(key string) Connection { - return s.connections.get(key) +func (s *Server) GetConnection(connID string) Connection { + conn, ok := s.getConnection(connID) + if !ok { + return nil + } + + return conn } // GetConnectionsByRoom returns a list of Connection // which are joined to this room. func (s *Server) GetConnectionsByRoom(roomName string) []Connection { - s.mu.Lock() var conns []Connection + s.mu.RLock() if connIDs, found := s.rooms[roomName]; found { for _, connID := range connIDs { - conns = append(conns, s.connections.get(connID)) + // existence check is not necessary here. + if cValue, ok := s.connections.Load(connID); ok { + if conn, ok := cValue.(*connection); ok { + conns = append(conns, conn) + } + } } - } - s.mu.Unlock() + + s.mu.RUnlock() + return conns } @@ -379,7 +355,7 @@ func (s *Server) emitMessage(from, to string, data []byte) { if s.rooms[to] != nil { // it suppose to send the message to a specific room/or a user inside its own room for _, connectionIDInsideRoom := range s.rooms[to] { - if c := s.connections.get(connectionIDInsideRoom); c != nil { + if c, ok := s.getConnection(connectionIDInsideRoom); ok { c.writeDefault(data) //send the message to the client(s) } else { // the connection is not connected but it's inside the room, we remove it on disconnect but for ANY CASE: @@ -392,19 +368,32 @@ func (s *Server) emitMessage(from, to string, data []byte) { } } } else { - // it suppose to send the message to all opened connections or to all except the sender - for _, cKV := range s.connections { - connID := cKV.key + // it suppose to send the message to all opened connections or to all except the sender. + s.connections.Range(func(k, v interface{}) bool { + connID, ok := k.(string) + if !ok { + // should never happen. + return false + } + if to != All && to != connID { // if it's not suppose to send to all connections (including itself) if to == Broadcast && from == connID { // if broadcast to other connections except this - continue //here we do the opossite of previous block, - // just skip this connection when it's suppose to send the message to all connections except the sender + // here we do the opossite of previous block, + // just skip this connection when it's suppose to send the message to all connections except the sender. + return false } } + + // not necessary cast. + conn, ok := v.(*connection) + if !ok { + return false + } // send to the client(s) when the top validators passed - cKV.value.writeDefault(data) - } + conn.writeDefault(data) + return true + }) } } @@ -421,16 +410,15 @@ func (s *Server) Disconnect(connID string) (err error) { // note: we cannot use that to send data if the client is actually closed. s.LeaveAll(connID) - // remove the connection from the list - if c, ok := s.connections.remove(connID); ok { - if !c.disconnected { - c.disconnected = true + // remove the connection from the list. + if conn, ok := s.getConnection(connID); ok { + conn.disconnected = true + // fire the disconnect callbacks, if any. + conn.fireDisconnect() + // close the underline connection and return its error, if any. + err = conn.underline.Close() - // fire the disconnect callbacks, if any - c.fireDisconnect() - // close the underline connection and return its error, if any. - err = c.underline.Close() - } + s.connections.Delete(connID) } return