Skip to content

Commit

Permalink
Merge pull request #12 from lxzan/dev
Browse files Browse the repository at this point in the history
v1.4.9
  • Loading branch information
lxzan committed May 2, 2023
2 parents ce10c4d + 900fbfa commit ddc71e5
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 210 deletions.
32 changes: 28 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
- [Best Practice](#best-practice)
- [Usage](#usage)
- [Upgrade from HTTP](#upgrade-from-http)
- [Client](#client)
- [Unix Domain Socket](#unix-domain-socket)
- [Broadcast](#broadcast)
- [Autobahn Test](#autobahn-test)
- [Benchmark](#benchmark)
Expand Down Expand Up @@ -151,22 +151,46 @@ func main() {
}
```

#### Client
#### Unix Domain Socket
- server
```go
package main

import (
"github.com/lxzan/gws"
"log"
"net"
)

func main() {
listener, err := net.Listen("unix", "/run/gws.sock")
if err != nil {
log.Println(err.Error())
return
}
var app = gws.NewServer(new(gws.BuiltinEventHandler), nil)
if err := app.RunListener(listener); err != nil {
log.Println(err.Error())
}
}
```

- client
```go
package main

import (
"fmt"
"github.com/lxzan/gws"
"log"
)

func main() {
socket, _, err := gws.NewClient(new(gws.BuiltinEventHandler), &gws.ClientOption{
Addr: "ws://127.0.0.1:6666/connect",
Addr: "unix://localhost/run/gws.sock",
})
if err != nil {
log.Printf(err.Error())
log.Println(err.Error())
return
}
socket.ReadLoop()
Expand Down
186 changes: 66 additions & 120 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,39 +2,46 @@ package gws

import (
"bufio"
"bytes"
"context"
"crypto/tls"
"encoding/base64"
"encoding/binary"
"github.com/lxzan/gws/internal"
"net"
"net/http"
"net/url"
"strconv"
"reflect"
"strings"
"time"
)

// NewClient 创建WebSocket客户端
type dialer struct {
option *ClientOption
conn net.Conn
eventHandler Event
resp *http.Response
secWebsocketKey string
}

// NewClient 创建WebSocket客户端; 支持ws, wss, unix三种协议
// Create WebSocket client, support ws, wss, unix three protocols
func NewClient(handler Event, option *ClientOption) (client *Conn, resp *http.Response, e error) {
var d = &dialer{
eventHandler: handler,
resp: &http.Response{
Header: http.Header{},
},
}
if option == nil {
option = new(ClientOption)
}
option.initialize()

var d = &dialer{eventHandler: handler, option: option}
defer func() {
if e != nil && !d.isNil(d.conn) {
_ = d.conn.Close()
}
}()

URL, err := url.Parse(option.Addr)
if err != nil {
return nil, d.resp, err
}

var conn net.Conn
var dialError error
var hostname = URL.Hostname()
var port = URL.Port()
Expand All @@ -45,157 +52,96 @@ func NewClient(handler Event, option *ClientOption) (client *Conn, resp *http.Re
port = "80"
}
host = hostname + ":" + port
conn, dialError = net.DialTimeout("tcp", host, option.DialTimeout)
d.conn, dialError = net.DialTimeout("tcp", host, option.DialTimeout)
case "wss":
if port == "" {
port = "443"
}
host = hostname + ":" + port
var tlsDialer = &net.Dialer{Timeout: option.DialTimeout}
conn, dialError = tls.DialWithDialer(tlsDialer, "tcp", host, option.TlsConfig)
d.conn, dialError = tls.DialWithDialer(tlsDialer, "tcp", host, option.TlsConfig)
case "unix":
d.conn, dialError = net.DialTimeout("unix", URL.Path, option.DialTimeout)
default:
return nil, d.resp, internal.ErrSchema
}

if dialError != nil {
return nil, d.resp, dialError
}
if err := conn.SetDeadline(time.Now().Add(option.DialTimeout)); err != nil {
if err := d.conn.SetDeadline(time.Now().Add(option.DialTimeout)); err != nil {
return nil, d.resp, err
}

d.host = host
d.option = option
d.conn = conn
return d.handshake()
}

type dialer struct {
option *ClientOption
conn net.Conn
host string
eventHandler Event
resp *http.Response
}

func (c *dialer) stradd(ss ...string) string {
var b []byte
for _, item := range ss {
b = append(b, item...)
func (c *dialer) isNil(v interface{}) bool {
if v == nil {
return true
}
return string(b)
return reflect.ValueOf(v).IsNil()
}

// 生成报文
func (c *dialer) generateTelegram() []byte {
if c.option.RequestHeader.Get(internal.SecWebSocketKey.Key) == "" {
func (c *dialer) writeRequest() (*http.Request, error) {
r, err := http.NewRequest(http.MethodGet, c.option.Addr, nil)
if err != nil {
return nil, err
}
r.Header = c.option.RequestHeader.Clone()
r.Header.Set(internal.Connection.Key, internal.Connection.Val)
r.Header.Set(internal.Upgrade.Key, internal.Upgrade.Val)
r.Header.Set(internal.SecWebSocketVersion.Key, internal.SecWebSocketVersion.Val)
if c.option.CompressEnabled {
r.Header.Set(internal.SecWebSocketExtensions.Key, internal.SecWebSocketExtensions.Val)
}
if c.secWebsocketKey == "" {
var key [16]byte
binary.BigEndian.PutUint64(key[0:8], internal.AlphabetNumeric.Uint64())
binary.BigEndian.PutUint64(key[8:16], internal.AlphabetNumeric.Uint64())
c.option.RequestHeader.Set(internal.SecWebSocketKey.Key, base64.StdEncoding.EncodeToString(key[0:]))
}
if c.option.CompressEnabled {
c.option.RequestHeader.Set(internal.SecWebSocketExtensions.Key, internal.SecWebSocketExtensions.Val)
c.secWebsocketKey = base64.StdEncoding.EncodeToString(key[0:])
r.Header.Set(internal.SecWebSocketKey.Key, c.secWebsocketKey)
}

var buf []byte
buf = append(buf, c.stradd("GET ", c.option.Addr, " HTTP/1.1\r\n")...)
buf = append(buf, c.stradd("Host: ", c.host, "\r\n")...)
buf = append(buf, "Connection: Upgrade\r\n"...)
buf = append(buf, "Upgrade: websocket\r\n"...)
buf = append(buf, "Sec-WebSocket-Version: 13\r\n"...)
for k, _ := range c.option.RequestHeader {
buf = append(buf, c.stradd(k, ": ", c.option.RequestHeader.Get(k), "\r\n")...)
}
buf = append(buf, "\r\n"...)
return buf
return r, r.Write(c.conn)
}

func (c *dialer) handshake() (*Conn, *http.Response, error) {
br := bufio.NewReaderSize(c.conn, c.option.ReadBufferSize)
telegram := c.generateTelegram()
if err := internal.WriteN(c.conn, telegram, len(telegram)); err != nil {
return nil, c.resp, err
request, err := c.writeRequest()
if err != nil {
return nil, nil, err
}

var ch = make(chan error)
ctx, cancel := context.WithTimeout(context.Background(), c.option.DialTimeout)
defer cancel()

var channel = make(chan error)
go func() {
var index = 0
for {
line, isPrefix, err := br.ReadLine()
if err != nil {
ch <- err
return
}
if isPrefix {
ch <- internal.ErrLongLine
return
}
if index == 0 {
arr := bytes.Split(line, []byte(" "))
if len(arr) >= 2 {
code, _ := strconv.Atoi(string(arr[1]))
c.resp.StatusCode = code
}
if len(arr) != 4 || c.resp.StatusCode != 101 {
ch <- internal.ErrStatusCode
return
}
} else {
if len(line) == 0 {
ch <- nil
return
}
arr := strings.Split(string(line), ": ")
if len(arr) != 2 {
ch <- internal.ErrHandshake
return
}
c.resp.Header.Set(arr[0], arr[1])
}
index++
}
c.resp, err = http.ReadResponse(br, request)
channel <- err
}()

for {
select {
case <-ctx.Done():
return nil, c.resp, internal.ErrDialTimeout
case err := <-ch:
if err != nil {
return nil, c.resp, err
}
if err := c.checkHeaders(); err != nil {
return nil, c.resp, err
}
var compressEnabled = false
if c.option.CompressEnabled && strings.Contains(c.resp.Header.Get(internal.SecWebSocketExtensions.Key), "permessage-deflate") {
compressEnabled = true
}
if err := c.conn.SetDeadline(time.Time{}); err != nil {
return nil, c.resp, err
}
if err := setNoDelay(c.conn); err != nil {
return nil, c.resp, err
}
return serveWebSocket(false, c.option.getConfig(), new(sliceMap), c.conn, br, c.eventHandler, compressEnabled), c.resp, nil
}
if err := <-channel; err != nil {
return nil, c.resp, err
}
if err := c.checkHeaders(); err != nil {
return nil, c.resp, err
}
if err := c.conn.SetDeadline(time.Time{}); err != nil {
return nil, c.resp, err
}
if err := setNoDelay(c.conn); err != nil {
return nil, c.resp, err
}
var compressEnabled = c.option.CompressEnabled && strings.Contains(c.resp.Header.Get(internal.SecWebSocketExtensions.Key), "permessage-deflate")
return serveWebSocket(false, c.option.getConfig(), new(sliceMap), c.conn, br, c.eventHandler, compressEnabled), c.resp, nil
}

func (c *dialer) checkHeaders() error {
if c.resp.StatusCode != 101 {
return internal.ErrStatusCode
}
if !internal.HttpHeaderEqual(c.resp.Header.Get(internal.Connection.Key), internal.Connection.Val) {
return internal.ErrHandshake
}
if !internal.HttpHeaderEqual(c.resp.Header.Get(internal.Upgrade.Key), internal.Upgrade.Val) {
return internal.ErrHandshake
}
var expectedKey = internal.ComputeAcceptKey(c.option.RequestHeader.Get(internal.SecWebSocketKey.Key))
var actualKey = c.resp.Header.Get(internal.SecWebSocketAccept.Key)
if actualKey != expectedKey {
if c.resp.Header.Get(internal.SecWebSocketAccept.Key) != internal.ComputeAcceptKey(c.secWebsocketKey) {
return internal.ErrHandshake
}
return nil
Expand Down
Loading

0 comments on commit ddc71e5

Please sign in to comment.