diff --git a/hikari/api/event_dispatcher.py b/hikari/api/event_dispatcher.py index 44c194f0a3..bbef968b05 100644 --- a/hikari/api/event_dispatcher.py +++ b/hikari/api/event_dispatcher.py @@ -35,8 +35,8 @@ EventT_co = typing.TypeVar("EventT_co", bound=base_events.Event, covariant=True) EventT_inv = typing.TypeVar("EventT_inv", bound=base_events.Event) - PredicateT = typing.Callable[[EventT_co], typing.Union[bool, typing.Coroutine[typing.Any, typing.Any, bool]]] - AsyncCallbackT = typing.Callable[[EventT_inv], typing.Coroutine[typing.Any, typing.Any, None]] + PredicateT = typing.Callable[[EventT_co], bool] + CallbackT = typing.Callable[[EventT_inv], typing.Coroutine[typing.Any, typing.Any, None]] class EventDispatcher(abc.ABC): @@ -134,9 +134,7 @@ async def on_everyone_mentioned(event): # For the sake of UX, I will check this at runtime instead and let the # user use a static type checker. @abc.abstractmethod - def subscribe( - self, event_type: typing.Type[typing.Any], callback: AsyncCallbackT[typing.Any] - ) -> AsyncCallbackT[typing.Any]: + def subscribe(self, event_type: typing.Type[typing.Any], callback: CallbackT[typing.Any]) -> CallbackT[typing.Any]: """Subscribe a given callback to a given event type. Parameters @@ -180,7 +178,7 @@ async def on_message(event): # For the sake of UX, I will check this at runtime instead and let the # user use a static type checker. @abc.abstractmethod - def unsubscribe(self, event_type: typing.Type[typing.Any], callback: AsyncCallbackT[typing.Any]) -> None: + def unsubscribe(self, event_type: typing.Type[typing.Any], callback: CallbackT[typing.Any]) -> None: """Unsubscribe a given callback from a given event type, if present. Parameters @@ -210,7 +208,7 @@ async def on_message(event): @abc.abstractmethod def get_listeners( self, event_type: typing.Type[EventT_co], *, polymorphic: bool = True, - ) -> typing.Collection[AsyncCallbackT[EventT_co]]: + ) -> typing.Collection[CallbackT[EventT_co]]: """Get the listeners for a given event type, if there are any. Parameters @@ -240,7 +238,7 @@ def get_listeners( @abc.abstractmethod def listen( self, event_type: typing.Optional[typing.Type[EventT_co]] = None, - ) -> typing.Callable[[AsyncCallbackT[EventT_co]], AsyncCallbackT[EventT_co]]: + ) -> typing.Callable[[CallbackT[EventT_co]], CallbackT[EventT_co]]: """Generate a decorator to subscribe a callback to an event type. This is a second-order decorator. @@ -285,11 +283,13 @@ async def wait_for( The event type to listen for. This will listen for subclasses of this type additionally. predicate - A function or coroutine taking the event as the single parameter. + A function taking the event as the single parameter. This should return `builtins.True` if the event is one you want to return, or `builtins.False` if the event should not be returned. If left as `None` (the default), then the first matching event type that the bot receives (or any subtype) will be the one returned. + + ASYNC PREDICATES ARE NOT SUPPORTED. timeout : typing.Optional[builtins.float or builtins.int] The amount of time to wait before raising an `asyncio.TimeoutError` and giving up instead. This is measured in seconds. If diff --git a/hikari/cli.py b/hikari/cli.py index 59a970f7f7..08070c0f75 100644 --- a/hikari/cli.py +++ b/hikari/cli.py @@ -48,3 +48,4 @@ def main() -> None: sys.stderr.write(f"hikari v{version} {sha1}\n") sys.stderr.write(f"located at {path}\n") sys.stderr.write(f"{py_impl} {py_ver} {py_compiler}\n") + sys.stderr.write(" ".join(frag.strip() for frag in platform.uname() if frag and frag.strip()) + "\n") diff --git a/hikari/events/base_events.py b/hikari/events/base_events.py index a073c05d98..3dfdef9543 100644 --- a/hikari/events/base_events.py +++ b/hikari/events/base_events.py @@ -162,21 +162,6 @@ class ExceptionEvent(Event, typing.Generic[FailedEventT]): side-effects on the application runtime. """ - app: traits.RESTAware = attr.ib(metadata={attr_extensions.SKIP_DEEP_COPY: True}) - # <>. - - shard: typing.Optional[gateway_shard.GatewayShard] = attr.ib(metadata={attr_extensions.SKIP_DEEP_COPY: True}) - """Shard that received the event. - - Returns - ------- - hikari.api.shard.GatewayShard - Shard that raised this exception. - - This may be `builtins.None` if no specific shard was the cause of this - exception (e.g. when starting up or shutting down). - """ - exception: Exception = attr.ib() """Exception that was raised. @@ -201,6 +186,28 @@ class ExceptionEvent(Event, typing.Generic[FailedEventT]): # for us to remove this effect. This functionally changes nothing but it helps MyPy. _failed_callback: FailedCallbackT[FailedEventT] = attr.ib() + @property + def app(self) -> traits.RESTAware: + # <>. + return self.failed_event.app + + @property + def shard(self) -> typing.Optional[gateway_shard.GatewayShard]: + """Shard that received the event, if there was one associated. + + Returns + ------- + typing.Optional[hikari.api.shard.GatewayShard] + Shard that raised this exception. + + This may be `builtins.None` if no specific shard was the cause of this + exception (e.g. when starting up or shutting down). + """ + shard = getattr(self.failed_event, "shard", None) + if isinstance(shard, gateway_shard.GatewayShard): + return shard + return None + @property def failed_callback(self) -> FailedCallbackT[FailedEventT]: """Event callback that threw an exception. diff --git a/hikari/impl/bot.py b/hikari/impl/bot.py index e3fdb4720b..f6c13da524 100644 --- a/hikari/impl/bot.py +++ b/hikari/impl/bot.py @@ -57,6 +57,7 @@ from hikari.impl import stateless_event_manager from hikari.impl import stateless_guild_chunker as stateless_guild_chunker_impl from hikari.impl import voice +from hikari.utilities import aio from hikari.utilities import art from hikari.utilities import constants from hikari.utilities import date @@ -574,30 +575,30 @@ async def start(self) -> None: def listen( self, event_type: typing.Optional[typing.Type[event_dispatcher.EventT_co]] = None, ) -> typing.Callable[ - [event_dispatcher.AsyncCallbackT[event_dispatcher.EventT_co]], - event_dispatcher.AsyncCallbackT[event_dispatcher.EventT_co], + [event_dispatcher.CallbackT[event_dispatcher.EventT_co]], + event_dispatcher.CallbackT[event_dispatcher.EventT_co], ]: # <> return self.dispatcher.listen(event_type) def get_listeners( self, event_type: typing.Type[event_dispatcher.EventT_co], *, polymorphic: bool = True, - ) -> typing.Collection[event_dispatcher.AsyncCallbackT[event_dispatcher.EventT_co]]: + ) -> typing.Collection[event_dispatcher.CallbackT[event_dispatcher.EventT_co]]: # <> return self.dispatcher.get_listeners(event_type, polymorphic=polymorphic) def subscribe( self, event_type: typing.Type[event_dispatcher.EventT_co], - callback: event_dispatcher.AsyncCallbackT[event_dispatcher.EventT_co], - ) -> event_dispatcher.AsyncCallbackT[event_dispatcher.EventT_co]: + callback: event_dispatcher.CallbackT[event_dispatcher.EventT_co], + ) -> event_dispatcher.CallbackT[event_dispatcher.EventT_co]: # <> return self.dispatcher.subscribe(event_type, callback) def unsubscribe( self, event_type: typing.Type[event_dispatcher.EventT_co], - callback: event_dispatcher.AsyncCallbackT[event_dispatcher.EventT_co], + callback: event_dispatcher.CallbackT[event_dispatcher.EventT_co], ) -> None: # <> return self.dispatcher.unsubscribe(event_type, callback) @@ -642,7 +643,12 @@ async def close(self) -> None: await self._connector_factory.close() self._global_ratelimit.close() - def run(self) -> None: + def run( + self, + *, + loop: typing.Optional[asyncio.AbstractEventLoop] = None, + slow_callback_duration: typing.Optional[float] = None, + ) -> None: """Run this application on the current thread in an event loop. This will use the event loop that is set for the current thread, or @@ -658,19 +664,47 @@ def run(self) -> None: The application is always guaranteed to be shut down before this function completes or propagates any exception. + + Parameters + ---------- + loop : typing.Optional[asyncio.AbstractEventLoop] + Event loop to run on. This defaults to `builtins.None`. + + If `builtins.None`, the event loop set for the current thread will + be used. If the thread does not have an event loop, then one will + be created first and registered to the running thread. + + It is advisable to only have one event loop per thread. Generally + you should not have a need to specify this. + slow_callback_duration : typing.Optional[builtins.float] + How long a coroutine should block for in seconds before it shows a + warning. + + This defaults to being `builtins.None`, which will disable the + feature (since it may cause a small increase in execution latency). + If specified as a number, it will be enabled. """ - try: - loop = asyncio.get_event_loop() - except RuntimeError: - _LOGGER.debug("no event loop registered on this thread; now creating one...") - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) + if loop is None: + try: + loop = asyncio.get_event_loop() + _LOGGER.debug("using default thread's event loop") + except RuntimeError: + _LOGGER.debug("no event loop registered on this thread; now creating one...") + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + # We always expect this to be populated by now. + loop: asyncio.AbstractEventLoop + + if slow_callback_duration and slow_callback_duration > 0: + aio.patch_slow_callback_detection(slow_callback_duration) try: self._map_signal_handlers( loop.add_signal_handler, lambda *_: loop.create_task(self.close(), name="signal interrupt shutting down application"), ) + _LOGGER.debug("using default thread's event loop", loop) loop.run_until_complete(self._shard_management_lifecycle()) except KeyboardInterrupt as ex: diff --git a/hikari/impl/event_manager_base.py b/hikari/impl/event_manager_base.py index 0ed4d90c30..7799f55e62 100644 --- a/hikari/impl/event_manager_base.py +++ b/hikari/impl/event_manager_base.py @@ -36,7 +36,6 @@ from hikari import traits from hikari.api import event_dispatcher from hikari.events import base_events -from hikari.events import shard_events from hikari.utilities import aio from hikari.utilities import data_binding from hikari.utilities import reflect @@ -50,7 +49,7 @@ if typing.TYPE_CHECKING: ListenerMapT = typing.MutableMapping[ typing.Type[event_dispatcher.EventT_co], - typing.MutableSequence[event_dispatcher.AsyncCallbackT[event_dispatcher.EventT_co]], + typing.MutableSequence[event_dispatcher.CallbackT[event_dispatcher.EventT_co]], ] WaiterT = typing.Tuple[ event_dispatcher.PredicateT[event_dispatcher.EventT_co], asyncio.Future[event_dispatcher.EventT_co] @@ -79,7 +78,7 @@ def __init__(self, app: traits.BotAware, intents: typing.Optional[intents_.Inten self._listeners: ListenerMapT[base_events.Event] = {} self._waiters: WaiterMapT[base_events.Event] = {} - async def consume_raw_event( + def consume_raw_event( self, shard: gateway_shard.GatewayShard, event_name: str, payload: data_binding.JSONObject ) -> None: try: @@ -87,21 +86,21 @@ async def consume_raw_event( except AttributeError: _LOGGER.debug("ignoring unknown event %s", event_name) else: - await callback(shard, payload) + asyncio.create_task(callback(shard, payload)) def subscribe( self, event_type: typing.Type[event_dispatcher.EventT_co], - callback: event_dispatcher.AsyncCallbackT[event_dispatcher.EventT_co], + callback: event_dispatcher.CallbackT[event_dispatcher.EventT_co], *, _nested: int = 0, - ) -> event_dispatcher.AsyncCallbackT[event_dispatcher.EventT_co]: - if not asyncio.iscoroutinefunction(callback): - raise TypeError("Event callbacks must be coroutine functions (`async def')") - - if not inspect.isclass(event_type) or not issubclass(event_type, base_events.Event): + ) -> event_dispatcher.CallbackT[event_dispatcher.EventT_co]: + if not issubclass(event_type, base_events.Event): raise TypeError("Cannot subscribe to a non-Event type") + if not inspect.iscoroutinefunction(callback): + raise TypeError("Cannot subscribe a non-coroutine function callback") + # `_nested` is used to show the correct source code snippet if an intent # warning is triggered. self._check_intents(event_type, _nested) @@ -144,12 +143,9 @@ def _check_intents(self, event_type: typing.Type[event_dispatcher.EventT_co], ne def get_listeners( self, event_type: typing.Type[event_dispatcher.EventT_co], *, polymorphic: bool = True, - ) -> typing.Collection[event_dispatcher.AsyncCallbackT[event_dispatcher.EventT_co]]: - if not inspect.isclass(event_type) or not issubclass(event_type, base_events.Event): - raise TypeError(f"Can only get listeners for subclasses of {base_events.Event.__name__}") - + ) -> typing.Collection[event_dispatcher.CallbackT[event_dispatcher.EventT_co]]: if polymorphic: - listeners: typing.List[event_dispatcher.AsyncCallbackT[event_dispatcher.EventT_co]] = [] + listeners: typing.List[event_dispatcher.CallbackT[event_dispatcher.EventT_co]] = [] for subscribed_event_type, subscribed_listeners in self._listeners.items(): if issubclass(subscribed_event_type, event_type): listeners += subscribed_listeners @@ -164,7 +160,7 @@ def get_listeners( def unsubscribe( self, event_type: typing.Type[event_dispatcher.EventT_co], - callback: event_dispatcher.AsyncCallbackT[event_dispatcher.EventT_co], + callback: event_dispatcher.CallbackT[event_dispatcher.EventT_co], ) -> None: if event_type in self._listeners: _LOGGER.debug( @@ -181,12 +177,12 @@ def unsubscribe( def listen( self, event_type: typing.Optional[typing.Type[event_dispatcher.EventT_co]] = None, ) -> typing.Callable[ - [event_dispatcher.AsyncCallbackT[event_dispatcher.EventT_co]], - event_dispatcher.AsyncCallbackT[event_dispatcher.EventT_co], + [event_dispatcher.CallbackT[event_dispatcher.EventT_co]], + event_dispatcher.CallbackT[event_dispatcher.EventT_co], ]: def decorator( - callback: event_dispatcher.AsyncCallbackT[event_dispatcher.EventT_co], - ) -> event_dispatcher.AsyncCallbackT[event_dispatcher.EventT_co]: + callback: event_dispatcher.CallbackT[event_dispatcher.EventT_co], + ) -> event_dispatcher.CallbackT[event_dispatcher.EventT_co]: nonlocal event_type signature = reflect.resolve_signature(callback) @@ -220,42 +216,31 @@ def dispatch(self, event: event_dispatcher.EventT_inv) -> asyncio.Future[typing. tasks: typing.List[typing.Coroutine[None, typing.Any, None]] = [] for cls in mro[: mro.index(base_events.Event) + 1]: - if cls in self._listeners: for callback in self._listeners[cls]: tasks.append(self._invoke_callback(callback, event)) if cls in self._waiters: - for predicate, future in self._waiters[cls]: - tasks.append(self._test_waiter(event, predicate, future)) # type: ignore[misc] + for predicate, future in tuple(self._waiters[cls]): + try: + result = predicate(event) + if not result: + continue + except Exception as ex: + future.set_exception(ex) + else: + future.set_result(event) + + waiter_set = self._waiters[cls] + waiter_set.remove((predicate, future)) return asyncio.gather(*tasks) if tasks else aio.completed_future() - @staticmethod - async def _test_waiter( - event: event_dispatcher.EventT_inv, - predicate: event_dispatcher.PredicateT[event_dispatcher.EventT_inv], - future: asyncio.Future[event_dispatcher.EventT_inv], - ) -> None: - try: - result = predicate(event) - if asyncio.iscoroutine(result): - result = await result # type: ignore - - if not result: - return - - except Exception as ex: - future.set_exception(ex) - else: - future.set_result(event) - async def _invoke_callback( - self, callback: event_dispatcher.AsyncCallbackT[event_dispatcher.EventT_inv], event: event_dispatcher.EventT_inv + self, callback: event_dispatcher.CallbackT[event_dispatcher.EventT_inv], event: event_dispatcher.EventT_inv ) -> None: try: await callback(event) - except Exception as ex: # Skip the first frame in logs, we don't care for it. trio = type(ex), ex, ex.__traceback__.tb_next if ex.__traceback__ is not None else None @@ -268,11 +253,7 @@ async def _invoke_callback( ) else: exception_event = base_events.ExceptionEvent( - app=self._app, - shard=getattr(event, "shard") if isinstance(event, shard_events.ShardEvent) else None, - exception=ex, - failed_event=event, - failed_callback=callback, + exception=ex, failed_event=event, failed_callback=callback, ) log = _LOGGER.debug if self.get_listeners(type(exception_event), polymorphic=True) else _LOGGER.error @@ -304,13 +285,7 @@ async def wait_for( waiter_set.add(pair) # type: ignore[arg-type] - try: - if timeout is not None: - return await asyncio.wait_for(future, timeout=timeout) - else: - return await future - - finally: - waiter_set.remove(pair) # type: ignore[arg-type] - if not waiter_set: - del self._waiters[event_type] + if timeout is not None: + return await asyncio.wait_for(future, timeout=timeout) + else: + return await future diff --git a/hikari/impl/shard.py b/hikari/impl/shard.py index ae9a65171b..50e26e1495 100644 --- a/hikari/impl/shard.py +++ b/hikari/impl/shard.py @@ -75,11 +75,11 @@ class GatewayShardImplV6(shard.GatewayShard): logs. If `builtins.False`, only the fact that data has been sent/received will be logged. event_consumer - A coroutine function consuming a `GatewayShardImplV6`, + A non-coroutine function consuming a `GatewayShardImplV6`, a `builtins.str` event name, and a `hikari.utilities.data_binding.JSONObject` event object as parameters. - This should return `builtins.None`, and will be called asynchronously - with each event that fires. + This should return `builtins.None`, and will be called with each event + that fires. http_settings : hikari.config.HTTPSettings The HTTP-related settings to use while negotiating a websocket. initial_activity : typing.Optional[hikari.presences.Activity] @@ -216,9 +216,7 @@ def __init__( compression: typing.Optional[str] = shard.GatewayCompression.PAYLOAD_ZLIB_STREAM, data_format: str = shard.GatewayDataFormat.JSON, debug: bool = False, - event_consumer: typing.Callable[ - [shard.GatewayShard, str, data_binding.JSONObject], typing.Coroutine[None, None, None] - ], + event_consumer: typing.Callable[[shard.GatewayShard, str, data_binding.JSONObject], None], http_settings: config.HTTPSettings, initial_activity: typing.Optional[presences.Activity] = None, initial_idle_since: typing.Optional[datetime.datetime] = None, @@ -590,7 +588,7 @@ async def _run_once(self, client_session: aiohttp.ClientSession) -> None: # Technically we are connected after the hello, but this ensures we can send and receive # before firing that event. - self._dispatch("CONNECTED", {}) + self._event_consumer(self, "CONNECTED", {}) try: @@ -613,12 +611,12 @@ async def _run_once(self, client_session: aiohttp.ClientSession) -> None: finally: heartbeat.cancel() finally: - self._dispatch("DISCONNECTED", {}) + self._event_consumer(self, "DISCONNECTED", {}) self._connected_at = None async def _close_ws(self, code: int, message: str) -> None: self._logger.debug("sending close frame with code %s and message %r", int(code), message) - # None if the websocket error'ed on initialization. + # None if the websocket error on initialization. if self._ws is not None: await self._ws.close(code=code, message=bytes(message, "utf-8")) @@ -751,7 +749,7 @@ async def _poll_events(self) -> None: self._logger.info("shard has resumed [session:%s, seq:%s]", self._session_id, self._seq) self._handshake_event.set() - self._dispatch(event, data) + self._event_consumer(self, event, data) elif op == self._Opcode.HEARTBEAT: self._logger.debug("received HEARTBEAT; sending HEARTBEAT ACK") @@ -864,11 +862,6 @@ async def _send_json(self, payload: data_binding.JSONObject) -> None: self._log_debug_payload(message, "sending json payload [op:%s]", payload.get("op")) await self._ws.send_str(message) # type: ignore[union-attr] - def _dispatch(self, event_name: str, event: data_binding.JSONObject) -> asyncio.Task[None]: - return asyncio.create_task( - self._event_consumer(self, event_name, event), name=f"gateway shard {self._shard_id} dispatch {event_name}", - ) - def _log_debug_payload(self, payload: str, message: str, *args: typing.Any) -> None: # Prevent logging these payloads if logging is not enabled. This aids performance a little. if not self._logger.isEnabledFor(logging.DEBUG): diff --git a/hikari/utilities/aio.py b/hikari/utilities/aio.py index 8c5ac19e9c..ac953752be 100644 --- a/hikari/utilities/aio.py +++ b/hikari/utilities/aio.py @@ -23,15 +23,60 @@ from __future__ import annotations -__all__: typing.Final[typing.List[str]] = ["completed_future", "is_async_iterator", "is_async_iterable"] +__all__: typing.Final[typing.List[str]] = [ + "patch_slow_callback_detection", + "completed_future", + "is_async_iterator", + "is_async_iterable", +] import asyncio import inspect +import logging import typing +from hikari.utilities import date + T_co = typing.TypeVar("T_co", covariant=True) T_inv = typing.TypeVar("T_inv") +_LOGGER: typing.Final[logging.Logger] = logging.getLogger(__name__) + + +@typing.no_type_check +def patch_slow_callback_detection(duration: float = 0.5) -> None: + """Patches some asyncio internals to allow detection of slow callbacks. + + Any callbacks that take more than the given `duration` will be logged. + + Parameters + ---------- + duration : float + The duration to wait for before classifying a task as being "slow". + """ + _LOGGER.debug("setting slow callback duration for loop to %.0fms", duration * 1_000) + + original_run = getattr(asyncio.Handle, "_run") + + def stringify(self: asyncio.Handle) -> str: + if _LOGGER.isEnabledFor(logging.WARNING): + if isinstance(getattr(self._callback, "__self__", None), asyncio.Task): + return repr(self._callback.__self__) + return str(self._callback) + return "" + + def run(self: asyncio.Handle) -> None: + """Instrumented runner.""" + start = date.monotonic() + try: + original_run(self) + finally: + period = date.monotonic() - start + if period >= duration: + _LOGGER.warning("Callback %s blocked for %.1fms!", stringify(self), period * 1_000) + + setattr(asyncio.Handle, "_run", run) + def completed_future(result: typing.Optional[T_inv] = None, /) -> asyncio.Future[typing.Optional[T_inv]]: """Create a future on the current running loop that is completed, then return it. diff --git a/tests/hikari/events/test_base_events.py b/tests/hikari/events/test_base_events.py index 72f3a0d9fb..2c87e02220 100644 --- a/tests/hikari/events/test_base_events.py +++ b/tests/hikari/events/test_base_events.py @@ -23,12 +23,13 @@ import pytest from hikari import intents +from hikari.api import shard as gateway_shard from hikari.events import base_events @base_events.requires_intents(intents.Intents.GUILDS) @attr.s(eq=False, hash=False, init=False, kw_only=True, slots=True) -class DummyGuildEVent(base_events.Event): +class DummyGuildEvent(base_events.Event): pass @@ -46,7 +47,7 @@ class ErrorEvent(base_events.Event): @attr.s(eq=False, hash=False, init=False, kw_only=True, slots=True) -class DummyGuildDerivedEvent(DummyGuildEVent): +class DummyGuildDerivedEvent(DummyGuildEvent): pass @@ -58,12 +59,12 @@ class DummyPresenceDerivedEvent(DummyPresenceEvent): def test_is_no_recursive_throw_event_marked(): assert base_events.is_no_recursive_throw_event(DummyPresenceEvent) assert base_events.is_no_recursive_throw_event(ErrorEvent) - assert not base_events.is_no_recursive_throw_event(DummyGuildEVent) + assert not base_events.is_no_recursive_throw_event(DummyGuildEvent) assert not base_events.is_no_recursive_throw_event(DummyGuildDerivedEvent) def test_requires_intents(): - assert list(base_events.get_required_intents_for(DummyGuildEVent)) == [intents.Intents.GUILDS] + assert list(base_events.get_required_intents_for(DummyGuildEvent)) == [intents.Intents.GUILDS] assert list(base_events.get_required_intents_for(DummyPresenceEvent)) == [intents.Intents.GUILD_PRESENCES] assert list(base_events.get_required_intents_for(ErrorEvent)) == [] @@ -90,13 +91,23 @@ def error(self): @pytest.fixture def event(self, error): return base_events.ExceptionEvent( - app=object(), - shard=object(), - exception=error, - failed_event=mock.Mock(base_events.Event), - failed_callback=mock.AsyncMock(), + exception=error, failed_event=mock.Mock(base_events.Event), failed_callback=mock.AsyncMock(), ) + def test_app_property(self, event): + app = mock.Mock() + event.failed_event.app = app + assert event.app is app + + @pytest.mark.parametrize("has_shard", [True, False]) + def test_shard_property(self, has_shard, event): + shard = mock.Mock(spec_set=gateway_shard.GatewayShard) + if has_shard: + event.failed_event.shard = shard + assert event.shard is shard + else: + assert event.shard is None + def test_failed_callback_property(self, event): stub_callback = object() event._failed_callback = stub_callback diff --git a/tests/hikari/impl/test_event_manager_base.py b/tests/hikari/impl/test_event_manager_base.py index e8ed021a08..d1f4539a03 100644 --- a/tests/hikari/impl/test_event_manager_base.py +++ b/tests/hikari/impl/test_event_manager_base.py @@ -45,18 +45,20 @@ class EventManagerBaseImpl(event_manager_base.EventManagerBase): @pytest.mark.asyncio async def test_consume_raw_event_when_AttributeError(self, event_manager): with mock.patch.object(event_manager_base, "_LOGGER") as logger: - await event_manager.consume_raw_event(None, "UNEXISTING_EVENT", {}) + event_manager.consume_raw_event(None, "UNEXISTING_EVENT", {}) logger.debug.assert_called_once_with("ignoring unknown event %s", "UNEXISTING_EVENT") @pytest.mark.asyncio async def test_consume_raw_event_when_found(self, event_manager): - event_manager.on_existing_event = mock.AsyncMock() + event_manager.on_existing_event = mock.Mock() shard = object() - await event_manager.consume_raw_event(shard, "EXISTING_EVENT", {}) + with mock.patch("asyncio.create_task") as create_task: + event_manager.consume_raw_event(shard, "EXISTING_EVENT", {}) - event_manager.on_existing_event.assert_awaited_once_with(shard, {}) + event_manager.on_existing_event.assert_called_once_with(shard, {}) + create_task.assert_called_once_with(event_manager.on_existing_event(shard, {})) def test_subscribe_when_callback_is_not_coroutine(self, event_manager): def test(): @@ -143,8 +145,7 @@ def test__check_intents_when_intents_incorrect(self, event_manager): ) def test_get_listeners_when_not_event(self, event_manager): - with pytest.raises(TypeError): - event_manager.get_listeners("test") + assert len(event_manager.get_listeners("test")) == 0 def test_get_listeners_polimorphic(self, event_manager): event_manager._listeners = { diff --git a/tests/hikari/impl/test_shard.py b/tests/hikari/impl/test_shard.py index 4667eaf2ac..19d1b6b78a 100644 --- a/tests/hikari/impl/test_shard.py +++ b/tests/hikari/impl/test_shard.py @@ -292,10 +292,9 @@ def client(self, http_settings=http_settings, proxy_settings=proxy_settings): ) client = hikari_test_helpers.mock_methods_on( client, - except_=("_run_once_shielded", "_InvalidSession", "_Reconnect", "_SocketClosed", "_dispatch", "_Opcode",), - also_mock=["_backoff", "_handshake_event", "_request_close_event", "_logger"], + except_=("_run_once_shielded", "_InvalidSession", "_Reconnect", "_SocketClosed", "_Opcode",), + also_mock=["_backoff", "_handshake_event", "_request_close_event", "_logger", "_event_consumer"], ) - client._dispatch = mock.AsyncMock() # Disable backoff checking by making the condition a negative tautology. client._RESTART_RATELIMIT_WINDOW = -1 return client @@ -458,7 +457,7 @@ def client(self, http_settings, proxy_settings): client = hikari_test_helpers.mock_methods_on( client, except_=("_run_once", "_InvalidSession", "_Reconnect", "_SocketClosed", "_Opcode",), - also_mock=["_backoff", "_handshake_event", "_request_close_event", "_logger",], + also_mock=["_backoff", "_handshake_event", "_request_close_event", "_logger", "_event_consumer"], ) # Disable backoff checking by making the condition a negative tautology. client._RESTART_RATELIMIT_WINDOW = -1 @@ -617,7 +616,7 @@ class Error(Exception): with pytest.raises(Error): await client._run_once(client_session) - client._dispatch.assert_any_call("CONNECTED", {}) + client._event_consumer.assert_any_call(client, "CONNECTED", {}) @hikari_test_helpers.timeout() async def test_heartbeat_is_not_started_before_handshake_completes(self, client, client_session): @@ -665,14 +664,14 @@ async def test_heartbeat_is_stopped_when_poll_events_stops(self, client, client_ async def test_dispatches_disconnect_if_connected(self, client, client_session): await client._run_once(client_session) - client._dispatch.assert_any_call("CONNECTED", {}) - client._dispatch.assert_any_call("DISCONNECTED", {}) + client._event_consumer.assert_any_call(client, "CONNECTED", {}) + client._event_consumer.assert_any_call(client, "DISCONNECTED", {}) async def test_no_dispatch_disconnect_if_not_connected(self, client, client_session): client_session.ws_connect = mock.Mock(side_effect=RuntimeError) with pytest.raises(RuntimeError): await client._run_once(client_session) - client._dispatch.assert_not_called() + client._event_consumer.assert_not_called() async def test_connected_at_reset_to_None_on_exit(self, client, client_session): await client._run_once(client_session) @@ -1034,14 +1033,13 @@ async def test_when_opcode_is_DISPATCH_and_event_is_READY(self, client, exit_err "s": 101, } client._receive_json = mock.AsyncMock(side_effect=[payload, exit_error]) - client._dispatch = mock.Mock() timestamp = datetime.datetime.now() with mock.patch.object(hikari_date, "monotonic", return_value=timestamp): with pytest.raises(exit_error): await client._poll_events() - client._dispatch.assert_called_once_with("READY", data_payload) + client._event_consumer.assert_any_call(client, "READY", data_payload) assert client._handshake_event.is_set() assert client._session_id == 123 assert client._seq == 101 @@ -1057,12 +1055,11 @@ async def test_when_opcode_is_DISPATCH_and_event_is_RESUME(self, client, exit_er "s": 101, } client._receive_json = mock.AsyncMock(side_effect=[payload, exit_error]) - client._dispatch = mock.Mock() with pytest.raises(exit_error): await client._poll_events() - client._dispatch.assert_called_once_with("RESUME", "some data") + client._event_consumer.assert_any_call(client, "RESUME", "some data") assert client._handshake_event.is_set() @hikari_test_helpers.timeout() @@ -1074,12 +1071,11 @@ async def test_when_opcode_is_DISPATCH_and_event_is_not_handled(self, client, ex "s": 101, } client._receive_json = mock.AsyncMock(side_effect=[payload, exit_error]) - client._dispatch = mock.Mock() with pytest.raises(exit_error): await client._poll_events() - client._dispatch.assert_called_once_with("UNKNOWN", "some data") + client._event_consumer.assert_any_call(client, "UNKNOWN", "some data") @hikari_test_helpers.timeout() async def test_when_opcode_is_HEARTBEAT(self, client, exit_error): @@ -1284,21 +1280,6 @@ async def test_send_json(self, client): client._ws.send_str.assert_awaited_once_with('{"some": "payload"}') -class TestDispatch: - def test_dispatch(self, client): - mock_task = object() - mock_coroutine = object() - client._app = mock.Mock() - client._event_consumer = mock.Mock(return_value=mock_coroutine) - client._shard_id = 123 - - with mock.patch.object(asyncio, "create_task", return_value=mock_task) as create_task: - assert client._dispatch("MESSAGE_CREATE", {"some": "payload"}) == mock_task - - client._event_consumer.assert_called_once_with(client, "MESSAGE_CREATE", {"some": "payload"}) - create_task.assert_called_once_with(mock_coroutine, name="gateway shard 123 dispatch MESSAGE_CREATE") - - class TestLogDebugPayload: def test_when_logging_debug_disabled(self, client): client._logger.isEnabledFor = mock.Mock(return_value=False)