diff --git a/velox/functions/lib/Utf8Utils.cpp b/velox/functions/lib/Utf8Utils.cpp index 02354db6cbc1..2b027b6f2caf 100644 --- a/velox/functions/lib/Utf8Utils.cpp +++ b/velox/functions/lib/Utf8Utils.cpp @@ -18,10 +18,6 @@ #include "velox/external/utf8proc/utf8procImpl.h" namespace facebook::velox::functions { -namespace { - -// Returns the length of a UTF-8 character indicated by the first byte. Returns -// -1 for invalid UTF-8 first byte. int firstByteCharLength(const char* u_input) { auto u = (const unsigned char*)u_input; unsigned char u0 = u[0]; @@ -59,8 +55,6 @@ int firstByteCharLength(const char* u_input) { return -1; } -} // namespace - int32_t tryGetUtf8CharLength(const char* input, int64_t size, int32_t& codePoint) { VELOX_DCHECK_NOT_NULL(input); diff --git a/velox/functions/lib/Utf8Utils.h b/velox/functions/lib/Utf8Utils.h index eb994219ec36..3c8950ee369a 100644 --- a/velox/functions/lib/Utf8Utils.h +++ b/velox/functions/lib/Utf8Utils.h @@ -82,4 +82,8 @@ FOLLY_ALWAYS_INLINE int validateAndGetNextUtf8Length( return -1; } +/// Returns the length of a UTF-8 character indicated by the first byte. Returns +/// -1 for invalid UTF-8 first byte. +int firstByteCharLength(const char* u_input); + } // namespace facebook::velox::functions diff --git a/velox/functions/prestosql/URLFunctions.h b/velox/functions/prestosql/URLFunctions.h index 1ce0c42d0e4a..372bfdef2142 100644 --- a/velox/functions/prestosql/URLFunctions.h +++ b/velox/functions/prestosql/URLFunctions.h @@ -16,6 +16,7 @@ #pragma once #include +#include "velox/external/utf8proc/utf8procImpl.h" #include "velox/functions/Macros.h" #include "velox/functions/lib/Utf8Utils.h" #include "velox/functions/prestosql/URIParser.h" @@ -30,6 +31,13 @@ constexpr std::array kEncodedReplacementCharacterStrings = "%EF%BF%BD%EF%BF%BD%EF%BF%BD%EF%BF%BD", "%EF%BF%BD%EF%BF%BD%EF%BF%BD%EF%BF%BD%EF%BF%BD", "%EF%BF%BD%EF%BF%BD%EF%BF%BD%EF%BF%BD%EF%BF%BD%EF%BF%BD"}; +constexpr std::array kDecodedReplacementCharacterStrings{ + "\xef\xbf\xbd", + "\xef\xbf\xbd\xef\xbf\xbd", + "\xef\xbf\xbd\xef\xbf\xbd\xef\xbf\xbd", + "\xef\xbf\xbd\xef\xbf\xbd\xef\xbf\xbd\xef\xbf\xbd", + "\xef\xbf\xbd\xef\xbf\xbd\xef\xbf\xbd\xef\xbf\xbd\xef\xbf\xbd", + "\xef\xbf\xbd\xef\xbf\xbd\xef\xbf\xbd\xef\xbf\xbd\xef\xbf\xbd\xef\xbf\xbd"}; FOLLY_ALWAYS_INLINE StringView submatch(const boost::cmatch& match, int idx) { const auto& sub = match[idx]; @@ -46,6 +54,29 @@ FOLLY_ALWAYS_INLINE void charEscape(unsigned char c, char* output) { output[2] = toHex(c % 16); } +template +FOLLY_ALWAYS_INLINE bool isMultipleInvalidSequences( + const T& inputBuffer, + size_t inputIndex) { + return + // 0xe0 followed by a value less than 0xe0 or 0xf0 followed by a + // value less than 0x90 is considered an overlong encoding. + (inputBuffer[inputIndex] == '\xe0' && + (inputBuffer[inputIndex + 1] & 0xe0) == 0x80) || + (inputBuffer[inputIndex] == '\xf0' && + (inputBuffer[inputIndex + 1] & 0xf0) == 0x80) || + // 0xf4 followed by a byte >= 0x90 looks valid to + // tryGetUtf8CharLength, but is actually outside the range of valid + // code points. + (inputBuffer[inputIndex] == '\xf4' && + (inputBuffer[inputIndex + 1] & 0xf0) != 0x80) || + // The bytes 0xf5-0xff, 0xc0, and 0xc1 look like the start of + // multi-byte code points to tryGetUtf8CharLength, but are not part of + // any valid code point. + (unsigned char)inputBuffer[inputIndex] > 0xf4 || + inputBuffer[inputIndex] == '\xc0' || inputBuffer[inputIndex] == '\xc1'; +} + /// Escapes ``input`` by encoding it so that it can be safely included in /// URL query parameter names and values: /// @@ -99,31 +130,10 @@ FOLLY_ALWAYS_INLINE void urlEscape(TOutString& output, const TInString& input) { // encodings or subsequences outside the range of valid 4 byte // sequences. In both these cases we should just write out a // replacement character for every byte in the sequence. - size_t replaceCharactersToWriteOut = 1; - if (inputIndex < inputSize - 1) { - bool isMultipleInvalidSequences = - // 0xe0 followed by a value less than 0xe0 or 0xf0 followed by a - // value less than 0x90 is considered an overlong encoding. - (inputBuffer[inputIndex] == '\xe0' && - (inputBuffer[inputIndex + 1] & 0xe0) == 0x80) || - (inputBuffer[inputIndex] == '\xf0' && - (inputBuffer[inputIndex + 1] & 0xf0) == 0x80) || - // 0xf4 followed by a byte >= 0x90 looks valid to - // tryGetUtf8CharLength, but is actually outside the range of - // valid code points. - (inputBuffer[inputIndex] == '\xf4' && - (inputBuffer[inputIndex + 1] & 0xf0) != 0x80) || - // The bytes 0xf5-0xff, 0xc0, and 0xc1 look like the start of - // multi-byte code points to tryGetUtf8CharLength, but are not - // part of any valid code point. - (unsigned char)inputBuffer[inputIndex] > 0xf4 || - inputBuffer[inputIndex] == '\xc0' || - inputBuffer[inputIndex] == '\xc1'; - - if (isMultipleInvalidSequences) { - replaceCharactersToWriteOut = charLength * -1; - } - } + size_t replaceCharactersToWriteOut = inputIndex < inputSize - 1 && + isMultipleInvalidSequences(inputBuffer, inputIndex) + ? -charLength + : 1; const auto& replacementCharacterString = kEncodedReplacementCharacterStrings @@ -141,6 +151,28 @@ FOLLY_ALWAYS_INLINE void urlEscape(TOutString& output, const TInString& input) { output.resize(outIndex); } +FOLLY_ALWAYS_INLINE char decodeByte(const char* p, const char* end) { + char buf[3]; + buf[2] = '\0'; + + if (p + 2 < end) { + buf[0] = p[1]; + buf[1] = p[2]; + p += 2; + + char* endptr; + char val = strtol(buf, &endptr, 16); + + if (endptr != buf + 2) { + VELOX_USER_FAIL("Illegal hex characters in escape (%) pattern: {}", buf); + } + + return val; + } else { + VELOX_USER_FAIL("Incomplete trailing escape (%) pattern"); + } +} + template FOLLY_ALWAYS_INLINE void urlUnescape( TOutString& output, @@ -151,9 +183,7 @@ FOLLY_ALWAYS_INLINE void urlUnescape( auto outputBuffer = output.data(); const char* p = input.data(); const char* end = p + inputSize; - char buf[3]; - buf[2] = '\0'; - char* endptr; + for (; p < end; ++p) { if constexpr (unescapePlus) { if (*p == '+') { @@ -162,20 +192,68 @@ FOLLY_ALWAYS_INLINE void urlUnescape( } } if (*p == '%') { - if (p + 2 < end) { - buf[0] = p[1]; - buf[1] = p[2]; - int val = strtol(buf, &endptr, 16); - if (endptr == buf + 2) { - *outputBuffer++ = (char)val; - p += 2; - } else { - VELOX_USER_FAIL( - "Illegal hex characters in escape (%) pattern: {}", buf); - } + char firstByte = decodeByte(p, end); + int32_t charLength = firstByteCharLength(&firstByte); + + if (charLength == 1) { + // This is an ASCII character, just write it out. + *outputBuffer++ = firstByte; + } else if (charLength < 0) { + // This isn't the start of a valid UTF-8 character, write out the + // replacement character. + const auto& replacementString = kDecodedReplacementCharacterStrings[0]; + std::memcpy( + outputBuffer, replacementString.data(), replacementString.length()); + outputBuffer += replacementString.length(); } else { - VELOX_USER_FAIL("Incomplete trailing escape (%) pattern"); + char* charStart = outputBuffer; + *outputBuffer++ = firstByte; + int32_t charLengthRemaining = charLength - 1; + + // Iterate over each percent encoded byte of the UTF-8 character. + while (charLengthRemaining > 0 && p + 3 < end && *(p + 3) == '%') { + char val = decodeByte(p + 3, end); + + if (!utf_cont(val)) { + // If the byte is not a continuation character this is not valid + // UTF-8 abort so we can write out replacement character(s). + break; + } + + // Skip over the previous percent encoded value in the input. We only + // do this after checking if the current byte is valid because if the + // current byte is invalid, it might be a valid byte in the next code + // point. + p += 3; + *outputBuffer++ = val; + charLengthRemaining--; + } + + int32_t codePoint; + if (charLengthRemaining > 0 || + tryGetUtf8CharLength(charStart, charLength, codePoint) < 0) { + // If we exited the loop early it means we encountered a byte that + // wasn't part of a valid UTF-8 code point. If tryGetUtf8CharLength + // returns a negative value it means even though the bytes looked like + // valid UTF-8 they were not, e.g. they were an overlong code point. + size_t charLength = outputBuffer - charStart; + size_t replaceCharactersToWriteOut = + isMultipleInvalidSequences(charStart, 0) ? charLength : 1; + const auto& replacementString = kDecodedReplacementCharacterStrings + [replaceCharactersToWriteOut - 1]; + + outputBuffer = charStart; + std::memcpy( + outputBuffer, + replacementString.data(), + replacementString.length()); + outputBuffer += replacementString.length(); + } } + + // Skip over the last percent encoded value in the code point (the for + // loop will handle skipping over the third char). + p += 2; } else { *outputBuffer++ = *p; } diff --git a/velox/functions/prestosql/tests/URLFunctionsTest.cpp b/velox/functions/prestosql/tests/URLFunctionsTest.cpp index 3fb18c22e7b1..6802105c137c 100644 --- a/velox/functions/prestosql/tests/URLFunctionsTest.cpp +++ b/velox/functions/prestosql/tests/URLFunctionsTest.cpp @@ -612,6 +612,38 @@ TEST_F(URLFunctionsTest, urlDecode) { urlDecode("http%3A%2F%2F%E3%83%86%E3%82%B9%E3%83%88")); EXPECT_EQ("~@:.-*_+ \u2603", urlDecode("%7E%40%3A.-*_%2B+%E2%98%83")); EXPECT_EQ("test", urlDecode("test")); + // Test a single byte invalid UTF-8 character. + EXPECT_EQ("te\xef\xbf\xbdst", urlDecode("te%88st")); + // Test a multi-byte invalid UTF-8 character. (If the first byte is between + // 0xe0 and 0xef, it should be a 3 byte character, but we only have 2 bytes + // here.) + EXPECT_EQ("te\xef\xbf\xbdst", urlDecode("te%e0%b8st")); + // Test an overlong 3 byte UTF-8 character + EXPECT_EQ("\xef\xbf\xbd\xef\xbf\xbd", urlDecode("%e0%94")); + // Test an overlong 3 byte UTF-8 character with a continuation byte. + EXPECT_EQ("\xef\xbf\xbd\xef\xbf\xbd\xef\xbf\xbd", urlDecode("%e0%94%83")); + // Test an overlong 4 byte UTF-8 character + EXPECT_EQ("\xef\xbf\xbd\xef\xbf\xbd", urlDecode("%f0%84")); + // Test an overlong 4 byte UTF-8 character with continuation bytes. + EXPECT_EQ( + "\xef\xbf\xbd\xef\xbf\xbd\xef\xbf\xbd\xef\xbf\xbd", + urlDecode("%f0%84%90%90")); + // Test a 4 byte UTF-8 character outside the range of valid values. + EXPECT_EQ( + "\xef\xbf\xbd\xef\xbf\xbd\xef\xbf\xbd\xef\xbf\xbd", + urlDecode("%fa%80%80%80")); + // Test the beginning of a 4 byte UTF-8 character followed by a + // non-continuation byte. + EXPECT_EQ("\xef\xbf\xbd\xef\xbf\xbd", urlDecode("%f0%e0")); + // Test the invalid byte 0xc0. + EXPECT_EQ("\xef\xbf\xbd\xef\xbf\xbd", urlDecode("%c0%83")); + // Test the invalid byte 0xc1. + EXPECT_EQ("\xef\xbf\xbd\xef\xbf\xbd", urlDecode("%c1%83")); + // Test a 4 byte UTF-8 character that looks valid, but is actually outside the + // range of valid values. + EXPECT_EQ( + "\xef\xbf\xbd\xef\xbf\xbd\xef\xbf\xbd\xef\xbf\xbd", + urlDecode("%f4%92%83%83")); EXPECT_THROW(urlDecode("http%3A%2F%2"), VeloxUserError); EXPECT_THROW(urlDecode("http%3A%2F%"), VeloxUserError);