Skip to content

Commit

Permalink
fix(zset): fix random in ZRANDMEMBER command
Browse files Browse the repository at this point in the history
fixes dragonflydb#2850

Signed-off-by: Stepan Bagritsevich <[email protected]>
  • Loading branch information
BagritsevichStepan authored and Stepan Bagritsevich committed May 1, 2024
1 parent 07d076a commit c7e7f8b
Show file tree
Hide file tree
Showing 2 changed files with 215 additions and 47 deletions.
121 changes: 105 additions & 16 deletions src/server/zset_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,56 @@ OpResult<DbSlice::ItAndUpdater> FindZEntry(const ZParams& zparams, const OpArgs&
return DbSlice::ItAndUpdater{add_res.it, add_res.exp_it, std::move(add_res.post_updater)};
}

using RandomPick = std::size_t;
using PicksArray = std::vector<RandomPick>;

/*
* Generates an array of non-unique indexes in O(picks_count).
* picks_count specifies the number of random indexes.
* */
PicksArray GenerateRandomPicks(std::size_t picks_count, std::size_t total_size) {
CHECK_GT(total_size, std::size_t(0));

PicksArray picks;
picks.resize(picks_count);

absl::BitGen bitgen;

for (std::size_t i = 0; i < picks_count; i++) {
picks[i] = absl::Uniform(bitgen, 0u, total_size);
}
return picks;
}

/*
* Generates an array of unique indexes in O(picks_count).
* picks_count specifies the number of random indexes.
*
* The function uses Robert Floyd's sampling algorithm
* https://dl.acm.org/doi/pdf/10.1145/30401.315746
* */
PicksArray GenerateUniqueRandomPicks(std::size_t picks_count, std::size_t total_size) {
CHECK_GE(total_size, picks_count);

PicksArray picks;
std::unordered_set<std::size_t> picked_indexes{picks_count};

absl::BitGen bitgen;

for (std::size_t i = total_size - picks_count; i < total_size; ++i) {
std::size_t random_index = absl::Uniform(bitgen, 0u, i + 1u);
if (!picked_indexes.contains(random_index)) {
picks.push_back(random_index);
picked_indexes.insert(random_index);
} else {
picks.push_back(i);
picked_indexes.insert(i);
}
}
DCHECK_EQ(picks.size(), picks_count);
return picks;
}

bool ScoreToLongLat(const std::optional<double>& val, double* xy) {
if (!val.has_value())
return false;
Expand Down Expand Up @@ -1702,6 +1752,56 @@ OpResult<StringVec> OpScan(const OpArgs& op_args, std::string_view key, uint64_t
return res;
}

OpResult<ScoredArray> OpRandMember(int count, const ZSetFamily::RangeParams& params,
const OpArgs& op_args, string_view key) {
auto it = op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_ZSET);
if (!it)
return it.status();

// Action::RANGE is a read-only operation, but requires const_cast
PrimeValue& pv = const_cast<PrimeValue&>(it.value()->second);

const std::size_t size = pv.Size();
if (!size) {
return ScoredArray();
}

const std::size_t picks_count =
count >= 0 ? std::min(static_cast<std::size_t>(count), size) : std::abs(count);

ScoredArray result{picks_count};
PicksArray picks = count >= 0 ? GenerateUniqueRandomPicks(picks_count, size)
: GenerateRandomPicks(picks_count, size);

/* CASE 1:
* The number of requested elements (count) is significantly less than the total size.
* In this case, we generate random indexes, and search for the elements at this index (each
* search for O(log(size)). In total O(picks_count * log(size)). */
if (picks_count * static_cast<std::uint64_t>(std::log2(size)) <
size) { // convert to std::uint64_t to prevent overflow
for (std::size_t i = 0; i < picks_count; ++i) {
IntervalVisitor iv{Action::RANGE, params, &pv};
iv(ZSetFamily::IndexInterval{picks[i], picks[i]});
result[i] = iv.PopResult().front();
}
} else {
/* CASE 2:
* The number of requested elements (count) does not differ much from the total size.
* In this case, we just traverse all elements and randomly add them to the result.
* In total O(size). */
IntervalVisitor iv{Action::RANGE, params, &pv};
iv(ZSetFamily::IndexInterval{0, -1});

ScoredArray all_elements = iv.PopResult();

for (std::size_t i = 0; i < picks_count; ++i) {
result[i] = all_elements[picks[i]];
}
}

return result;
}

