mirror of
https://github.com/kataras/iris.git
synced 2025-02-02 15:30:36 +01:00
websocket: use of sync.Map
Former-commit-id: 8ecb1e6f70380195ce916d4dfc3fe8d41c851995
This commit is contained in:
parent
43fd73eab9
commit
51e1b44c72
|
@ -18,6 +18,8 @@ import (
|
|||
// WS is the current websocket connection
|
||||
var WS *xwebsocket.Conn
|
||||
|
||||
// $ go run main.go server
|
||||
// $ go run main.go client
|
||||
func main() {
|
||||
if len(os.Args) == 2 && strings.ToLower(os.Args[1]) == "server" {
|
||||
ServerLoop()
|
||||
|
|
|
@ -8,75 +8,6 @@ import (
|
|||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type connectionKV struct {
|
||||
key string // the connection ID
|
||||
value *connection
|
||||
}
|
||||
|
||||
type connections []connectionKV
|
||||
|
||||
func (cs *connections) add(key string, value *connection) {
|
||||
args := *cs
|
||||
n := len(args)
|
||||
// check if already id/key exist, if yes replace the conn
|
||||
for i := 0; i < n; i++ {
|
||||
kv := &args[i]
|
||||
if kv.key == key {
|
||||
kv.value = value
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c := cap(args)
|
||||
// make the connections slice bigger and put the conn
|
||||
if c > n {
|
||||
args = args[:n+1]
|
||||
kv := &args[n]
|
||||
kv.key = key
|
||||
kv.value = value
|
||||
*cs = args
|
||||
return
|
||||
}
|
||||
// append to the connections slice and put the conn
|
||||
kv := connectionKV{}
|
||||
kv.key = key
|
||||
kv.value = value
|
||||
*cs = append(args, kv)
|
||||
}
|
||||
|
||||
func (cs *connections) get(key string) *connection {
|
||||
args := *cs
|
||||
n := len(args)
|
||||
for i := 0; i < n; i++ {
|
||||
kv := &args[i]
|
||||
if kv.key == key {
|
||||
return kv.value
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// returns the connection which removed and a bool value of found or not
|
||||
// the connection is useful to fire the disconnect events, we use that form in order to
|
||||
// make work things faster without the need of get-remove, just -remove should do the job.
|
||||
func (cs *connections) remove(key string) (*connection, bool) {
|
||||
args := *cs
|
||||
n := len(args)
|
||||
for i := 0; i < n; i++ {
|
||||
kv := &args[i]
|
||||
if kv.key == key {
|
||||
conn := kv.value
|
||||
// we found the index,
|
||||
// let's remove the item by appending to the temp and
|
||||
// after set the pointer of the slice to this temp args
|
||||
args = append(args[:i], args[i+1:]...)
|
||||
*cs = args
|
||||
return conn, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
type (
|
||||
// ConnectionFunc is the callback which fires when a client/connection is connected to the Server.
|
||||
// Receives one parameter which is the Connection
|
||||
|
@ -104,9 +35,9 @@ type (
|
|||
// To serve the built'n javascript client-side library look the `websocket.ClientHandler`.
|
||||
Server struct {
|
||||
config Config
|
||||
connections connections
|
||||
connections sync.Map // key = the Connection ID.
|
||||
rooms map[string][]string // by default a connection is joined to a room which has the connection id as its name
|
||||
mu sync.RWMutex // for rooms
|
||||
mu sync.RWMutex // for rooms and connections, they should be in-sync between them as well.
|
||||
onConnectionListeners []ConnectionFunc
|
||||
//connectionPool sync.Pool // sadly we can't make this because the websocket connection is live until is closed.
|
||||
upgrader websocket.Upgrader
|
||||
|
@ -122,7 +53,8 @@ func New(cfg Config) *Server {
|
|||
cfg = cfg.Validate()
|
||||
return &Server{
|
||||
config: cfg,
|
||||
rooms: make(map[string][]string, 0),
|
||||
connections: sync.Map{}, // ready-to-use, this is not necessary.
|
||||
rooms: make(map[string][]string),
|
||||
onConnectionListeners: make([]ConnectionFunc, 0),
|
||||
upgrader: websocket.Upgrader{
|
||||
HandshakeTimeout: cfg.HandshakeTimeout,
|
||||
|
@ -189,6 +121,22 @@ func (s *Server) Upgrade(ctx context.Context) Connection {
|
|||
return s.handleConnection(ctx, conn)
|
||||
}
|
||||
|
||||
func (s *Server) addConnection(c *connection) {
|
||||
s.connections.Store(c.id, c)
|
||||
}
|
||||
|
||||
func (s *Server) getConnection(connID string) (*connection, bool) {
|
||||
if cValue, ok := s.connections.Load(connID); ok {
|
||||
// this cast is not necessary,
|
||||
// we know that we always save a connection, but for good or worse let it be here.
|
||||
if conn, ok := cValue.(*connection); ok {
|
||||
return conn, ok
|
||||
}
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// wrapConnection wraps an underline connection to an iris websocket connection.
|
||||
// It does NOT starts its writer, reader and event mux, the caller is responsible for that.
|
||||
func (s *Server) handleConnection(ctx context.Context, websocketConn UnderlineConnection) *connection {
|
||||
|
@ -197,10 +145,10 @@ func (s *Server) handleConnection(ctx context.Context, websocketConn UnderlineCo
|
|||
// create the new connection
|
||||
c := newConnection(ctx, s, websocketConn, cid)
|
||||
// add the connection to the Server's list
|
||||
s.connections.add(cid, c)
|
||||
s.addConnection(c)
|
||||
|
||||
// join to itself
|
||||
s.Join(c.ID(), c.ID())
|
||||
s.Join(c.id, c.id)
|
||||
|
||||
return c
|
||||
}
|
||||
|
@ -237,8 +185,8 @@ func (s *Server) OnConnection(cb ConnectionFunc) {
|
|||
// useful when you have defined a custom connection id generator (based on a database)
|
||||
// and you want to check if that connection is already connected (on multiple tabs)
|
||||
func (s *Server) IsConnected(connID string) bool {
|
||||
c := s.connections.get(connID)
|
||||
return c != nil
|
||||
_, found := s.getConnection(connID)
|
||||
return found
|
||||
}
|
||||
|
||||
// Join joins a websocket client to a room,
|
||||
|
@ -320,48 +268,76 @@ func (s *Server) leave(roomName string, connID string) (left bool) {
|
|||
}
|
||||
|
||||
if left {
|
||||
// fire the on room leave connection's listeners
|
||||
s.connections.get(connID).fireOnLeave(roomName)
|
||||
// fire the on room leave connection's listeners,
|
||||
// the existence check is not necessary here.
|
||||
if c, ok := s.getConnection(connID); ok {
|
||||
c.fireOnLeave(roomName)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// GetTotalConnections returns the number of total connections
|
||||
func (s *Server) GetTotalConnections() int {
|
||||
s.mu.RLock()
|
||||
l := len(s.connections)
|
||||
s.mu.RUnlock()
|
||||
return l
|
||||
func (s *Server) GetTotalConnections() (n int) {
|
||||
s.connections.Range(func(k, v interface{}) bool {
|
||||
n++
|
||||
return true
|
||||
})
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
// GetConnections returns all connections
|
||||
func (s *Server) GetConnections() []Connection {
|
||||
s.mu.RLock()
|
||||
conns := make([]Connection, len(s.connections), len(s.connections))
|
||||
for i, c := range s.connections {
|
||||
conns[i] = c.value
|
||||
// first call of Range to get the total length, we don't want to use append or manually grow the list here for many reasons.
|
||||
length := s.GetTotalConnections()
|
||||
conns := make([]Connection, length, length)
|
||||
i := 0
|
||||
// second call of Range.
|
||||
s.connections.Range(func(k, v interface{}) bool {
|
||||
conn, ok := v.(*connection)
|
||||
if !ok {
|
||||
// if for some reason (should never happen), the value is not stored as *connection
|
||||
// then stop the iteration and don't continue insertion of the result connections
|
||||
// in order to avoid any issues while end-dev will try to iterate a nil entry.
|
||||
return false
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
conns[i] = conn
|
||||
i++
|
||||
return true
|
||||
})
|
||||
|
||||
return conns
|
||||
}
|
||||
|
||||
// GetConnection returns single connection
|
||||
func (s *Server) GetConnection(key string) Connection {
|
||||
return s.connections.get(key)
|
||||
func (s *Server) GetConnection(connID string) Connection {
|
||||
conn, ok := s.getConnection(connID)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return conn
|
||||
}
|
||||
|
||||
// GetConnectionsByRoom returns a list of Connection
|
||||
// which are joined to this room.
|
||||
func (s *Server) GetConnectionsByRoom(roomName string) []Connection {
|
||||
s.mu.Lock()
|
||||
var conns []Connection
|
||||
s.mu.RLock()
|
||||
if connIDs, found := s.rooms[roomName]; found {
|
||||
for _, connID := range connIDs {
|
||||
conns = append(conns, s.connections.get(connID))
|
||||
// existence check is not necessary here.
|
||||
if cValue, ok := s.connections.Load(connID); ok {
|
||||
if conn, ok := cValue.(*connection); ok {
|
||||
conns = append(conns, conn)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
s.mu.Unlock()
|
||||
s.mu.RUnlock()
|
||||
|
||||
return conns
|
||||
}
|
||||
|
||||
|
@ -379,7 +355,7 @@ func (s *Server) emitMessage(from, to string, data []byte) {
|
|||
if s.rooms[to] != nil {
|
||||
// it suppose to send the message to a specific room/or a user inside its own room
|
||||
for _, connectionIDInsideRoom := range s.rooms[to] {
|
||||
if c := s.connections.get(connectionIDInsideRoom); c != nil {
|
||||
if c, ok := s.getConnection(connectionIDInsideRoom); ok {
|
||||
c.writeDefault(data) //send the message to the client(s)
|
||||
} else {
|
||||
// the connection is not connected but it's inside the room, we remove it on disconnect but for ANY CASE:
|
||||
|
@ -392,19 +368,32 @@ func (s *Server) emitMessage(from, to string, data []byte) {
|
|||
}
|
||||
}
|
||||
} else {
|
||||
// it suppose to send the message to all opened connections or to all except the sender
|
||||
for _, cKV := range s.connections {
|
||||
connID := cKV.key
|
||||
// it suppose to send the message to all opened connections or to all except the sender.
|
||||
s.connections.Range(func(k, v interface{}) bool {
|
||||
connID, ok := k.(string)
|
||||
if !ok {
|
||||
// should never happen.
|
||||
return false
|
||||
}
|
||||
|
||||
if to != All && to != connID { // if it's not suppose to send to all connections (including itself)
|
||||
if to == Broadcast && from == connID { // if broadcast to other connections except this
|
||||
continue //here we do the opossite of previous block,
|
||||
// just skip this connection when it's suppose to send the message to all connections except the sender
|
||||
// here we do the opossite of previous block,
|
||||
// just skip this connection when it's suppose to send the message to all connections except the sender.
|
||||
return false
|
||||
}
|
||||
|
||||
}
|
||||
// send to the client(s) when the top validators passed
|
||||
cKV.value.writeDefault(data)
|
||||
|
||||
// not necessary cast.
|
||||
conn, ok := v.(*connection)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
// send to the client(s) when the top validators passed
|
||||
conn.writeDefault(data)
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -421,16 +410,15 @@ func (s *Server) Disconnect(connID string) (err error) {
|
|||
// note: we cannot use that to send data if the client is actually closed.
|
||||
s.LeaveAll(connID)
|
||||
|
||||
// remove the connection from the list
|
||||
if c, ok := s.connections.remove(connID); ok {
|
||||
if !c.disconnected {
|
||||
c.disconnected = true
|
||||
|
||||
// fire the disconnect callbacks, if any
|
||||
c.fireDisconnect()
|
||||
// remove the connection from the list.
|
||||
if conn, ok := s.getConnection(connID); ok {
|
||||
conn.disconnected = true
|
||||
// fire the disconnect callbacks, if any.
|
||||
conn.fireDisconnect()
|
||||
// close the underline connection and return its error, if any.
|
||||
err = c.underline.Close()
|
||||
}
|
||||
err = conn.underline.Close()
|
||||
|
||||
s.connections.Delete(connID)
|
||||
}
|
||||
|
||||
return
|
||||
|
|
Loading…
Reference in New Issue
Block a user