Skip to content

Commit

Permalink
refactor: remove dict object type
Browse files Browse the repository at this point in the history
  • Loading branch information
satoshi-099 committed Aug 20, 2024
1 parent a69e5be commit 110e630
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 207 deletions.
7 changes: 6 additions & 1 deletion aof_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,17 @@ func TestAof(t *testing.T) {
defer aof.Close()
})

t.Run("read-error", func(t *testing.T) {
t.Run("empty-aof", func(t *testing.T) {
aof, _ := NewAof("not-exist.aof")
defer aof.Close()

aof.Read(func(args []RESP) {
panic("should not call")
})
})

t.Run("read-wrong-file", func(t *testing.T) {
_, err := NewAof("internal")
assert.NotNil(err)
})
}
110 changes: 46 additions & 64 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,24 +141,21 @@ func incrCommand(writer *RESPWriter, args []RESP) {
return
}

switch object.Type() {
case dict.TypeInteger:
num := object.Data().(int) + 1
object.SetData(num)
switch v := object.(type) {
case int:
num := v + 1
writer.WriteInteger(num)
db.dict.Set(strings.Clone(key), num)

case dict.TypeString:
case []byte:
// conv to integer
bytes := object.Data().([]byte)
num, err := RESP(bytes).ToInt()
num, err := RESP(v).ToInt()
if err != nil {
writer.WriteError(errParseInteger)
return
}
num++
bytes = bytes[:0]
bytes = strconv.AppendInt(bytes, int64(num), 10)
object.SetData(bytes)
strconv.AppendInt(v[:0], int64(num), 10)
writer.WriteInteger(num)

default:
Expand All @@ -168,22 +165,16 @@ func incrCommand(writer *RESPWriter, args []RESP) {

func getCommand(writer *RESPWriter, args []RESP) {
key := args[0].ToStringUnsafe()

object, ttl := db.dict.Get(key)
if ttl == dict.KEY_NOT_EXIST {
writer.WriteNull()
return
}

switch object.Type() {
case dict.TypeInteger:
num := object.Data().(int)
writer.WriteBulkString(strconv.Itoa(num))

case dict.TypeString:
bytes := object.Data().([]byte)
writer.WriteBulk(bytes)

switch v := object.(type) {
case int:
writer.WriteBulkString(strconv.Itoa(v))
case []byte:
writer.WriteBulk(v)
default:
writer.WriteError(errWrongType)
}
Expand All @@ -200,7 +191,7 @@ func delCommand(writer *RESPWriter, args []RESP) {
}

func hsetCommand(writer *RESPWriter, args []RESP) {
hash := args[0].ToStringUnsafe()
hash := args[0]
args = args[1:]

if len(args)%2 == 1 {
Expand All @@ -226,15 +217,13 @@ func hsetCommand(writer *RESPWriter, args []RESP) {
}

func hgetCommand(writer *RESPWriter, args []RESP) {
hash := args[0].ToStringUnsafe()
hash := args[0]
key := args[1].ToStringUnsafe()

hmap, err := fetchMap(hash)
if err != nil {
writer.WriteError(errWrongType)
return
}

value, ok := hmap.Get(key)
if ok {
writer.WriteBulk(value)
Expand All @@ -244,9 +233,8 @@ func hgetCommand(writer *RESPWriter, args []RESP) {
}

func hdelCommand(writer *RESPWriter, args []RESP) {
hash := args[0].ToStringUnsafe()
hash := args[0]
keys := args[1:]

hmap, err := fetchMap(hash)
if err != nil {
writer.WriteError(err)
Expand All @@ -262,7 +250,7 @@ func hdelCommand(writer *RESPWriter, args []RESP) {
}

func hgetallCommand(writer *RESPWriter, args []RESP) {
hash := args[0].ToStringUnsafe()
hash := args[0]
hmap, err := fetchMap(hash)
if err != nil {
writer.WriteError(err)
Expand All @@ -276,7 +264,7 @@ func hgetallCommand(writer *RESPWriter, args []RESP) {
}

func lpushCommand(writer *RESPWriter, args []RESP) {
key := args[0].ToStringUnsafe()
key := args[0]
ls, err := fetchList(key, true)
if err != nil {
writer.WriteError(err)
Expand All @@ -289,7 +277,7 @@ func lpushCommand(writer *RESPWriter, args []RESP) {
}

func rpushCommand(writer *RESPWriter, args []RESP) {
key := args[0].ToStringUnsafe()
key := args[0]
ls, err := fetchList(key, true)
if err != nil {
writer.WriteError(err)
Expand All @@ -302,7 +290,7 @@ func rpushCommand(writer *RESPWriter, args []RESP) {
}

func lpopCommand(writer *RESPWriter, args []RESP) {
key := args[0].ToStringUnsafe()
key := args[0]
ls, err := fetchList(key)
if err != nil {
writer.WriteError(err)
Expand All @@ -317,7 +305,7 @@ func lpopCommand(writer *RESPWriter, args []RESP) {
}

func rpopCommand(writer *RESPWriter, args []RESP) {
key := args[0].ToStringUnsafe()
key := args[0]
ls, err := fetchList(key)
if err != nil {
writer.WriteError(err)
Expand All @@ -332,7 +320,7 @@ func rpopCommand(writer *RESPWriter, args []RESP) {
}

func lrangeCommand(writer *RESPWriter, args []RESP) {
key := args[0].ToStringUnsafe()
key := args[0]
start, err := args[1].ToInt()
if err != nil {
writer.WriteError(err)
Expand Down Expand Up @@ -361,26 +349,23 @@ func lrangeCommand(writer *RESPWriter, args []RESP) {
}

func saddCommand(writer *RESPWriter, args []RESP) {
key := args[0].ToStringUnsafe()
args = args[1:]

key := args[0]
set, err := fetchSet(key, true)
if err != nil {
writer.WriteError(err)
return
}

var count int
for i := 0; i < len(args); i++ {
if set.Add(args[i].ToString()) {
for _, arg := range args[1:] {
if set.Add(arg.ToString()) {
count++
}
}
writer.WriteInteger(count)
}

func sremCommand(writer *RESPWriter, args []RESP) {
key := args[0].ToStringUnsafe()
key := args[0]
set, err := fetchSet(key)
if err != nil {
writer.WriteError(err)
Expand All @@ -396,7 +381,7 @@ func sremCommand(writer *RESPWriter, args []RESP) {
}

func spopCommand(writer *RESPWriter, args []RESP) {
key := args[0].ToStringUnsafe()
key := args[0]
set, err := fetchSet(key)
if err != nil {
writer.WriteError(err)
Expand All @@ -411,7 +396,7 @@ func spopCommand(writer *RESPWriter, args []RESP) {
}

func zaddCommand(writer *RESPWriter, args []RESP) {
key := args[0].ToStringUnsafe()
key := args[0]
args = args[1:]

zset, err := fetchZSet(key, true)
Expand All @@ -436,7 +421,7 @@ func zaddCommand(writer *RESPWriter, args []RESP) {
}

func zrankCommand(writer *RESPWriter, args []RESP) {
key := args[0].ToStringUnsafe()
key := args[0]
member := args[1].ToStringUnsafe()

zset, err := fetchZSet(key)
Expand All @@ -454,7 +439,7 @@ func zrankCommand(writer *RESPWriter, args []RESP) {
}

func zremCommand(writer *RESPWriter, args []RESP) {
key := args[0].ToStringUnsafe()
key := args[0]
zset, err := fetchZSet(key)
if err != nil {
writer.WriteError(err)
Expand All @@ -470,7 +455,7 @@ func zremCommand(writer *RESPWriter, args []RESP) {
}

func zrangeCommand(writer *RESPWriter, args []RESP) {
key := args[0].ToStringUnsafe()
key := args[0]
start, err := args[1].ToInt()
if err != nil {
writer.WriteError(err)
Expand Down Expand Up @@ -509,7 +494,7 @@ func zrangeCommand(writer *RESPWriter, args []RESP) {
}

func zpopminCommand(writer *RESPWriter, args []RESP) {
key := args[0].ToStringUnsafe()
key := args[0]
count := 1
var err error
if len(args) > 1 {
Expand Down Expand Up @@ -544,56 +529,53 @@ func todoCommand(writer *RESPWriter, _ []RESP) {
writer.WriteString("OK")
}

func fetchMap(key string, setnx ...bool) (Map, error) {
func fetchMap(key []byte, setnx ...bool) (Map, error) {
return fetch(key, func() Map { return hash.NewZipMap() }, setnx...)
}

func fetchList(key string, setnx ...bool) (List, error) {
func fetchList(key []byte, setnx ...bool) (List, error) {
return fetch(key, func() List { return list.New() }, setnx...)
}

func fetchSet(key string, setnx ...bool) (Set, error) {
func fetchSet(key []byte, setnx ...bool) (Set, error) {
return fetch(key, func() Set { return hash.NewZipSet() }, setnx...)
}

func fetchZSet(key string, setnx ...bool) (ZSet, error) {
func fetchZSet(key []byte, setnx ...bool) (ZSet, error) {
return fetch(key, func() ZSet { return zset.NewZSet() }, setnx...)
}

func fetch[T any](key string, new func() T, setnx ...bool) (T, error) {
object, ttl := db.dict.Get(key)
func fetch[T any](key []byte, new func() T, setnx ...bool) (T, error) {
object, ttl := db.dict.Get(b2s(key))

if ttl != dict.KEY_NOT_EXIST {
v, ok := object.Data().(T)
v, ok := object.(T)
if !ok {
return v, errWrongType
}

// conversion zipped structure
if len(setnx) > 0 && setnx[0] {
switch object.Type() {
case dict.TypeZipMap:
zm := object.Data().(*hash.ZipMap)
if zm.Len() < 256 {
switch data := object.(type) {
case *hash.ZipMap:
if data.Len() < 256 {
break
}
object.SetData(zm.ToMap())
db.dict.Set(string(key), data.ToMap())

case dict.TypeZipSet:
zm := object.Data().(*hash.ZipSet)
if zm.Len() < 512 {
case *hash.ZipSet:
if data.Len() < 512 {
break
}
object.SetData(zm.ToSet())
db.dict.Set(string(key), data.ToSet())
}
}
return v, nil
}

v := new()
if len(setnx) > 0 && setnx[0] {
// make sure `key` is copy
db.dict.Set(strings.Clone(key), v)
db.dict.Set(string(key), v)
}

return v, nil
Expand Down
20 changes: 20 additions & 0 deletions command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,14 @@ func TestCommand(t *testing.T) {
assert.Equal(n, int64(1))
})

t.Run("error-get", func(t *testing.T) {
lskey := fmt.Sprintf("ls-%d", time.Now().UnixNano())
rdb.RPush(ctx, lskey, "1")

_, err := rdb.Get(ctx, lskey).Result()
assert.Equal(err.Error(), errWrongType.Error())
})

t.Run("setex", func(t *testing.T) {
res, _ := rdb.Set(ctx, "foo", "bar", time.Second).Result()
assert.Equal(res, "OK")
Expand Down Expand Up @@ -251,6 +259,18 @@ func TestCommand(t *testing.T) {
rdb.SAdd(ctx, "set", "k1", "k2", "k3").Result()
res, _ := rdb.SRem(ctx, "set", "k1", "k2", "k999").Result()
assert.Equal(res, int64(2))

// error wrong type
rdb.Set(ctx, "key", "value", 0)

_, err = rdb.SAdd(ctx, "key", "1").Result()
assert.Equal(err.Error(), errWrongType.Error())

_, err = rdb.SRem(ctx, "key", "1").Result()
assert.Equal(err.Error(), errWrongType.Error())

_, err = rdb.SPop(ctx, "key").Result()
assert.Equal(err.Error(), errWrongType.Error())
})

t.Run("zset", func(t *testing.T) {
Expand Down
Loading

0 comments on commit 110e630

Please sign in to comment.