Skip to content

Commit

Permalink
feat: support eval(lua scripts) command
Browse files Browse the repository at this point in the history
  • Loading branch information
satoshi-099 committed Sep 6, 2024
1 parent b7511c6 commit 0059c24
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 14 deletions.
35 changes: 35 additions & 0 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/xgzlucario/rotom/internal/hash"
"github.com/xgzlucario/rotom/internal/list"
"github.com/xgzlucario/rotom/internal/zset"
lua "github.com/yuin/gopher-lua"
)

var (
Expand Down Expand Up @@ -58,6 +59,7 @@ var cmdTable []*Command = []*Command{
{"zrank", zrankCommand, 2, false},
{"zpopmin", zpopminCommand, 1, true},
{"zrange", zrangeCommand, 3, false},
{"eval", evalCommand, 2, true},
{"ping", pingCommand, 0, false},
{"flushdb", flushdbCommand, 0, true},
// TODO
Expand Down Expand Up @@ -537,6 +539,39 @@ func flushdbCommand(writer *RESPWriter, _ []RESP) {
writer.WriteString("OK")
}

func evalCommand(writer *RESPWriter, args []RESP) {
L := server.lua
script := args[0].ToString()

numKeys, err := args[1].ToInt()
if err != nil {
writer.WriteError(err)
return
}

// set "KEYS" table
table := L.CreateTable(numKeys, 0)
for i := range numKeys {
table.Append(lua.LString(args[i+2]))
}
L.SetGlobal("KEYS", table)

if err := L.DoString(script); err != nil {
writer.WriteError(err)
return
}

ret := L.Get(-1)
switch ret.Type() {
case lua.LTString:
writer.WriteBulkString(ret.String())
case lua.LTNil:
writer.WriteNull()
default:
writer.WriteString("OK")
}
}

func todoCommand(writer *RESPWriter, _ []RESP) {
writer.WriteString("OK")
}
Expand Down
16 changes: 16 additions & 0 deletions command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,22 @@ func TestCommand(t *testing.T) {
assert.Equal(err.Error(), errWrongType.Error())
})

t.Run("eval", func(t *testing.T) {
keys := []string{"evalKey", "qwer"}

// set
_, err := rdb.Eval(ctx, "call('set',KEYS[0],KEYS[1])", keys).Result()
assert.Equal(err, redis.Nil)

// set with return
res, _ := rdb.Eval(ctx, "return call('set',KEYS[0],KEYS[1])", keys).Result()
assert.Equal(res, "OK")

// get
res, _ = rdb.Eval(ctx, "return call('get',KEYS[0])", keys[:1]).Result()
assert.Equal(res, keys[1])
})

t.Run("flushdb", func(t *testing.T) {
rdb.Set(ctx, "test-flush", "1", 0)
res, _ := rdb.FlushDB(ctx).Result()
Expand Down
3 changes: 2 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ require (
github.com/rs/zerolog v1.33.0
github.com/stretchr/testify v1.9.0
github.com/tidwall/mmap v0.3.0
golang.org/x/sys v0.24.0
github.com/yuin/gopher-lua v1.1.1
golang.org/x/sys v0.25.0
)

require (
Expand Down
6 changes: 4 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,17 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tidwall/mmap v0.3.0 h1:XXt1YsiXCF5/UAu3pLbu6g7iulJ9jsbs6vt7UpiV0sY=
github.com/tidwall/mmap v0.3.0/go.mod h1:2/dNzF5zA+te/JVHfrqNLcRkb8LjdH3c80vYHFQEZRk=
github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M=
github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw=
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 h1:2dVuKD2vS7b0QIHQbpyTISPd0LeHDbnYEryqj5Q1ug8=
golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56/go.mod h1:M4RDyNAINzryxdtnbRXRL/OHtkFuWGRjvuhBJpk2IlY=
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg=
golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34=
golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
gonum.org/v1/gonum v0.0.0-20181121035319-3f7ecaa7e8ca h1:PupagGYwj8+I4ubCxcmcBRk3VlUWtTg5huQpZR9flmE=
gonum.org/v1/gonum v0.0.0-20181121035319-3f7ecaa7e8ca/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo=
Expand Down
37 changes: 37 additions & 0 deletions lua.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package main

import (
"github.com/xgzlucario/rotom/internal/dict"
lua "github.com/yuin/gopher-lua"
)

func luaCall(L *lua.LState) int {
fn := L.ToString(1)
keys := L.GetGlobal("KEYS").(*lua.LTable)
switch fn {
case "set":
return luaSet(L, keys)
case "get":
return luaGet(L, keys)
}
return -1
}

func luaSet(L *lua.LState, keys *lua.LTable) int {
key := keys.RawGetInt(1).String()
value := keys.RawGetInt(2).String()
db.dict.Set(key, []byte(value))
L.Push(lua.LString("OK"))
return 1
}

func luaGet(L *lua.LState, keys *lua.LTable) int {
key := keys.RawGetInt(1).String()
value, ttl := db.dict.Get(key)
if ttl != dict.KEY_NOT_EXIST {
L.Push(lua.LString(value.([]byte)))
} else {
L.Push(lua.LNil)
}
return 1
}
27 changes: 16 additions & 11 deletions rotom.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/xgzlucario/rotom/internal/hash"
"github.com/xgzlucario/rotom/internal/list"
"github.com/xgzlucario/rotom/internal/zset"
lua "github.com/yuin/gopher-lua"
)

const (
Expand All @@ -33,22 +34,20 @@ type DB struct {
}

type Client struct {
fd int

recvx int
readx int
queryBuf []byte

fd int
recvx int
readx int
queryBuf []byte
argsBuf []RESP
replyWriter *RESPWriter
}

type Server struct {
fd int
config *Config
aeLoop *AeLoop
clients map[int]*Client

fd int
config *Config
aeLoop *AeLoop
clients map[int]*Client
lua *lua.LState
outOfMemory bool
}

Expand Down Expand Up @@ -215,15 +214,21 @@ func SendReplyToClient(loop *AeLoop, fd int, extra interface{}) {
func initServer(config *Config) (err error) {
server.config = config
server.clients = make(map[int]*Client)
// init aeloop
server.aeLoop, err = AeLoopCreate()
if err != nil {
return err
}
// init tcp server
server.fd, err = TcpServer(config.Port)
if err != nil {
Close(server.fd)
return err
}
// init lua state
L := lua.NewState()
L.SetGlobal("call", L.NewFunction(luaCall))
server.lua = L
return nil
}

Expand Down

0 comments on commit 0059c24

Please sign in to comment.