Skip to content

Commit

Permalink
Adjusted backoff to not error but keep at a constant limit when satur…
Browse files Browse the repository at this point in the history
…ated.

- Had to reorder code to get past flake8 "function too long".
- Fixed failing test cases and mypy issues.
- Backoffs now do not ever kill the bot, only loop.
  • Loading branch information
Nekokatt committed Sep 16, 2020
1 parent c9cdc42 commit 6af035e
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 164 deletions.
20 changes: 10 additions & 10 deletions hikari/impl/rate_limits.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,8 +441,8 @@ class ExponentialBackOff:
base : builtins.float
The base to use. Defaults to `2.0`.
maximum : builtins.float
The max value the backoff can be in a single iteration before an
`asyncio.TimeoutError` is raised. Defaults to `64.0` seconds.
The max value the backoff can be in a single iteration. Anything above
this will be capped to this base value plus random jitter.
jitter_multiplier : builtins.float
The multiplier for the random jitter. Defaults to `1.0`.
Set to `0` to disable jitter.
Expand Down Expand Up @@ -510,17 +510,17 @@ def __next__(self) -> float:
"""Get the next back off to sleep by."""
try:
value = self.base ** self.increment

if value >= self.maximum:
value = self.maximum
else:
# This should only be incremented after we verify we haven't hit the maximum value.
self.increment += 1
except OverflowError:
# If this happened then we can be sure that we've passed maximum.
raise asyncio.TimeoutError from None

if value >= self.maximum:
raise asyncio.TimeoutError from None
value = self.maximum

# This should only be incremented after we verify we haven't hit the maximum value.
self.increment += 1
value += random.random() * self.jitter_multiplier # nosec # noqa S311 rng for cryptography
return value
return value + random.random() * self.jitter_multiplier # nosec # noqa S311 rng for cryptography

def __iter__(self) -> ExponentialBackOff:
"""Return this object, as it is an iterator."""
Expand Down
223 changes: 124 additions & 99 deletions hikari/impl/shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,10 @@ async def connect(
)
)

assert isinstance(web_socket, cls)

raised = False
try:
assert isinstance(web_socket, cls)
web_socket.logger = logger
# We store this so we can remove it from debug logs
# which enables people to send me logs in issues safely.
Expand Down Expand Up @@ -311,11 +312,11 @@ async def connect(
message=b"client is shutting down",
)

except (aiohttp.ClientOSError, aiohttp.ClientConnectionError) as ex:
except (aiohttp.ClientOSError, aiohttp.ClientConnectionError, aiohttp.WSServerHandshakeError) as ex:
# Windows will sometimes raise an aiohttp.ClientOSError
# If we cannot do DNS lookup, this will fail with a ClientConnectionError
# usually.
raise errors.GatewayConnectionError(f"Failed to connect to gateway {type(ex).__name__}: {ex}") from ex
raise errors.GatewayConnectionError(f"Failed to connect to Discord: {ex!r}") from ex

finally:
await exit_stack.aclose()
Expand Down Expand Up @@ -703,6 +704,43 @@ async def _heartbeat(self, heartbeat_interval: float) -> bool:
# We should continue
continue

async def _poll_events(self) -> typing.Optional[bool]:
payload = await self._ws.receive_json(timeout=5) # type: ignore[union-attr]

op = payload[_OP] # opcode int
d = payload[_D] # data/payload. Usually a dict or a bool for INVALID_SESSION

if op == _DISPATCH:
t = payload[_T] # event name str
s = payload[_S] # seq int
self._logger.log(ux.TRACE, "dispatching %s with seq %s", t, s)
self._dispatch(t, s, d)
elif op == _HEARTBEAT:
await self._send_heartbeat_ack()
self._logger.log(ux.TRACE, "sent HEARTBEAT")
elif op == _HEARTBEAT_ACK:
now = date.monotonic()
self._last_heartbeat_ack_received = now
self._heartbeat_latency = now - self._last_heartbeat_sent
self._logger.log(ux.TRACE, "received HEARTBEAT ACK in %.1fms", self._heartbeat_latency * 1_000)
elif op == _RECONNECT:
# We should be able to resume...
self._logger.debug("received instruction to reconnect, will resume existing session")
return True
elif op == _INVALID_SESSION:
# We can resume if the payload was `true`.
if not d:
self._logger.debug("received invalid session, will need to start a new session")
self._seq = None
self._session_id = None
else:
self._logger.debug("received invalid session, will resume existing session")
return True
else:
self._logger.log(ux.TRACE, "unknown opcode %s received, it will be ignored...", op)

return None

