Skip to content

Commit

Permalink
fix: fix readx not updated when read an error format RESP cmd
Browse files Browse the repository at this point in the history
  • Loading branch information
satoshi-099 committed Aug 12, 2024
1 parent cc1df88 commit 28805f9
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 46 deletions.
4 changes: 2 additions & 2 deletions aof.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,14 @@ func (aof *Aof) Read(fn func(args []RESP)) error {
reader := NewReader(data)
argsBuf := make([]RESP, 8)
for {
values, err := reader.ReadNextCommand(argsBuf)
args, _, err := reader.ReadNextCommand(argsBuf)
if err != nil {
if err == io.EOF {
break
}
return err
}
fn(values)
fn(args)
}

return nil
Expand Down
1 change: 1 addition & 0 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ var cmdTable []*Command = []*Command{
// TODO
{"mset", todoCommand, 0, false},
{"xadd", todoCommand, 0, false},
{"client", todoCommand, 0, false},
}

func equalFold(a, b string) bool {
Expand Down
6 changes: 3 additions & 3 deletions command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,11 +320,11 @@ func TestCommand(t *testing.T) {

t.Run("concurrency", func(t *testing.T) {
var wg sync.WaitGroup
for i := 0; i < 2000; i++ {
for i := 0; i < 1000; i++ {
wg.Add(1)
go func() {
key := fmt.Sprintf("key-%08x", rand.Int())
value := fmt.Sprintf("val-%08x", rand.Int())
key := fmt.Sprintf("key%08x", rand.Int())
value := fmt.Sprintf("val%08x", rand.Int())

_, err := rdb.Set(ctx, key, value, 0).Result()
assert.Nil(err)
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ require (
github.com/stretchr/testify v1.9.0
github.com/tidwall/mmap v0.3.0
github.com/zeebo/xxh3 v1.0.2
golang.org/x/sys v0.23.0
golang.org/x/sys v0.24.0
)

require (
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.23.0 h1:YfKFowiIMvtgl1UERQoTPPToxltDeZfbj4H7dVUCwmM=
golang.org/x/sys v0.23.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg=
golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
gonum.org/v1/gonum v0.0.0-20181121035319-3f7ecaa7e8ca h1:PupagGYwj8+I4ubCxcmcBRk3VlUWtTg5huQpZR9flmE=
gonum.org/v1/gonum v0.0.0-20181121035319-3f7ecaa7e8ca/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo=
Expand Down
19 changes: 11 additions & 8 deletions resp.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,10 @@ func parseInt(buf []byte) (n int, after []byte, err error) {

// ReadNextCommand reads the next RESP command from the RESPReader.
// It parses both `COMMAND_BULK` and `COMMAND_INLINE` formats.
func (r *RESPReader) ReadNextCommand(argsBuf []RESP) (args []RESP, err error) {
if len(r.b) == 0 {
return nil, io.EOF
func (r *RESPReader) ReadNextCommand(argsBuf []RESP) (args []RESP, n int, err error) {
srclen := len(r.b)
if srclen == 0 {
return nil, 0, io.EOF
}
args = argsBuf[:0]

Expand All @@ -60,24 +61,24 @@ func (r *RESPReader) ReadNextCommand(argsBuf []RESP) (args []RESP, err error) {
// command_bulk format
num, after, err := parseInt(r.b[1:])
if err != nil {
return nil, err
return nil, 0, err
}
r.b = after

// read bulk strings for range
for i := 0; i < num; i++ {
if len(r.b) == 0 || r.b[0] != BULK {
return nil, errInvalidArguments
return nil, 0, errInvalidArguments
}

num, after, err := parseInt(r.b[1:])
if err != nil {
return nil, err
return nil, 0, err
}

// bound check
if num < 0 || num+2 > len(after) {
return nil, errInvalidArguments
return nil, 0, errInvalidArguments
}

args = append(args, after[:num])
Expand All @@ -90,11 +91,13 @@ func (r *RESPReader) ReadNextCommand(argsBuf []RESP) (args []RESP, err error) {
// command_inline format
before, after, ok := bytes.Cut(r.b, CRLF)
if !ok {
return nil, errInvalidArguments
return nil, 0, errInvalidArguments
}
args = append(args, before)
r.b = after
}

n = srclen - len(r.b)
return
}

Expand Down
48 changes: 29 additions & 19 deletions resp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,14 @@ func TestReader(t *testing.T) {

t.Run("error-reader", func(t *testing.T) {
// read nil
_, err := NewReader(nil).ReadNextCommand(nil)
_, n, err := NewReader(nil).ReadNextCommand(nil)
assert.Equal(n, 0)
assert.NotNil(err)

for _, prefix := range []byte{BULK, INTEGER, ARRAY} {
data := append([]byte{prefix}, "an error message"...)
_, err := NewReader(data).ReadNextCommand(nil)
_, n, err := NewReader(data).ReadNextCommand(nil)
assert.Equal(n, 0)
assert.NotNil(err)
}
})
Expand Down Expand Up @@ -84,37 +86,45 @@ func TestReader(t *testing.T) {
})

t.Run("command-bulk", func(t *testing.T) {
args, err := NewReader([]byte("*3\r\n$3\r\nSET\r\n$3\r\nfoo\r\n$3\r\nbar\r\n")).ReadNextCommand(nil)
assert.Equal(args[0].ToString(), "SET")
assert.Equal(args[1].ToString(), "foo")
assert.Equal(args[2].ToString(), "bar")
cmdStr := []byte("*3\r\n$3\r\nSET\r\n$3\r\nfoo\r\n$3\r\nbar\r\n")
args, n, err := NewReader(cmdStr).ReadNextCommand(nil)
assert.Equal(args, []RESP{RESP("SET"), RESP("foo"), RESP("bar")})
assert.Equal(n, len(cmdStr))
assert.Nil(err)

// error
args, err = NewReader([]byte("*A\r\n$3\r\nGET\r\n$3\r\nfoo\r\n")).ReadNextCommand(nil)
assert.Equal(len(args), 0)
// error format cmd
_, _, err = NewReader([]byte("*A\r\n$3\r\nGET\r\n$3\r\nfoo\r\n")).ReadNextCommand(nil)
assert.ErrorIs(err, errParseInteger)

args, err = NewReader([]byte("*3\r\n$A\r\nGET\r\n$3\r\nfoo\r\n")).ReadNextCommand(nil)
assert.Equal(len(args), 0)
_, _, err = NewReader([]byte("*3\r\n$A\r\nGET\r\n$3\r\nfoo\r\n")).ReadNextCommand(nil)
assert.ErrorIs(err, errParseInteger)

args, err = NewReader([]byte("*3\r\n+PING")).ReadNextCommand(nil)
assert.Equal(len(args), 0)
_, _, err = NewReader([]byte("*3\r\n+PING")).ReadNextCommand(nil)
assert.NotNil(err)

args, err = NewReader([]byte("*3\r\n$3ABC")).ReadNextCommand(nil)
assert.Equal(len(args), 0)
_, _, err = NewReader([]byte("*3\r\n$3ABC")).ReadNextCommand(nil)
assert.NotNil(err)

args, err = NewReader([]byte("*1\r\n")).ReadNextCommand(nil)
assert.Equal(len(args), 0)
_, _, err = NewReader([]byte("*1\r\n")).ReadNextCommand(nil)
assert.NotNil(err)

// multi cmd contains error format
{
rd := NewReader([]byte("*3\r\n$3\r\nSET\r\n$3\r\nfoo\r\n$3\r\nbar\r\n---ERROR MSG---"))
_, n, err = rd.ReadNextCommand(nil)
assert.Equal(n, 31)
assert.Nil(err)

_, n, err = rd.ReadNextCommand(nil)
assert.Equal(n, 0)
assert.NotNil(err)
}
})

t.Run("command-inline", func(t *testing.T) {
args, err := NewReader([]byte("PING\r\n")).ReadNextCommand(nil)
assert.Equal(args[0].ToString(), "PING")
args, n, err := NewReader([]byte("PING\r\n")).ReadNextCommand(nil)
assert.Equal(args[0], RESP("PING"))
assert.Equal(n, 6)
assert.Nil(err)
})
}
Expand Down
28 changes: 17 additions & 11 deletions rotom.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,12 @@ type DB struct {
}

type Client struct {
fd int
queryLen int
queryBuf []byte
fd int

recvx int
readx int
queryBuf []byte

argsBuf []RESP
replyWriter *RESPWriter
}
Expand Down Expand Up @@ -106,29 +109,30 @@ func ReadQueryFromClient(loop *AeLoop, fd int, extra interface{}) {
readSize := 0

READ:
n, err := Read(fd, client.queryBuf[client.queryLen:])
n, err := Read(fd, client.queryBuf[client.recvx:])
if err != nil {
log.Error().Msgf("client %v read err: %v", fd, err)
freeClient(client)
return
}
readSize += n
client.queryLen += n
client.recvx += n

if readSize == 0 {
log.Warn().Msgf("client %d read query empty, now free", fd)
freeClient(client)
return
}

if client.queryLen > MAX_QUERY_DATA_LEN {
if client.recvx >= MAX_QUERY_DATA_LEN {
log.Error().Msgf("client %d read query data too large, now free", fd)
freeClient(client)
return
}

// queryBuf need grow up
if client.queryLen == len(client.queryBuf) {
client.queryBuf = append(client.queryBuf, make([]byte, client.queryLen)...)
if client.recvx == len(client.queryBuf) {
client.queryBuf = append(client.queryBuf, make([]byte, client.recvx)...)
log.Warn().Msgf("client %d queryBuf grow up to size %s", fd, readableSize(len(client.queryBuf)))
goto READ
}
Expand All @@ -137,7 +141,8 @@ READ:
}

func resetClient(client *Client) {
client.queryLen = 0
client.readx = 0
client.recvx = 0
}

func freeClient(client *Client) {
Expand All @@ -147,18 +152,19 @@ func freeClient(client *Client) {
}

func ProcessQueryBuf(client *Client) {
queryBuf := client.queryBuf[:client.queryLen]
queryBuf := client.queryBuf[client.readx:client.recvx]

reader := NewReader(queryBuf)
for {
args, err := reader.ReadNextCommand(client.argsBuf)
args, n, err := reader.ReadNextCommand(client.argsBuf)
if err != nil {
if err == io.EOF {
break
}
log.Error().Msgf("read resp error: %v", err)
return
}
client.readx += n

command := args[0].ToStringUnsafe()
args = args[1:]
Expand Down

0 comments on commit 28805f9

Please sign in to comment.