From 8c4d52e33bd115d765b211d827723fe0bcfa55a1 Mon Sep 17 00:00:00 2001 From: lixizan Date: Thu, 27 Apr 2023 16:47:03 +0800 Subject: [PATCH 1/7] add onconnect event for server --- updrader.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/updrader.go b/updrader.go index e10a0319..53b7c822 100644 --- a/updrader.go +++ b/updrader.go @@ -138,6 +138,10 @@ func (c *Upgrader) doAccept(r *http.Request, netConn net.Conn, br *bufio.Reader) type Server struct { upgrader *Upgrader + // OnConnect 建立连接事件, 用于处理限流, 熔断和安全问题; 返回错误将会断开连接. + // Creates connection events for current limit, fuse and security issues; returning an error will disconnect. + OnConnect func(conn net.Conn) error + // OnError 接收握手过程中产生的错误回调 // Receive error callbacks generated during the handshake OnError func(conn net.Conn, err error) @@ -147,6 +151,7 @@ type Server struct { // create a websocket server func NewServer(eventHandler Event, option *ServerOption) *Server { var c = &Server{upgrader: NewUpgrader(eventHandler, option)} + c.OnConnect = func(conn net.Conn) error { return nil } c.OnError = func(conn net.Conn, err error) {} return c } @@ -236,6 +241,12 @@ func (c *Server) serve(listener net.Listener) error { } go func() { + if err := c.OnConnect(conn); err != nil { + _ = conn.Close() + c.OnError(conn, err) + return + } + br := bufio.NewReaderSize(conn, c.upgrader.option.ReadBufferSize) r, err := c.parseRequest(conn, br) if err != nil { From 9deff18f7a862100a16c50387d8f5b9366b63e3d Mon Sep 17 00:00:00 2001 From: lixizan Date: Thu, 27 Apr 2023 19:03:52 +0800 Subject: [PATCH 2/7] rename Accept => Upgrade Listen => ReadLoop --- README.md | 12 ++++++------ conn.go | 11 +++++++++-- examples/autobahn/client/main.go | 2 +- examples/autobahn/server/main.go | 4 ++-- examples/client/client.go | 2 +- updrader.go | 12 +++++++++--- 6 files changed, 28 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 6f2f804c..9d92e16e 100644 --- a/README.md +++ b/README.md @@ -92,12 +92,12 @@ func main() { }) http.HandleFunc("/connect", func(writer http.ResponseWriter, request *http.Request) { - socket, err := upgrader.Accept(writer, request) + socket, err := upgrader.Upgrade(writer, request) if err != nil { log.Printf("Accept: " + err.Error()) return } - go socket.Listen() + socket.ReadLoop() }) if err := http.ListenAndServe(":3000", nil); err != nil { @@ -135,7 +135,7 @@ func main() { log.Printf(err.Error()) return } - socket.Listen() + socket.ReadLoop() } type WebSocket struct { @@ -175,11 +175,11 @@ func main() { app := gin.New() upgrader := gws.NewUpgrader(new(WebSocket), nil) app.GET("/connect", func(ctx *gin.Context) { - socket, err := upgrader.Accept(ctx.Writer, ctx.Request) + socket, err := upgrader.Upgrade(ctx.Writer, ctx.Request) if err != nil { return } - go upgrader.Listen(socket) + go upgrader.ReadLoop(socket) }) if err := app.Run(":8080"); err != nil { panic(err) @@ -190,7 +190,7 @@ func main() { - HeartBeat ```go -const PingInterval = 5 * time.Second +const PingInterval = 10 * time.Second type Websocket struct { gws.BuiltinEventHandler diff --git a/conn.go b/conn.go index 6a6d48a4..2cdb297c 100644 --- a/conn.go +++ b/conn.go @@ -71,12 +71,19 @@ func serveWebSocket(isServer bool, config *Config, session SessionStorage, netCo return c } -// Listen listening to websocket messages through a dead loop -// 监听websocket消息 +// Listen 监听websocket消息 +// Deprecated: Listen will be deprecated in future versions, please use ReadLoop instead. func (c *Conn) Listen() { + c.ReadLoop() +} + +// ReadLoop start a read message loop +// 启动一个读消息的死循环 +func (c *Conn) ReadLoop() { defer c.conn.Close() c.handler.OnOpen(c) + for { if err := c.readMessage(); err != nil { c.emitError(err) diff --git a/examples/autobahn/client/main.go b/examples/autobahn/client/main.go index d394e2fd..a8a19074 100644 --- a/examples/autobahn/client/main.go +++ b/examples/autobahn/client/main.go @@ -32,7 +32,7 @@ func testCase(id int) { log.Println(err.Error()) return } - go socket.Listen() + go socket.ReadLoop() <-handler.onexit } diff --git a/examples/autobahn/server/main.go b/examples/autobahn/server/main.go index dbb00ff6..7b205389 100644 --- a/examples/autobahn/server/main.go +++ b/examples/autobahn/server/main.go @@ -15,11 +15,11 @@ func main() { }) http.HandleFunc("/connect", func(writer http.ResponseWriter, request *http.Request) { - socket, err := upgrader.Accept(writer, request) + socket, err := upgrader.Upgrade(writer, request) if err != nil { return } - go socket.Listen() + socket.ReadLoop() }) _ = http.ListenAndServe(":3000", nil) diff --git a/examples/client/client.go b/examples/client/client.go index a42a79a9..6e34ddd4 100644 --- a/examples/client/client.go +++ b/examples/client/client.go @@ -15,7 +15,7 @@ func main() { log.Printf(err.Error()) return } - go socket.Listen() + go socket.ReadLoop() for { var text = "" diff --git a/updrader.go b/updrader.go index 53b7c822..84018ed4 100644 --- a/updrader.go +++ b/updrader.go @@ -62,13 +62,19 @@ func (c *Upgrader) connectHandshake(r *http.Request, responseHeader http.Header, } // Accept http upgrade to websocket protocol +// Deprecated: Accept will be deprecated in future versions, please use Upgrade instead. func (c *Upgrader) Accept(w http.ResponseWriter, r *http.Request) (*Conn, error) { + return c.Upgrade(w, r) +} + +// Upgrade http upgrade to websocket protocol +func (c *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request) (*Conn, error) { netConn, br, err := c.hijack(w) if err != nil { return nil, err } - socket, err := c.doAccept(r, netConn, br) + socket, err := c.doUpgrade(r, netConn, br) if err != nil { _ = netConn.Close() return nil, err @@ -93,7 +99,7 @@ func (c *Upgrader) hijack(w http.ResponseWriter) (net.Conn, *bufio.Reader, error return netConn, brw.Reader, nil } -func (c *Upgrader) doAccept(r *http.Request, netConn net.Conn, br *bufio.Reader) (*Conn, error) { +func (c *Upgrader) doUpgrade(r *http.Request, netConn net.Conn, br *bufio.Reader) (*Conn, error) { var session = new(sliceMap) var header = c.option.ResponseHeader.Clone() if !c.option.CheckOrigin(r, session) { @@ -255,7 +261,7 @@ func (c *Server) serve(listener net.Listener) error { return } - socket, err := c.upgrader.doAccept(r, conn, br) + socket, err := c.upgrader.doUpgrade(r, conn, br) if err != nil { _ = conn.Close() c.OnError(conn, err) From d43fb67f4b6345732c0532132f664867340a3591 Mon Sep 17 00:00:00 2001 From: lixizan Date: Fri, 28 Apr 2023 09:09:42 +0800 Subject: [PATCH 3/7] add echo demo --- README.md | 2 +- examples/echo/main.go | 27 +++++++++++++++++++ examples/{wss-server => wss}/cert/server.crt | 0 examples/{wss-server => wss}/cert/server.pem | 0 .../{wss-server/server.go => wss/main.go} | 0 5 files changed, 28 insertions(+), 1 deletion(-) create mode 100644 examples/echo/main.go rename examples/{wss-server => wss}/cert/server.crt (100%) rename examples/{wss-server => wss}/cert/server.pem (100%) rename examples/{wss-server/server.go => wss/main.go} (100%) diff --git a/README.md b/README.md index 9d92e16e..c864cd97 100644 --- a/README.md +++ b/README.md @@ -66,7 +66,7 @@ type Event interface { #### Examples - [chat room](examples/chatroom/main.go) -- [echo](examples/wss-server/server.go) +- [echo](examples/echo/main.go) #### Quick Start diff --git a/examples/echo/main.go b/examples/echo/main.go new file mode 100644 index 00000000..1eef31e5 --- /dev/null +++ b/examples/echo/main.go @@ -0,0 +1,27 @@ +package main + +import ( + "github.com/lxzan/gws" + "log" +) + +func main() { + var app = gws.NewServer(new(Handler), &gws.ServerOption{ + CompressEnabled: true, + CheckUtf8Enabled: true, + }) + log.Fatalf("%v", app.Run(":3000")) +} + +type Handler struct { + gws.BuiltinEventHandler +} + +func (c *Handler) OnPing(socket *gws.Conn, payload []byte) { + socket.WritePong(payload) +} + +func (c *Handler) OnMessage(socket *gws.Conn, message *gws.Message) { + defer message.Close() + _ = socket.WriteMessage(message.Opcode, message.Bytes()) +} diff --git a/examples/wss-server/cert/server.crt b/examples/wss/cert/server.crt similarity index 100% rename from examples/wss-server/cert/server.crt rename to examples/wss/cert/server.crt diff --git a/examples/wss-server/cert/server.pem b/examples/wss/cert/server.pem similarity index 100% rename from examples/wss-server/cert/server.pem rename to examples/wss/cert/server.pem diff --git a/examples/wss-server/server.go b/examples/wss/main.go similarity index 100% rename from examples/wss-server/server.go rename to examples/wss/main.go From 687cad69acd63c46f5a85a798c7ed7bb9e851e3d Mon Sep 17 00:00:00 2001 From: lixizan Date: Fri, 28 Apr 2023 09:58:15 +0800 Subject: [PATCH 4/7] optimize compress write --- compress.go | 28 ++++++++-------------- protocol.go | 6 ++++- writer.go | 69 ++++++++++++++++++++++++++++++----------------------- 3 files changed, 54 insertions(+), 49 deletions(-) diff --git a/compress.go b/compress.go index 72d8290c..2d2d115b 100644 --- a/compress.go +++ b/compress.go @@ -16,33 +16,25 @@ func newCompressor(level int) *compressor { // 压缩器 type compressor struct { - buffer *bytes.Buffer - fw *flate.Writer -} - -func (c *compressor) Close() { - _bpool.Put(c.buffer) - c.buffer = nil + fw *flate.Writer } // Compress 压缩 -func (c *compressor) Compress(content *bytes.Buffer) (*bytes.Buffer, error) { - c.buffer = _bpool.Get(content.Len() / 3) - c.fw.Reset(c.buffer) - if err := internal.WriteN(c.fw, content.Bytes(), content.Len()); err != nil { - return nil, err +func (c *compressor) Compress(content []byte, buf *bytes.Buffer) error { + c.fw.Reset(buf) + if err := internal.WriteN(c.fw, content, len(content)); err != nil { + return err } if err := c.fw.Flush(); err != nil { - return nil, err + return err } - - if n := c.buffer.Len(); n >= 4 { - compressedContent := c.buffer.Bytes() + if n := buf.Len(); n >= 4 { + compressedContent := buf.Bytes() if tail := compressedContent[n-4:]; binary.BigEndian.Uint32(tail) == math.MaxUint16 { - c.buffer.Truncate(n - 4) + buf.Truncate(n - 4) } } - return c.buffer, nil + return nil } func newDecompressor() *decompressor { return &decompressor{fr: flate.NewReader(nil)} } diff --git a/protocol.go b/protocol.go index 7ed71ffb..3381718b 100644 --- a/protocol.go +++ b/protocol.go @@ -7,6 +7,10 @@ import ( "io" ) +const frameHeaderSize = 14 + +var frameHeaderPadding = frameHeader{} + type Opcode uint8 const ( @@ -62,7 +66,7 @@ func (b BuiltinEventHandler) OnPong(socket *Conn, payload []byte) {} func (b BuiltinEventHandler) OnMessage(socket *Conn, message *Message) {} -type frameHeader [internal.FrameHeaderSize]byte +type frameHeader [frameHeaderSize]byte func (c *frameHeader) GetFIN() bool { return ((*c)[0] >> 7) == 1 diff --git a/writer.go b/writer.go index becc5b5b..e3b517fe 100644 --- a/writer.go +++ b/writer.go @@ -1,10 +1,8 @@ package gws import ( - "bytes" "errors" "github.com/lxzan/gws/internal" - "net" ) // WriteClose proactively close the connection @@ -47,50 +45,61 @@ func (c *Conn) WriteMessage(opcode Opcode, payload []byte) error { return err } -// 解锁并回收资源 -// Unlock and recover resources -func (c *Conn) endWrite(compress bool) { - c.wmu.Unlock() - if compress { - c.compressor.Close() - } -} - // 执行写入逻辑, 关闭状态置为1后还能写, 以便发送关闭帧 // Execute the write logic, and write after the close state is set to 1, so that the close frame can be sent func (c *Conn) doWrite(opcode Opcode, payload []byte) error { + c.wmu.Lock() + defer c.wmu.Unlock() + if opcode == OpcodeText && !c.isTextValid(opcode, payload) { return internal.NewError(internal.CloseUnsupportedData, internal.ErrTextEncoding) } var compress = c.compressEnabled && opcode.IsDataFrame() && len(payload) >= c.config.CompressThreshold - c.wmu.Lock() - defer c.endWrite(compress) - - if compress { - compressedContent, err := c.compressor.Compress(bytes.NewBuffer(payload)) - if err != nil { - return internal.NewError(internal.CloseInternalServerErr, err) + if !compress { + var n = len(payload) + if n > c.config.WriteMaxPayloadSize { + return internal.CloseMessageTooLarge + } + var header = frameHeader{} + headerLength, maskBytes := header.GenerateHeader(c.isServer, true, compress, opcode, n) + if !c.isServer { + internal.MaskXOR(payload, maskBytes) } - payload = compressedContent.Bytes() + var totalSize = n + headerLength + var buf = _bpool.Get(totalSize) + buf.Write(header[:headerLength]) + buf.Write(payload) + var err = internal.WriteN(c.conn, buf.Bytes(), totalSize) + _bpool.Put(buf) + return err } - if len(payload) > c.config.WriteMaxPayloadSize { + return c.writeCompressedContents(opcode, payload) +} + +func (c *Conn) writeCompressedContents(opcode Opcode, payload []byte) error { + var buf = _bpool.Get(len(payload) / 3) + defer _bpool.Put(buf) + + buf.Write(frameHeaderPadding[0:]) + if err := c.compressor.Compress(payload, buf); err != nil { + return err + } + + var contents = buf.Bytes() + var payloadSize = buf.Len() - frameHeaderSize + if payloadSize > c.config.WriteMaxPayloadSize { return internal.CloseMessageTooLarge } - var n = len(payload) var header = frameHeader{} - headerLength, maskBytes := header.GenerateHeader(c.isServer, true, compress, opcode, n) + headerLength, maskBytes := header.GenerateHeader(c.isServer, true, true, opcode, payloadSize) + var offset = frameHeaderSize - headerLength if !c.isServer { - internal.MaskXOR(payload, maskBytes) - } - - var buffers = net.Buffers{header[:headerLength], payload} - if n == 0 { - buffers = buffers[:1] + internal.MaskXOR(contents[frameHeaderSize:], maskBytes) } - num, err := buffers.WriteTo(c.conn) - return internal.CheckIOError(headerLength+n, int(num), err) + copy(contents[offset:frameHeaderSize], header[:headerLength]) + return internal.WriteN(c.conn, contents[offset:], payloadSize+headerLength) } // WriteAsync 异步非阻塞地写入消息 From 8093a3259db01d6dd9a40f7336bccb87e68c2498 Mon Sep 17 00:00:00 2001 From: lixizan Date: Fri, 28 Apr 2023 10:23:19 +0800 Subject: [PATCH 5/7] optimize compress write --- protocol.go | 2 -- writer.go | 42 +++++++++++++++++++++--------------------- 2 files changed, 21 insertions(+), 23 deletions(-) diff --git a/protocol.go b/protocol.go index 3381718b..86f67a70 100644 --- a/protocol.go +++ b/protocol.go @@ -9,8 +9,6 @@ import ( const frameHeaderSize = 14 -var frameHeaderPadding = frameHeader{} - type Opcode uint8 const ( diff --git a/writer.go b/writer.go index e3b517fe..c5238301 100644 --- a/writer.go +++ b/writer.go @@ -55,33 +55,34 @@ func (c *Conn) doWrite(opcode Opcode, payload []byte) error { return internal.NewError(internal.CloseUnsupportedData, internal.ErrTextEncoding) } - var compress = c.compressEnabled && opcode.IsDataFrame() && len(payload) >= c.config.CompressThreshold - if !compress { - var n = len(payload) - if n > c.config.WriteMaxPayloadSize { - return internal.CloseMessageTooLarge - } - var header = frameHeader{} - headerLength, maskBytes := header.GenerateHeader(c.isServer, true, compress, opcode, n) - if !c.isServer { - internal.MaskXOR(payload, maskBytes) - } - var totalSize = n + headerLength - var buf = _bpool.Get(totalSize) - buf.Write(header[:headerLength]) - buf.Write(payload) - var err = internal.WriteN(c.conn, buf.Bytes(), totalSize) - _bpool.Put(buf) - return err + if c.compressEnabled && opcode.IsDataFrame() && len(payload) >= c.config.CompressThreshold { + return c.writeCompressedContents(opcode, payload) + } + + var n = len(payload) + if n > c.config.WriteMaxPayloadSize { + return internal.CloseMessageTooLarge + } + var header = frameHeader{} + headerLength, maskBytes := header.GenerateHeader(c.isServer, true, false, opcode, n) + if !c.isServer { + internal.MaskXOR(payload, maskBytes) } - return c.writeCompressedContents(opcode, payload) + var totalSize = n + headerLength + var buf = _bpool.Get(totalSize) + buf.Write(header[:headerLength]) + buf.Write(payload) + var err = internal.WriteN(c.conn, buf.Bytes(), totalSize) + _bpool.Put(buf) + return err } func (c *Conn) writeCompressedContents(opcode Opcode, payload []byte) error { var buf = _bpool.Get(len(payload) / 3) defer _bpool.Put(buf) - buf.Write(frameHeaderPadding[0:]) + var header = frameHeader{} + buf.Write(header[0:]) if err := c.compressor.Compress(payload, buf); err != nil { return err } @@ -92,7 +93,6 @@ func (c *Conn) writeCompressedContents(opcode Opcode, payload []byte) error { return internal.CloseMessageTooLarge } - var header = frameHeader{} headerLength, maskBytes := header.GenerateHeader(c.isServer, true, true, opcode, payloadSize) var offset = frameHeaderSize - headerLength if !c.isServer { From ee79fe16542393f4faed5d6dfb00da597cca843f Mon Sep 17 00:00:00 2001 From: lixizan Date: Fri, 28 Apr 2023 11:13:48 +0800 Subject: [PATCH 6/7] update unit tests --- compress_test.go | 22 +++++++++------------- updrader.go | 2 +- writer_test.go | 5 +++-- 3 files changed, 13 insertions(+), 16 deletions(-) diff --git a/compress_test.go b/compress_test.go index 09d6d8c0..6785c33e 100644 --- a/compress_test.go +++ b/compress_test.go @@ -17,16 +17,14 @@ func TestFlate(t *testing.T) { var dps = newDecompressor() var n = internal.AlphabetNumeric.Intn(1024) var rawText = internal.AlphabetNumeric.Generate(n) - var buf = bytes.NewBufferString("") - buf.Write(rawText) - compressedText, err := cps.Compress(buf) - if err != nil { + var compressedBuf = bytes.NewBufferString("") + if err := cps.Compress(rawText, compressedBuf); err != nil { as.NoError(err) return } - buf.Reset() - buf.Write(compressedText.Bytes()) + var buf = bytes.NewBufferString("") + buf.Write(compressedBuf.Bytes()) plainText, err := dps.Decompress(buf) if err != nil { as.NoError(err) @@ -41,18 +39,16 @@ func TestFlate(t *testing.T) { var dps = newDecompressor() var n = internal.AlphabetNumeric.Intn(1024) var rawText = internal.AlphabetNumeric.Generate(n) - var buf = bytes.NewBufferString("") - buf.Write(rawText) - compressedText, err := cps.Compress(buf) - if err != nil { + var compressedBuf = bytes.NewBufferString("") + if err := cps.Compress(rawText, compressedBuf); err != nil { as.NoError(err) return } - buf.Reset() - buf.Write(compressedText.Bytes()) + var buf = bytes.NewBufferString("") + buf.Write(compressedBuf.Bytes()) buf.WriteString("1234") - _, err = dps.Decompress(buf) + _, err := dps.Decompress(buf) as.Error(err) }) } diff --git a/updrader.go b/updrader.go index 84018ed4..a3600b4a 100644 --- a/updrader.go +++ b/updrader.go @@ -267,7 +267,7 @@ func (c *Server) serve(listener net.Listener) error { c.OnError(conn, err) return } - socket.Listen() + socket.ReadLoop() }() } } diff --git a/writer_test.go b/writer_test.go index 2f84e422..4d7d259e 100644 --- a/writer_test.go +++ b/writer_test.go @@ -15,11 +15,12 @@ func testWrite(c *Conn, fin bool, opcode Opcode, payload []byte) error { var useCompress = c.compressEnabled && opcode.IsDataFrame() && len(payload) >= c.config.CompressThreshold if useCompress { - compressedContent, err := c.compressor.Compress(bytes.NewBuffer(payload)) + var buf = bytes.NewBufferString("") + err := c.compressor.Compress(payload, buf) if err != nil { return internal.NewError(internal.CloseInternalServerErr, err) } - payload = compressedContent.Bytes() + payload = buf.Bytes() } if len(payload) > c.config.WriteMaxPayloadSize { return internal.CloseMessageTooLarge From e03ec8336788aacab3b0535eddbc1287a3d3d09d Mon Sep 17 00:00:00 2001 From: lixizan Date: Fri, 28 Apr 2023 14:30:30 +0800 Subject: [PATCH 7/7] improve xor --- writer.go | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/writer.go b/writer.go index c5238301..8e402f48 100644 --- a/writer.go +++ b/writer.go @@ -34,8 +34,6 @@ func (c *Conn) WriteString(s string) error { } // WriteMessage 发送消息 -// 如果是客户端, payload内容会因为异或计算而被改变 -// If it is a client, the payload content will be changed due to heterogeneous computation func (c *Conn) WriteMessage(opcode Opcode, payload []byte) error { if c.isClosed() { return internal.ErrConnClosed @@ -63,16 +61,18 @@ func (c *Conn) doWrite(opcode Opcode, payload []byte) error { if n > c.config.WriteMaxPayloadSize { return internal.CloseMessageTooLarge } + var header = frameHeader{} headerLength, maskBytes := header.GenerateHeader(c.isServer, true, false, opcode, n) - if !c.isServer { - internal.MaskXOR(payload, maskBytes) - } var totalSize = n + headerLength var buf = _bpool.Get(totalSize) buf.Write(header[:headerLength]) buf.Write(payload) - var err = internal.WriteN(c.conn, buf.Bytes(), totalSize) + var contents = buf.Bytes() + if !c.isServer { + internal.MaskXOR(contents[headerLength:], maskBytes) + } + var err = internal.WriteN(c.conn, contents, totalSize) _bpool.Put(buf) return err } @@ -94,12 +94,12 @@ func (c *Conn) writeCompressedContents(opcode Opcode, payload []byte) error { } headerLength, maskBytes := header.GenerateHeader(c.isServer, true, true, opcode, payloadSize) - var offset = frameHeaderSize - headerLength if !c.isServer { internal.MaskXOR(contents[frameHeaderSize:], maskBytes) } - copy(contents[offset:frameHeaderSize], header[:headerLength]) - return internal.WriteN(c.conn, contents[offset:], payloadSize+headerLength) + contents = contents[frameHeaderSize-headerLength:] + copy(contents[:headerLength], header[:headerLength]) + return internal.WriteN(c.conn, contents, payloadSize+headerLength) } // WriteAsync 异步非阻塞地写入消息