Skip to content

Commit

Permalink
Improve native inflater
Browse files Browse the repository at this point in the history
- Correctly process ZLIB_SUFFIX
- Make new check function to check exit codes
  • Loading branch information
DRSchlaubi committed Mar 11, 2024
1 parent 7a1f87a commit 0b3848c
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 19 deletions.
2 changes: 1 addition & 1 deletion gateway/src/commonMain/kotlin/DefaultGateway.kt
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ public class DefaultGateway(private val data: DefaultGatewayData) : Gateway {
private suspend fun read(frame: Frame) {
defaultGatewayLogger.trace { "Received raw frame: $frame" }
val json = when {
compression -> with(inflater) { frame.inflateData() }
compression -> with(inflater) { frame.inflateData() } ?: return
else -> frame.data.decodeToString()
}

Expand Down
7 changes: 6 additions & 1 deletion gateway/src/commonMain/kotlin/Inflater.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@ import io.ktor.utils.io.core.*
import io.ktor.websocket.*

internal interface Inflater : Closeable {
fun Frame.inflateData(): String
/**
* Inflates this frame.
*
* @return the inflated frame or null if the received frame was incomplete
*/
fun Frame.inflateData(): String?
}

internal expect fun Inflater(): Inflater
42 changes: 25 additions & 17 deletions gateway/src/nativeMain/kotlin/Inflater.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,56 +6,64 @@ import platform.zlib.*

private const val MAX_WBITS = 15 // Maximum window size in bits
private const val CHUNK_SIZE = 256 * 1000
private val ZLIB_SUFFIX = ubyteArrayOf(0x00u, 0x00u, 0xffu, 0xffu)

internal actual fun Inflater(): Inflater = NativeInflater()

@OptIn(ExperimentalForeignApi::class)
private class NativeInflater : Inflater {
private var frameBuffer = UByteArray(0)

private val zStream = nativeHeap.alloc<z_stream>().apply {
val initResponse = inflateInit2(ptr, MAX_WBITS)
if (initResponse != Z_OK) {
inflateInit2(ptr, MAX_WBITS).check {
nativeHeap.free(this)
throw ZLibException("Could not initialize zlib: ${zErrorMessage(initResponse)}")
}
}

override fun Frame.inflateData(): String {
val compressedData = data
override fun Frame.inflateData(): String? {
frameBuffer += data.asUByteArray()
// check if the last four bytes are equal to ZLIB_SUFFIX
if (frameBuffer.size < 4 ||
!frameBuffer.copyOfRange(frameBuffer.size - 4, frameBuffer.size).contentEquals(ZLIB_SUFFIX)
) {
return null
}
var out = ByteArray(0)
memScoped {
val uncompressedDataSize = CHUNK_SIZE // allocate enough space for the uncompressed data
val uncompressedData = allocArray<uByteVar>(uncompressedDataSize)
zStream.apply {
next_in = compressedData.refTo(0).getPointer(memScope).reinterpret()
avail_in = compressedData.size.convert()
next_in = frameBuffer.refTo(0).getPointer(memScope)
avail_in = frameBuffer.size.convert()
}

do {
zStream.apply {
next_out = uncompressedData
avail_out = uncompressedDataSize.convert()
}
val resultCode = inflate(zStream.ptr, Z_NO_FLUSH)
if (resultCode != Z_OK && resultCode != Z_STREAM_END) {
throw ZLibException(
"An error occurred during decompression of frame: ${zErrorMessage(resultCode)}"
)
inflate(zStream.ptr, Z_NO_FLUSH).check(listOf(Z_OK, Z_STREAM_END)) {
frameBuffer = UByteArray(0)
}
out += uncompressedData.readBytes(uncompressedDataSize - zStream.avail_out.convert<Int>())
} while (zStream.avail_out == 0u)
}

frameBuffer = UByteArray(0)
return out.decodeToString()
}

override fun close() {
inflateEnd(zStream.ptr).check { nativeHeap.free(zStream) }
}
}

private fun Int.check(validCodes: List<Int> = listOf(Z_OK), cleanup: () -> Unit = {}) {
if (this !in validCodes) {
try {
val response = inflateEnd(zStream.ptr)
if(response != Z_OK) {
throw ZLibException("Could not end zstream: ${zErrorMessage(response)}")
}
throw ZLibException(zErrorMessage(this).toString())
} finally {
nativeHeap.free(zStream)
cleanup()
}
}
}
Expand Down

0 comments on commit 0b3848c

Please sign in to comment.