From c477251d1f9a63c95e29b5005f9d35cd38eb27ae Mon Sep 17 00:00:00 2001
From: "Gerasimos (Makis) Maropoulos" <kataras2006@hotmail.com>
Date: Tue, 19 Feb 2019 22:49:16 +0200
Subject: [PATCH] improve client test, I think we are OK, both
 gorilla(websocket) and ws(websocket2) have the same API, it's time to combine
 them but first let's give a lower level of api available for users if they
 want to manage the routines by theirselves (i.e on unix they can use netpolls
 manually)

Former-commit-id: 3209a7490939bce913732c1375190b0771ba63ae
---
 _examples/mvc/session-controller/main.go      |  4 +-
 .../go-client-stress-test/client/main.go      | 50 ++++++-----
 .../go-client-stress-test/server/main.go      | 85 ++++++++++++++-----
 _examples/websocket/go-client/client/main.go  |  3 +-
 websocket/connection.go                       | 19 ++---
 websocket2/connection.go                      | 36 ++++----
 websocket2/server.go                          | 51 ++++++++---
 7 files changed, 160 insertions(+), 88 deletions(-)

diff --git a/_examples/mvc/session-controller/main.go b/_examples/mvc/session-controller/main.go
index aa3ea051..272b0abf 100644
--- a/_examples/mvc/session-controller/main.go
+++ b/_examples/mvc/session-controller/main.go
@@ -63,10 +63,10 @@ func newApp() *iris.Application {
 func main() {
 	app := newApp()
 
-	// 1. open the browser (no in private mode)
+	// 1. open the browser
 	// 2. navigate to http://localhost:8080
 	// 3. refresh the page some times
 	// 4. close the browser
-	// 5. re-open the browser and re-play 2.
+	// 5. re-open the browser (if it wasn't in private mode) and re-play 2.
 	app.Run(iris.Addr(":8080"))
 }
diff --git a/_examples/websocket/go-client-stress-test/client/main.go b/_examples/websocket/go-client-stress-test/client/main.go
index 4747feb6..0be23e13 100644
--- a/_examples/websocket/go-client-stress-test/client/main.go
+++ b/_examples/websocket/go-client-stress-test/client/main.go
@@ -2,8 +2,9 @@ package main
 
 import (
 	"bufio"
-	"fmt"
+	"log"
 	"math/rand"
+	"net"
 	"os"
 	"sync"
 	"sync/atomic"
@@ -13,11 +14,11 @@ import (
 )
 
 var (
-	url = "ws://localhost:8080/socket"
+	url = "ws://localhost:8080"
 	f   *os.File
 )
 
-const totalClients = 100000
+const totalClients = 16000 // max depends on the OS.
 
 var connectionFailures uint64
 
@@ -41,6 +42,7 @@ func collectError(op string, err error) {
 }
 
 func main() {
+	log.Println("--------======Running tests...==========--------------")
 	var err error
 	f, err = os.Open("./test.data")
 	if err != nil {
@@ -67,27 +69,27 @@ func main() {
 		wg.Add(1)
 		waitTime := time.Duration(rand.Intn(10)) * time.Millisecond
 		time.Sleep(waitTime)
-		go connect(wg, 10*time.Second+waitTime)
+		go connect(wg, 9*time.Second+waitTime)
 	}
 
 	for i := 0; i < totalClients/4; i++ {
 		wg.Add(1)
-		waitTime := time.Duration(rand.Intn(20)) * time.Millisecond
+		waitTime := time.Duration(rand.Intn(5)) * time.Millisecond
 		time.Sleep(waitTime)
-		go connect(wg, 25*time.Second+waitTime)
+		go connect(wg, 14*time.Second+waitTime)
 	}
 
 	wg.Wait()
-	fmt.Println("--------================--------------")
-	fmt.Printf("execution time [%s]", time.Since(start))
-	fmt.Println()
+
+	log.Printf("execution time [%s]", time.Since(start))
+	log.Println()
 
 	if connectionFailures > 0 {
-		fmt.Printf("Finished with %d/%d connection failures. Please close the server-side manually.\n", connectionFailures, totalClients)
+		log.Printf("Finished with %d/%d connection failures. Please close the server-side manually.\n", connectionFailures, totalClients)
 	}
 
 	if n := len(connectErrors); n > 0 {
-		fmt.Printf("Finished with %d connect errors:\n", n)
+		log.Printf("Finished with %d connect errors:\n", n)
 		var lastErr error
 		var sameC int
 
@@ -96,36 +98,44 @@ func main() {
 				if lastErr.Error() == err.Error() {
 					sameC++
 					continue
+				} else {
+					_, ok := lastErr.(*net.OpError)
+					if ok {
+						if _, ok = err.(*net.OpError); ok {
+							sameC++
+							continue
+						}
+					}
 				}
 			}
 
 			if sameC > 0 {
-				fmt.Printf("and %d more like this...\n", sameC)
+				log.Printf("and %d more like this...\n", sameC)
 				sameC = 0
 				continue
 			}
 
-			fmt.Printf("[%d] - %v\n", i+1, err)
+			log.Printf("[%d] - %v\n", i+1, err)
 			lastErr = err
 		}
 	}
 
 	if n := len(disconnectErrors); n > 0 {
-		fmt.Printf("Finished with %d disconnect errors\n", n)
+		log.Printf("Finished with %d disconnect errors\n", n)
 		for i, err := range disconnectErrors {
 			if err == websocket.ErrAlreadyDisconnected {
 				continue
 			}
 
-			fmt.Printf("[%d] - %v\n", i+1, err)
+			log.Printf("[%d] - %v\n", i+1, err)
 		}
 	}
 
 	if connectionFailures == 0 && len(connectErrors) == 0 && len(disconnectErrors) == 0 {
-		fmt.Println("ALL OK.")
+		log.Println("ALL OK.")
 	}
 
-	fmt.Println("--------================--------------")
+	log.Println("--------================--------------")
 }
 
 func connect(wg *sync.WaitGroup, alive time.Duration) {
@@ -138,17 +148,17 @@ func connect(wg *sync.WaitGroup, alive time.Duration) {
 	}
 
 	c.OnError(func(err error) {
-		fmt.Printf("error: %v", err)
+		log.Printf("error: %v", err)
 	})
 
 	disconnected := false
 	c.OnDisconnect(func() {
-		fmt.Printf("I am disconnected after [%s].\n", alive)
+		// log.Printf("I am disconnected after [%s].\n", alive)
 		disconnected = true
 	})
 
 	c.On("chat", func(message string) {
-		fmt.Printf("\n%s\n", message)
+		// log.Printf("\n%s\n", message)
 	})
 
 	go func() {
diff --git a/_examples/websocket/go-client-stress-test/server/main.go b/_examples/websocket/go-client-stress-test/server/main.go
index 07ea3014..a0cf6668 100644
--- a/_examples/websocket/go-client-stress-test/server/main.go
+++ b/_examples/websocket/go-client-stress-test/server/main.go
@@ -1,8 +1,9 @@
 package main
 
 import (
-	"fmt"
+	"log"
 	"os"
+	"runtime"
 	"sync/atomic"
 	"time"
 
@@ -10,38 +11,84 @@ import (
 	"github.com/kataras/iris/websocket2"
 )
 
-const totalClients = 100000
+const totalClients = 16000 // max depends on the OS.
+const http = true
 
 func main() {
-	app := iris.New()
 
-	// websocket.Config{PingPeriod: ((60 * time.Second) * 9) / 10}
 	ws := websocket.New(websocket.Config{})
 	ws.OnConnection(handleConnection)
-	app.Get("/socket", ws.Handler())
+
+	// websocket.Config{PingPeriod: ((60 * time.Second) * 9) / 10}
 
 	go func() {
-		t := time.NewTicker(2 * time.Second)
+		dur := 8 * time.Second
+		if totalClients >= 64000 {
+			// if more than 64000 then let's no check every 8 seconds, let's do it every 24 seconds,
+			// just for simplicity, either way works.
+			dur = 24 * time.Second
+		}
+		t := time.NewTicker(dur)
+		defer t.Stop()
+		defer os.Exit(0)
+		defer runtime.Goexit()
+
+		var started bool
 		for {
 			<-t.C
 
-			conns := ws.GetConnections()
-			for _, conn := range conns {
-				// fmt.Println(conn.ID())
-				// Do nothing.
-				_ = conn
+			n := ws.GetTotalConnections()
+			if n > 0 {
+				started = true
 			}
 
-			if atomic.LoadUint64(&count) == totalClients {
-				fmt.Println("ALL CLIENTS DISCONNECTED SUCCESSFULLY.")
-				t.Stop()
-				os.Exit(0)
-				return
+			if started {
+				totalConnected := atomic.LoadUint64(&count)
+
+				if totalConnected == totalClients {
+					if n != 0 {
+						log.Println("ALL CLIENTS DISCONNECTED BUT LEFTOVERS ON CONNECTIONS LIST.")
+					} else {
+						log.Println("ALL CLIENTS DISCONNECTED SUCCESSFULLY.")
+					}
+					return
+				} else if n == 0 {
+					log.Printf("%d/%d CLIENTS WERE NOT CONNECTED AT ALL. CHECK YOUR OS NET SETTINGS. ALL OTHER CONNECTED CLIENTS DISCONNECTED SUCCESSFULLY.\n",
+						totalClients-totalConnected, totalClients)
+
+					return
+				}
 			}
 		}
 	}()
 
-	app.Run(iris.Addr(":8080"))
+	if http {
+		app := iris.New()
+		app.Get("/", ws.Handler())
+		app.Run(iris.Addr(":8080"))
+		return
+	}
+
+	// ln, err := net.Listen("tcp", ":8080")
+	// if err != nil {
+	// 	panic(err)
+	// }
+
+	// defer ln.Close()
+	// for {
+	// 	conn, err := ln.Accept()
+	// 	if err != nil {
+	// 		panic(err)
+	// 	}
+
+	// 	go func() {
+	// 		err = ws.HandleConn(conn)
+	// 		if err != nil {
+	// 			panic(err)
+	// 		}
+	// 	}()
+	// }
+
 }
 
 func handleConnection(c websocket.Connection) {
@@ -56,9 +103,9 @@ var count uint64
 
 func handleDisconnect(c websocket.Connection) {
 	atomic.AddUint64(&count, 1)
-	fmt.Printf("client [%s] disconnected!\n", c.ID())
+	// log.Printf("client [%s] disconnected!\n", c.ID())
 }
 
 func handleErr(c websocket.Connection, err error) {
-	fmt.Printf("client [%s] errored: %v\n", c.ID(), err)
+	log.Printf("client [%s] errored: %v\n", c.ID(), err)
 }
diff --git a/_examples/websocket/go-client/client/main.go b/_examples/websocket/go-client/client/main.go
index 2d1ded6d..6d3f744d 100644
--- a/_examples/websocket/go-client/client/main.go
+++ b/_examples/websocket/go-client/client/main.go
@@ -21,8 +21,7 @@ $ go run main.go
 >> hi!
 */
 func main() {
-	// `websocket.DialContext` is also available.
-	c, err := websocket.Dial(url, websocket.ConnectionConfig{})
+	c, err := websocket.Dial(nil, url, websocket.ConnectionConfig{})
 	if err != nil {
 		panic(err)
 	}
diff --git a/websocket/connection.go b/websocket/connection.go
index 983fe209..8987ec3e 100644
--- a/websocket/connection.go
+++ b/websocket/connection.go
@@ -197,11 +197,6 @@ type (
 		// Note: the callback(s) called right before the server deletes the connection from the room
 		// so the connection theoretical can still send messages to its room right before it is being disconnected.
 		OnLeave(roomLeaveCb LeaveRoomFunc)
-		// Wait starts the pinger and the messages reader,
-		// it's named as "Wait" because it should be called LAST,
-		// after the "On" events IF server's `Upgrade` is used,
-		// otherise you don't have to call it because the `Handler()` does it automatically.
-		Wait()
 		// SetValue sets a key-value pair on the connection's mem store.
 		SetValue(key string, value interface{})
 		// GetValue gets a value by its key from the connection's mem store.
@@ -239,6 +234,11 @@ type (
 		// Disconnect disconnects the client, close the underline websocket conn and removes it from the conn list
 		// returns the error, if any, from the underline connection
 		Disconnect() error
+		// Wait starts the pinger and the messages reader,
+		// it's named as "Wait" because it should be called LAST,
+		// after the "On" events IF server's `Upgrade` is used,
+		// otherise you don't have to call it because the `Handler()` does it automatically.
+		Wait()
 	}
 
 	connection struct {
@@ -792,7 +792,7 @@ func (c ConnectionConfig) Validate() ConnectionConfig {
 // invalid.
 var ErrBadHandshake = websocket.ErrBadHandshake
 
-// DialContext creates a new client connection.
+// Dial creates a new client connection.
 //
 // The context will be used in the request and in the Dialer.
 //
@@ -803,7 +803,7 @@ var ErrBadHandshake = websocket.ErrBadHandshake
 // open socket of the server, i.e ws://localhost:8080/my_websocket_endpoint.
 //
 // Custom dialers can be used by wrapping the iris websocket connection via `websocket.WrapConnection`.
-func DialContext(ctx stdContext.Context, url string, cfg ConnectionConfig) (ClientConnection, error) {
+func Dial(ctx stdContext.Context, url string, cfg ConnectionConfig) (ClientConnection, error) {
 	if ctx == nil {
 		ctx = stdContext.Background()
 	}
@@ -822,8 +822,3 @@ func DialContext(ctx stdContext.Context, url string, cfg ConnectionConfig) (Clie
 
 	return clientConn, nil
 }
-
-// Dial creates a new client connection by calling `DialContext` with a background context.
-func Dial(url string, cfg ConnectionConfig) (ClientConnection, error) {
-	return DialContext(stdContext.Background(), url, cfg)
-}
diff --git a/websocket2/connection.go b/websocket2/connection.go
index 840d36e4..2547cf5a 100644
--- a/websocket2/connection.go
+++ b/websocket2/connection.go
@@ -197,11 +197,7 @@ type (
 		// Note: the callback(s) called right before the server deletes the connection from the room
 		// so the connection theoretical can still send messages to its room right before it is being disconnected.
 		OnLeave(roomLeaveCb LeaveRoomFunc)
-		// Wait starts the pinger and the messages reader,
-		// it's named as "Wait" because it should be called LAST,
-		// after the "On" events IF server's `Upgrade` is used,
-		// otherise you don't have to call it because the `Handler()` does it automatically.
-		Wait()
+
 		// SetValue sets a key-value pair on the connection's mem store.
 		SetValue(key string, value interface{})
 		// GetValue gets a value by its key from the connection's mem store.
@@ -239,6 +235,11 @@ type (
 		// Disconnect disconnects the client, close the underline websocket conn and removes it from the conn list
 		// returns the error, if any, from the underline connection
 		Disconnect() error
+		// Wait starts the pinger and the messages reader,
+		// it's named as "Wait" because it should be called LAST,
+		// after the "On" events IF server's `Upgrade` is used,
+		// otherise you don't have to call it because the `Handler()` does it automatically.
+		Wait() error
 	}
 
 	connection struct {
@@ -454,7 +455,7 @@ func (c *connection) isErrClosed(err error) bool {
 	return err != io.EOF
 }
 
-func (c *connection) startReader() {
+func (c *connection) startReader() error {
 	defer c.Disconnect()
 
 	hasReadTimeout := c.config.ReadTimeout > 0
@@ -476,25 +477,25 @@ func (c *connection) startReader() {
 
 		hdr, err := rd.NextFrame()
 		if err != nil {
-			return
+			return err
 		}
 		if hdr.OpCode.IsControl() {
 			if err := controlHandler(hdr, &rd); err != nil {
-				return
+				return err
 			}
 			continue
 		}
 
 		if hdr.OpCode&TextMessage == 0 && hdr.OpCode&BinaryMessage == 0 {
 			if err := rd.Discard(); err != nil {
-				return
+				return err
 			}
 			continue
 		}
 
 		data, err := ioutil.ReadAll(&rd)
 		if err != nil {
-			return
+			return err
 		}
 
 		c.messageReceived(data)
@@ -575,7 +576,6 @@ func (c *connection) startReader() {
 
 		// c.messageReceived(data)
 	}
-
 }
 
 // messageReceived checks the incoming message and fire the nativeMessage listeners or the event listeners (ws custom message)
@@ -747,16 +747,16 @@ func (c *connection) fireOnLeave(roomName string) {
 // it's named as "Wait" because it should be called LAST,
 // after the "On" events IF server's `Upgrade` is used,
 // otherise you don't have to call it because the `Handler()` does it automatically.
-func (c *connection) Wait() {
+func (c *connection) Wait() error {
 	if c.started {
-		return
+		return nil
 	}
 	c.started = true
 	// start the ping
 	c.startPinger()
 
 	// start the messages reader
-	c.startReader()
+	return c.startReader()
 }
 
 // ErrAlreadyDisconnected can be reported on the `Connection#Disconnect` function whenever the caller tries to close the
@@ -912,13 +912,7 @@ var ErrBadHandshake = ws.ErrHandshakeBadConnection
 //
 // Custom dialers can be used by wrapping the iris websocket connection via `websocket.WrapConnection`.
 func Dial(ctx stdContext.Context, url string, cfg ConnectionConfig) (ClientConnection, error) {
-	c, err := dial(ctx, url, cfg)
-	if err != nil {
-		time.Sleep(1 * time.Second)
-		c, err = dial(ctx, url, cfg)
-	}
-
-	return c, err
+	return dial(ctx, url, cfg)
 }
 
 func dial(ctx stdContext.Context, url string, cfg ConnectionConfig) (ClientConnection, error) {
diff --git a/websocket2/server.go b/websocket2/server.go
index e4a2df8d..8adfd32e 100644
--- a/websocket2/server.go
+++ b/websocket2/server.go
@@ -50,7 +50,8 @@ type (
 		mu                    sync.RWMutex        // for rooms.
 		onConnectionListeners []ConnectionFunc
 		//connectionPool        sync.Pool // sadly we can't make this because the websocket connection is live until is closed.
-		upgrader ws.HTTPUpgrader
+		httpUpgrader ws.HTTPUpgrader
+		tcpUpgrader  ws.Upgrader
 	}
 )
 
@@ -67,7 +68,8 @@ func New(cfg Config) *Server {
 		connections:           sync.Map{}, // ready-to-use, this is not necessary.
 		rooms:                 make(map[string][]string),
 		onConnectionListeners: make([]ConnectionFunc, 0),
-		upgrader:              ws.DefaultHTTPUpgrader, // ws.DefaultUpgrader,
+		httpUpgrader:          ws.DefaultHTTPUpgrader, // ws.DefaultUpgrader,
+		tcpUpgrader:           ws.DefaultUpgrader,
 	}
 }
 
@@ -115,7 +117,7 @@ func (s *Server) Handler() context.Handler {
 // This one does not starts the connection's writer and reader, so after your `On/OnMessage` events registration
 // the caller has to call the `Connection#Wait` function, otherwise the connection will be not handled.
 func (s *Server) Upgrade(ctx context.Context) Connection {
-	conn, _, _, err := s.upgrader.Upgrade(ctx.Request(), ctx.ResponseWriter())
+	conn, _, _, err := s.httpUpgrader.Upgrade(ctx.Request(), ctx.ResponseWriter())
 	if err != nil {
 		ctx.Application().Logger().Warnf("websocket error: %v\n", err)
 		ctx.StatusCode(503) // Status Service Unavailable
@@ -125,6 +127,37 @@ func (s *Server) Upgrade(ctx context.Context) Connection {
 	return s.handleConnection(ctx, conn)
 }
 
+func (s *Server) ZeroUpgrade(conn net.Conn) Connection {
+	_, err := s.tcpUpgrader.Upgrade(conn)
+	if err != nil {
+		return &connection{err: err}
+	}
+
+	return s.handleConnection(nil, conn)
+}
+
+func (s *Server) HandleConn(conn net.Conn) error {
+	c := s.ZeroUpgrade(conn)
+	if c.Err() != nil {
+		return c.Err()
+	}
+
+	// NOTE TO ME: fire these first BEFORE startReader and startPinger
+	// in order to set the events and any messages to send
+	// the startPinger will send the OK to the client and only
+	// then the client is able to send and receive from Server
+	// when all things are ready and only then. DO NOT change this order.
+
+	// fire the on connection event callbacks, if any
+	for i := range s.onConnectionListeners {
+		s.onConnectionListeners[i](c)
+	}
+
+	// start the ping and the messages reader
+	c.Wait()
+	return nil
+}
+
 func (s *Server) addConnection(c *connection) {
 	s.connections.Store(c.id, c)
 }
@@ -292,12 +325,7 @@ func (s *Server) GetTotalConnections() (n int) {
 }
 
 // GetConnections returns all connections
-func (s *Server) GetConnections() []Connection {
-	// first call of Range to get the total length, we don't want to use append or manually grow the list here for many reasons.
-	length := s.GetTotalConnections()
-	conns := make([]Connection, length, length)
-	i := 0
-	// second call of Range.
+func (s *Server) GetConnections() (conns []Connection) {
 	s.connections.Range(func(k, v interface{}) bool {
 		conn, ok := v.(*connection)
 		if !ok {
@@ -306,12 +334,11 @@ func (s *Server) GetConnections() []Connection {
 			// in order to avoid any issues while end-dev will try to iterate a nil entry.
 			return false
 		}
-		conns[i] = conn
-		i++
+		conns = append(conns, conn)
 		return true
 	})
 
-	return conns
+	return
 }
 
 // GetConnection returns single connection