Skip to content

Commit

Permalink
Merge pull request #102 from nekokatt/bugfix/closures
Browse files Browse the repository at this point in the history
Bugfix/closures
  • Loading branch information
Nekokatt authored Aug 30, 2020
2 parents 3e9f15f + d818f30 commit 20e24f9
Show file tree
Hide file tree
Showing 10 changed files with 491 additions and 313 deletions.
107 changes: 61 additions & 46 deletions hikari/impl/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
from hikari.events import lifetime_events
from hikari.impl import entity_factory as entity_factory_impl
from hikari.impl import event_factory as event_factory_impl
from hikari.impl import rate_limits
from hikari.impl import rest as rest_client_impl
from hikari.impl import shard as gateway_shard_impl
from hikari.impl import stateful_cache as cache_impl
Expand Down Expand Up @@ -201,20 +200,20 @@ class BotApp(

__slots__: typing.Sequence[str] = (
"_cache",
"_guild_chunker",
"_connector_factory",
"_debug",
"_entity_factory",
"_event_manager",
"_event_factory",
"_executor",
"_global_ratelimit",
"_guild_chunker",
"_http_settings",
"_initial_activity",
"_initial_idle_since",
"_initial_is_afk",
"_initial_status",
"_intents",
"_has_aborted",
"_large_threshold",
"_max_concurrency",
"_proxy_settings",
Expand Down Expand Up @@ -273,13 +272,13 @@ def __init__(
self._entity_factory = entity_factory_impl.EntityFactoryImpl(app=self)
self._event_factory = event_factory_impl.EventFactoryImpl(app=self)
self._executor = executor
self._global_ratelimit = rate_limits.ManualRateLimiter()
self._http_settings = config.HTTPSettings() if http_settings is None else http_settings
self._initial_activity = initial_activity
self._initial_idle_since = initial_idle_since
self._initial_is_afk = initial_is_afk
self._initial_status = initial_status
self._intents = intents
self._has_aborted = False
self._large_threshold = large_threshold
self._max_concurrency = 1
self._proxy_settings = config.ProxySettings() if proxy_settings is None else proxy_settings
Expand Down Expand Up @@ -508,7 +507,11 @@ async def start(self) -> None:
self._tasks.clear()
self._shard_gather_task = None

await self._init()
try:
await self._init()
except Exception:
await self.close()
raise

self._request_close_event.clear()

Expand Down Expand Up @@ -537,7 +540,15 @@ async def start(self) -> None:
window[shard_id] = asyncio.create_task(shard_obj.start(), name=f"start gateway shard {shard_id}")

# Wait for the group to start.
await asyncio.gather(*window.values())
gatherer = asyncio.gather(*window.values())
waiter = asyncio.create_task(self._request_close_event.wait(), name="listen for bot closure events")

await asyncio.wait((gatherer, waiter), return_when=asyncio.FIRST_COMPLETED)

if not waiter.done():
waiter.cancel()
else:
gatherer.cancel()

# Store the keep-alive tasks and continue.
for shard_id, start_task in window.items():
Expand All @@ -546,6 +557,7 @@ async def start(self) -> None:
finally:
if len(self._tasks) != len(self._shards):
_LOGGER.warning("application was aborted midway through initialization, so never managed to start")
await self.close()
raise errors.GatewayClientClosedError("Client was aborted midway through initialization")

finish_time = date.monotonic()
Expand Down Expand Up @@ -609,32 +621,23 @@ def dispatch(self, event: base_events.Event) -> asyncio.Future[typing.Any]:
return self.dispatcher.dispatch(event)

async def close(self) -> None:
"""Request that all shards disconnect and the application shuts down.
"""Immediately destroy all shards that are running and stop."""
self._request_close_event.set()

This will close all shards that are running, and then close any
REST components and connectors.
"""
self._guild_chunker.close()
# Prevent calling this multiple times.
if self._has_aborted:
return

try:
# This way if we cancel the stopping task, we still shut down properly.
self._request_close_event.set()
_LOGGER.info("stopping %s shard(s)", len(self._shards))

try:
if self._shards:
await self.dispatch(lifetime_events.StoppingEvent(app=self))
await self._abort_shards()
finally:
# The starting event occurs before the bot starts, regardless of if
# it had started or not, so it seems sensible stopped event has the
# same semantics.
self._tasks.clear()
await self.dispatch(lifetime_events.StoppedEvent(app=self))
finally:
await self._rest.close()
await self._connector_factory.close()
self._global_ratelimit.close()
self._has_aborted = True
self._guild_chunker.close()
await self.dispatch(lifetime_events.StoppingEvent(app=self))
await self._abort_shards()
self._tasks.clear()
await self.dispatch(lifetime_events.StoppedEvent(app=self))
await self._rest.close()
await self._connector_factory.close()
self._shard_gather_task = None
self._request_close_event.clear()

def run(
self,
Expand Down Expand Up @@ -744,7 +747,7 @@ def run(

def die() -> None:
_LOGGER.info("received signal to shut down client")
asyncio.ensure_future(self.close())
self._request_close_event.set()

for signum in kill_signals:
# Windows is dumb and doesn't support signals properly.
Expand All @@ -758,30 +761,33 @@ def die() -> None:
finally:
loop.run_until_complete(self.join())
except errors.GatewayClientClosedError as ex:
_LOGGER.info(str(ex))
_LOGGER.info("client closed with reason: %s", ex)
finally:
for signum in kill_signals:
# Windows is dumb and doesn't support signals properly.
with contextlib.suppress(NotImplementedError):
loop.remove_signal_handler(signum)

if finalize_loop_on_close:
_LOGGER.debug("closing asyncgens for event loop %s", loop)
loop.run_until_complete(loop.shutdown_asyncgens())
remaining_tasks = [t for t in asyncio.all_tasks(loop) if not t.done()]

remaining_tasks = asyncio.all_tasks(loop)
if remaining_tasks:
_LOGGER.warning("forcefully stopping %s remaining tasks", len(remaining_tasks))
_LOGGER.debug("forcefully stopping %s remaining tasks", len(remaining_tasks))

for task in remaining_tasks:
task.cancel()
loop.run_until_complete(asyncio.gather(*remaining_tasks, return_exceptions=True))

# Don't warn that these were never retrieved.
with contextlib.suppress(asyncio.InvalidStateError):
task.exception()

for task in remaining_tasks:
if not task.cancelled():
exception = task.exception()
if exception is not None:
_LOGGER.warning("unhandled exception during shutdown", exc_info=exception)
else:
_LOGGER.debug("no tasks are running, congratulations on writing a tidy application")

_LOGGER.debug("closing asyncgens for event loop %s", loop)
loop.run_until_complete(loop.shutdown_asyncgens())
loop.close()

async def join(self) -> None:
Expand Down Expand Up @@ -939,9 +945,8 @@ def _max_concurrency_chunker(self) -> typing.Iterator[typing.Iterator[int]]:
async def _abort_shards(self) -> None:
"""Close all shards and wait for them to terminate."""
for shard_id in self._shards:
if self._shards[shard_id].is_alive:
_LOGGER.debug("stopping shard %s", shard_id)
await self._shards[shard_id].close()
_LOGGER.debug("stopping shard %s", shard_id)
await self._shards[shard_id].close()
await asyncio.gather(*self._tasks.values(), return_exceptions=True)

async def _gather_shard_lifecycles(self) -> None:
Expand All @@ -950,12 +955,22 @@ async def _gather_shard_lifecycles(self) -> None:
Ensure shards are requested to close before the coroutine function
completes.
"""
_LOGGER.debug("gathering shards")
gatherer = asyncio.gather(*self._tasks.values())
waiter = asyncio.create_task(self._request_close_event.wait(), name="listen for bot closure events")

try:
_LOGGER.debug("gathering shards")
await asyncio.gather(*self._tasks.values())
await asyncio.wait([gatherer, waiter], return_when=asyncio.FIRST_COMPLETED)

if not waiter.done():
waiter.cancel()
finally:
_LOGGER.debug("gather terminated, shutting down shard(s)")
await asyncio.shield(self.close())
aborter = asyncio.shield(self.close())
try:
await gatherer
finally:
await aborter

async def _shard_management_lifecycle(self) -> None:
"""Start all shards and then wait for them to finish."""
Expand Down
6 changes: 5 additions & 1 deletion hikari/impl/buckets.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def start(self, poll_period: float = _POLL_PERIOD, expire_after: float = _EXPIRE
as the rate limit has reset. Defaults to `10` seconds.
"""
if not self.gc_task:
self.gc_task = asyncio.get_running_loop().create_task(self.gc(poll_period, expire_after))
self.gc_task = asyncio.create_task(self.gc(poll_period, expire_after))

def close(self) -> None:
"""Close the garbage collector and kill any tasks waiting on ratelimits.
Expand All @@ -396,6 +396,10 @@ def close(self) -> None:
self.real_hashes_to_buckets.clear()
self.routes_to_hashes.clear()

if self.gc_task is not None:
self.gc_task.cancel()
self.gc_task = None

# Ignore docstring not starting in an imperative mood
async def gc(self, poll_period: float, expire_after: float) -> None: # noqa: D401
"""The garbage collector loop.
Expand Down
1 change: 1 addition & 0 deletions hikari/impl/rate_limits.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def close(self) -> None:

if self.throttle_task is not None:
self.throttle_task.cancel()
self.throttle_task = None

failed_tasks = 0
while self.queue:
Expand Down
6 changes: 4 additions & 2 deletions hikari/impl/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ class BasicLazyCachedTCPConnectorFactory(rest_api.ConnectorFactory):

def __init__(self, **kwargs: typing.Any) -> None:
self.connector: typing.Optional[aiohttp.TCPConnector] = None
kwargs.setdefault("force_close", True)
kwargs.setdefault("enable_cleanup_closed", True)
self.connector_kwargs = kwargs

async def close(self) -> None:
Expand Down Expand Up @@ -424,6 +426,7 @@ async def close(self) -> None:
"""Close the HTTP client and any open HTTP connections."""
if self._client_session is not None:
await self._client_session.close()
await self._connector_factory.close()
self.global_rate_limit.close()
self.buckets.close()
self._closed_event.set()
Expand All @@ -444,9 +447,8 @@ def _acquire_client_session(self) -> aiohttp.ClientSession:
if self._client_session is None:
self._closed_event.clear()
self._client_session = aiohttp.ClientSession(
# Should not need a lock, since we don't technically await anything.
connector=self._connector_factory.acquire(),
connector_owner=self._connector_owner,
connector_owner=False,
version=aiohttp.HttpVersion11,
timeout=aiohttp.ClientTimeout(
total=self._http_settings.timeouts.total,
Expand Down
Loading

0 comments on commit 20e24f9

Please sign in to comment.