Skip to content

Commit

Permalink
New tests (#249)
Browse files Browse the repository at this point in the history
- Improve RateLimitedError message
  • Loading branch information
davfsa authored Oct 4, 2020
1 parent 5c7460b commit 55a9673
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 27 deletions.
6 changes: 3 additions & 3 deletions hikari/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,10 +457,10 @@ class RateLimitedError(ClientHTTPResponseError):
status: http.HTTPStatus = attr.ib(default=http.HTTPStatus.TOO_MANY_REQUESTS, init=False)
"""The HTTP status code for the response."""

reason: str = attr.ib(init=False)
"""The error reason."""
message: str = attr.ib(init=False)
"""The error message."""

@reason.default
@message.default
def _(self) -> str:
return f"You are being rate-limited for {self.retry_after:,} seconds on route {self.route}. Please slow down!"

Expand Down
3 changes: 2 additions & 1 deletion hikari/impl/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,8 +603,9 @@ async def _request(
except self._RetryRequest:
continue

@staticmethod
@typing.final
def _stringify_http_message(self, headers: data_binding.Headers, body: typing.Any) -> str:
def _stringify_http_message(headers: data_binding.Headers, body: typing.Any) -> str:
string = "\n".join(
f" {name}: {value}" if name != _AUTHORIZATION_HEADER else f" {name}: **REDACTED TOKEN**"
for name, value in headers.items()
Expand Down
154 changes: 142 additions & 12 deletions tests/hikari/impl/test_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from hikari.impl import special_endpoints
from hikari.internal import net
from hikari.internal import routes
from hikari.internal import ux
from tests.hikari import client_session_stub
from tests.hikari import hikari_test_helpers

Expand Down Expand Up @@ -117,8 +118,20 @@ class StubRestClient:
return StubRestClient()

@pytest.fixture()
def rest_provider(self, rest_client):
return rest._RESTProvider(lambda: mock.Mock(), None, lambda: mock.Mock(), lambda: rest_client)
def cache(self):
return mock.Mock()

@pytest.fixture()
def executor(self):
return mock.Mock()

@pytest.fixture()
def entity_factory(self):
return mock.Mock()

@pytest.fixture()
def rest_provider(self, rest_client, cache, executor, entity_factory):
return rest._RESTProvider(lambda: entity_factory, executor, lambda: cache, lambda: rest_client)

def test_rest_property(self, rest_provider, rest_client):
assert rest_provider.rest == rest_client
Expand All @@ -129,6 +142,18 @@ def test_http_settings_property(self, rest_provider, rest_client):
def test_proxy_settings_property(self, rest_provider, rest_client):
assert rest_provider.proxy_settings == rest_client.proxy_settings

def test_entity_factory_property(self, rest_provider, entity_factory):
assert rest_provider.entity_factory == entity_factory

def test_cache_property(self, rest_provider, cache):
assert rest_provider.cache == cache

def test_executor_property(self, rest_provider, executor):
assert rest_provider.executor == executor

def test_me_property(self, rest_provider, cache):
assert rest_provider.me == cache.get_me()


###########
# RESTApp #
Expand All @@ -148,6 +173,24 @@ def rest_app():


class TestRESTApp:
def test__init__when_connector_factory_is_None(self):
http_settings = object()

with mock.patch.object(rest, "BasicLazyCachedTCPConnectorFactory") as factory:
rest_app = rest.RESTApp(
connector_factory=None,
connector_owner=False,
executor=None,
http_settings=http_settings,
proxy_settings=None,
url=None,
)

factory.assert_called_once_with(http_settings)

assert rest_app._connector_factory is factory()
assert rest_app._connector_owner is True

def test_executor_property(self, rest_app):
mock_executor = object()
rest_app._executor = mock_executor
Expand All @@ -166,10 +209,9 @@ def test_proxy_settings(self, rest_app):
def test_acquire(self, rest_app):
mock_event_loop = object()
rest_app._event_loop = mock_event_loop
mock_entity_factory = object()

stack = contextlib.ExitStack()
stack.enter_context(mock.patch.object(entity_factory, "EntityFactoryImpl", return_value=mock_entity_factory))
_entity_factory = stack.enter_context(mock.patch.object(entity_factory, "EntityFactoryImpl"))
mock_client = stack.enter_context(mock.patch.object(rest, rest.RESTClientImpl.__qualname__))
stack.enter_context(mock.patch.object(asyncio, "get_running_loop", return_value=mock_event_loop))

Expand All @@ -179,7 +221,7 @@ def test_acquire(self, rest_app):
mock_client.assert_called_once_with(
connector_factory=rest_app._connector_factory,
connector_owner=rest_app._connector_owner,
entity_factory=mock_entity_factory,
entity_factory=_entity_factory(),
executor=rest_app._executor,
http_settings=rest_app._http_settings,
proxy_settings=rest_app._proxy_settings,
Expand All @@ -188,6 +230,26 @@ def test_acquire(self, rest_app):
rest_url=rest_app._url,
)

def test_acquire_when_even_loop_not_set(self, rest_app):
mock_event_loop = object()

stack = contextlib.ExitStack()
_entity_factory = stack.enter_context(mock.patch.object(entity_factory, "EntityFactoryImpl"))
stack.enter_context(mock.patch.object(rest, rest.RESTClientImpl.__qualname__))
stack.enter_context(mock.patch.object(asyncio, "get_running_loop", return_value=mock_event_loop))

with stack:
rest_app.acquire(token="token", token_type="Type")

assert rest_app._event_loop is mock_event_loop

# This is just to test the lambdas so it counts towards coverage
assert _entity_factory.call_count == 1
factory = _entity_factory.call_args_list[0][0][0]
factory.entity_factory
factory.cache
factory.rest

def test_acquire_when__event_loop_and_loop_do_not_equal(self, rest_app):
rest_app._event_loop = object()
with mock.patch.object(asyncio, "get_running_loop"):
Expand Down Expand Up @@ -479,6 +541,19 @@ def test__generate_allowed_mentions(self, rest_client, function_input, expected_
def test__transform_emoji_to_url_format(self, rest_client, emoji, expected_return):
assert rest_client._transform_emoji_to_url_format(emoji) == expected_return

def test__stringify_http_message_when_body_is_None(self, rest_client):
headers = {"HEADER1": "value1", "HEADER2": "value2", "Authorization": "this will never see the light of day"}
expected_return = " HEADER1: value1\n HEADER2: value2\n Authorization: **REDACTED TOKEN**"
assert rest_client._stringify_http_message(headers, None) == expected_return

@pytest.mark.parametrize(("body", "expected"), [(bytes("hello :)", "ascii"), "hello :)"), (123, "123")])
def test__stringify_http_message_when_body_is_not_None(self, rest_client, body, expected):
headers = {"HEADER1": "value1", "HEADER2": "value2", "Authorization": "this will never see the light of day"}
expected_return = (
f" HEADER1: value1\n HEADER2: value2\n Authorization: **REDACTED TOKEN**\n\n {expected}"
)
assert rest_client._stringify_http_message(headers, body) == expected_return

#######################
# Non-async endpoints #
#######################
Expand Down Expand Up @@ -743,6 +818,14 @@ class ExitException(Exception):

return ExitException

async def test___aenter__and__aexit__(self, rest_client):
with mock.patch.object(rest_client, "close") as close:
async with rest_client as client:
assert client is rest_client
close.assert_not_called()

close.assert_awaited_once_with()

@hikari_test_helpers.timeout()
async def test__request_when_buckets_not_started(self, rest_client, exit_exception):
route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123)
Expand Down Expand Up @@ -770,6 +853,7 @@ async def test__request_when__token_is_None(self, rest_client, exit_exception):
rest_client.buckets.is_started = True
rest_client._token = None
rest_client._acquire_client_session = mock.Mock(return_value=mock_session)
rest_client._stringify_http_message = mock.Mock()
with mock.patch.object(asyncio, "gather", new=mock.AsyncMock()):
with pytest.raises(exit_exception):
await rest_client._request(route)
Expand All @@ -784,12 +868,13 @@ async def test__request_when__token_is_not_None(self, rest_client, exit_exceptio
rest_client.buckets.is_started = True
rest_client._token = "token"
rest_client._acquire_client_session = mock.Mock(return_value=mock_session)
rest_client._stringify_http_message = mock.Mock()
with mock.patch.object(asyncio, "gather", new=mock.AsyncMock()):
with pytest.raises(exit_exception):
await rest_client._request(route)

_, kwargs = mock_session.request.call_args_list[0]
assert kwargs["headers"][rest._AUTHORIZATION_HEADER] == "token"
_, kwargs = mock_session.request.call_args_list[0]
assert kwargs["headers"][rest._AUTHORIZATION_HEADER] == "token"

@hikari_test_helpers.timeout()
async def test__request_when_no_auth_passed(self, rest_client, exit_exception):
Expand All @@ -798,6 +883,7 @@ async def test__request_when_no_auth_passed(self, rest_client, exit_exception):
rest_client.buckets.is_started = True
rest_client._token = "token"
rest_client._acquire_client_session = mock.Mock(return_value=mock_session)
rest_client._stringify_http_message = mock.Mock()
with mock.patch.object(asyncio, "gather", new=mock.AsyncMock()):
with pytest.raises(exit_exception):
await rest_client._request(route, no_auth=True)
Expand All @@ -814,9 +900,9 @@ class StubResponse:
route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123)
mock_session = mock.AsyncMock(request=mock.AsyncMock(return_value=StubResponse()))
rest_client.buckets.is_started = True
rest_client._debug = False
rest_client._parse_ratelimits = mock.AsyncMock()
rest_client._acquire_client_session = mock.Mock(return_value=mock_session)
rest_client._stringify_http_message = mock.Mock()
with mock.patch.object(asyncio, "gather", new=mock.AsyncMock()):
assert (await rest_client._request(route)) is None

Expand All @@ -834,9 +920,9 @@ async def read(self):
route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123)
mock_session = mock.AsyncMock(request=mock.AsyncMock(return_value=StubResponse()))
rest_client.buckets.is_started = True
rest_client._debug = True
rest_client._parse_ratelimits = mock.AsyncMock()
rest_client._acquire_client_session = mock.Mock(return_value=mock_session)
rest_client._stringify_http_message = mock.Mock()
with mock.patch.object(asyncio, "gather", new=mock.AsyncMock()):
assert (await rest_client._request(route)) == {"something": None}

Expand All @@ -851,9 +937,9 @@ class StubResponse:
route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123)
mock_session = mock.AsyncMock(request=mock.AsyncMock(return_value=StubResponse()))
rest_client.buckets.is_started = True
rest_client._debug = False
rest_client._parse_ratelimits = mock.AsyncMock()
rest_client._acquire_client_session = mock.Mock(return_value=mock_session)
rest_client._stringify_http_message = mock.Mock()
with mock.patch.object(asyncio, "gather", new=mock.AsyncMock()):
with pytest.raises(errors.HTTPError):
await rest_client._request(route)
Expand All @@ -868,10 +954,10 @@ class StubResponse:
route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123)
mock_session = mock.AsyncMock(request=mock.AsyncMock(return_value=StubResponse()))
rest_client.buckets.is_started = True
rest_client._debug = False
rest_client._parse_ratelimits = mock.AsyncMock()
rest_client._handle_error_response = mock.AsyncMock(side_effect=exit_exception)
rest_client._acquire_client_session = mock.Mock(return_value=mock_session)
rest_client._stringify_http_message = mock.Mock()
with mock.patch.object(asyncio, "gather", new=mock.AsyncMock()):
with pytest.raises(exit_exception):
await rest_client._request(route)
Expand All @@ -886,12 +972,39 @@ class StubResponse:
route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123)
mock_session = mock.AsyncMock(request=mock.AsyncMock(side_effect=[rest_client._RetryRequest, exit_exception]))
rest_client.buckets.is_started = True
rest_client._debug = False
rest_client._acquire_client_session = mock.Mock(return_value=mock_session)
with mock.patch.object(asyncio, "gather", new=mock.AsyncMock()):
with pytest.raises(exit_exception):
await rest_client._request(route)

