diff --git a/velox/docs/functions/presto/aggregate.rst b/velox/docs/functions/presto/aggregate.rst index c8443653857f..38e798dc324f 100644 --- a/velox/docs/functions/presto/aggregate.rst +++ b/velox/docs/functions/presto/aggregate.rst @@ -411,6 +411,201 @@ __ https://www.cse.ust.hk/~raywong/comp5331/References/EfficientComputationOfFre As ``approx_percentile(x, w, percentages)``, but with a maximum rank error of ``accuracy``. +Classification Metrics Aggregate Functions +------------------------------------------ + +The following functions each measure how some metric of a binary +`confusion matrix `_ changes as a function of +classification thresholds. They are meant to be used in conjunction. + +For example, to find the `precision-recall curve `_, use + + .. code-block:: none + + WITH + recall_precision AS ( + SELECT + CLASSIFICATION_RECALL(10000, correct, pred) AS recalls, + CLASSIFICATION_PRECISION(10000, correct, pred) AS precisions + FROM + classification_dataset + ) + SELECT + recall, + precision + FROM + recall_precision + CROSS JOIN UNNEST(recalls, precisions) AS t(recall, precision) + +To get the corresponding thresholds for these values, use + + .. code-block:: none + + WITH + recall_precision AS ( + SELECT + CLASSIFICATION_THRESHOLDS(10000, correct, pred) AS thresholds, + CLASSIFICATION_RECALL(10000, correct, pred) AS recalls, + CLASSIFICATION_PRECISION(10000, correct, pred) AS precisions + FROM + classification_dataset + ) + SELECT + threshold, + recall, + precision + FROM + recall_precision + CROSS JOIN UNNEST(thresholds, recalls, precisions) AS t(threshold, recall, precision) + +To find the `ROC curve `_, use + + .. code-block:: none + + WITH + fallout_recall AS ( + SELECT + CLASSIFICATION_FALLOUT(10000, correct, pred) AS fallouts, + CLASSIFICATION_RECALL(10000, correct, pred) AS recalls + FROM + classification_dataset + ) + SELECT + fallout + recall, + FROM + recall_fallout + CROSS JOIN UNNEST(fallouts, recalls) AS t(fallout, recall) + + +.. function:: classification_miss_rate(buckets, y, x, weight) -> array + + Computes the miss-rate with up to ``buckets`` number of buckets. Returns + an array of miss-rate values. + + ``y`` should be a boolean outcome value; ``x`` should be predictions, each + between 0 and 1; ``weight`` should be non-negative values, indicating the weight of the instance. + + The + `miss-rate `_ + is defined as a sequence whose :math:`j`-th entry is + + .. math :: + + { + \sum_{i \;|\; x_i \leq t_j \bigwedge y_i = 1} \left[ w_i \right] + \over + \sum_{i \;|\; x_i \leq t_j \bigwedge y_i = 1} \left[ w_i \right] + + + \sum_{i \;|\; x_i > t_j \bigwedge y_i = 1} \left[ w_i \right] + }, + + where :math:`t_j` is the :math:`j`-th smallest threshold, + and :math:`y_i`, :math:`x_i`, and :math:`w_i` are the :math:`i`-th + entries of ``y``, ``x``, and ``weight``, respectively. + +.. function:: classification_miss_rate(buckets, y, x) -> array + + This function is equivalent to the variant of + :func:`!classification_miss_rate` that takes a ``weight``, with a per-item weight of ``1``. + +.. function:: classification_fall_out(buckets, y, x, weight) -> array + + Computes the fall-out with up to ``buckets`` number of buckets. Returns + an array of fall-out values. + + ``y`` should be a boolean outcome value; ``x`` should be predictions, each + between 0 and 1; ``weight`` should be non-negative values, indicating the weight of the instance. + + The + `fall-out `_ + is defined as a sequence whose :math:`j`-th entry is + + .. math :: + + { + \sum_{i \;|\; x_i > t_j \bigwedge y_i = 0} \left[ w_i \right] + \over + \sum_{i \;|\; y_i = 0} \left[ w_i \right] + }, + + where :math:`t_j` is the :math:`j`-th smallest threshold, + and :math:`y_i`, :math:`x_i`, and :math:`w_i` are the :math:`i`-th + entries of ``y``, ``x``, and ``weight``, respectively. + +.. function:: classification_fall_out(buckets, y, x) -> array + + This function is equivalent to the variant of + :func:`!classification_fall_out` that takes a ``weight``, with a per-item weight of ``1``. + +.. function:: classification_precision(buckets, y, x, weight) -> array + + Computes the precision with up to ``buckets`` number of buckets. Returns + an array of precision values. + + ``y`` should be a boolean outcome value; ``x`` should be predictions, each + between 0 and 1; ``weight`` should be non-negative values, indicating the weight of the instance. + + The + `precision `_ + is defined as a sequence whose :math:`j`-th entry is + + .. math :: + + { + \sum_{i \;|\; x_i > t_j \bigwedge y_i = 1} \left[ w_i \right] + \over + \sum_{i \;|\; x_i > t_j} \left[ w_i \right] + }, + + where :math:`t_j` is the :math:`j`-th smallest threshold, + and :math:`y_i`, :math:`x_i`, and :math:`w_i` are the :math:`i`-th + entries of ``y``, ``x``, and ``weight``, respectively. + +.. function:: classification_precision(buckets, y, x) -> array + + This function is equivalent to the variant of + :func:`!classification_precision` that takes a ``weight``, with a per-item weight of ``1``. + +.. function:: classification_recall(buckets, y, x, weight) -> array + + Computes the recall with up to ``buckets`` number of buckets. Returns + an array of recall values. + + ``y`` should be a boolean outcome value; ``x`` should be predictions, each + between 0 and 1; ``weight`` should be non-negative values, indicating the weight of the instance. + + The + `recall `_ + is defined as a sequence whose :math:`j`-th entry is + + .. math :: + + { + \sum_{i \;|\; x_i > t_j \bigwedge y_i = 1} \left[ w_i \right] + \over + \sum_{i \;|\; y_i = 1} \left[ w_i \right] + }, + + where :math:`t_j` is the :math:`j`-th smallest threshold, + and :math:`y_i`, :math:`x_i`, and :math:`w_i` are the :math:`i`-th + entries of ``y``, ``x``, and ``weight``, respectively. + +.. function:: classification_recall(buckets, y, x) -> array + + This function is equivalent to the variant of + :func:`!classification_recall` that takes a ``weight``, with a per-item weight of ``1``. + +.. function:: classification_thresholds(buckets, y, x) -> array + + Computes the thresholds with up to ``buckets`` number of buckets. Returns + an array of threshold values. + + ``y`` should be a boolean outcome value; ``x`` should be predictions, each + between 0 and 1. + + The thresholds are defined as a sequence whose :math:`j`-th entry is the :math:`j`-th smallest threshold. + Statistical Aggregate Functions ------------------------------- diff --git a/velox/docs/functions/presto/coverage.rst b/velox/docs/functions/presto/coverage.rst index 127297b7e197..a4df266e46b2 100644 --- a/velox/docs/functions/presto/coverage.rst +++ b/velox/docs/functions/presto/coverage.rst @@ -325,11 +325,11 @@ Here is a list of all scalar and aggregate Presto functions with functions that :func:`array_duplicates` :func:`dow` :func:`json_extract` :func:`repeat` st_union :func:`bool_and` :func:`rank` :func:`array_except` :func:`doy` :func:`json_extract_scalar` :func:`replace` st_within :func:`bool_or` :func:`row_number` :func:`array_frequency` :func:`e` :func:`json_format` replace_first st_x :func:`checksum` - :func:`array_has_duplicates` :func:`element_at` :func:`json_parse` :func:`reverse` st_xmax classification_fall_out - :func:`array_intersect` :func:`empty_approx_set` :func:`json_size` rgb st_xmin classification_miss_rate - :func:`array_join` :func:`ends_with` key_sampling_percent :func:`round` st_y classification_precision - array_least_frequent enum_key :func:`laplace_cdf` :func:`rpad` st_ymax classification_recall - :func:`array_max` :func:`exp` :func:`last_day_of_month` :func:`rtrim` st_ymin classification_thresholds + :func:`array_has_duplicates` :func:`element_at` :func:`json_parse` :func:`reverse` st_xmax :func: `classification_fall_out` + :func:`array_intersect` :func:`empty_approx_set` :func:`json_size` rgb st_xmin :func: `classification_miss_rate` + :func:`array_join` :func:`ends_with` key_sampling_percent :func:`round` st_y :func: `classification_precision` + array_least_frequent enum_key :func:`laplace_cdf` :func:`rpad` st_ymax :func: `classification_recall` + :func:`array_max` :func:`exp` :func:`last_day_of_month` :func:`rtrim` st_ymin :func: `classification_thresholds` array_max_by expand_envelope :func:`least` scale_qdigest :func:`starts_with` convex_hull_agg :func:`array_min` :func:`f_cdf` :func:`length` :func:`second` :func:`strpos` :func:`corr` array_min_by features :func:`levenshtein_distance` secure_rand :func:`strrpos` :func:`count` diff --git a/velox/functions/prestosql/aggregates/AggregateNames.h b/velox/functions/prestosql/aggregates/AggregateNames.h index 7cf2ff9810d8..27f539693d19 100644 --- a/velox/functions/prestosql/aggregates/AggregateNames.h +++ b/velox/functions/prestosql/aggregates/AggregateNames.h @@ -32,6 +32,11 @@ const char* const kBitwiseXor = "bitwise_xor_agg"; const char* const kBoolAnd = "bool_and"; const char* const kBoolOr = "bool_or"; const char* const kChecksum = "checksum"; +const char* const kClassificationFallout = "classification_fall_out"; +const char* const kClassificationPrecision = "classification_precision"; +const char* const kClassificationRecall = "classification_recall"; +const char* const kClassificationMissRate = "classification_miss_rate"; +const char* const kClassificationThreshold = "classification_thresholds"; const char* const kCorr = "corr"; const char* const kCount = "count"; const char* const kCountIf = "count_if"; diff --git a/velox/functions/prestosql/aggregates/CMakeLists.txt b/velox/functions/prestosql/aggregates/CMakeLists.txt index 986de438474e..29009602ea7f 100644 --- a/velox/functions/prestosql/aggregates/CMakeLists.txt +++ b/velox/functions/prestosql/aggregates/CMakeLists.txt @@ -28,6 +28,7 @@ velox_add_library( CountIfAggregate.cpp CovarianceAggregates.cpp ChecksumAggregate.cpp + ClassificationAggregation.cpp EntropyAggregates.cpp GeometricMeanAggregate.cpp HistogramAggregate.cpp diff --git a/velox/functions/prestosql/aggregates/ClassificationAggregation.cpp b/velox/functions/prestosql/aggregates/ClassificationAggregation.cpp new file mode 100644 index 000000000000..872e906c596c --- /dev/null +++ b/velox/functions/prestosql/aggregates/ClassificationAggregation.cpp @@ -0,0 +1,684 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/common/base/IOUtils.h" +#include "velox/exec/Aggregate.h" +#include "velox/functions/prestosql/aggregates/AggregateNames.h" +#include "velox/vector/FlatVector.h" + +namespace facebook::velox::aggregate::prestosql { +namespace { + +enum class ClassificationType { + kFallout = 0, + kPrecision = 1, + kRecall = 2, + kMissRate = 3, + kThresholds = 4, +}; + +/// Struct to represent the bucket of the FixedDoubleHistogram +/// at a given index. +struct Bucket { + Bucket(double _left, double _right, double _weight) + : left(_left), right(_right), weight(_weight) {} + const double left; + const double right; + const double weight; +}; + +/// Fixed-bucket histogram of weights as doubles. For each bucket, it stores the +/// total weight accumulated. +class FixedDoubleHistogram { + public: + explicit FixedDoubleHistogram(HashStringAllocator* allocator) + : weights_(StlAllocator(allocator)) {} + + void resizeWeights() { + validateParameters(bucketCount_, min_, max_); + weights_.resize(bucketCount_); + } + + /// API to support the case when bucket is created without a bucketCount + /// count. + void tryInit(int64_t bucketCount) { + if (bucketCount_ == -1) { + bucketCount_ = bucketCount; + resizeWeights(); + } + } + + /// Add weight to bucket based on the value of the prediction. + void add(double pred, double weight) { + if (weight == 0) { + return; + } + if (weight < 0) { + VELOX_USER_FAIL("Weight must be non-negative."); + } + if (pred < kMinPredictionValue || pred > kMaxPredictionValue) { + VELOX_USER_FAIL( + "Prediction value must be between {} and {}", + kMinPredictionValue, + kMaxPredictionValue); + } + auto index = getIndexForValue(bucketCount_, min_, max_, pred); + weights_.at(index) += weight; + totalWeights_ += weight; + maxUsedIndex_ = std::max(maxUsedIndex_, index); + } + + /// Returns a bucket in this histogram at a given index. + Bucket getBucket(int64_t index) { + return Bucket( + getLeftValueForIndex(bucketCount_, min_, max_, index), + getRightValueForIndex(bucketCount_, min_, max_, index), + weights_.at(index)); + } + + /// The size of the histogram is represented by maxUsedIndex_, which + /// represents the largest index in the buckets with a non-zero accrued value. + /// This helps us avoid O(n) operation for the size of the histogram. + int64_t size() const { + return maxUsedIndex_ + 1; + } + + int64_t bucketCount() const { + return bucketCount_; + } + + /// The state of the histogram can be serialized into a buffer. The format is + /// represented as [header][bucketCount][min][max][weights]. The header is + /// used to identify the version of the serialization format. The bucketCount, + /// min, and max are used to represent the parameters of the histogram. + /// Weights are the number of weights (equal to number of buckets) in the + /// histogram. + size_t serialize(char* output) const { + VELOX_CHECK(output); + common::OutputByteStream stream(output); + size_t bytesUsed = 0; + stream.append( + reinterpret_cast(&kSerializationVersionHeader), + sizeof(kSerializationVersionHeader)); + bytesUsed += sizeof(kSerializationVersionHeader); + + stream.append( + reinterpret_cast(&bucketCount_), sizeof(bucketCount_)); + bytesUsed += sizeof(bucketCount_); + + stream.append(reinterpret_cast(&min_), sizeof(min_)); + bytesUsed += sizeof(min_); + + stream.append(reinterpret_cast(&max_), sizeof(max_)); + bytesUsed += sizeof(max_); + + for (auto weight : weights_) { + stream.append(reinterpret_cast(&weight), sizeof(weight)); + bytesUsed += sizeof(weight); + } + + return bytesUsed; + } + + /// Merges the current histogram with another histogram represented as a + /// buffer. + void mergeWith(const char* data, size_t expectedSize) { + auto input = common::InputByteStream(data); + deserialize(*this, input, expectedSize); + } + + size_t serializationSize() const { + return sizeof(kSerializationVersionHeader) + sizeof(bucketCount_) + + sizeof(min_) + sizeof(max_) + (weights_.size() * sizeof(double)); + } + + /// This represents the total accrued weights in the bucket. The value is + /// cached to avoid recomputing it every time it is needed. + double totalWeights() const { + return totalWeights_; + } + + private: + /// Deserializes the histogram from a buffer. + static void deserialize( + FixedDoubleHistogram& histogram, + common::InputByteStream& in, + size_t expectedSize) { + if (FOLLY_UNLIKELY(expectedSize < minDeserializedBufferSize())) { + VELOX_USER_FAIL( + "Cannot deserialize FixedDoubleHistogram. Expected size: {}, actual size: {}", + minDeserializedBufferSize(), + expectedSize); + } + + uint8_t version; + in.copyTo(&version, 1); + VELOX_CHECK_EQ(version, kSerializationVersionHeader); + + int64_t bucketCount; + double min; + double max; + in.copyTo(&bucketCount, 1); + in.copyTo(&min, 1); + in.copyTo(&max, 1); + + /// This accounts for the case when the histogram is not initialized yet. + + if (histogram.bucketCount_ == -1) { + histogram.bucketCount_ = bucketCount; + histogram.min_ = min; + histogram.max_ = max; + histogram.resizeWeights(); + } else { + /// When merging histograms, all the parameters except for the values + /// accrued inside the buckets must be the same. + if (histogram.bucketCount_ != bucketCount) { + VELOX_USER_FAIL( + "Bucket count must be constant." + "Left bucket count: {}, right bucket count: {}", + histogram.bucketCount_, + bucketCount); + } + + if (histogram.min_ != min || histogram.max_ != max) { + VELOX_USER_FAIL( + "Cannot merge histograms with different min/max values. " + "Left min: {}, left max: {}, right min: {}, right max: {}", + histogram.min_, + histogram.max_, + min, + max); + } + } + + for (int64_t i = 0; i < bucketCount; ++i) { + double weight; + in.copyTo(&weight, 1); + histogram.weights_[i] += weight; + histogram.totalWeights_ += weight; + if (weight != 0) { + histogram.maxUsedIndex_ = std::max(histogram.maxUsedIndex_, i); + } + } + const size_t bytesRead = sizeof(kSerializationVersionHeader) + + sizeof(bucketCount) + sizeof(min) + sizeof(max) + + (bucketCount * sizeof(double)); + VELOX_CHECK_EQ(bytesRead, expectedSize); + return; + } + + /// The minimium size of a valid buffer to deserialize a histogram. + static constexpr size_t minDeserializedBufferSize() { + return ( + sizeof(kSerializationVersionHeader) + sizeof(int64_t) + sizeof(double) + + /// 2 Reresents the minimum number of buckets. + sizeof(double) + 2 * sizeof(double)); + } + + /// Returns the index of the bucket in the histogram that contains the + /// value. This is done by mapping value to [min, max) and then mapping that + /// value to the corresponding bucket. + static int64_t + getIndexForValue(int64_t bucketCount, double min, double max, double value) { + VELOX_CHECK( + value >= min && value <= max, + fmt::format( + "Value must be within range: {} [{}, {}]", value, min, max)); + return std::min( + static_cast((bucketCount * (value - min)) / (max - min)), + bucketCount - 1); + } + + static double getLeftValueForIndex( + int64_t bucketCount, + double min, + double max, + int64_t index) { + return min + index * (max - min) / bucketCount; + } + + static double getRightValueForIndex( + int64_t bucketCount, + double min, + double max, + int64_t index) { + return std::min( + max, getLeftValueForIndex(bucketCount, min, max, index + 1)); + } + + void validateParameters(int64_t bucketCount, double min, double max) { + VELOX_CHECK_LE(bucketCount, weights_.max_size()); + + if (bucketCount < 2) { + VELOX_USER_FAIL("Bucket count must be at least 2.0"); + } + + if (min >= max) { + VELOX_USER_FAIL("Min must be less than max. Min: {}, max: {}", min, max); + } + } + + static constexpr double kMinPredictionValue = 0.0; + static constexpr double kMaxPredictionValue = 1.0; + static constexpr uint8_t kSerializationVersionHeader = 1; + std::vector> weights_; + double totalWeights_{0}; + int64_t bucketCount_{-1}; + double min_{0}; + double max_{1.0}; + int64_t maxUsedIndex_{-1}; +}; + +template +struct Accumulator { + explicit Accumulator(HashStringAllocator* allocator) + : trueWeights_(allocator), falseWeights_(allocator) {} + + void + setWeights(int64_t bucketCount, bool outcome, double pred, double weight) { + VELOX_CHECK_EQ(bucketCount, trueWeights_.bucketCount()); + VELOX_CHECK_EQ(bucketCount, falseWeights_.bucketCount()); + + /// Similar to Java Presto, the max prediction value for the histogram + /// is set to be 0.99999999999 in order to ensure bin corresponding to 1 + /// is not reached. + static const double kMaxPredictionValue = 0.99999999999; + pred = std::min(pred, kMaxPredictionValue); + outcome ? trueWeights_.add(pred, weight) : falseWeights_.add(pred, weight); + } + + void tryInit(int64_t bucketCount) { + trueWeights_.tryInit(bucketCount); + falseWeights_.tryInit(bucketCount); + } + + vector_size_t size() const { + return trueWeights_.size(); + } + + size_t serialize(char* output) const { + size_t bytes = trueWeights_.serialize(output); + return bytes + falseWeights_.serialize(output + bytes); + } + + size_t serializationSize() const { + return trueWeights_.serializationSize() + falseWeights_.serializationSize(); + } + + void mergeWith(StringView serialized) { + auto input = serialized.data(); + VELOX_CHECK_EQ(serialized.size() % 2, 0); + const size_t bufferSize = serialized.size() / 2; + trueWeights_.mergeWith(input, bufferSize); + falseWeights_.mergeWith(input + serialized.size() / 2, bufferSize); + } + + void extractValues(FlatVector* flatResult, vector_size_t offset) { + const double totalTrueWeight = trueWeights_.totalWeights(); + const double totalFalseWeight = falseWeights_.totalWeights(); + + double runningFalseWeight = 0; + double runningTrueWeight = 0; + int64_t trueWeightIndex = 0; + while (trueWeightIndex < trueWeights_.bucketCount() && + totalTrueWeight > runningTrueWeight) { + auto trueBucketResult = trueWeights_.getBucket(trueWeightIndex); + auto falseBucketResult = falseWeights_.getBucket(trueWeightIndex); + + const double falsePositive = totalFalseWeight - runningFalseWeight; + const double negative = totalFalseWeight; + + if constexpr (type == ClassificationType::kFallout) { + flatResult->set(offset + trueWeightIndex, falsePositive / negative); + } else if constexpr (type == ClassificationType::kPrecision) { + const double truePositive = (totalTrueWeight - runningTrueWeight); + const double totalPositives = truePositive + falsePositive; + flatResult->set( + offset + trueWeightIndex, truePositive / totalPositives); + } else if constexpr (type == ClassificationType::kRecall) { + const double truePositive = (totalTrueWeight - runningTrueWeight); + flatResult->set( + offset + trueWeightIndex, truePositive / totalTrueWeight); + } else if constexpr (type == ClassificationType::kMissRate) { + flatResult->set( + offset + trueWeightIndex, runningTrueWeight / totalTrueWeight); + } else if constexpr (type == ClassificationType::kThresholds) { + flatResult->set(offset + trueWeightIndex, trueBucketResult.left); + } else { + VELOX_UNREACHABLE("Not expected to be called."); + } + + runningTrueWeight += trueBucketResult.weight; + runningFalseWeight += falseBucketResult.weight; + trueWeightIndex += 1; + } + } + + private: + FixedDoubleHistogram trueWeights_; + FixedDoubleHistogram falseWeights_; +}; + +template +class ClassificationAggregation : public exec::Aggregate { + public: + explicit ClassificationAggregation( + TypePtr resultType, + bool useDefaultWeight = false) + : Aggregate(std::move(resultType)), useDefaultWeight_(useDefaultWeight) {} + + int32_t accumulatorFixedWidthSize() const override { + return sizeof(Accumulator); + } + + bool isFixedSize() const override { + return false; + } + + void addSingleGroupRawInput( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) override { + decodeArguments(rows, args); + auto accumulator = value>(group); + + auto tracker = trackRowSize(group); + rows.applyToSelected([&](auto row) { + if (decodedBuckets_.isNullAt(row) || decodedOutcome_.isNullAt(row) || + decodedPred_.isNullAt(row) || + (!useDefaultWeight_ && decodedWeight_.isNullAt(row))) { + return; + } + clearNull(group); + accumulator->tryInit(decodedBuckets_.valueAt(row)); + accumulator->setWeights( + decodedBuckets_.valueAt(row), + decodedOutcome_.valueAt(row), + decodedPred_.valueAt(row), + useDefaultWeight_ ? 1.0 : decodedWeight_.valueAt(row)); + }); + } + + // Step 4. + void addRawInput( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) override { + decodeArguments(rows, args); + + rows.applyToSelected([&](vector_size_t row) { + if (decodedBuckets_.isNullAt(row) || decodedOutcome_.isNullAt(row) || + decodedPred_.isNullAt(row) || + (!useDefaultWeight_ && decodedWeight_.isNullAt(row))) { + return; + } + + auto& group = groups[row]; + auto tracker = trackRowSize(group); + + clearNull(group); + auto* accumulator = value>(group); + accumulator->tryInit(decodedBuckets_.valueAt(row)); + + accumulator->setWeights( + decodedBuckets_.valueAt(row), + decodedOutcome_.valueAt(row), + decodedPred_.valueAt(row), + useDefaultWeight_ ? 1.0 : decodedWeight_.valueAt(row)); + }); + } + + // Step 5. + void extractAccumulators(char** groups, int32_t numGroups, VectorPtr* result) + override { + VELOX_CHECK(result); + auto flatResult = (*result)->asFlatVector(); + flatResult->resize(numGroups); + + uint64_t* rawNulls = nullptr; + if (flatResult->mayHaveNulls()) { + BufferPtr& nulls = flatResult->mutableNulls(flatResult->size()); + rawNulls = nulls->asMutable(); + } + + for (auto i = 0; i < numGroups; ++i) { + auto group = groups[i]; + if (isNull(group)) { + flatResult->setNull(i, true); + continue; + } + + if (rawNulls) { + bits::clearBit(rawNulls, i); + } + auto accumulator = value>(group); + auto serializationSize = accumulator->serializationSize(); + char* rawBuffer = + flatResult->getRawStringBufferWithSpace(serializationSize); + + VELOX_CHECK_EQ(accumulator->serialize(rawBuffer), serializationSize); + auto sv = StringView(rawBuffer, serializationSize); + flatResult->set(i, std::move(sv)); + } + } + + void extractValues(char** groups, int32_t numGroups, VectorPtr* result) + override { + auto vector = (*result)->as(); + VELOX_CHECK(vector); + vector->resize(numGroups); + + vector_size_t numValues = 0; + uint64_t* rawNulls = getRawNulls(result->get()); + + for (auto i = 0; i < numGroups; ++i) { + auto* group = groups[i]; + auto* accumulator = value>(group); + const auto size = accumulator->size(); + if (isNull(group)) { + vector->setNull(i, true); + continue; + } + + clearNull(rawNulls, i); + numValues += size; + } + + auto flatResults = vector->elements()->asFlatVector(); + flatResults->resize(numValues); + + auto* rawOffsets = vector->offsets()->asMutable(); + auto* rawSizes = vector->sizes()->asMutable(); + + vector_size_t offset = 0; + for (auto i = 0; i < numGroups; ++i) { + auto* group = groups[i]; + + if (isNull(group)) { + continue; + } + auto* accumulator = value>(group); + const vector_size_t size = accumulator->size(); + + rawOffsets[i] = offset; + rawSizes[i] = size; + + accumulator->extractValues(flatResults, offset); + + offset += size; + } + } + + void addIntermediateResults( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) override { + VELOX_CHECK_EQ(args.size(), 1); + decodedAcc_.decode(*args[0], rows); + + rows.applyToSelected([&](auto row) { + if (decodedAcc_.isNullAt(row)) { + return; + } + + auto group = groups[row]; + auto tracker = trackRowSize(group); + clearNull(group); + + auto serialized = decodedAcc_.valueAt(row); + + auto accumulator = value>(group); + accumulator->mergeWith(serialized); + }); + } + + void addSingleGroupIntermediateResults( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool /* mayPushdown */) override { + VELOX_CHECK_EQ(args.size(), 1); + decodedAcc_.decode(*args[0], rows); + auto tracker = trackRowSize(group); + + rows.applyToSelected([&](auto row) { + if (decodedAcc_.isNullAt(row)) { + return; + } + + clearNull(group); + + auto serialized = decodedAcc_.valueAt(row); + + auto accumulator = value>(group); + accumulator->mergeWith(serialized); + }); + } + + protected: + void initializeNewGroupsInternal( + char** groups, + folly::Range indices) override { + exec::Aggregate::setAllNulls(groups, indices); + for (auto i : indices) { + auto group = groups[i]; + new (group + offset_) Accumulator(allocator_); + } + } + + void destroyInternal(folly::Range groups) override { + destroyAccumulators>(groups); + } + + private: + void decodeArguments( + const SelectivityVector& rows, + const std::vector& args) { + decodedBuckets_.decode(*args[0], rows, true); + decodedOutcome_.decode(*args[1], rows, true); + decodedPred_.decode(*args[2], rows, true); + if (!useDefaultWeight_) { + decodedWeight_.decode(*args[3], rows, true); + } + } + + DecodedVector decodedAcc_; + DecodedVector decodedBuckets_; + DecodedVector decodedOutcome_; + DecodedVector decodedPred_; + DecodedVector decodedWeight_; + const bool useDefaultWeight_{false}; +}; + +template +void registerAggregateFunctionImpl( + const std::string& name, + bool withCompanionFunctions, + bool overwrite, + const std::vector>& + signatures) { + exec::registerAggregateFunction( + name, + signatures, + [](core::AggregationNode::Step, + const std::vector& args, + const TypePtr& resultType, + const core::QueryConfig& /*config*/) + -> std::unique_ptr { + if (args.size() == 4) { + return std::make_unique>(resultType); + } else { + return std::make_unique>( + resultType, true); + } + }, + withCompanionFunctions, + overwrite); +} +} // namespace + +void registerClassificationFunctions( + const std::string& prefix, + bool withCompanionFunctions, + bool overwrite) { + const auto signatures = + std::vector>{ + exec::AggregateFunctionSignatureBuilder() + .returnType("array(double)") + .intermediateType("varbinary") + .argumentType("bigint") + .argumentType("boolean") + .argumentType("double") + .build(), + exec::AggregateFunctionSignatureBuilder() + .returnType("array(double)") + .intermediateType("varbinary") + .argumentType("bigint") + .argumentType("boolean") + .argumentType("double") + .argumentType("double") + .build()}; + registerAggregateFunctionImpl( + prefix + kClassificationFallout, + withCompanionFunctions, + overwrite, + signatures); + registerAggregateFunctionImpl( + prefix + kClassificationPrecision, + withCompanionFunctions, + overwrite, + signatures); + registerAggregateFunctionImpl( + prefix + kClassificationRecall, + withCompanionFunctions, + overwrite, + signatures); + registerAggregateFunctionImpl( + prefix + kClassificationMissRate, + withCompanionFunctions, + overwrite, + signatures); + registerAggregateFunctionImpl( + prefix + kClassificationThreshold, + withCompanionFunctions, + overwrite, + signatures); +} + +} // namespace facebook::velox::aggregate::prestosql diff --git a/velox/functions/prestosql/aggregates/RegisterAggregateFunctions.cpp b/velox/functions/prestosql/aggregates/RegisterAggregateFunctions.cpp index 53bf0ce22ba9..dbf1ee361cb2 100644 --- a/velox/functions/prestosql/aggregates/RegisterAggregateFunctions.cpp +++ b/velox/functions/prestosql/aggregates/RegisterAggregateFunctions.cpp @@ -47,6 +47,10 @@ extern void registerChecksumAggregate( const std::string& prefix, bool withCompanionFunctions, bool overwrite); +extern void registerClassificationFunctions( + const std::string& prefix, + bool withCompanionFunctions, + bool overwrite); extern void registerCountAggregate( const std::string& prefix, bool withCompanionFunctions, @@ -165,6 +169,7 @@ void registerAllAggregateFunctions( registerBoolAggregates(prefix, withCompanionFunctions, overwrite); registerCentralMomentsAggregates(prefix, withCompanionFunctions, overwrite); registerChecksumAggregate(prefix, withCompanionFunctions, overwrite); + registerClassificationFunctions(prefix, withCompanionFunctions, overwrite); registerCountAggregate(prefix, withCompanionFunctions, overwrite); registerCountIfAggregate(prefix, withCompanionFunctions, overwrite); registerCovarianceAggregates(prefix, withCompanionFunctions, overwrite); diff --git a/velox/functions/prestosql/aggregates/tests/CMakeLists.txt b/velox/functions/prestosql/aggregates/tests/CMakeLists.txt index 85265e3613ee..293e1e096367 100644 --- a/velox/functions/prestosql/aggregates/tests/CMakeLists.txt +++ b/velox/functions/prestosql/aggregates/tests/CMakeLists.txt @@ -25,6 +25,7 @@ add_executable( BoolAndOrTest.cpp CentralMomentsAggregationTest.cpp ChecksumAggregateTest.cpp + ClassificationAggregationTest.cpp CountAggregationTest.cpp CountDistinctTest.cpp CountIfAggregationTest.cpp diff --git a/velox/functions/prestosql/aggregates/tests/ClassificationAggregationTest.cpp b/velox/functions/prestosql/aggregates/tests/ClassificationAggregationTest.cpp new file mode 100644 index 000000000000..ec5edecfc401 --- /dev/null +++ b/velox/functions/prestosql/aggregates/tests/ClassificationAggregationTest.cpp @@ -0,0 +1,247 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/functions/lib/aggregates/tests/utils/AggregationTestBase.h" + +using namespace facebook::velox::functions::aggregate::test; + +namespace facebook::velox::aggregate::test { +namespace { + +class ClassificationAggregationTest : public AggregationTestBase { + protected: + void SetUp() override { + AggregationTestBase::SetUp(); + } +}; + +TEST_F(ClassificationAggregationTest, basic) { + auto runTest = [&](const std::string& expression, + RowVectorPtr input, + RowVectorPtr expected) { + testAggregations({input}, {}, {expression}, {expected}); + }; + + /// Test without any nulls. + auto input = makeRowVector({ + makeNullableFlatVector( + {true, false, true, false, false, false, false, false, true, false}), + makeNullableFlatVector( + {0.1, 0.2, 0.3, 0.3, 0.3, 0.3, 0.7, 1.0, 0.5, 0.5}), + }); + + /// Fallout test. + auto expected = makeRowVector({ + makeArrayVector({{1.0, 1.0, 3.0 / 7}}), + }); + runTest("classification_fall_out(5, c0, c1)", input, expected); + + /// Precision test. + expected = makeRowVector({ + makeArrayVector({{0.3, 2.0 / 9, 0.25}}), + }); + runTest("classification_precision(5, c0, c1)", input, expected); + + /// Recall test. + expected = makeRowVector({ + makeArrayVector({{1.0, 2.0 / 3, 1.0 / 3}}), + }); + runTest("classification_recall(5, c0, c1)", input, expected); + + /// Miss rate test. + expected = makeRowVector({ + makeArrayVector({{0, 1.0 / 3, 2.0 / 3}}), + }); + runTest("classification_miss_rate(5, c0, c1)", input, expected); + + /// Thresholds test. + expected = makeRowVector({ + makeArrayVector({{0, 0.2, 0.4}}), + }); + runTest("classification_thresholds(5, c0, c1)", input, expected); + + /// Test with some nulls. + input = makeRowVector({ + makeNullableFlatVector( + {std::nullopt, + false, + true, + false, + false, + false, + false, + false, + std::nullopt, + false}), + makeNullableFlatVector( + {0.1, 0.2, 0.3, 0.3, 0.3, 0.3, 0.7, 1.0, std::nullopt, std::nullopt}), + }); + + /// Fallout test. + expected = makeRowVector({makeArrayVector({{1.0, 1.0}})}); + runTest("classification_fall_out(5, c0, c1)", input, expected); + + /// Precision test. + expected = makeRowVector({makeArrayVector({{1.0 / 7, 1.0 / 7}})}); + runTest("classification_precision(5, c0, c1)", input, expected); + + /// Recall test. + expected = makeRowVector({makeArrayVector({{1, 1}})}); + runTest("classification_recall(5, c0, c1)", input, expected); + + /// Miss rate test. + expected = makeRowVector({makeArrayVector({{0, 0}})}); + runTest("classification_miss_rate(5, c0, c1)", input, expected); + + /// Thresholds test. + expected = makeRowVector({makeArrayVector({{0, 0.2}})}); + runTest("classification_thresholds(5, c0, c1)", input, expected); + + /// Test with all nulls. + input = makeRowVector({ + makeNullableFlatVector({std::nullopt, std::nullopt}), + makeNullableFlatVector({std::nullopt, std::nullopt}), + }); + + expected = makeRowVector({makeNullableArrayVector( + std::vector>>>{ + {std::nullopt}})}); + runTest("classification_fall_out(5, c0, c1)", input, expected); + runTest("classification_precision(5, c0, c1)", input, expected); + runTest("classification_recall(5, c0, c1)", input, expected); + runTest("classification_miss_rate(5, c0, c1)", input, expected); + runTest("classification_thresholds(5, c0, c1)", input, expected); + + /// Test invalid bucket count test + input = makeRowVector({ + makeNullableFlatVector({true}), + makeNullableFlatVector({1.0}), + }); + + constexpr std::array functions = { + "classification_fall_out", + "classification_precision", + "classification_recall", + "classification_miss_rate", + "classification_thresholds"}; + + /// Test invalid bucket count. + constexpr std::array invalidBuckets = {0, 1}; + for (const auto bucket : invalidBuckets) { + for (const auto function : functions) { + VELOX_ASSERT_THROW( + runTest( + fmt::format("{}({}, {}, {})", function, bucket, "c0", "c1"), + input, + expected), + "Bucket count must be at least 2.0"); + } + } + + /// Test invalid threshold. + for (const auto function : functions) { + VELOX_ASSERT_THROW( + runTest( + fmt::format("{}(5, {}, {}, {})", function, "c0", "c1", -0.1), + input, + expected), + "Weight must be non-negative."); + } + + /// Test invalid predictions. Note, a prediction of > 1 + /// will never actually be hit because convert the pred = std::min(pred, + /// 0.99999999999) + for (const auto function : functions) { + VELOX_ASSERT_THROW( + runTest( + fmt::format("{}({}, {}, {})", function, 5, "c0", -0.1), + input, + expected), + "Prediction value must be between 0 and 1"); + } +} + +TEST_F(ClassificationAggregationTest, groupBy) { + auto input = makeRowVector({ + makeNullableFlatVector({0, 0, 1, 1, 1, 2, 2, 2, 2, 3, 3}), + makeNullableFlatVector( + {true, + false, + true, + false, + false, + false, + false, + false, + true, + true, + false}), + makeNullableFlatVector( + {0.1, 0.2, 0.3, 0.3, 0.3, 0.3, 0.7, 1.0, 1.0, 0.5, 0.5}), + }); + + auto runTest = [this]( + const std::string& expression, + RowVectorPtr input, + RowVectorPtr expected) { + testAggregations({input}, {"c0"}, {expression}, {expected}); + }; + auto keys = makeFlatVector({0, 1, 2, 3}); + runTest( + "classification_fall_out(5, c1, c2)", + input, + makeRowVector({ + keys, + makeArrayVector( + {{{1}, {1, 1}, {1, 1, 2.0 / 3, 2.0 / 3, 1.0 / 3}, {1, 1, 1}}}), + })); + runTest( + "classification_precision(5, c1, c2)", + input, + makeRowVector({ + keys, + makeArrayVector( + {{{0.5}, + {1.0 / 3, 1.0 / 3}, + {0.25, 0.25, 1.0 / 3, 1.0 / 3, 0.5}, + {0.5, 0.5, 0.5}}}), + })); + runTest( + "classification_recall(5, c1, c2)", + input, + makeRowVector({ + keys, + makeArrayVector({{{1}, {1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1}}}), + })); + runTest( + "classification_miss_rate(5, c1, c2)", + input, + makeRowVector({ + keys, + makeArrayVector({{{0}, {0, 0}, {0, 0, 0, 0, 0}, {0, 0, 0}}}), + })); + runTest( + "classification_thresholds(5, c1, c2)", + input, + makeRowVector({ + keys, + makeArrayVector( + {{{0}, {0, 0.2}, {0, 0.2, 0.4, 0.6, 0.8}, {0, 0.2, 0.4}}}), + })); +} + +} // namespace +} // namespace facebook::velox::aggregate::test diff --git a/velox/functions/prestosql/fuzzer/AggregationFuzzerTest.cpp b/velox/functions/prestosql/fuzzer/AggregationFuzzerTest.cpp index 12650de90df4..de75197d2ade 100644 --- a/velox/functions/prestosql/fuzzer/AggregationFuzzerTest.cpp +++ b/velox/functions/prestosql/fuzzer/AggregationFuzzerTest.cpp @@ -29,6 +29,7 @@ #include "velox/functions/prestosql/fuzzer/ApproxPercentileInputGenerator.h" #include "velox/functions/prestosql/fuzzer/ApproxPercentileResultVerifier.h" #include "velox/functions/prestosql/fuzzer/ArbitraryResultVerifier.h" +#include "velox/functions/prestosql/fuzzer/ClassificationAggregationInputGenerator.h" #include "velox/functions/prestosql/fuzzer/MapUnionSumInputGenerator.h" #include "velox/functions/prestosql/fuzzer/MinMaxByResultVerifier.h" #include "velox/functions/prestosql/fuzzer/MinMaxInputGenerator.h" @@ -76,7 +77,16 @@ getCustomInputGenerators() { {"approx_set", std::make_shared()}, {"approx_percentile", std::make_shared()}, {"map_union_sum", std::make_shared()}, - }; + {"classification_fall_out", + std::make_shared()}, + {"classification_precision", + std::make_shared()}, + {"classification_recall", + std::make_shared()}, + {"classification_miss_rate", + std::make_shared()}, + {"classification_thresholds", + std::make_shared()}}; } } // namespace diff --git a/velox/functions/prestosql/fuzzer/ClassificationAggregationInputGenerator.h b/velox/functions/prestosql/fuzzer/ClassificationAggregationInputGenerator.h new file mode 100644 index 000000000000..7347b344b892 --- /dev/null +++ b/velox/functions/prestosql/fuzzer/ClassificationAggregationInputGenerator.h @@ -0,0 +1,80 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include "velox/exec/fuzzer/InputGenerator.h" + +namespace facebook::velox::exec::test { + +class ClassificationAggregationInputGenerator : public InputGenerator { + public: + std::vector generate( + const std::vector& types, + VectorFuzzer& fuzzer, + FuzzerGenerator& rng, + memory::MemoryPool* pool) override { + std::vector result; + result.reserve(types.size()); + + VELOX_CHECK( + types.size() == 3 || types.size() == 4, + fmt::format("Unexpected types count:{}", types.size())); + + VELOX_CHECK( + types[0]->isBigint(), "Unexpected type: {}", types[0]->toString()); + VELOX_CHECK( + types[2]->isDouble(), "Unexpected type: {}", types[2]->toString()); + + const auto size = fuzzer.getOptions().vectorSize; + velox::test::VectorMaker vectorMaker{pool}; + auto bucket = vectorMaker.flatVector(size, [&](auto /*row*/) { + /// The bucket must be the same everytime or else classification + /// aggregation function considers this an invalid input. Moreover, the + /// buckets are capped to 50'000 to prevent OOM-ing issues. The minimum is + /// 2 since that is the minimum valid bucket count. + static auto bucket = + boost::random::uniform_int_distribution(2, 50000)(rng); + return bucket; + }); + + auto pred = vectorMaker.flatVector(size, [&](auto /*row*/) { + /// Predictions must be > 0. + return std::uniform_real_distribution( + 0, std::numeric_limits::max())(rng); + }); + + result.emplace_back(std::move(bucket)); + result.emplace_back(fuzzer.fuzz(types[1])); + result.emplace_back(std::move(pred)); + if (types.size() == 4) { + VELOX_CHECK( + types[3]->isDouble(), "Unexpected type: {}", types[3]->toString()); + + auto weight = vectorMaker.flatVector(size, [&](auto /*row*/) { + /// Weights must be > 0. + return std::uniform_real_distribution( + 0, std::numeric_limits::max())(rng); + }); + result.emplace_back(std::move(weight)); + } + return result; + } + + void reset() override {} +}; + +} // namespace facebook::velox::exec::test diff --git a/velox/functions/prestosql/fuzzer/WindowFuzzerTest.cpp b/velox/functions/prestosql/fuzzer/WindowFuzzerTest.cpp index f7d863d7b809..d67008d2f005 100644 --- a/velox/functions/prestosql/fuzzer/WindowFuzzerTest.cpp +++ b/velox/functions/prestosql/fuzzer/WindowFuzzerTest.cpp @@ -25,6 +25,7 @@ #include "velox/functions/prestosql/fuzzer/ApproxDistinctResultVerifier.h" #include "velox/functions/prestosql/fuzzer/ApproxPercentileInputGenerator.h" #include "velox/functions/prestosql/fuzzer/ApproxPercentileResultVerifier.h" +#include "velox/functions/prestosql/fuzzer/ClassificationAggregationInputGenerator.h" #include "velox/functions/prestosql/fuzzer/MinMaxInputGenerator.h" #include "velox/functions/prestosql/fuzzer/WindowOffsetInputGenerator.h" #include "velox/functions/prestosql/registration/RegistrationFunctions.h" @@ -74,7 +75,16 @@ getCustomInputGenerators() { {"lag", std::make_shared(1)}, {"nth_value", std::make_shared(1)}, {"ntile", std::make_shared(0)}, - }; + {"classification_fall_out", + std::make_shared()}, + {"classification_precision", + std::make_shared()}, + {"classification_recall", + std::make_shared()}, + {"classification_miss_rate", + std::make_shared()}, + {"classification_thresholds", + std::make_shared()}}; } } // namespace @@ -158,6 +168,11 @@ int main(int argc, char** argv) { "any_value", "arbitrary", "array_agg", + "classification_fall_out", + "classification_precision", + "classification_recall", + "classification_miss_rate", + "classification_thresholds", "set_agg", "set_union", "map_agg",