Skip to content

Commit

Permalink
fix(zset): fix random in ZRANDMEMBER command
Browse files Browse the repository at this point in the history
fixes dragonflydb#2850

Signed-off-by: Stepan Bagritsevich <[email protected]>
  • Loading branch information
Stepan Bagritsevich authored and BagritsevichStepan committed May 5, 2024
1 parent dee4a17 commit e604df2
Show file tree
Hide file tree
Showing 2 changed files with 262 additions and 47 deletions.
127 changes: 111 additions & 16 deletions src/server/zset_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,70 @@ OpResult<DbSlice::ItAndUpdater> FindZEntry(const ZParams& zparams, const OpArgs&
return DbSlice::ItAndUpdater{add_res.it, add_res.exp_it, std::move(add_res.post_updater)};
}

using RandomPick = std::uint32_t;

class PicksGenerator {
public:
virtual RandomPick Generate() = 0;
virtual ~PicksGenerator() = default;
};

class NonUniquePicksGenerator : public PicksGenerator {
public:
NonUniquePicksGenerator(RandomPick max_range) : max_range_(max_range) {
CHECK_GT(max_range, RandomPick(0));
}

RandomPick Generate() override {
return absl::Uniform(bitgen_, 0u, max_range_);
}

private:
const RandomPick max_range_;
absl::BitGen bitgen_{};
};

/*
* Generates unique index in O(1).
*
* picks_count specifies the number of random indexes to be generated.
* In other words, this is the number of times the Generate() function is called.
*
* The class uses Robert Floyd's sampling algorithm
* https://dl.acm.org/doi/pdf/10.1145/30401.315746
* */
class UniquePicksGenerator : public PicksGenerator {
public:
UniquePicksGenerator(std::uint32_t picks_count, RandomPick max_range)
: remaining_picks_count_(picks_count), picked_indexes_(picks_count) {
CHECK_GE(max_range, picks_count);
current_random_limit_ = max_range - picks_count;
}

RandomPick Generate() override {
DCHECK_GT(remaining_picks_count_, 0u);

remaining_picks_count_--;

const RandomPick max_index = current_random_limit_++;
const RandomPick random_index = absl::Uniform(bitgen_, 0u, max_index + 1u);

const bool random_index_is_picked = picked_indexes_.emplace(random_index).second;
if (random_index_is_picked) {
return random_index;
}

picked_indexes_.insert(max_index);
return max_index;
}

private:
RandomPick current_random_limit_;
std::uint32_t remaining_picks_count_;
std::unordered_set<RandomPick> picked_indexes_;
absl::BitGen bitgen_{};
};

bool ScoreToLongLat(const std::optional<double>& val, double* xy) {
if (!val.has_value())
return false;
Expand Down Expand Up @@ -1702,6 +1766,48 @@ OpResult<StringVec> OpScan(const OpArgs& op_args, std::string_view key, uint64_t
return res;
}

OpResult<ScoredArray> OpRandMember(int count, const ZSetFamily::RangeParams& params,
const OpArgs& op_args, string_view key) {
auto it = op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_ZSET);
if (!it)
return it.status();

// Action::RANGE is a read-only operation, but requires const_cast
PrimeValue& pv = const_cast<PrimeValue&>(it.value()->second);

const std::size_t size = pv.Size();
const std::size_t picks_count =
count >= 0 ? std::min(static_cast<std::size_t>(count), size) : std::abs(count);

ScoredArray result{picks_count};
std::unique_ptr<PicksGenerator> generator =
count >= 0 ? static_cast<std::unique_ptr<PicksGenerator>>(
std::make_unique<UniquePicksGenerator>(picks_count, size))
: std::make_unique<NonUniquePicksGenerator>(size);

if (picks_count * static_cast<std::uint64_t>(std::log2(size)) < size) {
for (std::size_t i = 0; i < picks_count; i++) {
const std::size_t picked_index = generator->Generate();

IntervalVisitor iv{Action::RANGE, params, &pv};
iv(ZSetFamily::IndexInterval{picked_index, picked_index});

result[i] = iv.PopResult().front();
}
} else {
IntervalVisitor iv{Action::RANGE, params, &pv};
iv(ZSetFamily::IndexInterval{0, -1});

ScoredArray all_elements = iv.PopResult();

for (std::size_t i = 0; i < picks_count; i++) {
result[i] = all_elements[generator->Generate()];
}
}

return result;
}

