diff --git a/gateway/api/gateway.api b/gateway/api/gateway.api index a69e58ec756..757479bb113 100644 --- a/gateway/api/gateway.api +++ b/gateway/api/gateway.api @@ -102,6 +102,43 @@ public final class dev/kord/gateway/AutoModerationRuleUpdate : dev/kord/gateway/ public fun toString ()Ljava/lang/String; } +public abstract class dev/kord/gateway/BaseGateway : dev/kord/gateway/Gateway { + public fun ()V + public fun detach (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public fun getCoroutineContext ()Lkotlin/coroutines/CoroutineContext; + protected abstract fun getDispatcher ()Lkotlinx/coroutines/CoroutineDispatcher; + public fun getEvents ()Lkotlinx/coroutines/flow/MutableSharedFlow; + public synthetic fun getEvents ()Lkotlinx/coroutines/flow/SharedFlow; + protected final fun getLog ()Lmu/KLogger; + protected final fun getState ()Ldev/kord/gateway/BaseGateway$State; + protected abstract fun onDetach (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + protected abstract fun onSend (Ldev/kord/gateway/Command;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + protected abstract fun onStart (Ldev/kord/gateway/GatewayConfiguration;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + protected abstract fun onStop (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public fun send (Ldev/kord/gateway/Command;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + protected final fun setState (Ldev/kord/gateway/BaseGateway$State;)V + public fun start (Ldev/kord/gateway/GatewayConfiguration;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public fun stop (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + protected final fun throwStateError ()Ljava/lang/Void; +} + +protected abstract class dev/kord/gateway/BaseGateway$State { + public synthetic fun (ZLkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun getRetry ()Z +} + +public final class dev/kord/gateway/BaseGateway$State$Detached : dev/kord/gateway/BaseGateway$State { + public static final field INSTANCE Ldev/kord/gateway/BaseGateway$State$Detached; +} + +public final class dev/kord/gateway/BaseGateway$State$Running : dev/kord/gateway/BaseGateway$State { + public fun (Z)V +} + +public final class dev/kord/gateway/BaseGateway$State$Stopped : dev/kord/gateway/BaseGateway$State { + public static final field INSTANCE Ldev/kord/gateway/BaseGateway$State$Stopped; +} + public final class dev/kord/gateway/ChannelCreate : dev/kord/gateway/DispatchEvent { public fun (Ldev/kord/common/entity/DiscordChannel;Ljava/lang/Integer;)V public final fun component1 ()Ldev/kord/common/entity/DiscordChannel; @@ -219,16 +256,12 @@ public final class dev/kord/gateway/Command$SerializationStrategy : kotlinx/seri public synthetic fun serialize (Lkotlinx/serialization/encoding/Encoder;Ljava/lang/Object;)V } -public final class dev/kord/gateway/DefaultGateway : dev/kord/gateway/Gateway { +public final class dev/kord/gateway/DefaultGateway : dev/kord/gateway/BaseGateway { public static final field Companion Ldev/kord/gateway/DefaultGateway$Companion; public fun (Ldev/kord/gateway/DefaultGatewayData;)V - public fun detach (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; - public fun getCoroutineContext ()Lkotlin/coroutines/CoroutineContext; - public fun getEvents ()Lkotlinx/coroutines/flow/SharedFlow; + public fun getEvents ()Lkotlinx/coroutines/flow/MutableSharedFlow; + public synthetic fun getEvents ()Lkotlinx/coroutines/flow/SharedFlow; public fun getPing ()Lkotlinx/coroutines/flow/StateFlow; - public fun send (Ldev/kord/gateway/Command;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; - public fun start (Ldev/kord/gateway/GatewayConfiguration;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; - public fun stop (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; } public final class dev/kord/gateway/DefaultGateway$Companion { @@ -618,6 +651,7 @@ public abstract interface class dev/kord/gateway/Gateway : kotlinx/coroutines/Co } public final class dev/kord/gateway/Gateway$Companion { + public final fun connectionManaged (Ldev/kord/gateway/DefaultGatewayData;Lkotlin/jvm/functions/Function1;)Ldev/kord/gateway/Gateway; public final fun none ()Ldev/kord/gateway/Gateway; } @@ -2018,6 +2052,94 @@ public final class dev/kord/gateway/builder/Shards { public fun toString ()Ljava/lang/String; } +public abstract interface class dev/kord/gateway/connection/GatewayConnection { + public abstract fun close (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public abstract fun getPing ()Lkotlinx/coroutines/flow/StateFlow; + public abstract fun open (Ldev/kord/gateway/connection/GatewayConnection$Data;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public abstract fun send (Ldev/kord/gateway/Command;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +} + +public abstract interface class dev/kord/gateway/connection/GatewayConnection$CloseReason { +} + +public final class dev/kord/gateway/connection/GatewayConnection$CloseReason$Error : dev/kord/gateway/connection/GatewayConnection$CloseReason { + public fun (Ljava/lang/Throwable;)V + public final fun getCause ()Ljava/lang/Throwable; +} + +public final class dev/kord/gateway/connection/GatewayConnection$CloseReason$InvalidSession : dev/kord/gateway/connection/GatewayConnection$CloseReason { + public static final field INSTANCE Ldev/kord/gateway/connection/GatewayConnection$CloseReason$InvalidSession; +} + +public final class dev/kord/gateway/connection/GatewayConnection$CloseReason$Manual : dev/kord/gateway/connection/GatewayConnection$CloseReason { + public static final field INSTANCE Ldev/kord/gateway/connection/GatewayConnection$CloseReason$Manual; +} + +public final class dev/kord/gateway/connection/GatewayConnection$CloseReason$Plain : dev/kord/gateway/connection/GatewayConnection$CloseReason { + public fun (ILjava/lang/String;Ldev/kord/gateway/Resume;)V + public final fun getCode ()I + public final fun getMessage ()Ljava/lang/String; + public final fun getResume ()Ldev/kord/gateway/Resume; +} + +public final class dev/kord/gateway/connection/GatewayConnection$CloseReason$Reconnect : dev/kord/gateway/connection/GatewayConnection$CloseReason { + public static final field INSTANCE Ldev/kord/gateway/connection/GatewayConnection$CloseReason$Reconnect; +} + +public final class dev/kord/gateway/connection/GatewayConnection$CloseReason$ResumableInvalidSession : dev/kord/gateway/connection/GatewayConnection$CloseReason { + public fun (Ldev/kord/gateway/Resume;)V + public final fun getResume ()Ldev/kord/gateway/Resume; +} + +public final class dev/kord/gateway/connection/GatewayConnection$CloseReason$ResumableReconnect : dev/kord/gateway/connection/GatewayConnection$CloseReason { + public fun (Ldev/kord/gateway/Resume;)V + public final fun getResume ()Ldev/kord/gateway/Resume; +} + +public final class dev/kord/gateway/connection/GatewayConnection$Data { + public fun (Ldev/kord/common/entity/DiscordShard;Ljava/net/URI;Ldev/kord/gateway/connection/GatewayConnection$Session;Lio/ktor/client/HttpClient;Lkotlinx/serialization/json/Json;Lkotlinx/coroutines/flow/MutableSharedFlow;Ldev/kord/common/ratelimit/RateLimiter;Ldev/kord/gateway/ratelimit/IdentifyRateLimiter;Ldev/kord/gateway/retry/Retry;)V + public final fun component1 ()Ldev/kord/common/entity/DiscordShard; + public final fun component2 ()Ljava/net/URI; + public final fun component3 ()Ldev/kord/gateway/connection/GatewayConnection$Session; + public final fun component4 ()Lio/ktor/client/HttpClient; + public final fun component5 ()Lkotlinx/serialization/json/Json; + public final fun component6 ()Lkotlinx/coroutines/flow/MutableSharedFlow; + public final fun component7 ()Ldev/kord/common/ratelimit/RateLimiter; + public final fun component8 ()Ldev/kord/gateway/ratelimit/IdentifyRateLimiter; + public final fun component9 ()Ldev/kord/gateway/retry/Retry; + public final fun copy (Ldev/kord/common/entity/DiscordShard;Ljava/net/URI;Ldev/kord/gateway/connection/GatewayConnection$Session;Lio/ktor/client/HttpClient;Lkotlinx/serialization/json/Json;Lkotlinx/coroutines/flow/MutableSharedFlow;Ldev/kord/common/ratelimit/RateLimiter;Ldev/kord/gateway/ratelimit/IdentifyRateLimiter;Ldev/kord/gateway/retry/Retry;)Ldev/kord/gateway/connection/GatewayConnection$Data; + public static synthetic fun copy$default (Ldev/kord/gateway/connection/GatewayConnection$Data;Ldev/kord/common/entity/DiscordShard;Ljava/net/URI;Ldev/kord/gateway/connection/GatewayConnection$Session;Lio/ktor/client/HttpClient;Lkotlinx/serialization/json/Json;Lkotlinx/coroutines/flow/MutableSharedFlow;Ldev/kord/common/ratelimit/RateLimiter;Ldev/kord/gateway/ratelimit/IdentifyRateLimiter;Ldev/kord/gateway/retry/Retry;ILjava/lang/Object;)Ldev/kord/gateway/connection/GatewayConnection$Data; + public fun equals (Ljava/lang/Object;)Z + public final fun getClient ()Lio/ktor/client/HttpClient; + public final fun getEventFlow ()Lkotlinx/coroutines/flow/MutableSharedFlow; + public final fun getIdentifyRateLimiter ()Ldev/kord/gateway/ratelimit/IdentifyRateLimiter; + public final fun getJson ()Lkotlinx/serialization/json/Json; + public final fun getReconnectRetry ()Ldev/kord/gateway/retry/Retry; + public final fun getSendRateLimiter ()Ldev/kord/common/ratelimit/RateLimiter; + public final fun getSession ()Ldev/kord/gateway/connection/GatewayConnection$Session; + public final fun getShard ()Ldev/kord/common/entity/DiscordShard; + public final fun getUri ()Ljava/net/URI; + public fun hashCode ()I + public fun toString ()Ljava/lang/String; +} + +public abstract interface class dev/kord/gateway/connection/GatewayConnection$Session { +} + +public final class dev/kord/gateway/connection/GatewayConnection$Session$New : dev/kord/gateway/connection/GatewayConnection$Session { + public fun (Ldev/kord/gateway/Identify;)V + public final fun getIdentify ()Ldev/kord/gateway/Identify; +} + +public final class dev/kord/gateway/connection/GatewayConnection$Session$Resumed : dev/kord/gateway/connection/GatewayConnection$Session { + public fun (Ldev/kord/gateway/Resume;)V + public final fun getResume ()Ldev/kord/gateway/Resume; +} + +public final class dev/kord/gateway/connection/GatewayConnectionKt { + public static final fun GatewayConnection ()Ldev/kord/gateway/connection/GatewayConnection; +} + public abstract interface class dev/kord/gateway/ratelimit/IdentifyRateLimiter { public abstract fun consume (ILkotlinx/coroutines/flow/SharedFlow;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public abstract fun getMaxConcurrency ()I diff --git a/gateway/src/main/kotlin/BaseGateway.kt b/gateway/src/main/kotlin/BaseGateway.kt new file mode 100644 index 00000000000..b9066a84490 --- /dev/null +++ b/gateway/src/main/kotlin/BaseGateway.kt @@ -0,0 +1,141 @@ +package dev.kord.gateway + +import kotlinx.atomicfu.AtomicRef +import kotlinx.atomicfu.atomic +import kotlinx.atomicfu.update +import kotlinx.coroutines.CoroutineDispatcher +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.cancel +import kotlinx.coroutines.flow.MutableSharedFlow +import mu.KLogger +import mu.KotlinLogging +import kotlin.coroutines.CoroutineContext + +/** + * Base abstraction for a gateway implementation. + */ +public abstract class BaseGateway : Gateway { + + /** + * The logger for this gateway. + */ + protected val log: KLogger = KotlinLogging.logger { } + + /** + * The current state of this gateway as atomic reference. + * @see State + */ + private val atomicState: AtomicRef = atomic(State.Stopped) + + /** + * The current state of this gateway. + * @see State + */ + protected var state: State by atomicState + + /** + * The dispatcher used to run the gateway. + * It will be used to assemble the [coroutineContext] of this gateway. + * By default, there is [SupervisorJob] before the dispatcher in [coroutineContext]. + */ + protected abstract val dispatcher: CoroutineDispatcher + + override val coroutineContext: CoroutineContext get() = SupervisorJob() + dispatcher + + override val events: MutableSharedFlow = MutableSharedFlow() + + override suspend fun start(configuration: GatewayConfiguration) { + requireState() + atomicState.update { State.Running(true) } + onStart(configuration) + } + + /** + * This method is called just after the [start] method, + * once the state is updated to [State.Running]. + * The state is ensured to be valid before this method is called. + */ + protected abstract suspend fun onStart(configuration: GatewayConfiguration) + + override suspend fun stop() { + requireStateIsNot() + events.emit(Close.UserClose) + atomicState.update { State.Stopped } + onStop() + } + + /** + * This method is called just after the [stop] method, + * once the [Close.UserClose] event is emitted, + * and the state is updated to [State.Stopped]. + * The state is ensured to be valid before this method is called. + */ + protected abstract suspend fun onStop() + + override suspend fun detach() { + (this as CoroutineScope).cancel() + if (state is State.Detached) return + atomicState.update { State.Detached } + events.emit(Close.Detach) + onDetach() + } + + /** + * This method is called just after the [detach] method, + * once the state is updated to [State.Detached], + * and the [Close.Detach] event is emitted. + * The state is ensured to be valid before this method is called. + */ + protected abstract suspend fun onDetach() + + override suspend fun send(command: Command) { + requireStateIsNot() + onSend(command) + } + + /** + * This method is called just after the [send] method. + * The state is ensured to be valid before this method is called. + */ + protected abstract suspend fun onSend(command: Command) + + /** + * Checks whether the current [state] is not of type [T]. + * If it is, an [IllegalStateException] is thrown with a describing message. + */ + protected inline fun requireStateIsNot() { + if (state !is T) return + throwStateError() + } + + /** + * Checks whether the current [state] is of type [T]. + * If it isn't, an [IllegalStateException] is thrown with a describing message. + */ + protected inline fun requireState() { + if (state is T) return + throwStateError() + } + + /** + * Throws an [IllegalStateException] with a describing message based on the current [state]. + */ + protected fun throwStateError(): Nothing { + when (state) { + is State.Stopped -> error("The gateway is already stopped.") + is State.Running -> error("The gateway is already running, call stop() first.") + is State.Detached -> error("The Gateway has been detached and can no longer be used, create a new instance instead.") + } + } + + /** + * Represents the current state of the gateway. + * @param retry whether the gateway should attempt to reconnect when it stops. + */ + protected sealed class State(public val retry: Boolean) { + public object Stopped : State(false) + public class Running(retry: Boolean) : State(retry) + public object Detached : State(false) + } +} diff --git a/gateway/src/main/kotlin/ConnectionManagedGateway.kt b/gateway/src/main/kotlin/ConnectionManagedGateway.kt new file mode 100644 index 00000000000..0848c3c48b7 --- /dev/null +++ b/gateway/src/main/kotlin/ConnectionManagedGateway.kt @@ -0,0 +1,130 @@ +package dev.kord.gateway + +import dev.kord.common.entity.optional.optional +import dev.kord.common.entity.optional.optionalInt +import dev.kord.gateway.connection.GatewayConnection +import dev.kord.gateway.connection.GatewayConnection.CloseReason +import dev.kord.gateway.connection.GatewayConnection.Session +import kotlinx.atomicfu.AtomicRef +import kotlinx.atomicfu.atomic +import kotlinx.atomicfu.loop +import kotlinx.coroutines.CoroutineDispatcher +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.flow.MutableSharedFlow +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.launchIn +import kotlinx.coroutines.flow.onEach +import kotlinx.serialization.json.Json +import java.net.URI +import kotlin.coroutines.CoroutineContext +import kotlin.time.Duration + +internal typealias ConnectionProvider = suspend () -> GatewayConnection + +internal class ConnectionManagedGateway( + private val connectionProvider: ConnectionProvider, + private val data: DefaultGatewayData +) : BaseGateway() { + + override val dispatcher: CoroutineDispatcher = data.dispatcher + override val coroutineContext: CoroutineContext = SupervisorJob() + dispatcher + override val ping: MutableStateFlow = MutableStateFlow(null) + override val events: MutableSharedFlow = data.eventFlow + private val atomicConnection: AtomicRef = atomic(null) + private val atomicSession: AtomicRef = atomic(null) + private val jsonParser = Json { + ignoreUnknownKeys = true + isLenient = true + } + + override suspend fun onStart(configuration: GatewayConfiguration) { + data.reconnectRetry.reset() + while (data.reconnectRetry.hasNext && state is State.Running) { + val connection = connectionProvider().also { atomicConnection.value = it } + val pingForward = connection.ping.onEach { ping.value = it }.launchIn(this) + val session = atomicSession.value ?: newSessionFromConfig(configuration).also { atomicSession.value = it } + val connectionData = GatewayConnection.Data( + shard = configuration.shard, + uri = URI.create(data.url), + session = session, + client = data.client, + json = jsonParser, + eventFlow = data.eventFlow, + sendRateLimiter = data.sendRateLimiter, + identifyRateLimiter = data.identifyRateLimiter, + reconnectRetry = data.reconnectRetry + ) + val closeReason = connection.open(connectionData) + pingForward.cancel() + when (closeReason) { + is CloseReason.ResumableInvalidSession -> { + log.trace { "Gateway resumable invalid session." } + atomicSession.value = Session.Resumed(closeReason.resume) + } + + is CloseReason.ResumableReconnect -> { + log.trace { "Gateway resumable reconnect." } + atomicSession.value = Session.Resumed(closeReason.resume) + } + + is CloseReason.Manual -> { + log.trace { "Gateway connection closed manually." } + break + } + + is CloseReason.Error -> { + log.error(closeReason.cause) { "Gateway connection closed with error." } + if (closeReason.cause is java.nio.channels.UnresolvedAddressException) { + data.eventFlow.emit(Close.Timeout) + } + } + + is CloseReason.Plain -> { + val closeReasonCode = GatewayCloseCode.values().find { it.code == closeReason.code } + if (closeReasonCode == null || !closeReasonCode.retry) { + error("Gateway closed: ${closeReason.code} ${closeReason.message}") + } + atomicSession.value = if (closeReasonCode.resetSession && closeReason.resume != null) { + Session.Resumed(closeReason.resume) + } else null + } + + else -> atomicSession.value = null + } + } + } + + private fun newSessionFromConfig(config: GatewayConfiguration): Session.New { + val identify = Identify( + token = config.token, + properties = IdentifyProperties(os, config.name, config.name), + compress = false.optional(), + largeThreshold = config.threshold.optionalInt(), + shard = config.shard.optional(), + presence = config.presence, + intents = config.intents + ) + return Session.New(identify) + } + + override suspend fun onStop() { + useConnection { it.close() } + } + + override suspend fun onDetach() { + useConnection { it.close() } + } + + override suspend fun onSend(command: Command) { + useConnection { it.send(command) } + } + + private inline fun useConnection(block: (GatewayConnection) -> Unit) { + atomicConnection.loop { + if (it != null) { + block(it) + return + } + } + } +} diff --git a/gateway/src/main/kotlin/DefaultGateway.kt b/gateway/src/main/kotlin/DefaultGateway.kt index 933e2a8ce6c..bf0e33a36a0 100644 --- a/gateway/src/main/kotlin/DefaultGateway.kt +++ b/gateway/src/main/kotlin/DefaultGateway.kt @@ -13,9 +13,6 @@ import io.ktor.client.request.* import io.ktor.http.* import io.ktor.util.logging.* import io.ktor.websocket.* -import kotlinx.atomicfu.AtomicRef -import kotlinx.atomicfu.atomic -import kotlinx.atomicfu.update import kotlinx.coroutines.* import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.channels.ReceiveChannel @@ -23,23 +20,13 @@ import kotlinx.coroutines.flow.* import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock import kotlinx.serialization.json.Json -import mu.KotlinLogging import java.io.ByteArrayOutputStream import java.util.zip.Inflater import java.util.zip.InflaterOutputStream import kotlin.contracts.InvocationKind import kotlin.contracts.contract -import kotlin.coroutines.CoroutineContext import kotlin.time.Duration -private val defaultGatewayLogger = KotlinLogging.logger { } - -private sealed class State(val retry: Boolean) { - object Stopped : State(false) - class Running(retry: Boolean) : State(retry) - object Detached : State(false) -} - /** * @param url The url to connect to. * @param client The [HttpClient] from which a WebSocket will be created, requires the [WebSockets] plugin to be @@ -61,21 +48,16 @@ public data class DefaultGatewayData( /** * The default Gateway implementation of Kord, using an [HttpClient] for the underlying webSocket */ -public class DefaultGateway(private val data: DefaultGatewayData) : Gateway { - - override val coroutineContext: CoroutineContext = SupervisorJob() + data.dispatcher +public class DefaultGateway(private val data: DefaultGatewayData) : BaseGateway() { private val compression: Boolean - private val _ping = MutableStateFlow(null) - override val ping: StateFlow get() = _ping - override val events: SharedFlow = data.eventFlow + override val ping: StateFlow get() = _ping + override val events: MutableSharedFlow = data.eventFlow private lateinit var socket: DefaultClientWebSocketSession - private val state: AtomicRef = atomic(State.Stopped) - private val handshakeHandler: HandshakeHandler private lateinit var inflater: Inflater @@ -85,6 +67,8 @@ public class DefaultGateway(private val data: DefaultGatewayData) : Gateway { isLenient = true } + override val dispatcher: CoroutineDispatcher = data.dispatcher + private val stateMutex = Mutex() init { @@ -99,11 +83,11 @@ public class DefaultGateway(private val data: DefaultGatewayData) : Gateway { InvalidSessionHandler(events) { restart(it) } } - //running on default dispatchers because ktor does *not* like running on an EmptyCoroutineContext from main - override suspend fun start(configuration: GatewayConfiguration): Unit = withContext(Dispatchers.Default) { - resetState(configuration) + override suspend fun onStart(configuration: GatewayConfiguration): Unit = withContext(Dispatchers.Default) { + handshakeHandler.configuration = configuration + data.reconnectRetry.reset() - while (data.reconnectRetry.hasNext && state.value is State.Running) { + while (data.reconnectRetry.hasNext && state is State.Running) { try { val (needsIdentify, gatewayUrl) = handshakeHandler.needsIdentifyAndGatewayUrl @@ -111,7 +95,7 @@ public class DefaultGateway(private val data: DefaultGatewayData) : Gateway { data.identifyRateLimiter.consume(shardId = configuration.shard.index, events) } - defaultGatewayLogger.trace { "opening gateway connection to $gatewayUrl" } + log.trace { "Opening gateway connection to $gatewayUrl." } socket = data.client.webSocketSession { url(gatewayUrl) } /** @@ -121,7 +105,7 @@ public class DefaultGateway(private val data: DefaultGatewayData) : Gateway { */ inflater = Inflater() } catch (exception: Exception) { - defaultGatewayLogger.error(exception) + log.error(exception) if (exception is java.nio.channels.UnresolvedAddressException) { data.eventFlow.emit(Close.Timeout) } @@ -133,47 +117,37 @@ public class DefaultGateway(private val data: DefaultGatewayData) : Gateway { try { readSocket() } catch (exception: Exception) { - defaultGatewayLogger.error(exception) + log.error(exception) } - defaultGatewayLogger.trace { "gateway connection closing" } + log.trace { "Gateway connection closing." } try { handleClose() } catch (exception: Exception) { - defaultGatewayLogger.error(exception) + log.error(exception) } - defaultGatewayLogger.trace { "handled gateway connection closed" } + log.trace { "Handled gateway connection closed." } - if (state.value.retry) data.reconnectRetry.retry() - else data.eventFlow.emit(Close.RetryLimitReached) + if (state.retry) { + data.reconnectRetry.retry() + } else { + events.emit(Close.RetryLimitReached) + } } _ping.value = null if (!data.reconnectRetry.hasNext) { - defaultGatewayLogger.warn { "retry limit exceeded, gateway closing" } - } - } - - private suspend fun resetState(configuration: GatewayConfiguration) = stateMutex.withLock { - when (state.value) { - is State.Running -> throw IllegalStateException(gatewayRunningError) - State.Detached -> throw IllegalStateException(gatewayDetachedError) - State.Stopped -> Unit + log.warn { "Retry limit exceeded, gateway closing." } } - - handshakeHandler.configuration = configuration - data.reconnectRetry.reset() - state.update { State.Running(true) } //resetting state } - private suspend fun readSocket() { socket.incoming.asFlow().buffer(Channel.UNLIMITED).collect { when (it) { is Frame.Binary, is Frame.Text -> read(it) - else -> { /*ignore*/ + else -> { /* Ignored. */ } } } @@ -197,11 +171,11 @@ public class DefaultGateway(private val data: DefaultGatewayData) : Gateway { } try { - defaultGatewayLogger.trace { "Gateway <<< $json" } + log.trace { "Gateway <<< $json" } val event = jsonParser.decodeFromString(Event.DeserializationStrategy, json) ?: return - data.eventFlow.emit(event) + events.emit(event) } catch (exception: Exception) { - defaultGatewayLogger.error(exception) + log.error(exception) } } @@ -211,70 +185,58 @@ public class DefaultGateway(private val data: DefaultGatewayData) : Gateway { socket.closeReason.await() } ?: return - defaultGatewayLogger.trace { "Gateway closed: ${reason.code} ${reason.message}" } + log.trace { "Gateway closed: ${reason.code} ${reason.message}" } val discordReason = values().firstOrNull { it.code == reason.code.toInt() } ?: return data.eventFlow.emit(Close.DiscordClose(discordReason, discordReason.retry)) when { !discordReason.retry -> { - state.update { State.Stopped } + state = State.Stopped throw IllegalStateException("Gateway closed: ${reason.code} ${reason.message}") } + discordReason.resetSession -> { - setStopped() + state = State.Running(true) } } } - // This avoids a bug with the atomicfu compiler plugin - private fun setStopped() { - state.update { State.Running(true) } - } - private fun ReceiveChannel.asFlow() = flow { try { for (value in this@asFlow) emit(value) } catch (ignore: CancellationException) { - //reading was stopped from somewhere else, ignore + // Reading was stopped from somewhere else, ignored. } } - override suspend fun stop() { - check(state.value !is State.Detached) { "The resources of this gateway are detached, create another one" } - data.eventFlow.emit(Close.UserClose) - state.update { State.Stopped } + override suspend fun onStop() { _ping.value = null if (socketOpen) socket.close(CloseReason(1000, "leaving")) } internal suspend fun restart(code: Close) { - check(state.value !is State.Detached) { "The resources of this gateway are detached, create another one" } - state.update { State.Running(false) } + requireStateIsNot() + state = State.Running(false) if (socketOpen) { data.eventFlow.emit(code) socket.close(CloseReason(4900, "reconnecting")) } } - override suspend fun detach() { - (this as CoroutineScope).cancel() - if (state.value is State.Detached) return - state.update { State.Detached } + override suspend fun onDetach() { _ping.value = null - data.eventFlow.emit(Close.Detach) if (::socket.isInitialized) { socket.close() } } - override suspend fun send(command: Command): Unit = stateMutex.withLock { - check(state.value !is State.Detached) { "The resources of this gateway are detached, create another one" } + override suspend fun onSend(command: Command) { sendUnsafe(command) } private suspend fun trySend(command: Command) = stateMutex.withLock { - if (state.value !is State.Running) return@withLock + if (state !is State.Running) return@withLock sendUnsafe(command) } @@ -282,22 +244,18 @@ public class DefaultGateway(private val data: DefaultGatewayData) : Gateway { data.sendRateLimiter.consume() val json = Json.encodeToString(Command.SerializationStrategy, command) if (command is Identify) { - defaultGatewayLogger.trace { - val copy = command.copy(token = "token") + log.trace { + val copy = command.copy(token = "Hidden") "Gateway >>> ${Json.encodeToString(Command.SerializationStrategy, copy)}" } - } else defaultGatewayLogger.trace { "Gateway >>> $json" } + } else log.trace { "Gateway >>> $json" } socket.send(Frame.Text(json)) } @OptIn(ExperimentalCoroutinesApi::class) private val socketOpen get() = ::socket.isInitialized && !socket.outgoing.isClosedForSend && !socket.incoming.isClosedForReceive - public companion object { - private const val gatewayRunningError = "The Gateway is already running, call stop() first." - private const val gatewayDetachedError = - "The Gateway has been detached and can no longer be used, create a new instance instead." - } + public companion object } public inline fun DefaultGateway(builder: DefaultGatewayBuilder.() -> Unit = {}): DefaultGateway { diff --git a/gateway/src/main/kotlin/Gateway.kt b/gateway/src/main/kotlin/Gateway.kt index dcafb555a4c..232668cf53a 100644 --- a/gateway/src/main/kotlin/Gateway.kt +++ b/gateway/src/main/kotlin/Gateway.kt @@ -4,6 +4,7 @@ import dev.kord.common.entity.Snowflake import dev.kord.common.entity.optional.Optional import dev.kord.gateway.builder.PresenceBuilder import dev.kord.gateway.builder.RequestGuildMembersBuilder +import dev.kord.gateway.connection.GatewayConnection import io.ktor.util.logging.* import kotlinx.coroutines.* import kotlinx.coroutines.channels.Channel @@ -99,6 +100,14 @@ public interface Gateway : CoroutineScope { */ public fun none(): Gateway = None + /** + * Creates a [Gateway] that will use the provided connection and [data] to connect to the Discord gateway. + */ + public fun connectionManaged( + data: DefaultGatewayData, + connectionProvider: suspend () -> GatewayConnection + ): Gateway = ConnectionManagedGateway(connectionProvider, data) + } } diff --git a/gateway/src/main/kotlin/connection/DefaultGatewayConnection.kt b/gateway/src/main/kotlin/connection/DefaultGatewayConnection.kt new file mode 100644 index 00000000000..c175c74e35c --- /dev/null +++ b/gateway/src/main/kotlin/connection/DefaultGatewayConnection.kt @@ -0,0 +1,282 @@ +package dev.kord.gateway.connection + +import dev.kord.gateway.* +import dev.kord.gateway.connection.GatewayConnection.* +import io.ktor.client.plugins.websocket.* +import io.ktor.client.request.* +import io.ktor.http.* +import io.ktor.websocket.* +import kotlinx.atomicfu.AtomicBoolean +import kotlinx.atomicfu.AtomicRef +import kotlinx.atomicfu.atomic +import kotlinx.atomicfu.loop +import kotlinx.coroutines.* +import kotlinx.coroutines.flow.MutableStateFlow +import mu.KLogger +import mu.KotlinLogging +import java.io.ByteArrayOutputStream +import java.io.Closeable +import java.nio.charset.StandardCharsets +import java.util.zip.InflaterOutputStream +import kotlin.time.Duration +import kotlin.time.Duration.Companion.seconds +import kotlin.time.TimeSource +import kotlin.time.TimeSource.Monotonic.ValueTimeMark +import io.ktor.websocket.CloseReason as KtorCloseReason + +/** + * Default implementation of [GatewayConnection]. + */ +internal class DefaultGatewayConnection : GatewayConnection { + + override val ping: MutableStateFlow = MutableStateFlow(null) + private lateinit var inflater: FrameInflater + private lateinit var data: Data + private lateinit var session: DefaultClientWebSocketSession + private lateinit var scope: CoroutineScope + private val atomicPossiblyZombie: AtomicBoolean = atomic(false) + private var possiblyZombie: Boolean by atomicPossiblyZombie + private val atomicSequence: AtomicRef = atomic(null) + private val sequence: Int? by atomicSequence + private val log: KLogger = KotlinLogging.logger { } + private val atomicHeartbeatTimeMark: AtomicRef = atomic(TimeSource.Monotonic.markNow()) + private var heartbeatTimeMark: ValueTimeMark by atomicHeartbeatTimeMark + private val atomicReceivedHello: AtomicBoolean = atomic(false) + private val atomicReadyData: AtomicRef = atomic(null) + private val atomicReconnectRequested: AtomicBoolean = atomic(false) + private val atomicInvalidSession: AtomicRef = atomic(null) + private val atomicState: AtomicRef = atomic(State.Uninitialized) + private val atomicManualClose: AtomicBoolean = atomic(false) + private val atomicHeartbeatJob: AtomicRef = atomic(null) + private val sessionToken: String + get() = when (val sessionData = data.session) { + is Session.New -> sessionData.identify.token + is Session.Resumed -> sessionData.resume.token + } + + override suspend fun open(data: Data): CloseReason { + if (!atomicState.compareAndSet(State.Uninitialized, State.Opening)) errorInvalidState() + this.data = data + val url = Url(data.uri) + val isZLibCompressed = url.parameters.contains("compress", "zlib-stream") + inflater = if (isZLibCompressed) FrameInflater.ZLib() else FrameInflater.None + + coroutineScope { + scope = this + session = data.client.webSocketSession { + url(url) + } + processIncoming() + atomicHeartbeatJob.value?.cancel() + } + + val reason = resolveCloseReason() + if (!atomicState.compareAndSet(expect = State.Closing, update = State.Closed)) errorInvalidState() + return reason + } + + private suspend fun resolveCloseReason(): CloseReason { + val resume = atomicReadyData.value?.let { Resume(sessionToken, it.sessionId, sequence ?: 0) } + + val isReconnectRequested = atomicReconnectRequested.value + if (isReconnectRequested) { + return if (resume != null) { + CloseReason.ResumableReconnect(resume) + } else CloseReason.Reconnect + } + + val invalidSessionData = atomicInvalidSession.value + if (invalidSessionData != null) { + return if (invalidSessionData.resumable && resume != null) { + CloseReason.ResumableInvalidSession(resume) + } else CloseReason.InvalidSession + } + + val isClosedManually = atomicManualClose.value + if (isClosedManually) return CloseReason.Manual + + val sessionCloseReason = session.closeReason.await() + if (sessionCloseReason != null) { + return CloseReason.Plain(sessionCloseReason.code.toInt(), sessionCloseReason.message, resume) + } + + return CloseReason.Error(IllegalStateException("Could not resolve close reason.")) + } + + private suspend fun processIncoming() { + inflater.use { inflater -> + for (frame in session.incoming) when (frame) { + is Frame.Binary, is Frame.Text -> { + processInflatedFrame(inflater.inflate(frame)) + } + + else -> {} // Ignore other, they are handled by the session. + } + } + } + + private suspend fun processInflatedFrame(byteArray: ByteArray) { + val jsonString = byteArray.toString(StandardCharsets.UTF_8) + log.info { "Gateway <<< $jsonString" } + runCatching { + data.json.decodeFromString(Event.DeserializationStrategy, jsonString)?.also { processEvent(it) } + }.onFailure { + log.catching(it) + }.onSuccess { event -> + (event as? DispatchEvent)?.sequence?.let { atomicSequence.value = it } + } + } + + private suspend fun processEvent(event: Event) { + when (event) { + HeartbeatACK -> processHeartbeatACK() + is Heartbeat -> scope.launch { sendHeartbeat() } + is Hello -> processHello(event) + is Ready -> processReady(event) + Reconnect -> atomicReconnectRequested.compareAndSet(expect = false, update = true) + is InvalidSession -> atomicInvalidSession.compareAndSet(expect = null, update = event) + is DispatchEvent, is Close -> data.eventFlow.emit(event) + } + } + + private fun processHeartbeatACK() { + ping.value = atomicHeartbeatTimeMark.value.elapsedNow() + possiblyZombie = false + } + + private fun processReady(ready: Ready) { + val hasReceivedReadyBefore = !atomicReadyData.compareAndSet(expect = null, update = ready.data) + if (hasReceivedReadyBefore) { + log.warn { "Received more than one Ready event." } + return + } + } + + private fun processHello(hello: Hello) { + data.reconnectRetry.reset() + + val hasReceivedHelloBefore = !atomicReceivedHello.compareAndSet(expect = false, update = true) + if (hasReceivedHelloBefore) { + log.warn { "Received more than one Hello opcode." } + } else { + if (!atomicState.compareAndSet(expect = State.Opening, update = State.Open)) errorInvalidState() + atomicHeartbeatJob.compareAndSet(null, scope.launch { heartBeating(hello.heartbeatInterval.seconds) }) + } + + val resumeOrIdentify = when (val sessionData = data.session) { + is Session.New -> sessionData.identify + is Session.Resumed -> sessionData.resume + } + scope.launch { + data.identifyRateLimiter.consume(data.shard.index, data.eventFlow) + send(resumeOrIdentify) + } + } + + private suspend fun heartBeating(interval: Duration) { + val coroutineContext = currentCoroutineContext() + while (atomicState.value == State.Open && coroutineContext.isActive) { + val isZombie = !atomicPossiblyZombie.compareAndSet(expect = false, update = true) + if (isZombie) { + atomicReconnectRequested.compareAndSet(expect = false, update = true) + session.close(CLOSE_REASON_RECONNECTING) + break + } + sendHeartbeat() + delay(interval) + } + } + + private suspend fun sendHeartbeat() { + data.sendRateLimiter.consume() + val jsonString = data.json.encodeToString(Command.SerializationStrategy, Command.Heartbeat(sequence)) + heartbeatTimeMark = TimeSource.Monotonic.markNow() + session.send(jsonString) + } + + override suspend fun send(command: Command) { + atomicState.loop { currentState -> + when (currentState) { + State.Open -> { + data.sendRateLimiter.consume() + val jsonString = data.json.encodeToString(Command.SerializationStrategy, command) + + log.info { + val credentialFreeCopy = when (command) { + is Identify -> command.copy(token = "Hidden") + is Resume -> command.copy(token = "Hidden") + else -> null + } + val credentialFreeJson = if (credentialFreeCopy != null) { + data.json.encodeToString(Command.SerializationStrategy, credentialFreeCopy) + } else jsonString + + "Gateway >>> $credentialFreeJson" + } + + session.send(jsonString) + return + } + + State.Opening -> yield() + State.Closing, State.Closed -> errorInvalidState() + State.Uninitialized -> errorInvalidState() + } + } + } + + override suspend fun close() { + if (!atomicState.compareAndSet(expect = State.Open, update = State.Closing)) errorInvalidState() + atomicManualClose.compareAndSet(expect = false, update = true) + session.close(CLOSE_REASON_LEAVING) + } + + private fun errorInvalidState(): Nothing { + when (atomicState.value) { + State.Uninitialized -> error("Connection is not initialized.") + State.Opening -> error("Connection is opening.") + State.Open -> error("Connection is already open.") + State.Closing -> error("Connection is closing.") + State.Closed -> error("Connection is closed.") + } + } + + private interface FrameInflater : Closeable { + + fun inflate(frame: Frame): ByteArray + + object None : FrameInflater { + override fun inflate(frame: Frame): ByteArray = frame.data + override fun close() {} + } + + class ZLib : FrameInflater { + + private val buffer = ByteArrayOutputStream() + private val inflaterOutput = InflaterOutputStream(buffer) + + override fun inflate(frame: Frame): ByteArray { + inflaterOutput.apply { + write(frame.data) + flush() + } + val inflated = buffer.toByteArray() + buffer.reset() + return inflated + } + + override fun close() { + inflaterOutput.close() + buffer.reset() + } + } + } + + private enum class State { Uninitialized, Opening, Open, Closing, Closed } + + private companion object { + + private val CLOSE_REASON_LEAVING = KtorCloseReason(code = KtorCloseReason.Codes.NORMAL, message = "Leaving") + private val CLOSE_REASON_RECONNECTING = KtorCloseReason(code = 4900, message = "Reconnecting") + } +} diff --git a/gateway/src/main/kotlin/connection/GatewayConnection.kt b/gateway/src/main/kotlin/connection/GatewayConnection.kt new file mode 100644 index 00000000000..3e34ca8de63 --- /dev/null +++ b/gateway/src/main/kotlin/connection/GatewayConnection.kt @@ -0,0 +1,134 @@ +package dev.kord.gateway.connection + +import dev.kord.common.entity.DiscordShard +import dev.kord.common.ratelimit.RateLimiter +import dev.kord.gateway.Command +import dev.kord.gateway.Event +import dev.kord.gateway.Identify +import dev.kord.gateway.Resume +import dev.kord.gateway.ratelimit.IdentifyRateLimiter +import dev.kord.gateway.retry.Retry +import io.ktor.client.* +import kotlinx.coroutines.flow.MutableSharedFlow +import kotlinx.coroutines.flow.StateFlow +import kotlinx.serialization.json.Json +import java.net.URI +import kotlin.time.Duration + +/** + * Creates default [GatewayConnection] implementation. + * @see [GatewayConnection] + */ +public fun GatewayConnection(): GatewayConnection = DefaultGatewayConnection() + +/** + * Represents a connection to the Discord gateway. + * Lifecycle of a connection is very simple: [open] and [close]. + */ +public interface GatewayConnection { + + /** + * The current ping of the connection. + * Null if the connection is closed or not opened yet. + */ + public val ping: StateFlow + + /** + * Opens the connection to the Discord gateway. + * @param data the data required to open the connection. + * @throws IllegalStateException if the connection is already closed/open. + * @see GatewayConnection.Data + */ + public suspend fun open(data: Data): CloseReason + + /** + * Closes the connection to the Discord gateway. + * @throws IllegalStateException if the connection is already closed, or is not open yet. + */ + public suspend fun close() + + /** + * Sends the [command] to the Discord gateway. + * @throws IllegalStateException if the connection is closed. + */ + public suspend fun send(command: Command) + + /** + * Data required to open a connection. + * @see GatewayConnection.open + */ + public data class Data( + val shard: DiscordShard, + val uri: URI, + val session: Session, + val client: HttpClient, + val json: Json, + val eventFlow: MutableSharedFlow, + val sendRateLimiter: RateLimiter, + val identifyRateLimiter: IdentifyRateLimiter, + val reconnectRetry: Retry + ) + + /** + * Represents a gateway connection session. + */ + public sealed interface Session { + + /** + * Represents a new session. + */ + public class New(public val identify: Identify) : Session + + /** + * Represents a resumed session. + */ + public class Resumed(public val resume: Resume) : Session + } + + /** + * Represents the reason why a connection was closed. + */ + public sealed interface CloseReason { + + /** + * The connection was closed manually. + */ + public object Manual : CloseReason + + /** + * The connection was closed due to an invalid session. + */ + public object InvalidSession : CloseReason + + /** + * The connection was closed due to an invalid session. + * Resumable variant. + */ + public class ResumableInvalidSession(public val resume: Resume) : CloseReason + + /** + * The connection was closed due to a reconnect required. + */ + public object Reconnect : CloseReason + + /** + * The connection was closed due to a reconnect required. + * Resumable variant. + */ + public class ResumableReconnect(public val resume: Resume) : CloseReason + + /** + * The connection was closed due to an error. + */ + public class Error(public val cause: Throwable) : CloseReason + + /** + * The connection was closed due to a close frame. + */ + public class Plain( + public val code: Int, + public val message: String?, + public val resume: Resume? + ) : CloseReason + } +}