Skip to content

Commit

Permalink
More inflater improvements
Browse files Browse the repository at this point in the history
Co-authored-by: Michael Rittmeister <[email protected]>
  • Loading branch information
lukellmann and DRSchlaubi committed Apr 5, 2024
1 parent 89cf26e commit 0f4df33
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 92 deletions.
38 changes: 18 additions & 20 deletions gateway/src/commonMain/kotlin/DefaultGateway.kt
Original file line number Diff line number Diff line change
Expand Up @@ -168,30 +168,28 @@ public class DefaultGateway(private val data: DefaultGatewayData) : Gateway {


private suspend fun readSocket() {
socket.incoming.asFlow().buffer(Channel.UNLIMITED).collect {
when (it) {
is Frame.Binary, is Frame.Text -> read(it)
else -> { /*ignore*/
val frames = socket.incoming.asFlow()
.buffer(Channel.UNLIMITED)
.onEach { frame -> defaultGatewayLogger.trace { "Received raw frame: $frame" } }
val eventsJson = if (compression) {
frames.decompressFrames(inflater)
} else {
frames.mapNotNull { frame ->
when (frame) {
is Frame.Binary, is Frame.Text -> frame.data.decodeToString()
else -> null // ignore other frame types
}
}
}
}

private suspend fun read(frame: Frame) {
defaultGatewayLogger.trace { "Received raw frame: $frame" }
val json = when {
compression -> with(inflater) { frame.inflateData() } ?: return
else -> frame.data.decodeToString()
}

try {
defaultGatewayLogger.trace { "Gateway <<< $json" }
val event = jsonParser.decodeFromString(Event.DeserializationStrategy, json)
data.eventFlow.emit(event)
} catch (exception: Exception) {
defaultGatewayLogger.error(exception) { "" }
eventsJson.collect { json ->
try {
defaultGatewayLogger.trace { "Gateway <<< $json" }
val event = jsonParser.decodeFromString(Event.DeserializationStrategy, json)
data.eventFlow.emit(event)
} catch (exception: Exception) {
defaultGatewayLogger.error(exception) { "" }
}
}

}

private suspend fun handleClose() {
Expand Down
46 changes: 40 additions & 6 deletions gateway/src/commonMain/kotlin/Inflater.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,48 @@ package dev.kord.gateway

import io.ktor.utils.io.core.*
import io.ktor.websocket.*
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.transform

internal interface Inflater : Closeable {
/**
* Inflates this frame.
*
* @return the inflated frame or null if the received frame was incomplete
*/
fun Frame.inflateData(): String?
/** Decompresses [compressedLen] bytes from [compressed] and decodes them to a [String]. */
fun inflate(compressed: ByteArray, compressedLen: Int): String
}

internal expect fun Inflater(): Inflater

// check if the last four bytes are equal to Z_SYNC_FLUSH suffix (00 00 ff ff),
// see https://discord.com/developers/docs/topics/gateway#transport-compression
private fun ByteArray.endsWithZlibSuffix(len: Int) = len >= 4
&& this[len - 4] == 0x00.toByte()
&& this[len - 3] == 0x00.toByte()
&& this[len - 2] == 0xff.toByte()
&& this[len - 1] == 0xff.toByte()

internal fun Flow<Frame>.decompressFrames(inflater: Inflater): Flow<String> {
var buffer = ByteArray(0)
var bufferLen = 0
return transform { frame ->
when (frame) {
is Frame.Text, is Frame.Binary -> {
val data = frame.data
val dataLen = data.size
// skip copying into buffer if buffer is empty and data has suffix
if (bufferLen == 0 && data.endsWithZlibSuffix(dataLen)) {
emit(inflater.inflate(data, dataLen))
} else {
if (buffer.size - bufferLen < dataLen) {
buffer = buffer.copyOf(bufferLen + dataLen)
}
data.copyInto(buffer, destinationOffset = bufferLen)
bufferLen += dataLen
if (buffer.endsWithZlibSuffix(bufferLen)) {
emit(inflater.inflate(buffer, bufferLen))
bufferLen = 0
}
}
}
else -> {} // ignore other frame types
}
}
}
5 changes: 2 additions & 3 deletions gateway/src/jsMain/kotlin/Inflater.kt
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
package dev.kord.gateway

import dev.kord.gateway.internal.Inflate
import io.ktor.websocket.*
import node.buffer.Buffer
import node.buffer.BufferEncoding

internal actual fun Inflater() = object : Inflater {
private val inflate = Inflate()

override fun Frame.inflateData(): String {
val buffer = Buffer.from(data)
override fun inflate(compressed: ByteArray, compressedLen: Int): String {
val buffer = Buffer.from(compressed, byteOffset = 0, length = compressedLen)

return inflate.process(buffer).toString(BufferEncoding.utf8)
}
Expand Down
13 changes: 6 additions & 7 deletions gateway/src/jvmMain/kotlin/Inflater.kt
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
package dev.kord.gateway

import io.ktor.websocket.*
import java.io.ByteArrayOutputStream
import java.util.zip.InflaterOutputStream

internal actual fun Inflater() = object : Inflater {
private val delegate = java.util.zip.Inflater()
private val buffer = ByteArrayOutputStream()

override fun Frame.inflateData(): String {
val outputStream = ByteArrayOutputStream()
InflaterOutputStream(outputStream, delegate).use {
it.write(data)
override fun inflate(compressed: ByteArray, compressedLen: Int): String {
buffer.reset()
InflaterOutputStream(buffer, delegate).use {
it.write(compressed, /* off = */ 0, /* len = */ compressedLen)
}

return outputStream.use { it.toByteArray().decodeToString() }
return buffer.toString("UTF-8")
}

override fun close() = delegate.end()
Expand Down
81 changes: 25 additions & 56 deletions gateway/src/nativeMain/kotlin/Inflater.kt
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
package dev.kord.gateway

import io.ktor.websocket.*
import kotlinx.cinterop.*
import platform.zlib.*

private const val CHUNK_SIZE = 256 * 1000
private val ZLIB_SUFFIX = ubyteArrayOf(0x00u, 0x00u, 0xffu, 0xffu)

internal actual fun Inflater(): Inflater = NativeInflater()
private class ZlibException(message: String) : IllegalStateException(message)

@OptIn(ExperimentalForeignApi::class)
private class NativeInflater : Inflater {
// see https://www.zlib.net/manual.html
internal actual fun Inflater(): Inflater = object : Inflater {
// see https://zlib.net/manual.html

private var frameBuffer = UByteArray(0)
private var decompressed = UByteArray(1024) // buffer only grows, is reused for every zlib inflate call
private var decompressedLen = 0

private val zStream = nativeHeap.alloc<z_stream>().also { zStream ->
// next_in, avail_in, zalloc, zfree and opaque must be initialized before calling inflateInit
Expand All @@ -34,39 +31,31 @@ private class NativeInflater : Inflater {
}
}

override fun Frame.inflateData(): String? {
frameBuffer += data.asUByteArray()
// check if the last four bytes are equal to ZLIB_SUFFIX
if (frameBuffer.size < 4 ||
!frameBuffer.copyOfRange(frameBuffer.size - 4, frameBuffer.size).contentEquals(ZLIB_SUFFIX)
) {
return null
}
var out = ByteArray(0)
memScoped {
val uncompressedDataSize = CHUNK_SIZE // allocate enough space for the uncompressed data
val uncompressedData = allocArray<uByteVar>(uncompressedDataSize)
zStream.apply {
next_in = frameBuffer.refTo(0).getPointer(memScope)
avail_in = frameBuffer.size.convert()
}
private fun throwZlibException(msg: CPointer<ByteVar>?, ret: Int): Nothing =
throw ZlibException(msg?.toKString() ?: zError(ret)?.toKString() ?: ret.toString())

do {
zStream.apply {
next_out = uncompressedData
avail_out = uncompressedDataSize.convert()
override fun inflate(compressed: ByteArray, compressedLen: Int): String =
compressed.asUByteArray().usePinned { compressedPinned ->
zStream.next_in = compressedPinned.addressOf(0)
zStream.avail_in = compressedLen.convert()
decompressedLen = 0
while (true) {
val ret = decompressed.usePinned { decompressedPinned ->
zStream.next_out = decompressedPinned.addressOf(decompressedLen)
zStream.avail_out = (decompressed.size - decompressedLen).convert()
inflate(zStream.ptr, Z_NO_FLUSH)
}
inflate(zStream.ptr, Z_NO_FLUSH).check(listOf(Z_OK, Z_STREAM_END)) {
frameBuffer = UByteArray(0)
if (ret != Z_OK && ret != Z_STREAM_END) {
throwZlibException(zStream.msg, ret)
}
out += uncompressedData.readBytes(uncompressedDataSize - zStream.avail_out.convert<Int>())
} while (zStream.avail_out == 0u)
if (zStream.avail_in == 0u || zStream.avail_out != 0u) break
// grow decompressed buffer
decompressedLen = decompressed.size
decompressed = decompressed.copyOf(decompressed.size * 2)
}
decompressed.asByteArray().decodeToString(endIndex = decompressed.size - zStream.avail_out.convert<Int>())
}

frameBuffer = UByteArray(0)
return out.decodeToString()
}

override fun close() {
val ret = inflateEnd(zStream.ptr)
try {
Expand All @@ -76,23 +65,3 @@ private class NativeInflater : Inflater {
}
}
}

@ExperimentalForeignApi
private fun Int.check(validCodes: List<Int> = listOf(Z_OK), cleanup: () -> Unit = {}) {
if (this !in validCodes) {
try {
throw ZlibException(zErrorMessage(this).toString())
} finally {
cleanup()
}
}
}

private class ZlibException(message: String?) : IllegalStateException(message)

@ExperimentalForeignApi
private fun zErrorMessage(errorCode: Int) = zError(errorCode)?.toKString() ?: errorCode

@ExperimentalForeignApi
private fun throwZlibException(msg: CPointer<ByteVar>?, ret: Int): Nothing =
throw ZlibException(msg?.toKString() ?: zError(ret)?.toKString() ?: ret.toString())

0 comments on commit 0f4df33

Please sign in to comment.