From 0abf17c2c5c71a5b1643b8ccd28ab69cab3bb49f Mon Sep 17 00:00:00 2001 From: "Nekoka.tt" <3903853-nekokatt@users.noreply.gitlab.com> Date: Wed, 16 Sep 2020 21:23:17 +0100 Subject: [PATCH 01/12] Fixed OSError caused by Windows in spurious cases on socket closure. In the process, I also: - Added a load of validation and documentation to settings config classes. - Made it so the pypi checker supports proxies. - Removed the verify_ssl and allow_redirects and merged them into a new ssl (ssl.SSLContext) and max_redirects parameter. - Creating standard TCPConnector and ClientSession instances now is done in hikari.utilities.net. - READY logging message for a shard now shows how many guilds that shard has, approximately. --- hikari/config.py | 304 +++++++++++++++++++++++++++----- hikari/impl/bot.py | 9 +- hikari/impl/rest.py | 41 +++-- hikari/impl/shard.py | 41 ++--- hikari/utilities/net.py | 102 ++++++++++- hikari/utilities/ux.py | 30 +++- tests/hikari/impl/test_shard.py | 32 ++-- 7 files changed, 444 insertions(+), 115 deletions(-) diff --git a/hikari/config.py b/hikari/config.py index 220e4e5839..cf309ca7d0 100644 --- a/hikari/config.py +++ b/hikari/config.py @@ -23,6 +23,8 @@ from __future__ import annotations +import yarl + __all__: typing.Final[typing.List[str]] = [ "BasicAuthHeader", "ProxySettings", @@ -31,6 +33,7 @@ ] import base64 +import ssl as ssl_ import typing import attr @@ -42,52 +45,136 @@ _PROXY_AUTHENTICATION_HEADER: typing.Final[str] = "Proxy-Authentication" +def _ssl_factory(value: typing.Union[bool, ssl_.SSLContext]) -> ssl_.SSLContext: + if isinstance(value, bool): + ssl = ssl_.create_default_context() + # We can't turn SSL verification off without disabling hostname verification first. + # If we are using verification, this will just leave it enabled, so it is fine. + ssl.check_hostname = value + ssl.verify_mode = ssl_.CERT_REQUIRED if value else ssl_.CERT_NONE + else: + ssl = value + return ssl + + @attr_extensions.with_copy -@attr.s(slots=True, kw_only=True, repr=False, weakref_slot=False) +@attr.s(slots=True, kw_only=True, repr=True, weakref_slot=False) class BasicAuthHeader: """An object that can be set as a producer for a basic auth header.""" - username: str = attr.ib() - """Username for the header.""" + username: str = attr.ib(validator=attr.validators.instance_of(str)) + """Username for the header. + + Returns + ------- + builtins.str + The username to use. This must not contain `":"`. + """ + + password: str = attr.ib(repr=False, validator=attr.validators.instance_of(str)) + """Password to use. + + Returns + ------- + builtins.str + The password to use. + """ + + charset: str = attr.ib(default="utf-8", validator=attr.validators.instance_of(str)) + """Encoding to use for the username and password. - password: str = attr.ib() - """Password for the header.""" + Default is `"utf-8"`, but you may choose to use something else, + including third-party encodings (e.g. IBM's EBCDIC codepages). + + Returns + ------- + builtins.str + The encoding to use. + """ @property def header(self) -> str: - """Generate the header value and return it.""" - raw_token = f"{self.username}:{self.password}".encode("ascii") - token_part = base64.b64encode(raw_token).decode("ascii") + """Create the full `Authentication` header value. + + Returns + ------- + builtins.str + A base64-encoded string containing + `"{username}:{password}`. + """ + raw_token = f"{self.username}:{self.password}".encode(self.charset) + token_part = base64.b64encode(raw_token).decode(self.charset) return f"{_BASICAUTH_TOKEN_PREFIX} {token_part}" def __str__(self) -> str: return self.header - __repr__ = __str__ - @attr_extensions.with_copy @attr.s(slots=True, kw_only=True, weakref_slot=False) class ProxySettings: - """The proxy settings to use.""" + """Settings for configuring an HTTP-based proxy.""" + + auth: typing.Any = attr.ib(default=None) + """Authentication header value to use. + + When cast to a `builtins.str`, this should provide the full value + for the authentication header. - auth: typing.Optional[typing.Any] = attr.ib(default=None) - """An object that when cast to a string, yields the proxy auth header.""" + If you are using basic auth, you should consider using the + `BasicAuthHeader` helper object here, as this will provide any + transformations you may require into a base64 string. + + The default is to have this set to `builtins.None`, which will + result in no authentication being provided. + + Returns + ------- + typing.Any + The value for the `Authentication` header, or `builtins.None` + to disable. + """ headers: typing.Optional[data_binding.Headers] = attr.ib(default=None) """Additional headers to use for requests via a proxy, if required.""" - url: typing.Optional[str] = attr.ib(default=None) - """The URL of the proxy to use.""" + url: typing.Union[None, str, yarl.URL] = attr.ib(default=None) + """Proxy URL to use. - trust_env: bool = attr.ib(default=False) - """If `builtins.True`, and no proxy info is given, then `HTTP_PROXY` and - `HTTPS_PROXY` will be used from the environment variables if present. + Defaults to `builtins.None` which disables the use of an explicit proxy. + + Returns + ------- + typing.Union[builtins.None, builtins.str, yarl.URL] + The proxy URL to use, or `builtins.None` to disable it. + """ + + @url.validator + def _(self, _: attr.Attribute[typing.Optional[str]], value: typing.Optional[str]) -> None: + if value is not None and not isinstance(value, (str, yarl.URL)): + raise TypeError("ProxySettings.url must be None, a str, or a yarl.URL instance") + + trust_env: bool = attr.ib(default=False, validator=attr.validators.instance_of(bool)) + """Toggle whether to look for a `netrc` file or environment variables. + + If `builtins.True`, and no `url` is given on this object, then + `HTTP_PROXY` and `HTTPS_PROXY` will be used from the environment + variables, or a `netrc` file may be read to determine credentials. - Any proxy credentials will be read from the user's `netrc` file - (https://www.gnu.org/software/inetutils/manual/html_node/The-_002enetrc-file.html) If `builtins.False`, then this information is instead ignored. - Defaults to `builtins.False` if unspecified. + + Defaults to `builtins.False` to prevent potentially unwanted behavior. + + !!! note + For more details of using `netrc`, visit: + https://www.gnu.org/software/inetutils/manual/html_node/The-_002enetrc-file.html + + Returns + ------- + builtins.bool + `builtins.True` if allowing the use of environment variables + and/or `netrc` to determine proxy settings; `builtins.False` + if this should be disabled explicitly. """ @property @@ -116,45 +203,182 @@ class HTTPTimeoutSettings: """Settings to control HTTP request timeouts.""" acquire_and_connect: typing.Optional[float] = attr.ib(default=None) - """Timeout for `request_socket_connect` PLUS connection acquisition.""" + """Timeout for `request_socket_connect` PLUS connection acquisition. + + By default, this has no timeout allocated. + + Returns + ------- + typing.Optional[builtins.float] + The timeout, or `builtins.None` to disable it. + """ request_socket_connect: typing.Optional[float] = attr.ib(default=None) - """Timeout for connecting a socket.""" + """Timeout for connecting a socket. + + By default, this has no timeout allocated. + + Returns + ------- + typing.Optional[builtins.float] + The timeout, or `builtins.None` to disable it. + """ request_socket_read: typing.Optional[float] = attr.ib(default=None) - """Timeout for reading a socket.""" + """Timeout for reading a socket. + + By default, this has no timeout allocated. + + Returns + ------- + typing.Optional[builtins.float] + The timeout, or `builtins.None` to disable it. + """ total: typing.Optional[float] = attr.ib(default=30.0) """Total timeout for entire request. - Defaults to 30 seconds. + By default, this has a 30 second timeout allocated. + + Returns + ------- + typing.Optional[builtins.float] + The timeout, or `builtins.None` to disable it. """ + @acquire_and_connect.validator + @request_socket_connect.validator + @request_socket_read.validator + @total.validator + def _(self, attrib: attr.Attribute[typing.Optional[float]], value: typing.Optional[float]) -> None: + # This error won't occur until some time in the future where it will be annoying to + # try and determine the root cause, so validate it NOW. + if value is not None and (not isinstance(value, (float, int)) or value <= 0): # type: ignore[unreachable] + raise ValueError(f"HTTPTimeoutSettings.{attrib.name} must be None, or a POSITIVE float/int") + @attr_extensions.with_copy @attr.s(slots=True, kw_only=True, weakref_slot=False) class HTTPSettings: - """Settings to control the HTTP client.""" + """Settings to control HTTP clients.""" + + enable_cleanup_closed: bool = attr.ib(default=True, validator=attr.validators.instance_of(bool)) + """Toggle whether to clean up closed transports. - allow_redirects: bool = attr.ib(default=False) - """If `builtins.True`, allow following redirects from `3xx` HTTP responses. + This defaults to `builtins.True` to combat various protocol and asyncio + issues present when using Microsoft Windows. If you are sure you know + what you are doing, you may instead set this to `False` to disable this + behavior internally. - Generally you do not want to enable this unless you have a good reason to. + Returns + ------- + builtins.bool + `builtins.True` to enable this behavior, `builtins.False` to disable + it. """ - max_redirects: int = attr.ib(default=10) - """The maximum number of redirects to allow. + force_close_transports: bool = attr.ib(default=True, validator=attr.validators.instance_of(bool)) + """Toggle whether to force close transports on shutdown. - If `allow_redirects` is `builtins.False`, then this is ignored. + This defaults to `builtins.True` to combat various protocol and asyncio + issues present when using Microsoft Windows. If you are sure you know + what you are doing, you may instead set this to `False` to disable this + behavior internally. + + Returns + ------- + builtins.bool + `builtins.True` to enable this behavior, `builtins.False` to disable + it. """ - timeouts: HTTPTimeoutSettings = attr.ib(factory=HTTPTimeoutSettings) - """Settings to control HTTP request timeouts.""" + max_redirects: typing.Optional[int] = attr.ib(default=10) + """Behavior for handling redirect HTTP responses. + + If a `builtins.int`, allow following redirects from `3xx` HTTP responses + for up to this many redirects. Exceeding this value will raise an + exception. + + If `builtins.None`, then disallow any redirects. + + The default is to disallow this behavior for security reasons. + + Generally, it is safer to keep this disabled. You may find a case in the + future where you need to enable this if Discord change their URL without + warning. + + !!! note + This will only apply to the REST API. WebSockets remain unaffected + by any value set here. + + Returns + ------- + typing.Optional[builtins.int] + The number of redirects to allow at a maximum per request. + `builtins.None` disables the handling + of redirects and will result in exceptions being raised instead + should one occur. + """ + + @max_redirects.validator + def _(self, _: attr.Attribute[typing.Optional[int]], value: typing.Optional[int]) -> None: + # This error won't occur until some time in the future where it will be annoying to + # try and determine the root cause, so validate it NOW. + if value is not None and (not isinstance(value, int) or value <= 0): # type: ignore[unreachable] + raise ValueError("http_settings.max_redirects must be None or a POSITIVE integer") + + ssl: ssl_.SSLContext = attr.ib( + default=True, + converter=_ssl_factory, + validator=attr.validators.instance_of(ssl_.SSLContext), # type: ignore[assignment,arg-type] + ) + """SSL context to use. + + This may be __assigned__ a `builtins.bool` or an `ssl.SSLContext` object. + + If assigned to `builtins.True`, a default SSL context is generated by + this class that will enforce SSL verification. This is then stored in + this field. + + If `builtins.False`, then a default SSL context is generated by this + class that will **NOT** enforce SSL verification. This is then stored + in this field. + + If an instance of `ssl.SSLContext`, then this context will be used. + + !!! warning + Setting a custom value here may have security implications, or + may result in the application being unable to connect to Discord + at all. + + !!! warning + Disabling SSL verification is almost always unadvised. This + is because your application will no longer check whether you are + connecting to Discord, or to some third party spoof designed + to steal personal credentials such as your application token. + + There may be cases where SSL certificates do not get updated, + and in this case, you may find that disabling this explicitly + allows you to work around any issues that are occurring, but + you should immediately seek a better solution where possible + if any form of personal security is in your interest. + + Returns + ------- + ssl.SSLContext + The SSL context to use for this application. + """ + + timeouts: HTTPTimeoutSettings = attr.ib( + factory=HTTPTimeoutSettings, validator=attr.validators.instance_of(HTTPTimeoutSettings) + ) + """Settings to control HTTP request timeouts. + + The behaviour if this is not explicitly defined is to use sane + defaults that are most efficient for optimal use of this library. - verify_ssl: bool = attr.ib(default=True) - """If `builtins.True`, then responses with invalid SSL certificates will be - rejected. Generally you want to keep this enabled unless you have a - problem with SSL and you know exactly what you are doing by disabling - this. Disabling SSL verification can have major security implications. - You turn this off at your own risk. + Returns + ------- + HTTPTimeoutSettings + The HTTP timeout settings to use for connection timeouts. """ diff --git a/hikari/impl/bot.py b/hikari/impl/bot.py index 3c7eeb2684..be412384b9 100644 --- a/hikari/impl/bot.py +++ b/hikari/impl/bot.py @@ -322,7 +322,7 @@ def __init__( # RESTful API. self._rest = rest_impl.RESTClientImpl( - connector_factory=rest_impl.BasicLazyCachedTCPConnectorFactory(), + connector_factory=rest_impl.BasicLazyCachedTCPConnectorFactory(self._http_settings), connector_owner=True, entity_factory=self._entity_factory, executor=self._executor, @@ -768,7 +768,10 @@ async def start( # Dispatch the update checker, the sharding requirements checker, and dispatch # the starting event together to save a little time on startup. if check_for_updates: - asyncio.create_task(ux.check_for_updates(), name="check for package updates") + asyncio.create_task( + ux.check_for_updates(self._http_settings, self._proxy_settings), + name="check for package updates", + ) requirements_task = asyncio.create_task(self._rest.fetch_gateway_bot(), name="fetch gateway sharding settings") await self.dispatch(lifetime_events.StartingEvent(app=self)) requirements = await requirements_task @@ -822,7 +825,7 @@ async def start( # die in this time, we shut down immediately. # If we time out, the joining tasks get discarded and we spin up the next # block of shards, if applicable. - _LOGGER.debug("waiting for 5 seconds until next shard startup window") + _LOGGER.info("the next startup window is in 5 seconds, please wait...") await aio.first_completed(aio.all_of(*shard_joiners, timeout=5), close_waiter) if not close_waiter.cancelled(): diff --git a/hikari/impl/rest.py b/hikari/impl/rest.py index f5ff635017..1cce95950e 100644 --- a/hikari/impl/rest.py +++ b/hikari/impl/rest.py @@ -119,13 +119,11 @@ class BasicLazyCachedTCPConnectorFactory(rest_api.ConnectorFactory): """Lazy cached TCP connector factory.""" - __slots__: typing.Sequence[str] = ("connector", "connector_kwargs") + __slots__: typing.Sequence[str] = ("connector", "http_settings") - def __init__(self, **kwargs: typing.Any) -> None: + def __init__(self, http_settings: config.HTTPSettings) -> None: self.connector: typing.Optional[aiohttp.TCPConnector] = None - kwargs.setdefault("enable_cleanup_closed", True) - kwargs.setdefault("force_close", True) - self.connector_kwargs = kwargs + self.http_settings = http_settings async def close(self) -> None: if self.connector is not None: @@ -134,7 +132,7 @@ async def close(self) -> None: def acquire(self) -> aiohttp.BaseConnector: if self.connector is None: - self.connector = aiohttp.TCPConnector(**self.connector_kwargs) + self.connector = net.create_tcp_connector(self.http_settings) return self.connector @@ -207,6 +205,10 @@ class RESTApp(traits.ExecutorAware): manually. The latter is useful if you wish to maintain a shared connection pool across your application with other non-Hikari components. + + !!! warning + If you do not give a `connector_factory`, this will be IGNORED + and always be treated as `builtins.True` internally. executor : typing.Optional[concurrent.futures.Executor] The executor to use for blocking file IO operations. If `builtins.None` is passed, then the default `concurrent.futures.ThreadPoolExecutor` for @@ -246,6 +248,9 @@ def __init__( proxy_settings: typing.Optional[config.ProxySettings] = None, url: typing.Optional[str] = None, ) -> None: + self._http_settings = config.HTTPSettings() if http_settings is None else http_settings + self._proxy_settings = config.ProxySettings() if proxy_settings is None else proxy_settings + # Lazy initialized later, since we must initialize this in the event # loop we run the application from, otherwise aiohttp throws complaints # at us. Quart, amongst other libraries, causes issues with this by @@ -253,12 +258,14 @@ def __init__( # the connector here and initialised this class in global scope, it # would potentially end up using the wrong event loop and aiohttp # would then fail when creating an HTTP request. - self._connector_factory: rest_api.ConnectorFactory = connector_factory or BasicLazyCachedTCPConnectorFactory() + if connector_factory is None: + connector_factory = BasicLazyCachedTCPConnectorFactory(self._http_settings) + connector_owner = True + + self._connector_factory = connector_factory self._connector_owner = connector_owner self._event_loop: typing.Optional[asyncio.AbstractEventLoop] = None self._executor = executor - self._http_settings = config.HTTPSettings() if http_settings is None else http_settings - self._proxy_settings = config.ProxySettings() if proxy_settings is None else proxy_settings self._url = url @property @@ -480,16 +487,13 @@ def __exit__(self, exc_type: typing.Type[Exception], exc_val: Exception, exc_tb: def _acquire_client_session(self) -> aiohttp.ClientSession: if self._client_session is None: self._closed_event.clear() - self._client_session = aiohttp.ClientSession( + self._client_session = net.create_client_session( connector=self._connector_factory.acquire(), + # No, this is correct. We manage closing the connector ourselves in this class if we are + # told we own it. This works around some other lifespan issues. connector_owner=False, - version=aiohttp.HttpVersion11, - timeout=aiohttp.ClientTimeout( - total=self._http_settings.timeouts.total, - connect=self._http_settings.timeouts.acquire_and_connect, - sock_read=self._http_settings.timeouts.request_socket_read, - sock_connect=self._http_settings.timeouts.request_socket_connect, - ), + http_settings=self._http_settings, + raise_for_status=False, trust_env=self._proxy_settings.trust_env, ) _LOGGER.log(ux.TRACE, "acquired new aiohttp client session") @@ -558,11 +562,10 @@ async def _request( params=query, json=json, data=form, - allow_redirects=self._http_settings.allow_redirects, + allow_redirects=self._http_settings.max_redirects is not None, max_redirects=self._http_settings.max_redirects, proxy=self._proxy_settings.url, proxy_headers=self._proxy_settings.all_headers, - verify_ssl=self._http_settings.verify_ssl, ) time_taken = (date.monotonic() - start) * 1_000 diff --git a/hikari/impl/shard.py b/hikari/impl/shard.py index fa57d123d7..870a1e4081 100644 --- a/hikari/impl/shard.py +++ b/hikari/impl/shard.py @@ -47,6 +47,7 @@ from hikari.impl import rate_limits from hikari.utilities import data_binding from hikari.utilities import date +from hikari.utilities import net from hikari.utilities import ux if typing.TYPE_CHECKING: @@ -228,9 +229,9 @@ async def _receive_and_check(self, timeout: typing.Optional[float], /) -> str: async def connect( cls, *, - http_config: config.HTTPSettings, + http_settings: config.HTTPSettings, logger: logging.Logger, - proxy_config: config.ProxySettings, + proxy_settings: config.ProxySettings, log_filterer: typing.Callable[[str], str], url: str, ) -> typing.AsyncGenerator[_V6GatewayTransport, None]: @@ -246,32 +247,16 @@ async def connect( exit_stack = contextlib.AsyncExitStack() try: + connector = net.create_tcp_connector(http_settings, dns_cache=False, limit=1) client_session = await exit_stack.enter_async_context( - aiohttp.ClientSession( - connector=aiohttp.TCPConnector( - limit=1, - use_dns_cache=False, - verify_ssl=http_config.verify_ssl, - enable_cleanup_closed=True, - force_close=True, - ), - raise_for_status=True, - timeout=aiohttp.ClientTimeout( - total=http_config.timeouts.total, - connect=http_config.timeouts.acquire_and_connect, - sock_read=http_config.timeouts.request_socket_read, - sock_connect=http_config.timeouts.request_socket_connect, - ), - trust_env=proxy_config.trust_env, - ws_response_class=cls, - ) + net.create_client_session(connector, True, http_settings, True, proxy_settings.trust_env, cls) ) web_socket = await exit_stack.enter_async_context( client_session.ws_connect( max_msg_size=0, - proxy=proxy_config.url, - proxy_headers=proxy_config.headers, + proxy=proxy_settings.url, + proxy_headers=proxy_settings.headers, url=url, ) ) @@ -641,11 +626,13 @@ def _dispatch(self, name: str, seq: int, data: data_binding.JSONObject) -> None: user_id = user_pl["id"] self._user_id = snowflakes.Snowflake(user_id) tag = user_pl["username"] + "#" + user_pl["discriminator"] + unavailable_guild_count = len(data["guilds"]) self._logger.info( - "shard is ready [session:%s, user_id:%s, tag:%s]", + "shard is ready [session:%s, user_id:%s, tag:%s, guilds:%s]", self._session_id, user_id, tag, + unavailable_guild_count, ) self._handshake_completed.set() @@ -774,8 +761,8 @@ async def _run(self) -> None: try: last_started_at = date.monotonic() - if not await self._run_once(): - self._logger.debug("shard has shut down") + # TODO: should I be using the result of this still, or is it dead code? Is it a bug I created? + await self._run_once() except errors.GatewayConnectionError as ex: self._logger.error( @@ -814,10 +801,10 @@ async def _run_once(self) -> bool: self._ws = await exit_stack.enter_async_context( _V6GatewayTransport.connect( - http_config=self._http_settings, + http_settings=self._http_settings, log_filterer=_log_filterer(self._token), logger=self._logger, - proxy_config=self._proxy_settings, + proxy_settings=self._proxy_settings, url=self._url, ) ) diff --git a/hikari/utilities/net.py b/hikari/utilities/net.py index 945cd5c1b3..f03a3c047e 100644 --- a/hikari/utilities/net.py +++ b/hikari/utilities/net.py @@ -23,15 +23,15 @@ from __future__ import annotations -__all__: typing.Final[typing.List[str]] = ["generate_error_response"] +__all__: typing.Final[typing.List[str]] = ["generate_error_response", "create_client_session"] import http import typing -from hikari import errors +import aiohttp -if typing.TYPE_CHECKING: - import aiohttp +from hikari import config +from hikari import errors async def generate_error_response(response: aiohttp.ClientResponse) -> errors.HTTPError: @@ -59,3 +59,97 @@ async def generate_error_response(response: aiohttp.ClientResponse) -> errors.HT cls = errors.HTTPResponseError return cls(real_url, status, response.headers, raw_body) + + +def create_tcp_connector( + http_settings: config.HTTPSettings, + *, + dns_cache: typing.Union[bool, int] = True, + limit: int = 100, +) -> aiohttp.TCPConnector: + """Create a TCP connector and return it. + + Parameters + ---------- + dns_cache: typing.Union[builtins.None, builtins.bool, int] + If `builtins.True`, DNS caching is used with a default TTL of 10 seconds. + If `builtins.False`, DNS cacheing is disabled. If an `builtins.int` is + given, then DNS caching is enabled with an explicit TTL set. If + `builtins.None`, the cache will be enabled and never invalidate. + http_settings : config.HTTPSettings + HTTP settings to use for the connector. + limit : builtins.int + Number of connections to allow in the pool at a maximum. + + Returns + ------- + aiohttp.TCPConnector + TCP connector to use. + """ + return aiohttp.TCPConnector( + enable_cleanup_closed=http_settings.enable_cleanup_closed, + force_close=http_settings.force_close_transports, + limit=limit, + ssl_context=http_settings.ssl, + ttl_dns_cache=dns_cache if not isinstance(dns_cache, bool) else 10, + use_dns_cache=dns_cache is not False, + ) + + +def create_client_session( + connector: aiohttp.BaseConnector, + connector_owner: bool, + http_settings: config.HTTPSettings, + raise_for_status: bool, + trust_env: bool, + ws_response_cls: typing.Type[aiohttp.ClientWebSocketResponse] = aiohttp.ClientWebSocketResponse, +) -> aiohttp.ClientSession: + """Generate a client session using the given settings. + + !!! warning + You must invoke this from within a running event loop. + + !!! note + If you pass an explicit connector, then the connection + that is created will not own the connector. You will be + expected to manually close it __after__ the returned + client session is closed to prevent leaking resources. + + Parameters + ---------- + connector : aiohttp.BaseConnector + The connector to use. + connector_owner : builtins.bool + If `builtins.True`, then the client session will close the + connector on shutdown. Otherwise, you must do it manually. + http_settings : hikari.config.HTTPSettings + HTTP settings to use. + raise_for_status : builtins.bool + `builtins.True` to default to throwing exceptions if a request + fails, or `builtins.False` to default to not. + trust_env : builtins.bool + `builtins.True` to trust anything in environment variables + and the `netrc` file, `builtins.False` to ignore it. + ws_response_cls : typing.Type[aiohttp.ClientWebSocketResponse] + `builtins.True` to default to throwing exceptions if a request + fails, or `builtins.False` to default to not. + + Returns + ------- + aiohttp.ClientSession + The client session to use. + """ + return aiohttp.ClientSession( + connector=connector, + connector_owner=connector_owner, + raise_for_status=raise_for_status, + timeout=aiohttp.ClientTimeout( + connect=http_settings.timeouts.acquire_and_connect, + sock_connect=http_settings.timeouts.request_socket_connect, + sock_read=http_settings.timeouts.request_socket_read, + total=http_settings.timeouts.total, + ), + trust_env=trust_env, + version=aiohttp.HttpVersion11, + ws_response_class=ws_response_cls, + ) diff --git a/hikari/utilities/ux.py b/hikari/utilities/ux.py index f9b714c4fc..cb7c78961d 100644 --- a/hikari/utilities/ux.py +++ b/hikari/utilities/ux.py @@ -22,6 +22,8 @@ """User-experience extensions and utilities.""" from __future__ import annotations +from hikari.utilities import net + __all__: typing.List[str] = ["init_logging", "print_banner", "supports_color", "HikariVersion", "check_for_updates"] import contextlib @@ -35,11 +37,13 @@ import sys import typing -import aiohttp import colorlog # type: ignore[import] from hikari import _about as about +if typing.TYPE_CHECKING: + from hikari import config + # While this is discouraged for most purposes in libraries, this enables us to # filter out the vast majority of clutter that most network logger calls # create. This also has a very minute performance improvement for trace logging @@ -241,13 +245,27 @@ class HikariVersion(distutils.version.StrictVersion): ) -async def check_for_updates() -> None: +async def check_for_updates( + http_settings: config.HTTPSettings, + proxy_settings: config.ProxySettings, +) -> None: """Perform a check for newer versions of the library, logging any found.""" try: - async with aiohttp.request( - "GET", "https://pypi.org/pypi/hikari/json", timeout=aiohttp.ClientTimeout(total=1.5), raise_for_status=True - ) as resp: - data = await resp.json() + async with net.create_client_session( + connector=net.create_tcp_connector(dns_cache=False, limit=1, http_settings=http_settings), + connector_owner=True, + http_settings=http_settings, + raise_for_status=True, + trust_env=proxy_settings.trust_env, + ) as cs: + async with cs.get( + "https://pypi.org/pypi/hikari/json", + allow_redirects=http_settings.max_redirects is not None, + max_redirects=http_settings.max_redirects if http_settings.max_redirects is not None else 10, + proxy=proxy_settings.url, + proxy_headers=proxy_settings.all_headers, + ) as resp: + data = await resp.json() this_version = HikariVersion(about.__version__) is_dev = this_version.prerelease is not None diff --git a/tests/hikari/impl/test_shard.py b/tests/hikari/impl/test_shard.py index c70d05d613..7b3c5dae5d 100644 --- a/tests/hikari/impl/test_shard.py +++ b/tests/hikari/impl/test_shard.py @@ -247,8 +247,8 @@ def __init__(self): with stack: async with shard._V6GatewayTransport.connect( - http_config=http_settings, - proxy_config=proxy_settings, + http_settings=http_settings, + proxy_settings=proxy_settings, logger=logger, url="https://some.url", log_filterer=log_filterer, @@ -309,8 +309,8 @@ def __init__(self): with stack: async with shard._V6GatewayTransport.connect( - http_config=http_settings, - proxy_config=proxy_settings, + http_settings=http_settings, + proxy_settings=proxy_settings, logger=logger, url="https://some.url", log_filterer=log_filterer, @@ -349,8 +349,8 @@ def __init__(self): with stack: async with shard._V6GatewayTransport.connect( - http_config=http_settings, - proxy_config=proxy_settings, + http_settings=http_settings, + proxy_settings=proxy_settings, logger=logger, url="https://some.url", log_filterer=log_filterer, @@ -389,8 +389,8 @@ def __init__(self): with stack: async with shard._V6GatewayTransport.connect( - http_config=http_settings, - proxy_config=proxy_settings, + http_settings=http_settings, + proxy_settings=proxy_settings, logger=logger, url="https://some.url", log_filterer=log_filterer, @@ -428,8 +428,8 @@ def __init__(self): with stack: async with shard._V6GatewayTransport.connect( - http_config=http_settings, - proxy_config=proxy_settings, + http_settings=http_settings, + proxy_settings=proxy_settings, logger=logger, url="https://some.url", log_filterer=log_filterer, @@ -461,8 +461,8 @@ async def test_connect_when_error_connecting(self, http_settings, proxy_settings with stack: async with shard._V6GatewayTransport.connect( - http_config=http_settings, - proxy_config=proxy_settings, + http_settings=http_settings, + proxy_settings=proxy_settings, logger=logger, url="https://some.url", log_filterer=log_filterer, @@ -499,8 +499,8 @@ async def test_connect_when_handshake_error_with_unknown_reason(self, http_setti with stack: async with shard._V6GatewayTransport.connect( - http_config=http_settings, - proxy_config=proxy_settings, + http_settings=http_settings, + proxy_settings=proxy_settings, logger=logger, url="https://some.url", log_filterer=log_filterer, @@ -537,8 +537,8 @@ async def test_connect_when_handshake_error_with_known_reason(self, http_setting with stack: async with shard._V6GatewayTransport.connect( - http_config=http_settings, - proxy_config=proxy_settings, + http_settings=http_settings, + proxy_settings=proxy_settings, logger=logger, url="https://some.url", log_filterer=log_filterer, From f61e7f3da7503bc9862c2c21355b802a6500c3c0 Mon Sep 17 00:00:00 2001 From: "Nekoka.tt" <3903853-nekokatt@users.noreply.gitlab.com> Date: Wed, 16 Sep 2020 21:43:08 +0100 Subject: [PATCH 02/12] Improve logging for sharded bots. --- hikari/impl/bot.py | 4 ++++ hikari/impl/shard.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/hikari/impl/bot.py b/hikari/impl/bot.py index be412384b9..adb0e97ecc 100644 --- a/hikari/impl/bot.py +++ b/hikari/impl/bot.py @@ -767,6 +767,8 @@ async def start( # Dispatch the update checker, the sharding requirements checker, and dispatch # the starting event together to save a little time on startup. + start_time = date.monotonic() + if check_for_updates: asyncio.create_task( ux.check_for_updates(self._http_settings, self._proxy_settings), @@ -867,6 +869,8 @@ async def start( await self.dispatch(lifetime_events.StartedEvent(app=self)) + _LOGGER.info("application started successfully in approx %.0f seconds", date.monotonic() - start_time) + def stream( self, event_type: typing.Type[event_dispatcher.EventT_co], diff --git a/hikari/impl/shard.py b/hikari/impl/shard.py index 870a1e4081..2b093af234 100644 --- a/hikari/impl/shard.py +++ b/hikari/impl/shard.py @@ -790,7 +790,7 @@ async def _run(self) -> None: finally: self._closed.set() - self._logger.info("shard %s has shut down", self._shard_id) + self._logger.info("shard has disconnected and shut down", self._shard_id) async def _run_once(self) -> bool: self._closing.clear() From 3dda9f9a95261a72f417272d5526d37415eab7e8 Mon Sep 17 00:00:00 2001 From: "Nekoka.tt" <3903853-nekokatt@users.noreply.gitlab.com> Date: Wed, 16 Sep 2020 21:49:12 +0100 Subject: [PATCH 03/12] Fixed an issue with 32 bit Python distributions. Seems hash() returns a 32-bit value on a 32 bit computer. That caused a test case to fail by chance. That has now been fixed. --- tests/hikari/test_snowflake.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/hikari/test_snowflake.py b/tests/hikari/test_snowflake.py index ca746ace5b..98f5c8ccf6 100644 --- a/tests/hikari/test_snowflake.py +++ b/tests/hikari/test_snowflake.py @@ -53,7 +53,7 @@ def test_internal_worker_id(self, neko_snowflake): assert neko_snowflake.internal_worker_id == 2 def test_hash(self, neko_snowflake, raw_id): - assert hash(neko_snowflake) == raw_id + assert hash(neko_snowflake) == hash(raw_id) def test_int_cast(self, neko_snowflake, raw_id): assert int(neko_snowflake) == raw_id @@ -107,7 +107,7 @@ def test_created_at(self, neko_unique): ) def test__hash__(self, neko_unique, raw_id): - assert hash(neko_unique) == raw_id + assert hash(neko_unique) == hash(raw_id) def test__eq__(self, neko_snowflake, raw_id): class NekoUnique(snowflakes.Unique): From 76b59435577401efffa4d929c1845f55decb6558 Mon Sep 17 00:00:00 2001 From: "Nekoka.tt" <3903853-nekokatt@users.noreply.gitlab.com> Date: Wed, 16 Sep 2020 22:13:53 +0100 Subject: [PATCH 04/12] Fixed test cases, re-enabled logging of reconnect/invalid-session messages on INFO. --- hikari/impl/shard.py | 8 +++---- tests/hikari/impl/test_rest.py | 24 +++++++++++++++----- tests/hikari/impl/test_shard.py | 40 ++++++++++++++++++++++++--------- 3 files changed, 53 insertions(+), 19 deletions(-) diff --git a/hikari/impl/shard.py b/hikari/impl/shard.py index 2b093af234..156ee65a40 100644 --- a/hikari/impl/shard.py +++ b/hikari/impl/shard.py @@ -712,16 +712,16 @@ async def _poll_events(self) -> typing.Optional[bool]: self._logger.log(ux.TRACE, "received HEARTBEAT ACK in %.1fms", self._heartbeat_latency * 1_000) elif op == _RECONNECT: # We should be able to resume... - self._logger.debug("received instruction to reconnect, will resume existing session") + self._logger.info("received instruction to reconnect, will resume existing session") return True elif op == _INVALID_SESSION: # We can resume if the payload was `true`. if not d: - self._logger.debug("received invalid session, will need to start a new session") + self._logger.info("received invalid session, will need to start a new session") self._seq = None self._session_id = None else: - self._logger.debug("received invalid session, will resume existing session") + self._logger.info("received invalid session, will resume existing session") return True else: self._logger.log(ux.TRACE, "unknown opcode %s received, it will be ignored...", op) @@ -790,7 +790,7 @@ async def _run(self) -> None: finally: self._closed.set() - self._logger.info("shard has disconnected and shut down", self._shard_id) + self._logger.info("shard has disconnected and shut down") async def _run_once(self) -> bool: self._closing.clear() diff --git a/tests/hikari/impl/test_rest.py b/tests/hikari/impl/test_rest.py index 8e5de6b3a0..694e5ea959 100644 --- a/tests/hikari/impl/test_rest.py +++ b/tests/hikari/impl/test_rest.py @@ -54,19 +54,31 @@ @pytest.fixture() -def connector_factory(): - return rest.BasicLazyCachedTCPConnectorFactory(test=123) +def http_settings(): + return mock.Mock(spec_set=config.HTTPSettings) + + +@pytest.fixture() +def connector_factory(http_settings): + return rest.BasicLazyCachedTCPConnectorFactory(http_settings) class TestBasicLazyCachedTCPConnectorFactory: - def test_acquire_when_connector_is_None(self, connector_factory): + def test_acquire_when_connector_is_None(self, connector_factory, http_settings): connector_mock = object() connector_factory.connector = None with mock.patch.object(aiohttp, "TCPConnector", return_value=connector_mock) as tcp_connector: assert connector_factory.acquire() is connector_mock assert connector_factory.connector is connector_mock - tcp_connector.assert_called_once_with(test=123, force_close=True, enable_cleanup_closed=True) + tcp_connector.assert_called_once_with( + force_close=http_settings.force_close_transports, + enable_cleanup_closed=http_settings.enable_cleanup_closed, + limit=100, + ssl_context=http_settings.ssl, + ttl_dns_cache=10, + use_dns_cache=True, + ) def test_acquire_when_connector_is_not_None(self, connector_factory): connector_mock = object() @@ -404,9 +416,11 @@ def test__acquire_client_session_when_None(self, rest_client): client_session.assert_called_once_with( connector=connector_mock, connector_owner=False, - version=aiohttp.HttpVersion11, + raise_for_status=False, timeout=aiohttp.ClientTimeout(total=10, connect=5, sock_read=4, sock_connect=1), trust_env=False, + version=aiohttp.HttpVersion11, + ws_response_class=aiohttp.ClientWebSocketResponse, ) def test__acquire_client_session_when_not_None_and_open(self, rest_client): diff --git a/tests/hikari/impl/test_shard.py b/tests/hikari/impl/test_shard.py index 7b3c5dae5d..0060488abc 100644 --- a/tests/hikari/impl/test_shard.py +++ b/tests/hikari/impl/test_shard.py @@ -257,10 +257,11 @@ def __init__(self): tcp_connector.assert_called_once_with( limit=1, + ttl_dns_cache=10, use_dns_cache=False, - verify_ssl=http_settings.verify_ssl, - enable_cleanup_closed=True, - force_close=True, + ssl_context=http_settings.ssl, + enable_cleanup_closed=http_settings.enable_cleanup_closed, + force_close=http_settings.force_close_transports, ) client_timeout.assert_called_once_with( total=http_settings.timeouts.total, @@ -270,9 +271,11 @@ def __init__(self): ) client_session.assert_called_once_with( connector=tcp_connector(), + connector_owner=True, raise_for_status=True, timeout=client_timeout(), trust_env=proxy_settings.trust_env, + version=aiohttp.HttpVersion11, ws_response_class=shard._V6GatewayTransport, ) mock_client_session.ws_connect.assert_called_once_with( @@ -827,7 +830,7 @@ async def test_update_voice_state(self, client, channel, self_deaf, self_mute): client._ws.send_json.assert_awaited_once_with({"op": 4, "d": payload}) - def test__dipatch_when_READY(self, client): + def test_dispatch_when_READY(self, client): client._seq = 0 client._session_id = 0 client._user_id = 0 @@ -836,21 +839,38 @@ def test__dipatch_when_READY(self, client): client._event_consumer = mock.Mock() client._dispatch( - "READY", 10, {"session_id": 101, "user": {"id": 123, "username": "hikari", "discriminator": "5863"}} + "READY", + 10, + { + "session_id": 101, + "user": {"id": 123, "username": "hikari", "discriminator": "5863"}, + "guilds": [ + {"id": "123"}, + {"id": "456"}, + {"id": "789"}, + ], + }, ) assert client._seq == 10 assert client._session_id == 101 assert client._user_id == 123 client._logger.info.assert_called_once_with( - "shard is ready [session:%s, user_id:%s, tag:%s]", - 101, - 123, - "hikari#5863", + "shard is ready [session:%s, user_id:%s, tag:%s, guilds:%s]", 101, 123, "hikari#5863", 3 ) client._handshake_completed.set.assert_called_once_with() client._event_consumer.assert_called_once_with( - client, "READY", {"session_id": 101, "user": {"id": 123, "username": "hikari", "discriminator": "5863"}} + client, + "READY", + { + "session_id": 101, + "user": {"id": 123, "username": "hikari", "discriminator": "5863"}, + "guilds": [ + {"id": "123"}, + {"id": "456"}, + {"id": "789"}, + ], + }, ) def test__dipatch_when_RESUME(self, client): From edee462ce7abd13c636db9c7dfcbd423442d5a3e Mon Sep 17 00:00:00 2001 From: "Nekoka.tt" <3903853-nekokatt@users.noreply.gitlab.com> Date: Wed, 16 Sep 2020 22:25:14 +0100 Subject: [PATCH 05/12] Tweaked shard logging. --- hikari/impl/shard.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/hikari/impl/shard.py b/hikari/impl/shard.py index 156ee65a40..b0fba5aa57 100644 --- a/hikari/impl/shard.py +++ b/hikari/impl/shard.py @@ -84,7 +84,7 @@ _BACKOFF_WINDOW: typing.Final[float] = 30.0 _BACKOFF_BASE: typing.Final[float] = 1.85 _BACKOFF_INCREMENT_START: typing.Final[int] = 2 -_BACKOFF_CAP: typing.Final[float] = 600.0 +_BACKOFF_CAP: typing.Final[float] = 60.0 # Discord seems to invalidate sessions if I send a 1xxx, which is useless # for invalid session and reconnect messages where I want to be able to # resume. @@ -763,6 +763,7 @@ async def _run(self) -> None: last_started_at = date.monotonic() # TODO: should I be using the result of this still, or is it dead code? Is it a bug I created? await self._run_once() + self._logger.info("shard has disconnected and shut down normally") except errors.GatewayConnectionError as ex: self._logger.error( @@ -780,17 +781,21 @@ async def _run(self) -> None: ex.reason, ) + # We don't want to back off from this. If Discord keep closing the connection, it is their issue. + # If we back off here, we'll find a mass outage will prevent shards from becoming healthy on + # reconnect in large sharded bots for a very long period of time. + backoff.reset() + except errors.GatewayError as ex: - self._logger.debug("encountered generic gateway error", exc_info=ex) + self._logger.error("encountered generic gateway error", exc_info=ex) raise except Exception as ex: - self._logger.debug("encountered some unhandled error", exc_info=ex) + self._logger.error("encountered some unhandled error", exc_info=ex) raise finally: self._closed.set() - self._logger.info("shard has disconnected and shut down") async def _run_once(self) -> bool: self._closing.clear() From 690a5ebbaadc2d5afaf42b4c655cd5903d02f5cf Mon Sep 17 00:00:00 2001 From: "Nekoka.tt" <3903853-nekokatt@users.noreply.gitlab.com> Date: Wed, 16 Sep 2020 22:48:21 +0100 Subject: [PATCH 06/12] Fixed regression where zombie caused stuff to die permanently. --- hikari/impl/shard.py | 86 +++++++++++++++++++++++--------------------- 1 file changed, 45 insertions(+), 41 deletions(-) diff --git a/hikari/impl/shard.py b/hikari/impl/shard.py index b0fba5aa57..64745e894c 100644 --- a/hikari/impl/shard.py +++ b/hikari/impl/shard.py @@ -746,56 +746,60 @@ async def _run(self) -> None: initial_increment=_BACKOFF_INCREMENT_START, ) - while True: - if date.monotonic() - last_started_at < _BACKOFF_WINDOW: - time = next(backoff) - self._logger.info("backing off reconnecting for %.2fs", time) + try: + while True: + if date.monotonic() - last_started_at < _BACKOFF_WINDOW: + time = next(backoff) + self._logger.info("backing off reconnecting for %.2fs", time) + + try: + await asyncio.wait_for(self._closing.wait(), timeout=time) + # We were told to close. + return + except asyncio.TimeoutError: + # We are going to run once. + pass try: - await asyncio.wait_for(self._closing.wait(), timeout=time) - # We were told to close. - return - except asyncio.TimeoutError: - # We are going to run once. - pass + last_started_at = date.monotonic() + should_die = await self._run_once() - try: - last_started_at = date.monotonic() - # TODO: should I be using the result of this still, or is it dead code? Is it a bug I created? - await self._run_once() - self._logger.info("shard has disconnected and shut down normally") - - except errors.GatewayConnectionError as ex: - self._logger.error( - "failed to communicate with server, reason was: %s. Will retry shortly", - ex.__cause__, - ) + if not should_die: + continue - except errors.GatewayServerClosedConnectionError as ex: - if not ex.can_reconnect: - raise + self._logger.info("shard has disconnected and shut down normally") + return - self._logger.info( - "server has closed connection, will reconnect if possible [code:%s, reason:%s]", - ex.code, - ex.reason, - ) + except errors.GatewayConnectionError as ex: + self._logger.error( + "failed to communicate with server, reason was: %s. Will retry shortly", + ex.__cause__, + ) - # We don't want to back off from this. If Discord keep closing the connection, it is their issue. - # If we back off here, we'll find a mass outage will prevent shards from becoming healthy on - # reconnect in large sharded bots for a very long period of time. - backoff.reset() + except errors.GatewayServerClosedConnectionError as ex: + if not ex.can_reconnect: + raise - except errors.GatewayError as ex: - self._logger.error("encountered generic gateway error", exc_info=ex) - raise + self._logger.info( + "server has closed connection, will reconnect if possible [code:%s, reason:%s]", + ex.code, + ex.reason, + ) - except Exception as ex: - self._logger.error("encountered some unhandled error", exc_info=ex) - raise + # We don't want to back off from this. If Discord keep closing the connection, it is their issue. + # If we back off here, we'll find a mass outage will prevent shards from becoming healthy on + # reconnect in large sharded bots for a very long period of time. + backoff.reset() - finally: - self._closed.set() + except errors.GatewayError as ex: + self._logger.error("encountered generic gateway error", exc_info=ex) + raise + + except Exception as ex: + self._logger.error("encountered some unhandled error", exc_info=ex) + raise + finally: + self._closed.set() async def _run_once(self) -> bool: self._closing.clear() From ab8218b06a1c6f5a0f3e5227db68710b7a517922 Mon Sep 17 00:00:00 2001 From: "Nekoka.tt" <3903853-nekokatt@users.noreply.gitlab.com> Date: Wed, 16 Sep 2020 22:50:20 +0100 Subject: [PATCH 07/12] Fixed zombie logging message. --- hikari/impl/shard.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/hikari/impl/shard.py b/hikari/impl/shard.py index 64745e894c..b2453a79ae 100644 --- a/hikari/impl/shard.py +++ b/hikari/impl/shard.py @@ -674,7 +674,11 @@ async def _heartbeat(self, heartbeat_interval: float) -> bool: while True: if self._last_heartbeat_ack_received <= self._last_heartbeat_sent: # Gateway is zombie - self._logger.log(ux.TRACE, "zombied") + self._logger.warning( + "connection has not received a HEARTBEAT_ACK for approx %.1fs and is being disconnected, " + "expect a reconnect shortly", + date.monotonic() - self._last_heartbeat_ack_received, + ) return True self._logger.log( From 2dd60b594bb52c306be1c45231a0c618aa5362a3 Mon Sep 17 00:00:00 2001 From: "Nekoka.tt" <3903853-nekokatt@users.noreply.gitlab.com> Date: Wed, 16 Sep 2020 22:52:27 +0100 Subject: [PATCH 08/12] Fixed inverted condition. --- hikari/impl/shard.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/hikari/impl/shard.py b/hikari/impl/shard.py index b2453a79ae..d82e9c8bf3 100644 --- a/hikari/impl/shard.py +++ b/hikari/impl/shard.py @@ -766,13 +766,11 @@ async def _run(self) -> None: try: last_started_at = date.monotonic() - should_die = await self._run_once() + should_restart = await self._run_once() - if not should_die: - continue - - self._logger.info("shard has disconnected and shut down normally") - return + if not should_restart: + self._logger.info("shard has disconnected and shut down normally") + return except errors.GatewayConnectionError as ex: self._logger.error( From cbaea48d8eb8f178d4fe23a0cc0b46ce19ed043e Mon Sep 17 00:00:00 2001 From: "Nekoka.tt" <3903853-nekokatt@users.noreply.gitlab.com> Date: Wed, 16 Sep 2020 23:04:38 +0100 Subject: [PATCH 09/12] Fixed garbled banner on Windows when redirecting streams. --- hikari/utilities/ux.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/hikari/utilities/ux.py b/hikari/utilities/ux.py index cb7c78961d..1c4100600d 100644 --- a/hikari/utilities/ux.py +++ b/hikari/utilities/ux.py @@ -22,6 +22,8 @@ """User-experience extensions and utilities.""" from __future__ import annotations +import time + from hikari.utilities import net __all__: typing.List[str] = ["init_logging", "print_banner", "supports_color", "HikariVersion", "check_for_updates"] @@ -191,6 +193,9 @@ def print_banner(package: typing.Optional[str], allow_color: bool, force_color: args[code] = "" sys.stdout.write(string.Template(raw_banner).safe_substitute(args)) + # Give the stream some time to flush + sys.stdout.flush() + time.sleep(0.125) def supports_color(allow_color: bool, force_color: bool) -> bool: From 1e29d916dad7c259035fd4341c795d2c1aa8db99 Mon Sep 17 00:00:00 2001 From: "Nekoka.tt" <3903853-nekokatt@users.noreply.gitlab.com> Date: Wed, 16 Sep 2020 23:19:49 +0100 Subject: [PATCH 10/12] Fixed bug where shutting down shards would warn about a backoff that won't occurr. --- hikari/impl/shard.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hikari/impl/shard.py b/hikari/impl/shard.py index d82e9c8bf3..5745f83bc8 100644 --- a/hikari/impl/shard.py +++ b/hikari/impl/shard.py @@ -751,7 +751,7 @@ async def _run(self) -> None: ) try: - while True: + while not self._closing.is_set() and not self._closed: if date.monotonic() - last_started_at < _BACKOFF_WINDOW: time = next(backoff) self._logger.info("backing off reconnecting for %.2fs", time) From 69432f40d06b5e6f22887c03ed1554c7437e7896 Mon Sep 17 00:00:00 2001 From: "Nekoka.tt" <3903853-nekokatt@users.noreply.gitlab.com> Date: Wed, 16 Sep 2020 23:42:51 +0100 Subject: [PATCH 11/12] Fixed more edge case scenarios for gateway termination and fixed tests. --- hikari/impl/shard.py | 16 ++++++++++------ tests/hikari/impl/test_shard.py | 3 ++- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/hikari/impl/shard.py b/hikari/impl/shard.py index 5745f83bc8..32735e5f41 100644 --- a/hikari/impl/shard.py +++ b/hikari/impl/shard.py @@ -666,14 +666,14 @@ async def _identify(self) -> None: await self._ws.send_json(payload) # type: ignore[union-attr] async def _heartbeat(self, heartbeat_interval: float) -> bool: - # Return True if zombied. + # Return True if zombied or should reconnect, false if time to die forever. # Prevent immediately zombie-ing. self._last_heartbeat_ack_received = date.monotonic() self._logger.debug("starting heartbeat with interval %ss", heartbeat_interval) - while True: + while not self._closing.is_set() and not self._closed.is_set(): if self._last_heartbeat_ack_received <= self._last_heartbeat_sent: - # Gateway is zombie + # Gateway is zombie, close and request reconnect. self._logger.warning( "connection has not received a HEARTBEAT_ACK for approx %.1fs and is being disconnected, " "expect a reconnect shortly", @@ -690,11 +690,14 @@ async def _heartbeat(self, heartbeat_interval: float) -> bool: try: await asyncio.wait_for(self._closing.wait(), timeout=heartbeat_interval) # We are closing - return False + break except asyncio.TimeoutError: # We should continue continue + self._logger.debug("heartbeat task is finishing now") + return False + async def _poll_events(self) -> typing.Optional[bool]: payload = await self._ws.receive_json(timeout=5) # type: ignore[union-attr] @@ -742,6 +745,7 @@ async def _resume(self) -> None: async def _run(self) -> None: self._closed.clear() + self._closing.clear() last_started_at = -float("inf") backoff = rate_limits.ExponentialBackOff( @@ -751,7 +755,7 @@ async def _run(self) -> None: ) try: - while not self._closing.is_set() and not self._closed: + while not self._closing.is_set() and not self._closed.is_set(): if date.monotonic() - last_started_at < _BACKOFF_WINDOW: time = next(backoff) self._logger.info("backing off reconnecting for %.2fs", time) @@ -801,10 +805,10 @@ async def _run(self) -> None: self._logger.error("encountered some unhandled error", exc_info=ex) raise finally: + self._closing.set() self._closed.set() async def _run_once(self) -> bool: - self._closing.clear() self._handshake_completed.clear() dispatch_disconnect = False diff --git a/tests/hikari/impl/test_shard.py b/tests/hikari/impl/test_shard.py index 0060488abc..da0e76a872 100644 --- a/tests/hikari/impl/test_shard.py +++ b/tests/hikari/impl/test_shard.py @@ -970,7 +970,8 @@ async def test__identify_when_intents(self, client): async def test__heartbeat(self, client): client._last_heartbeat_sent = 5 client._logger = mock.Mock() - client._closing = mock.Mock() + client._closing = mock.Mock(is_set=mock.Mock(return_value=False)) + client._closed = mock.Mock(is_set=mock.Mock(return_value=False)) client._send_heartbeat = mock.AsyncMock() with mock.patch.object(date, "monotonic", return_value=10): From d0a221f2239e35538065937f7961ca8c447d24d8 Mon Sep 17 00:00:00 2001 From: "Nekoka.tt" <3903853-nekokatt@users.noreply.gitlab.com> Date: Thu, 17 Sep 2020 00:17:59 +0100 Subject: [PATCH 12/12] Fixed import in wrong place. --- hikari/config.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/hikari/config.py b/hikari/config.py index cf309ca7d0..0c06b96bfd 100644 --- a/hikari/config.py +++ b/hikari/config.py @@ -23,8 +23,6 @@ from __future__ import annotations -import yarl - __all__: typing.Final[typing.List[str]] = [ "BasicAuthHeader", "ProxySettings", @@ -37,6 +35,7 @@ import typing import attr +import yarl from hikari.utilities import attr_extensions from hikari.utilities import data_binding