void ZAddGeneric(string_view key, const ZParams& zparams, ScoredMemberSpan memb_sp,
ConnectionContext* cntx) {
auto cb = [&](Transaction* t, EngineShard* shard) {
Expand Down Expand Up @@ -2323,43 +2423,32 @@ void ZSetFamily::ZRandMember(CmdArgList args, ConnectionContext* cntx) {
if (args.size() > 3)
return cntx->SendError(WrongNumArgsError("ZRANDMEMBER"));

ZRangeSpec range_spec;
range_spec.interval = IndexInterval(0, -1);

CmdArgParser parser{args};
string_view key = parser.Next();

bool is_count = parser.HasNext();
int count = is_count ? parser.Next<int>() : 1;

range_spec.params.with_scores = static_cast<bool>(parser.Check("WITHSCORES").IgnoreCase());
ZSetFamily::RangeParams params;
params.with_scores = static_cast<bool>(parser.Check("WITHSCORES").IgnoreCase());

if (parser.HasNext())
return cntx->SendError(absl::StrCat("Unsupported option:", string_view(parser.Next())));

if (auto err = parser.Error(); err)
return cntx->SendError(err->MakeReply());

bool sign = count < 0;
range_spec.params.limit = std::abs(count);

const auto cb = [&](Transaction* t, EngineShard* shard) {
return OpRange(range_spec, t->GetOpArgs(shard), key);
return OpRandMember(count, params, t->GetOpArgs(shard), key);
};

OpResult<ScoredArray> result = cntx->transaction->ScheduleSingleHopT(cb);
auto* rb = static_cast<RedisReplyBuilder*>(cntx->reply_builder());
if (result) {
if (sign && !result->empty()) {
for (auto i = result->size(); i < range_spec.params.limit; ++i) {
// we can return duplicate elements, so first is OK
result->push_back(result->front());
}
}
rb->SendScoredArray(result.value(), range_spec.params.with_scores);
rb->SendScoredArray(result.value(), params.with_scores);
} else if (result.status() == OpStatus::KEY_NOTFOUND) {
if (is_count) {
rb->SendScoredArray(ScoredArray(), range_spec.params.with_scores);
rb->SendScoredArray(ScoredArray(), params.with_scores);
} else {
rb->SendNull();
}
Expand Down
141 changes: 110 additions & 31 deletions src/server/zset_family_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,53 +77,132 @@ TEST_F(ZSetFamilyTest, ZRem) {
}

TEST_F(ZSetFamilyTest, ZRandMember) {
auto resp = Run({
"zadd",
"x",
"1",
"a",
"2",
"b",
"3",
"c",
});
auto resp = Run({"ZAdd", "x", "1", "a", "2", "b", "3", "c"});
EXPECT_THAT(resp, IntArg(3));

// Test if count > 0
resp = Run({"ZRandMember", "x"});
ASSERT_THAT(resp, ArgType(RespExpr::STRING));
EXPECT_THAT(resp, "a");
EXPECT_THAT(resp, AnyOf("a", "b", "c"));

resp = Run({"ZRandMember", "x", "1"});
ASSERT_THAT(resp, ArgType(RespExpr::STRING));
EXPECT_THAT(resp, AnyOf("a", "b", "c"));

resp = Run({"ZRandMember", "x", "2"});
ASSERT_THAT(resp, ArgType(RespExpr::ARRAY));
EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("a", "b"));
ASSERT_THAT(resp, ArrLen(2));
EXPECT_THAT(resp.GetVec(), IsSubsetOf({"a", "b", "c"}));

resp = Run({"ZRandMember", "x", "0"});
ASSERT_THAT(resp, ArgType(RespExpr::ARRAY));
EXPECT_EQ(resp.GetVec().size(), 0);
resp = Run({"ZRandMember", "x", "3"});
ASSERT_THAT(resp, ArrLen(3));
EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("a", "b", "c"));

resp = Run({"ZRandMember", "k"});
ASSERT_THAT(resp, ArgType(RespExpr::NIL));
// Test if count < 0
std::unordered_set<std::string> expected_entries({"a", "b", "c"});

resp = Run({"ZRandMember", "k", "2"});
ASSERT_THAT(resp, ArgType(RespExpr::ARRAY));
EXPECT_EQ(resp.GetVec().size(), 0);
auto expect_elements = [](const auto& expected_elements, const auto& actual_elements) {
for (const auto& x : actual_elements) {
if (!expected_elements.contains(x)) {
return false;
}
}
return true;
};

auto parse_response = [](const auto& resp) {
auto vec = resp.GetVec();

std::vector<std::string> entries;
std::transform(vec.begin(), vec.end(), std::back_inserter(entries),
[](auto& x) { return x.GetString(); });
return entries;
};

resp = Run({"ZRandMember", "x", "-1"});
ASSERT_THAT(resp, ArgType(RespExpr::STRING));
EXPECT_THAT(resp, AnyOf("a", "b", "c"));

resp = Run({"ZRandMember", "x", "-5"});
ASSERT_THAT(resp, ArrLen(5));
EXPECT_THAT(resp.GetVec(), ElementsAre("a", "b", "c", "a", "a"));
resp = Run({"ZRandMember", "x", "-2"});
ASSERT_THAT(resp, ArrLen(2));
EXPECT_TRUE(expect_elements(expected_entries, parse_response(resp)));

resp = Run({"ZRandMember", "x", "5"});
resp = Run({"ZRandMember", "x", "-3"});
ASSERT_THAT(resp, ArrLen(3));
EXPECT_TRUE(expect_elements(expected_entries, parse_response(resp)));

// Test if count < 0, but the absolute value is larger than the size of the sorted set
resp = Run({"ZRandMember", "x", "-15"});
ASSERT_THAT(resp, ArrLen(15));
EXPECT_TRUE(expect_elements(expected_entries, parse_response(resp)));

// Test if count is 0
ASSERT_THAT(Run({"ZRandMember", "x", "0"}), ArrLen(0));

// Test if count is larger than the size of the sorted set
resp = Run({"ZRandMember", "x", "15"});
ASSERT_THAT(resp, ArrLen(3));
EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("a", "b", "c"));

resp = Run({"ZRandMember", "x", "-5", "WITHSCORES"});
ASSERT_THAT(resp, ArrLen(10));
EXPECT_THAT(resp.GetVec(), ElementsAre("a", "1", "b", "2", "c", "3", "a", "1", "a", "1"));
// Test if sorted set is empty
EXPECT_THAT(Run({"ZAdd", "empty::zset", "1", "one"}), IntArg(1));
EXPECT_THAT(Run({"ZRem", "empty::zset", "one"}), IntArg(1));
ASSERT_THAT(Run({"ZRandMember", "empty::zset", "0"}), ArrLen(0));
ASSERT_THAT(Run({"ZRandMember", "empty::zset", "3"}), ArrLen(0));
ASSERT_THAT(Run({"ZRandMember", "empty::zset", "-4"}), ArrLen(0));

// Test if key does not exist
ASSERT_THAT(Run({"ZRandMember", "y"}), ArgType(RespExpr::NIL));
ASSERT_THAT(Run({"ZRandMember", "y", "0"}), ArrLen(0));

// Test WITHSCORES
using ZSetEntry = std::pair<std::string, std::string>;
std::set<ZSetEntry> expected_entries_with_scores{{"a", "1"}, {"b", "2"}, {"c", "3"}};

auto parse_response_with_scores = [](const auto& resp) {
auto vec = resp.GetVec();

std::vector<ZSetEntry> entries;
for (std::size_t i = 1; i < vec.size(); i += 2) {
entries.emplace_back(vec[i - 1].GetString(), vec[i].GetString());
}
return entries;
};

resp = Run({"ZRandMember", "x", "1", "WITHSCORES"});
ASSERT_THAT(resp, ArrLen(2));
EXPECT_THAT(parse_response_with_scores(resp), IsSubsetOf(expected_entries_with_scores));

resp = Run({"ZRandMember", "x", "2", "WITHSCORES"});
ASSERT_THAT(resp, ArrLen(4));
EXPECT_THAT(parse_response_with_scores(resp), IsSubsetOf(expected_entries_with_scores));

resp = Run({"ZRandMember", "x", "3", "WITHSCORES"});
ASSERT_THAT(resp, ArrLen(6));
EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("a", "1", "b", "2", "c", "3"));
EXPECT_THAT(parse_response_with_scores(resp),
UnorderedElementsAre(std::make_pair("a", "1"), std::make_pair("b", "2"),
std::make_pair("c", "3")));

resp = Run({"ZRandMember", "x", "3", "WITHSCORES", "test"});
EXPECT_THAT(resp, ErrArg("wrong number of arguments"));
resp = Run({"ZRandMember", "x", "15", "WITHSCORES"});
ASSERT_THAT(resp, ArrLen(6));
EXPECT_THAT(parse_response_with_scores(resp),
UnorderedElementsAre(std::make_pair("a", "1"), std::make_pair("b", "2"),
std::make_pair("c", "3")));

resp = Run({"ZRandMember", "x", "-1", "WITHSCORES"});
ASSERT_THAT(resp, ArrLen(2));
EXPECT_TRUE(expect_elements(expected_entries_with_scores, parse_response_with_scores(resp)));

resp = Run({"ZRandMember", "x", "-2", "WITHSCORES"});
ASSERT_THAT(resp, ArrLen(4));
EXPECT_TRUE(expect_elements(expected_entries_with_scores, parse_response_with_scores(resp)));

resp = Run({"ZRandMember", "x", "-3", "WITHSCORES"});
ASSERT_THAT(resp, ArrLen(6));
EXPECT_TRUE(expect_elements(expected_entries_with_scores, parse_response_with_scores(resp)));

resp = Run({"ZRandMember", "x", "-15", "WITHSCORES"});
ASSERT_THAT(resp, ArrLen(30));
EXPECT_TRUE(expect_elements(expected_entries_with_scores, parse_response_with_scores(resp)));
}

TEST_F(ZSetFamilyTest, ZMScore) {
Expand Down

0 comments on commit c7e7f8b

Please sign in to comment.