Skip to content

Commit

Permalink
Several bugfixes (#382)
Browse files Browse the repository at this point in the history
- New error `ComponentNotRunningError` that will be raised when trying to interact with a component that is not running
  - Actually closing when one of the shards has been manually shut down
  - `sent_close` not being set when calling `send_close`
  - Unused ratelimiter now hoocked up to prevent getting disconnected on too many requests
  - `_shards` is now populated as soon as the shard object is created to be able to be used through `bot.x` on shard events
  • Loading branch information
davfsa authored Dec 2, 2020
1 parent a125515 commit 059b10c
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 34 deletions.
12 changes: 12 additions & 0 deletions hikari/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"HikariError",
"HikariWarning",
"HikariInterrupt",
"ComponentNotRunningError",
"NotFoundError",
"RateLimitedError",
"RateLimitTooLongError",
Expand Down Expand Up @@ -100,6 +101,17 @@ class HikariInterrupt(KeyboardInterrupt, HikariError):
"""The signal name that was raised."""


@attr.s(auto_exc=True, slots=True, repr=False, weakref_slot=False)
class ComponentNotRunningError(HikariError):
"""An exception thrown if trying to interact with a component that is not running."""

reason: str = attr.ib()
"""A string to explain the issue."""

def __str__(self) -> str:
return self.reason


@attr.s(auto_exc=True, slots=True, repr=False, weakref_slot=False)
class GatewayError(HikariError):
"""A base exception type for anything that can be thrown by the Gateway."""
Expand Down
2 changes: 1 addition & 1 deletion hikari/events/lifetime_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class StoppedEvent(base_events.Event):
closed within a coroutine function.
!!! warning
The application will not proceed to leave the `_rest.run` call until all
The application will not proceed to leave the `bot.run` call until all
event handlers for this event have completed/terminated. This
prevents the risk of race conditions occurring where a script may
terminate the process before a callback can occur.
Expand Down
24 changes: 19 additions & 5 deletions hikari/impl/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ class BotApp(traits.BotAware, event_dispatcher.EventDispatcher):
"_executor",
"_http_settings",
"_intents",
"_is_alive",
"_proxy_settings",
"_raw_event_consumer",
"_rest",
Expand Down Expand Up @@ -269,6 +270,7 @@ def __init__(
self._banner = banner
self._closing_event = asyncio.Event()
self._closed = False
self._is_alive = False
self._executor = executor
self._http_settings = http_settings if http_settings is not None else config.HTTPSettings()
self._intents = intents
Expand Down Expand Up @@ -395,6 +397,10 @@ def voice(self) -> voice_.VoiceComponent:
def rest(self) -> rest_.RESTClient:
return self._rest

@property
def is_alive(self) -> bool:
return self._is_alive

async def close(self, force: bool = True) -> None:
"""Kill the application by shutting all components down."""
if not self._closing_event.is_set():
Expand Down Expand Up @@ -442,6 +448,7 @@ async def handle(name: str, awaitable: typing.Awaitable[typing.Any]) -> None:

# Clear out shard map
self._shards.clear()
self._is_alive = False

await self.dispatch(lifetime_events.StoppedEvent(app=self))

Expand Down Expand Up @@ -771,6 +778,7 @@ async def start(
ux.check_for_updates(self._http_settings, self._proxy_settings),
name="check for package updates",
)
self._is_alive = False
requirements_task = asyncio.create_task(self._rest.fetch_gateway_bot(), name="fetch gateway sharding settings")
await self.dispatch(lifetime_events.StartingEvent(app=self))
requirements = await requirements_task
Expand Down Expand Up @@ -836,15 +844,16 @@ async def start(
except asyncio.TimeoutError:
# If any shards stopped silently, we should close.
if any(not s.is_alive for s in self._shards.values()):
_LOGGER.info("one of the shards has been manually shut down (no error), will now shut down")
_LOGGER.warning("one of the shards has been manually shut down (no error), will now shut down")
await self.close()
return
# new window starts.

except Exception as ex:
_LOGGER.critical("an exception occurred in one of the started shards during bot startup: %r", ex)
raise

started_shards = await aio.all_of(
await aio.all_of(
*(
self._start_one_shard(
activity=activity,
Expand All @@ -861,20 +870,22 @@ async def start(
)
)

for started_shard in started_shards:
self._shards[started_shard.id] = started_shard

await self.dispatch(lifetime_events.StartedEvent(app=self))

_LOGGER.info("application started successfully in approx %.2f seconds", time.monotonic() - start_time)

def _check_if_alive(self) -> None:
if self._is_alive:
raise errors.ComponentNotRunningError("bot is not running so it cannot be interacted with")

def stream(
self,
event_type: typing.Type[event_dispatcher.EventT_co],
/,
timeout: typing.Union[float, int, None],
limit: typing.Optional[int] = None,
) -> event_stream.Streamer[event_dispatcher.EventT_co]:
self._check_if_alive()
return self._events.stream(event_type, timeout=timeout, limit=limit)

def subscribe(
Expand All @@ -894,6 +905,7 @@ async def wait_for(
timeout: typing.Union[float, int, None],
predicate: typing.Optional[event_dispatcher.PredicateT[event_dispatcher.EventT_co]] = None,
) -> event_dispatcher.EventT_co:
self._check_if_alive()
return await self._events.wait_for(event_type, timeout=timeout, predicate=predicate)

async def update_presence(
Expand All @@ -904,6 +916,7 @@ async def update_presence(
activity: undefined.UndefinedNoneOr[presences.Activity] = undefined.UNDEFINED,
afk: undefined.UndefinedOr[bool] = undefined.UNDEFINED,
) -> None:
self._check_if_alive()
self._validate_activity(activity)

coros = [
Expand Down Expand Up @@ -949,6 +962,7 @@ async def _start_one_shard(
token=self._token,
url=url,
)
self._shards[new_shard.id] = new_shard

start = time.monotonic()
await aio.first_completed(new_shard.start(), self._closing_event.wait())
Expand Down
37 changes: 29 additions & 8 deletions hikari/impl/shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ async def send_close(self, *, code: int = 1000, message: bytes = b"") -> bool:
# something disconnects, which makes aiohttp just shut down as if we
# did it.
if not self.sent_close:
self.sent_close = True
self.logger.debug("sending close frame with code %s and message %s", int(code), message)
try:
return await asyncio.wait_for(super().close(code=code, message=message), timeout=5)
Expand Down Expand Up @@ -500,7 +501,7 @@ async def close(self) -> None:
"shard.close() was called and the websocket was still alive -- "
"disconnecting immediately with GOING AWAY"
)
await self._ws.close(code=errors.ShardCloseCode.GOING_AWAY, message=b"shard disconnecting")
await self._ws.send_close(code=errors.ShardCloseCode.GOING_AWAY, message=b"shard disconnecting")
self._closing.set()
finally:
self._chunking_rate_limit.close()
Expand All @@ -516,6 +517,23 @@ async def join(self) -> None:
"""Wait for this shard to close, if running."""
await self._closed.wait()

async def _send_json(
self,
data: data_binding.JSONObject,
compress: typing.Optional[int] = None,
*,
dumps: aiohttp.typedefs.JSONEncoder = json.dumps,
) -> None:
await self._total_rate_limit.acquire()

await self._ws.send_json(data=data, compress=compress, dumps=dumps) # type: ignore[union-attr]

def _check_if_alive(self) -> None:
if not self.is_alive:
raise errors.ComponentNotRunningError(
f"shard {self._shard_id} is not running so it cannot be interacted with"
)

async def request_guild_members(
self,
guild: snowflakes.SnowflakeishOr[guilds.PartialGuild],
Expand All @@ -526,6 +544,7 @@ async def request_guild_members(
users: undefined.UndefinedOr[snowflakes.SnowflakeishSequence[users_.User]] = undefined.UNDEFINED,
nonce: undefined.UndefinedOr[str] = undefined.UNDEFINED,
) -> None:
self._check_if_alive()
if not query and not limit and not self._intents & intents_.Intents.GUILD_MEMBERS:
raise errors.MissingIntentError(intents_.Intents.GUILD_MEMBERS)

Expand Down Expand Up @@ -554,7 +573,7 @@ async def request_guild_members(
payload.put_snowflake_array("user_ids", users)
payload.put("nonce", nonce)

await self._ws.send_json({_OP: _REQUEST_GUILD_MEMBERS, _D: payload}) # type: ignore[union-attr]
await self._send_json({_OP: _REQUEST_GUILD_MEMBERS, _D: payload})

async def start(self) -> None:
if self._run_task is not None:
Expand Down Expand Up @@ -582,14 +601,15 @@ async def update_presence(
activity: undefined.UndefinedNoneOr[presences.Activity] = undefined.UNDEFINED,
status: undefined.UndefinedOr[presences.Status] = undefined.UNDEFINED,
) -> None:
self._check_if_alive()
presence_payload = self._serialize_and_store_presence_payload(
idle_since=idle_since,
afk=afk,
activity=activity,
status=status,
)
payload: data_binding.JSONObject = {_OP: _PRESENCE_UPDATE, _D: presence_payload}
await self._ws.send_json(payload) # type: ignore[union-attr]
await self._send_json(payload)

async def update_voice_state(
self,
Expand All @@ -599,7 +619,8 @@ async def update_voice_state(
self_mute: bool = False,
self_deaf: bool = False,
) -> None:
await self._ws.send_json( # type: ignore[union-attr]
self._check_if_alive()
await self._send_json(
{
_OP: _VOICE_STATE_UPDATE,
_D: {
Expand Down Expand Up @@ -661,7 +682,7 @@ async def _identify(self) -> None:

payload[_D]["presence"] = self._serialize_and_store_presence_payload()

await self._ws.send_json(payload) # type: ignore[union-attr]
await self._send_json(payload)

async def _heartbeat(self, heartbeat_interval: float) -> bool:
# Return True if zombied or should reconnect, false if time to die forever.
Expand Down Expand Up @@ -734,7 +755,7 @@ async def _poll_events(self) -> typing.Optional[bool]:
return None

async def _resume(self) -> None:
await self._ws.send_json( # type: ignore[union-attr]
await self._send_json(
{
_OP: _RESUME,
_D: {"token": self._token, "seq": self._seq, "session_id": self._session_id},
Expand Down Expand Up @@ -907,11 +928,11 @@ async def _run_once(self) -> bool:
return True

async def _send_heartbeat(self) -> None:
await self._ws.send_json({_OP: _HEARTBEAT, _D: self._seq}) # type: ignore[union-attr]
await self._send_json({_OP: _HEARTBEAT, _D: self._seq})
self._last_heartbeat_sent = time.monotonic()

async def _send_heartbeat_ack(self) -> None:
await self._ws.send_json({_OP: _HEARTBEAT_ACK, _D: None}) # type: ignore[union-attr]
await self._send_json({_OP: _HEARTBEAT_ACK, _D: None})

@staticmethod
def _serialize_activity(activity: typing.Optional[presences.Activity]) -> data_binding.JSONish:
Expand Down
15 changes: 15 additions & 0 deletions hikari/traits.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,21 @@ class BotAware(RESTAware, ShardAware, EventFactoryAware, DispatcherAware, typing

__slots__: typing.Sequence[str] = ()

@property
def is_alive(self) -> bool:
"""Check whether the bot is running or not.
This is useful as some functions might raise
`hikari.errors.ComponentNotRunningError` if this is
`builtins.False`.
Returns
-------
builtins.bool
Whether the bot is running or not.
"""
raise NotImplementedError

async def join(self, until_close: bool = True) -> None:
"""Wait indefinitely until the application closes.
Expand Down
Loading

0 comments on commit 059b10c

Please sign in to comment.