async def _resume(self) -> None:
await self._ws.send_json( # type: ignore[union-attr]
{
Expand All @@ -721,50 +759,51 @@ async def _run(self) -> None:
initial_increment=_BACKOFF_INCREMENT_START,
)

try:
while True:
if date.monotonic() - last_started_at < _BACKOFF_WINDOW:
time = next(backoff)
self._logger.debug("backing off reconnecting for %.2fs to prevent spam", time)

try:
await asyncio.wait_for(self._closing.wait(), timeout=time)
# We were told to close.
return
except asyncio.TimeoutError:
# We are going to run once.
pass
while True:
if date.monotonic() - last_started_at < _BACKOFF_WINDOW:
time = next(backoff)
self._logger.info("backing off reconnecting for %.2fs", time)

try:
last_started_at = date.monotonic()
if not await self._run_once():
self._logger.debug("shard has shut down")
await asyncio.wait_for(self._closing.wait(), timeout=time)
# We were told to close.
return
except asyncio.TimeoutError:
# We are going to run once.
pass

except errors.GatewayConnectionError as ex:
self._logger.error(
"failed to communicate with server, reason was: %s. Will retry shortly", ex.__cause__,
)
try:
last_started_at = date.monotonic()
if not await self._run_once():
self._logger.debug("shard has shut down")

except errors.GatewayConnectionError as ex:
self._logger.error(
"failed to communicate with server, reason was: %s. Will retry shortly",
ex.__cause__,
)

except errors.GatewayServerClosedConnectionError as ex:
if not ex.can_reconnect:
raise
except errors.GatewayServerClosedConnectionError as ex:
if not ex.can_reconnect:
raise

self._logger.info(
"server has closed connection, will reconnect if possible [code:%s, reason:%s]",
ex.code,
ex.reason,
)
self._logger.info(
"server has closed connection, will reconnect if possible [code:%s, reason:%s]",
ex.code,
ex.reason,
)

except errors.GatewayError as ex:
self._logger.debug("encountered generic gateway error", exc_info=ex)
raise
except errors.GatewayError as ex:
self._logger.debug("encountered generic gateway error", exc_info=ex)
raise

except Exception as ex:
self._logger.debug("encountered some unhandled error", exc_info=ex)
raise
finally:
self._closed.set()
self._logger.info("shard %s has shut down permanently", self._shard_id)
except Exception as ex:
self._logger.debug("encountered some unhandled error", exc_info=ex)
raise

finally:
self._closed.set()
self._logger.info("shard %s has shut down", self._shard_id)

async def _run_once(self) -> bool:
self._closing.clear()
Expand All @@ -788,27 +827,7 @@ async def _run_once(self) -> bool:
self._event_consumer(self, "CONNECTED", {})
dispatch_disconnect = True

# Expect HELLO.
payload = await self._ws.receive_json()
if payload[_OP] != _HELLO:
self._logger.debug(
"expected HELLO opcode, received %s which makes no sense, closing with PROTOCOL ERROR ",
"(_run_once => raise and do not reconnect)",
payload[_OP],
)
await self._ws.send_close(code=errors.ShardCloseCode.PROTOCOL_ERROR, message=b"Expected HELLO op")
raise errors.GatewayError(f"Expected opcode {_HELLO}, but received {payload[_OP]}")

heartbeat_latency = float(payload[_D]["heartbeat_interval"]) / 1_000.0
heartbeat_task = asyncio.create_task(self._heartbeat(heartbeat_latency))

if self._closing.is_set():
self._logger.debug(
"closing flag was set before we could handshake, disconnecting with GOING AWAY "
"(_run_once => do not reconnect)"
)
await self._ws.send_close(code=errors.ShardCloseCode.GOING_AWAY, message=b"shard disconnecting")
return False
heartbeat_task = await self._wait_for_hello()

try:
if self._seq is not None:
Expand All @@ -823,48 +842,21 @@ async def _run_once(self) -> bool:
"closing flag was set during handshake, disconnecting with GOING AWAY "
"(_run_once => do not reconnect)"
)
await self._ws.send_close(code=errors.ShardCloseCode.GOING_AWAY, message=b"shard disconnecting")
await self._ws.send_close( # type: ignore[union-attr]
code=errors.ShardCloseCode.GOING_AWAY, message=b"shard disconnecting"
)
return False

# Event polling.
while not self._closing.is_set() and not heartbeat_task.done() and not heartbeat_task.cancelled():
try:
payload = await self._ws.receive_json(timeout=5)
result = await self._poll_events()

if result is not None:
return result
except asyncio.TimeoutError:
# Don't wait forever, check if the heartbeat has died.
continue

op = payload[_OP] # opcode int
d = payload[_D] # data/payload. Usually a dict or a bool for INVALID_SESSION

