Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement AES Encryption for Voice #897

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
148 changes: 119 additions & 29 deletions voice/api/voice.api

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ public sealed class EncryptionMode(
`value`: String,
) : EncryptionMode(value)

public object AeadAes256Gcm : EncryptionMode("aead_aes256_gcm")

public object AeadAes256GcmRtpSize : EncryptionMode("aead_aes256_gcm_rtpsize")

public object XSalsa20Poly1305 : EncryptionMode("xsalsa20_poly1305")

public object XSalsa20Poly1305Suffix : EncryptionMode("xsalsa20_poly1305_suffix")
Expand All @@ -65,6 +69,8 @@ public sealed class EncryptionMode(
*/
public val entries: List<EncryptionMode> by lazy(mode = PUBLICATION) {
listOf(
AeadAes256Gcm,
AeadAes256GcmRtpSize,
XSalsa20Poly1305,
XSalsa20Poly1305Suffix,
XSalsa20Poly1305Lite,
Expand All @@ -77,6 +83,8 @@ public sealed class EncryptionMode(
* specified [value].
*/
public fun from(`value`: String): EncryptionMode = when (value) {
"aead_aes256_gcm" -> AeadAes256Gcm
"aead_aes256_gcm_rtpsize" -> AeadAes256GcmRtpSize
"xsalsa20_poly1305" -> XSalsa20Poly1305
"xsalsa20_poly1305_suffix" -> XSalsa20Poly1305Suffix
"xsalsa20_poly1305_lite" -> XSalsa20Poly1305Lite
Expand Down
4 changes: 3 additions & 1 deletion voice/src/main/kotlin/EncryptionMode.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
STRING_KORD_ENUM, name = "EncryptionMode",
docUrl = "https://discord.com/developers/docs/topics/voice-connections#establishing-a-voice-udp-connection-encryption-modes",
entries = [
Entry("AeadAes256Gcm", stringValue = "aead_aes256_gcm"),
Entry("AeadAes256GcmRtpSize", stringValue = "aead_aes256_gcm_rtpsize"),
Entry("XSalsa20Poly1305", stringValue = "xsalsa20_poly1305"),
Entry("XSalsa20Poly1305Suffix", stringValue = "xsalsa20_poly1305_suffix"),
Entry("XSalsa20Poly1305Lite", stringValue = "xsalsa20_poly1305_lite")
Entry("XSalsa20Poly1305Lite", stringValue = "xsalsa20_poly1305_lite"),
]
)

Expand Down
9 changes: 4 additions & 5 deletions voice/src/main/kotlin/VoiceConnection.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import dev.kord.common.annotation.KordVoice
import dev.kord.common.entity.Snowflake
import dev.kord.gateway.Gateway
import dev.kord.gateway.UpdateVoiceStatus
import dev.kord.voice.encryption.strategies.NonceStrategy
import dev.kord.voice.encryption.VoiceEncryption
import dev.kord.voice.gateway.VoiceGateway
import dev.kord.voice.gateway.VoiceGatewayConfiguration
import dev.kord.voice.handlers.StreamsHandler
Expand Down Expand Up @@ -41,7 +41,6 @@ public data class VoiceConnectionData(
* @param audioProvider a [AudioProvider] that will provide [AudioFrame] when required.
* @param frameInterceptor a [FrameInterceptor] that will intercept all outgoing [AudioFrame]s.
* @param frameSender the [AudioFrameSender] that will handle the sending of audio packets.
* @param nonceStrategy the [NonceStrategy] that is used during encryption of audio.
*/
@KordVoice
public class VoiceConnection(
Expand All @@ -54,8 +53,8 @@ public class VoiceConnection(
public val audioProvider: AudioProvider,
public val frameInterceptor: FrameInterceptor,
public val frameSender: AudioFrameSender,
public val nonceStrategy: NonceStrategy,
connectionDetachDuration: Duration
public val encryption: VoiceEncryption,
connectionDetachDuration: Duration,
) {
public val scope: CoroutineScope =
CoroutineScope(SupervisorJob() + CoroutineName("kord-voice-connection[${data.guildId.value}]"))
Expand Down Expand Up @@ -148,7 +147,7 @@ public suspend inline fun VoiceConnection(
selfId: Snowflake,
channelId: Snowflake,
guildId: Snowflake,
builder: VoiceConnectionBuilder.() -> Unit = {}
builder: VoiceConnectionBuilder.() -> Unit = {},
): VoiceConnection {
contract { callsInPlace(builder, InvocationKind.EXACTLY_ONCE) }
return VoiceConnectionBuilder(gateway, selfId, channelId, guildId).apply(builder).build()
Expand Down
90 changes: 46 additions & 44 deletions voice/src/main/kotlin/VoiceConnectionBuilder.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ import dev.kord.gateway.Gateway
import dev.kord.gateway.UpdateVoiceStatus
import dev.kord.gateway.VoiceServerUpdate
import dev.kord.gateway.VoiceStateUpdate
import dev.kord.voice.encryption.strategies.LiteNonceStrategy
import dev.kord.voice.encryption.strategies.NonceStrategy
import dev.kord.voice.encryption.AeadAes256Gcm
import dev.kord.voice.encryption.VoiceEncryption
import dev.kord.voice.exception.VoiceConnectionInitializationException
import dev.kord.voice.gateway.DefaultVoiceGatewayBuilder
import dev.kord.voice.gateway.VoiceGateway
Expand All @@ -31,7 +31,7 @@ public class VoiceConnectionBuilder(
public var gateway: Gateway,
public var selfId: Snowflake,
public var channelId: Snowflake,
public var guildId: Snowflake
public var guildId: Snowflake,
) {
/**
* The amount in milliseconds to wait for the events required to create a [VoiceConnection]. Default is 5000, or 5 seconds.
Expand Down Expand Up @@ -65,9 +65,10 @@ public class VoiceConnectionBuilder(

/**
* The nonce strategy to be used for the encryption of audio packets.
* If `null`, [dev.kord.voice.encryption.strategies.LiteNonceStrategy] will be used.
* If `null` & voice receive if disabled, [dev.kord.voice.encryption.AeadAes256Gcm] will be used,
* otherwise [dev.kord.voice.encryption.XSalsa20Poly1305] with the Lite strategy will be used.
*/
public var nonceStrategy: NonceStrategy? = null
public var encryption: VoiceEncryption? = null

/**
* A boolean indicating whether your voice state will be muted.
Expand Down Expand Up @@ -111,49 +112,50 @@ public class VoiceConnectionBuilder(
this.voiceGatewayBuilder = builder
}

private suspend fun Gateway.updateVoiceState(): Pair<VoiceConnectionData, VoiceGatewayConfiguration> = coroutineScope {
val voiceStateDeferred = async {
withTimeoutOrNull(timeout) {
gateway.events.filterIsInstance<VoiceStateUpdate>()
.filter { it.voiceState.guildId.value == guildId && it.voiceState.userId == selfId }
.first()
.voiceState
private suspend fun Gateway.updateVoiceState(): Pair<VoiceConnectionData, VoiceGatewayConfiguration> =
coroutineScope {
val voiceStateDeferred = async {
withTimeoutOrNull(timeout) {
gateway.events.filterIsInstance<VoiceStateUpdate>()
.filter { it.voiceState.guildId.value == guildId && it.voiceState.userId == selfId }
.first()
.voiceState
}
}
}

val voiceServerDeferred = async {
withTimeoutOrNull(timeout) {
gateway.events.filterIsInstance<VoiceServerUpdate>()
.filter { it.voiceServerUpdateData.guildId == guildId }
.first()
.voiceServerUpdateData
val voiceServerDeferred = async {
withTimeoutOrNull(timeout) {
gateway.events.filterIsInstance<VoiceServerUpdate>()
.filter { it.voiceServerUpdateData.guildId == guildId }
.first()
.voiceServerUpdateData
}
}
}

send(
UpdateVoiceStatus(
guildId = guildId,
channelId = channelId,
selfMute = selfMute,
selfDeaf = selfDeaf,
send(
UpdateVoiceStatus(
guildId = guildId,
channelId = channelId,
selfMute = selfMute,
selfDeaf = selfDeaf,
)
)
)

val voiceServer = voiceServerDeferred.await()
val voiceState = voiceStateDeferred.await()
val voiceServer = voiceServerDeferred.await()
val voiceState = voiceStateDeferred.await()

if (voiceServer == null || voiceState == null)
throw VoiceConnectionInitializationException("Did not receive a VoiceStateUpdate and or a VoiceServerUpdate in time!")
if (voiceServer == null || voiceState == null)
throw VoiceConnectionInitializationException("Did not receive a VoiceStateUpdate and or a VoiceServerUpdate in time!")

VoiceConnectionData(
selfId,
guildId,
voiceState.sessionId
) to VoiceGatewayConfiguration(
voiceServer.token,
"wss://${voiceServer.endpoint}/?v=${KordConfiguration.VOICE_GATEWAY_VERSION}",
)
}
VoiceConnectionData(
selfId,
guildId,
voiceState.sessionId
) to VoiceGatewayConfiguration(
voiceServer.token,
"wss://${voiceServer.endpoint}/?v=${KordConfiguration.VOICE_GATEWAY_VERSION}",
)
}

/**
* @throws dev.kord.voice.exception.VoiceConnectionInitializationException when there was a problem retrieving voice information from Discord.
Expand All @@ -166,19 +168,19 @@ public class VoiceConnectionBuilder(
.build()
val udpSocket = udpSocket ?: GlobalVoiceUdpSocket
val audioProvider = audioProvider ?: EmptyAudioPlayerProvider
val nonceStrategy = nonceStrategy ?: LiteNonceStrategy()
val encryption = encryption ?: AeadAes256Gcm
val frameInterceptor = frameInterceptor ?: DefaultFrameInterceptor()
val audioSender =
audioSender ?: DefaultAudioFrameSender(
DefaultAudioFrameSenderData(
udpSocket,
frameInterceptor,
audioProvider,
nonceStrategy
encryption
)
)
val streams =
streams ?: if (receiveVoice) DefaultStreams(voiceGateway, udpSocket, nonceStrategy) else NOPStreams
streams ?: if (receiveVoice) DefaultStreams(voiceGateway, udpSocket, encryption) else NOPStreams

return VoiceConnection(
voiceConnectionData,
Expand All @@ -190,7 +192,7 @@ public class VoiceConnectionBuilder(
audioProvider,
frameInterceptor,
audioSender,
nonceStrategy,
encryption,
connectionDetachDuration
)
}
Expand Down
107 changes: 107 additions & 0 deletions voice/src/main/kotlin/encryption/AeadAes256Gcm.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package dev.kord.voice.encryption

import dev.kord.voice.EncryptionMode
import dev.kord.voice.encryption.VoiceEncryption.Box
import dev.kord.voice.encryption.VoiceEncryption.Unbox
import dev.kord.voice.io.ByteArrayView
import dev.kord.voice.io.MutableByteArrayCursor
import dev.kord.voice.io.mutableCursor
import dev.kord.voice.io.view
import dev.kord.voice.udp.RTPPacket
import javax.crypto.Cipher
import javax.crypto.spec.GCMParameterSpec
import javax.crypto.spec.SecretKeySpec

/**
* An [encryption method][VoiceEncryption] that uses the AES-256 GCM cipher.
*/
public data object AeadAes256Gcm : VoiceEncryption {
private const val AUTH_TAG_LEN = 16
private const val NONCE_LEN = 4
private const val IV_LEN = 12

override val mode: EncryptionMode get() = EncryptionMode.AeadAes256Gcm

override val nonceLength: Int get() = 4

override fun createBox(key: ByteArray): Box = BoxImpl(key)

override fun createUnbox(key: ByteArray): Unbox = UnboxImpl(key)

private abstract class Common(key: ByteArray) {
protected val iv = ByteArray(IV_LEN)
protected val ivCursor = iv.mutableCursor()

protected val cipherKey = SecretKeySpec(key, "AES")
protected val cipher: Cipher = Cipher.getInstance("AES/GCM/NoPadding")

fun apply(
mode: Int,
src: ByteArrayView,
dst: MutableByteArrayCursor,
aead: ByteArrayView,
nonce: ByteArray,
writeNonce: MutableByteArrayCursor.(nonce: ByteArray) -> Unit,
): Boolean {
iv.fill(0)
ivCursor.reset()
ivCursor.apply { writeNonce(nonce) }

init(mode)
cipher.updateAAD(aead.data, aead.dataStart, aead.viewSize)
dst.cursor += cipher.doFinal(src.data, src.dataStart, src.viewSize, dst.data, dst.cursor)

return true
}

fun init(mode: Int) {
cipher.init(mode, cipherKey, GCMParameterSpec(AUTH_TAG_LEN * 8, iv, 0, IV_LEN))
}
}

private class BoxImpl(key: ByteArray) : Box, Common(key) {
val nonceBuffer: ByteArray = ByteArray(NONCE_LEN)
val nonceCursor by lazy { nonceBuffer.mutableCursor() }
val nonceView by lazy { nonceBuffer.view() }

override val overhead: Int
get() = AUTH_TAG_LEN + NONCE_LEN

override fun apply(
src: ByteArrayView,
dst: MutableByteArrayCursor,
aead: ByteArrayView,
nonce: ByteArray,
): Boolean = apply(Cipher.ENCRYPT_MODE, src, dst, aead, nonce, MutableByteArrayCursor::writeByteArray)

override fun appendNonce(nonce: ByteArrayView, dst: MutableByteArrayCursor) {
dst.writeByteView(nonce)
}

override fun generateNonce(header: () -> ByteArrayView): ByteArrayView {
nonceCursor.reset()
nonceCursor.writeByteView(header().view(0, NONCE_LEN)!!)
return nonceView
}
}

private class UnboxImpl(key: ByteArray) : Unbox, Common(key) {
override fun apply(
src: ByteArrayView,
dst: MutableByteArrayCursor,
aead: ByteArrayView,
nonce: ByteArray,
): Boolean = apply(Cipher.DECRYPT_MODE, src, dst, aead, nonce) { writeByteView(it.view(0, NONCE_LEN)!!) }

override fun getNonce(packet: RTPPacket): ByteArrayView = with(packet.payload) {
// grab the last NONCE_LEN bytes of the packet payload.
val nonce = view(dataEnd - NONCE_LEN, dataEnd)
?: error("Failed to strip nonce from RTP packet payload.")

// resize the payload view to exclude the nonce.
resize(0, dataEnd - NONCE_LEN)

return nonce
}
}
}
Loading