Skip to content

Commit

Permalink
feat: add zrank command
Browse files Browse the repository at this point in the history
  • Loading branch information
xgzlucario committed Aug 1, 2024
1 parent 3d1f8f9 commit 52c5222
Show file tree
Hide file tree
Showing 8 changed files with 147 additions and 61 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ This is rotom, a tiny Redis Server written in Go. It replicates the core event l
2. Compatible with the Redis RESP protocol, allowing any Redis client to connect to rotom.
3. Implements data structures such as dict, list, map, zipmap, set, zipset, and zset.
4. Supports AOF.
5. Supports 18 commonly used commands.
5. Supports 20+ commonly used commands.

## AELoop

Expand Down
2 changes: 1 addition & 1 deletion README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
2. 兼容 Redis RESP 协议,你可以使用任何 redis 客户端连接 rotom
3. 实现了 dict, list, map, zipmap, set, zipset, zset 数据结构
4. AOF 支持
5. 支持 18 种常用命令
5. 支持 20 多种常用命令

## AELoop 事件循环

Expand Down
60 changes: 37 additions & 23 deletions command.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"bytes"
"fmt"
"strconv"
"strings"
Expand Down Expand Up @@ -45,6 +46,7 @@ var cmdTable []*Command = []*Command{
{"srem", sremCommand, 2, true},
{"spop", spopCommand, 1, true},
{"zadd", zaddCommand, 3, true},
{"zrank", zrankCommand, 2, false},
{"zpopmin", zpopminCommand, 1, true},
{"zrange", zrangeCommand, 3, false},
{"ping", pingCommand, 0, false},
Expand Down Expand Up @@ -299,29 +301,24 @@ func lrangeCommand(writer *RESPWriter, args []RESP) {
writer.WriteError(err)
return
}
end, err := args[2].ToInt()
stop, err := args[2].ToInt()
if err != nil {
writer.WriteError(err)
return
}

ls, err := fetchList(key)
if err != nil {
writer.WriteError(err)
return
}

// calculate list size
size := end - start
if end == -1 {
size = ls.Size()
}
if size < 0 {
size = 0
if stop == -1 {
stop = ls.Size()
}
start = min(start, stop)

writer.WriteArrayHead(size)
ls.Range(start, end, func(data []byte) {
writer.WriteArrayHead(stop - start)
ls.Range(start, stop, func(data []byte) {
writer.WriteBulk(data)
})
}
Expand All @@ -336,13 +333,13 @@ func saddCommand(writer *RESPWriter, args []RESP) {
return
}

var newItems int
var count int
for i := 0; i < len(args); i++ {
if set.Add(args[i].ToString()) {
newItems++
count++
}
}
writer.WriteInteger(newItems)
writer.WriteInteger(count)
}

func sremCommand(writer *RESPWriter, args []RESP) {
Expand Down Expand Up @@ -390,20 +387,37 @@ func zaddCommand(writer *RESPWriter, args []RESP) {
return
}

var newFields int
var count int
for i := 0; i < len(args); i += 2 {
score, err := args[i].ToInt()
score, err := args[i].ToFloat()
if err != nil {
writer.WriteError(err)
return
}

key := args[i+1].ToString()
if zset.Set(key, float64(score)) {
newFields++
if zset.Set(key, score) {
count++
}
}
writer.WriteInteger(newFields)
writer.WriteInteger(count)
}

func zrankCommand(writer *RESPWriter, args []RESP) {
key := args[0].ToStringUnsafe()
member := args[1].ToStringUnsafe()

zset, err := fetchZSet(key)
if err != nil {
writer.WriteError(err)
return
}

rank, _ := zset.Rank(member)
if rank < 0 {
writer.WriteNull()
} else {
writer.WriteInteger(rank)
}
}

func zrangeCommand(writer *RESPWriter, args []RESP) {
Expand All @@ -427,12 +441,12 @@ func zrangeCommand(writer *RESPWriter, args []RESP) {
stop = zset.Len()
}

withScores := len(args) == 4 && strings.EqualFold(args[3].ToStringUnsafe(), "WITHSCORES")
withScores := len(args) == 4 && bytes.EqualFold(args[3], []byte("WITHSCORES"))
if withScores {
writer.WriteArrayHead((stop - start) * 2)
zset.Range(start, stop, func(key string, score float64) {
writer.WriteBulkString(key)
writer.WriteBulkString(strconv.Itoa(int(score)))
writer.WriteFloat(score)
})

} else {
Expand Down Expand Up @@ -466,7 +480,7 @@ func zpopminCommand(writer *RESPWriter, args []RESP) {
for range size {
key, score := zset.PopMin()
writer.WriteBulkString(key)
writer.WriteBulkString(strconv.Itoa(int(score)))
writer.WriteFloat(score)
}
}

Expand Down
37 changes: 33 additions & 4 deletions command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,27 @@ func TestCommand(t *testing.T) {
res, _ = rdb.LRange(ctx, "list", 1, 3).Result()
assert.Equal(res, []string{"b", "a"})

res, err := rdb.LRange(ctx, "list", 3, 2).Result()
assert.Equal(len(res), 0)
assert.Nil(err)

// lpop
val, _ := rdb.LPop(ctx, "list").Result()
assert.Equal(val, "c")

// rpop
val, _ = rdb.RPop(ctx, "list").Result()
assert.Equal(val, "f")

// pop nil
{
_, err := rdb.LPop(ctx, "list-empty").Result()
assert.Equal(err, redis.Nil)

_, err = rdb.RPop(ctx, "list-empty").Result()
assert.Equal(err, redis.Nil)
}

})

t.Run("set", func(t *testing.T) {
Expand Down Expand Up @@ -183,10 +197,25 @@ func TestCommand(t *testing.T) {

n, _ = rdb.ZAdd(ctx, "rank",
redis.Z{Member: "player1", Score: 100},
redis.Z{Member: "player2", Score: 300},
redis.Z{Member: "player2", Score: 300.5},
redis.Z{Member: "player3", Score: 100}).Result()
assert.Equal(n, int64(2))

// zrank
{
res, _ := rdb.ZRank(ctx, "rank", "player1").Result()
assert.Equal(res, int64(0))

res, _ = rdb.ZRank(ctx, "rank", "player2").Result()
assert.Equal(res, int64(2))

res, _ = rdb.ZRank(ctx, "rank", "player3").Result()
assert.Equal(res, int64(1))

_, err := rdb.ZRank(ctx, "rank", "player999").Result()
assert.Equal(err, redis.Nil)
}

// zrange
{
members, _ := rdb.ZRange(ctx, "rank", 0, -1).Result()
Expand All @@ -205,13 +234,13 @@ func TestCommand(t *testing.T) {
assert.Equal(res, []redis.Z{
{Member: "player1", Score: 100},
{Member: "player3", Score: 100},
{Member: "player2", Score: 300},
{Member: "player2", Score: 300.5},
})

res, _ = rdb.ZRangeWithScores(ctx, "rank", 1, 3).Result()
assert.Equal(res, []redis.Z{
{Member: "player3", Score: 100},
{Member: "player2", Score: 300},
{Member: "player2", Score: 300.5},
})

res, _ = rdb.ZRangeWithScores(ctx, "rank", 70, 60).Result()
Expand All @@ -228,7 +257,7 @@ func TestCommand(t *testing.T) {

res, _ = rdb.ZPopMin(ctx, "rank").Result()
assert.Equal(res, []redis.Z{
{Member: "player2", Score: 300},
{Member: "player2", Score: 300.5},
})
}
})
Expand Down
17 changes: 15 additions & 2 deletions internal/zset/zset.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func (z *ZSet) Delete(key string) (float64, bool) {
}

func (z *ZSet) PopMin() (key string, score float64) {
z.skl.ForEachIf(func(n node, s struct{}) bool {
z.skl.ForEachIf(func(n node, _ struct{}) bool {
key = n.key
score = n.score
return false
Expand All @@ -70,9 +70,22 @@ func (z *ZSet) PopMin() (key string, score float64) {
return
}

func (z *ZSet) Rank(key string) (int, float64) {
score, ok := z.m.Get(key)
if !ok {
return -1, 0
}
index := -1
z.skl.ForEachIf(func(n node, _ struct{}) bool {
index++
return n.key != key
})
return index, score
}

func (z *ZSet) Range(start, stop int, fn func(key string, score float64)) {
var index int
z.skl.ForEachIf(func(n node, s struct{}) bool {
z.skl.ForEachIf(func(n node, _ struct{}) bool {
if index >= start && index < stop {
fn(n.key, n.score)
}
Expand Down
53 changes: 29 additions & 24 deletions resp.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package main

import (
"bytes"
"io"
"slices"
"strconv"
Expand Down Expand Up @@ -106,19 +105,19 @@ func (r *RESPReader) ReadNextCommand(argsBuf []RESP) (args []RESP, err error) {

// RESPWriter is a writer that helps construct RESP (Redis Serialization Protocol) messages.
type RESPWriter struct {
b *bytes.Buffer
b []byte
}

// NewWriter initializes a new RESPWriter with a given capacity.
func NewWriter(cap int) *RESPWriter {
return &RESPWriter{bytes.NewBuffer(make([]byte, 0, cap))}
return &RESPWriter{make([]byte, 0, cap)}
}

// WriteArrayHead writes the RESP array header with the given length.
func (w *RESPWriter) WriteArrayHead(arrayLen int) {
w.b.WriteByte(ARRAY)
w.b.WriteString(strconv.Itoa(arrayLen))
w.b.Write(CRLF)
w.b = append(w.b, ARRAY)
w.b = strconv.AppendUint(w.b, uint64(arrayLen), 10)
w.b = append(w.b, CRLF...)
}

// WriteBulk writes a RESP bulk string from a byte slice.
Expand All @@ -127,43 +126,47 @@ func (w *RESPWriter) WriteBulk(bluk []byte) {
}

// WriteBulkString writes a RESP bulk string from a string.
func (w *RESPWriter) WriteBulkString(bluk string) {
w.b.WriteByte(BULK)
w.b.WriteString(strconv.Itoa(len(bluk)))
w.b.Write(CRLF)
w.b.WriteString(bluk)
w.b.Write(CRLF)
func (w *RESPWriter) WriteBulkString(bulk string) {
w.b = append(w.b, BULK)
w.b = strconv.AppendUint(w.b, uint64(len(bulk)), 10)
w.b = append(w.b, CRLF...)
w.b = append(w.b, bulk...)
w.b = append(w.b, CRLF...)
}

// WriteError writes a RESP error message.
func (w *RESPWriter) WriteError(err error) {
w.b.WriteByte(ERROR)
w.b.WriteString(err.Error())
w.b.Write(CRLF)
w.b = append(w.b, ERROR)
w.b = append(w.b, err.Error()...)
w.b = append(w.b, CRLF...)
}

// WriteString writes a RESP simple string.
func (w *RESPWriter) WriteString(str string) {
w.b.WriteByte(STRING)
w.b.WriteString(str)
w.b.Write(CRLF)
w.b = append(w.b, STRING)
w.b = append(w.b, str...)
w.b = append(w.b, CRLF...)
}

// WriteInteger writes a RESP integer.
func (w *RESPWriter) WriteInteger(num int) {
w.b.WriteByte(INTEGER)
w.b.WriteString(strconv.Itoa(num))
w.b.Write(CRLF)
w.b = append(w.b, INTEGER)
w.b = strconv.AppendUint(w.b, uint64(num), 10)
w.b = append(w.b, CRLF...)
}

// WriteFloat writes a RESP bulk string from a float64.
func (w *RESPWriter) WriteFloat(num float64) {
w.WriteBulkString(strconv.FormatFloat(num, 'f', -1, 64))
}

// WriteNull writes a RESP null bulk string.
func (w *RESPWriter) WriteNull() {
w.b.WriteString("$-1")
w.b.Write(CRLF)
w.b = append(w.b, "$-1\r\n"...)
}

// Reset resets the internal buffer.
func (w *RESPWriter) Reset() { w.b.Reset() }
func (w *RESPWriter) Reset() { w.b = w.b[:0] }

// RESP represents the RESP (Redis Serialization Protocol) message in byte slice format.
type RESP []byte
Expand All @@ -174,6 +177,8 @@ func (r RESP) ToStringUnsafe() string { return b2s(r) }

func (r RESP) ToInt() (int, error) { return strconv.Atoi(b2s(r)) }

func (r RESP) ToFloat() (float64, error) { return strconv.ParseFloat(b2s(r), 64) }

func (r RESP) Clone() []byte { return slices.Clone(r) }

func b2s(b []byte) string {
Expand Down
Loading

0 comments on commit 52c5222

Please sign in to comment.