diff --git a/velox/functions/prestosql/tests/UuidFunctionsTest.cpp b/velox/functions/prestosql/tests/UuidFunctionsTest.cpp index ada3b64533894..cce14e6917ee5 100644 --- a/velox/functions/prestosql/tests/UuidFunctionsTest.cpp +++ b/velox/functions/prestosql/tests/UuidFunctionsTest.cpp @@ -63,10 +63,13 @@ TEST_F(UuidFunctionsTest, castAsVarchar) { // Verify that CAST results as the same as boost::lexical_cast. We do not use // boost::lexical_cast to implement CAST because it is too slow. auto expected = makeFlatVector(size, [&](auto row) { - const auto uuid = uuids->valueAt(row); + auto uuid = uuids->valueAt(row); + auto charPtr = reinterpret_cast(&uuid); boost::uuids::uuid u; - memcpy(&u, &uuid, 16); + for (size_t i = 0; i < 16; ++i) { + u.data[15 - i] = charPtr[i]; + } return boost::lexical_cast(u); }); diff --git a/velox/functions/prestosql/types/UuidType.cpp b/velox/functions/prestosql/types/UuidType.cpp index 8d0b5b5b22fa0..dd4b8a5a05481 100644 --- a/velox/functions/prestosql/types/UuidType.cpp +++ b/velox/functions/prestosql/types/UuidType.cpp @@ -97,8 +97,8 @@ class UuidCastOperator : public exec::CastOperator { size_t offset = 0; for (auto i = 0; i < 16; ++i) { - result.data()[offset] = kHexTable[uuidBytes[i] * 2]; - result.data()[offset + 1] = kHexTable[uuidBytes[i] * 2 + 1]; + result.data()[offset] = kHexTable[uuidBytes[15 - i] * 2]; + result.data()[offset + 1] = kHexTable[uuidBytes[15 - i] * 2 + 1]; offset += 2; if (i == 3 || i == 5 || i == 7 || i == 9) { @@ -125,7 +125,10 @@ class UuidCastOperator : public exec::CastOperator { auto uuid = boost::lexical_cast(uuidString); int128_t u; - memcpy(&u, &uuid, 16); + auto charPtr = reinterpret_cast(&u); + for (size_t i = 0; i < 16; ++i) { + charPtr[i] = uuid.data[15 - i]; + } flatResult->set(row, u); }); diff --git a/velox/serializers/PrestoSerializer.cpp b/velox/serializers/PrestoSerializer.cpp index 4036e47912667..5f3c8e5da8e36 100644 --- a/velox/serializers/PrestoSerializer.cpp +++ b/velox/serializers/PrestoSerializer.cpp @@ -15,6 +15,7 @@ */ #include "velox/serializers/PrestoSerializer.h" +#include #include #include @@ -22,6 +23,7 @@ #include "velox/common/base/Crc.h" #include "velox/common/base/RawVector.h" #include "velox/common/memory/ByteStream.h" +#include "velox/functions/prestosql/types/UuidType.h" #include "velox/vector/BiasVector.h" #include "velox/vector/ComplexVector.h" #include "velox/vector/DictionaryVector.h" @@ -442,6 +444,42 @@ void readDecimalValues( } } +int128_t readUuidValue(ByteInputStream* source) { + // ByteInputStream does not support reading int128_t values. + // UUIDs are serialized as 2 int64 values with msb int64 value first. + auto high = source->read(); + auto low = source->read(); + return HugeInt::build(high, low); +} + +void readUuidValues( + ByteInputStream* source, + vector_size_t size, + vector_size_t offset, + const BufferPtr& nulls, + vector_size_t nullCount, + const BufferPtr& values) { + auto rawValues = values->asMutable(); + if (nullCount) { + checkValuesSize(values, nulls, size, offset); + + int32_t toClear = offset; + bits::forEachSetBit( + nulls->as(), offset, offset + size, [&](int32_t row) { + // Set the values between the last non-null and this to type default. + for (; toClear < row; ++toClear) { + rawValues[toClear] = 0; + } + rawValues[row] = readUuidValue(source); + toClear = row + 1; + }); + } else { + for (int32_t row = 0; row < size; ++row) { + rawValues[offset + row] = readUuidValue(source); + } + } +} + /// When deserializing vectors under row vectors that introduce /// nulls, the child vector must have a gap at the place where a /// parent RowVector has a null. So, if there is a parent RowVector @@ -565,6 +603,16 @@ void read( values); return; } + if (isUuidType(type)) { + readUuidValues( + source, + numNewValues, + resultOffset, + flatResult->nulls(), + nullCount, + values); + return; + } readValues( source, numNewValues, @@ -1364,6 +1412,7 @@ class VectorStream { useLosslessTimestamp_(opts.useLosslessTimestamp), nullsFirst_(opts.nullsFirst), isLongDecimal_(type_->isLongDecimal()), + isUuid_(isUuidType(type_)), opts_(opts), encoding_(getEncoding(encoding, vector)), nulls_(streamArena, true, true), @@ -1709,6 +1758,10 @@ class VectorStream { return isLongDecimal_; } + bool isUuid() const { + return isUuid_; + } + void clear() { encoding_ = std::nullopt; initializeHeader(typeToEncodingName(type_), *streamArena_); @@ -1784,6 +1837,7 @@ class VectorStream { const bool useLosslessTimestamp_; const bool nullsFirst_; const bool isLongDecimal_; + const bool isUuid_; const SerdeOpts opts_; std::optional encoding_; int32_t nonNullCount_{0}; @@ -1841,6 +1895,14 @@ FOLLY_ALWAYS_INLINE int128_t toJavaDecimalValue(int128_t value) { return value; } +FOLLY_ALWAYS_INLINE int128_t toJavaUuidValue(int128_t value) { + // Presto Java UuidType uses java.util.UUID that expects 2 long values + // with most significant bits first, swap upper and lower to adjust. + auto low = HugeInt::upper(value); + auto high = HugeInt::lower(value); + return HugeInt::build(high, low); +} + template <> void VectorStream::append(folly::Range values) { for (auto& value : values) { @@ -1848,6 +1910,9 @@ void VectorStream::append(folly::Range values) { if (isLongDecimal_) { val = toJavaDecimalValue(value); } + else if (isUuid_) { + val = toJavaUuidValue(value); + } values_.append(folly::Range(&val, 1)); } } @@ -2392,7 +2457,8 @@ void copyWords( const int32_t* indices, int32_t numIndices, const T* values, - bool isLongDecimal = false) { + bool isLongDecimal = false, + bool isUuid = false) { if (std::is_same_v && isLongDecimal) { for (auto i = 0; i < numIndices; ++i) { reinterpret_cast(destination)[i] = toJavaDecimalValue( @@ -2400,6 +2466,13 @@ void copyWords( } return; } + if (std::is_same_v && isUuid) { + for (auto i = 0; i < numIndices; ++i) { + reinterpret_cast(destination)[i] = toJavaUuidValue( + reinterpret_cast(values)[indices[i]]); + } + return; + } for (auto i = 0; i < numIndices; ++i) { destination[i] = values[indices[i]]; } @@ -2412,9 +2485,10 @@ void copyWordsWithRows( const int32_t* indices, int32_t numIndices, const T* values, - bool isLongDecimal = false) { + bool isLongDecimal = false, + bool isUuid = false) { if (!indices) { - copyWords(destination, rows, numIndices, values, isLongDecimal); + copyWords(destination, rows, numIndices, values, isLongDecimal, isUuid); return; } if (std::is_same_v && isLongDecimal) { @@ -2424,6 +2498,13 @@ void copyWordsWithRows( } return; } + else if (std::is_same_v && isUuid) { + for (auto i = 0; i < numIndices; ++i) { + reinterpret_cast(destination)[i] = toJavaUuidValue( + reinterpret_cast(values)[rows[indices[i]]]); + } + return; + } for (auto i = 0; i < numIndices; ++i) { destination[i] = values[rows[indices[i]]]; } @@ -2484,7 +2565,8 @@ void appendNonNull( nonNullIndices, numNonNull, values, - stream->isLongDecimal()); + stream->isLongDecimal(), + stream->isUuid()); } } @@ -2577,7 +2659,7 @@ void serializeFlatVector( AppendWindow window(stream->values(), scratch); T* output = window.get(rows.size()); copyWords( - output, rows.data(), rows.size(), rawValues, stream->isLongDecimal()); + output, rows.data(), rows.size(), rawValues, stream->isLongDecimal(), stream->isUuid()); return; } diff --git a/velox/serializers/tests/PrestoSerializerTest.cpp b/velox/serializers/tests/PrestoSerializerTest.cpp index 9f4817a00c76b..6055ab4291cfa 100644 --- a/velox/serializers/tests/PrestoSerializerTest.cpp +++ b/velox/serializers/tests/PrestoSerializerTest.cpp @@ -15,6 +15,7 @@ */ #include "velox/serializers/PrestoSerializer.h" #include +#include #include #include #include "velox/common/base/tests/GTestUtils.h" @@ -1054,6 +1055,23 @@ TEST_P(PrestoSerializerTest, longDecimal) { testRoundTrip(vector); } +TEST_P(PrestoSerializerTest, uuid) { + std::vector uuidValues(200); + + for (int row = 0; row < uuidValues.size(); row++) { + uuidValues[row] = (int128_t) 0xD1 << row % 120; + } + auto vector = makeFlatVector(uuidValues, UUID()); + + testRoundTrip(vector); + + // Add some nulls. + for (auto i = 0; i < uuidValues.size(); i += 7) { + vector->setNull(i, true); + } + testRoundTrip(vector); +} + // Test that hierarchically encoded columns (rows) have their encodings // preserved by the PrestoBatchVectorSerializer. TEST_P(PrestoSerializerTest, encodingsBatchVectorSerializer) {