diff --git a/hikari/errors.py b/hikari/errors.py index 624960d87d..ed944f5777 100644 --- a/hikari/errors.py +++ b/hikari/errors.py @@ -27,6 +27,7 @@ "HikariError", "HikariWarning", "HikariInterrupt", + "ComponentNotRunningError", "NotFoundError", "RateLimitedError", "RateLimitTooLongError", @@ -100,6 +101,17 @@ class HikariInterrupt(KeyboardInterrupt, HikariError): """The signal name that was raised.""" +@attr.s(auto_exc=True, slots=True, repr=False, weakref_slot=False) +class ComponentNotRunningError(HikariError): + """An exception thrown if trying to interact with a component that is not running.""" + + reason: str = attr.ib() + """A string to explain the issue.""" + + def __str__(self) -> str: + return self.reason + + @attr.s(auto_exc=True, slots=True, repr=False, weakref_slot=False) class GatewayError(HikariError): """A base exception type for anything that can be thrown by the Gateway.""" diff --git a/hikari/events/lifetime_events.py b/hikari/events/lifetime_events.py index 289ec9649c..47727e8581 100644 --- a/hikari/events/lifetime_events.py +++ b/hikari/events/lifetime_events.py @@ -115,7 +115,7 @@ class StoppedEvent(base_events.Event): closed within a coroutine function. !!! warning - The application will not proceed to leave the `_rest.run` call until all + The application will not proceed to leave the `bot.run` call until all event handlers for this event have completed/terminated. This prevents the risk of race conditions occurring where a script may terminate the process before a callback can occur. diff --git a/hikari/impl/bot.py b/hikari/impl/bot.py index 9942d1cdf7..e48fea5679 100644 --- a/hikari/impl/bot.py +++ b/hikari/impl/bot.py @@ -234,6 +234,7 @@ class BotApp(traits.BotAware, event_dispatcher.EventDispatcher): "_executor", "_http_settings", "_intents", + "_is_alive", "_proxy_settings", "_raw_event_consumer", "_rest", @@ -269,6 +270,7 @@ def __init__( self._banner = banner self._closing_event = asyncio.Event() self._closed = False + self._is_alive = False self._executor = executor self._http_settings = http_settings if http_settings is not None else config.HTTPSettings() self._intents = intents @@ -395,6 +397,10 @@ def voice(self) -> voice_.VoiceComponent: def rest(self) -> rest_.RESTClient: return self._rest + @property + def is_alive(self) -> bool: + return self._is_alive + async def close(self, force: bool = True) -> None: """Kill the application by shutting all components down.""" if not self._closing_event.is_set(): @@ -442,6 +448,7 @@ async def handle(name: str, awaitable: typing.Awaitable[typing.Any]) -> None: # Clear out shard map self._shards.clear() + self._is_alive = False await self.dispatch(lifetime_events.StoppedEvent(app=self)) @@ -771,6 +778,7 @@ async def start( ux.check_for_updates(self._http_settings, self._proxy_settings), name="check for package updates", ) + self._is_alive = False requirements_task = asyncio.create_task(self._rest.fetch_gateway_bot(), name="fetch gateway sharding settings") await self.dispatch(lifetime_events.StartingEvent(app=self)) requirements = await requirements_task @@ -836,7 +844,8 @@ async def start( except asyncio.TimeoutError: # If any shards stopped silently, we should close. if any(not s.is_alive for s in self._shards.values()): - _LOGGER.info("one of the shards has been manually shut down (no error), will now shut down") + _LOGGER.warning("one of the shards has been manually shut down (no error), will now shut down") + await self.close() return # new window starts. @@ -844,7 +853,7 @@ async def start( _LOGGER.critical("an exception occurred in one of the started shards during bot startup: %r", ex) raise - started_shards = await aio.all_of( + await aio.all_of( *( self._start_one_shard( activity=activity, @@ -861,13 +870,14 @@ async def start( ) ) - for started_shard in started_shards: - self._shards[started_shard.id] = started_shard - await self.dispatch(lifetime_events.StartedEvent(app=self)) _LOGGER.info("application started successfully in approx %.2f seconds", time.monotonic() - start_time) + def _check_if_alive(self) -> None: + if self._is_alive: + raise errors.ComponentNotRunningError("bot is not running so it cannot be interacted with") + def stream( self, event_type: typing.Type[event_dispatcher.EventT_co], @@ -875,6 +885,7 @@ def stream( timeout: typing.Union[float, int, None], limit: typing.Optional[int] = None, ) -> event_stream.Streamer[event_dispatcher.EventT_co]: + self._check_if_alive() return self._events.stream(event_type, timeout=timeout, limit=limit) def subscribe( @@ -894,6 +905,7 @@ async def wait_for( timeout: typing.Union[float, int, None], predicate: typing.Optional[event_dispatcher.PredicateT[event_dispatcher.EventT_co]] = None, ) -> event_dispatcher.EventT_co: + self._check_if_alive() return await self._events.wait_for(event_type, timeout=timeout, predicate=predicate) async def update_presence( @@ -904,6 +916,7 @@ async def update_presence( activity: undefined.UndefinedNoneOr[presences.Activity] = undefined.UNDEFINED, afk: undefined.UndefinedOr[bool] = undefined.UNDEFINED, ) -> None: + self._check_if_alive() self._validate_activity(activity) coros = [ @@ -949,6 +962,7 @@ async def _start_one_shard( token=self._token, url=url, ) + self._shards[new_shard.id] = new_shard start = time.monotonic() await aio.first_completed(new_shard.start(), self._closing_event.wait()) diff --git a/hikari/impl/shard.py b/hikari/impl/shard.py index 4e1f2d40dd..f32b0b63e3 100644 --- a/hikari/impl/shard.py +++ b/hikari/impl/shard.py @@ -142,6 +142,7 @@ async def send_close(self, *, code: int = 1000, message: bytes = b"") -> bool: # something disconnects, which makes aiohttp just shut down as if we # did it. if not self.sent_close: + self.sent_close = True self.logger.debug("sending close frame with code %s and message %s", int(code), message) try: return await asyncio.wait_for(super().close(code=code, message=message), timeout=5) @@ -500,7 +501,7 @@ async def close(self) -> None: "shard.close() was called and the websocket was still alive -- " "disconnecting immediately with GOING AWAY" ) - await self._ws.close(code=errors.ShardCloseCode.GOING_AWAY, message=b"shard disconnecting") + await self._ws.send_close(code=errors.ShardCloseCode.GOING_AWAY, message=b"shard disconnecting") self._closing.set() finally: self._chunking_rate_limit.close() @@ -516,6 +517,23 @@ async def join(self) -> None: """Wait for this shard to close, if running.""" await self._closed.wait() + async def _send_json( + self, + data: data_binding.JSONObject, + compress: typing.Optional[int] = None, + *, + dumps: aiohttp.typedefs.JSONEncoder = json.dumps, + ) -> None: + await self._total_rate_limit.acquire() + + await self._ws.send_json(data=data, compress=compress, dumps=dumps) # type: ignore[union-attr] + + def _check_if_alive(self) -> None: + if not self.is_alive: + raise errors.ComponentNotRunningError( + f"shard {self._shard_id} is not running so it cannot be interacted with" + ) + async def request_guild_members( self, guild: snowflakes.SnowflakeishOr[guilds.PartialGuild], @@ -526,6 +544,7 @@ async def request_guild_members( users: undefined.UndefinedOr[snowflakes.SnowflakeishSequence[users_.User]] = undefined.UNDEFINED, nonce: undefined.UndefinedOr[str] = undefined.UNDEFINED, ) -> None: + self._check_if_alive() if not query and not limit and not self._intents & intents_.Intents.GUILD_MEMBERS: raise errors.MissingIntentError(intents_.Intents.GUILD_MEMBERS) @@ -554,7 +573,7 @@ async def request_guild_members( payload.put_snowflake_array("user_ids", users) payload.put("nonce", nonce) - await self._ws.send_json({_OP: _REQUEST_GUILD_MEMBERS, _D: payload}) # type: ignore[union-attr] + await self._send_json({_OP: _REQUEST_GUILD_MEMBERS, _D: payload}) async def start(self) -> None: if self._run_task is not None: @@ -582,6 +601,7 @@ async def update_presence( activity: undefined.UndefinedNoneOr[presences.Activity] = undefined.UNDEFINED, status: undefined.UndefinedOr[presences.Status] = undefined.UNDEFINED, ) -> None: + self._check_if_alive() presence_payload = self._serialize_and_store_presence_payload( idle_since=idle_since, afk=afk, @@ -589,7 +609,7 @@ async def update_presence( status=status, ) payload: data_binding.JSONObject = {_OP: _PRESENCE_UPDATE, _D: presence_payload} - await self._ws.send_json(payload) # type: ignore[union-attr] + await self._send_json(payload) async def update_voice_state( self, @@ -599,7 +619,8 @@ async def update_voice_state( self_mute: bool = False, self_deaf: bool = False, ) -> None: - await self._ws.send_json( # type: ignore[union-attr] + self._check_if_alive() + await self._send_json( { _OP: _VOICE_STATE_UPDATE, _D: { @@ -661,7 +682,7 @@ async def _identify(self) -> None: payload[_D]["presence"] = self._serialize_and_store_presence_payload() - await self._ws.send_json(payload) # type: ignore[union-attr] + await self._send_json(payload) async def _heartbeat(self, heartbeat_interval: float) -> bool: # Return True if zombied or should reconnect, false if time to die forever. @@ -734,7 +755,7 @@ async def _poll_events(self) -> typing.Optional[bool]: return None async def _resume(self) -> None: - await self._ws.send_json( # type: ignore[union-attr] + await self._send_json( { _OP: _RESUME, _D: {"token": self._token, "seq": self._seq, "session_id": self._session_id}, @@ -907,11 +928,11 @@ async def _run_once(self) -> bool: return True async def _send_heartbeat(self) -> None: - await self._ws.send_json({_OP: _HEARTBEAT, _D: self._seq}) # type: ignore[union-attr] + await self._send_json({_OP: _HEARTBEAT, _D: self._seq}) self._last_heartbeat_sent = time.monotonic() async def _send_heartbeat_ack(self) -> None: - await self._ws.send_json({_OP: _HEARTBEAT_ACK, _D: None}) # type: ignore[union-attr] + await self._send_json({_OP: _HEARTBEAT_ACK, _D: None}) @staticmethod def _serialize_activity(activity: typing.Optional[presences.Activity]) -> data_binding.JSONish: diff --git a/hikari/traits.py b/hikari/traits.py index bd88fda5d8..71f40672bc 100644 --- a/hikari/traits.py +++ b/hikari/traits.py @@ -447,6 +447,21 @@ class BotAware(RESTAware, ShardAware, EventFactoryAware, DispatcherAware, typing __slots__: typing.Sequence[str] = () + @property + def is_alive(self) -> bool: + """Check whether the bot is running or not. + + This is useful as some functions might raise + `hikari.errors.ComponentNotRunningError` if this is + `builtins.False`. + + Returns + ------- + builtins.bool + Whether the bot is running or not. + """ + raise NotImplementedError + async def join(self, until_close: bool = True) -> None: """Wait indefinitely until the application closes. diff --git a/tests/hikari/impl/test_shard.py b/tests/hikari/impl/test_shard.py index b9c73949c3..96ed88436a 100644 --- a/tests/hikari/impl/test_shard.py +++ b/tests/hikari/impl/test_shard.py @@ -638,29 +638,41 @@ def test_shard_count_property(self, client): client._shard_count = 69 assert client.shard_count == 69 + def test_shard__check_if_alive_when_not_alive(self, client): + with mock.patch.object(shard.GatewayShardImpl, "is_alive", new=False): + with pytest.raises(errors.ComponentNotRunningError): + client._check_if_alive() + + @hikari_test_helpers.assert_does_not_raise(errors.ComponentNotRunningError) + def test_shard__check_if_alive_when_alive(self, client): + with mock.patch.object(shard.GatewayShardImpl, "is_alive", new=True): + client._check_if_alive() + async def test_close_when_closing_set(self, client): client._closing = mock.Mock(is_set=mock.Mock(return_value=True)) - client._ws = mock.Mock() + client._send_close = mock.Mock() client._chunking_rate_limit = mock.Mock() client._total_rate_limit = mock.Mock() await client.close() client._closing.set.assert_not_called() - client._ws.close.assert_not_called() + client._send_close.assert_not_called() client._chunking_rate_limit.close.assert_not_called() client._total_rate_limit.close.assert_not_called() async def test_close_when_closing_not_set(self, client): client._closing = mock.Mock(is_set=mock.Mock(return_value=False)) - client._ws = mock.Mock(close=mock.AsyncMock()) + client._ws = mock.Mock(send_close=mock.AsyncMock()) client._chunking_rate_limit = mock.Mock() client._total_rate_limit = mock.Mock() await client.close() client._closing.set.assert_called_once_with() - client._ws.close.assert_awaited_once_with(code=errors.ShardCloseCode.GOING_AWAY, message=b"shard disconnecting") + client._ws.send_close.assert_awaited_once_with( + code=errors.ShardCloseCode.GOING_AWAY, message=b"shard disconnecting" + ) client._chunking_rate_limit.close.assert_called_once_with() client._total_rate_limit.close.assert_called_once_with() @@ -695,69 +707,85 @@ async def test_join(self, client): client._closed.wait.assert_awaited_once_with() async def test_request_guild_members_when_no_query_and_no_limit_and_GUILD_MEMBERS_not_enabled(self, client): + client._check_if_alive = mock.Mock() client._intents = intents.Intents.GUILD_INTEGRATIONS with pytest.raises(errors.MissingIntentError): await client.request_guild_members(123, query="", limit=0) + client._check_if_alive.assert_called_once_with() async def test_request_guild_members_when_presences_and_GUILD_PRESENCES_not_enabled(self, client): + client._check_if_alive = mock.Mock() client._intents = intents.Intents.GUILD_INTEGRATIONS with pytest.raises(errors.MissingIntentError): await client.request_guild_members(123, query="test", limit=1, include_presences=True) + client._check_if_alive.assert_called_once_with() async def test_request_guild_members_when_presences_false_and_GUILD_PRESENCES_not_enabled(self, client): + client._check_if_alive = mock.Mock() client._intents = intents.Intents.GUILD_INTEGRATIONS - client._ws = mock.Mock(send_json=mock.AsyncMock()) + client._send_json = mock.AsyncMock() await client.request_guild_members(123, query="test", limit=1, include_presences=False) - client._ws.send_json.assert_awaited_once_with( + client._send_json.assert_awaited_once_with( { "op": 8, "d": {"guild_id": "123", "query": "test", "presences": False, "limit": 1}, } ) + client._check_if_alive.assert_called_once_with() @pytest.mark.parametrize("kwargs", [{"query": "some query"}, {"limit": 1}]) async def test_request_guild_members_when_specifiying_users_with_limit_or_query(self, client, kwargs): + client._check_if_alive = mock.Mock() client._intents = intents.Intents.GUILD_INTEGRATIONS with pytest.raises(ValueError, match="Cannot specify limit/query with users"): await client.request_guild_members(123, users=[], **kwargs) + client._check_if_alive.assert_called_once_with() @pytest.mark.parametrize("limit", [-1, 101]) async def test_request_guild_members_when_limit_under_0_or_over_100(self, client, limit): + client._check_if_alive = mock.Mock() client._intents = intents.Intents.ALL with pytest.raises(ValueError, match="'limit' must be between 0 and 100, both inclusive"): await client.request_guild_members(123, limit=limit) + client._check_if_alive.assert_called_once_with() async def test_request_guild_members_when_users_over_100(self, client): + client._check_if_alive = mock.Mock() client._intents = intents.Intents.ALL with pytest.raises(ValueError, match="'users' is limited to 100 users"): await client.request_guild_members(123, users=range(101)) + client._check_if_alive.assert_called_once_with() async def test_request_guild_members_when_nonce_over_32_chars(self, client): + client._check_if_alive = mock.Mock() client._intents = intents.Intents.ALL with pytest.raises(ValueError, match="'nonce' can be no longer than 32 byte characters long."): await client.request_guild_members(123, nonce="x" * 33) + client._check_if_alive.assert_called_once_with() @pytest.mark.parametrize("include_presences", [True, False]) async def test_request_guild_members(self, client, include_presences): client._intents = intents.Intents.ALL - client._ws = mock.Mock(send_json=mock.AsyncMock()) + client._check_if_alive = mock.Mock() + client._send_json = mock.AsyncMock() await client.request_guild_members(123, include_presences=include_presences) - client._ws.send_json.assert_awaited_once_with( + client._send_json.assert_awaited_once_with( { "op": 8, "d": {"guild_id": "123", "query": "", "presences": include_presences, "limit": 0}, } ) + client._check_if_alive.assert_called_once_with() async def test_start_when_already_running(self, client): client._run_task = object() @@ -816,8 +844,8 @@ async def test_start(self, client): wait.assert_awaited_once_with((waiter, run_task), return_when=asyncio.FIRST_COMPLETED) async def test_update_presence(self, client): + client._check_if_alive = mock.Mock() presence_payload = object() - client._ws = mock.Mock(send_json=mock.AsyncMock()) client._serialize_and_store_presence_payload = mock.Mock(return_value=presence_payload) client._send_json = mock.AsyncMock() @@ -828,13 +856,15 @@ async def test_update_presence(self, client): activity=None, ) - client._ws.send_json.assert_awaited_once_with({"op": 3, "d": presence_payload}) + client._send_json.assert_awaited_once_with({"op": 3, "d": presence_payload}) + client._check_if_alive.assert_called_once_with() @pytest.mark.parametrize("channel", [12345, None]) @pytest.mark.parametrize("self_deaf", [True, False]) @pytest.mark.parametrize("self_mute", [True, False]) async def test_update_voice_state(self, client, channel, self_deaf, self_mute): - client._ws = mock.Mock(send_json=mock.AsyncMock()) + client._check_if_alive = mock.Mock() + client._send_json = mock.AsyncMock() payload = { "channel_id": str(channel) if channel is not None else None, "guild_id": "6969420", @@ -844,7 +874,7 @@ async def test_update_voice_state(self, client, channel, self_deaf, self_mute): await client.update_voice_state("6969420", channel, self_mute=self_mute, self_deaf=self_deaf) - client._ws.send_json.assert_awaited_once_with({"op": 4, "d": payload}) + client._send_json.assert_awaited_once_with({"op": 4, "d": payload}) def test_dispatch_when_READY(self, client): client._seq = 0 @@ -921,7 +951,7 @@ async def test__identify(self, client): client._shard_id = 0 client._shard_count = 1 client._serialize_and_store_presence_payload = mock.Mock(return_value={"presence": "payload"}) - client._ws = mock.Mock(send_json=mock.AsyncMock()) + client._send_json = mock.AsyncMock() stack = contextlib.ExitStack() stack.enter_context(mock.patch.object(platform, "system", return_value="Potato PC")) stack.enter_context(mock.patch.object(platform, "architecture", return_value=["ARM64"])) @@ -947,7 +977,7 @@ async def test__identify(self, client): "presence": {"presence": "payload"}, }, } - client._ws.send_json.assert_awaited_once_with(expected_json) + client._send_json.assert_awaited_once_with(expected_json) @hikari_test_helpers.timeout() async def test__heartbeat(self, client): @@ -978,7 +1008,7 @@ async def test__resume(self, client): client._token = "token" client._seq = 123 client._session_id = 456 - client._ws = mock.Mock(send_json=mock.AsyncMock()) + client._send_json = mock.AsyncMock() await client._resume() @@ -986,7 +1016,7 @@ async def test__resume(self, client): "op": 6, "d": {"token": "token", "seq": 123, "session_id": 456}, } - client._ws.send_json.assert_awaited_once_with(expected_json) + client._send_json.assert_awaited_once_with(expected_json) @pytest.mark.skip("TODO") async def test__run(self, client): @@ -997,22 +1027,22 @@ async def test__run_once(self, client): ... async def test__send_heartbeat(self, client): - client._ws = mock.Mock(send_json=mock.AsyncMock()) + client._send_json = mock.AsyncMock() client._last_heartbeat_sent = 0 client._seq = 10 with mock.patch.object(time, "monotonic", return_value=200): await client._send_heartbeat() - client._ws.send_json.assert_awaited_once_with({"op": 1, "d": 10}) + client._send_json.assert_awaited_once_with({"op": 1, "d": 10}) assert client._last_heartbeat_sent == 200 async def test__send_heartbeat_ack(self, client): - client._ws = mock.Mock(send_json=mock.AsyncMock()) + client._send_json = mock.AsyncMock() await client._send_heartbeat_ack() - client._ws.send_json.assert_awaited_once_with({"op": 11, "d": None}) + client._send_json.assert_awaited_once_with({"op": 11, "d": None}) def test__serialize_activity_when_activity_is_None(self, client): assert client._serialize_activity(None) is None diff --git a/tests/hikari/test_errors.py b/tests/hikari/test_errors.py index 0be3c8bc5b..137fbb4136 100644 --- a/tests/hikari/test_errors.py +++ b/tests/hikari/test_errors.py @@ -33,6 +33,15 @@ def test_is_standard_property(self, code, expected): assert errors.ShardCloseCode(code).is_standard is expected +class TestComponentNotRunningError: + @pytest.fixture() + def error(self): + return errors.ComponentNotRunningError("some reason") + + def test_str(self, error): + assert str(error) == "some reason" + + class TestGatewayError: @pytest.fixture() def error(self):