Skip to content

Commit

Permalink
check text encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
lixizan committed Apr 21, 2023
1 parent 8f36165 commit 0fcd170
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 17 deletions.
15 changes: 14 additions & 1 deletion conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"sync"
"sync/atomic"
"time"
"unicode/utf8"
)

type Conn struct {
Expand Down Expand Up @@ -84,6 +85,18 @@ func (c *Conn) Listen() {
}
}

func (c *Conn) isTextValid(opcode Opcode, payload []byte) bool {
if !c.config.CheckUtf8Enabled {
return true
}
switch opcode {
case OpcodeText, OpcodeCloseConnection:
return utf8.Valid(payload)
default:
return true
}
}

func (c *Conn) isClosed() bool {
return atomic.LoadUint32(&c.closed) == 1
}
Expand Down Expand Up @@ -144,7 +157,7 @@ func (c *Conn) emitClose(buf *bytes.Buffer) error {
responseCode = internal.StatusCode(realCode)
}
}
if c.config.CheckUtf8Enabled && !isTextValid(OpcodeCloseConnection, buf.Bytes()) {
if !c.isTextValid(OpcodeCloseConnection, buf.Bytes()) {
responseCode = internal.CloseUnsupportedData
}
}
Expand Down
8 changes: 5 additions & 3 deletions examples/autobahn/client/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"time"
)

const remoteAddr = "localhost:9001"

func main() {
const count = 517
for i := 1; i <= count; i++ {
Expand All @@ -16,7 +18,7 @@ func main() {
}

func testCase(id int) {
var url = fmt.Sprintf("ws://localhost:9001/runCase?case=%d&agent=gws/client", id)
var url = fmt.Sprintf("ws://%s/runCase?case=%d&agent=gws/client", remoteAddr, id)
var handler = &WebSocket{onexit: make(chan struct{})}
socket, _, err := gws.NewClient(handler, &gws.ClientOption{
Addr: url,
Expand Down Expand Up @@ -57,7 +59,7 @@ func (c *WebSocket) OnPing(socket *gws.Conn, payload []byte) {
func (c *WebSocket) OnPong(socket *gws.Conn, payload []byte) {}

func (c *WebSocket) OnMessage(socket *gws.Conn, message *gws.Message) {
_ = socket.WriteMessage(message.Opcode, message.Bytes())
_ = socket.WriteAsync(message.Opcode, message.Bytes())
}

type updateReportsHandler struct {
Expand All @@ -78,7 +80,7 @@ func (c *updateReportsHandler) OnClose(socket *gws.Conn, code uint16, reason []b
}

func updateReports() {
var url = fmt.Sprintf("ws://localhost:9001/updateReports?agent=gws/client")
var url = fmt.Sprintf("ws://%s/updateReports?agent=gws/client", remoteAddr)
var handler = &updateReportsHandler{onexit: make(chan struct{})}
socket, _, err := gws.NewClient(handler, &gws.ClientOption{
Addr: url,
Expand Down
7 changes: 7 additions & 0 deletions internal/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,10 @@ func Split(s string, sep string) []string {
func HttpHeaderEqual(a, b string) bool {
return strings.ToLower(a) == strings.ToLower(b)
}

func SelectInt(ok bool, a, b int) int {
if ok {
return a
}
return b
}
5 changes: 5 additions & 0 deletions internal/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,8 @@ func TestHttpHeaderEqual(t *testing.T) {
assert.Equal(t, true, HttpHeaderEqual("WebSocket", "websocket"))
assert.Equal(t, false, HttpHeaderEqual("WebSocket@", "websocket"))
}

func TestSelectInt(t *testing.T) {
assert.Equal(t, 1, SelectInt(true, 1, 2))
assert.Equal(t, 2, SelectInt(false, 1, 2))
}
8 changes: 0 additions & 8 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"encoding/binary"
"github.com/lxzan/gws/internal"
"io"
"unicode/utf8"
)

type Opcode uint8
Expand Down Expand Up @@ -198,13 +197,6 @@ func (c *Message) Close() {
c.Data = nil
}

func isTextValid(opcode Opcode, p []byte) bool {
if len(p) > 0 && (opcode == OpcodeCloseConnection || opcode == OpcodeText) {
return utf8.Valid(p)
}
return true
}

type continuationFrame struct {
initialized bool
compressed bool
Expand Down
2 changes: 1 addition & 1 deletion reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ func (c *Conn) emitMessage(msg *Message, compressed bool) error {
}
msg.Data = data
}
if c.config.CheckUtf8Enabled && !isTextValid(msg.Opcode, msg.Bytes()) {
if !c.isTextValid(msg.Opcode, msg.Bytes()) {
return internal.NewError(internal.CloseUnsupportedData, internal.ErrTextEncoding)
}

Expand Down
4 changes: 2 additions & 2 deletions writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ func (c *Conn) endWrite(compress bool) {
// 执行写入逻辑, 关闭状态置为1后还能写, 以便发送关闭帧
// Execute the write logic, and write after the close state is set to 1, so that the close frame can be sent
func (c *Conn) doWrite(opcode Opcode, payload []byte) error {
if c.config.CheckUtf8Enabled && !isTextValid(OpcodeCloseConnection, payload) {
return internal.CloseUnsupportedData
if opcode == OpcodeText && !c.isTextValid(opcode, payload) {
return internal.NewError(internal.CloseUnsupportedData, internal.ErrTextEncoding)
}

var compress = c.compressEnabled && opcode.IsDataFrame() && len(payload) >= c.config.CompressThreshold
Expand Down
30 changes: 28 additions & 2 deletions writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,34 @@ func TestConn_WriteInvalidUTF8(t *testing.T) {
var clientHandler = new(webSocketMocker)
var serverOption = &ServerOption{CheckUtf8Enabled: true}
var clientOption = &ClientOption{}
server, _ := newPeer(serverHandler, serverOption, clientHandler, clientOption)
server, client := newPeer(serverHandler, serverOption, clientHandler, clientOption)
go server.Listen()
go client.Listen()
var payload = []byte{1, 2, 255}
as.Error(server.WriteMessage(OpcodeText, payload))
as.Error(server.WriteAsync(OpcodeText, payload))
}

func TestConn_WriteClose(t *testing.T) {
var wg = sync.WaitGroup{}
wg.Add(3)
var serverHandler = new(webSocketMocker)
var clientHandler = new(webSocketMocker)
var serverOption = &ServerOption{CheckUtf8Enabled: true}
var clientOption = &ClientOption{}
server, client := newPeer(serverHandler, serverOption, clientHandler, clientOption)
clientHandler.onClose = func(socket *Conn, code uint16, reason []byte) {
wg.Done()
}
clientHandler.onMessage = func(socket *Conn, message *Message) {
wg.Done()
}
go server.Listen()
go client.Listen()

//var payload = internal.CloseGoingAway.Bytes()
//payload = append(payload, "goodbye"...)
server.WriteMessage(OpcodeText, nil)
server.WriteMessage(OpcodeText, []byte("hello"))
server.WriteMessage(OpcodeCloseConnection, []byte{1})
wg.Wait()
}

0 comments on commit 0fcd170

Please sign in to comment.