Skip to content

Commit

Permalink
fix(zset): fix random in ZRANDMEMBER command
Browse files Browse the repository at this point in the history
fixes dragonflydb#2850

Signed-off-by: Stepan Bagritsevich <[email protected]>
  • Loading branch information
Stepan Bagritsevich committed May 2, 2024
1 parent 082aba0 commit 8255c7e
Show file tree
Hide file tree
Showing 2 changed files with 224 additions and 47 deletions.
130 changes: 114 additions & 16 deletions src/server/zset_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,70 @@ OpResult<DbSlice::ItAndUpdater> FindZEntry(const ZParams& zparams, const OpArgs&
return DbSlice::ItAndUpdater{add_res.it, add_res.exp_it, std::move(add_res.post_updater)};
}

using RandomPick = std::size_t;

class PicksGenerator {
public:
virtual RandomPick Generate() = 0;
virtual ~PicksGenerator() = default;
};

class NonUniquePicksGenerator : public PicksGenerator {
public:
NonUniquePicksGenerator(std::size_t total_size) : total_size_(total_size) {
CHECK_GT(total_size, std::size_t(0));
}

RandomPick Generate() override {
return absl::Uniform(bitgen_, 0u, total_size_);
}

private:
const std::size_t total_size_;
absl::BitGen bitgen_{};
};

/*
* Generates unique index in O(1).
*
* picks_count specifies the number of random indexes to be generated.
* In other words, this is the number of times the Generate() function is called.
*
* The class uses Robert Floyd's sampling algorithm
* https://dl.acm.org/doi/pdf/10.1145/30401.315746
* */
class UniquePicksGenerator : public PicksGenerator {
public:
UniquePicksGenerator(std::size_t picks_count, std::size_t total_size)
: picked_indexes_(picks_count) {
CHECK_GE(total_size, picks_count);
current_random_limit_ = total_size - picks_count;
}

RandomPick Generate() override {
const std::size_t max_index = current_random_limit_++;
const RandomPick random_index = absl::Uniform(bitgen_, 0u, max_index + 1u);

if (!IndexWasPicked(random_index)) {
picked_indexes_.insert(random_index);
return random_index;
} else {
picked_indexes_.insert(max_index);
return max_index;
}
}

private:
bool IndexWasPicked(RandomPick pick) {
return picked_indexes_.find(pick) != picked_indexes_.end();
}

private:
std::size_t current_random_limit_;
std::unordered_set<RandomPick> picked_indexes_;
absl::BitGen bitgen_{};
};

bool ScoreToLongLat(const std::optional<double>& val, double* xy) {
if (!val.has_value())
return false;
Expand Down Expand Up @@ -1702,6 +1766,51 @@ OpResult<StringVec> OpScan(const OpArgs& op_args, std::string_view key, uint64_t
return res;
}

OpResult<ScoredArray> OpRandMember(int count, const ZSetFamily::RangeParams& params,
const OpArgs& op_args, string_view key) {
auto it = op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_ZSET);
if (!it)
return it.status();

// Action::RANGE is a read-only operation, but requires const_cast
PrimeValue& pv = const_cast<PrimeValue&>(it.value()->second);

const std::size_t size = pv.Size();
const std::size_t picks_count =
count >= 0 ? std::min(static_cast<std::size_t>(count), size) : std::abs(count);

ScoredArray result{picks_count};
auto generator = [count, picks_count, size]() -> std::unique_ptr<PicksGenerator> {
if (count >= 0) {
return std::make_unique<UniquePicksGenerator>(picks_count, size);
} else {
return std::make_unique<NonUniquePicksGenerator>(size);
}
}();

if (picks_count * static_cast<std::uint64_t>(std::log2(size)) < size) {
for (std::size_t i = 0; i < picks_count; i++) {
const std::size_t picked_index = generator->Generate();

IntervalVisitor iv{Action::RANGE, params, &pv};
iv(ZSetFamily::IndexInterval{picked_index, picked_index});

result[i] = iv.PopResult().front();
}
} else {
IntervalVisitor iv{Action::RANGE, params, &pv};
iv(ZSetFamily::IndexInterval{0, -1});

ScoredArray all_elements = iv.PopResult();

for (std::size_t i = 0; i < picks_count; i++) {
result[i] = all_elements[generator->Generate()];
}
}

return result;
}