@pytest.mark.parametrize("logger_level", [ux.TRACE, ux.TRACE + 10])
@hikari_test_helpers.timeout()
async def test__request_logger(self, rest_client, exit_exception, logger_level):
class StubResponse:
status = http.HTTPStatus.NO_CONTENT
headers = {}
reason = "cause why not"

async def read(self):
return None

route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123)
mock_session = mock.AsyncMock(request=mock.AsyncMock(return_value=StubResponse()))
rest_client.buckets.is_started = True
rest_client._parse_ratelimits = mock.AsyncMock()
rest_client._acquire_client_session = mock.Mock(return_value=mock_session)

with mock.patch.object(
rest, "_LOGGER", new=mock.Mock(getEffectiveLevel=mock.Mock(return_value=logger_level))
) as logger:
with mock.patch.object(asyncio, "gather", new=mock.AsyncMock()):
await rest_client._request(route)

if logger_level > ux.TRACE:
assert logger.log.call_count == 0
else:
assert logger.log.call_count == 2

async def test__handle_error_response(self, rest_client, exit_exception):
mock_response = mock.Mock()
with mock.patch.object(net, "generate_error_response", return_value=exit_exception) as generate_error_response:
Expand Down Expand Up @@ -989,6 +1102,23 @@ async def json(self):
with pytest.raises(rest_client._RetryRequest):
await rest_client._parse_ratelimits(route, StubResponse())

