Skip to content

Commit

Permalink
feat: eval support KEYS and ARGV params
Browse files Browse the repository at this point in the history
  • Loading branch information
satoshi-099 committed Sep 7, 2024
1 parent 0059c24 commit 6a90904
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 32 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ This is rotom, a high performance, low latency tiny Redis Server written in Go.
3. Implements data structures such as dict, list, map, zipmap, set, zipset, and zset.
4. Supports AOF.
5. Supports 20+ commonly used commands.
6. Supports execute lua scripts.

## AELoop

Expand Down
1 change: 1 addition & 0 deletions README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
3. 实现了 dict, list, map, zipmap, set, zipset, zset 数据结构
4. AOF 支持
5. 支持 20 多种常用命令
6. 支持执行 lua 脚本

## AELoop 事件循环

Expand Down
52 changes: 39 additions & 13 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -542,34 +542,60 @@ func flushdbCommand(writer *RESPWriter, _ []RESP) {
func evalCommand(writer *RESPWriter, args []RESP) {
L := server.lua
script := args[0].ToString()

numKeys, err := args[1].ToInt()
if err != nil {
writer.WriteError(err)
return
}
numArgv := len(args) - 2 - numKeys

// set "KEYS" table
table := L.CreateTable(numKeys, 0)
for i := range numKeys {
table.Append(lua.LString(args[i+2]))
if numKeys > 0 {
keyTable := L.CreateTable(numKeys, 0)
for i := 2; i < numKeys+2; i++ {
keyTable.Append(lua.LString(args[i]))
}
L.SetGlobal("KEYS", keyTable)
}

// set "ARGV" table
if numArgv > 0 {
argvTable := L.CreateTable(numArgv, 0)
for i := 2 + numKeys; i < len(args); i++ {
argvTable.Append(lua.LString(args[i]))
}
L.SetGlobal("ARGV", argvTable)
}
L.SetGlobal("KEYS", table)

if err := L.DoString(script); err != nil {
writer.WriteError(err)
return
}

ret := L.Get(-1)
switch ret.Type() {
case lua.LTString:
writer.WriteBulkString(ret.String())
case lua.LTNil:
writer.WriteNull()
default:
writer.WriteString("OK")
var serialize func(isRoot bool, ret lua.LValue)
serialize = func(isRoot bool, ret lua.LValue) {
switch res := ret.(type) {
case lua.LString:
writer.WriteBulkString(res.String())

case lua.LNumber:
writer.WriteInteger(int(res)) // convert to integer

case *lua.LTable:
writer.WriteArrayHead(res.Len())
res.ForEach(func(index, value lua.LValue) {
serialize(false, value)
})

default:
writer.WriteNull()
}

if isRoot && ret.Type() != lua.LTNil {
L.Pop(1)
}
}
serialize(true, L.Get(-1))
}

func todoCommand(writer *RESPWriter, _ []RESP) {
Expand Down
45 changes: 35 additions & 10 deletions command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -377,19 +377,44 @@ func TestCommand(t *testing.T) {
})

t.Run("eval", func(t *testing.T) {
keys := []string{"evalKey", "qwer"}
{
res, err := rdb.Eval(ctx,
"return {KEYS[1],KEYS[2],ARGV[1],ARGV[2]}",
[]string{"key1", "key2"},
[]any{"first", "second"},
).Result()
assert.Equal(res, []any{"key1", "key2", "first", "second"})
assert.Nil(err)
}
{
res, err := rdb.Eval(ctx, "return {1,2,3}", []string{}).Result()
assert.Equal(res, []any{int64(1), int64(2), int64(3)})
assert.Nil(err)
}
{
keys := []string{"evalKey", "qwer"}

// set
_, err := rdb.Eval(ctx, "call('set',KEYS[0],KEYS[1])", keys).Result()
assert.Equal(err, redis.Nil)
// set
_, err := rdb.Eval(ctx, "call('set',KEYS[0],KEYS[1])", keys).Result()
assert.Equal(err, redis.Nil)

// set with return
res, _ := rdb.Eval(ctx, "return call('set',KEYS[0],KEYS[1])", keys).Result()
assert.Equal(res, "OK")
// set with return
res, _ := rdb.Eval(ctx, "return call('set',KEYS[0],KEYS[1])", keys).Result()
assert.Equal(res, "OK")

// get
res, _ = rdb.Eval(ctx, "return call('get',KEYS[0])", keys[:1]).Result()
assert.Equal(res, keys[1])
// get
res, _ = rdb.Eval(ctx, "return call('get',KEYS[0])", keys[:1]).Result()
assert.Equal(res, keys[1])

// get nil
_, err = rdb.Eval(ctx, "return call('get',KEYS[0])", []string{"notExistKey"}).Result()
assert.Equal(err, redis.Nil)
}
{
// unknown function
_, err := rdb.Eval(ctx, "call('wwwww','aaa')", []string{}).Result()
assert.NotNil(err) // TODO
}
})

t.Run("flushdb", func(t *testing.T) {
Expand Down
24 changes: 15 additions & 9 deletions lua.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,32 @@ import (

func luaCall(L *lua.LState) int {
fn := L.ToString(1)
keys := L.GetGlobal("KEYS").(*lua.LTable)
var keyTable, argvTable *lua.LTable
if t := L.GetGlobal("KEYS"); t.Type() == lua.LTTable {
keyTable = t.(*lua.LTable)
}
if t := L.GetGlobal("ARGV"); t.Type() != lua.LTNil {
argvTable = t.(*lua.LTable)
}
switch fn {
case "set":
return luaSet(L, keys)
return luaSet(L, keyTable, argvTable)
case "get":
return luaGet(L, keys)
return luaGet(L, keyTable, argvTable)
}
return -1
return 0
}

func luaSet(L *lua.LState, keys *lua.LTable) int {
key := keys.RawGetInt(1).String()
value := keys.RawGetInt(2).String()
func luaSet(L *lua.LState, keyTable, _ *lua.LTable) int {
key := keyTable.RawGetInt(1).String()
value := keyTable.RawGetInt(2).String()
db.dict.Set(key, []byte(value))
L.Push(lua.LString("OK"))
return 1
}

func luaGet(L *lua.LState, keys *lua.LTable) int {
key := keys.RawGetInt(1).String()
func luaGet(L *lua.LState, keyTable, _ *lua.LTable) int {
key := keyTable.RawGetInt(1).String()
value, ttl := db.dict.Get(key)
if ttl != dict.KEY_NOT_EXIST {
L.Push(lua.LString(value.([]byte)))
Expand Down

0 comments on commit 6a90904

Please sign in to comment.