From bf57d3be2919c618b146e8dea34fec592c2c9276 Mon Sep 17 00:00:00 2001 From: davfsa Date: Tue, 18 Jul 2023 14:59:00 +0200 Subject: [PATCH] Allow passing initial `seq`, `session_id` and `resume_url` attributes to the shard, as well as retrieving them. - Additionally export all impl classes to `hikari` namespace --- hikari/__init__.py | 7 ++-- hikari/__init__.pyi | 5 +-- hikari/api/shard.py | 15 ++++++++ hikari/impl/__init__.py | 3 +- hikari/impl/__init__.pyi | 3 +- hikari/impl/gateway_bot.py | 6 ++-- hikari/impl/shard.py | 52 ++++++++++++++++++++++----- tests/hikari/impl/test_gateway_bot.py | 6 ++-- 8 files changed, 71 insertions(+), 26 deletions(-) diff --git a/hikari/__init__.py b/hikari/__init__.py index 45a0066f97..ccc8b67f15 100644 --- a/hikari/__init__.py +++ b/hikari/__init__.py @@ -22,7 +22,7 @@ # SOFTWARE. """A sane Python framework for writing modern Discord bots. -To get started, you will want to initialize an instance of `hikari.impl.bot.GatewayBot` +To get started, you will want to initialize an instance of `hikari.impl.gateway_bot.GatewayBot` for writing a gateway based bot, `hikari.impl.rest_bot.RESTBot` for a REST based bot, or `hikari.impl.rest.RESTApp` if you only need to use the REST API. """ @@ -100,10 +100,7 @@ from hikari.files import Rawish from hikari.files import Resourceish from hikari.guilds import * -from hikari.impl import ClientCredentialsStrategy -from hikari.impl import GatewayBot -from hikari.impl import RESTApp -from hikari.impl import RESTBot +from hikari.impl import * from hikari.intents import * from hikari.interactions.base_interactions import * from hikari.interactions.command_interactions import * diff --git a/hikari/__init__.pyi b/hikari/__init__.pyi index 1e8c845337..005d3e953c 100644 --- a/hikari/__init__.pyi +++ b/hikari/__init__.pyi @@ -74,10 +74,7 @@ from hikari.files import Pathish as Pathish from hikari.files import Rawish as Rawish from hikari.files import Resourceish as Resourceish from hikari.guilds import * -from hikari.impl import ClientCredentialsStrategy as ClientCredentialsStrategy -from hikari.impl import GatewayBot as GatewayBot -from hikari.impl import RESTApp as RESTApp -from hikari.impl import RESTBot as RESTBot +from hikari.impl import * from hikari.intents import * from hikari.interactions.base_interactions import * from hikari.interactions.command_interactions import * diff --git a/hikari/api/shard.py b/hikari/api/shard.py index b0378a880c..9243baa31f 100644 --- a/hikari/api/shard.py +++ b/hikari/api/shard.py @@ -104,6 +104,21 @@ def is_connected(self) -> bool: def shard_count(self) -> int: """Return the total number of shards expected in the entire application.""" + @property + @abc.abstractmethod + def session_id(self) -> typing.Optional[str]: + """The session id for the shard.""" + + @property + @abc.abstractmethod + def seq(self) -> typing.Optional[int]: + """The sequence number for the shard.""" + + @property + @abc.abstractmethod + def resume_gateway_url(self) -> typing.Optional[str]: + """The resume gateway url for the shard.""" + @abc.abstractmethod def get_user_id(self) -> snowflakes.Snowflake: """Return the user ID. diff --git a/hikari/impl/__init__.py b/hikari/impl/__init__.py index a631354b56..2834183eb1 100644 --- a/hikari/impl/__init__.py +++ b/hikari/impl/__init__.py @@ -35,7 +35,8 @@ from hikari.impl.config import * from hikari.impl.entity_factory import * from hikari.impl.event_manager import * -from hikari.impl.event_manager_base import * +from hikari.impl.event_manager_base import EventManagerBase +from hikari.impl.event_manager_base import EventStream from hikari.impl.gateway_bot import * from hikari.impl.interaction_server import * from hikari.impl.rate_limits import * diff --git a/hikari/impl/__init__.pyi b/hikari/impl/__init__.pyi index 3e9f90a956..133b797d30 100644 --- a/hikari/impl/__init__.pyi +++ b/hikari/impl/__init__.pyi @@ -6,7 +6,8 @@ from hikari.impl.cache import * from hikari.impl.config import * from hikari.impl.entity_factory import * from hikari.impl.event_manager import * -from hikari.impl.event_manager_base import * +from hikari.impl.event_manager_base import EventManagerBase as EventManagerBase +from hikari.impl.event_manager_base import EventStream as EventStream from hikari.impl.gateway_bot import * from hikari.impl.interaction_server import * from hikari.impl.rate_limits import * diff --git a/hikari/impl/gateway_bot.py b/hikari/impl/gateway_bot.py index 0d3f884304..0703884a99 100644 --- a/hikari/impl/gateway_bot.py +++ b/hikari/impl/gateway_bot.py @@ -987,7 +987,7 @@ async def start( large_threshold=large_threshold, shard_id=shard_id, shard_count=shard_count, - url=requirements.url, + gateway_url=requirements.url, ) for shard_id in window ) @@ -1261,7 +1261,7 @@ async def _start_one_shard( large_threshold: int, shard_id: int, shard_count: int, - url: str, + gateway_url: str, ) -> None: new_shard = shard_impl.GatewayShardImpl( http_settings=self._http_settings, @@ -1279,7 +1279,7 @@ async def _start_one_shard( shard_id=shard_id, shard_count=shard_count, token=self._token, - url=url, + gateway_url=gateway_url, ) try: start = time.monotonic() diff --git a/hikari/impl/shard.py b/hikari/impl/shard.py index 607eb83162..ff6f43cf92 100644 --- a/hikari/impl/shard.py +++ b/hikari/impl/shard.py @@ -373,7 +373,7 @@ class GatewayShardImpl(shard.GatewayShard): ---------- token : str The bot token to use. - url : str + gateway_url : str The gateway URL to use. This should not contain a query-string or fragments. event_manager : hikari.api.event_manager.EventManager @@ -417,6 +417,18 @@ class GatewayShardImpl(shard.GatewayShard): The proxy settings to use while negotiating a websocket. data_format : str Data format to use for inbound data. Only supported format is `"json"`. + initial_seq : typing.Optional[int] + The initial session sequence to start at. + initial_session_id : typing.Optional[str] + The initial session id to start with. + initial_resume_gateway_url : typing.Optional[str] + The initial resume gateway url to use. + + Raises + ------ + ValueError + If not all of `initial_seq`, `initial_session_id`, and `initial_resume_gateway_url` + are passed, with any one of them being given a value. """ __slots__: typing.Sequence[str] = ( @@ -467,13 +479,16 @@ def __init__( large_threshold: int = 250, shard_id: int = 0, shard_count: int = 1, + initial_seq: typing.Optional[int] = None, + initial_session_id: typing.Optional[str] = None, + initial_resume_gateway_url: typing.Optional[str] = None, http_settings: config.HTTPSettings, proxy_settings: config.ProxySettings, data_format: str = shard.GatewayDataFormat.JSON, event_manager: event_manager_.EventManager, event_factory: event_factory_.EventFactory, token: str, - url: str, + gateway_url: str, ) -> None: if data_format != shard.GatewayDataFormat.JSON: raise NotImplementedError(f"Unsupported gateway data format: {data_format}") @@ -481,10 +496,17 @@ def __init__( if compression and compression != shard.GatewayCompression.TRANSPORT_ZLIB_STREAM: raise NotImplementedError(f"Unsupported compression format {compression}") + if bool(initial_seq) ^ bool(initial_session_id) ^ bool(initial_resume_gateway_url): + # It makes no sense to allow passing RESUME data if not all the data is passed + raise ValueError( + "You must specify exactly all or neither of " + "`initial_seq`, `initial_session_id` or `initial_resume_gateway_url`" + ) + self._activity = initial_activity self._event_manager = event_manager self._event_factory = event_factory - self._gateway_url = url + self._gateway_url = gateway_url self._handshake_event: typing.Optional[asyncio.Event] = None self._heartbeat_latency = float("nan") self._http_settings = http_settings @@ -501,9 +523,9 @@ def __init__( f"shard {shard_id} non-priority rate limit", *_NON_PRIORITY_RATELIMIT ) self._proxy_settings = proxy_settings - self._resume_gateway_url: typing.Optional[str] = None - self._seq: typing.Optional[int] = None - self._session_id: typing.Optional[str] = None + self._resume_gateway_url = initial_resume_gateway_url + self._seq = initial_seq + self._session_id = initial_session_id self._shard_count = shard_count self._shard_id = shard_id self._status = initial_status @@ -541,6 +563,18 @@ def is_connected(self) -> bool: def shard_count(self) -> int: return self._shard_count + @property + def session_id(self) -> typing.Optional[str]: + return self._session_id + + @property + def seq(self) -> typing.Optional[int]: + return self._seq + + @property + def resume_gateway_url(self) -> typing.Optional[str]: + return self._resume_gateway_url + async def close(self) -> None: if not self._keep_alive_task: raise errors.ComponentStateConflictError("Cannot close an inactive shard") @@ -766,9 +800,9 @@ async def _poll_events(self) -> None: can_reconnect = payload[_D] # We can resume if the payload data is `true`. if not can_reconnect: self._logger.info("received invalid session, will need to start a new session") - self._seq = None - self._resume_gateway_url = None - self._session_id = None + self._seq = 0 + self._resume_gateway_url = "" + self._session_id = "" else: self._logger.info("received invalid session, will resume existing session") diff --git a/tests/hikari/impl/test_gateway_bot.py b/tests/hikari/impl/test_gateway_bot.py index b854a3f75a..fc1a57af7b 100644 --- a/tests/hikari/impl/test_gateway_bot.py +++ b/tests/hikari/impl/test_gateway_bot.py @@ -948,7 +948,7 @@ async def test_start_one_shard(self, bot): large_threshold=1000, shard_id=1, shard_count=3, - url="https://some.website", + gateway_url="https://some.website", ) shard.assert_called_once_with( @@ -989,7 +989,7 @@ async def test_start_one_shard_when_not_alive(self, bot): large_threshold=1000, shard_id=1, shard_count=3, - url="https://some.website", + gateway_url="https://some.website", ) assert bot._shards == {} @@ -1015,7 +1015,7 @@ async def test_start_one_shard_when_exception(self, bot, is_alive): large_threshold=1000, shard_id=1, shard_count=3, - url="https://some.website", + gateway_url="https://some.website", ) assert bot._shards == {}