async def test__parse_ratelimits_when_retry_after_is_not_close_enough(self, rest_client):
class StubResponse:
status = http.HTTPStatus.TOO_MANY_REQUESTS
content_type = rest._APPLICATION_JSON
headers = {
rest._DATE_HEADER: "Thu, 02 Jul 2020 20:55:08 GMT",
rest._X_RATELIMIT_RESET_AFTER_HEADER: "0.002",
}
real_url = "https://some.url"

async def json(self):
return {"retry_after": "4"}

route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123)
with pytest.raises(errors.RateLimitedError):
await rest_client._parse_ratelimits(route, StubResponse())

async def test_close_when__client_session_is_None(self, rest_client):
rest_client._client_session = None
rest_client._connector_factory = mock.AsyncMock()
Expand Down
15 changes: 6 additions & 9 deletions tests/hikari/internal/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,22 +210,19 @@ def test___setitem___removes_old_entry_instead_of_replacing(self):
mock_map["ok"] = "foo"
assert list(mock_map.items())[2] == ("ok", "foo")

# TODO: fix this so that it is not flaky.
# https://travis-ci.org/github/nekokatt/hikari/jobs/724494888#L797
@pytest.mark.skip("flaky test, might fail on Windows runners.")
@pytest.mark.asyncio
async def test___setitem___garbage_collection(self):
mock_map = collections.TimedCacheMap(
expiry=datetime.timedelta(seconds=hikari_test_helpers.REASONABLE_QUICK_RESPONSE_TIME * 3)
)
mock_map.update({"OK": "no", "blam": "booga"})
mock_map["OK"] = "no"
await asyncio.sleep(hikari_test_helpers.REASONABLE_QUICK_RESPONSE_TIME * 2)
assert mock_map == {"OK": "no", "blam": "booga"}
mock_map.update({"ayanami": "rei", "owo": "awoo"})
assert mock_map == {"OK": "no", "blam": "booga", "ayanami": "rei", "owo": "awoo"}
assert mock_map == {"OK": "no"}
mock_map["ayanami"] = "rei"
assert mock_map == {"OK": "no", "ayanami": "rei"}
await asyncio.sleep(hikari_test_helpers.REASONABLE_QUICK_RESPONSE_TIME * 2)
mock_map.update({"nyaa": "qt"})
assert mock_map == {"ayanami": "rei", "owo": "awoo", "nyaa": "qt"}
mock_map["nyaa"] = "qt"
assert mock_map == {"ayanami": "rei", "nyaa": "qt"}