void ZAddGeneric(string_view key, const ZParams& zparams, ScoredMemberSpan memb_sp,
ConnectionContext* cntx) {
auto cb = [&](Transaction* t, EngineShard* shard) {
Expand Down Expand Up @@ -2323,43 +2432,32 @@ void ZSetFamily::ZRandMember(CmdArgList args, ConnectionContext* cntx) {
if (args.size() > 3)
return cntx->SendError(WrongNumArgsError("ZRANDMEMBER"));

ZRangeSpec range_spec;
range_spec.interval = IndexInterval(0, -1);

CmdArgParser parser{args};
string_view key = parser.Next();

bool is_count = parser.HasNext();
int count = is_count ? parser.Next<int>() : 1;

range_spec.params.with_scores = static_cast<bool>(parser.Check("WITHSCORES").IgnoreCase());
ZSetFamily::RangeParams params;
params.with_scores = static_cast<bool>(parser.Check("WITHSCORES").IgnoreCase());

if (parser.HasNext())
return cntx->SendError(absl::StrCat("Unsupported option:", string_view(parser.Next())));

if (auto err = parser.Error(); err)
return cntx->SendError(err->MakeReply());

bool sign = count < 0;
range_spec.params.limit = std::abs(count);

const auto cb = [&](Transaction* t, EngineShard* shard) {
return OpRange(range_spec, t->GetOpArgs(shard), key);
return OpRandMember(count, params, t->GetOpArgs(shard), key);
};

OpResult<ScoredArray> result = cntx->transaction->ScheduleSingleHopT(cb);
auto* rb = static_cast<RedisReplyBuilder*>(cntx->reply_builder());
if (result) {
if (sign && !result->empty()) {
for (auto i = result->size(); i < range_spec.params.limit; ++i) {
// we can return duplicate elements, so first is OK
result->push_back(result->front());
}
}
rb->SendScoredArray(result.value(), range_spec.params.with_scores);
rb->SendScoredArray(result.value(), params.with_scores);
} else if (result.status() == OpStatus::KEY_NOTFOUND) {
if (is_count) {
rb->SendScoredArray(ScoredArray(), range_spec.params.with_scores);
rb->SendScoredArray(ScoredArray(), params.with_scores);
} else {
rb->SendNull();
}
Expand Down
141 changes: 110 additions & 31 deletions src/server/zset_family_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,53 +77,132 @@ TEST_F(ZSetFamilyTest, ZRem) {
}

TEST_F(ZSetFamilyTest, ZRandMember) {
auto resp = Run({
"zadd",
"x",
"1",
"a",
"2",
"b",
"3",
"c",
});
auto resp = Run({"ZAdd", "x", "1", "a", "2", "b", "3", "c"});
EXPECT_THAT(resp, IntArg(3));

// Test if count > 0
resp = Run({"ZRandMember", "x"});
ASSERT_THAT(resp, ArgType(RespExpr::STRING));
EXPECT_THAT(resp, "a");
EXPECT_THAT(resp, AnyOf("a", "b", "c"));

resp = Run({"ZRandMember", "x", "1"});
ASSERT_THAT(resp, ArgType(RespExpr::STRING));
EXPECT_THAT(resp, AnyOf("a", "b", "c"));

resp = Run({"ZRandMember", "x", "2"});
ASSERT_THAT(resp, ArgType(RespExpr::ARRAY));
EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("a", "b"));
ASSERT_THAT(resp, ArrLen(2));
EXPECT_THAT(resp.GetVec(), IsSubsetOf({"a", "b", "c"}));

resp = Run({"ZRandMember", "x", "0"});
ASSERT_THAT(resp, ArgType(RespExpr::ARRAY));
EXPECT_EQ(resp.GetVec().size(), 0);
resp = Run({"ZRandMember", "x", "3"});
ASSERT_THAT(resp, ArrLen(3));
EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("a", "b", "c"));

resp = Run({"ZRandMember", "k"});
ASSERT_THAT(resp, ArgType(RespExpr::NIL));
// Test if count < 0
std::unordered_set<std::string> expected_entries({"a", "b", "c"});

resp = Run({"ZRandMember", "k", "2"});
ASSERT_THAT(resp, ArgType(RespExpr::ARRAY));
EXPECT_EQ(resp.GetVec().size(), 0);
auto expect_elements = [](const auto& expected_elements, const auto& actual_elements) {
for (const auto& x : actual_elements) {
if (expected_elements.find(x) == expected_elements.end()) {
return false;
}
}
return true;
};

auto parse_response = [](const auto& resp) {
auto vec = resp.GetVec();

std::vector<std::string> entries;
std::transform(vec.begin(), vec.end(), std::back_inserter(entries),
[](auto& x) { return x.GetString(); });
return entries;
};

resp = Run({"ZRandMember", "x", "-1"});
ASSERT_THAT(resp, ArgType(RespExpr::STRING));
EXPECT_THAT(resp, AnyOf("a", "b", "c"));

resp = Run({"ZRandMember", "x", "-5"});
ASSERT_THAT(resp, ArrLen(5));
EXPECT_THAT(resp.GetVec(), ElementsAre("a", "b", "c", "a", "a"));
resp = Run({"ZRandMember", "x", "-2"});
ASSERT_THAT(resp, ArrLen(2));
EXPECT_TRUE(expect_elements(expected_entries, parse_response(resp)));

resp = Run({"ZRandMember", "x", "5"});
resp = Run({"ZRandMember", "x", "-3"});
ASSERT_THAT(resp, ArrLen(3));
EXPECT_TRUE(expect_elements(expected_entries, parse_response(resp)));

// Test if count < 0, but the absolute value is larger than the size of the sorted set
resp = Run({"ZRandMember", "x", "-15"});
ASSERT_THAT(resp, ArrLen(15));
EXPECT_TRUE(expect_elements(expected_entries, parse_response(resp)));

// Test if count is 0
ASSERT_THAT(Run({"ZRandMember", "x", "0"}), ArrLen(0));

// Test if count is larger than the size of the sorted set
resp = Run({"ZRandMember", "x", "15"});
ASSERT_THAT(resp, ArrLen(3));
EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("a", "b", "c"));

resp = Run({"ZRandMember", "x", "-5", "WITHSCORES"});
ASSERT_THAT(resp, ArrLen(10));
EXPECT_THAT(resp.GetVec(), ElementsAre("a", "1", "b", "2", "c", "3", "a", "1", "a", "1"));
// Test if sorted set is empty
EXPECT_THAT(Run({"ZAdd", "empty::zset", "1", "one"}), IntArg(1));
EXPECT_THAT(Run({"ZRem", "empty::zset", "one"}), IntArg(1));
ASSERT_THAT(Run({"ZRandMember", "empty::zset", "0"}), ArrLen(0));
ASSERT_THAT(Run({"ZRandMember", "empty::zset", "3"}), ArrLen(0));
ASSERT_THAT(Run({"ZRandMember", "empty::zset", "-4"}), ArrLen(0));

// Test if key does not exist
ASSERT_THAT(Run({"ZRandMember", "y"}), ArgType(RespExpr::NIL));
ASSERT_THAT(Run({"ZRandMember", "y", "0"}), ArrLen(0));

// Test WITHSCORES
using ZSetEntry = std::pair<std::string, std::string>;
std::set<ZSetEntry> expected_entries_with_scores{{"a", "1"}, {"b", "2"}, {"c", "3"}};

auto parse_response_with_scores = [](const auto& resp) {
auto vec = resp.GetVec();

std::vector<ZSetEntry> entries;
for (std::size_t i = 1; i < vec.size(); i += 2) {
entries.emplace_back(vec[i - 1].GetString(), vec[i].GetString());
}
return entries;
};

resp = Run({"ZRandMember", "x", "1", "WITHSCORES"});
ASSERT_THAT(resp, ArrLen(2));
EXPECT_THAT(parse_response_with_scores(resp), IsSubsetOf(expected_entries_with_scores));

resp = Run({"ZRandMember", "x", "2", "WITHSCORES"});
ASSERT_THAT(resp, ArrLen(4));
EXPECT_THAT(parse_response_with_scores(resp), IsSubsetOf(expected_entries_with_scores));

resp = Run({"ZRandMember", "x", "3", "WITHSCORES"});
ASSERT_THAT(resp, ArrLen(6));
EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("a", "1", "b", "2", "c", "3"));
EXPECT_THAT(parse_response_with_scores(resp),
UnorderedElementsAre(std::make_pair("a", "1"), std::make_pair("b", "2"),
std::make_pair("c", "3")));

resp = Run({"ZRandMember", "x", "3", "WITHSCORES", "test"});
EXPECT_THAT(resp, ErrArg("wrong number of arguments"));
resp = Run({"ZRandMember", "x", "15", "WITHSCORES"});
ASSERT_THAT(resp, ArrLen(6));
EXPECT_THAT(parse_response_with_scores(resp),
UnorderedElementsAre(std::make_pair("a", "1"), std::make_pair("b", "2"),
std::make_pair("c", "3")));

resp = Run({"ZRandMember", "x", "-1", "WITHSCORES"});
ASSERT_THAT(resp, ArrLen(2));
EXPECT_TRUE(expect_elements(expected_entries_with_scores, parse_response_with_scores(resp)));

resp = Run({"ZRandMember", "x", "-2", "WITHSCORES"});
ASSERT_THAT(resp, ArrLen(4));
EXPECT_TRUE(expect_elements(expected_entries_with_scores, parse_response_with_scores(resp)));

resp = Run({"ZRandMember", "x", "-3", "WITHSCORES"});
ASSERT_THAT(resp, ArrLen(6));
EXPECT_TRUE(expect_elements(expected_entries_with_scores, parse_response_with_scores(resp)));

resp = Run({"ZRandMember", "x", "-15", "WITHSCORES"});
ASSERT_THAT(resp, ArrLen(30));
EXPECT_TRUE(expect_elements(expected_entries_with_scores, parse_response_with_scores(resp)));
}

TEST_F(ZSetFamilyTest, ZMScore) {
Expand Down

0 comments on commit 8255c7e

Please sign in to comment.