diff --git a/command.go b/command.go index 3044635..d851095 100644 --- a/command.go +++ b/command.go @@ -10,6 +10,7 @@ import ( "github.com/xgzlucario/rotom/internal/hash" "github.com/xgzlucario/rotom/internal/list" "github.com/xgzlucario/rotom/internal/zset" + lua "github.com/yuin/gopher-lua" ) var ( @@ -58,6 +59,7 @@ var cmdTable []*Command = []*Command{ {"zrank", zrankCommand, 2, false}, {"zpopmin", zpopminCommand, 1, true}, {"zrange", zrangeCommand, 3, false}, + {"eval", evalCommand, 2, true}, {"ping", pingCommand, 0, false}, {"flushdb", flushdbCommand, 0, true}, // TODO @@ -537,6 +539,39 @@ func flushdbCommand(writer *RESPWriter, _ []RESP) { writer.WriteString("OK") } +func evalCommand(writer *RESPWriter, args []RESP) { + L := server.lua + script := args[0].ToString() + + numKeys, err := args[1].ToInt() + if err != nil { + writer.WriteError(err) + return + } + + // set "KEYS" table + table := L.CreateTable(numKeys, 0) + for i := range numKeys { + table.Append(lua.LString(args[i+2])) + } + L.SetGlobal("KEYS", table) + + if err := L.DoString(script); err != nil { + writer.WriteError(err) + return + } + + ret := L.Get(-1) + switch ret.Type() { + case lua.LTString: + writer.WriteBulkString(ret.String()) + case lua.LTNil: + writer.WriteNull() + default: + writer.WriteString("OK") + } +} + func todoCommand(writer *RESPWriter, _ []RESP) { writer.WriteString("OK") } diff --git a/command_test.go b/command_test.go index c0974e4..0d95c4c 100644 --- a/command_test.go +++ b/command_test.go @@ -376,6 +376,22 @@ func TestCommand(t *testing.T) { assert.Equal(err.Error(), errWrongType.Error()) }) + t.Run("eval", func(t *testing.T) { + keys := []string{"evalKey", "qwer"} + + // set + _, err := rdb.Eval(ctx, "call('set',KEYS[0],KEYS[1])", keys).Result() + assert.Equal(err, redis.Nil) + + // set with return + res, _ := rdb.Eval(ctx, "return call('set',KEYS[0],KEYS[1])", keys).Result() + assert.Equal(res, "OK") + + // get + res, _ = rdb.Eval(ctx, "return call('get',KEYS[0])", keys[:1]).Result() + assert.Equal(res, keys[1]) + }) + t.Run("flushdb", func(t *testing.T) { rdb.Set(ctx, "test-flush", "1", 0) res, _ := rdb.FlushDB(ctx).Result() diff --git a/go.mod b/go.mod index c3bdbe3..fdf2d4d 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,8 @@ require ( github.com/rs/zerolog v1.33.0 github.com/stretchr/testify v1.9.0 github.com/tidwall/mmap v0.3.0 - golang.org/x/sys v0.24.0 + github.com/yuin/gopher-lua v1.1.1 + golang.org/x/sys v0.25.0 ) require ( diff --git a/go.sum b/go.sum index 8fb972a..96b0fc2 100644 --- a/go.sum +++ b/go.sum @@ -50,6 +50,8 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/tidwall/mmap v0.3.0 h1:XXt1YsiXCF5/UAu3pLbu6g7iulJ9jsbs6vt7UpiV0sY= github.com/tidwall/mmap v0.3.0/go.mod h1:2/dNzF5zA+te/JVHfrqNLcRkb8LjdH3c80vYHFQEZRk= +github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M= +github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 h1:2dVuKD2vS7b0QIHQbpyTISPd0LeHDbnYEryqj5Q1ug8= golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56/go.mod h1:M4RDyNAINzryxdtnbRXRL/OHtkFuWGRjvuhBJpk2IlY= @@ -57,8 +59,8 @@ golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg= -golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= +golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= gonum.org/v1/gonum v0.0.0-20181121035319-3f7ecaa7e8ca h1:PupagGYwj8+I4ubCxcmcBRk3VlUWtTg5huQpZR9flmE= gonum.org/v1/gonum v0.0.0-20181121035319-3f7ecaa7e8ca/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo= diff --git a/lua.go b/lua.go new file mode 100644 index 0000000..e18d3d8 --- /dev/null +++ b/lua.go @@ -0,0 +1,37 @@ +package main + +import ( + "github.com/xgzlucario/rotom/internal/dict" + lua "github.com/yuin/gopher-lua" +) + +func luaCall(L *lua.LState) int { + fn := L.ToString(1) + keys := L.GetGlobal("KEYS").(*lua.LTable) + switch fn { + case "set": + return luaSet(L, keys) + case "get": + return luaGet(L, keys) + } + return -1 +} + +func luaSet(L *lua.LState, keys *lua.LTable) int { + key := keys.RawGetInt(1).String() + value := keys.RawGetInt(2).String() + db.dict.Set(key, []byte(value)) + L.Push(lua.LString("OK")) + return 1 +} + +func luaGet(L *lua.LState, keys *lua.LTable) int { + key := keys.RawGetInt(1).String() + value, ttl := db.dict.Get(key) + if ttl != dict.KEY_NOT_EXIST { + L.Push(lua.LString(value.([]byte))) + } else { + L.Push(lua.LNil) + } + return 1 +} diff --git a/rotom.go b/rotom.go index 26d77c2..b48ed16 100644 --- a/rotom.go +++ b/rotom.go @@ -11,6 +11,7 @@ import ( "github.com/xgzlucario/rotom/internal/hash" "github.com/xgzlucario/rotom/internal/list" "github.com/xgzlucario/rotom/internal/zset" + lua "github.com/yuin/gopher-lua" ) const ( @@ -33,22 +34,20 @@ type DB struct { } type Client struct { - fd int - - recvx int - readx int - queryBuf []byte - + fd int + recvx int + readx int + queryBuf []byte argsBuf []RESP replyWriter *RESPWriter } type Server struct { - fd int - config *Config - aeLoop *AeLoop - clients map[int]*Client - + fd int + config *Config + aeLoop *AeLoop + clients map[int]*Client + lua *lua.LState outOfMemory bool } @@ -215,15 +214,21 @@ func SendReplyToClient(loop *AeLoop, fd int, extra interface{}) { func initServer(config *Config) (err error) { server.config = config server.clients = make(map[int]*Client) + // init aeloop server.aeLoop, err = AeLoopCreate() if err != nil { return err } + // init tcp server server.fd, err = TcpServer(config.Port) if err != nil { Close(server.fd) return err } + // init lua state + L := lua.NewState() + L.SetGlobal("call", L.NewFunction(luaCall)) + server.lua = L return nil }