if op == _DISPATCH:
t = payload[_T] # event name str
s = payload[_S] # seq int
self._logger.log(ux.TRACE, "dispatching %s with seq %s", t, s)
self._dispatch(t, s, d)
elif op == _HEARTBEAT:
await self._send_heartbeat_ack()
self._logger.log(ux.TRACE, "sent HEARTBEAT")
elif op == _HEARTBEAT_ACK:
now = date.monotonic()
self._last_heartbeat_ack_received = now
self._heartbeat_latency = now - self._last_heartbeat_sent
self._logger.log(ux.TRACE, "received HEARTBEAT ACK in %.1fms", self._heartbeat_latency * 1_000)
elif op == _RECONNECT:
# We should be able to resume...
self._logger.debug("received instruction to reconnect, will resume existing session")
return True
elif op == _INVALID_SESSION:
# We can resume if the payload was `true`.
if not d:
self._logger.debug("received invalid session, will need to start a new session")
self._seq = None
self._session_id = None
else:
self._logger.debug("received invalid session, will resume existing session")
return True
else:
self._logger.log(ux.TRACE, "unknown opcode %s received, it will be ignored...", op)
# We should check if the shard is still alive and then poll again after.
pass

# If the heartbeat died due to an error, it should be raised here.
# This will currently allow us to try to resume if that happens
Expand All @@ -882,7 +874,10 @@ async def _run_once(self) -> bool:
"shard has requested graceful termination, so will not attempt to reconnect "
"(_run_once => do not reconnect)"
)
await self._ws.send_close(code=errors.ShardCloseCode.GOING_AWAY, message=b"shard disconnecting")
await self._ws.send_close( # type: ignore[union-attr]
code=errors.ShardCloseCode.GOING_AWAY,
message=b"shard disconnecting",
)
return False

finally:
Expand All @@ -899,7 +894,7 @@ async def _run_once(self) -> bool:

# Check if we made the socket close or handled it. If we didn't, we should always try to
# reconnect, as aiohttp is probably closing it internally without telling us properly.
if not ws.sent_close:
if not ws.sent_close: # type: ignore[union-attr]
return True

async def _send_heartbeat(self) -> None:
Expand Down Expand Up @@ -962,3 +957,33 @@ def _serialize_datetime(dt: typing.Optional[datetime.datetime]) -> typing.Option
return None

return int(dt.timestamp() * 1_000)

async def _wait_for_hello(self) -> asyncio.Task[bool]:
# Expect HELLO.
payload = await self._ws.receive_json() # type: ignore[union-attr]
if payload[_OP] != _HELLO:
self._logger.debug(
"expected HELLO opcode, received %s which makes no sense, closing with PROTOCOL ERROR ",
"(_run_once => raise and do not reconnect)",
payload[_OP],
)
await self._ws.send_close( # type: ignore[union-attr]
code=errors.ShardCloseCode.PROTOCOL_ERROR,
message=b"Expected HELLO op",
)
raise errors.GatewayError(f"Expected opcode {_HELLO}, but received {payload[_OP]}")

if self._closing.is_set():
self._logger.debug(
"closing flag was set before we could handshake, disconnecting with GOING AWAY "
"(_run_once => do not reconnect)"
)
await self._ws.send_close( # type: ignore[union-attr]
code=errors.ShardCloseCode.GOING_AWAY,
message=b"shard disconnecting",
)
raise asyncio.CancelledError("closing flag was set before we could handshake")

heartbeat_interval = float(payload[_D]["heartbeat_interval"]) / 1_000.0
heartbeat_task = asyncio.create_task(self._heartbeat(heartbeat_interval))
return heartbeat_task
11 changes: 4 additions & 7 deletions tests/hikari/impl/test_rate_limits.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,8 +450,7 @@ def test_increment_raises_on_numerical_limitation(self):
base=5, maximum=sys.float_info.max, jitter_multiplier=0.0, initial_increment=power
)

with pytest.raises(asyncio.TimeoutError):
next(eb)
assert next(eb) == sys.float_info.max

def test_increment_maximum(self):
max_bound = 64
Expand All @@ -460,16 +459,14 @@ def test_increment_maximum(self):
for _ in range(iterations):
next(eb)

with pytest.raises(asyncio.TimeoutError):
next(eb)
assert next(eb) == max_bound

def test_increment_does_not_increment_when_on_maximum(self):
eb = rate_limits.ExponentialBackOff(2, 32, initial_increment=5)
eb = rate_limits.ExponentialBackOff(2, 32, initial_increment=5, jitter_multiplier=0)

assert eb.increment == 5

with pytest.raises(asyncio.TimeoutError):
next(eb)
assert next(eb) == 32

assert eb.increment == 5

Expand Down
Loading

0 comments on commit 6af035e

Please sign in to comment.