diff --git a/gateway/src/commonMain/kotlin/DefaultGateway.kt b/gateway/src/commonMain/kotlin/DefaultGateway.kt
index 75e89c1969d..669f456760f 100644
--- a/gateway/src/commonMain/kotlin/DefaultGateway.kt
+++ b/gateway/src/commonMain/kotlin/DefaultGateway.kt
@@ -168,30 +168,28 @@ public class DefaultGateway(private val data: DefaultGatewayData) : Gateway {
private suspend fun readSocket() {
- socket.incoming.asFlow().buffer(Channel.UNLIMITED).collect {
- when (it) {
- is Frame.Binary, is Frame.Text -> read(it)
- else -> { /*ignore*/
+ val frames = socket.incoming.asFlow()
+ .buffer(Channel.UNLIMITED)
+ .onEach { frame -> defaultGatewayLogger.trace { "Received raw frame: $frame" } }
+ val eventsJson = if (compression) {
+ frames.decompressFrames(inflater)
+ } else {
+ frames.mapNotNull { frame ->
+ when (frame) {
+ is Frame.Binary, is Frame.Text -> frame.data.decodeToString()
+ else -> null // ignore other frame types
}
}
}
- }
-
- private suspend fun read(frame: Frame) {
- defaultGatewayLogger.trace { "Received raw frame: $frame" }
- val json = when {
- compression -> with(inflater) { frame.inflateData() } ?: return
- else -> frame.data.decodeToString()
- }
-
- try {
- defaultGatewayLogger.trace { "Gateway <<< $json" }
- val event = jsonParser.decodeFromString(Event.DeserializationStrategy, json)
- data.eventFlow.emit(event)
- } catch (exception: Exception) {
- defaultGatewayLogger.error(exception) { "" }
+ eventsJson.collect { json ->
+ try {
+ defaultGatewayLogger.trace { "Gateway <<< $json" }
+ val event = jsonParser.decodeFromString(Event.DeserializationStrategy, json)
+ data.eventFlow.emit(event)
+ } catch (exception: Exception) {
+ defaultGatewayLogger.error(exception) { "" }
+ }
}
-
}
private suspend fun handleClose() {
diff --git a/gateway/src/commonMain/kotlin/Inflater.kt b/gateway/src/commonMain/kotlin/Inflater.kt
index 384091ef88a..bd9257d5f73 100644
--- a/gateway/src/commonMain/kotlin/Inflater.kt
+++ b/gateway/src/commonMain/kotlin/Inflater.kt
@@ -2,14 +2,48 @@ package dev.kord.gateway
import io.ktor.utils.io.core.*
import io.ktor.websocket.*
+import kotlinx.coroutines.flow.Flow
+import kotlinx.coroutines.flow.transform
internal interface Inflater : Closeable {
- /**
- * Inflates this frame.
- *
- * @return the inflated frame or null if the received frame was incomplete
- */
- fun Frame.inflateData(): String?
+ /** Decompresses [compressedLen] bytes from [compressed] and decodes them to a [String]. */
+ fun inflate(compressed: ByteArray, compressedLen: Int): String
}
internal expect fun Inflater(): Inflater
+
+// check if the last four bytes are equal to Z_SYNC_FLUSH suffix (00 00 ff ff),
+// see https://discord.com/developers/docs/topics/gateway#transport-compression
+private fun ByteArray.endsWithZlibSuffix(len: Int) = len >= 4
+ && this[len - 4] == 0x00.toByte()
+ && this[len - 3] == 0x00.toByte()
+ && this[len - 2] == 0xff.toByte()
+ && this[len - 1] == 0xff.toByte()
+
+internal fun Flow.decompressFrames(inflater: Inflater): Flow {
+ var buffer = ByteArray(0)
+ var bufferLen = 0
+ return transform { frame ->
+ when (frame) {
+ is Frame.Text, is Frame.Binary -> {
+ val data = frame.data
+ val dataLen = data.size
+ // skip copying into buffer if buffer is empty and data has suffix
+ if (bufferLen == 0 && data.endsWithZlibSuffix(dataLen)) {
+ emit(inflater.inflate(data, dataLen))
+ } else {
+ if (buffer.size - bufferLen < dataLen) {
+ buffer = buffer.copyOf(bufferLen + dataLen)
+ }
+ data.copyInto(buffer, destinationOffset = bufferLen)
+ bufferLen += dataLen
+ if (buffer.endsWithZlibSuffix(bufferLen)) {
+ emit(inflater.inflate(buffer, bufferLen))
+ bufferLen = 0
+ }
+ }
+ }
+ else -> {} // ignore other frame types
+ }
+ }
+}
diff --git a/gateway/src/jsMain/kotlin/Inflater.kt b/gateway/src/jsMain/kotlin/Inflater.kt
index 418c269d094..53d725b00f3 100644
--- a/gateway/src/jsMain/kotlin/Inflater.kt
+++ b/gateway/src/jsMain/kotlin/Inflater.kt
@@ -1,15 +1,14 @@
package dev.kord.gateway
import dev.kord.gateway.internal.Inflate
-import io.ktor.websocket.*
import node.buffer.Buffer
import node.buffer.BufferEncoding
internal actual fun Inflater() = object : Inflater {
private val inflate = Inflate()
- override fun Frame.inflateData(): String {
- val buffer = Buffer.from(data)
+ override fun inflate(compressed: ByteArray, compressedLen: Int): String {
+ val buffer = Buffer.from(compressed, byteOffset = 0, length = compressedLen)
return inflate.process(buffer).toString(BufferEncoding.utf8)
}
diff --git a/gateway/src/jvmMain/kotlin/Inflater.kt b/gateway/src/jvmMain/kotlin/Inflater.kt
index 348ba9ae389..ccfc286488c 100644
--- a/gateway/src/jvmMain/kotlin/Inflater.kt
+++ b/gateway/src/jvmMain/kotlin/Inflater.kt
@@ -1,19 +1,18 @@
package dev.kord.gateway
-import io.ktor.websocket.*
import java.io.ByteArrayOutputStream
import java.util.zip.InflaterOutputStream
internal actual fun Inflater() = object : Inflater {
private val delegate = java.util.zip.Inflater()
+ private val buffer = ByteArrayOutputStream()
- override fun Frame.inflateData(): String {
- val outputStream = ByteArrayOutputStream()
- InflaterOutputStream(outputStream, delegate).use {
- it.write(data)
+ override fun inflate(compressed: ByteArray, compressedLen: Int): String {
+ buffer.reset()
+ InflaterOutputStream(buffer, delegate).use {
+ it.write(compressed, /* off = */ 0, /* len = */ compressedLen)
}
-
- return outputStream.use { it.toByteArray().decodeToString() }
+ return buffer.toString("UTF-8")
}
override fun close() = delegate.end()
diff --git a/gateway/src/nativeMain/kotlin/Inflater.kt b/gateway/src/nativeMain/kotlin/Inflater.kt
index 5734b1eed5a..d61e15da5c3 100644
--- a/gateway/src/nativeMain/kotlin/Inflater.kt
+++ b/gateway/src/nativeMain/kotlin/Inflater.kt
@@ -1,19 +1,16 @@
package dev.kord.gateway
-import io.ktor.websocket.*
import kotlinx.cinterop.*
import platform.zlib.*
-private const val CHUNK_SIZE = 256 * 1000
-private val ZLIB_SUFFIX = ubyteArrayOf(0x00u, 0x00u, 0xffu, 0xffu)
-
-internal actual fun Inflater(): Inflater = NativeInflater()
+private class ZlibException(message: String) : IllegalStateException(message)
@OptIn(ExperimentalForeignApi::class)
-private class NativeInflater : Inflater {
- // see https://www.zlib.net/manual.html
+internal actual fun Inflater(): Inflater = object : Inflater {
+ // see https://zlib.net/manual.html
- private var frameBuffer = UByteArray(0)
+ private var decompressed = UByteArray(1024) // buffer only grows, is reused for every zlib inflate call
+ private var decompressedLen = 0
private val zStream = nativeHeap.alloc().also { zStream ->
// next_in, avail_in, zalloc, zfree and opaque must be initialized before calling inflateInit
@@ -34,39 +31,31 @@ private class NativeInflater : Inflater {
}
}
- override fun Frame.inflateData(): String? {
- frameBuffer += data.asUByteArray()
- // check if the last four bytes are equal to ZLIB_SUFFIX
- if (frameBuffer.size < 4 ||
- !frameBuffer.copyOfRange(frameBuffer.size - 4, frameBuffer.size).contentEquals(ZLIB_SUFFIX)
- ) {
- return null
- }
- var out = ByteArray(0)
- memScoped {
- val uncompressedDataSize = CHUNK_SIZE // allocate enough space for the uncompressed data
- val uncompressedData = allocArray(uncompressedDataSize)
- zStream.apply {
- next_in = frameBuffer.refTo(0).getPointer(memScope)
- avail_in = frameBuffer.size.convert()
- }
+ private fun throwZlibException(msg: CPointer?, ret: Int): Nothing =
+ throw ZlibException(msg?.toKString() ?: zError(ret)?.toKString() ?: ret.toString())
- do {
- zStream.apply {
- next_out = uncompressedData
- avail_out = uncompressedDataSize.convert()
+ override fun inflate(compressed: ByteArray, compressedLen: Int): String =
+ compressed.asUByteArray().usePinned { compressedPinned ->
+ zStream.next_in = compressedPinned.addressOf(0)
+ zStream.avail_in = compressedLen.convert()
+ decompressedLen = 0
+ while (true) {
+ val ret = decompressed.usePinned { decompressedPinned ->
+ zStream.next_out = decompressedPinned.addressOf(decompressedLen)
+ zStream.avail_out = (decompressed.size - decompressedLen).convert()
+ inflate(zStream.ptr, Z_NO_FLUSH)
}
- inflate(zStream.ptr, Z_NO_FLUSH).check(listOf(Z_OK, Z_STREAM_END)) {
- frameBuffer = UByteArray(0)
+ if (ret != Z_OK && ret != Z_STREAM_END) {
+ throwZlibException(zStream.msg, ret)
}
- out += uncompressedData.readBytes(uncompressedDataSize - zStream.avail_out.convert())
- } while (zStream.avail_out == 0u)
+ if (zStream.avail_in == 0u || zStream.avail_out != 0u) break
+ // grow decompressed buffer
+ decompressedLen = decompressed.size
+ decompressed = decompressed.copyOf(decompressed.size * 2)
+ }
+ decompressed.asByteArray().decodeToString(endIndex = decompressed.size - zStream.avail_out.convert())
}
- frameBuffer = UByteArray(0)
- return out.decodeToString()
- }
-
override fun close() {
val ret = inflateEnd(zStream.ptr)
try {
@@ -76,23 +65,3 @@ private class NativeInflater : Inflater {
}
}
}
-
-@ExperimentalForeignApi
-private fun Int.check(validCodes: List = listOf(Z_OK), cleanup: () -> Unit = {}) {
- if (this !in validCodes) {
- try {
- throw ZlibException(zErrorMessage(this).toString())
- } finally {
- cleanup()
- }
- }
-}
-
-private class ZlibException(message: String?) : IllegalStateException(message)
-
-@ExperimentalForeignApi
-private fun zErrorMessage(errorCode: Int) = zError(errorCode)?.toKString() ?: errorCode
-
-@ExperimentalForeignApi
-private fun throwZlibException(msg: CPointer?, ret: Int): Nothing =
- throw ZlibException(msg?.toKString() ?: zError(ret)?.toKString() ?: ret.toString())