class TestLimitedCapacityCacheMap:
Expand Down
10 changes: 8 additions & 2 deletions tests/hikari/test_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def test_format_icon_when_hash_is_None(self, model):
assert model.format_icon() is None


@pytest.mark.asyncio
class TestTextChannel:
@pytest.fixture()
def model(self, mock_app):
Expand All @@ -196,7 +197,6 @@ def model(self, mock_app):
type=channels.ChannelType.GUILD_TEXT,
)

@pytest.mark.asyncio
async def test_history(self, model):
model.app.rest.fetch_messages = mock.AsyncMock()

Expand All @@ -213,7 +213,6 @@ async def test_history(self, model):
around=datetime.datetime(2020, 4, 1, 0, 30, 0),
)

@pytest.mark.asyncio
async def test_send(self, model):
model.app.rest.create_message = mock.AsyncMock()
mock_attachment = object()
Expand Down Expand Up @@ -245,6 +244,13 @@ async def test_send(self, model):
role_mentions=[789, 567],
)

def test_trigger_typing(self, model):
model.app.rest.trigger_typing = mock.Mock()

model.trigger_typing()

model.app.rest.trigger_typing.assert_called_once_with(12345679)


class TestGuildChannel:
@pytest.fixture()
Expand Down

0 comments on commit 55a9673

Please sign in to comment.