Skip to content

Commit

Permalink
Allow passing initial seq, session_id and resume_url attributes…
Browse files Browse the repository at this point in the history
… to the shard, as well as retrieving them.

- Additionally export all impl classes to `hikari` namespace
  • Loading branch information
davfsa committed Jul 18, 2023
1 parent 7055102 commit bf57d3b
Show file tree
Hide file tree
Showing 8 changed files with 71 additions and 26 deletions.
7 changes: 2 additions & 5 deletions hikari/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
# SOFTWARE.
"""A sane Python framework for writing modern Discord bots.
To get started, you will want to initialize an instance of `hikari.impl.bot.GatewayBot`
To get started, you will want to initialize an instance of `hikari.impl.gateway_bot.GatewayBot`
for writing a gateway based bot, `hikari.impl.rest_bot.RESTBot` for a REST based bot,
or `hikari.impl.rest.RESTApp` if you only need to use the REST API.
"""
Expand Down Expand Up @@ -100,10 +100,7 @@
from hikari.files import Rawish
from hikari.files import Resourceish
from hikari.guilds import *
from hikari.impl import ClientCredentialsStrategy
from hikari.impl import GatewayBot
from hikari.impl import RESTApp
from hikari.impl import RESTBot
from hikari.impl import *
from hikari.intents import *
from hikari.interactions.base_interactions import *
from hikari.interactions.command_interactions import *
Expand Down
5 changes: 1 addition & 4 deletions hikari/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,7 @@ from hikari.files import Pathish as Pathish
from hikari.files import Rawish as Rawish
from hikari.files import Resourceish as Resourceish
from hikari.guilds import *
from hikari.impl import ClientCredentialsStrategy as ClientCredentialsStrategy
from hikari.impl import GatewayBot as GatewayBot
from hikari.impl import RESTApp as RESTApp
from hikari.impl import RESTBot as RESTBot
from hikari.impl import *
from hikari.intents import *
from hikari.interactions.base_interactions import *
from hikari.interactions.command_interactions import *
Expand Down
15 changes: 15 additions & 0 deletions hikari/api/shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,21 @@ def is_connected(self) -> bool:
def shard_count(self) -> int:
"""Return the total number of shards expected in the entire application."""

@property
@abc.abstractmethod
def session_id(self) -> typing.Optional[str]:
"""The session id for the shard."""

@property
@abc.abstractmethod
def seq(self) -> typing.Optional[int]:
"""The sequence number for the shard."""

@property
@abc.abstractmethod
def resume_gateway_url(self) -> typing.Optional[str]:
"""The resume gateway url for the shard."""

@abc.abstractmethod
def get_user_id(self) -> snowflakes.Snowflake:
"""Return the user ID.
Expand Down
3 changes: 2 additions & 1 deletion hikari/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@
from hikari.impl.config import *
from hikari.impl.entity_factory import *
from hikari.impl.event_manager import *
from hikari.impl.event_manager_base import *
from hikari.impl.event_manager_base import EventManagerBase
from hikari.impl.event_manager_base import EventStream
from hikari.impl.gateway_bot import *
from hikari.impl.interaction_server import *
from hikari.impl.rate_limits import *
Expand Down
3 changes: 2 additions & 1 deletion hikari/impl/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ from hikari.impl.cache import *
from hikari.impl.config import *
from hikari.impl.entity_factory import *
from hikari.impl.event_manager import *
from hikari.impl.event_manager_base import *
from hikari.impl.event_manager_base import EventManagerBase as EventManagerBase
from hikari.impl.event_manager_base import EventStream as EventStream
from hikari.impl.gateway_bot import *
from hikari.impl.interaction_server import *
from hikari.impl.rate_limits import *
Expand Down
6 changes: 3 additions & 3 deletions hikari/impl/gateway_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -987,7 +987,7 @@ async def start(
large_threshold=large_threshold,
shard_id=shard_id,
shard_count=shard_count,
url=requirements.url,
gateway_url=requirements.url,
)
for shard_id in window
)
Expand Down Expand Up @@ -1261,7 +1261,7 @@ async def _start_one_shard(
large_threshold: int,
shard_id: int,
shard_count: int,
url: str,
gateway_url: str,
) -> None:
new_shard = shard_impl.GatewayShardImpl(
http_settings=self._http_settings,
Expand All @@ -1279,7 +1279,7 @@ async def _start_one_shard(
shard_id=shard_id,
shard_count=shard_count,
token=self._token,
url=url,
gateway_url=gateway_url,
)
try:
start = time.monotonic()
Expand Down
52 changes: 43 additions & 9 deletions hikari/impl/shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ class GatewayShardImpl(shard.GatewayShard):
----------
token : str
The bot token to use.
url : str
gateway_url : str
The gateway URL to use. This should not contain a query-string or
fragments.
event_manager : hikari.api.event_manager.EventManager
Expand Down Expand Up @@ -417,6 +417,18 @@ class GatewayShardImpl(shard.GatewayShard):
The proxy settings to use while negotiating a websocket.
data_format : str
Data format to use for inbound data. Only supported format is `"json"`.
initial_seq : typing.Optional[int]
The initial session sequence to start at.
initial_session_id : typing.Optional[str]
The initial session id to start with.
initial_resume_gateway_url : typing.Optional[str]
The initial resume gateway url to use.
Raises
------
ValueError
If not all of `initial_seq`, `initial_session_id`, and `initial_resume_gateway_url`
are passed, with any one of them being given a value.
"""