void ZAddGeneric(string_view key, const ZParams& zparams, ScoredMemberSpan memb_sp,
ConnectionContext* cntx) {
auto cb = [&](Transaction* t, EngineShard* shard) {
Expand Down Expand Up @@ -2323,43 +2429,32 @@ void ZSetFamily::ZRandMember(CmdArgList args, ConnectionContext* cntx) {
if (args.size() > 3)
return cntx->SendError(WrongNumArgsError("ZRANDMEMBER"));

ZRangeSpec range_spec;
range_spec.interval = IndexInterval(0, -1);

CmdArgParser parser{args};
string_view key = parser.Next();

bool is_count = parser.HasNext();
int count = is_count ? parser.Next<int>() : 1;

range_spec.params.with_scores = static_cast<bool>(parser.Check("WITHSCORES").IgnoreCase());
ZSetFamily::RangeParams params;
params.with_scores = static_cast<bool>(parser.Check("WITHSCORES").IgnoreCase());

if (parser.HasNext())
return cntx->SendError(absl::StrCat("Unsupported option:", string_view(parser.Next())));

if (auto err = parser.Error(); err)
return cntx->SendError(err->MakeReply());

bool sign = count < 0;
range_spec.params.limit = std::abs(count);

const auto cb = [&](Transaction* t, EngineShard* shard) {
return OpRange(range_spec, t->GetOpArgs(shard), key);
return OpRandMember(count, params, t->GetOpArgs(shard), key);
};

OpResult<ScoredArray> result = cntx->transaction->ScheduleSingleHopT(cb);
auto* rb = static_cast<RedisReplyBuilder*>(cntx->reply_builder());
if (result) {
if (sign && !result->empty()) {
for (auto i = result->size(); i < range_spec.params.limit; ++i) {
// we can return duplicate elements, so first is OK
result->push_back(result->front());
}
}
rb->SendScoredArray(result.value(), range_spec.params.with_scores);
rb->SendScoredArray(result.value(), params.with_scores);
} else if (result.status() == OpStatus::KEY_NOTFOUND) {
if (is_count) {
rb->SendScoredArray(ScoredArray(), range_spec.params.with_scores);
rb->SendScoredArray(ScoredArray(), params.with_scores);
} else {
rb->SendNull();
}
Expand Down
182 changes: 151 additions & 31 deletions src/server/zset_family_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,84 @@ class ZSetFamilyTest : public BaseFamilyTest {
protected:
};

using ScoredElement = std::pair<std::string, std::string>;

template <typename Array> auto ParseToScoredArray(Array arr) {
std::vector<ScoredElement> scored_elements;
for (std::size_t i = 1; i < arr.size(); i += 2) {
scored_elements.emplace_back(arr[i - 1].GetString(), arr[i].GetString());
}
return scored_elements;
}

MATCHER_P(ConsistsOfMatcher, elements, "") {
auto vec = arg.GetVec();
for (const auto& x : vec) {
if (elements.find(x.GetString()) == elements.end()) {
return false;
}
}
return true;
}

MATCHER_P(ConsistsOfScoredElementsMatcher, elements, "") {
auto vec = arg.GetVec();
if (vec.size() % 2) {
return false;
}

auto scored_vec = ParseToScoredArray(vec);
for (const auto& scored_element : scored_vec) {
if (elements.find(scored_element) == elements.end()) {
return false;
}
}
return true;
}

MATCHER_P(IsScoredSubsetOfMatcher, elements_list, "") {
auto vec = arg.GetVec();
if (vec.size() % 2) {
return false;
}

auto scored_vec = ParseToScoredArray(vec);
std::vector<ScoredElement> elements{elements_list};

std::sort(scored_vec.begin(), scored_vec.end());
std::sort(elements.begin(), elements.end());

return std::includes(elements.begin(), elements.end(), scored_vec.begin(), scored_vec.end());
}

MATCHER_P(UnorderedScoredElementsAreMatcher, elements_list, "") {
auto vec = arg.GetVec();
if (vec.size() % 2) {
return false;
}

auto scored_vec = ParseToScoredArray(vec);
return std::is_permutation(scored_vec.begin(), scored_vec.end(), elements_list.begin(),
elements_list.end());
}

auto ConsistsOf(std::initializer_list<std::string> elements) {
return ConsistsOfMatcher(std::unordered_set<std::string>{elements});
}

auto ConsistsOfScoredElements(std::initializer_list<std::pair<std::string, std::string>> elements) {
return ConsistsOfScoredElementsMatcher(std::set<std::pair<std::string, std::string>>{elements});
}

auto IsScoredSubsetOf(std::initializer_list<std::pair<std::string, std::string>> elements) {
return IsScoredSubsetOfMatcher(elements);
}

auto UnorderedScoredElementsAre(
std::initializer_list<std::pair<std::string, std::string>> elements) {
return UnorderedScoredElementsAreMatcher(elements);
}

TEST_F(ZSetFamilyTest, Add) {
auto resp = Run({"zadd", "x", "1.1", "a"});
EXPECT_THAT(resp, IntArg(1));
Expand Down Expand Up @@ -77,53 +155,95 @@ TEST_F(ZSetFamilyTest, ZRem) {
}

TEST_F(ZSetFamilyTest, ZRandMember) {
auto resp = Run({
"zadd",
"x",
"1",
"a",
"2",
"b",
"3",
"c",
});
auto resp = Run({"ZAdd", "x", "1", "a", "2", "b", "3", "c"});
EXPECT_THAT(resp, IntArg(3));

// Test if count > 0
resp = Run({"ZRandMember", "x"});
ASSERT_THAT(resp, ArgType(RespExpr::STRING));
EXPECT_THAT(resp, "a");
EXPECT_THAT(resp, AnyOf("a", "b", "c"));

resp = Run({"ZRandMember", "x", "1"});
ASSERT_THAT(resp, ArgType(RespExpr::STRING));
EXPECT_THAT(resp, AnyOf("a", "b", "c"));

resp = Run({"ZRandMember", "x", "2"});
ASSERT_THAT(resp, ArgType(RespExpr::ARRAY));
EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("a", "b"));
ASSERT_THAT(resp, ArrLen(2));
EXPECT_THAT(resp.GetVec(), IsSubsetOf({"a", "b", "c"}));

resp = Run({"ZRandMember", "x", "0"});
ASSERT_THAT(resp, ArgType(RespExpr::ARRAY));
EXPECT_EQ(resp.GetVec().size(), 0);
resp = Run({"ZRandMember", "x", "3"});
ASSERT_THAT(resp, ArrLen(3));
EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("a", "b", "c"));

// Test if count < 0
resp = Run({"ZRandMember", "x", "-1"});
ASSERT_THAT(resp, ArgType(RespExpr::STRING));
EXPECT_THAT(resp, AnyOf("a", "b", "c"));

resp = Run({"ZRandMember", "k"});
ASSERT_THAT(resp, ArgType(RespExpr::NIL));
resp = Run({"ZRandMember", "x", "-2"});
ASSERT_THAT(resp, ArrLen(2));
EXPECT_THAT(resp, ConsistsOf({"a", "b", "c"}));

resp = Run({"ZRandMember", "k", "2"});
ASSERT_THAT(resp, ArgType(RespExpr::ARRAY));
EXPECT_EQ(resp.GetVec().size(), 0);
resp = Run({"ZRandMember", "x", "-3"});
ASSERT_THAT(resp, ArrLen(3));
EXPECT_THAT(resp, ConsistsOf({"a", "b", "c"}));

resp = Run({"ZRandMember", "x", "-5"});
ASSERT_THAT(resp, ArrLen(5));
EXPECT_THAT(resp.GetVec(), ElementsAre("a", "b", "c", "a", "a"));
// Test if count < 0, but the absolute value is larger than the size of the sorted set
resp = Run({"ZRandMember", "x", "-15"});
ASSERT_THAT(resp, ArrLen(15));
EXPECT_THAT(resp, ConsistsOf({"a", "b", "c"}));

resp = Run({"ZRandMember", "x", "5"});
// Test if count is 0
ASSERT_THAT(Run({"ZRandMember", "x", "0"}), ArrLen(0));

// Test if count is larger than the size of the sorted set
resp = Run({"ZRandMember", "x", "15"});
ASSERT_THAT(resp, ArrLen(3));
EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("a", "b", "c"));

resp = Run({"ZRandMember", "x", "-5", "WITHSCORES"});
ASSERT_THAT(resp, ArrLen(10));
EXPECT_THAT(resp.GetVec(), ElementsAre("a", "1", "b", "2", "c", "3", "a", "1", "a", "1"));
// Test if sorted set is empty
EXPECT_THAT(Run({"ZAdd", "empty::zset", "1", "one"}), IntArg(1));
EXPECT_THAT(Run({"ZRem", "empty::zset", "one"}), IntArg(1));
ASSERT_THAT(Run({"ZRandMember", "empty::zset", "0"}), ArrLen(0));
ASSERT_THAT(Run({"ZRandMember", "empty::zset", "3"}), ArrLen(0));
ASSERT_THAT(Run({"ZRandMember", "empty::zset", "-4"}), ArrLen(0));

// Test if key does not exist
ASSERT_THAT(Run({"ZRandMember", "y"}), ArgType(RespExpr::NIL));
ASSERT_THAT(Run({"ZRandMember", "y", "0"}), ArrLen(0));

// Test WITHSCORES
resp = Run({"ZRandMember", "x", "1", "WITHSCORES"});
ASSERT_THAT(resp, ArrLen(2));
EXPECT_THAT(resp, IsScoredSubsetOf({{"a", "1"}, {"b", "2"}, {"c", "3"}}));

resp = Run({"ZRandMember", "x", "2", "WITHSCORES"});
ASSERT_THAT(resp, ArrLen(4));
EXPECT_THAT(resp, IsScoredSubsetOf({{"a", "1"}, {"b", "2"}, {"c", "3"}}));

resp = Run({"ZRandMember", "x", "3", "WITHSCORES"});
ASSERT_THAT(resp, ArrLen(6));
EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("a", "1", "b", "2", "c", "3"));
EXPECT_THAT(resp, UnorderedScoredElementsAre({{"a", "1"}, {"b", "2"}, {"c", "3"}}));

resp = Run({"ZRandMember", "x", "3", "WITHSCORES", "test"});
EXPECT_THAT(resp, ErrArg("wrong number of arguments"));
resp = Run({"ZRandMember", "x", "15", "WITHSCORES"});
ASSERT_THAT(resp, ArrLen(6));
EXPECT_THAT(resp, UnorderedScoredElementsAre({{"a", "1"}, {"b", "2"}, {"c", "3"}}));

resp = Run({"ZRandMember", "x", "-1", "WITHSCORES"});
ASSERT_THAT(resp, ArrLen(2));
EXPECT_THAT(resp, ConsistsOfScoredElements({{"a", "1"}, {"b", "2"}, {"c", "3"}}));

resp = Run({"ZRandMember", "x", "-2", "WITHSCORES"});
ASSERT_THAT(resp, ArrLen(4));
EXPECT_THAT(resp, ConsistsOfScoredElements({{"a", "1"}, {"b", "2"}, {"c", "3"}}));

resp = Run({"ZRandMember", "x", "-3", "WITHSCORES"});
ASSERT_THAT(resp, ArrLen(6));
EXPECT_THAT(resp, ConsistsOfScoredElements({{"a", "1"}, {"b", "2"}, {"c", "3"}}));

resp = Run({"ZRandMember", "x", "-15", "WITHSCORES"});
ASSERT_THAT(resp, ArrLen(30));
EXPECT_THAT(resp, ConsistsOfScoredElements({{"a", "1"}, {"b", "2"}, {"c", "3"}}));
}

TEST_F(ZSetFamilyTest, ZMScore) {
Expand Down

0 comments on commit e604df2

Please sign in to comment.