From 0b3848c7c6eab666638b48c62c41c5f2e6c7d559 Mon Sep 17 00:00:00 2001 From: Michael Rittmeister Date: Mon, 11 Mar 2024 22:12:49 +0100 Subject: [PATCH] Improve native inflater - Correctly process ZLIB_SUFFIX - Make new check function to check exit codes --- .../src/commonMain/kotlin/DefaultGateway.kt | 2 +- gateway/src/commonMain/kotlin/Inflater.kt | 7 +++- gateway/src/nativeMain/kotlin/Inflater.kt | 42 +++++++++++-------- 3 files changed, 32 insertions(+), 19 deletions(-) diff --git a/gateway/src/commonMain/kotlin/DefaultGateway.kt b/gateway/src/commonMain/kotlin/DefaultGateway.kt index 635fd38b905..8aa56ebca56 100644 --- a/gateway/src/commonMain/kotlin/DefaultGateway.kt +++ b/gateway/src/commonMain/kotlin/DefaultGateway.kt @@ -180,7 +180,7 @@ public class DefaultGateway(private val data: DefaultGatewayData) : Gateway { private suspend fun read(frame: Frame) { defaultGatewayLogger.trace { "Received raw frame: $frame" } val json = when { - compression -> with(inflater) { frame.inflateData() } + compression -> with(inflater) { frame.inflateData() } ?: return else -> frame.data.decodeToString() } diff --git a/gateway/src/commonMain/kotlin/Inflater.kt b/gateway/src/commonMain/kotlin/Inflater.kt index c2a642fa5c7..384091ef88a 100644 --- a/gateway/src/commonMain/kotlin/Inflater.kt +++ b/gateway/src/commonMain/kotlin/Inflater.kt @@ -4,7 +4,12 @@ import io.ktor.utils.io.core.* import io.ktor.websocket.* internal interface Inflater : Closeable { - fun Frame.inflateData(): String + /** + * Inflates this frame. + * + * @return the inflated frame or null if the received frame was incomplete + */ + fun Frame.inflateData(): String? } internal expect fun Inflater(): Inflater diff --git a/gateway/src/nativeMain/kotlin/Inflater.kt b/gateway/src/nativeMain/kotlin/Inflater.kt index 3abeafcd393..dc41eabb46c 100644 --- a/gateway/src/nativeMain/kotlin/Inflater.kt +++ b/gateway/src/nativeMain/kotlin/Inflater.kt @@ -6,28 +6,35 @@ import platform.zlib.* private const val MAX_WBITS = 15 // Maximum window size in bits private const val CHUNK_SIZE = 256 * 1000 +private val ZLIB_SUFFIX = ubyteArrayOf(0x00u, 0x00u, 0xffu, 0xffu) internal actual fun Inflater(): Inflater = NativeInflater() @OptIn(ExperimentalForeignApi::class) private class NativeInflater : Inflater { + private var frameBuffer = UByteArray(0) + private val zStream = nativeHeap.alloc().apply { - val initResponse = inflateInit2(ptr, MAX_WBITS) - if (initResponse != Z_OK) { + inflateInit2(ptr, MAX_WBITS).check { nativeHeap.free(this) - throw ZLibException("Could not initialize zlib: ${zErrorMessage(initResponse)}") } } - override fun Frame.inflateData(): String { - val compressedData = data + override fun Frame.inflateData(): String? { + frameBuffer += data.asUByteArray() + // check if the last four bytes are equal to ZLIB_SUFFIX + if (frameBuffer.size < 4 || + !frameBuffer.copyOfRange(frameBuffer.size - 4, frameBuffer.size).contentEquals(ZLIB_SUFFIX) + ) { + return null + } var out = ByteArray(0) memScoped { val uncompressedDataSize = CHUNK_SIZE // allocate enough space for the uncompressed data val uncompressedData = allocArray(uncompressedDataSize) zStream.apply { - next_in = compressedData.refTo(0).getPointer(memScope).reinterpret() - avail_in = compressedData.size.convert() + next_in = frameBuffer.refTo(0).getPointer(memScope) + avail_in = frameBuffer.size.convert() } do { @@ -35,27 +42,28 @@ private class NativeInflater : Inflater { next_out = uncompressedData avail_out = uncompressedDataSize.convert() } - val resultCode = inflate(zStream.ptr, Z_NO_FLUSH) - if (resultCode != Z_OK && resultCode != Z_STREAM_END) { - throw ZLibException( - "An error occurred during decompression of frame: ${zErrorMessage(resultCode)}" - ) + inflate(zStream.ptr, Z_NO_FLUSH).check(listOf(Z_OK, Z_STREAM_END)) { + frameBuffer = UByteArray(0) } out += uncompressedData.readBytes(uncompressedDataSize - zStream.avail_out.convert()) } while (zStream.avail_out == 0u) } + frameBuffer = UByteArray(0) return out.decodeToString() } override fun close() { + inflateEnd(zStream.ptr).check { nativeHeap.free(zStream) } + } +} + +private fun Int.check(validCodes: List = listOf(Z_OK), cleanup: () -> Unit = {}) { + if (this !in validCodes) { try { - val response = inflateEnd(zStream.ptr) - if(response != Z_OK) { - throw ZLibException("Could not end zstream: ${zErrorMessage(response)}") - } + throw ZLibException(zErrorMessage(this).toString()) } finally { - nativeHeap.free(zStream) + cleanup() } } }