__slots__: typing.Sequence[str] = (
Expand Down Expand Up @@ -467,24 +479,34 @@ def __init__(
large_threshold: int = 250,
shard_id: int = 0,
shard_count: int = 1,
initial_seq: typing.Optional[int] = None,
initial_session_id: typing.Optional[str] = None,
initial_resume_gateway_url: typing.Optional[str] = None,
http_settings: config.HTTPSettings,
proxy_settings: config.ProxySettings,
data_format: str = shard.GatewayDataFormat.JSON,
event_manager: event_manager_.EventManager,
event_factory: event_factory_.EventFactory,
token: str,
url: str,
gateway_url: str,
) -> None:
if data_format != shard.GatewayDataFormat.JSON:
raise NotImplementedError(f"Unsupported gateway data format: {data_format}")

if compression and compression != shard.GatewayCompression.TRANSPORT_ZLIB_STREAM:
raise NotImplementedError(f"Unsupported compression format {compression}")

if bool(initial_seq) ^ bool(initial_session_id) ^ bool(initial_resume_gateway_url):
# It makes no sense to allow passing RESUME data if not all the data is passed
raise ValueError(
"You must specify exactly all or neither of "
"`initial_seq`, `initial_session_id` or `initial_resume_gateway_url`"
)

self._activity = initial_activity
self._event_manager = event_manager
self._event_factory = event_factory
self._gateway_url = url
self._gateway_url = gateway_url
self._handshake_event: typing.Optional[asyncio.Event] = None
self._heartbeat_latency = float("nan")
self._http_settings = http_settings
Expand All @@ -501,9 +523,9 @@ def __init__(
f"shard {shard_id} non-priority rate limit", *_NON_PRIORITY_RATELIMIT
)
self._proxy_settings = proxy_settings
self._resume_gateway_url: typing.Optional[str] = None
self._seq: typing.Optional[int] = None
self._session_id: typing.Optional[str] = None
self._resume_gateway_url = initial_resume_gateway_url
self._seq = initial_seq
self._session_id = initial_session_id
self._shard_count = shard_count
self._shard_id = shard_id
self._status = initial_status
Expand Down Expand Up @@ -541,6 +563,18 @@ def is_connected(self) -> bool:
def shard_count(self) -> int:
return self._shard_count

@property
def session_id(self) -> typing.Optional[str]:
return self._session_id

@property
def seq(self) -> typing.Optional[int]:
return self._seq

@property
def resume_gateway_url(self) -> typing.Optional[str]:
return self._resume_gateway_url

async def close(self) -> None:
if not self._keep_alive_task:
raise errors.ComponentStateConflictError("Cannot close an inactive shard")
Expand Down Expand Up @@ -766,9 +800,9 @@ async def _poll_events(self) -> None:
can_reconnect = payload[_D] # We can resume if the payload data is `true`.
if not can_reconnect:
self._logger.info("received invalid session, will need to start a new session")
self._seq = None
self._resume_gateway_url = None
self._session_id = None
self._seq = 0
self._resume_gateway_url = ""
self._session_id = ""
else:
self._logger.info("received invalid session, will resume existing session")

Expand Down
6 changes: 3 additions & 3 deletions tests/hikari/impl/test_gateway_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,7 +948,7 @@ async def test_start_one_shard(self, bot):
large_threshold=1000,
shard_id=1,
shard_count=3,
url="https://some.website",
gateway_url="https://some.website",
)

shard.assert_called_once_with(
Expand Down Expand Up @@ -989,7 +989,7 @@ async def test_start_one_shard_when_not_alive(self, bot):
large_threshold=1000,
shard_id=1,
shard_count=3,
url="https://some.website",
gateway_url="https://some.website",
)

assert bot._shards == {}
Expand All @@ -1015,7 +1015,7 @@ async def test_start_one_shard_when_exception(self, bot, is_alive):
large_threshold=1000,
shard_id=1,
shard_count=3,
url="https://some.website",
gateway_url="https://some.website",
)

assert bot._shards == {}
Expand Down

0 comments on commit bf57d3b

Please sign in to comment.