mirror of
https://github.com/kataras/iris.git
synced 2025-02-02 15:30:36 +01:00
websocket: use of sync.Map
Former-commit-id: 8ecb1e6f70380195ce916d4dfc3fe8d41c851995
This commit is contained in:
parent
43fd73eab9
commit
51e1b44c72
|
@ -18,6 +18,8 @@ import (
|
||||||
// WS is the current websocket connection
|
// WS is the current websocket connection
|
||||||
var WS *xwebsocket.Conn
|
var WS *xwebsocket.Conn
|
||||||
|
|
||||||
|
// $ go run main.go server
|
||||||
|
// $ go run main.go client
|
||||||
func main() {
|
func main() {
|
||||||
if len(os.Args) == 2 && strings.ToLower(os.Args[1]) == "server" {
|
if len(os.Args) == 2 && strings.ToLower(os.Args[1]) == "server" {
|
||||||
ServerLoop()
|
ServerLoop()
|
||||||
|
|
|
@ -8,75 +8,6 @@ import (
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
type connectionKV struct {
|
|
||||||
key string // the connection ID
|
|
||||||
value *connection
|
|
||||||
}
|
|
||||||
|
|
||||||
type connections []connectionKV
|
|
||||||
|
|
||||||
func (cs *connections) add(key string, value *connection) {
|
|
||||||
args := *cs
|
|
||||||
n := len(args)
|
|
||||||
// check if already id/key exist, if yes replace the conn
|
|
||||||
for i := 0; i < n; i++ {
|
|
||||||
kv := &args[i]
|
|
||||||
if kv.key == key {
|
|
||||||
kv.value = value
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
c := cap(args)
|
|
||||||
// make the connections slice bigger and put the conn
|
|
||||||
if c > n {
|
|
||||||
args = args[:n+1]
|
|
||||||
kv := &args[n]
|
|
||||||
kv.key = key
|
|
||||||
kv.value = value
|
|
||||||
*cs = args
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// append to the connections slice and put the conn
|
|
||||||
kv := connectionKV{}
|
|
||||||
kv.key = key
|
|
||||||
kv.value = value
|
|
||||||
*cs = append(args, kv)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cs *connections) get(key string) *connection {
|
|
||||||
args := *cs
|
|
||||||
n := len(args)
|
|
||||||
for i := 0; i < n; i++ {
|
|
||||||
kv := &args[i]
|
|
||||||
if kv.key == key {
|
|
||||||
return kv.value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// returns the connection which removed and a bool value of found or not
|
|
||||||
// the connection is useful to fire the disconnect events, we use that form in order to
|
|
||||||
// make work things faster without the need of get-remove, just -remove should do the job.
|
|
||||||
func (cs *connections) remove(key string) (*connection, bool) {
|
|
||||||
args := *cs
|
|
||||||
n := len(args)
|
|
||||||
for i := 0; i < n; i++ {
|
|
||||||
kv := &args[i]
|
|
||||||
if kv.key == key {
|
|
||||||
conn := kv.value
|
|
||||||
// we found the index,
|
|
||||||
// let's remove the item by appending to the temp and
|
|
||||||
// after set the pointer of the slice to this temp args
|
|
||||||
args = append(args[:i], args[i+1:]...)
|
|
||||||
*cs = args
|
|
||||||
return conn, true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
type (
|
type (
|
||||||
// ConnectionFunc is the callback which fires when a client/connection is connected to the Server.
|
// ConnectionFunc is the callback which fires when a client/connection is connected to the Server.
|
||||||
// Receives one parameter which is the Connection
|
// Receives one parameter which is the Connection
|
||||||
|
@ -104,9 +35,9 @@ type (
|
||||||
// To serve the built'n javascript client-side library look the `websocket.ClientHandler`.
|
// To serve the built'n javascript client-side library look the `websocket.ClientHandler`.
|
||||||
Server struct {
|
Server struct {
|
||||||
config Config
|
config Config
|
||||||
connections connections
|
connections sync.Map // key = the Connection ID.
|
||||||
rooms map[string][]string // by default a connection is joined to a room which has the connection id as its name
|
rooms map[string][]string // by default a connection is joined to a room which has the connection id as its name
|
||||||
mu sync.RWMutex // for rooms
|
mu sync.RWMutex // for rooms and connections, they should be in-sync between them as well.
|
||||||
onConnectionListeners []ConnectionFunc
|
onConnectionListeners []ConnectionFunc
|
||||||
//connectionPool sync.Pool // sadly we can't make this because the websocket connection is live until is closed.
|
//connectionPool sync.Pool // sadly we can't make this because the websocket connection is live until is closed.
|
||||||
upgrader websocket.Upgrader
|
upgrader websocket.Upgrader
|
||||||
|
@ -121,8 +52,9 @@ type (
|
||||||
func New(cfg Config) *Server {
|
func New(cfg Config) *Server {
|
||||||
cfg = cfg.Validate()
|
cfg = cfg.Validate()
|
||||||
return &Server{
|
return &Server{
|
||||||
config: cfg,
|
config: cfg,
|
||||||
rooms: make(map[string][]string, 0),
|
connections: sync.Map{}, // ready-to-use, this is not necessary.
|
||||||
|
rooms: make(map[string][]string),
|
||||||
onConnectionListeners: make([]ConnectionFunc, 0),
|
onConnectionListeners: make([]ConnectionFunc, 0),
|
||||||
upgrader: websocket.Upgrader{
|
upgrader: websocket.Upgrader{
|
||||||
HandshakeTimeout: cfg.HandshakeTimeout,
|
HandshakeTimeout: cfg.HandshakeTimeout,
|
||||||
|
@ -189,6 +121,22 @@ func (s *Server) Upgrade(ctx context.Context) Connection {
|
||||||
return s.handleConnection(ctx, conn)
|
return s.handleConnection(ctx, conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) addConnection(c *connection) {
|
||||||
|
s.connections.Store(c.id, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) getConnection(connID string) (*connection, bool) {
|
||||||
|
if cValue, ok := s.connections.Load(connID); ok {
|
||||||
|
// this cast is not necessary,
|
||||||
|
// we know that we always save a connection, but for good or worse let it be here.
|
||||||
|
if conn, ok := cValue.(*connection); ok {
|
||||||
|
return conn, ok
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
// wrapConnection wraps an underline connection to an iris websocket connection.
|
// wrapConnection wraps an underline connection to an iris websocket connection.
|
||||||
// It does NOT starts its writer, reader and event mux, the caller is responsible for that.
|
// It does NOT starts its writer, reader and event mux, the caller is responsible for that.
|
||||||
func (s *Server) handleConnection(ctx context.Context, websocketConn UnderlineConnection) *connection {
|
func (s *Server) handleConnection(ctx context.Context, websocketConn UnderlineConnection) *connection {
|
||||||
|
@ -197,10 +145,10 @@ func (s *Server) handleConnection(ctx context.Context, websocketConn UnderlineCo
|
||||||
// create the new connection
|
// create the new connection
|
||||||
c := newConnection(ctx, s, websocketConn, cid)
|
c := newConnection(ctx, s, websocketConn, cid)
|
||||||
// add the connection to the Server's list
|
// add the connection to the Server's list
|
||||||
s.connections.add(cid, c)
|
s.addConnection(c)
|
||||||
|
|
||||||
// join to itself
|
// join to itself
|
||||||
s.Join(c.ID(), c.ID())
|
s.Join(c.id, c.id)
|
||||||
|
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
@ -237,8 +185,8 @@ func (s *Server) OnConnection(cb ConnectionFunc) {
|
||||||
// useful when you have defined a custom connection id generator (based on a database)
|
// useful when you have defined a custom connection id generator (based on a database)
|
||||||
// and you want to check if that connection is already connected (on multiple tabs)
|
// and you want to check if that connection is already connected (on multiple tabs)
|
||||||
func (s *Server) IsConnected(connID string) bool {
|
func (s *Server) IsConnected(connID string) bool {
|
||||||
c := s.connections.get(connID)
|
_, found := s.getConnection(connID)
|
||||||
return c != nil
|
return found
|
||||||
}
|
}
|
||||||
|
|
||||||
// Join joins a websocket client to a room,
|
// Join joins a websocket client to a room,
|
||||||
|
@ -320,48 +268,76 @@ func (s *Server) leave(roomName string, connID string) (left bool) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if left {
|
if left {
|
||||||
// fire the on room leave connection's listeners
|
// fire the on room leave connection's listeners,
|
||||||
s.connections.get(connID).fireOnLeave(roomName)
|
// the existence check is not necessary here.
|
||||||
|
if c, ok := s.getConnection(connID); ok {
|
||||||
|
c.fireOnLeave(roomName)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetTotalConnections returns the number of total connections
|
// GetTotalConnections returns the number of total connections
|
||||||
func (s *Server) GetTotalConnections() int {
|
func (s *Server) GetTotalConnections() (n int) {
|
||||||
s.mu.RLock()
|
s.connections.Range(func(k, v interface{}) bool {
|
||||||
l := len(s.connections)
|
n++
|
||||||
s.mu.RUnlock()
|
return true
|
||||||
return l
|
})
|
||||||
|
|
||||||
|
return n
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetConnections returns all connections
|
// GetConnections returns all connections
|
||||||
func (s *Server) GetConnections() []Connection {
|
func (s *Server) GetConnections() []Connection {
|
||||||
s.mu.RLock()
|
// first call of Range to get the total length, we don't want to use append or manually grow the list here for many reasons.
|
||||||
conns := make([]Connection, len(s.connections), len(s.connections))
|
length := s.GetTotalConnections()
|
||||||
for i, c := range s.connections {
|
conns := make([]Connection, length, length)
|
||||||
conns[i] = c.value
|
i := 0
|
||||||
}
|
// second call of Range.
|
||||||
s.mu.RUnlock()
|
s.connections.Range(func(k, v interface{}) bool {
|
||||||
|
conn, ok := v.(*connection)
|
||||||
|
if !ok {
|
||||||
|
// if for some reason (should never happen), the value is not stored as *connection
|
||||||
|
// then stop the iteration and don't continue insertion of the result connections
|
||||||
|
// in order to avoid any issues while end-dev will try to iterate a nil entry.
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
conns[i] = conn
|
||||||
|
i++
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
return conns
|
return conns
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetConnection returns single connection
|
// GetConnection returns single connection
|
||||||
func (s *Server) GetConnection(key string) Connection {
|
func (s *Server) GetConnection(connID string) Connection {
|
||||||
return s.connections.get(key)
|
conn, ok := s.getConnection(connID)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return conn
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetConnectionsByRoom returns a list of Connection
|
// GetConnectionsByRoom returns a list of Connection
|
||||||
// which are joined to this room.
|
// which are joined to this room.
|
||||||
func (s *Server) GetConnectionsByRoom(roomName string) []Connection {
|
func (s *Server) GetConnectionsByRoom(roomName string) []Connection {
|
||||||
s.mu.Lock()
|
|
||||||
var conns []Connection
|
var conns []Connection
|
||||||
|
s.mu.RLock()
|
||||||
if connIDs, found := s.rooms[roomName]; found {
|
if connIDs, found := s.rooms[roomName]; found {
|
||||||
for _, connID := range connIDs {
|
for _, connID := range connIDs {
|
||||||
conns = append(conns, s.connections.get(connID))
|
// existence check is not necessary here.
|
||||||
|
if cValue, ok := s.connections.Load(connID); ok {
|
||||||
|
if conn, ok := cValue.(*connection); ok {
|
||||||
|
conns = append(conns, conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
s.mu.Unlock()
|
|
||||||
|
s.mu.RUnlock()
|
||||||
|
|
||||||
return conns
|
return conns
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -379,7 +355,7 @@ func (s *Server) emitMessage(from, to string, data []byte) {
|
||||||
if s.rooms[to] != nil {
|
if s.rooms[to] != nil {
|
||||||
// it suppose to send the message to a specific room/or a user inside its own room
|
// it suppose to send the message to a specific room/or a user inside its own room
|
||||||
for _, connectionIDInsideRoom := range s.rooms[to] {
|
for _, connectionIDInsideRoom := range s.rooms[to] {
|
||||||
if c := s.connections.get(connectionIDInsideRoom); c != nil {
|
if c, ok := s.getConnection(connectionIDInsideRoom); ok {
|
||||||
c.writeDefault(data) //send the message to the client(s)
|
c.writeDefault(data) //send the message to the client(s)
|
||||||
} else {
|
} else {
|
||||||
// the connection is not connected but it's inside the room, we remove it on disconnect but for ANY CASE:
|
// the connection is not connected but it's inside the room, we remove it on disconnect but for ANY CASE:
|
||||||
|
@ -392,19 +368,32 @@ func (s *Server) emitMessage(from, to string, data []byte) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// it suppose to send the message to all opened connections or to all except the sender
|
// it suppose to send the message to all opened connections or to all except the sender.
|
||||||
for _, cKV := range s.connections {
|
s.connections.Range(func(k, v interface{}) bool {
|
||||||
connID := cKV.key
|
connID, ok := k.(string)
|
||||||
|
if !ok {
|
||||||
|
// should never happen.
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
if to != All && to != connID { // if it's not suppose to send to all connections (including itself)
|
if to != All && to != connID { // if it's not suppose to send to all connections (including itself)
|
||||||
if to == Broadcast && from == connID { // if broadcast to other connections except this
|
if to == Broadcast && from == connID { // if broadcast to other connections except this
|
||||||
continue //here we do the opossite of previous block,
|
// here we do the opossite of previous block,
|
||||||
// just skip this connection when it's suppose to send the message to all connections except the sender
|
// just skip this connection when it's suppose to send the message to all connections except the sender.
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// not necessary cast.
|
||||||
|
conn, ok := v.(*connection)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
// send to the client(s) when the top validators passed
|
// send to the client(s) when the top validators passed
|
||||||
cKV.value.writeDefault(data)
|
conn.writeDefault(data)
|
||||||
}
|
return true
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -421,16 +410,15 @@ func (s *Server) Disconnect(connID string) (err error) {
|
||||||
// note: we cannot use that to send data if the client is actually closed.
|
// note: we cannot use that to send data if the client is actually closed.
|
||||||
s.LeaveAll(connID)
|
s.LeaveAll(connID)
|
||||||
|
|
||||||
// remove the connection from the list
|
// remove the connection from the list.
|
||||||
if c, ok := s.connections.remove(connID); ok {
|
if conn, ok := s.getConnection(connID); ok {
|
||||||
if !c.disconnected {
|
conn.disconnected = true
|
||||||
c.disconnected = true
|
// fire the disconnect callbacks, if any.
|
||||||
|
conn.fireDisconnect()
|
||||||
|
// close the underline connection and return its error, if any.
|
||||||
|
err = conn.underline.Close()
|
||||||
|
|
||||||
// fire the disconnect callbacks, if any
|
s.connections.Delete(connID)
|
||||||
c.fireDisconnect()
|
|
||||||
// close the underline connection and return its error, if any.
|
|
||||||
err = c.underline.Close()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
Loading…
Reference in New Issue
Block a user