From ad7f46f86cee2fb690494d1527eb6f8a00446299 Mon Sep 17 00:00:00 2001 From: Nekokatt Date: Mon, 24 Aug 2020 03:15:14 +0100 Subject: [PATCH] Bugfix/77 wait for race (#82) * Optimised raw event dispatching to uncover bug. Looks like, at least on my machine, asyncio immediately invokes anything you await rather than switching to another task on the queue first unless the call does raw IO. I have confirmed this with Epoll, Poll and Select selector implementations on a non-debug asyncio SelectorEventLoop implementation. This means that the bulk of dispatching an event would currently occur as soon as the event is dispatched rather than after another task runs, which could lead to immediate slowdown if other tasks are queued. Switching to sync dispatching and using create task to invoke the callback management "later" seems to speed up this implementation significantly and allows other race conditions we have not accounted for properly as part of #77 to be detectable with test scripts that saturate the event loop. * Updated CLi script to show OS type as well. * Added code to allow debugging of asyncio loop blocking incidents. * Fixes #77 dispatcher wait_for race condition. * Removed async predicates for wait_for, removing last parts of race condition hopefully. * Fixes #77 dispatcher wait_for race condition. --- hikari/api/event_dispatcher.py | 18 ++-- hikari/cli.py | 1 + hikari/events/base_events.py | 37 ++++---- hikari/impl/bot.py | 60 ++++++++++--- hikari/impl/event_manager_base.py | 93 +++++++------------- hikari/impl/shard.py | 23 ++--- hikari/utilities/aio.py | 47 +++++++++- tests/hikari/events/test_base_events.py | 29 ++++-- tests/hikari/impl/test_event_manager_base.py | 13 +-- tests/hikari/impl/test_shard.py | 39 +++----- 10 files changed, 204 insertions(+), 156 deletions(-) diff --git a/hikari/api/event_dispatcher.py b/hikari/api/event_dispatcher.py index 44c194f0a3..bbef968b05 100644 --- a/hikari/api/event_dispatcher.py +++ b/hikari/api/event_dispatcher.py @@ -35,8 +35,8 @@ EventT_co = typing.TypeVar("EventT_co", bound=base_events.Event, covariant=True) EventT_inv = typing.TypeVar("EventT_inv", bound=base_events.Event) - PredicateT = typing.Callable[[EventT_co], typing.Union[bool, typing.Coroutine[typing.Any, typing.Any, bool]]] - AsyncCallbackT = typing.Callable[[EventT_inv], typing.Coroutine[typing.Any, typing.Any, None]] + PredicateT = typing.Callable[[EventT_co], bool] + CallbackT = typing.Callable[[EventT_inv], typing.Coroutine[typing.Any, typing.Any, None]] class EventDispatcher(abc.ABC): @@ -134,9 +134,7 @@ async def on_everyone_mentioned(event): # For the sake of UX, I will check this at runtime instead and let the # user use a static type checker. @abc.abstractmethod - def subscribe( - self, event_type: typing.Type[typing.Any], callback: AsyncCallbackT[typing.Any] - ) -> AsyncCallbackT[typing.Any]: + def subscribe(self, event_type: typing.Type[typing.Any], callback: CallbackT[typing.Any]) -> CallbackT[typing.Any]: """Subscribe a given callback to a given event type. Parameters @@ -180,7 +178,7 @@ async def on_message(event): # For the sake of UX, I will check this at runtime instead and let the # user use a static type checker. @abc.abstractmethod - def unsubscribe(self, event_type: typing.Type[typing.Any], callback: AsyncCallbackT[typing.Any]) -> None: + def unsubscribe(self, event_type: typing.Type[typing.Any], callback: CallbackT[typing.Any]) -> None: """Unsubscribe a given callback from a given event type, if present. Parameters @@ -210,7 +208,7 @@ async def on_message(event): @abc.abstractmethod def get_listeners( self, event_type: typing.Type[EventT_co], *, polymorphic: bool = True, - ) -> typing.Collection[AsyncCallbackT[EventT_co]]: + ) -> typing.Collection[CallbackT[EventT_co]]: """Get the listeners for a given event type, if there are any. Parameters @@ -240,7 +238,7 @@ def get_listeners( @abc.abstractmethod def listen( self, event_type: typing.Optional[typing.Type[EventT_co]] = None, - ) -> typing.Callable[[AsyncCallbackT[EventT_co]], AsyncCallbackT[EventT_co]]: + ) -> typing.Callable[[CallbackT[EventT_co]], CallbackT[EventT_co]]: """Generate a decorator to subscribe a callback to an event type. This is a second-order decorator. @@ -285,11 +283,13 @@ async def wait_for( The event type to listen for. This will listen for subclasses of this type additionally. predicate - A function or coroutine taking the event as the single parameter. + A function taking the event as the single parameter. This should return `builtins.True` if the event is one you want to return, or `builtins.False` if the event should not be returned. If left as `None` (the default), then the first matching event type that the bot receives (or any subtype) will be the one returned. + + ASYNC PREDICATES ARE NOT SUPPORTED. timeout : typing.Optional[builtins.float or builtins.int] The amount of time to wait before raising an `asyncio.TimeoutError` and giving up instead. This is measured in seconds. If diff --git a/hikari/cli.py b/hikari/cli.py index 59a970f7f7..08070c0f75 100644 --- a/hikari/cli.py +++ b/hikari/cli.py @@ -48,3 +48,4 @@ def main() -> None: sys.stderr.write(f"hikari v{version} {sha1}\n") sys.stderr.write(f"located at {path}\n") sys.stderr.write(f"{py_impl} {py_ver} {py_compiler}\n") + sys.stderr.write(" ".join(frag.strip() for frag in platform.uname() if frag and frag.strip()) + "\n") diff --git a/hikari/events/base_events.py b/hikari/events/base_events.py index a073c05d98..3dfdef9543 100644 --- a/hikari/events/base_events.py +++ b/hikari/events/base_events.py @@ -162,21 +162,6 @@ class ExceptionEvent(Event, typing.Generic[FailedEventT]): side-effects on the application runtime. """ - app: traits.RESTAware = attr.ib(metadata={attr_extensions.SKIP_DEEP_COPY: True}) - # <>. - - shard: typing.Optional[gateway_shard.GatewayShard] = attr.ib(metadata={attr_extensions.SKIP_DEEP_COPY: True}) - """Shard that received the event. - - Returns - ------- - hikari.api.shard.GatewayShard - Shard that raised this exception. - - This may be `builtins.None` if no specific shard was the cause of this - exception (e.g. when starting up or shutting down). - """ - exception: Exception = attr.ib() """Exception that was raised. @@ -201,6 +186,28 @@ class ExceptionEvent(Event, typing.Generic[FailedEventT]): # for us to remove this effect. This functionally changes nothing but it helps MyPy. _failed_callback: FailedCallbackT[FailedEventT] = attr.ib() + @property + def app(self) -> traits.RESTAware: + # <>. + return self.failed_event.app + + @property + def shard(self) -> typing.Optional[gateway_shard.GatewayShard]: + """Shard that received the event, if there was one associated. + + Returns + ------- + typing.Optional[hikari.api.shard.GatewayShard] + Shard that raised this exception. + + This may be `builtins.None` if no specific shard was the cause of this + exception (e.g. when starting up or shutting down). + """ + shard = getattr(self.failed_event, "shard", None) + if isinstance(shard, gateway_shard.GatewayShard): + return shard + return None + @property def failed_callback(self) -> FailedCallbackT[FailedEventT]: """Event callback that threw an exception. diff --git a/hikari/impl/bot.py b/hikari/impl/bot.py index e3fdb4720b..f6c13da524 100644 --- a/hikari/impl/bot.py +++ b/hikari/impl/bot.py @@ -57,6 +57,7 @@ from hikari.impl import stateless_event_manager from hikari.impl import stateless_guild_chunker as stateless_guild_chunker_impl from hikari.impl import voice +from hikari.utilities import aio from hikari.utilities import art from hikari.utilities import constants from hikari.utilities import date @@ -574,30 +575,30 @@ async def start(self) -> None: def listen( self, event_type: typing.Optional[typing.Type[event_dispatcher.EventT_co]] = None, ) -> typing.Callable[ - [event_dispatcher.AsyncCallbackT[event_dispatcher.EventT_co]], - event_dispatcher.AsyncCallbackT[event_dispatcher.EventT_co], + [event_dispatcher.CallbackT[event_dispatcher.EventT_co]], + event_dispatcher.CallbackT[event_dispatcher.EventT_co], ]: # <> return self.dispatcher.listen(event_type) def get_listeners( self, event_type: typing.Type[event_dispatcher.EventT_co], *, polymorphic: bool = True, - ) -> typing.Collection[event_dispatcher.AsyncCallbackT[event_dispatcher.EventT_co]]: + ) -> typing.Collection[event_dispatcher.CallbackT[event_dispatcher.EventT_co]]: # <> return self.dispatcher.get_listeners(event_type, polymorphic=polymorphic) def subscribe( self, event_type: typing.Type[event_dispatcher.EventT_co], - callback: event_dispatcher.AsyncCallbackT[event_dispatcher.EventT_co], - ) -> event_dispatcher.AsyncCallbackT[event_dispatcher.EventT_co]: + callback: event_dispatcher.CallbackT[event_dispatcher.EventT_co], + ) -> event_dispatcher.CallbackT[event_dispatcher.EventT_co]: # <> return self.dispatcher.subscribe(event_type, callback) def unsubscribe( self, event_type: typing.Type[event_dispatcher.EventT_co], - callback: event_dispatcher.AsyncCallbackT[event_dispatcher.EventT_co], + callback: event_dispatcher.CallbackT[event_dispatcher.EventT_co], ) -> None: # <> return self.dispatcher.unsubscribe(event_type, callback) @@ -642,7 +643,12 @@ async def close(self) -> None: await self._connector_factory.close() self._global_ratelimit.close() - def run(self) -> None: + def run( + self, + *, + loop: typing.Optional[asyncio.AbstractEventLoop] = None, + slow_callback_duration: typing.Optional[float] = None, + ) -> None: """Run this application on the current thread in an event loop. This will use the event loop that is set for the current thread, or @@ -658,19 +664,47 @@ def run(self) -> None: The application is always guaranteed to be shut down before this function completes or propagates any exception. + + Parameters + ---------- + loop : typing.Optional[asyncio.AbstractEventLoop] + Event loop to run on. This defaults to `builtins.None`. + + If `builtins.None`, the event loop set for the current thread will + be used. If the thread does not have an event loop, then one will + be created first and registered to the running thread. + + It is advisable to only have one event loop per thread. Generally + you should not have a need to specify this. + slow_callback_duration : typing.Optional[builtins.float] + How long a coroutine should block for in seconds before it shows a + warning. + + This defaults to being `builtins.None`, which will disable the + feature (since it may cause a small increase in execution latency). + If specified as a number, it will be enabled. """ - try: - loop = asyncio.get_event_loop() - except RuntimeError: - _LOGGER.debug("no event loop registered on this thread; now creating one...") - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) + if loop is None: + try: + loop = asyncio.get_event_loop() + _LOGGER.debug("using default thread's event loop") + except RuntimeError: + _LOGGER.debug("no event loop registered on this thread; now creating one...") + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + # We always expect this to be populated by now. + loop: asyncio.AbstractEventLoop + + if slow_callback_duration and slow_callback_duration > 0: + aio.patch_slow_callback_detection(slow_callback_duration) try: self._map_signal_handlers( loop.add_signal_handler, lambda *_: loop.create_task(self.close(), name="signal interrupt shutting down application"), ) + _LOGGER.debug("using default thread's event loop", loop) loop.run_until_complete(self._shard_management_lifecycle()) except KeyboardInterrupt as ex: diff --git a/hikari/impl/event_manager_base.py b/hikari/impl/event_manager_base.py index 0ed4d90c30..7799f55e62 100644 --- a/hikari/impl/event_manager_base.py +++ b/hikari/impl/event_manager_base.py @@ -36,7 +36,6 @@ from hikari import traits from hikari.api import event_dispatcher from hikari.events import base_events -from hikari.events import shard_events from hikari.utilities import aio from hikari.utilities import data_binding from hikari.utilities import reflect @@ -50,7 +49,7 @@ if typing.TYPE_CHECKING: ListenerMapT = typing.MutableMapping[ typing.Type[event_dispatcher.EventT_co], - typing.MutableSequence[event_dispatcher.AsyncCallbackT[event_dispatcher.EventT_co]], + typing.MutableSequence[event_dispatcher.CallbackT[event_dispatcher.EventT_co]], ] WaiterT = typing.Tuple[ event_dispatcher.PredicateT[event_dispatcher.EventT_co], asyncio.Future[event_dispatcher.EventT_co] @@ -79,7 +78,7 @@ def __init__(self, app: traits.BotAware, intents: typing.Optional[intents_.Inten self._listeners: ListenerMapT[base_events.Event] = {} self._waiters: WaiterMapT[base_events.Event] = {} - async def consume_raw_event( + def consume_raw_event( self, shard: gateway_shard.GatewayShard, event_name: str, payload: data_binding.JSONObject ) -> None: try: @@ -87,21 +86,21 @@ async def consume_raw_event( except AttributeError: _LOGGER.debug("ignoring unknown event %s", event_name) else: - await callback(shard, payload) + asyncio.create_task(callback(shard, payload)) def subscribe( self, event_type: typing.Type[event_dispatcher.EventT_co], - callback: event_dispatcher.AsyncCallbackT[event_dispatcher.EventT_co], + callback: event_dispatcher.CallbackT[event_dispatcher.EventT_co], *, _nested: int = 0, - ) -> event_dispatcher.AsyncCallbackT[event_dispatcher.EventT_co]: - if not asyncio.iscoroutinefunction(callback): - raise TypeError("Event callbacks must be coroutine functions (`async def')") - - if not inspect.isclass(event_type) or not issubclass(event_type, base_events.Event): + ) -> event_dispatcher.CallbackT[event_dispatcher.EventT_co]: + if not issubclass(event_type, base_events.Event): raise TypeError("Cannot subscribe to a non-Event type") + if not inspect.iscoroutinefunction(callback): + raise TypeError("Cannot subscribe a non-coroutine function callback") + # `_nested` is used to show the correct source code snippet if an intent # warning is triggered. self._check_intents(event_type, _nested) @@ -144,12 +143,9 @@ def _check_intents(self, event_type: typing.Type[event_dispatcher.EventT_co], ne def get_listeners( self, event_type: typing.Type[event_dispatcher.EventT_co], *, polymorphic: bool = True, - ) -> typing.Collection[event_dispatcher.AsyncCallbackT[event_dispatcher.EventT_co]]: - if not inspect.isclass(event_type) or not issubclass(event_type, base_events.Event): - raise TypeError(f"Can only get listeners for subclasses of {base_events.Event.__name__}") - + ) -> typing.Collection[event_dispatcher.CallbackT[event_dispatcher.EventT_co]]: if polymorphic: - listeners: typing.List[event_dispatcher.AsyncCallbackT[event_dispatcher.EventT_co]] = [] + listeners: typing.List[event_dispatcher.CallbackT[event_dispatcher.EventT_co]] = [] for subscribed_event_type, subscribed_listeners in self._listeners.items(): if issubclass(subscribed_event_type, event_type): listeners += subscribed_listeners @@ -164,7 +160,7 @@ def get_listeners( def unsubscribe( self, event_type: typing.Type[event_dispatcher.EventT_co], - callback: event_dispatcher.AsyncCallbackT[event_dispatcher.EventT_co], + callback: event_dispatcher.CallbackT[event_dispatcher.EventT_co], ) -> None: if event_type in self._listeners: _LOGGER.debug( @@ -181,12 +177,12 @@ def unsubscribe( def listen( self, event_type: typing.Optional[typing.Type[event_dispatcher.EventT_co]] = None, ) -> typing.Callable[ - [event_dispatcher.AsyncCallbackT[event_dispatcher.EventT_co]], - event_dispatcher.AsyncCallbackT[event_dispatcher.EventT_co], + [event_dispatcher.CallbackT[event_dispatcher.EventT_co]], + event_dispatcher.CallbackT[event_dispatcher.EventT_co], ]: def decorator( - callback: event_dispatcher.AsyncCallbackT[event_dispatcher.EventT_co], - ) -> event_dispatcher.AsyncCallbackT[event_dispatcher.EventT_co]: + callback: event_dispatcher.CallbackT[event_dispatcher.EventT_co], + ) -> event_dispatcher.CallbackT[event_dispatcher.EventT_co]: nonlocal event_type signature = reflect.resolve_signature(callback) @@ -220,42 +216,31 @@ def dispatch(self, event: event_dispatcher.EventT_inv) -> asyncio.Future[typing. tasks: typing.List[typing.Coroutine[None, typing.Any, None]] = [] for cls in mro[: mro.index(base_events.Event) + 1]: - if cls in self._listeners: for callback in self._listeners[cls]: tasks.append(self._invoke_callback(callback, event)) if cls in self._waiters: - for predicate, future in self._waiters[cls]: - tasks.append(self._test_waiter(event, predicate, future)) # type: ignore[misc] + for predicate, future in tuple(self._waiters[cls]): + try: + result = predicate(event) + if not result: + continue + except Exception as ex: + future.set_exception(ex) + else: + future.set_result(event) + + waiter_set = self._waiters[cls] + waiter_set.remove((predicate, future)) return asyncio.gather(*tasks) if tasks else aio.completed_future() - @staticmethod - async def _test_waiter( - event: event_dispatcher.EventT_inv, - predicate: event_dispatcher.PredicateT[event_dispatcher.EventT_inv], - future: asyncio.Future[event_dispatcher.EventT_inv], - ) -> None: - try: - result = predicate(event) - if asyncio.iscoroutine(result): - result = await result # type: ignore - - if not result: - return - - except Exception as ex: - future.set_exception(ex) - else: - future.set_result(event) - async def _invoke_callback( - self, callback: event_dispatcher.AsyncCallbackT[event_dispatcher.EventT_inv], event: event_dispatcher.EventT_inv + self, callback: event_dispatcher.CallbackT[event_dispatcher.EventT_inv], event: event_dispatcher.EventT_inv ) -> None: try: await callback(event) - except Exception as ex: # Skip the first frame in logs, we don't care for it. trio = type(ex), ex, ex.__traceback__.tb_next if ex.__traceback__ is not None else None @@ -268,11 +253,7 @@ async def _invoke_callback( ) else: exception_event = base_events.ExceptionEvent( - app=self._app, - shard=getattr(event, "shard") if isinstance(event, shard_events.ShardEvent) else None, - exception=ex, - failed_event=event, - failed_callback=callback, + exception=ex, failed_event=event, failed_callback=callback, ) log = _LOGGER.debug if self.get_listeners(type(exception_event), polymorphic=True) else _LOGGER.error @@ -304,13 +285,7 @@ async def wait_for( waiter_set.add(pair) # type: ignore[arg-type] - try: - if timeout is not None: - return await asyncio.wait_for(future, timeout=timeout) - else: - return await future - - finally: - waiter_set.remove(pair) # type: ignore[arg-type] - if not waiter_set: - del self._waiters[event_type] + if timeout is not None: + return await asyncio.wait_for(future, timeout=timeout) + else: + return await future diff --git a/hikari/impl/shard.py b/hikari/impl/shard.py index ae9a65171b..50e26e1495 100644 --- a/hikari/impl/shard.py +++ b/hikari/impl/shard.py @@ -75,11 +75,11 @@ class GatewayShardImplV6(shard.GatewayShard): logs. If `builtins.False`, only the fact that data has been sent/received will be logged. event_consumer - A coroutine function consuming a `GatewayShardImplV6`, + A non-coroutine function consuming a `GatewayShardImplV6`, a `builtins.str` event name, and a `hikari.utilities.data_binding.JSONObject` event object as parameters. - This should return `builtins.None`, and will be called asynchronously - with each event that fires. + This should return `builtins.None`, and will be called with each event + that fires. http_settings : hikari.config.HTTPSettings The HTTP-related settings to use while negotiating a websocket. initial_activity : typing.Optional[hikari.presences.Activity] @@ -216,9 +216,7 @@ def __init__( compression: typing.Optional[str] = shard.GatewayCompression.PAYLOAD_ZLIB_STREAM, data_format: str = shard.GatewayDataFormat.JSON, debug: bool = False, - event_consumer: typing.Callable[ - [shard.GatewayShard, str, data_binding.JSONObject], typing.Coroutine[None, None, None] - ], + event_consumer: typing.Callable[[shard.GatewayShard, str, data_binding.JSONObject], None], http_settings: config.HTTPSettings, initial_activity: typing.Optional[presences.Activity] = None, initial_idle_since: typing.Optional[datetime.datetime] = None, @@ -590,7 +588,7 @@ async def _run_once(self, client_session: aiohttp.ClientSession) -> None: # Technically we are connected after the hello, but this ensures we can send and receive # before firing that event. - self._dispatch("CONNECTED", {}) + self._event_consumer(self, "CONNECTED", {}) try: @@ -613,12 +611,12 @@ async def _run_once(self, client_session: aiohttp.ClientSession) -> None: finally: heartbeat.cancel() finally: - self._dispatch("DISCONNECTED", {}) + self._event_consumer(self, "DISCONNECTED", {}) self._connected_at = None async def _close_ws(self, code: int, message: str) -> None: self._logger.debug("sending close frame with code %s and message %r", int(code), message) - # None if the websocket error'ed on initialization. + # None if the websocket error on initialization. if self._ws is not None: await self._ws.close(code=code, message=bytes(message, "utf-8")) @@ -751,7 +749,7 @@ async def _poll_events(self) -> None: self._logger.info("shard has resumed [session:%s, seq:%s]", self._session_id, self._seq) self._handshake_event.set() - self._dispatch(event, data) + self._event_consumer(self, event, data) elif op == self._Opcode.HEARTBEAT: self._logger.debug("received HEARTBEAT; sending HEARTBEAT ACK") @@ -864,11 +862,6 @@ async def _send_json(self, payload: data_binding.JSONObject) -> None: self._log_debug_payload(message, "sending json payload [op:%s]", payload.get("op")) await self._ws.send_str(message) # type: ignore[union-attr] - def _dispatch(self, event_name: str, event: data_binding.JSONObject) -> asyncio.Task[None]: - return asyncio.create_task( - self._event_consumer(self, event_name, event), name=f"gateway shard {self._shard_id} dispatch {event_name}", - ) - def _log_debug_payload(self, payload: str, message: str, *args: typing.Any) -> None: # Prevent logging these payloads if logging is not enabled. This aids performance a little. if not self._logger.isEnabledFor(logging.DEBUG): diff --git a/hikari/utilities/aio.py b/hikari/utilities/aio.py index 8c5ac19e9c..ac953752be 100644 --- a/hikari/utilities/aio.py +++ b/hikari/utilities/aio.py @@ -23,15 +23,60 @@ from __future__ import annotations -__all__: typing.Final[typing.List[str]] = ["completed_future", "is_async_iterator", "is_async_iterable"] +__all__: typing.Final[typing.List[str]] = [ + "patch_slow_callback_detection", + "completed_future", + "is_async_iterator", + "is_async_iterable", +] import asyncio import inspect +import logging import typing +from hikari.utilities import date + T_co = typing.TypeVar("T_co", covariant=True) T_inv = typing.TypeVar("T_inv") +_LOGGER: typing.Final[logging.Logger] = logging.getLogger(__name__) + + +@typing.no_type_check +def patch_slow_callback_detection(duration: float = 0.5) -> None: + """Patches some asyncio internals to allow detection of slow callbacks. + + Any callbacks that take more than the given `duration` will be logged. + + Parameters + ---------- + duration : float + The duration to wait for before classifying a task as being "slow". + """ + _LOGGER.debug("setting slow callback duration for loop to %.0fms", duration * 1_000) + + original_run = getattr(asyncio.Handle, "_run") + + def stringify(self: asyncio.Handle) -> str: + if _LOGGER.isEnabledFor(logging.WARNING): + if isinstance(getattr(self._callback, "__self__", None), asyncio.Task): + return repr(self._callback.__self__) + return str(self._callback) + return "" + + def run(self: asyncio.Handle) -> None: + """Instrumented runner.""" + start = date.monotonic() + try: + original_run(self) + finally: + period = date.monotonic() - start + if period >= duration: + _LOGGER.warning("Callback %s blocked for %.1fms!", stringify(self), period * 1_000) + + setattr(asyncio.Handle, "_run", run) + def completed_future(result: typing.Optional[T_inv] = None, /) -> asyncio.Future[typing.Optional[T_inv]]: """Create a future on the current running loop that is completed, then return it. diff --git a/tests/hikari/events/test_base_events.py b/tests/hikari/events/test_base_events.py index 72f3a0d9fb..2c87e02220 100644 --- a/tests/hikari/events/test_base_events.py +++ b/tests/hikari/events/test_base_events.py @@ -23,12 +23,13 @@ import pytest from hikari import intents +from hikari.api import shard as gateway_shard from hikari.events import base_events @base_events.requires_intents(intents.Intents.GUILDS) @attr.s(eq=False, hash=False, init=False, kw_only=True, slots=True) -class DummyGuildEVent(base_events.Event): +class DummyGuildEvent(base_events.Event): pass @@ -46,7 +47,7 @@ class ErrorEvent(base_events.Event): @attr.s(eq=False, hash=False, init=False, kw_only=True, slots=True) -class DummyGuildDerivedEvent(DummyGuildEVent): +class DummyGuildDerivedEvent(DummyGuildEvent): pass @@ -58,12 +59,12 @@ class DummyPresenceDerivedEvent(DummyPresenceEvent): def test_is_no_recursive_throw_event_marked(): assert base_events.is_no_recursive_throw_event(DummyPresenceEvent) assert base_events.is_no_recursive_throw_event(ErrorEvent) - assert not base_events.is_no_recursive_throw_event(DummyGuildEVent) + assert not base_events.is_no_recursive_throw_event(DummyGuildEvent) assert not base_events.is_no_recursive_throw_event(DummyGuildDerivedEvent) def test_requires_intents(): - assert list(base_events.get_required_intents_for(DummyGuildEVent)) == [intents.Intents.GUILDS] + assert list(base_events.get_required_intents_for(DummyGuildEvent)) == [intents.Intents.GUILDS] assert list(base_events.get_required_intents_for(DummyPresenceEvent)) == [intents.Intents.GUILD_PRESENCES] assert list(base_events.get_required_intents_for(ErrorEvent)) == [] @@ -90,13 +91,23 @@ def error(self): @pytest.fixture def event(self, error): return base_events.ExceptionEvent( - app=object(), - shard=object(), - exception=error, - failed_event=mock.Mock(base_events.Event), - failed_callback=mock.AsyncMock(), + exception=error, failed_event=mock.Mock(base_events.Event), failed_callback=mock.AsyncMock(), ) + def test_app_property(self, event): + app = mock.Mock() + event.failed_event.app = app + assert event.app is app + + @pytest.mark.parametrize("has_shard", [True, False]) + def test_shard_property(self, has_shard, event): + shard = mock.Mock(spec_set=gateway_shard.GatewayShard) + if has_shard: + event.failed_event.shard = shard + assert event.shard is shard + else: + assert event.shard is None + def test_failed_callback_property(self, event): stub_callback = object() event._failed_callback = stub_callback diff --git a/tests/hikari/impl/test_event_manager_base.py b/tests/hikari/impl/test_event_manager_base.py index e8ed021a08..d1f4539a03 100644 --- a/tests/hikari/impl/test_event_manager_base.py +++ b/tests/hikari/impl/test_event_manager_base.py @@ -45,18 +45,20 @@ class EventManagerBaseImpl(event_manager_base.EventManagerBase): @pytest.mark.asyncio async def test_consume_raw_event_when_AttributeError(self, event_manager): with mock.patch.object(event_manager_base, "_LOGGER") as logger: - await event_manager.consume_raw_event(None, "UNEXISTING_EVENT", {}) + event_manager.consume_raw_event(None, "UNEXISTING_EVENT", {}) logger.debug.assert_called_once_with("ignoring unknown event %s", "UNEXISTING_EVENT") @pytest.mark.asyncio async def test_consume_raw_event_when_found(self, event_manager): - event_manager.on_existing_event = mock.AsyncMock() + event_manager.on_existing_event = mock.Mock() shard = object() - await event_manager.consume_raw_event(shard, "EXISTING_EVENT", {}) + with mock.patch("asyncio.create_task") as create_task: + event_manager.consume_raw_event(shard, "EXISTING_EVENT", {}) - event_manager.on_existing_event.assert_awaited_once_with(shard, {}) + event_manager.on_existing_event.assert_called_once_with(shard, {}) + create_task.assert_called_once_with(event_manager.on_existing_event(shard, {})) def test_subscribe_when_callback_is_not_coroutine(self, event_manager): def test(): @@ -143,8 +145,7 @@ def test__check_intents_when_intents_incorrect(self, event_manager): ) def test_get_listeners_when_not_event(self, event_manager): - with pytest.raises(TypeError): - event_manager.get_listeners("test") + assert len(event_manager.get_listeners("test")) == 0 def test_get_listeners_polimorphic(self, event_manager): event_manager._listeners = { diff --git a/tests/hikari/impl/test_shard.py b/tests/hikari/impl/test_shard.py index 4667eaf2ac..19d1b6b78a 100644 --- a/tests/hikari/impl/test_shard.py +++ b/tests/hikari/impl/test_shard.py @@ -292,10 +292,9 @@ def client(self, http_settings=http_settings, proxy_settings=proxy_settings): ) client = hikari_test_helpers.mock_methods_on( client, - except_=("_run_once_shielded", "_InvalidSession", "_Reconnect", "_SocketClosed", "_dispatch", "_Opcode",), - also_mock=["_backoff", "_handshake_event", "_request_close_event", "_logger"], + except_=("_run_once_shielded", "_InvalidSession", "_Reconnect", "_SocketClosed", "_Opcode",), + also_mock=["_backoff", "_handshake_event", "_request_close_event", "_logger", "_event_consumer"], ) - client._dispatch = mock.AsyncMock() # Disable backoff checking by making the condition a negative tautology. client._RESTART_RATELIMIT_WINDOW = -1 return client @@ -458,7 +457,7 @@ def client(self, http_settings, proxy_settings): client = hikari_test_helpers.mock_methods_on( client, except_=("_run_once", "_InvalidSession", "_Reconnect", "_SocketClosed", "_Opcode",), - also_mock=["_backoff", "_handshake_event", "_request_close_event", "_logger",], + also_mock=["_backoff", "_handshake_event", "_request_close_event", "_logger", "_event_consumer"], ) # Disable backoff checking by making the condition a negative tautology. client._RESTART_RATELIMIT_WINDOW = -1 @@ -617,7 +616,7 @@ class Error(Exception): with pytest.raises(Error): await client._run_once(client_session) - client._dispatch.assert_any_call("CONNECTED", {}) + client._event_consumer.assert_any_call(client, "CONNECTED", {}) @hikari_test_helpers.timeout() async def test_heartbeat_is_not_started_before_handshake_completes(self, client, client_session): @@ -665,14 +664,14 @@ async def test_heartbeat_is_stopped_when_poll_events_stops(self, client, client_ async def test_dispatches_disconnect_if_connected(self, client, client_session): await client._run_once(client_session) - client._dispatch.assert_any_call("CONNECTED", {}) - client._dispatch.assert_any_call("DISCONNECTED", {}) + client._event_consumer.assert_any_call(client, "CONNECTED", {}) + client._event_consumer.assert_any_call(client, "DISCONNECTED", {}) async def test_no_dispatch_disconnect_if_not_connected(self, client, client_session): client_session.ws_connect = mock.Mock(side_effect=RuntimeError) with pytest.raises(RuntimeError): await client._run_once(client_session) - client._dispatch.assert_not_called() + client._event_consumer.assert_not_called() async def test_connected_at_reset_to_None_on_exit(self, client, client_session): await client._run_once(client_session) @@ -1034,14 +1033,13 @@ async def test_when_opcode_is_DISPATCH_and_event_is_READY(self, client, exit_err "s": 101, } client._receive_json = mock.AsyncMock(side_effect=[payload, exit_error]) - client._dispatch = mock.Mock() timestamp = datetime.datetime.now() with mock.patch.object(hikari_date, "monotonic", return_value=timestamp): with pytest.raises(exit_error): await client._poll_events() - client._dispatch.assert_called_once_with("READY", data_payload) + client._event_consumer.assert_any_call(client, "READY", data_payload) assert client._handshake_event.is_set() assert client._session_id == 123 assert client._seq == 101 @@ -1057,12 +1055,11 @@ async def test_when_opcode_is_DISPATCH_and_event_is_RESUME(self, client, exit_er "s": 101, } client._receive_json = mock.AsyncMock(side_effect=[payload, exit_error]) - client._dispatch = mock.Mock() with pytest.raises(exit_error): await client._poll_events() - client._dispatch.assert_called_once_with("RESUME", "some data") + client._event_consumer.assert_any_call(client, "RESUME", "some data") assert client._handshake_event.is_set() @hikari_test_helpers.timeout() @@ -1074,12 +1071,11 @@ async def test_when_opcode_is_DISPATCH_and_event_is_not_handled(self, client, ex "s": 101, } client._receive_json = mock.AsyncMock(side_effect=[payload, exit_error]) - client._dispatch = mock.Mock() with pytest.raises(exit_error): await client._poll_events() - client._dispatch.assert_called_once_with("UNKNOWN", "some data") + client._event_consumer.assert_any_call(client, "UNKNOWN", "some data") @hikari_test_helpers.timeout() async def test_when_opcode_is_HEARTBEAT(self, client, exit_error): @@ -1284,21 +1280,6 @@ async def test_send_json(self, client): client._ws.send_str.assert_awaited_once_with('{"some": "payload"}') -class TestDispatch: - def test_dispatch(self, client): - mock_task = object() - mock_coroutine = object() - client._app = mock.Mock() - client._event_consumer = mock.Mock(return_value=mock_coroutine) - client._shard_id = 123 - - with mock.patch.object(asyncio, "create_task", return_value=mock_task) as create_task: - assert client._dispatch("MESSAGE_CREATE", {"some": "payload"}) == mock_task - - client._event_consumer.assert_called_once_with(client, "MESSAGE_CREATE", {"some": "payload"}) - create_task.assert_called_once_with(mock_coroutine, name="gateway shard 123 dispatch MESSAGE_CREATE") - - class TestLogDebugPayload: def test_when_logging_debug_disabled(self, client): client._logger.isEnabledFor = mock.Mock(return_value=False)