diff --git a/command.go b/command.go index c30e314..52be542 100644 --- a/command.go +++ b/command.go @@ -83,7 +83,7 @@ func lookupCommand(name string) (*Command, error) { func (cmd *Command) processCommand(writer *RESPWriter, args []RESP) { if len(args) < cmd.minArgsNum { - writer.WriteError(errInvalidArguments) + writer.WriteError(errWrongArguments) return } cmd.handler(writer, args) @@ -209,7 +209,7 @@ func hsetCommand(writer *RESPWriter, args []RESP) { args = args[1:] if len(args)%2 == 1 { - writer.WriteError(errInvalidArguments) + writer.WriteError(errWrongArguments) return } @@ -353,6 +353,8 @@ func lrangeCommand(writer *RESPWriter, args []RESP) { if stop == -1 { stop = ls.Size() + } else if stop < ls.Size() { + stop++ // range 1 3 means range[1,3] } start = min(start, stop) @@ -542,30 +544,6 @@ func flushdbCommand(writer *RESPWriter, _ []RESP) { func evalCommand(writer *RESPWriter, args []RESP) { L := server.lua script := args[0].ToString() - numKeys, err := args[1].ToInt() - if err != nil { - writer.WriteError(err) - return - } - numArgv := len(args) - 2 - numKeys - - // set "KEYS" table - if numKeys > 0 { - keyTable := L.CreateTable(numKeys, 0) - for i := 2; i < numKeys+2; i++ { - keyTable.Append(lua.LString(args[i])) - } - L.SetGlobal("KEYS", keyTable) - } - - // set "ARGV" table - if numArgv > 0 { - argvTable := L.CreateTable(numArgv, 0) - for i := 2 + numKeys; i < len(args); i++ { - argvTable.Append(lua.LString(args[i])) - } - L.SetGlobal("ARGV", argvTable) - } if err := L.DoString(script); err != nil { writer.WriteError(err) diff --git a/command_test.go b/command_test.go index 19a8708..9244c08 100644 --- a/command_test.go +++ b/command_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/alicebob/miniredis/v2" "github.com/redis/go-redis/v9" "github.com/stretchr/testify/assert" ) @@ -17,7 +18,7 @@ func startup() { config := &Config{ Port: 20082, AppendOnly: true, - AppendFileName: "appendonly-test.aof", + AppendFileName: "test.aof", } os.Remove(config.AppendFileName) config4Server(config) @@ -27,18 +28,33 @@ func startup() { server.aeLoop.AeMain() } -var ctx = context.Background() - func TestCommand(t *testing.T) { - assert := assert.New(t) - - go startup() - time.Sleep(time.Second / 2) - - // wait for client starup - rdb := redis.NewClient(&redis.Options{ - Addr: ":20082", + t.Run("miniredis", func(t *testing.T) { + s := miniredis.RunT(t) + rdb := redis.NewClient(&redis.Options{ + Addr: s.Addr(), + }) + sleepFn := func(dur time.Duration) { + s.FastForward(dur) + } + testCommand(t, rdb, sleepFn) }) + t.Run("rotom", func(t *testing.T) { + go startup() + time.Sleep(time.Second / 2) + rdb := redis.NewClient(&redis.Options{ + Addr: ":20082", + }) + sleepFn := func(dur time.Duration) { + time.Sleep(dur) + } + testCommand(t, rdb, sleepFn) + }) +} + +func testCommand(t *testing.T, rdb *redis.Client, sleepFn func(time.Duration)) { + assert := assert.New(t) + ctx := context.Background() t.Run("ping", func(t *testing.T) { res, _ := rdb.Ping(ctx).Result() @@ -53,66 +69,56 @@ func TestCommand(t *testing.T) { assert.Equal(res, "bar") res, err := rdb.Get(ctx, "none").Result() - assert.Equal(err, redis.Nil) assert.Equal(res, "") + assert.Equal(err, redis.Nil) n, _ := rdb.Del(ctx, "foo", "none").Result() assert.Equal(n, int64(1)) - // setnx - ok, err := rdb.SetNX(ctx, "key-nx", "123", redis.KeepTTL).Result() - assert.Nil(err) - assert.True(ok) - - ok, err = rdb.SetNX(ctx, "key-nx", "123", redis.KeepTTL).Result() - assert.Nil(err) - assert.False(ok) - }) - - t.Run("error-get", func(t *testing.T) { - lskey := fmt.Sprintf("ls-%d", time.Now().UnixNano()) - rdb.RPush(ctx, lskey, "1") - - _, err := rdb.Get(ctx, lskey).Result() - assert.Equal(err.Error(), errWrongType.Error()) - }) - - t.Run("setex", func(t *testing.T) { - res, _ := rdb.Set(ctx, "foo", "bar", time.Second).Result() - assert.Equal(res, "OK") - - res, _ = rdb.Get(ctx, "foo").Result() - assert.Equal(res, "bar") - - time.Sleep(time.Second + time.Millisecond) + // setex + { + res, _ := rdb.Set(ctx, "foo", "bar", time.Second).Result() + assert.Equal(res, "OK") - _, err := rdb.Get(ctx, "foo").Result() - assert.Equal(err, redis.Nil) - }) + res, _ = rdb.Get(ctx, "foo").Result() + assert.Equal(res, "bar") - t.Run("setpx", func(t *testing.T) { - res, _ := rdb.Set(ctx, "foo", "bar", time.Millisecond*100).Result() - assert.Equal(res, "OK") + sleepFn(time.Second + time.Millisecond) - res, _ = rdb.Get(ctx, "foo").Result() - assert.Equal(res, "bar") + _, err := rdb.Get(ctx, "foo").Result() + assert.Equal(err, redis.Nil) + } + // setpx + { + res, _ := rdb.Set(ctx, "foo", "bar", time.Millisecond*100).Result() + assert.Equal(res, "OK") - time.Sleep(time.Millisecond * 101) + res, _ = rdb.Get(ctx, "foo").Result() + assert.Equal(res, "bar") - _, err := rdb.Get(ctx, "foo").Result() - assert.Equal(err, redis.Nil) - }) + sleepFn(time.Millisecond * 101) - t.Run("pipline", func(t *testing.T) { - pip := rdb.Pipeline() - pip.RPush(ctx, "ls-pip", "A", "B", "C") - pip.LPop(ctx, "ls-pip") + _, err := rdb.Get(ctx, "foo").Result() + assert.Equal(err, redis.Nil) + } + // setnx + { + ok, err := rdb.SetNX(ctx, "keynx", "123", redis.KeepTTL).Result() + assert.Nil(err) + assert.True(ok) - _, err := pip.Exec(ctx) - assert.Nil(err) + ok, err = rdb.SetNX(ctx, "keynx", "123", redis.KeepTTL).Result() + assert.Nil(err) + assert.False(ok) + } + // error + { + lskey := fmt.Sprintf("ls-%x", time.Now().UnixNano()) + rdb.RPush(ctx, lskey, "1") - res, _ := rdb.LRange(ctx, "ls-pip", 0, -1).Result() - assert.Equal(res, []string{"B", "C"}) + _, err := rdb.Get(ctx, lskey).Result() + assert.Equal(err.Error(), errWrongType.Error()) + } }) t.Run("incr", func(t *testing.T) { @@ -170,10 +176,10 @@ func TestCommand(t *testing.T) { // error hset _, err := rdb.HSet(ctx, "map").Result() - assert.Equal(err.Error(), errInvalidArguments.Error()) + assert.Contains(err.Error(), errWrongArguments.Error()) _, err = rdb.HSet(ctx, "map", "k1", "v1", "k2").Result() - assert.Equal(err.Error(), errInvalidArguments.Error()) + assert.Contains(err.Error(), errWrongArguments.Error()) // err wrong type rdb.Set(ctx, "key", "value", 0) @@ -205,7 +211,7 @@ func TestCommand(t *testing.T) { assert.Equal(res, []string{"c", "b", "a", "d", "e", "f"}) res, _ = rdb.LRange(ctx, "list", 1, 3).Result() - assert.Equal(res, []string{"b", "a"}) + assert.Equal(res, []string{"b", "a", "d"}) res, err := rdb.LRange(ctx, "list", 3, 2).Result() assert.Equal(len(res), 0) @@ -377,44 +383,30 @@ func TestCommand(t *testing.T) { }) t.Run("eval", func(t *testing.T) { - { - res, err := rdb.Eval(ctx, - "return {KEYS[1],KEYS[2],ARGV[1],ARGV[2]}", - []string{"key1", "key2"}, - []any{"first", "second"}, - ).Result() - assert.Equal(res, []any{"key1", "key2", "first", "second"}) - assert.Nil(err) - } - { - res, err := rdb.Eval(ctx, "return {1,2,3}", []string{}).Result() - assert.Equal(res, []any{int64(1), int64(2), int64(3)}) - assert.Nil(err) - } - { - keys := []string{"evalKey", "qwer"} + res, _ := rdb.Eval(ctx, "return {'key1','key2','key3'}", nil).Result() + assert.Equal(res, []any{"key1", "key2", "key3"}) - // set - _, err := rdb.Eval(ctx, "call('set',KEYS[0],KEYS[1])", keys).Result() - assert.Equal(err, redis.Nil) + res, _ = rdb.Eval(ctx, "return {1,2,3}", nil).Result() + assert.Equal(res, []any{int64(1), int64(2), int64(3)}) - // set with return - res, _ := rdb.Eval(ctx, "return call('set',KEYS[0],KEYS[1])", keys).Result() - assert.Equal(res, "OK") + // set + _, err := rdb.Eval(ctx, "redis.call('set','xgz','qwe')", nil).Result() + assert.Equal(err, redis.Nil) - // get - res, _ = rdb.Eval(ctx, "return call('get',KEYS[0])", keys[:1]).Result() - assert.Equal(res, keys[1]) + res, _ = rdb.Eval(ctx, "return redis.call('set','xgz','qwe')", nil).Result() + assert.Equal(res, "OK") - // get nil - _, err = rdb.Eval(ctx, "return call('get',KEYS[0])", []string{"notExistKey"}).Result() - assert.Equal(err, redis.Nil) - } - { - // unknown function - _, err := rdb.Eval(ctx, "call('wwwww','aaa')", []string{}).Result() - assert.NotNil(err) // TODO - } + // get + res, _ = rdb.Eval(ctx, "return redis.call('get','xgz')", nil).Result() + assert.Equal(res, "qwe") + + // get nil + _, err = rdb.Eval(ctx, "return redis.call('get','not-ex-evalkey')", nil).Result() + assert.Equal(err, redis.Nil) + + // error call + _, err = rdb.Eval(ctx, "return redis.call('myfunc','key')", nil).Result() + assert.NotNil(err) }) t.Run("flushdb", func(t *testing.T) { @@ -437,9 +429,8 @@ func TestCommand(t *testing.T) { _, err := rdb.Set(ctx, key, value, 0).Result() assert.Nil(err) - res, err := rdb.Get(ctx, key).Result() + res, _ := rdb.Get(ctx, key).Result() assert.Equal(res, value) - assert.Nil(err) wg.Done() }() @@ -447,11 +438,11 @@ func TestCommand(t *testing.T) { wg.Wait() }) - t.Run("bigKey", func(t *testing.T) { - body := make([]byte, MAX_QUERY_DATA_LEN) - _, err := rdb.Set(ctx, "bigKey", body, 0).Result() - assert.NotNil(err) - }) + // t.Run("bigKey", func(t *testing.T) { + // body := make([]byte, MAX_QUERY_DATA_LEN) + // _, err := rdb.Set(ctx, "bigKey", body, 0).Result() + // assert.NotNil(err) + // }) t.Run("trans-zipmap", func(t *testing.T) { for i := 0; i <= 256; i++ { @@ -467,8 +458,9 @@ func TestCommand(t *testing.T) { } }) - t.Run("client-closed", func(t *testing.T) { - rdb.Close() + t.Run("closed", func(t *testing.T) { + err := rdb.Close() + assert.Nil(err) }) } diff --git a/errors.go b/errors.go index ce37600..0e8b883 100644 --- a/errors.go +++ b/errors.go @@ -5,11 +5,11 @@ import ( ) var ( - errWrongType = errors.New("WRONGTYPE Operation against a key holding the wrong kind of value") - errParseInteger = errors.New("ERR value is not an integer or out of range") - errCRLFNotFound = errors.New("ERR CRLF not found in line") - errInvalidArguments = errors.New("ERR invalid number of arguments") - errUnknownCommand = errors.New("ERR unknown command") - errOOM = errors.New("ERR command not allowed when out of memory") - errSyntax = errors.New("ERR syntax error") + errWrongType = errors.New("WRONGTYPE Operation against a key holding the wrong kind of value") + errParseInteger = errors.New("ERR value is not an integer or out of range") + errCRLFNotFound = errors.New("ERR CRLF not found in line") + errWrongArguments = errors.New("ERR wrong number of arguments") + errUnknownCommand = errors.New("ERR unknown command") + errOOM = errors.New("ERR command not allowed when out of memory") + errSyntax = errors.New("ERR syntax error") ) diff --git a/go.mod b/go.mod index fdf2d4d..aaa1ee6 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/xgzlucario/rotom go 1.22 require ( + github.com/alicebob/miniredis/v2 v2.33.0 github.com/chen3feng/stl4go v0.1.1 github.com/deckarep/golang-set/v2 v2.6.0 github.com/influxdata/tdigest v0.0.1 @@ -15,6 +16,7 @@ require ( ) require ( + github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect diff --git a/go.sum b/go.sum index 96b0fc2..ff057c1 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,7 @@ +github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a h1:HbKu58rmZpUGpz5+4FfNmIU+FmZg2P3Xaj2v2bfNWmk= +github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc= +github.com/alicebob/miniredis/v2 v2.33.0 h1:uvTF0EDeu9RLnUEG27Db5I68ESoIxTiXbNUiji6lZrA= +github.com/alicebob/miniredis/v2 v2.33.0/go.mod h1:MhP4a3EU7aENRi9aO+tHfTBZicLqQevyi/DJpoj6mi0= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= diff --git a/lua.go b/lua.go index bb58d1b..1dfb04c 100644 --- a/lua.go +++ b/lua.go @@ -5,34 +5,35 @@ import ( lua "github.com/yuin/gopher-lua" ) -func luaCall(L *lua.LState) int { +func OpenRedis(L *lua.LState) int { + mod := L.RegisterModule("redis", map[string]lua.LGFunction{ + "call": libCall, + }) + L.Push(mod) + return 1 +} + +func libCall(L *lua.LState) int { fn := L.ToString(1) - var keyTable, argvTable *lua.LTable - if t := L.GetGlobal("KEYS"); t.Type() == lua.LTTable { - keyTable = t.(*lua.LTable) - } - if t := L.GetGlobal("ARGV"); t.Type() != lua.LTNil { - argvTable = t.(*lua.LTable) - } switch fn { case "set": - return luaSet(L, keyTable, argvTable) + return libSet(L) case "get": - return luaGet(L, keyTable, argvTable) + return libGet(L) } - return 0 + return -1 } -func luaSet(L *lua.LState, keyTable, _ *lua.LTable) int { - key := keyTable.RawGetInt(1).String() - value := keyTable.RawGetInt(2).String() +func libSet(L *lua.LState) int { + key := L.ToString(2) + value := L.ToString(3) db.dict.Set(key, []byte(value)) L.Push(lua.LString("OK")) return 1 } -func luaGet(L *lua.LState, keyTable, _ *lua.LTable) int { - key := keyTable.RawGetInt(1).String() +func libGet(L *lua.LState) int { + key := L.ToString(2) value, ttl := db.dict.Get(key) if ttl != dict.KEY_NOT_EXIST { L.Push(lua.LString(value.([]byte))) diff --git a/resp.go b/resp.go index 27e53a0..2d8cfb3 100644 --- a/resp.go +++ b/resp.go @@ -68,7 +68,7 @@ func (r *RESPReader) ReadNextCommand(argsBuf []RESP) (args []RESP, n int, err er // read bulk strings for range for i := 0; i < num; i++ { if len(r.b) == 0 || r.b[0] != BULK { - return nil, 0, errInvalidArguments + return nil, 0, errWrongArguments } num, after, err := parseInt(r.b[1:]) @@ -78,7 +78,7 @@ func (r *RESPReader) ReadNextCommand(argsBuf []RESP) (args []RESP, n int, err er // bound check if num < 0 || num+2 > len(after) { - return nil, 0, errInvalidArguments + return nil, 0, errWrongArguments } args = append(args, after[:num]) @@ -91,7 +91,7 @@ func (r *RESPReader) ReadNextCommand(argsBuf []RESP) (args []RESP, n int, err er // command_inline format before, after, ok := bytes.Cut(r.b, CRLF) if !ok { - return nil, 0, errInvalidArguments + return nil, 0, errWrongArguments } args = append(args, before) r.b = after diff --git a/rotom.go b/rotom.go index b48ed16..df165ea 100644 --- a/rotom.go +++ b/rotom.go @@ -227,8 +227,11 @@ func initServer(config *Config) (err error) { } // init lua state L := lua.NewState() - L.SetGlobal("call", L.NewFunction(luaCall)) + L.Push(L.NewFunction(OpenRedis)) + L.Push(lua.LString("redis")) + L.Call(1, 0) server.lua = L + return nil }