From 110e6304860c2ea8d55517a216b8a7c739da8455 Mon Sep 17 00:00:00 2001 From: guangzhixu Date: Tue, 20 Aug 2024 14:29:28 +0800 Subject: [PATCH] refactor: remove dict object type --- aof_test.go | 7 ++- command.go | 110 ++++++++++++++++--------------------- command_test.go | 20 +++++++ internal/dict/dict.go | 74 +++++++++++-------------- internal/dict/dict_test.go | 42 ++++---------- internal/dict/object.go | 68 ----------------------- 6 files changed, 114 insertions(+), 207 deletions(-) delete mode 100644 internal/dict/object.go diff --git a/aof_test.go b/aof_test.go index 555323a..4edd516 100644 --- a/aof_test.go +++ b/aof_test.go @@ -35,7 +35,7 @@ func TestAof(t *testing.T) { defer aof.Close() }) - t.Run("read-error", func(t *testing.T) { + t.Run("empty-aof", func(t *testing.T) { aof, _ := NewAof("not-exist.aof") defer aof.Close() @@ -43,4 +43,9 @@ func TestAof(t *testing.T) { panic("should not call") }) }) + + t.Run("read-wrong-file", func(t *testing.T) { + _, err := NewAof("internal") + assert.NotNil(err) + }) } diff --git a/command.go b/command.go index 59ad37e..f62ea1e 100644 --- a/command.go +++ b/command.go @@ -141,24 +141,21 @@ func incrCommand(writer *RESPWriter, args []RESP) { return } - switch object.Type() { - case dict.TypeInteger: - num := object.Data().(int) + 1 - object.SetData(num) + switch v := object.(type) { + case int: + num := v + 1 writer.WriteInteger(num) + db.dict.Set(strings.Clone(key), num) - case dict.TypeString: + case []byte: // conv to integer - bytes := object.Data().([]byte) - num, err := RESP(bytes).ToInt() + num, err := RESP(v).ToInt() if err != nil { writer.WriteError(errParseInteger) return } num++ - bytes = bytes[:0] - bytes = strconv.AppendInt(bytes, int64(num), 10) - object.SetData(bytes) + strconv.AppendInt(v[:0], int64(num), 10) writer.WriteInteger(num) default: @@ -168,22 +165,16 @@ func incrCommand(writer *RESPWriter, args []RESP) { func getCommand(writer *RESPWriter, args []RESP) { key := args[0].ToStringUnsafe() - object, ttl := db.dict.Get(key) if ttl == dict.KEY_NOT_EXIST { writer.WriteNull() return } - - switch object.Type() { - case dict.TypeInteger: - num := object.Data().(int) - writer.WriteBulkString(strconv.Itoa(num)) - - case dict.TypeString: - bytes := object.Data().([]byte) - writer.WriteBulk(bytes) - + switch v := object.(type) { + case int: + writer.WriteBulkString(strconv.Itoa(v)) + case []byte: + writer.WriteBulk(v) default: writer.WriteError(errWrongType) } @@ -200,7 +191,7 @@ func delCommand(writer *RESPWriter, args []RESP) { } func hsetCommand(writer *RESPWriter, args []RESP) { - hash := args[0].ToStringUnsafe() + hash := args[0] args = args[1:] if len(args)%2 == 1 { @@ -226,15 +217,13 @@ func hsetCommand(writer *RESPWriter, args []RESP) { } func hgetCommand(writer *RESPWriter, args []RESP) { - hash := args[0].ToStringUnsafe() + hash := args[0] key := args[1].ToStringUnsafe() - hmap, err := fetchMap(hash) if err != nil { writer.WriteError(errWrongType) return } - value, ok := hmap.Get(key) if ok { writer.WriteBulk(value) @@ -244,9 +233,8 @@ func hgetCommand(writer *RESPWriter, args []RESP) { } func hdelCommand(writer *RESPWriter, args []RESP) { - hash := args[0].ToStringUnsafe() + hash := args[0] keys := args[1:] - hmap, err := fetchMap(hash) if err != nil { writer.WriteError(err) @@ -262,7 +250,7 @@ func hdelCommand(writer *RESPWriter, args []RESP) { } func hgetallCommand(writer *RESPWriter, args []RESP) { - hash := args[0].ToStringUnsafe() + hash := args[0] hmap, err := fetchMap(hash) if err != nil { writer.WriteError(err) @@ -276,7 +264,7 @@ func hgetallCommand(writer *RESPWriter, args []RESP) { } func lpushCommand(writer *RESPWriter, args []RESP) { - key := args[0].ToStringUnsafe() + key := args[0] ls, err := fetchList(key, true) if err != nil { writer.WriteError(err) @@ -289,7 +277,7 @@ func lpushCommand(writer *RESPWriter, args []RESP) { } func rpushCommand(writer *RESPWriter, args []RESP) { - key := args[0].ToStringUnsafe() + key := args[0] ls, err := fetchList(key, true) if err != nil { writer.WriteError(err) @@ -302,7 +290,7 @@ func rpushCommand(writer *RESPWriter, args []RESP) { } func lpopCommand(writer *RESPWriter, args []RESP) { - key := args[0].ToStringUnsafe() + key := args[0] ls, err := fetchList(key) if err != nil { writer.WriteError(err) @@ -317,7 +305,7 @@ func lpopCommand(writer *RESPWriter, args []RESP) { } func rpopCommand(writer *RESPWriter, args []RESP) { - key := args[0].ToStringUnsafe() + key := args[0] ls, err := fetchList(key) if err != nil { writer.WriteError(err) @@ -332,7 +320,7 @@ func rpopCommand(writer *RESPWriter, args []RESP) { } func lrangeCommand(writer *RESPWriter, args []RESP) { - key := args[0].ToStringUnsafe() + key := args[0] start, err := args[1].ToInt() if err != nil { writer.WriteError(err) @@ -361,18 +349,15 @@ func lrangeCommand(writer *RESPWriter, args []RESP) { } func saddCommand(writer *RESPWriter, args []RESP) { - key := args[0].ToStringUnsafe() - args = args[1:] - + key := args[0] set, err := fetchSet(key, true) if err != nil { writer.WriteError(err) return } - var count int - for i := 0; i < len(args); i++ { - if set.Add(args[i].ToString()) { + for _, arg := range args[1:] { + if set.Add(arg.ToString()) { count++ } } @@ -380,7 +365,7 @@ func saddCommand(writer *RESPWriter, args []RESP) { } func sremCommand(writer *RESPWriter, args []RESP) { - key := args[0].ToStringUnsafe() + key := args[0] set, err := fetchSet(key) if err != nil { writer.WriteError(err) @@ -396,7 +381,7 @@ func sremCommand(writer *RESPWriter, args []RESP) { } func spopCommand(writer *RESPWriter, args []RESP) { - key := args[0].ToStringUnsafe() + key := args[0] set, err := fetchSet(key) if err != nil { writer.WriteError(err) @@ -411,7 +396,7 @@ func spopCommand(writer *RESPWriter, args []RESP) { } func zaddCommand(writer *RESPWriter, args []RESP) { - key := args[0].ToStringUnsafe() + key := args[0] args = args[1:] zset, err := fetchZSet(key, true) @@ -436,7 +421,7 @@ func zaddCommand(writer *RESPWriter, args []RESP) { } func zrankCommand(writer *RESPWriter, args []RESP) { - key := args[0].ToStringUnsafe() + key := args[0] member := args[1].ToStringUnsafe() zset, err := fetchZSet(key) @@ -454,7 +439,7 @@ func zrankCommand(writer *RESPWriter, args []RESP) { } func zremCommand(writer *RESPWriter, args []RESP) { - key := args[0].ToStringUnsafe() + key := args[0] zset, err := fetchZSet(key) if err != nil { writer.WriteError(err) @@ -470,7 +455,7 @@ func zremCommand(writer *RESPWriter, args []RESP) { } func zrangeCommand(writer *RESPWriter, args []RESP) { - key := args[0].ToStringUnsafe() + key := args[0] start, err := args[1].ToInt() if err != nil { writer.WriteError(err) @@ -509,7 +494,7 @@ func zrangeCommand(writer *RESPWriter, args []RESP) { } func zpopminCommand(writer *RESPWriter, args []RESP) { - key := args[0].ToStringUnsafe() + key := args[0] count := 1 var err error if len(args) > 1 { @@ -544,47 +529,45 @@ func todoCommand(writer *RESPWriter, _ []RESP) { writer.WriteString("OK") } -func fetchMap(key string, setnx ...bool) (Map, error) { +func fetchMap(key []byte, setnx ...bool) (Map, error) { return fetch(key, func() Map { return hash.NewZipMap() }, setnx...) } -func fetchList(key string, setnx ...bool) (List, error) { +func fetchList(key []byte, setnx ...bool) (List, error) { return fetch(key, func() List { return list.New() }, setnx...) } -func fetchSet(key string, setnx ...bool) (Set, error) { +func fetchSet(key []byte, setnx ...bool) (Set, error) { return fetch(key, func() Set { return hash.NewZipSet() }, setnx...) } -func fetchZSet(key string, setnx ...bool) (ZSet, error) { +func fetchZSet(key []byte, setnx ...bool) (ZSet, error) { return fetch(key, func() ZSet { return zset.NewZSet() }, setnx...) } -func fetch[T any](key string, new func() T, setnx ...bool) (T, error) { - object, ttl := db.dict.Get(key) +func fetch[T any](key []byte, new func() T, setnx ...bool) (T, error) { + object, ttl := db.dict.Get(b2s(key)) if ttl != dict.KEY_NOT_EXIST { - v, ok := object.Data().(T) + v, ok := object.(T) if !ok { return v, errWrongType } // conversion zipped structure if len(setnx) > 0 && setnx[0] { - switch object.Type() { - case dict.TypeZipMap: - zm := object.Data().(*hash.ZipMap) - if zm.Len() < 256 { + switch data := object.(type) { + case *hash.ZipMap: + if data.Len() < 256 { break } - object.SetData(zm.ToMap()) + db.dict.Set(string(key), data.ToMap()) - case dict.TypeZipSet: - zm := object.Data().(*hash.ZipSet) - if zm.Len() < 512 { + case *hash.ZipSet: + if data.Len() < 512 { break } - object.SetData(zm.ToSet()) + db.dict.Set(string(key), data.ToSet()) } } return v, nil @@ -592,8 +575,7 @@ func fetch[T any](key string, new func() T, setnx ...bool) (T, error) { v := new() if len(setnx) > 0 && setnx[0] { - // make sure `key` is copy - db.dict.Set(strings.Clone(key), v) + db.dict.Set(string(key), v) } return v, nil diff --git a/command_test.go b/command_test.go index 16aec2b..27beb31 100644 --- a/command_test.go +++ b/command_test.go @@ -60,6 +60,14 @@ func TestCommand(t *testing.T) { assert.Equal(n, int64(1)) }) + t.Run("error-get", func(t *testing.T) { + lskey := fmt.Sprintf("ls-%d", time.Now().UnixNano()) + rdb.RPush(ctx, lskey, "1") + + _, err := rdb.Get(ctx, lskey).Result() + assert.Equal(err.Error(), errWrongType.Error()) + }) + t.Run("setex", func(t *testing.T) { res, _ := rdb.Set(ctx, "foo", "bar", time.Second).Result() assert.Equal(res, "OK") @@ -251,6 +259,18 @@ func TestCommand(t *testing.T) { rdb.SAdd(ctx, "set", "k1", "k2", "k3").Result() res, _ := rdb.SRem(ctx, "set", "k1", "k2", "k999").Result() assert.Equal(res, int64(2)) + + // error wrong type + rdb.Set(ctx, "key", "value", 0) + + _, err = rdb.SAdd(ctx, "key", "1").Result() + assert.Equal(err.Error(), errWrongType.Error()) + + _, err = rdb.SRem(ctx, "key", "1").Result() + assert.Equal(err.Error(), errWrongType.Error()) + + _, err = rdb.SPop(ctx, "key").Result() + assert.Equal(err.Error(), errWrongType.Error()) }) t.Run("zset", func(t *testing.T) { diff --git a/internal/dict/dict.go b/internal/dict/dict.go index 46ac0ba..ecc838c 100644 --- a/internal/dict/dict.go +++ b/internal/dict/dict.go @@ -8,7 +8,7 @@ import ( ) const ( - TTL_DEFAULT = -1 + TTL_FOREVER = -1 KEY_NOT_EXIST = -2 ) @@ -34,70 +34,57 @@ func GetNanoTime() int64 { // Dict is the hashmap for Rotom. type Dict struct { - data *swiss.Map[string, *Object] + data *swiss.Map[string, any] expire *swiss.Map[string, int64] } func New() *Dict { return &Dict{ - data: swiss.NewMap[string, *Object](64), + data: swiss.NewMap[string, any](64), expire: swiss.NewMap[string, int64](64), } } -func (dict *Dict) Get(key string) (*Object, int) { - object, ok := dict.data.Get(key) +func (dict *Dict) Get(key string) (any, int) { + data, ok := dict.data.Get(key) if !ok { // key not exist return nil, KEY_NOT_EXIST } - object.lastAccessd = _sec.Load() + nsec, ok := dict.expire.Get(key) + if !ok { + return data, TTL_FOREVER + } - if object.hasTTL { - nsec, _ := dict.expire.Get(key) - // key expired - if nsec < _nsec.Load() { - dict.data.Delete(key) - dict.expire.Delete(key) - return nil, KEY_NOT_EXIST - } - return object, nsec2duration(nsec) + // key expired + if nsec < _nsec.Load() { + dict.data.Delete(key) + dict.expire.Delete(key) + return nil, KEY_NOT_EXIST } - return object, TTL_DEFAULT + return data, nsec2duration(nsec) } func (dict *Dict) Set(key string, data any) { - dict.data.Put(key, &Object{ - typ: typeOfData(data), - lastAccessd: _sec.Load(), - data: data, - }) + dict.data.Put(key, data) } func (dict *Dict) SetWithTTL(key string, data any, ttl int64) { - object := &Object{ - typ: typeOfData(data), - lastAccessd: _sec.Load(), - data: data, - } if ttl > 0 { dict.expire.Put(key, ttl) - object.hasTTL = true } - dict.data.Put(key, object) + dict.data.Put(key, data) } func (dict *Dict) Delete(key string) bool { - object, ok := dict.data.Get(key) + _, ok := dict.data.Get(key) if !ok { return false } dict.data.Delete(key) - if object.hasTTL { - dict.expire.Delete(key) - } + dict.expire.Delete(key) return true } @@ -105,22 +92,21 @@ func (dict *Dict) Delete(key string) bool { // return `0` if key not exist or expired. // return `1` if set successed. func (dict *Dict) SetTTL(key string, ttl int64) int { - object, ok := dict.data.Get(key) + _, ok := dict.data.Get(key) if !ok { // key not exist return 0 } - if object.hasTTL { - nsec, _ := dict.expire.Get(key) - // key expired - if nsec < _nsec.Load() { - dict.data.Delete(key) - dict.expire.Delete(key) - return 0 - } + + // check key if already expired + nsec, ok := dict.expire.Get(key) + if ok && nsec < _nsec.Load() { + dict.data.Delete(key) + dict.expire.Delete(key) + return 0 } + // set ttl - object.hasTTL = true dict.expire.Put(key, ttl) return 1 } @@ -136,3 +122,7 @@ func (dict *Dict) EvictExpired() { return count > 20 }) } + +func nsec2duration(nsec int64) (second int) { + return int(nsec-_nsec.Load()) / int(time.Second) +} diff --git a/internal/dict/dict_test.go b/internal/dict/dict_test.go index 7a3ef57..c178f14 100644 --- a/internal/dict/dict_test.go +++ b/internal/dict/dict_test.go @@ -5,9 +5,6 @@ import ( "time" "github.com/stretchr/testify/assert" - "github.com/xgzlucario/rotom/internal/hash" - "github.com/xgzlucario/rotom/internal/list" - "github.com/xgzlucario/rotom/internal/zset" ) func TestDict(t *testing.T) { @@ -17,13 +14,12 @@ func TestDict(t *testing.T) { dict := New() dict.Set("key", []byte("hello")) - object, ttl := dict.Get("key") - assert.Equal(ttl, TTL_DEFAULT) - assert.Equal(object.Data(), []byte("hello")) - assert.Equal(object.Type(), TypeString) + data, ttl := dict.Get("key") + assert.Equal(ttl, TTL_FOREVER) + assert.Equal(data, []byte("hello")) - object, ttl = dict.Get("none") - assert.Nil(object) + data, ttl = dict.Get("none") + assert.Nil(data) assert.Equal(ttl, KEY_NOT_EXIST) }) @@ -33,21 +29,20 @@ func TestDict(t *testing.T) { dict.SetWithTTL("key", []byte("hello"), time.Now().Add(time.Minute).UnixNano()) time.Sleep(time.Second / 10) - object, ttl := dict.Get("key") + data, ttl := dict.Get("key") assert.Equal(ttl, 59) - assert.Equal(object.Data(), []byte("hello")) - assert.Equal(object.Type(), TypeString) + assert.Equal(data, []byte("hello")) res := dict.SetTTL("key", time.Now().Add(-time.Second).UnixNano()) assert.Equal(res, 1) - res = dict.SetTTL("not-exist", TTL_DEFAULT) + res = dict.SetTTL("not-exist", TTL_FOREVER) assert.Equal(res, 0) // get expired - object, ttl = dict.Get("key") + data, ttl = dict.Get("key") assert.Equal(ttl, KEY_NOT_EXIST) - assert.Nil(object) + assert.Nil(data) // setTTL expired dict.SetWithTTL("keyx", []byte("hello"), time.Now().Add(-time.Second).UnixNano()) @@ -70,20 +65,3 @@ func TestDict(t *testing.T) { assert.True(ok) }) } - -func TestOnject(t *testing.T) { - assert := assert.New(t) - - object := new(Object) - object.SetData([]byte("hello")) - object.SetData(1) - object.SetData(hash.NewZipMap()) - object.SetData(hash.NewMap()) - object.SetData(hash.NewZipSet()) - object.SetData(hash.NewSet()) - object.SetData(list.New()) - object.SetData(zset.NewZSet()) - assert.Panics(func() { - object.SetData(time.Now()) - }) -} diff --git a/internal/dict/object.go b/internal/dict/object.go deleted file mode 100644 index 653c099..0000000 --- a/internal/dict/object.go +++ /dev/null @@ -1,68 +0,0 @@ -package dict - -import ( - "fmt" - "time" - - "github.com/xgzlucario/rotom/internal/hash" - "github.com/xgzlucario/rotom/internal/list" - "github.com/xgzlucario/rotom/internal/zset" -) - -// Type defines all rotom data types. -type Type byte - -const ( - TypeString Type = iota + 1 - TypeInteger - TypeMap - TypeZipMap - TypeSet - TypeZipSet - TypeList - TypeZSet -) - -// Object is the basic element for storing in dict. -type Object struct { - typ Type - hasTTL bool - lastAccessd uint32 - data any -} - -func (o *Object) Type() Type { return o.typ } - -func (o *Object) Data() any { return o.data } - -func (o *Object) SetData(data any) { - o.typ = typeOfData(data) - o.data = data -} - -func nsec2duration(nsec int64) (second int) { - return int(nsec-_nsec.Load()) / int(time.Second) -} - -func typeOfData(data any) Type { - switch data.(type) { - case []byte: - return TypeString - case int: - return TypeInteger - case *hash.Map: - return TypeMap - case *hash.ZipMap: - return TypeZipMap - case *hash.Set: - return TypeSet - case *hash.ZipSet: - return TypeZipSet - case *list.QuickList: - return TypeList - case *zset.ZSet: - return TypeZSet - default: - panic(fmt.Sprintf("unknown type: %T", data)) - } -}