diff --git a/gateway/src/commonMain/kotlin/DefaultGateway.kt b/gateway/src/commonMain/kotlin/DefaultGateway.kt index 75e89c1969d..669f456760f 100644 --- a/gateway/src/commonMain/kotlin/DefaultGateway.kt +++ b/gateway/src/commonMain/kotlin/DefaultGateway.kt @@ -168,30 +168,28 @@ public class DefaultGateway(private val data: DefaultGatewayData) : Gateway { private suspend fun readSocket() { - socket.incoming.asFlow().buffer(Channel.UNLIMITED).collect { - when (it) { - is Frame.Binary, is Frame.Text -> read(it) - else -> { /*ignore*/ + val frames = socket.incoming.asFlow() + .buffer(Channel.UNLIMITED) + .onEach { frame -> defaultGatewayLogger.trace { "Received raw frame: $frame" } } + val eventsJson = if (compression) { + frames.decompressFrames(inflater) + } else { + frames.mapNotNull { frame -> + when (frame) { + is Frame.Binary, is Frame.Text -> frame.data.decodeToString() + else -> null // ignore other frame types } } } - } - - private suspend fun read(frame: Frame) { - defaultGatewayLogger.trace { "Received raw frame: $frame" } - val json = when { - compression -> with(inflater) { frame.inflateData() } ?: return - else -> frame.data.decodeToString() - } - - try { - defaultGatewayLogger.trace { "Gateway <<< $json" } - val event = jsonParser.decodeFromString(Event.DeserializationStrategy, json) - data.eventFlow.emit(event) - } catch (exception: Exception) { - defaultGatewayLogger.error(exception) { "" } + eventsJson.collect { json -> + try { + defaultGatewayLogger.trace { "Gateway <<< $json" } + val event = jsonParser.decodeFromString(Event.DeserializationStrategy, json) + data.eventFlow.emit(event) + } catch (exception: Exception) { + defaultGatewayLogger.error(exception) { "" } + } } - } private suspend fun handleClose() { diff --git a/gateway/src/commonMain/kotlin/Inflater.kt b/gateway/src/commonMain/kotlin/Inflater.kt index 384091ef88a..bd9257d5f73 100644 --- a/gateway/src/commonMain/kotlin/Inflater.kt +++ b/gateway/src/commonMain/kotlin/Inflater.kt @@ -2,14 +2,48 @@ package dev.kord.gateway import io.ktor.utils.io.core.* import io.ktor.websocket.* +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.transform internal interface Inflater : Closeable { - /** - * Inflates this frame. - * - * @return the inflated frame or null if the received frame was incomplete - */ - fun Frame.inflateData(): String? + /** Decompresses [compressedLen] bytes from [compressed] and decodes them to a [String]. */ + fun inflate(compressed: ByteArray, compressedLen: Int): String } internal expect fun Inflater(): Inflater + +// check if the last four bytes are equal to Z_SYNC_FLUSH suffix (00 00 ff ff), +// see https://discord.com/developers/docs/topics/gateway#transport-compression +private fun ByteArray.endsWithZlibSuffix(len: Int) = len >= 4 + && this[len - 4] == 0x00.toByte() + && this[len - 3] == 0x00.toByte() + && this[len - 2] == 0xff.toByte() + && this[len - 1] == 0xff.toByte() + +internal fun Flow.decompressFrames(inflater: Inflater): Flow { + var buffer = ByteArray(0) + var bufferLen = 0 + return transform { frame -> + when (frame) { + is Frame.Text, is Frame.Binary -> { + val data = frame.data + val dataLen = data.size + // skip copying into buffer if buffer is empty and data has suffix + if (bufferLen == 0 && data.endsWithZlibSuffix(dataLen)) { + emit(inflater.inflate(data, dataLen)) + } else { + if (buffer.size - bufferLen < dataLen) { + buffer = buffer.copyOf(bufferLen + dataLen) + } + data.copyInto(buffer, destinationOffset = bufferLen) + bufferLen += dataLen + if (buffer.endsWithZlibSuffix(bufferLen)) { + emit(inflater.inflate(buffer, bufferLen)) + bufferLen = 0 + } + } + } + else -> {} // ignore other frame types + } + } +} diff --git a/gateway/src/jsMain/kotlin/Inflater.kt b/gateway/src/jsMain/kotlin/Inflater.kt index 418c269d094..53d725b00f3 100644 --- a/gateway/src/jsMain/kotlin/Inflater.kt +++ b/gateway/src/jsMain/kotlin/Inflater.kt @@ -1,15 +1,14 @@ package dev.kord.gateway import dev.kord.gateway.internal.Inflate -import io.ktor.websocket.* import node.buffer.Buffer import node.buffer.BufferEncoding internal actual fun Inflater() = object : Inflater { private val inflate = Inflate() - override fun Frame.inflateData(): String { - val buffer = Buffer.from(data) + override fun inflate(compressed: ByteArray, compressedLen: Int): String { + val buffer = Buffer.from(compressed, byteOffset = 0, length = compressedLen) return inflate.process(buffer).toString(BufferEncoding.utf8) } diff --git a/gateway/src/jvmMain/kotlin/Inflater.kt b/gateway/src/jvmMain/kotlin/Inflater.kt index 348ba9ae389..ccfc286488c 100644 --- a/gateway/src/jvmMain/kotlin/Inflater.kt +++ b/gateway/src/jvmMain/kotlin/Inflater.kt @@ -1,19 +1,18 @@ package dev.kord.gateway -import io.ktor.websocket.* import java.io.ByteArrayOutputStream import java.util.zip.InflaterOutputStream internal actual fun Inflater() = object : Inflater { private val delegate = java.util.zip.Inflater() + private val buffer = ByteArrayOutputStream() - override fun Frame.inflateData(): String { - val outputStream = ByteArrayOutputStream() - InflaterOutputStream(outputStream, delegate).use { - it.write(data) + override fun inflate(compressed: ByteArray, compressedLen: Int): String { + buffer.reset() + InflaterOutputStream(buffer, delegate).use { + it.write(compressed, /* off = */ 0, /* len = */ compressedLen) } - - return outputStream.use { it.toByteArray().decodeToString() } + return buffer.toString("UTF-8") } override fun close() = delegate.end() diff --git a/gateway/src/nativeMain/kotlin/Inflater.kt b/gateway/src/nativeMain/kotlin/Inflater.kt index 5734b1eed5a..d61e15da5c3 100644 --- a/gateway/src/nativeMain/kotlin/Inflater.kt +++ b/gateway/src/nativeMain/kotlin/Inflater.kt @@ -1,19 +1,16 @@ package dev.kord.gateway -import io.ktor.websocket.* import kotlinx.cinterop.* import platform.zlib.* -private const val CHUNK_SIZE = 256 * 1000 -private val ZLIB_SUFFIX = ubyteArrayOf(0x00u, 0x00u, 0xffu, 0xffu) - -internal actual fun Inflater(): Inflater = NativeInflater() +private class ZlibException(message: String) : IllegalStateException(message) @OptIn(ExperimentalForeignApi::class) -private class NativeInflater : Inflater { - // see https://www.zlib.net/manual.html +internal actual fun Inflater(): Inflater = object : Inflater { + // see https://zlib.net/manual.html - private var frameBuffer = UByteArray(0) + private var decompressed = UByteArray(1024) // buffer only grows, is reused for every zlib inflate call + private var decompressedLen = 0 private val zStream = nativeHeap.alloc().also { zStream -> // next_in, avail_in, zalloc, zfree and opaque must be initialized before calling inflateInit @@ -34,39 +31,31 @@ private class NativeInflater : Inflater { } } - override fun Frame.inflateData(): String? { - frameBuffer += data.asUByteArray() - // check if the last four bytes are equal to ZLIB_SUFFIX - if (frameBuffer.size < 4 || - !frameBuffer.copyOfRange(frameBuffer.size - 4, frameBuffer.size).contentEquals(ZLIB_SUFFIX) - ) { - return null - } - var out = ByteArray(0) - memScoped { - val uncompressedDataSize = CHUNK_SIZE // allocate enough space for the uncompressed data - val uncompressedData = allocArray(uncompressedDataSize) - zStream.apply { - next_in = frameBuffer.refTo(0).getPointer(memScope) - avail_in = frameBuffer.size.convert() - } + private fun throwZlibException(msg: CPointer?, ret: Int): Nothing = + throw ZlibException(msg?.toKString() ?: zError(ret)?.toKString() ?: ret.toString()) - do { - zStream.apply { - next_out = uncompressedData - avail_out = uncompressedDataSize.convert() + override fun inflate(compressed: ByteArray, compressedLen: Int): String = + compressed.asUByteArray().usePinned { compressedPinned -> + zStream.next_in = compressedPinned.addressOf(0) + zStream.avail_in = compressedLen.convert() + decompressedLen = 0 + while (true) { + val ret = decompressed.usePinned { decompressedPinned -> + zStream.next_out = decompressedPinned.addressOf(decompressedLen) + zStream.avail_out = (decompressed.size - decompressedLen).convert() + inflate(zStream.ptr, Z_NO_FLUSH) } - inflate(zStream.ptr, Z_NO_FLUSH).check(listOf(Z_OK, Z_STREAM_END)) { - frameBuffer = UByteArray(0) + if (ret != Z_OK && ret != Z_STREAM_END) { + throwZlibException(zStream.msg, ret) } - out += uncompressedData.readBytes(uncompressedDataSize - zStream.avail_out.convert()) - } while (zStream.avail_out == 0u) + if (zStream.avail_in == 0u || zStream.avail_out != 0u) break + // grow decompressed buffer + decompressedLen = decompressed.size + decompressed = decompressed.copyOf(decompressed.size * 2) + } + decompressed.asByteArray().decodeToString(endIndex = decompressed.size - zStream.avail_out.convert()) } - frameBuffer = UByteArray(0) - return out.decodeToString() - } - override fun close() { val ret = inflateEnd(zStream.ptr) try { @@ -76,23 +65,3 @@ private class NativeInflater : Inflater { } } } - -@ExperimentalForeignApi -private fun Int.check(validCodes: List = listOf(Z_OK), cleanup: () -> Unit = {}) { - if (this !in validCodes) { - try { - throw ZlibException(zErrorMessage(this).toString()) - } finally { - cleanup() - } - } -} - -private class ZlibException(message: String?) : IllegalStateException(message) - -@ExperimentalForeignApi -private fun zErrorMessage(errorCode: Int) = zError(errorCode)?.toKString() ?: errorCode - -@ExperimentalForeignApi -private fun throwZlibException(msg: CPointer?, ret: Int): Nothing = - throw ZlibException(msg?.toKString() ?: zError(ret)?.toKString() ?: ret.toString())