diff --git a/common/src/commonMain/kotlin/DiscordBitSet.kt b/common/src/commonMain/kotlin/DiscordBitSet.kt index ece22501c9d..33b2b1bbb93 100644 --- a/common/src/commonMain/kotlin/DiscordBitSet.kt +++ b/common/src/commonMain/kotlin/DiscordBitSet.kt @@ -1,5 +1,6 @@ package dev.kord.common +import dev.kord.common.serialization.LongOrStringSerializer import kotlinx.serialization.KSerializer import kotlinx.serialization.Serializable import kotlinx.serialization.descriptors.PrimitiveKind @@ -16,7 +17,7 @@ private const val WIDTH = Long.SIZE_BITS public fun EmptyBitSet(): DiscordBitSet = DiscordBitSet() internal expect fun formatIntegerFromLittleEndianLongArray(data: LongArray): String -internal expect fun parseIntegerToBigEndianByteArray(value: String): ByteArray +internal expect fun parseNonNegativeIntegerToBigEndianByteArray(value: String): ByteArray @Serializable(with = DiscordBitSet.Serializer::class) public class DiscordBitSet(internal var data: LongArray) { // data is in little-endian order @@ -116,7 +117,8 @@ public class DiscordBitSet(internal var data: LongArray) { // data is in little- internal object Serializer : KSerializer { override val descriptor = PrimitiveSerialDescriptor("dev.kord.common.DiscordBitSet", PrimitiveKind.STRING) override fun serialize(encoder: Encoder, value: DiscordBitSet) = encoder.encodeString(value.value) - override fun deserialize(decoder: Decoder) = DiscordBitSet(decoder.decodeString()) + override fun deserialize(decoder: Decoder) = + DiscordBitSet(decoder.decodeSerializableValue(LongOrStringSerializer)) } } @@ -129,7 +131,7 @@ public fun DiscordBitSet(value: String): DiscordBitSet { return DiscordBitSet(longArrayOf(value.toULong().toLong())) } - val bytes = parseIntegerToBigEndianByteArray(value) + val bytes = parseNonNegativeIntegerToBigEndianByteArray(value) val longSize = (bytes.size / Long.SIZE_BYTES) + 1 val destination = LongArray(longSize) diff --git a/common/src/commonTest/kotlin/BitSetTests.kt b/common/src/commonTest/kotlin/BitSetTests.kt index 2ff2c4f401c..e70cde4e0b5 100644 --- a/common/src/commonTest/kotlin/BitSetTests.kt +++ b/common/src/commonTest/kotlin/BitSetTests.kt @@ -1,5 +1,9 @@ package dev.kord.common +import kotlinx.serialization.encodeToString +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.encodeToJsonElement import kotlin.js.JsName import kotlin.random.Random import kotlin.random.nextLong @@ -126,4 +130,36 @@ class BitSetTests { assertTrue(bits.value.all { it in '0'..'9' }) } } + + @Test + fun negative_values_cant_be_parsed() { + assertFailsWith { DiscordBitSet("-1") } + assertFailsWith { DiscordBitSet("-99999999999999999999999999999999") } + } + + private val numberStrings = listOf("0", "1", "1024", "6543654", "59946645771238946") + + // https://github.com/kordlib/kord/issues/911 + @Test + fun deserialization_works_with_json_strings_and_numbers() { + numberStrings.forEach { number -> + val string = "\"$number\"" + val expected = DiscordBitSet(number) + assertEquals(expected, Json.decodeFromString(string)) + assertEquals(expected, Json.decodeFromString(number)) + } + } + + @Test + fun serialization_works_and_produces_json_strings() { + numberStrings.forEach { number -> + val bitSet = DiscordBitSet(number) + val string = Json.encodeToString(bitSet) + val json = Json.encodeToJsonElement(bitSet) + assertEquals("\"$number\"", string) + assertIs(json) + assertTrue(json.isString) + assertEquals(number, json.content) + } + } } diff --git a/common/src/jvmMain/kotlin/DiscordBitSetJvm.kt b/common/src/jvmMain/kotlin/DiscordBitSetJvm.kt index 80f04d72b97..b9aa772c636 100644 --- a/common/src/jvmMain/kotlin/DiscordBitSetJvm.kt +++ b/common/src/jvmMain/kotlin/DiscordBitSetJvm.kt @@ -10,4 +10,6 @@ internal actual fun formatIntegerFromLittleEndianLongArray(data: LongArray): Str return BigInteger(/* signum = */ 1, /* magnitude = */ buffer.array()).toString() } -internal actual fun parseIntegerToBigEndianByteArray(value: String): ByteArray = BigInteger(value).toByteArray() +internal actual fun parseNonNegativeIntegerToBigEndianByteArray(value: String): ByteArray = BigInteger(value) + .also { if (it.signum() < 0) throw NumberFormatException("Invalid DiscordBitSet format: '$value'") } + .toByteArray() diff --git a/common/src/nonJvmMain/kotlin/DiscordBitSet.kt b/common/src/nonJvmMain/kotlin/DiscordBitSet.kt index 3c0e5f121a3..c883976ebf8 100644 --- a/common/src/nonJvmMain/kotlin/DiscordBitSet.kt +++ b/common/src/nonJvmMain/kotlin/DiscordBitSet.kt @@ -11,5 +11,7 @@ internal actual fun formatIntegerFromLittleEndianLongArray(data: LongArray) = BigInteger.fromByteArray(readBytes(), Sign.POSITIVE).toString() } -internal actual fun parseIntegerToBigEndianByteArray(value: String): ByteArray = - BigInteger.parseString(value).toByteArray() +internal actual fun parseNonNegativeIntegerToBigEndianByteArray(value: String): ByteArray = BigInteger + .parseString(value) + .also { if (it.isNegative) throw NumberFormatException("Invalid DiscordBitSet format: '$value'") } + .toByteArray()