Skip to content

Commit

Permalink
refator: use stdmap to replace swissmap in rotom
Browse files Browse the repository at this point in the history
  • Loading branch information
xgzlucario committed Aug 21, 2024
1 parent 68e9d26 commit 57847e8
Show file tree
Hide file tree
Showing 9 changed files with 64 additions and 70 deletions.
4 changes: 2 additions & 2 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func setCommand(writer *RESPWriter, args []RESP) {
writer.WriteError(errParseInteger)
return
}
ttl = time.Now().Add(time.Second * time.Duration(n)).UnixNano()
ttl = dict.GetNanoTime() + int64(time.Second)*int64(n)
extra = extra[2:]

// PX
Expand All @@ -114,7 +114,7 @@ func setCommand(writer *RESPWriter, args []RESP) {
writer.WriteError(errParseInteger)
return
}
ttl = time.Now().Add(time.Millisecond * time.Duration(n)).UnixNano()
ttl = dict.GetNanoTime() + int64(time.Millisecond)*int64(n)
extra = extra[2:]

// KEEPTTL
Expand Down
12 changes: 12 additions & 0 deletions command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,18 @@ func TestCommand(t *testing.T) {

res, _ := rdb.ZRem(ctx, "rank", "player1", "player2", "player999").Result()
assert.Equal(res, int64(2))

// err wrong type
rdb.Set(ctx, "key", "value", 0)

_, err := rdb.ZAdd(ctx, "key", redis.Z{}).Result()
assert.Equal(err.Error(), errWrongType.Error())

_, err = rdb.ZRank(ctx, "key", "member1").Result()
assert.Equal(err.Error(), errWrongType.Error())

_, err = rdb.ZRem(ctx, "key", "member1").Result()
assert.Equal(err.Error(), errWrongType.Error())
})

