diff --git a/hikari/impl/rate_limits.py b/hikari/impl/rate_limits.py index ef73088937..c3cf798b29 100644 --- a/hikari/impl/rate_limits.py +++ b/hikari/impl/rate_limits.py @@ -441,8 +441,8 @@ class ExponentialBackOff: base : builtins.float The base to use. Defaults to `2.0`. maximum : builtins.float - The max value the backoff can be in a single iteration before an - `asyncio.TimeoutError` is raised. Defaults to `64.0` seconds. + The max value the backoff can be in a single iteration. Anything above + this will be capped to this base value plus random jitter. jitter_multiplier : builtins.float The multiplier for the random jitter. Defaults to `1.0`. Set to `0` to disable jitter. @@ -510,17 +510,17 @@ def __next__(self) -> float: """Get the next back off to sleep by.""" try: value = self.base ** self.increment + + if value >= self.maximum: + value = self.maximum + else: + # This should only be incremented after we verify we haven't hit the maximum value. + self.increment += 1 except OverflowError: # If this happened then we can be sure that we've passed maximum. - raise asyncio.TimeoutError from None - - if value >= self.maximum: - raise asyncio.TimeoutError from None + value = self.maximum - # This should only be incremented after we verify we haven't hit the maximum value. - self.increment += 1 - value += random.random() * self.jitter_multiplier # nosec # noqa S311 rng for cryptography - return value + return value + random.random() * self.jitter_multiplier # nosec # noqa S311 rng for cryptography def __iter__(self) -> ExponentialBackOff: """Return this object, as it is an iterator.""" diff --git a/hikari/impl/shard.py b/hikari/impl/shard.py index a13796ddf5..fa57d123d7 100644 --- a/hikari/impl/shard.py +++ b/hikari/impl/shard.py @@ -27,7 +27,6 @@ import asyncio import contextlib -import http import json import logging import platform @@ -121,25 +120,33 @@ class _V6GatewayTransport(aiohttp.ClientWebSocketResponse): Payload logging is also performed here. """ - __slots__: typing.Sequence[str] = ("_zlib", "_logger", "_log_filterer") + __slots__: typing.Sequence[str] = ("zlib", "logger", "log_filterer", "sent_close") # Initialized from `connect' - _zlib: _ZlibDecompressor - _logger: logging.Logger - _log_filterer: typing.Callable[[str], str] + zlib: _ZlibDecompressor + logger: logging.Logger + log_filterer: typing.Callable[[str], str] + sent_close: bool def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: super().__init__(*args, **kwargs) - self._zlib = zlib.decompressobj() - - async def close(self, *, code: int = 1000, message: bytes = b"") -> bool: - if not self._closed and not self._closing: - self._logger.debug("sending close frame with code %s and message %s", int(code), message) - try: - return await asyncio.wait_for(super().close(code=code, message=message), timeout=5) - except asyncio.TimeoutError: - self._logger.debug("failed to send close frame in time, probably connection issues") - return False + self.zlib = zlib.decompressobj() + self.sent_close = False + + async def send_close(self, *, code: int = 1000, message: bytes = b"") -> bool: + # aiohttp may close the socket by invoking close() internally. By giving + # a different name, we can ensure aiohttp won't invoke this method. + # We can then guarantee any call to this method was made by us, as + # opposed to, for example, Windows injecting a spurious EOF when + # something disconnects, which makes aiohttp just shut down as if we + # did it. + if not self.sent_close: + self.logger.debug("sending close frame with code %s and message %s", int(code), message) + try: + return await asyncio.wait_for(super().close(code=code, message=message), timeout=5) + except asyncio.TimeoutError: + self.logger.debug("failed to send close frame in time, probably connection issues") + return False async def receive_json( self, @@ -148,9 +155,9 @@ async def receive_json( timeout: typing.Optional[float] = None, ) -> typing.Any: pl = await self._receive_and_check(timeout) - if self._logger.getEffectiveLevel() <= ux.TRACE: - filtered = self._log_filterer(pl) # type: ignore - self._logger.log(ux.TRACE, "received payload with size %s\n %s", len(pl), filtered) + if self.logger.getEffectiveLevel() <= ux.TRACE: + filtered = self.log_filterer(pl) # type: ignore + self.logger.log(ux.TRACE, "received payload with size %s\n %s", len(pl), filtered) return loads(pl) async def send_json( @@ -161,9 +168,9 @@ async def send_json( dumps: aiohttp.typedefs.JSONEncoder = json.dumps, ) -> None: pl = dumps(data) - if self._logger.getEffectiveLevel() <= ux.TRACE: - filtered = self._log_filterer(pl) # type: ignore - self._logger.log(ux.TRACE, "sending payload with size %s\n %s", len(pl), filtered) + if self.logger.getEffectiveLevel() <= ux.TRACE: + filtered = self.log_filterer(pl) # type: ignore + self.logger.log(ux.TRACE, "sending payload with size %s\n %s", len(pl), filtered) await self.send_str(pl, compress) async def _receive_and_check(self, timeout: typing.Optional[float], /) -> str: @@ -175,7 +182,7 @@ async def _receive_and_check(self, timeout: typing.Optional[float], /) -> str: if message.type == aiohttp.WSMsgType.CLOSE: close_code = int(message.data) reason = message.extra - self._logger.error("connection closed with code %s (%s)", close_code, reason) + self.logger.error("connection closed with code %s (%s)", close_code, reason) can_reconnect = close_code < 4000 or close_code in ( errors.ShardCloseCode.UNKNOWN_ERROR, @@ -189,7 +196,10 @@ async def _receive_and_check(self, timeout: typing.Optional[float], /) -> str: raise errors.GatewayServerClosedConnectionError(reason, close_code, can_reconnect) elif message.type == aiohttp.WSMsgType.CLOSING or message.type == aiohttp.WSMsgType.CLOSED: - raise asyncio.CancelledError("Socket closed") + # May be caused by the server shutting us down. + # May be caused by Windows injecting an EOF if something disconnects, as some + # network drivers appear to do this. + raise errors.GatewayError("Socket has closed") elif len(buff) != 0 and message.type != aiohttp.WSMsgType.BINARY: raise errors.GatewayError(f"Unexpected message type received {message.type.name}, expected BINARY") @@ -198,7 +208,7 @@ async def _receive_and_check(self, timeout: typing.Optional[float], /) -> str: buff.extend(message.data) if buff.endswith(b"\x00\x00\xff\xff"): - return self._zlib.decompress(buff).decode("utf-8") + return self.zlib.decompress(buff).decode("utf-8") elif message.type == aiohttp.WSMsgType.TEXT: return message.data # type: ignore @@ -206,10 +216,10 @@ async def _receive_and_check(self, timeout: typing.Optional[float], /) -> str: else: # Assume exception for now. ex = self.exception() - self._logger.warning( + self.logger.warning( "encountered unexpected error: %s", ex, - exc_info=ex if self._logger.isEnabledFor(logging.DEBUG) else None, + exc_info=ex if self.logger.isEnabledFor(logging.DEBUG) else None, ) raise errors.GatewayError("Unexpected websocket exception from gateway") from ex @@ -233,84 +243,84 @@ async def connect( and keeps all of the nested boilerplate out of the way of the rest of the code, for the most part anyway. """ + exit_stack = contextlib.AsyncExitStack() + try: - async with aiohttp.ClientSession( - connector=aiohttp.TCPConnector( - limit=1, - use_dns_cache=False, - verify_ssl=http_config.verify_ssl, - enable_cleanup_closed=True, - force_close=True, - ), - raise_for_status=True, - timeout=aiohttp.ClientTimeout( - total=http_config.timeouts.total, - connect=http_config.timeouts.acquire_and_connect, - sock_read=http_config.timeouts.request_socket_read, - sock_connect=http_config.timeouts.request_socket_connect, - ), - trust_env=proxy_config.trust_env, - ws_response_class=cls, - ) as cs: - try: - async with cs.ws_connect( - max_msg_size=0, - proxy=proxy_config.url, - proxy_headers=proxy_config.headers, - url=url, - ) as ws: - raised = False - try: - assert isinstance(ws, cls) - ws._logger = logger - # We store this so we can remove it from debug logs - # which enables people to send me logs in issues safely. - # Also MyPy raises a false positive about this... - ws._log_filterer = log_filterer # type: ignore - - yield ws - except errors.GatewayError: - raised = True - raise - except Exception as ex: - raised = True - raise errors.GatewayError(f"Unexpected {type(ex).__name__}: {ex}") from ex - finally: - if ws.closed: - logger.log(ux.TRACE, "ws was already closed") - - elif raised: - await ws.close( - code=errors.ShardCloseCode.UNEXPECTED_CONDITION, - message=b"unexpected fatal client error :-(", - ) - - elif not ws._closing: - # We use a special close code here that prevents Discord - # randomly invalidating our session. Undocumented behaviour is - # nice like that... - await ws.close( - code=_RESUME_CLOSE_CODE, - message=b"client is shutting down", - ) - - except aiohttp.ClientConnectionError as ex: - message = f"Failed to connect to Discord: {ex!r}" - raise errors.GatewayConnectionError(message) from ex - - except aiohttp.WSServerHandshakeError as ex: - try: - reason = http.HTTPStatus(ex.status).name - except ValueError: - reason = "Unknown Reason" + client_session = await exit_stack.enter_async_context( + aiohttp.ClientSession( + connector=aiohttp.TCPConnector( + limit=1, + use_dns_cache=False, + verify_ssl=http_config.verify_ssl, + enable_cleanup_closed=True, + force_close=True, + ), + raise_for_status=True, + timeout=aiohttp.ClientTimeout( + total=http_config.timeouts.total, + connect=http_config.timeouts.acquire_and_connect, + sock_read=http_config.timeouts.request_socket_read, + sock_connect=http_config.timeouts.request_socket_connect, + ), + trust_env=proxy_config.trust_env, + ws_response_class=cls, + ) + ) - message = ( - f"Discord produced a {ex.status} {reason} response " - f"when attempting to upgrade to a websocket: {ex.message!r}" + web_socket = await exit_stack.enter_async_context( + client_session.ws_connect( + max_msg_size=0, + proxy=proxy_config.url, + proxy_headers=proxy_config.headers, + url=url, + ) + ) + + assert isinstance(web_socket, cls) + + raised = False + try: + web_socket.logger = logger + # We store this so we can remove it from debug logs + # which enables people to send me logs in issues safely. + # Also MyPy raises a false positive about this... + web_socket.log_filterer = log_filterer # type: ignore + + yield web_socket + except errors.GatewayError: + raised = True + raise + except Exception as ex: + raised = True + raise errors.GatewayError(f"Unexpected {type(ex).__name__}: {ex}") from ex + finally: + if web_socket.closed: + logger.log(ux.TRACE, "ws was already closed") + + elif raised: + await web_socket.send_close( + code=errors.ShardCloseCode.UNEXPECTED_CONDITION, + message=b"unexpected fatal client error :-(", + ) + + elif not web_socket._closing: + # We use a special close code here that prevents Discord + # randomly invalidating our session. Undocumented behaviour is + # nice like that... + await web_socket.send_close( + code=_RESUME_CLOSE_CODE, + message=b"client is shutting down", ) - raise errors.GatewayError(message) from ex + except (aiohttp.ClientOSError, aiohttp.ClientConnectionError, aiohttp.WSServerHandshakeError) as ex: + # Windows will sometimes raise an aiohttp.ClientOSError + # If we cannot do DNS lookup, this will fail with a ClientConnectionError + # usually. + raise errors.GatewayConnectionError(f"Failed to connect to Discord: {ex!r}") from ex + finally: + await exit_stack.aclose() + # We have to sleep to allow aiohttp time to close SSL transports... # https://github.com/aio-libs/aiohttp/issues/1925 # https://docs.aiohttp.org/en/stable/client_advanced.html#graceful-shutdown @@ -694,6 +704,43 @@ async def _heartbeat(self, heartbeat_interval: float) -> bool: # We should continue continue + async def _poll_events(self) -> typing.Optional[bool]: + payload = await self._ws.receive_json(timeout=5) # type: ignore[union-attr] + + op = payload[_OP] # opcode int + d = payload[_D] # data/payload. Usually a dict or a bool for INVALID_SESSION + + if op == _DISPATCH: + t = payload[_T] # event name str + s = payload[_S] # seq int + self._logger.log(ux.TRACE, "dispatching %s with seq %s", t, s) + self._dispatch(t, s, d) + elif op == _HEARTBEAT: + await self._send_heartbeat_ack() + self._logger.log(ux.TRACE, "sent HEARTBEAT") + elif op == _HEARTBEAT_ACK: + now = date.monotonic() + self._last_heartbeat_ack_received = now + self._heartbeat_latency = now - self._last_heartbeat_sent + self._logger.log(ux.TRACE, "received HEARTBEAT ACK in %.1fms", self._heartbeat_latency * 1_000) + elif op == _RECONNECT: + # We should be able to resume... + self._logger.debug("received instruction to reconnect, will resume existing session") + return True + elif op == _INVALID_SESSION: + # We can resume if the payload was `true`. + if not d: + self._logger.debug("received invalid session, will need to start a new session") + self._seq = None + self._session_id = None + else: + self._logger.debug("received invalid session, will resume existing session") + return True + else: + self._logger.log(ux.TRACE, "unknown opcode %s received, it will be ignored...", op) + + return None + async def _resume(self) -> None: await self._ws.send_json( # type: ignore[union-attr] { @@ -712,174 +759,144 @@ async def _run(self) -> None: initial_increment=_BACKOFF_INCREMENT_START, ) - try: - while True: - if date.monotonic() - last_started_at < _BACKOFF_WINDOW: - time = next(backoff) - self._logger.debug("backing off reconnecting for %.2fs to prevent spam", time) - - try: - await asyncio.wait_for(self._closing.wait(), timeout=time) - # We were told to close. - return - except asyncio.TimeoutError: - # We are going to run once. - pass + while True: + if date.monotonic() - last_started_at < _BACKOFF_WINDOW: + time = next(backoff) + self._logger.info("backing off reconnecting for %.2fs", time) try: - last_started_at = date.monotonic() - if not await self._run_once(): - self._logger.debug("shard has shut down") + await asyncio.wait_for(self._closing.wait(), timeout=time) + # We were told to close. + return + except asyncio.TimeoutError: + # We are going to run once. + pass - except errors.GatewayConnectionError as ex: - self._logger.error("failed to connect to server, reason was: %s. Will retry shortly", ex.__cause__) + try: + last_started_at = date.monotonic() + if not await self._run_once(): + self._logger.debug("shard has shut down") + + except errors.GatewayConnectionError as ex: + self._logger.error( + "failed to communicate with server, reason was: %s. Will retry shortly", + ex.__cause__, + ) - except errors.GatewayServerClosedConnectionError as ex: - if not ex.can_reconnect: - raise + except errors.GatewayServerClosedConnectionError as ex: + if not ex.can_reconnect: + raise - self._logger.info( - "server has closed connection, will reconnect if possible [code:%s, reason:%s]", - ex.code, - ex.reason, - ) + self._logger.info( + "server has closed connection, will reconnect if possible [code:%s, reason:%s]", + ex.code, + ex.reason, + ) - except errors.GatewayError as ex: - self._logger.debug("encountered generic gateway error", exc_info=ex) - raise + except errors.GatewayError as ex: + self._logger.debug("encountered generic gateway error", exc_info=ex) + raise - except Exception as ex: - self._logger.debug("encountered some unhandled error", exc_info=ex) - raise - finally: - self._closed.set() - self._logger.info("shard %s has shut down permanently", self._shard_id) + except Exception as ex: + self._logger.debug("encountered some unhandled error", exc_info=ex) + raise + + finally: + self._closed.set() + self._logger.info("shard %s has shut down", self._shard_id) async def _run_once(self) -> bool: self._closing.clear() self._handshake_completed.clear() dispatch_disconnect = False - try: - async with _V6GatewayTransport.connect( + + exit_stack = contextlib.AsyncExitStack() + + self._ws = await exit_stack.enter_async_context( + _V6GatewayTransport.connect( http_config=self._http_settings, log_filterer=_log_filterer(self._token), logger=self._logger, proxy_config=self._proxy_settings, url=self._url, - ) as self._ws: - # Dispatch CONNECTED synthetic event. - self._event_consumer(self, "CONNECTED", {}) - dispatch_disconnect = True - - # Expect HELLO. - payload = await self._ws.receive_json() - if payload[_OP] != _HELLO: - self._logger.debug( - "expected HELLO opcode, received %s which makes no sense, closing with PROTOCOL ERROR ", - "(_run_once => raise and do not reconnect)", - payload[_OP], - ) - await self._ws.close(code=errors.ShardCloseCode.PROTOCOL_ERROR, message=b"Expected HELLO op") - raise errors.GatewayError(f"Expected opcode {_HELLO}, but received {payload[_OP]}") + ) + ) - heartbeat_latency = float(payload[_D]["heartbeat_interval"]) / 1_000.0 - heartbeat_task = asyncio.create_task(self._heartbeat(heartbeat_latency)) + try: + # Dispatch CONNECTED synthetic event. + self._event_consumer(self, "CONNECTED", {}) + dispatch_disconnect = True + + heartbeat_task = await self._wait_for_hello() + + try: + if self._seq is not None: + self._logger.debug("resuming session %s", self._session_id) + await self._resume() + else: + self._logger.debug("identifying with new session") + await self._identify() if self._closing.is_set(): self._logger.debug( - "closing flag was set before we could handshake, disconnecting with GOING AWAY " + "closing flag was set during handshake, disconnecting with GOING AWAY " "(_run_once => do not reconnect)" ) - await self._ws.close(code=errors.ShardCloseCode.GOING_AWAY, message=b"shard disconnecting") + await self._ws.send_close( # type: ignore[union-attr] + code=errors.ShardCloseCode.GOING_AWAY, message=b"shard disconnecting" + ) return False - try: - if self._seq is not None: - self._logger.debug("resuming session %s", self._session_id) - await self._resume() - else: - self._logger.debug("identifying with new session") - await self._identify() - - if self._closing.is_set(): - self._logger.debug( - "closing flag was set during handshake, disconnecting with GOING AWAY " - "(_run_once => do not reconnect)" - ) - await self._ws.close(code=errors.ShardCloseCode.GOING_AWAY, message=b"shard disconnecting") - return False - - # Event polling. - while not self._closing.is_set() and not heartbeat_task.done() and not heartbeat_task.cancelled(): - try: - payload = await self._ws.receive_json(timeout=5) - except asyncio.TimeoutError: - # Don't wait forever, check if the heartbeat has died. - continue - - op = payload[_OP] # opcode int - d = payload[_D] # data/payload. Usually a dict or a bool for INVALID_SESSION - - if op == _DISPATCH: - t = payload[_T] # event name str - s = payload[_S] # seq int - self._logger.log(ux.TRACE, "dispatching %s with seq %s", t, s) - self._dispatch(t, s, d) - elif op == _HEARTBEAT: - await self._send_heartbeat_ack() - self._logger.log(ux.TRACE, "sent HEARTBEAT") - elif op == _HEARTBEAT_ACK: - now = date.monotonic() - self._last_heartbeat_ack_received = now - self._heartbeat_latency = now - self._last_heartbeat_sent - self._logger.log( - ux.TRACE, "received HEARTBEAT ACK in %.1fms", self._heartbeat_latency * 1_000 - ) - elif op == _RECONNECT: - # We should be able to resume... - self._logger.debug("received instruction to reconnect, will resume existing session") - return True - elif op == _INVALID_SESSION: - # We can resume if the payload was `true`. - if not d: - self._logger.debug("received invalid session, will need to start a new session") - self._seq = None - self._session_id = None - else: - self._logger.debug("received invalid session, will resume existing session") - return True - else: - self._logger.log(ux.TRACE, "unknown opcode %s received, it will be ignored...", op) - - # If the heartbeat died due to an error, it should be raised here. - # This will currently allow us to try to resume if that happens - # We return True if zombied. - if await heartbeat_task: - now = date.monotonic() - self._logger.error( - "connection is a zombie, last heartbeat sent %.2fs ago", - now - self._last_heartbeat_sent, - ) - self._logger.debug("will attempt to reconnect (_run_once => reconnect)") - return True + # Event polling. + while not self._closing.is_set() and not heartbeat_task.done() and not heartbeat_task.cancelled(): + try: + result = await self._poll_events() - self._logger.debug( - "shard has requested graceful termination, so will not attempt to reconnect " - "(_run_once => do not reconnect)" + if result is not None: + return result + except asyncio.TimeoutError: + # We should check if the shard is still alive and then poll again after. + pass + + # If the heartbeat died due to an error, it should be raised here. + # This will currently allow us to try to resume if that happens + # We return True if zombied. + if await heartbeat_task: + now = date.monotonic() + self._logger.error( + "connection is a zombie, last heartbeat sent %.2fs ago", + now - self._last_heartbeat_sent, ) - await self._ws.close(code=errors.ShardCloseCode.GOING_AWAY, message=b"shard disconnecting") - return False + self._logger.debug("will attempt to reconnect (_run_once => reconnect)") + return True - finally: - heartbeat_task.cancel() + self._logger.debug( + "shard has requested graceful termination, so will not attempt to reconnect " + "(_run_once => do not reconnect)" + ) + await self._ws.send_close( # type: ignore[union-attr] + code=errors.ShardCloseCode.GOING_AWAY, + message=b"shard disconnecting", + ) + return False + + finally: + heartbeat_task.cancel() finally: + ws = self._ws self._ws = None + await exit_stack.aclose() if dispatch_disconnect: # If we managed to connect, we must always send the DISCONNECT event # afterwards. self._event_consumer(self, "DISCONNECTED", {}) + # Check if we made the socket close or handled it. If we didn't, we should always try to + # reconnect, as aiohttp is probably closing it internally without telling us properly. + if not ws.sent_close: # type: ignore[union-attr] + return True + async def _send_heartbeat(self) -> None: await self._ws.send_json({_OP: _HEARTBEAT, _D: self._seq}) # type: ignore[union-attr] self._last_heartbeat_sent = date.monotonic() @@ -940,3 +957,33 @@ def _serialize_datetime(dt: typing.Optional[datetime.datetime]) -> typing.Option return None return int(dt.timestamp() * 1_000) + + async def _wait_for_hello(self) -> asyncio.Task[bool]: + # Expect HELLO. + payload = await self._ws.receive_json() # type: ignore[union-attr] + if payload[_OP] != _HELLO: + self._logger.debug( + "expected HELLO opcode, received %s which makes no sense, closing with PROTOCOL ERROR ", + "(_run_once => raise and do not reconnect)", + payload[_OP], + ) + await self._ws.send_close( # type: ignore[union-attr] + code=errors.ShardCloseCode.PROTOCOL_ERROR, + message=b"Expected HELLO op", + ) + raise errors.GatewayError(f"Expected opcode {_HELLO}, but received {payload[_OP]}") + + if self._closing.is_set(): + self._logger.debug( + "closing flag was set before we could handshake, disconnecting with GOING AWAY " + "(_run_once => do not reconnect)" + ) + await self._ws.send_close( # type: ignore[union-attr] + code=errors.ShardCloseCode.GOING_AWAY, + message=b"shard disconnecting", + ) + raise asyncio.CancelledError("closing flag was set before we could handshake") + + heartbeat_interval = float(payload[_D]["heartbeat_interval"]) / 1_000.0 + heartbeat_task = asyncio.create_task(self._heartbeat(heartbeat_interval)) + return heartbeat_task diff --git a/tests/hikari/impl/test_rate_limits.py b/tests/hikari/impl/test_rate_limits.py index 8280ebe901..9e4b49c6c0 100644 --- a/tests/hikari/impl/test_rate_limits.py +++ b/tests/hikari/impl/test_rate_limits.py @@ -450,8 +450,7 @@ def test_increment_raises_on_numerical_limitation(self): base=5, maximum=sys.float_info.max, jitter_multiplier=0.0, initial_increment=power ) - with pytest.raises(asyncio.TimeoutError): - next(eb) + assert next(eb) == sys.float_info.max def test_increment_maximum(self): max_bound = 64 @@ -460,16 +459,14 @@ def test_increment_maximum(self): for _ in range(iterations): next(eb) - with pytest.raises(asyncio.TimeoutError): - next(eb) + assert next(eb) == max_bound def test_increment_does_not_increment_when_on_maximum(self): - eb = rate_limits.ExponentialBackOff(2, 32, initial_increment=5) + eb = rate_limits.ExponentialBackOff(2, 32, initial_increment=5, jitter_multiplier=0) assert eb.increment == 5 - with pytest.raises(asyncio.TimeoutError): - next(eb) + assert next(eb) == 32 assert eb.increment == 5 diff --git a/tests/hikari/impl/test_shard.py b/tests/hikari/impl/test_shard.py index 96a84c40d4..c70d05d613 100644 --- a/tests/hikari/impl/test_shard.py +++ b/tests/hikari/impl/test_shard.py @@ -65,8 +65,8 @@ class Test_V6GatewayTransport: def transport_impl(self): with mock.patch.object(aiohttp.ClientWebSocketResponse, "__init__"): transport = shard._V6GatewayTransport() - transport._logger = mock.Mock(getEffectiveLevel=mock.Mock(return_value=5)) - transport._log_filterer = mock.Mock() + transport.logger = mock.Mock(getEffectiveLevel=mock.Mock(return_value=5)) + transport.log_filterer = mock.Mock() yield transport def test__init__calls_super(self): @@ -75,44 +75,22 @@ def test__init__calls_super(self): init.assert_called_once_with("arg1", "arg2", some_kwarg="kwarg1") - async def test_close_when_closed_doesnt_log(self, transport_impl): - transport_impl._closed = True - transport_impl._closing = False - transport_impl._logger = mock.Mock() - - with mock.patch.object(aiohttp.ClientWebSocketResponse, "close") as close: - await transport_impl.close(code=1234, message=b"some message") - - transport_impl._logger.debug.assert_not_called() - close.assert_called_once_with(code=1234, message=b"some message") - - async def test_close_when_closing_doesnt_log(self, transport_impl): - transport_impl._closed = False - transport_impl._closing = True - transport_impl._logger = mock.Mock() - - with mock.patch.object(aiohttp.ClientWebSocketResponse, "close") as close: - await transport_impl.close(code=1234, message=b"some message") - - transport_impl._logger.debug.assert_not_called() - close.assert_called_once_with(code=1234, message=b"some message") - - async def test_close_when_not_closed_nor_closing_logs(self, transport_impl): + async def test_send_close_when_not_closed_nor_closing_logs(self, transport_impl): transport_impl._closed = False transport_impl._closing = False - transport_impl._logger = mock.Mock() + transport_impl.logger = mock.Mock() with mock.patch.object(aiohttp.ClientWebSocketResponse, "close") as close: - await transport_impl.close(code=1234, message=b"some message") + await transport_impl.send_close(code=1234, message=b"some message") - transport_impl._logger.debug.assert_called_once_with( + transport_impl.logger.debug.assert_called_once_with( "sending close frame with code %s and message %s", 1234, b"some message" ) close.assert_called_once_with(code=1234, message=b"some message") async def test_receive_json(self, transport_impl): transport_impl._receive_and_check = mock.AsyncMock(return_value="{'json_response': null}") - transport_impl._log_payload = mock.Mock() + transport_impl.log_payload = mock.Mock() mock_loads = mock.Mock(return_value={"json_response": None}) assert await transport_impl.receive_json(loads=mock_loads, timeout=69) == {"json_response": None} @@ -122,7 +100,7 @@ async def test_receive_json(self, transport_impl): async def test_send_json(self, transport_impl): transport_impl.send_str = mock.AsyncMock() - transport_impl._log_payload = mock.Mock() + transport_impl.log_payload = mock.Mock() mock_dumps = mock.Mock(return_value="{'json_send': null}") await transport_impl.send_json({"json_send": None}, 420, dumps=mock_dumps) @@ -156,7 +134,7 @@ def __init__( async def test__receive_and_check_when_message_type_is_CLOSE_and_should_reconnect(self, code, transport_impl): stub_response = self.StubResponse(type=aiohttp.WSMsgType.CLOSE, extra="some error extra", data=code) transport_impl.receive = mock.AsyncMock(return_value=stub_response) - transport_impl._logger = mock.Mock() + transport_impl.logger = mock.Mock() with pytest.raises(errors.GatewayServerClosedConnectionError) as exinfo: await transport_impl._receive_and_check(10) @@ -174,7 +152,7 @@ async def test__receive_and_check_when_message_type_is_CLOSE_and_should_reconnec async def test__receive_and_check_when_message_type_is_CLOSE_and_should_not_reconnect(self, code, transport_impl): stub_response = self.StubResponse(type=aiohttp.WSMsgType.CLOSE, extra="dont reconnect", data=code) transport_impl.receive = mock.AsyncMock(return_value=stub_response) - transport_impl._logger = mock.Mock() + transport_impl.logger = mock.Mock() with pytest.raises(errors.GatewayServerClosedConnectionError) as exinfo: await transport_impl._receive_and_check(10) @@ -189,7 +167,7 @@ async def test__receive_and_check_when_message_type_is_CLOSING(self, transport_i stub_response = self.StubResponse(type=aiohttp.WSMsgType.CLOSING) transport_impl.receive = mock.AsyncMock(return_value=stub_response) - with pytest.raises(asyncio.CancelledError, match="Socket closed"): + with pytest.raises(errors.GatewayError, match="Socket has closed"): await transport_impl._receive_and_check(10) transport_impl.receive.assert_awaited_once_with(10) @@ -198,7 +176,7 @@ async def test__receive_and_check_when_message_type_is_CLOSED(self, transport_im stub_response = self.StubResponse(type=aiohttp.WSMsgType.CLOSED) transport_impl.receive = mock.AsyncMock(return_value=stub_response) - with pytest.raises(asyncio.CancelledError, match="Socket closed"): + with pytest.raises(errors.GatewayError, match="Socket has closed"): await transport_impl._receive_and_check(10) transport_impl.receive.assert_awaited_once_with(10) @@ -208,12 +186,12 @@ async def test__receive_and_check_when_message_type_is_BINARY(self, transport_im response2 = self.StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"data") response3 = self.StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"\x00\x00\xff\xff") transport_impl.receive = mock.AsyncMock(side_effect=[response1, response2, response3]) - transport_impl._zlib = mock.Mock(decompress=mock.Mock(return_value=b"utf-8 encoded bytes")) + transport_impl.zlib = mock.Mock(decompress=mock.Mock(return_value=b"utf-8 encoded bytes")) assert await transport_impl._receive_and_check(10) == "utf-8 encoded bytes" transport_impl.receive.assert_awaited_with(10) - transport_impl._zlib.decompress.assert_called_once_with(bytearray(b"somedata\x00\x00\xff\xff")) + transport_impl.zlib.decompress.assert_called_once_with(bytearray(b"somedata\x00\x00\xff\xff")) async def test__receive_and_check_when_buff_but_next_is_not_BINARY(self, transport_impl): response1 = self.StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"some") @@ -237,7 +215,7 @@ async def test__receive_and_check_when_message_type_is_TEXT(self, transport_impl async def test__receive_and_check_when_message_type_is_unknown(self, transport_impl): transport_impl.receive = mock.AsyncMock(return_value=self.StubResponse(type=aiohttp.WSMsgType.ERROR)) transport_impl.exception = mock.Mock(return_value=Exception) - transport_impl._logger = mock.Mock() + transport_impl.logger = mock.Mock() with pytest.raises(errors.GatewayError, match="Unexpected websocket exception from gateway"): await transport_impl._receive_and_check(10) @@ -247,6 +225,8 @@ async def test__receive_and_check_when_message_type_is_unknown(self, transport_i async def test_connect_yields_websocket(self, http_settings, proxy_settings): class MockWS(hikari_test_helpers.AsyncContextManagerMock, shard._V6GatewayTransport): closed = True + send_close = mock.AsyncMock() + sent_close = False def __init__(self): pass @@ -273,7 +253,7 @@ def __init__(self): url="https://some.url", log_filterer=log_filterer, ) as ws: - assert ws._logger is logger + assert ws.logger is logger tcp_connector.assert_called_once_with( limit=1, @@ -308,7 +288,8 @@ def __init__(self): async def test_connect_when_gateway_error_after_connecting(self, http_settings, proxy_settings): class MockWS(hikari_test_helpers.AsyncContextManagerMock, shard._V6GatewayTransport): closed = False - close = mock.AsyncMock() + sent_close = False + send_close = mock.AsyncMock() def __init__(self): pass @@ -336,7 +317,7 @@ def __init__(self): ): hikari_test_helpers.raiser(errors.GatewayError("some reason")) - mock_websocket.close.assert_awaited_once_with( + mock_websocket.send_close.assert_awaited_once_with( code=errors.ShardCloseCode.UNEXPECTED_CONDITION, message=b"unexpected fatal client error :-(" ) @@ -347,7 +328,8 @@ def __init__(self): async def test_connect_when_unexpected_error_after_connecting(self, http_settings, proxy_settings): class MockWS(hikari_test_helpers.AsyncContextManagerMock, shard._V6GatewayTransport): closed = False - close = mock.AsyncMock() + send_close = mock.AsyncMock() + sent_close = False def __init__(self): pass @@ -375,7 +357,7 @@ def __init__(self): ): hikari_test_helpers.raiser(ValueError("testing")) - mock_websocket.close.assert_awaited_once_with( + mock_websocket.send_close.assert_awaited_once_with( code=errors.ShardCloseCode.UNEXPECTED_CONDITION, message=b"unexpected fatal client error :-(" ) @@ -387,7 +369,8 @@ async def test_connect_when_no_error_and_not_closing(self, http_settings, proxy_ class MockWS(hikari_test_helpers.AsyncContextManagerMock, shard._V6GatewayTransport): closed = False _closing = False - close = mock.AsyncMock() + sent_close = False + send_close = mock.AsyncMock() def __init__(self): pass @@ -414,7 +397,9 @@ def __init__(self): ): pass - mock_websocket.close.assert_awaited_once_with(code=shard._RESUME_CLOSE_CODE, message=b"client is shutting down") + mock_websocket.send_close.assert_awaited_once_with( + code=shard._RESUME_CLOSE_CODE, message=b"client is shutting down" + ) sleep.assert_awaited_once_with(0.25) mock_client_session.assert_used_once() @@ -504,8 +489,8 @@ async def test_connect_when_handshake_error_with_unknown_reason(self, http_setti pytest.raises( errors.GatewayError, match=( - "Discord produced a 123 Unknown Reason response " - "when attempting to upgrade to a websocket: 'some error'" + r"Failed to connect to Discord: " + r"WSServerHandshakeError\(None, None, status=123, message='some error'\)" ), ) ) @@ -542,8 +527,8 @@ async def test_connect_when_handshake_error_with_known_reason(self, http_setting pytest.raises( errors.GatewayError, match=( - "Discord produced a 500 INTERNAL_SERVER_ERROR response " - "when attempting to upgrade to a websocket: 'some error'" + r"Failed to connect to Discord: WSServerHandshakeError" + r"\(None, None, status=500, message='some error'\)" ), ) )