t.Run("flushdb", func(t *testing.T) {
Expand Down
2 changes: 0 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ go 1.22
require (
github.com/chen3feng/stl4go v0.1.1
github.com/deckarep/golang-set/v2 v2.6.0
github.com/dolthub/swiss v0.2.1
github.com/influxdata/tdigest v0.0.1
github.com/redis/go-redis/v9 v9.5.2
github.com/rs/zerolog v1.33.0
Expand All @@ -18,7 +17,6 @@ require (
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dolthub/maphash v0.1.0 // indirect
github.com/edsrzf/mmap-go v1.1.0 // indirect
github.com/google/go-cmp v0.6.0 // indirect
github.com/kr/pretty v0.3.0 // indirect
Expand Down
4 changes: 0 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@ github.com/deckarep/golang-set/v2 v2.6.0 h1:XfcQbWM1LlMB8BsJ8N9vW5ehnnPVIw0je80N
github.com/deckarep/golang-set/v2 v2.6.0/go.mod h1:VAky9rY/yGXJOLEDv3OMci+7wtDpOF4IN+y82NBOac4=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/dolthub/maphash v0.1.0 h1:bsQ7JsF4FkkWyrP3oCnFJgrCUAFbFf3kOl4L/QxPDyQ=
github.com/dolthub/maphash v0.1.0/go.mod h1:gkg4Ch4CdCDu5h6PMriVLawB7koZ+5ijb9puGMV50a4=
github.com/dolthub/swiss v0.2.1 h1:gs2osYs5SJkAaH5/ggVJqXQxRXtWshF6uE0lgR/Y3Gw=
github.com/dolthub/swiss v0.2.1/go.mod h1:8AhKZZ1HK7g18j7v7k6c5cYIGEZJcPn0ARsai8cUrh0=
github.com/edsrzf/mmap-go v1.1.0 h1:6EUwBLQ/Mcr1EYLE4Tn1VdW1A4ckqCQWZBw8Hr0kjpQ=
github.com/edsrzf/mmap-go v1.1.0/go.mod h1:19H/e8pUPLicwkyNgOykDXkJ9F0MHE+Z52B8EIth78Q=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
Expand Down
10 changes: 0 additions & 10 deletions internal/dict/benchmark/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"runtime/debug"
"time"

"github.com/dolthub/swiss"
"github.com/influxdata/tdigest"
"github.com/xgzlucario/rotom/internal/dict"
)
Expand Down Expand Up @@ -60,15 +59,6 @@ func main() {
m[k] = v
td.Add(float64(time.Since(start)), 1)
}

case "swiss":
m := swiss.NewMap[string, any](8)
for i := 0; i < entries; i++ {
k, v := genKV(i)
start := time.Now()
m.Put(k, v)
td.Add(float64(time.Since(start)), 1)
}
}
cost := time.Since(start)

Expand Down
52 changes: 26 additions & 26 deletions internal/dict/dict.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ package dict
import (
"sync/atomic"
"time"

"github.com/dolthub/swiss"
)

const (
Expand Down Expand Up @@ -34,93 +32,95 @@ func GetNanoTime() int64 {

// Dict is the hashmap for Rotom.
type Dict struct {
data *swiss.Map[string, any]
expire *swiss.Map[string, int64]
data map[string]any
expire map[string]int64
}

func New() *Dict {
return &Dict{
data: swiss.NewMap[string, any](64),
expire: swiss.NewMap[string, int64](64),
data: make(map[string]any, 64),
expire: make(map[string]int64, 64),
}
}

func (dict *Dict) Get(key string) (any, int) {
data, ok := dict.data.Get(key)
data, ok := dict.data[key]
if !ok {
// key not exist
return nil, KEY_NOT_EXIST
}

nsec, ok := dict.expire.Get(key)
nsec, ok := dict.expire[key]
if !ok {
return data, TTL_FOREVER
}

// key expired
if nsec < _nsec.Load() {
dict.data.Delete(key)
dict.expire.Delete(key)
delete(dict.data, key)
delete(dict.expire, key)
return nil, KEY_NOT_EXIST
}

return data, nsec2duration(nsec)
}

func (dict *Dict) Set(key string, data any) {
dict.data.Put(key, data)
dict.data[key] = data
}

func (dict *Dict) SetWithTTL(key string, data any, ttl int64) {
if ttl > 0 {
dict.expire.Put(key, ttl)
dict.expire[key] = ttl
}
dict.data.Put(key, data)
dict.data[key] = data
}

func (dict *Dict) Delete(key string) bool {
_, ok := dict.data.Get(key)
_, ok := dict.data[key]
if !ok {
return false
}
dict.data.Delete(key)
dict.expire.Delete(key)
delete(dict.data, key)
delete(dict.expire, key)
return true
}

// SetTTL set expire time for key.
// return `0` if key not exist or expired.
// return `1` if set successed.
func (dict *Dict) SetTTL(key string, ttl int64) int {
_, ok := dict.data.Get(key)
_, ok := dict.data[key]
if !ok {
// key not exist
return 0
}

// check key if already expired
nsec, ok := dict.expire.Get(key)
nsec, ok := dict.expire[key]
if ok && nsec < _nsec.Load() {
dict.data.Delete(key)
dict.expire.Delete(key)
delete(dict.data, key)
delete(dict.expire, key)
return 0
}

// set ttl
dict.expire.Put(key, ttl)
dict.expire[key] = ttl
return 1
}

func (dict *Dict) EvictExpired() {
var count int
dict.expire.Iter(func(key string, nsec int64) bool {
for key, nsec := range dict.expire {
if _nsec.Load() > nsec {
dict.expire.Delete(key)
dict.data.Delete(key)
delete(dict.expire, key)
delete(dict.data, key)
}
count++
return count > 20
})
if count > 20 {
return
}
}
}

func nsec2duration(nsec int64) (second int) {
Expand Down
26 changes: 12 additions & 14 deletions internal/hash/map.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
package hash

import (
"github.com/dolthub/swiss"
)

type MapI interface {
Set(key string, val []byte) bool
Get(key string) ([]byte, bool)
Expand All @@ -15,34 +11,36 @@ type MapI interface {
var _ MapI = (*Map)(nil)

type Map struct {
data *swiss.Map[string, []byte]
data map[string][]byte
}

func NewMap() *Map {
return &Map{swiss.NewMap[string, []byte](256)}
return &Map{make(map[string][]byte, 256)}
}

func (m *Map) Get(key string) ([]byte, bool) {
return m.data.Get(key)
val, ok := m.data[key]
return val, ok
}

func (m *Map) Set(key string, val []byte) bool {
_, ok := m.data.Get(key)
m.data.Put(key, val)
_, ok := m.data[key]
m.data[key] = val
return !ok
}

func (m *Map) Remove(key string) bool {
return m.data.Delete(key)
_, ok := m.data[key]
delete(m.data, key)
return ok
}

func (m *Map) Len() int {
return m.data.Count()
return len(m.data)
}

func (m *Map) Scan(fn func(key string, val []byte)) {
m.data.Iter(func(key string, val []byte) (stop bool) {
for key, val := range m.data {
fn(key, val)
return false
})
}
}
22 changes: 11 additions & 11 deletions internal/zset/zset.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"cmp"

"github.com/chen3feng/stl4go"
"github.com/dolthub/swiss"
)

type node struct {
Expand All @@ -20,41 +19,42 @@ func nodeCompare(a, b node) int {
}

type ZSet struct {
m *swiss.Map[string, float64]
m map[string]float64
skl *stl4go.SkipList[node, struct{}]
}

func NewZSet() *ZSet {
return &ZSet{
m: swiss.NewMap[string, float64](8),
m: make(map[string]float64),
skl: stl4go.NewSkipListFunc[node, struct{}](nodeCompare),
}
}

func (z *ZSet) Get(key string) (float64, bool) {
return z.m.Get(key)
val, ok := z.m[key]
return val, ok
}

func (z *ZSet) Set(key string, score float64) bool {
old, ok := z.m.Get(key)
old, ok := z.m[key]
if ok {
// same
if score == old {
return false
}
z.skl.Remove(node{key, old})
}
z.m.Put(key, score)
z.m[key] = score
z.skl.Insert(node{key, score}, struct{}{})
return !ok
}

func (z *ZSet) Remove(key string) bool {
score, ok := z.m.Get(key)
score, ok := z.m[key]
if !ok {
return false
}
z.m.Delete(key)
delete(z.m, key)
z.skl.Remove(node{key, score})
return true
}
Expand All @@ -65,13 +65,13 @@ func (z *ZSet) PopMin() (key string, score float64) {
score = n.score
return false
})
z.m.Delete(key)
delete(z.m, key)
z.skl.Remove(node{key, score})
return
}

func (z *ZSet) Rank(key string) (int, float64) {
score, ok := z.m.Get(key)
score, ok := z.m[key]
if !ok {
return -1, 0
}
Expand All @@ -95,5 +95,5 @@ func (z *ZSet) Range(start, stop int, fn func(key string, score float64)) {
}

func (z *ZSet) Len() int {
return z.m.Count()
return len(z.m)
}
2 changes: 1 addition & 1 deletion rotom.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ READ:
client.recvx += n

if readSize == 0 {
log.Warn().Msgf("client %d read query empty, now free", fd)
log.Info().Msgf("client %d read query empty, now free", fd)
freeClient(client)
return
}
Expand Down

0 comments on commit 57847e8

Please sign in to comment.