diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 130ddf50d2..67c4c06109 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -4,6 +4,10 @@ on: push: pull_request: +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + jobs: test: strategy: @@ -20,7 +24,7 @@ jobs: uses: actions/checkout@v3 - name: Setup python - uses: actions/setup-python@v3 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} @@ -43,7 +47,7 @@ jobs: mv .coverage .coverage.${{ matrix.os }}.${{ matrix.python-version }} - name: Upload coverage - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: coverage path: .coverage.${{ matrix.os }}.${{ matrix.python-version }} @@ -60,12 +64,12 @@ jobs: uses: actions/checkout@v3 - name: Setup python - uses: actions/setup-python@v3 + uses: actions/setup-python@v4 with: python-version: 3.8 - name: Download coverage - uses: actions/download-artifact@v2 + uses: actions/download-artifact@v3 with: name: coverage @@ -90,7 +94,7 @@ jobs: uses: actions/checkout@v3 - name: Setup python - uses: actions/setup-python@v3 + uses: actions/setup-python@v4 with: python-version: 3.8 @@ -117,7 +121,7 @@ jobs: uses: actions/checkout@v3 - name: Setup python - uses: actions/setup-python@v3 + uses: actions/setup-python@v4 with: python-version: 3.8 @@ -134,7 +138,7 @@ jobs: uses: actions/checkout@v3 - name: Setup python - uses: actions/setup-python@v3 + uses: actions/setup-python@v4 with: python-version: 3.8 @@ -148,7 +152,7 @@ jobs: - name: Upload artifacts if: github.event_name != 'release' - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: pages path: public @@ -159,21 +163,17 @@ jobs: # other jobs are subject to change ci-done: needs: [upload-coverage, linting, twemoji, pages] - if: always() + if: always() && !cancelled() runs-on: ubuntu-latest steps: - - name: Mark status based on past job status + - name: Set status based on required jobs env: - # All new need jobs need to be added here with the prefix "RESULT_" - RESULT_UPLOAD_COVERAGE: ${{ needs.upload-coverage.result }} - RESULT_LINTING: ${{ needs.linting.result }} - RESULT_TWEMOJI: ${{ needs.twemoji.result }} - RESULT_PAGES: ${{ needs.pages.result }} + RESULTS: ${{ join(needs.*.result, ' ') }} run: | - if [ "$(env | grep 'RESULT_')" = "$(env | grep "RESULT_" | grep '=success')" ]; then - exit 0 - else - exit 1 - fi + for result in $RESULTS; do + if [ "$result" != "success" ]; then + exit 1 + fi + done diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index 5aece32749..520879b98e 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -8,6 +8,10 @@ on: schedule: - cron: "0 0 * * *" # Every day at 00:00 +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + jobs: analyze: name: Analyze @@ -18,9 +22,9 @@ jobs: uses: actions/checkout@v3 - name: Initialize CodeQL - uses: github/codeql-action/init@v1 + uses: github/codeql-action/init@v2 with: languages: python - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v1 + uses: github/codeql-action/analyze@v2 diff --git a/.github/workflows/fragments-check.yml b/.github/workflows/fragments-check.yml index 5f40c08c16..f261a89e90 100644 --- a/.github/workflows/fragments-check.yml +++ b/.github/workflows/fragments-check.yml @@ -6,6 +6,10 @@ on: branches: - master +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + jobs: check-fragment-added: if: github.event.pull_request.user.type != 'Bot' && !contains(github.event.pull_request.labels.*.name, 'skip-fragment-check') @@ -20,7 +24,7 @@ jobs: fetch-depth: 0 - name: Setup python - uses: actions/setup-python@v3 + uses: actions/setup-python@v4 with: python-version: 3.8 diff --git a/.github/workflows/prepare-release.yml b/.github/workflows/prepare-release.yml index 91b1dfb02a..96da9a7512 100644 --- a/.github/workflows/prepare-release.yml +++ b/.github/workflows/prepare-release.yml @@ -16,7 +16,7 @@ jobs: uses: actions/checkout@v3 - name: Setup python - uses: actions/setup-python@v3 + uses: actions/setup-python@v4 with: python-version: 3.8 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 73f30e0ad5..45678b1f2d 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -15,7 +15,7 @@ jobs: token: ${{ secrets.PAT_TOKEN }} - name: Setup python - uses: actions/setup-python@v3 + uses: actions/setup-python@v4 with: python-version: 3.8 diff --git a/CHANGELOG.md b/CHANGELOG.md index c480359176..a3106cb194 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,55 @@ This file is updated every release with the use of `towncrier` from the fragment .. towncrier release notes start +Hikari 2.0.0.dev109 (2022-06-26) +================================ + +Breaking Changes +---------------- + +- Removal of all application commands v1 related fields and endpoints. + - Discord has completely disabled some endpoints, so we unfortunately can't + deprecate them instead of removing them ([#1148](https://github.com/hikari-py/hikari/issues/1148)) +- Removed the `resolved` attribute from `AutocompleteInteraction` as autocomplete interactions never have resolved objects. ([#1152](https://github.com/hikari-py/hikari/issues/1152)) +- `build` methods are now typed as returning `MutableMapping[str, typing.Any]`. ([#1164](https://github.com/hikari-py/hikari/issues/1164)) + + +Deprecation +----------- + +- `messages.Mentions` object deprecated + - Alternatives can be found in the base message object ([#1149](https://github.com/hikari-py/hikari/issues/1149)) + + +Features +-------- + +- Add `create` method to `CommandBuilder`. ([#1016](https://github.com/hikari-py/hikari/issues/1016)) +- Support for attachments in REST-based interaction responses. ([#1048](https://github.com/hikari-py/hikari/issues/1048)) +- Add option to disable automatic member chunking. + Added the `auto_chunk_members` kwarg to `GatewayBot` and `EventManagerImpl`, which when `False` will disable automatic member chunking. ([#1084](https://github.com/hikari-py/hikari/issues/1084)) +- Allow passing multiple event types to the listen decorator. + Parse union type hints for events if listen decorator is empty. ([#1103](https://github.com/hikari-py/hikari/issues/1103)) +- Animated guild banner support. ([#1116](https://github.com/hikari-py/hikari/issues/1116)) +- Implement application commands permission v2. + - New `default_member_permissions` and `is_dm_enabled` related fields. + - Added `hikari.events.application_events.ApplicationCommandPermissionsUpdate`. + - Added `APPLICATION_COMMAND_PERMISSION_UPDATE` audit log entry ([#1148](https://github.com/hikari-py/hikari/issues/1148)) + + +Bugfixes +-------- + +- Improved pyright support. ([#1108](https://github.com/hikari-py/hikari/issues/1108)) +- `RESTClientImpl.fetch_bans` now return a `LazyIterator` to allow pagination of values. ([#1119](https://github.com/hikari-py/hikari/issues/1119)) +- Fix unicode decode error caused by `latin-1` encoding when sending the banner. ([#1120](https://github.com/hikari-py/hikari/issues/1120)) +- Don't error on an out-of-spec HTTP status code (e.g one of Cloudflare's custom status codes). + `HTTPResponseError.status` may now be of type `http.HTTPStatus` or `int`. ([#1121](https://github.com/hikari-py/hikari/issues/1121)) +- Fix name of polish locale (`hikari.Locale.OL` -> `hikari.Locale.PL`) ([#1144](https://github.com/hikari-py/hikari/issues/1144)) +- Properly garbage collect message references in the cache + - Properly deserialize `PartialMessage.referenced_message` as a partial message ([#1192](https://github.com/hikari-py/hikari/issues/1192)) + + Hikari 2.0.0.dev108 (2022-03-27) ================================ diff --git a/README.md b/README.md index 88ada16509..88a34a9766 100644 --- a/README.md +++ b/README.md @@ -144,6 +144,7 @@ Hikari does not include a command framework by default, so you will want to pick - [`lightbulb`](https://github.com/tandemdude/hikari-lightbulb) - a simple and easy to use command framework for Hikari. - [`tanjun`](https://github.com/FasterSpeeding/Tanjun) - a flexible command framework designed to extend Hikari. +- [`crescent`](https://github.com/magpie-dev/hikari-crescent) - a command handler for Hikari that keeps your project neat and tidy. --- diff --git a/changes/1016.feature.md b/changes/1016.feature.md deleted file mode 100644 index f8168d8da9..0000000000 --- a/changes/1016.feature.md +++ /dev/null @@ -1 +0,0 @@ -Add `create` method to `CommandBuilder`. diff --git a/changes/1048.feature.md b/changes/1048.feature.md deleted file mode 100644 index d6b99b2000..0000000000 --- a/changes/1048.feature.md +++ /dev/null @@ -1 +0,0 @@ -Support for attachments in REST-based interaction responses. diff --git a/changes/1103.feature.md b/changes/1103.feature.md deleted file mode 100644 index c5eae50102..0000000000 --- a/changes/1103.feature.md +++ /dev/null @@ -1,2 +0,0 @@ -Allow passing multiple event types to the listen decorator. -Parse union type hints for events if listen decorator is empty. diff --git a/changes/1116.feature.md b/changes/1116.feature.md deleted file mode 100644 index edf0dd6bbb..0000000000 --- a/changes/1116.feature.md +++ /dev/null @@ -1 +0,0 @@ -Animated guild banner support. diff --git a/changes/1189.feature.md b/changes/1189.feature.md new file mode 100644 index 0000000000..6fa6c63424 --- /dev/null +++ b/changes/1189.feature.md @@ -0,0 +1 @@ +`GuildVoiceChannel` now inherits from `TextableGuildChannel` instead of `GuildChannel`. diff --git a/changes/1201.feature.md b/changes/1201.feature.md new file mode 100644 index 0000000000..e00faef7c6 --- /dev/null +++ b/changes/1201.feature.md @@ -0,0 +1 @@ +Add the `app_permissions` field to command and component interactions. diff --git a/dev-requirements.txt b/dev-requirements.txt index 71bd77170a..61b60e3033 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -6,13 +6,13 @@ mock==4.0.3 # Py.test stuff. -pytest==7.1.1 +pytest==7.1.2 pytest-asyncio==0.18.3 pytest-cov==3.0.0 -pytest-randomly==3.11.0 +pytest-randomly==3.12.0 # Coverage testing. -coverage[toml]==6.3.2 +coverage[toml]==6.4.1 # Other stuff async-timeout==4.0.2 # Used for timeouts in some test cases. @@ -28,27 +28,27 @@ sphobjinv==2.2.2 # TYPE CHECKING # ################# -mypy==0.942 -pyright==1.1.234 +mypy==0.961 +pyright==1.1.257 ####################### # DEPENDENCY CHECKING # ####################### -safety~=1.10.3 +safety~=2.0.0 ############## # FORMATTING # ############## -black==22.3.0 +black==22.6.0 isort==5.10.1 ########### # Linting # ########### -slotscheck==0.14.0 +slotscheck==0.14.1 ################## # SPELL CHECKING # diff --git a/docs/documentation.mako b/docs/documentation.mako index eedaa05f38..d90ad8c6b4 100644 --- a/docs/documentation.mako +++ b/docs/documentation.mako @@ -550,7 +550,7 @@ sphobjinv.DataObjStr( name = name, domain = "py", - role = "var", + role = "variable", uri = v.url(), priority = "1", dispname = "-", @@ -602,7 +602,7 @@ sphobjinv.DataObjStr( name = f.obj.__module__ + "." + f.obj.__qualname__, domain = "py", - role = "func", + role = "function", uri = f.url(), priority = "1", dispname = "-", diff --git a/examples/__init__.py b/examples/__init__.py index 10e6d281ed..81d74ef3bb 100644 --- a/examples/__init__.py +++ b/examples/__init__.py @@ -7,4 +7,4 @@ # # You should have received a copy of the CC0 Public Domain Dedication along with this software. # If not, see . -"""Allows mypy to run here.""" +"""Allows type-checkers to run here.""" diff --git a/flake8-requirements.txt b/flake8-requirements.txt index eb1d917c10..c4826d936c 100644 --- a/flake8-requirements.txt +++ b/flake8-requirements.txt @@ -4,23 +4,22 @@ flake8==4.0.1 # Ref: https://github.com/DmytroLitvinov/awesome-flake8-extensions flake8-bandit~=3.0.0 # runs bandit -flake8-black==0.3.2 # runs black +flake8-black==0.3.3 # runs black flake8-broken-line==0.4.0 # forbey "\" linebreaks flake8-builtins==1.5.3 # builtin shadowing checks flake8-coding==1.3.2 # coding magic-comment detection -flake8-comprehensions==3.8.0 # comprehension checks +flake8-comprehensions==3.10.0 # comprehension checks flake8-deprecated==1.3 # deprecated call checks flake8-docstrings==1.6.0 # pydocstyle support flake8-executable==2.1.1 # shebangs flake8-fixme==1.1.1 # "fix me" counter flake8-functions==0.0.7 # function linting -jinja2==3.0.3 # temporarily freeze jinja2 due to incompatibilities with flake8-html -flake8-html==0.4.1 # html output +flake8-html==0.4.2 # html output flake8-if-statements==0.1.0 # condition linting flake8-isort==4.1.1 # runs isort flake8-mutable==1.2.0 # mutable default argument detection flake8-pep3101==1.3.0 # new-style format strings only -flake8-print==4.0.0 # complain about print statements in code +flake8-print==5.0.0 # complain about print statements in code flake8-printf-formatting==1.1.2 # forbey printf-style python2 string formatting flake8-pytest-style==1.6.0 # pytest checks flake8-raise==0.0.5 # exception raising linting diff --git a/hikari/__init__.pyi b/hikari/__init__.pyi index 267fc5bb17..f989f82fd0 100644 --- a/hikari/__init__.pyi +++ b/hikari/__init__.pyi @@ -1,7 +1,7 @@ # DO NOT MANUALLY EDIT THIS FILE! # This file was automatically generated by `nox -s generate-stubs` -from typing import Any +from _typeshed import Incomplete from hikari import api as api from hikari import applications as applications @@ -102,4 +102,4 @@ from hikari.users import * from hikari.voices import * from hikari.webhooks import * -__all__: Any +__all__: Incomplete diff --git a/hikari/_about.py b/hikari/_about.py index efacec3310..1540fd8809 100644 --- a/hikari/_about.py +++ b/hikari/_about.py @@ -39,5 +39,5 @@ __issue_tracker__: typing.Final[str] = "https://github.com/hikari-py/hikari/issues" __license__: typing.Final[str] = "MIT" __url__: typing.Final[str] = "https://github.com/hikari-py/hikari" -__version__: typing.Final[str] = "2.0.0.dev109" +__version__: typing.Final[str] = "2.0.0.dev110" __git_sha1__: typing.Final[str] = "HEAD" diff --git a/hikari/api/config.py b/hikari/api/config.py index 839c5a1cbf..ada229e626 100644 --- a/hikari/api/config.py +++ b/hikari/api/config.py @@ -23,7 +23,7 @@ """Core interface for Hikari's configuration dataclasses.""" from __future__ import annotations -__all__: typing.Sequence[str] = ("CacheComponents",) +__all__: typing.Sequence[str] = ("CacheComponents", "CacheSettings", "HTTPSettings", "ProxySettings") import abc import typing diff --git a/hikari/api/event_factory.py b/hikari/api/event_factory.py index 91515b1b97..31677c135f 100644 --- a/hikari/api/event_factory.py +++ b/hikari/api/event_factory.py @@ -40,6 +40,7 @@ from hikari import users as user_models from hikari import voices as voices_models from hikari.api import shard as gateway_shard + from hikari.events import application_events from hikari.events import channel_events from hikari.events import guild_events from hikari.events import interaction_events @@ -61,6 +62,29 @@ class EventFactory(abc.ABC): __slots__: typing.Sequence[str] = () + ###################### + # APPLICATION EVENTS # + ###################### + + @abc.abstractmethod + def deserialize_application_command_permission_update_event( + self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject + ) -> application_events.ApplicationCommandPermissionsUpdateEvent: + """Parse a raw payload from Discord into an application command permissions update event object. + + Parameters + ---------- + shard : hikari.api.shard.GatewayShard + The shard that emitted this event. + payload : hikari.internal.data_binding.JSONObject + The dict payload to parse. + + Returns + ------- + hikari.events.application_events.ApplicationCommandPermissionsUpdateEvent + The parsed application command permissions update event. + """ + ################## # CHANNEL EVENTS # ################## diff --git a/hikari/api/rest.py b/hikari/api/rest.py index ceea4c2048..fd8f07af51 100644 --- a/hikari/api/rest.py +++ b/hikari/api/rest.py @@ -1199,9 +1199,6 @@ async def create_message( `role_mentions` or `user_mentions` or if both `attachment` and `attachments`, `component` and `components` or `embed` and `embeds` are specified. - builtins.TypeError - If `attachments`, `components` or `embeds` is passed but is not a - sequence. hikari.errors.BadRequestError This may be raised in several discrete situations, such as messages being empty with no attachments or embeds; messages with more than @@ -1446,9 +1443,6 @@ async def edit_message( builtins.ValueError If both `attachment` and `attachments`, `component` and `components` or `embed` and `embeds` are specified. - builtins.TypeError - If `attachments`, `components` or `embeds` is passed but is not a - sequence. hikari.errors.BadRequestError This may be raised in several discrete situations, such as messages being empty with no embeds; messages with more than 2000 characters @@ -2358,8 +2352,6 @@ async def execute_webhook( If more than 100 unique objects/entities are passed for `role_mentions` or `user_mentions` or if both `attachment` and `attachments` or `embed` and `embeds` are specified. - builtins.TypeError - If `attachments`, or `embeds` is passed but is not a sequence. hikari.errors.BadRequestError This may be raised in several discrete situations, such as messages being empty with no attachments or embeds; messages with more than @@ -2579,9 +2571,6 @@ async def edit_webhook_message( builtins.ValueError If both `attachment` and `attachments`, `component` and `components` or `embed` and `embeds` are specified. - builtins.TypeError - If `attachments`, `components` or `embeds` is passed but is not a - sequence. hikari.errors.BadRequestError This may be raised in several discrete situations, such as messages being empty with no attachments or embeds; messages with more than @@ -3286,6 +3275,9 @@ async def add_user_to_guild( Requires the `MANAGE_NICKNAMES` permission on the guild. nick : hikari.undefined.UndefinedOr[builtins.str] Deprecated alias for `nickname`. + + .. deprecated:: 2.0.0.dev106 + Use `nickname` instead. roles : hikari.undefined.UndefinedOr[hikari.snowflakes.SnowflakeishSequence[hikari.guilds.PartialRole]] If provided, the roles to add to the user when he joins the guild. This may be a collection objects or IDs of existing roles. @@ -5006,6 +4998,9 @@ async def edit_member( Requires the `MANAGE_NICKNAMES` permission. nick : hikari.undefined.UndefinedOr[builtins.str] Deprecated alias for `nickname`. + + .. deprecated:: 2.0.0.dev104 + Use `nickname` instead. roles : hikari.undefined.UndefinedOr[hikari.snowflakes.SnowflakeishSequence[hikari.guilds.PartialRole]] If provided, the new roles for the member. @@ -5509,21 +5504,42 @@ async def fetch_ban( """ @abc.abstractmethod - async def fetch_bans( + def fetch_bans( self, guild: snowflakes.SnowflakeishOr[guilds.PartialGuild], - ) -> typing.Sequence[guilds.GuildBan]: + /, + *, + newest_first: bool = False, + start_at: undefined.UndefinedOr[snowflakes.SearchableSnowflakeishOr[users.PartialUser]] = undefined.UNDEFINED, + ) -> iterators.LazyIterator[guilds.GuildBan]: """Fetch the bans of a guild. + !!! note + This call is not a coroutine function, it returns a special type of + lazy iterator that will perform API calls as you iterate across it. + See `hikari.iterators` for the full API for this iterator type. + Parameters ---------- guild : hikari.snowflakes.SnowflakeishOr[hikari.guilds.PartialGuild] The guild to fetch the bans from. This may be the object or the ID of an existing guild. + Other Parameters + ---------------- + newest_first : builtins.bool + Whether to fetch the newest first or the oldest first. + + Defaults to `builtins.False`. + start_at : undefined.UndefinedOr[snowflakes.SearchableSnowflakeishOr[users.PartialUser]] + If provided, will start at this snowflake. If you provide + a datetime object, it will be transformed into a snowflake. This + may also be a scheduled event object object. In this case, the + date the object was first created will be used. + Returns ------- - typing.Sequence[hikari.guilds.GuildBan] + hikari.iterators.LazyIterator[hikari.guilds.GuildBan] The requested bans. Raises @@ -6798,77 +6814,6 @@ async def fetch_application_commands( If an internal error occurs on Discord while handling the request. """ - @abc.abstractmethod - async def create_application_command( - self, - application: snowflakes.SnowflakeishOr[guilds.PartialApplication], - name: str, - description: str, - guild: undefined.UndefinedOr[snowflakes.SnowflakeishOr[guilds.PartialGuild]] = undefined.UNDEFINED, - *, - options: undefined.UndefinedOr[typing.Sequence[commands.CommandOption]] = undefined.UNDEFINED, - default_permission: undefined.UndefinedOr[bool] = undefined.UNDEFINED, - ) -> commands.SlashCommand: - r"""Create an application slash command. - - .. deprecated:: 2.0.0.dev106 - Use `RESTClient.create_slash_command` instead. - - Parameters - ---------- - application: hikari.snowflakes.SnowflakeishOr[hikari.guilds.PartialApplication] - Object or ID of the application to create a command for. - name : builtins.str - The command's name. This should match the regex `^[\w-]{1,32}$` in - Unicode mode and be lowercase. - description : builtins.str - The description to set for the command. - This should be inclusively between 1-100 characters in length. - guild : hikari.undefined.UndefinedOr[hikari.snowflakes.SnowflakeishOr[hikari.guilds.PartialGuild] - Object or ID of the specific guild this should be made for. - If left as `hikari.undefined.UNDEFINED` then this call will create - a global command rather than a guild specific one. - - Other Parameters - ---------------- - options : hikari.undefined.UndefinedOr[typing.Sequence[hikari.commands.CommandOption]] - A sequence of up to 10 options for this command. - default_permission : hikari.undefined.UndefinedOr[builtins.bool] - Whether this command should be enabled by default (without any - permissions) when added to a guild. - - Defaults to `builtins.True`. - - Returns - ------- - hikari.commands.SlashCommand - Object of the created command. - - Raises - ------ - hikari.errors.ForbiddenError - If you cannot access the provided application's commands. - hikari.errors.NotFoundError - If the provided application isn't found. - hikari.errors.BadRequestError - If any of the fields that are passed have an invalid value. - hikari.errors.UnauthorizedError - If you are unauthorized to make the request (invalid/missing token). - hikari.errors.RateLimitTooLongError - Raised in the event that a rate limit occurs that is - longer than `max_rate_limit` when making a request. - hikari.errors.RateLimitedError - Usually, Hikari will handle and retry on hitting - rate-limits automatically. This includes most bucket-specific - rate-limits and global rate-limits. In some rare edge cases, - however, Discord implements other undocumented rules for - rate-limiting, such as limits per attribute. These cannot be - detected or handled normally by Hikari due to their undocumented - nature, and will trigger this exception if they occur. - hikari.errors.InternalServerError - If an internal error occurs on Discord while handling the request. - """ - @abc.abstractmethod async def create_slash_command( self, @@ -6878,7 +6823,10 @@ async def create_slash_command( *, guild: undefined.UndefinedOr[snowflakes.SnowflakeishOr[guilds.PartialGuild]] = undefined.UNDEFINED, options: undefined.UndefinedOr[typing.Sequence[commands.CommandOption]] = undefined.UNDEFINED, - default_permission: undefined.UndefinedOr[bool] = undefined.UNDEFINED, + default_member_permissions: typing.Union[ + undefined.UndefinedType, int, permissions_.Permissions + ] = undefined.UNDEFINED, + dm_enabled: undefined.UndefinedOr[bool] = undefined.UNDEFINED, ) -> commands.SlashCommand: r"""Create an application command. @@ -6901,11 +6849,15 @@ async def create_slash_command( a global command rather than a guild specific one. options : hikari.undefined.UndefinedOr[typing.Sequence[hikari.commands.CommandOption]] A sequence of up to 10 options for this command. - default_permission : hikari.undefined.UndefinedOr[builtins.bool] - Whether this command should be enabled by default (without any - permissions) when added to a guild. + default_member_permissions : typing.Union[hikari.undefined.UndefinedType, int, hikari.permissions.Permissions] + Member permissions necessary to utilize this command by default. + + If `0`, then it will be available for all members. Note that this doesn't affect + administrators of the guild and overwrites. + dm_enabled : hikari.undefined.UndefinedOr[builtins.bool] + Whether this command is enabled in DMs with the bot. - Defaults to `builtins.True`. + This can only be applied to non-guild commands. Returns ------- @@ -6945,7 +6897,10 @@ async def create_context_menu_command( name: str, *, guild: undefined.UndefinedOr[snowflakes.SnowflakeishOr[guilds.PartialGuild]] = undefined.UNDEFINED, - default_permission: undefined.UndefinedOr[bool] = undefined.UNDEFINED, + default_member_permissions: typing.Union[ + undefined.UndefinedType, int, permissions_.Permissions + ] = undefined.UNDEFINED, + dm_enabled: undefined.UndefinedOr[bool] = undefined.UNDEFINED, ) -> commands.ContextMenuCommand: r"""Create an application command. @@ -6967,11 +6922,15 @@ async def create_context_menu_command( Object or ID of the specific guild this should be made for. If left as `hikari.undefined.UNDEFINED` then this call will create a global command rather than a guild specific one. - default_permission : hikari.undefined.UndefinedOr[builtins.bool] - Whether this command should be enabled by default (without any - permissions) when added to a guild. + default_member_permissions : typing.Union[hikari.undefined.UndefinedType, int, hikari.permissions.Permissions] + Member permissions necessary to utilize this command by default. + + If `0`, then it will be available for all members. Note that this doesn't affect + administrators of the guild and overwrites. + dm_enabled : hikari.undefined.UndefinedOr[builtins.bool] + Whether this command is enabled in DMs with the bot. - Defaults to `builtins.True`. + This can only be applied to non-guild commands. Returns ------- @@ -7071,6 +7030,10 @@ async def edit_application_command( name: undefined.UndefinedOr[str] = undefined.UNDEFINED, description: undefined.UndefinedOr[str] = undefined.UNDEFINED, options: undefined.UndefinedOr[typing.Sequence[commands.CommandOption]] = undefined.UNDEFINED, + default_member_permissions: typing.Union[ + undefined.UndefinedType, int, permissions_.Permissions + ] = undefined.UNDEFINED, + dm_enabled: undefined.UndefinedOr[bool] = undefined.UNDEFINED, ) -> commands.PartialCommand: """Edit a registered application command. @@ -7096,6 +7059,15 @@ async def edit_application_command( options : hikari.undefined.UndefinedOr[typing.Sequence[hikari.commands.CommandOption]] A sequence of up to 10 options to set for this command. Leave this as `hikari.undefined.UNDEFINED` to not change. + default_member_permissions : typing.Union[hikari.undefined.UndefinedType, int, hikari.permissions.Permissions] + Member permissions necessary to utilize this command by default. + + If `0`, then it will be available for all members. Note that this doesn't affect + administrators of the guild and overwrites. + dm_enabled : hikari.undefined.UndefinedOr[builtins.bool] + Whether this command is enabled in DMs with the bot. + + This can only be applied to non-guild commands. Returns ------- @@ -7262,62 +7234,7 @@ async def fetch_application_command_permissions( If an internal error occurs on Discord while handling the request. """ - @abc.abstractmethod - async def set_application_guild_commands_permissions( - self, - application: snowflakes.SnowflakeishOr[guilds.PartialApplication], - guild: snowflakes.SnowflakeishOr[guilds.PartialGuild], - permissions: typing.Mapping[ - snowflakes.SnowflakeishOr[commands.PartialCommand], typing.Sequence[commands.CommandPermission] - ], - ) -> typing.Sequence[commands.GuildCommandPermissions]: - """Set permissions in a guild for multiple commands. - - !!! note - This overwrites any previously set permissions for the specified - commands. - - Parameters - ---------- - application: hikari.snowflakes.SnowflakeishOr[hikari.guilds.PartialApplication] - Object or ID of the application to set the command permissions for. - guild : hikari.undefined.UndefinedOr[hikari.snowflakes.SnowflakeishOr[hikari.guilds.PartialGuild]] - Object or ID of the guild to set the command permissions for. - permissions : typing.Mapping[hikari.snowflakes.SnowflakeishOr[hikari.commands.PartialCommand], typing.Sequence[hikari.commands.CommandPermission]] - Mapping of objects and/or IDs of commands to sequences of the commands - to set for the specified guild. - - !!! warning - Only a maximum of up to 10 permissions can be set per command. - - Returns - ------- - typing.Sequence[hikari.commands.GuildCommandPermissions] - Sequence of the set guild command permissions. - - Raises - ------ - hikari.errors.ForbiddenError - If you cannot access the provided application's commands or guild. - hikari.errors.NotFoundError - If the provided application or command isn't found. - hikari.errors.UnauthorizedError - If you are unauthorized to make the request (invalid/missing token). - hikari.errors.RateLimitTooLongError - Raised in the event that a rate limit occurs that is - longer than `max_rate_limit` when making a request. - hikari.errors.RateLimitedError - Usually, Hikari will handle and retry on hitting - rate-limits automatically. This includes most bucket-specific - rate-limits and global rate-limits. In some rare edge cases, - however, Discord implements other undocumented rules for - rate-limiting, such as limits per attribute. These cannot be - detected or handled normally by Hikari due to their undocumented - nature, and will trigger this exception if they occur. - hikari.errors.InternalServerError - If an internal error occurs on Discord while handling the request. - """ # noqa: E501 - Line too long - + # THIS IS AN OAUTH2 FLOW ONLY @abc.abstractmethod async def set_application_command_permissions( self, @@ -7328,6 +7245,12 @@ async def set_application_command_permissions( ) -> commands.GuildCommandPermissions: """Set permissions for a specific command. + !!! note + This requires the `access_token` to have the + `hikari.applications.OAuth2Scope.APPLICATIONS_COMMANDS_PERMISSION_UPDATE` + scope enabled along with the authorization of a Bot which has `MANAGE_INVITES` + permission within the target guild. + !!! note This overwrites any previously set permissions. @@ -7740,9 +7663,6 @@ async def edit_interaction_response( builtins.ValueError If both `attachment` and `attachments`, `component` and `components` or `embed` and `embeds` are specified. - builtins.TypeError - If `attachments`, `components` or `embeds` is passed but is not a - sequence. hikari.errors.BadRequestError This may be raised in several discrete situations, such as messages being empty with no attachments or embeds; messages with more than diff --git a/hikari/api/special_endpoints.py b/hikari/api/special_endpoints.py index 058c4f714b..831030fd55 100644 --- a/hikari/api/special_endpoints.py +++ b/hikari/api/special_endpoints.py @@ -67,7 +67,6 @@ from hikari.api import entity_factory as entity_factory_ from hikari.api import rest as rest_api from hikari.interactions import base_interactions - from hikari.internal import data_binding from hikari.internal import time _T = typing.TypeVar("_T") @@ -539,7 +538,7 @@ def type(self) -> typing.Union[int, base_interactions.ResponseType]: @abc.abstractmethod def build( self, entity_factory: entity_factory_.EntityFactory, / - ) -> typing.Tuple[data_binding.JSONObject, typing.Sequence[files.Resource[files.AsyncReader]]]: + ) -> typing.Tuple[typing.MutableMapping[str, typing.Any], typing.Sequence[files.Resource[files.AsyncReader]]]: """Build a JSON object from this builder. Parameters @@ -549,7 +548,7 @@ def build( Returns ------- - typing.Tuple[hikari.internal.data_binding.JSONObject, typing.Sequence[files.Resource[Files.AsyncReader]] + typing.Tuple[typing.MutableMapping[str, typing.Any], typing.Sequence[files.Resource[Files.AsyncReader]] A tuple of the built json object representation of this builder and a sequence of up to 10 files to send with the response. """ @@ -968,10 +967,19 @@ def id(self) -> undefined.UndefinedOr[snowflakes.Snowflake]: @property @abc.abstractmethod - def default_permission(self) -> undefined.UndefinedOr[bool]: - """Whether the command should be enabled by default (without any permissions). + def default_member_permissions(self) -> typing.Union[undefined.UndefinedType, permissions_.Permissions, int]: + """Member permissions necessary to utilize this command by default. - Defaults to `builtins.bool`. + If `0`, then it will be available for all members. Note that this doesn't affect + administrators of the guild and overwrites. + """ + + @property + @abc.abstractmethod + def is_dm_enabled(self) -> undefined.UndefinedOr[bool]: + """Whether this command is enabled in DMs with the bot. + + Only applicable to globally-scoped commands. """ @abc.abstractmethod @@ -990,13 +998,33 @@ def set_id(self: _T, id_: undefined.UndefinedOr[snowflakes.Snowflakeish], /) -> """ @abc.abstractmethod - def set_default_permission(self: _T, state: undefined.UndefinedOr[bool], /) -> _T: - """Whether this command should be enabled by default (without any permissions). + def set_default_member_permissions( + self: _T, default_member_permissions: typing.Union[undefined.UndefinedType, int, permissions_.Permissions], / + ) -> _T: + """Set the member permissions necessary to utilize this command by default. + + Parameters + ---------- + default_member_permissions : hikari.undefined.UndefinedOr[builtins.bool] + The default member permissions to utilize this command by default. + + If `0`, then it will be available for all members. Note that this doesn't affect + administrators of the guild and overwrites. + + Returns + ------- + CommandBuilder + Object of this command builder for chained calls. + """ + + @abc.abstractmethod + def set_is_dm_enabled(self: _T, state: undefined.UndefinedOr[bool], /) -> _T: + """Set whether this command will be enabled in DMs with the bot. Parameters ---------- state : hikari.undefined.UndefinedOr[builtins.bool] - Whether this command should be enabled by default. + Whether this command is enabled in DMs with the bot. Returns ------- @@ -1005,7 +1033,7 @@ def set_default_permission(self: _T, state: undefined.UndefinedOr[bool], /) -> _ """ @abc.abstractmethod - def build(self, entity_factory: entity_factory_.EntityFactory, /) -> data_binding.JSONObject: + def build(self, entity_factory: entity_factory_.EntityFactory, /) -> typing.MutableMapping[str, typing.Any]: """Build a JSON object from this builder. Parameters @@ -1015,7 +1043,7 @@ def build(self, entity_factory: entity_factory_.EntityFactory, /) -> data_bindin Returns ------- - hikari.internal.data_binding.JSONObject + typing.MutableMapping[str, typing.Any] The built json object representation of this builder. """ @@ -1170,12 +1198,12 @@ class ComponentBuilder(abc.ABC): __slots__: typing.Sequence[str] = () @abc.abstractmethod - def build(self) -> data_binding.JSONObject: + def build(self) -> typing.MutableMapping[str, typing.Any]: """Build a JSON object from this builder. Returns ------- - hikari.internal.data_binding.JSONObject + typing.MutableMapping[str, typing.Any] The built json object representation of this builder. """ diff --git a/hikari/applications.py b/hikari/applications.py index 3a74ff49d9..d8a05701f3 100644 --- a/hikari/applications.py +++ b/hikari/applications.py @@ -137,7 +137,7 @@ class OAuth2Scope(str, enums.Enum): """ APPLICATIONS_COMMANDS = "applications.commands" - """Allows your application's (slash) commands to be used in a guild. + """Allows your application's commands to be used in a guild. This is used in Discord's special Bot Authorization Flow like `OAuth2Scope.BOT` in-order to join an application into a guild as an @@ -145,7 +145,10 @@ class OAuth2Scope(str, enums.Enum): """ APPLICATIONS_COMMANDS_UPDATE = "applications.commands.update" - """Allows your application to update it's (slash) commands via a bearer token.""" + """Allows your application to update its commands via a bearer token.""" + + APPLICATIONS_COMMANDS_PERMISSION_UPDATE = "applications.commands.permissions.update" + """Allows your application to update its commands permissions via a bearer token.""" APPLICATIONS_ENTITLEMENTS = "applications.entitlements" """Enables reading entitlements for a user's applications.""" diff --git a/hikari/audit_logs.py b/hikari/audit_logs.py index f39c888ffa..befe6ecc9b 100644 --- a/hikari/audit_logs.py +++ b/hikari/audit_logs.py @@ -101,6 +101,7 @@ class AuditLogChangeKey(str, enums.Enum): PERMISSIONS = "permissions" USER_LIMIT = "user_limit" COLOR = "color" + COMMAND_ID = "command_id" HOIST = "hoist" MENTIONABLE = "mentionable" ALLOW = "allow" @@ -194,6 +195,10 @@ class AuditLogEventType(int, enums.Enum): STICKER_CREATE = 90 STICKER_UPDATE = 91 STICKER_DELETE = 92 + GUILD_SCHEDULED_EVENT_CREATE = 100 + GUILD_SCHEDULED_EVENT_UPDATE = 101 + GUILD_SCHEDULED_EVENT_DELETE = 102 + APPLICATION_COMMAND_PERMISSION_UPDATE = 121 @attr.define(hash=False, kw_only=True, weakref_slot=False) diff --git a/hikari/channels.py b/hikari/channels.py index 0188521e23..af769dce39 100644 --- a/hikari/channels.py +++ b/hikari/channels.py @@ -1288,7 +1288,7 @@ class GuildNewsChannel(TextableGuildChannel): @attr.define(hash=True, kw_only=True, weakref_slot=False) -class GuildVoiceChannel(GuildChannel): +class GuildVoiceChannel(TextableGuildChannel): """Represents a voice channel.""" bitrate: int = attr.field(eq=False, hash=False, repr=True) @@ -1311,6 +1311,14 @@ class GuildVoiceChannel(GuildChannel): video_quality_mode: typing.Union[VideoQualityMode, int] = attr.field(eq=False, hash=False, repr=False) """The video quality mode for the voice channel.""" + last_message_id: typing.Optional[snowflakes.Snowflake] = attr.field(eq=False, hash=False, repr=False) + """The ID of the last message sent in this channel. + + !!! warning + This might point to an invalid or deleted message. Do not assume that + this will always be valid. + """ + @attr.define(hash=True, kw_only=True, weakref_slot=False) class GuildStageChannel(GuildChannel): diff --git a/hikari/colors.py b/hikari/colors.py index cbdb931783..3866ab01b4 100644 --- a/hikari/colors.py +++ b/hikari/colors.py @@ -424,34 +424,6 @@ def from_tuple_string(cls, tuple_str: str, /) -> Color: else: return cls.from_rgb(_to_rgb_int(r, "red"), _to_rgb_int(g, "green"), _to_rgb_int(b, "blue")) - # Partially chose to override these as the docstrings contain typos according to Sphinx. - @classmethod - def from_bytes( - cls, - bytes_: typing.Union[typing.Iterable[typing.SupportsIndex], typing.SupportsBytes], - byteorder: typing.Literal["little", "big"], - *, - signed: bool = True, - ) -> Color: - """Convert the bytes to a `Color`. - - Parameters - ---------- - bytes_ : typing.Iterable[builtins.int] - A iterable of int byte values. - byteorder : builtins.str - The endianness of the value represented by the bytes. - Can be `"big"` endian or `"little"` endian. - signed : builtins.bool - Whether the value is signed or unsigned. - - Returns - ------- - Color - The Color object. - """ - return Color(int.from_bytes(bytes_, byteorder, signed=signed)) - @classmethod def of(cls, value: Colorish, /) -> Color: """Convert the value to a `Color`. @@ -517,12 +489,12 @@ def of(cls, value: Colorish, /) -> Color: if len(value) != 3: raise ValueError(f"Color must be an RGB triplet if set to a {type(value).__name__} type") - if any(isinstance(c, float) for c in value): - r, g, b = value + r, g, b = value + + if isinstance(r, float) and isinstance(g, float) and isinstance(b, float): return cls.from_rgb_float(r, g, b) - if all(isinstance(c, int) for c in value): - r, g, b = value + if isinstance(r, int) and isinstance(g, int) and isinstance(b, int): return cls.from_rgb(r, g, b) if isinstance(value, str): diff --git a/hikari/commands.py b/hikari/commands.py index 618a06ab88..f13f121e02 100644 --- a/hikari/commands.py +++ b/hikari/commands.py @@ -40,6 +40,7 @@ import attr +from hikari import permissions from hikari import snowflakes from hikari import traits from hikari import undefined @@ -208,13 +209,15 @@ class PartialCommand(snowflakes.Unique): lowercase. """ - default_permission: bool = attr.field(eq=False, hash=False, repr=True) - """Whether the command is enabled by default when added to a guild. + default_member_permissions: permissions.Permissions = attr.field(eq=False, hash=False, repr=True) + """Member permissions necessary to utilize this command by default. - Defaults to `builtins.True`. This behaviour is overridden by command - permissions. + This excludes administrators of the guild and overwrites. """ + is_dm_enabled: bool = attr.field(eq=False, hash=False, repr=True) + """Whether this command is enabled in DMs with the bot.""" + guild_id: typing.Optional[snowflakes.Snowflake] = attr.field(eq=False, hash=False, repr=False) """ID of the guild this command is in. @@ -473,6 +476,9 @@ class CommandPermissionType(int, enums.Enum): USER = 2 """A command permission which toggles access for a specific user.""" + CHANNEL = 3 + """A command permission which toggles access in a specific channel.""" + @attr_extensions.with_copy @attr.define(kw_only=True, weakref_slot=False) @@ -480,7 +486,13 @@ class CommandPermission: """Representation of a permission which enables or disables a command for a user or role.""" id: snowflakes.Snowflake = attr.field(converter=snowflakes.Snowflake) - """Id of the role or user this permission changes the permission's state for.""" + """ID of the role or user this permission changes the permission's state for. + + There are some special constants for this field: + + * If equals to `guild_id`, then it applies to all members in a guild. + * If equals to (`guild_id` - 1), then it applies to all channels in a guild. + """ type: typing.Union[CommandPermissionType, int] = attr.field(converter=CommandPermissionType) """The entity this permission overrides the command's state for.""" @@ -494,6 +506,14 @@ class CommandPermission: class GuildCommandPermissions: """Representation of the permissions set for a command within a guild.""" + id: snowflakes.Snowflake = attr.field() + """ID of the entity these permissions apply to. + + This may be the ID of a specific command or the application ID. When this is equal + to `application_id`, the permissions apply to all commands that do not contain + explicit overwrites. + """ + application_id: snowflakes.Snowflake = attr.field() """ID of the application the relevant command belongs to.""" @@ -504,4 +524,4 @@ class GuildCommandPermissions: """ID of the guild these permissions are in.""" permissions: typing.Sequence[CommandPermission] = attr.field() - """Sequence of up to (and including) 10 of the command permissions set in this guild.""" + """Sequence of up to (and including) 100 of the command permissions set in this guild.""" diff --git a/hikari/embeds.py b/hikari/embeds.py index 24321c523b..50ddc7f547 100644 --- a/hikari/embeds.py +++ b/hikari/embeds.py @@ -52,20 +52,16 @@ import concurrent.futures import datetime - _T = typing.TypeVar("_T", bound="EmbedResource[files.AsyncReader]") - -AsyncReaderT = typing.TypeVar("AsyncReaderT", bound=files.AsyncReader) - @attr_extensions.with_copy @attr.define(kw_only=True, weakref_slot=False) -class EmbedResource(files.Resource[AsyncReaderT]): +class EmbedResource(files.Resource[files.AsyncReader]): """A base type for any resource provided in an embed. Resources can be downloaded and uploaded. """ - resource: files.Resource[AsyncReaderT] = attr.field(repr=True) + resource: files.Resource[files.AsyncReader] = attr.field(repr=True) """The resource this object wraps around.""" @property @@ -96,7 +92,7 @@ def stream( *, executor: typing.Optional[concurrent.futures.Executor] = None, head_only: bool = False, - ) -> files.AsyncReaderContextManager[AsyncReaderT]: + ) -> files.AsyncReaderContextManager[files.AsyncReader]: """Produce a stream of data for the resource. Parameters @@ -114,10 +110,10 @@ def stream( @attr.define(kw_only=True, weakref_slot=False) -class EmbedResourceWithProxy(EmbedResource[AsyncReaderT]): +class EmbedResourceWithProxy(EmbedResource): """Resource with a corresponding proxied element.""" - proxy_resource: typing.Optional[files.Resource[AsyncReaderT]] = attr.field(default=None, repr=False) + proxy_resource: typing.Optional[files.Resource[files.AsyncReader]] = attr.field(default=None, repr=False) """The proxied version of the resource, or `builtins.None` if not present. !!! note @@ -163,12 +159,12 @@ class EmbedFooter: text: typing.Optional[str] = attr.field(default=None, repr=True) """The footer text, or `builtins.None` if not present.""" - icon: typing.Optional[EmbedResourceWithProxy[files.AsyncReader]] = attr.field(default=None, repr=True) + icon: typing.Optional[EmbedResourceWithProxy] = attr.field(default=None, repr=True) """The URL of the footer icon, or `builtins.None` if not present.""" @attr.define(hash=False, kw_only=True, weakref_slot=False) -class EmbedImage(EmbedResourceWithProxy[AsyncReaderT]): +class EmbedImage(EmbedResourceWithProxy): """Represents an embed image.""" height: typing.Optional[int] = attr.field(default=None, repr=False) @@ -191,7 +187,7 @@ class EmbedImage(EmbedResourceWithProxy[AsyncReaderT]): @attr.define(hash=False, kw_only=True, weakref_slot=False) -class EmbedVideo(EmbedResourceWithProxy[AsyncReaderT]): +class EmbedVideo(EmbedResourceWithProxy): """Represents an embed video. !!! note @@ -246,7 +242,7 @@ class EmbedAuthor: This may be `builtins.None` if no hyperlink on the author's name is specified. """ - icon: typing.Optional[EmbedResourceWithProxy[files.AsyncReader]] = attr.field(default=None, repr=False) + icon: typing.Optional[EmbedResourceWithProxy] = attr.field(default=None, repr=False) """The author's icon, or `builtins.None` if not present.""" @@ -278,11 +274,11 @@ def is_inline(self, value: bool) -> None: self._inline = value -def _ensure_embed_resource(resource: files.Resourceish, cls: typing.Type[_T]) -> _T: +def _ensure_embed_resource(resource: files.Resourceish) -> files.Resource[files.AsyncReader]: if isinstance(resource, EmbedResource): - return cls(resource=resource.resource) + return resource.resource - return cls(resource=files.ensure_resource(resource)) + return files.ensure_resource(resource) class Embed: @@ -315,9 +311,9 @@ def from_received_embed( url: typing.Optional[str], color: typing.Optional[colors.Color], timestamp: typing.Optional[datetime.datetime], - image: typing.Optional[EmbedImage[files.AsyncReader]], - thumbnail: typing.Optional[EmbedImage[files.AsyncReader]], - video: typing.Optional[EmbedVideo[files.AsyncReader]], + image: typing.Optional[EmbedImage], + thumbnail: typing.Optional[EmbedImage], + video: typing.Optional[EmbedVideo], author: typing.Optional[EmbedAuthor], provider: typing.Optional[EmbedProvider], footer: typing.Optional[EmbedFooter], @@ -372,10 +368,10 @@ def __init__( self.description = description self.url = url self._author: typing.Optional[EmbedAuthor] = None - self._image: typing.Optional[EmbedImage[files.AsyncReader]] = None - self._video: typing.Optional[EmbedVideo[files.AsyncReader]] = None + self._image: typing.Optional[EmbedImage] = None + self._video: typing.Optional[EmbedVideo] = None self._provider: typing.Optional[EmbedProvider] = None - self._thumbnail: typing.Optional[EmbedImage[files.AsyncReader]] = None + self._thumbnail: typing.Optional[EmbedImage] = None self._footer: typing.Optional[EmbedFooter] = None # More boilerplate to allow this to be optional, but saves a useless list on every embed @@ -605,7 +601,7 @@ def footer(self) -> typing.Optional[EmbedFooter]: return self._footer @property - def image(self) -> typing.Optional[EmbedImage[files.AsyncReader]]: + def image(self) -> typing.Optional[EmbedImage]: """Return the image set in the embed. Will be `builtins.None` if not set. @@ -619,7 +615,7 @@ def image(self) -> typing.Optional[EmbedImage[files.AsyncReader]]: return self._image @property - def thumbnail(self) -> typing.Optional[EmbedImage[files.AsyncReader]]: + def thumbnail(self) -> typing.Optional[EmbedImage]: """Return the thumbnail set in the embed. Will be `builtins.None` if not set. @@ -633,7 +629,7 @@ def thumbnail(self) -> typing.Optional[EmbedImage[files.AsyncReader]]: return self._thumbnail @property - def video(self) -> typing.Optional[EmbedVideo[files.AsyncReader]]: + def video(self) -> typing.Optional[EmbedVideo]: """Return the video to show in the embed. Will be `builtins.None` if not set. @@ -742,7 +738,7 @@ def set_author( if name is None and url is None and icon is None: self._author = None else: - real_icon = _ensure_embed_resource(icon, EmbedResourceWithProxy) if icon is not None else None + real_icon = EmbedResourceWithProxy(resource=_ensure_embed_resource(icon)) if icon is not None else None self._author = EmbedAuthor(name=name, url=url, icon=real_icon) return self @@ -791,7 +787,7 @@ def set_footer(self, text: typing.Optional[str], *, icon: typing.Optional[files. self._footer = None else: - real_icon = _ensure_embed_resource(icon, EmbedResourceWithProxy) if icon is not None else None + real_icon = EmbedResourceWithProxy(resource=_ensure_embed_resource(icon)) if icon is not None else None self._footer = EmbedFooter(icon=real_icon, text=text) return self @@ -829,7 +825,7 @@ def set_image(self, image: typing.Optional[files.Resourceish] = None, /) -> Embe This embed. Allows for call chaining. """ if image is not None: - self._image = _ensure_embed_resource(image, EmbedImage) + self._image = EmbedImage(resource=_ensure_embed_resource(image)) else: self._image = None @@ -868,7 +864,7 @@ def set_thumbnail(self, image: typing.Optional[files.Resourceish] = None, /) -> This embed. Allows for call chaining. """ if image is not None: - self._thumbnail = _ensure_embed_resource(image, EmbedImage) + self._thumbnail = EmbedImage(resource=_ensure_embed_resource(image)) else: self._thumbnail = None diff --git a/hikari/errors.py b/hikari/errors.py index 4273d489d0..ce5c9f8bdf 100644 --- a/hikari/errors.py +++ b/hikari/errors.py @@ -231,8 +231,12 @@ class HTTPResponseError(HTTPError): url: str = attr.field() """The URL that produced this error message.""" - status: http.HTTPStatus = attr.field() - """The HTTP status code for the response.""" + status: typing.Union[http.HTTPStatus, int] = attr.field() + """The HTTP status code for the response. + + This will be `int` if it's outside the range of status codes in the HTTP + specification (e.g. one of Cloudflare's non-standard status codes). + """ headers: data_binding.Headers = attr.field() """The headers received in the error response.""" @@ -247,8 +251,12 @@ class HTTPResponseError(HTTPError): """The error code.""" def __str__(self) -> str: - name = self.status.name.replace("_", " ").title() - name_value = f"{name} {self.status.value}" + if isinstance(self.status, http.HTTPStatus): + name = self.status.name.replace("_", " ").title() + name_value = f"{name} {self.status.value}" + + else: + name_value = f"Unknown Status {self.status}" if self.code: code_str = f" ({self.code})" @@ -284,7 +292,7 @@ class BadRequestError(ClientHTTPResponseError): status: http.HTTPStatus = attr.field(default=http.HTTPStatus.BAD_REQUEST, init=False) """The HTTP status code for the response.""" - errors: typing.Optional[typing.Dict[str, data_binding.JSONObject]] = attr.field(default=None, kw_only=True) + errors: typing.Optional[typing.Mapping[str, data_binding.JSONObject]] = attr.field(default=None, kw_only=True) """Dict of top level field names to field specific error paths. For more information, this error format is loosely defined at diff --git a/hikari/events/__init__.py b/hikari/events/__init__.py index 5f4c5ce3cd..2e858ab8e5 100644 --- a/hikari/events/__init__.py +++ b/hikari/events/__init__.py @@ -24,6 +24,7 @@ from __future__ import annotations +from hikari.events.application_events import * from hikari.events.base_events import Event from hikari.events.base_events import ExceptionEvent from hikari.events.channel_events import * diff --git a/hikari/events/__init__.pyi b/hikari/events/__init__.pyi index d0009472e1..2bd26128da 100644 --- a/hikari/events/__init__.pyi +++ b/hikari/events/__init__.pyi @@ -1,6 +1,7 @@ # DO NOT MANUALLY EDIT THIS FILE! # This file was automatically generated by `nox -s generate-stubs` +from hikari.events.application_events import * from hikari.events.base_events import Event as Event from hikari.events.base_events import ExceptionEvent as ExceptionEvent from hikari.events.channel_events import * diff --git a/hikari/events/application_events.py b/hikari/events/application_events.py new file mode 100644 index 0000000000..0d9b454d5e --- /dev/null +++ b/hikari/events/application_events.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# cython: language_level=3 +# Copyright (c) 2020 Nekokatt +# Copyright (c) 2021-present davfsa +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +"""Events fired for application related changes.""" +from __future__ import annotations + +__all__: typing.Sequence[str] = ("ApplicationCommandPermissionsUpdateEvent",) + +import typing + +import attr + +from hikari.events import shard_events +from hikari.internal import attr_extensions + +if typing.TYPE_CHECKING: + from hikari import commands + from hikari import traits + from hikari.api import shard as gateway_shard + + +@attr_extensions.with_copy +@attr.define(kw_only=True, weakref_slot=False) +class ApplicationCommandPermissionsUpdateEvent(shard_events.ShardEvent): + """Event fired when permissions for an application command are updated.""" + + app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + # <>. + + shard: gateway_shard.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + # <>. + + permissions: commands.GuildCommandPermissions = attr.field(repr=False) + """The updated application command permissions.""" diff --git a/hikari/events/typing_events.py b/hikari/events/typing_events.py index 40756df7a3..6d02278c2e 100644 --- a/hikari/events/typing_events.py +++ b/hikari/events/typing_events.py @@ -196,12 +196,12 @@ def user_id(self) -> snowflakes.Snowflake: # <>. return self.member.id - async def fetch_channel(self) -> typing.Union[channels.TextableGuildChannel]: + async def fetch_channel(self) -> channels.TextableGuildChannel: """Perform an API call to fetch an up-to-date image of this channel. Returns ------- - typing.Union[hikari.channels.TextableGuildChannel] + hikari.channels.TextableGuildChannel The channel. """ channel = await super().fetch_channel() diff --git a/hikari/impl/bot.py b/hikari/impl/bot.py index d77820cc2a..6107facd2e 100644 --- a/hikari/impl/bot.py +++ b/hikari/impl/bot.py @@ -207,6 +207,19 @@ class GatewayBot(traits.GatewayBotAware): Defaults to `hikari.intents.Intents.ALL_UNPRIVILEGED`. This allows you to change which intents your application will use on the gateway. This can be used to control and change the types of events you will receive. + auto_chunk_members : builtins.bool + Defaults to `builtins.True`. If `builtins.False`, then no member chunks + will be requested automatically, even if there are reasons to do so. + + All following statements must be true to automatically request chunks: + + 1. `auto_chunk_members` is `builtins.True`. + 2. The members intent is enabled. + 3. The server is marked as "large" or the presences intent is not enabled + (since Discord only sends other members when presences are declared, + we should also chunk small guilds if the presences are not declared). + 4. The members cache is enabled or there are listeners for the + `MemberChunkEvent`. logs : typing.Union[builtins.None, LoggerLevel, typing.Dict[str, typing.Any]] Defaults to `"INFO"`. @@ -320,6 +333,7 @@ def __init__( cache_settings: typing.Optional[config_impl.CacheSettings] = None, http_settings: typing.Optional[config_impl.HTTPSettings] = None, intents: intents_.Intents = intents_.Intents.ALL_UNPRIVILEGED, + auto_chunk_members: bool = True, logs: typing.Union[None, int, str, typing.Dict[str, typing.Any]] = "INFO", max_rate_limit: float = 300, max_retries: int = 3, @@ -352,7 +366,11 @@ def __init__( # Event handling self._event_manager = event_manager_impl.EventManagerImpl( - self._entity_factory, self._event_factory, self._intents, cache=self._cache + self._entity_factory, + self._event_factory, + self._intents, + auto_chunk_members=auto_chunk_members, + cache=self._cache, ) # Voice subsystem diff --git a/hikari/impl/buckets.py b/hikari/impl/buckets.py index bc5848d80a..1733c5ddb5 100644 --- a/hikari/impl/buckets.py +++ b/hikari/impl/buckets.py @@ -496,7 +496,7 @@ def do_gc_pass(self, expire_after: float) -> None: `RESTBucketManager.start` and `RESTBucketManager.close` to control this instead. """ - buckets_to_purge = [] + buckets_to_purge: typing.List[str] = [] now = time.monotonic() diff --git a/hikari/impl/cache.py b/hikari/impl/cache.py index 393a9919f5..6293788645 100644 --- a/hikari/impl/cache.py +++ b/hikari/impl/cache.py @@ -1479,8 +1479,8 @@ def _garbage_collect_message( if message.object.referenced_message: self._garbage_collect_message(message.object.referenced_message, decrement=1) - if message.object.mentions.users: - for user in message.object.mentions.users.values(): + if message.object.user_mentions: + for user in message.object.user_mentions.values(): self._garbage_collect_user(user, decrement=1) # If we got this far the message won't be in _message_entries as that'd infer that it hasn't been marked as @@ -1551,11 +1551,11 @@ def _set_message( author = self._set_user(message.author) member = self._set_member(message.member) if message.member else None - mention_users: undefined.UndefinedOr[ + user_mentions: undefined.UndefinedOr[ typing.Mapping[snowflakes.Snowflake, cache_utility.RefCell[users.User]] ] = undefined.UNDEFINED - if message.mentions.users is not undefined.UNDEFINED: - mention_users = {user_id: self._set_user(user) for user_id, user in message.mentions.users.items()} + if message.user_mentions is not undefined.UNDEFINED: + user_mentions = {user_id: self._set_user(user) for user_id, user in message.user_mentions.items()} interaction_user: typing.Optional[cache_utility.RefCell[users.User]] = None if message.interaction: @@ -1563,7 +1563,12 @@ def _set_message( referenced_message: typing.Optional[cache_utility.RefCell[cache_utility.MessageData]] = None if message.referenced_message: - referenced_message = self._set_message(message.referenced_message) + reference_id = message.referenced_message.id + referenced_message = self._message_entries.get(reference_id) or self._referenced_messages.get(reference_id) + + if referenced_message: + # Since the message is partial, if we don't have it cached, there is nothing we can do about it + referenced_message.object.update(message.referenced_message) # Only increment ref counts if this wasn't previously cached. if message.id not in self._referenced_messages and message.id not in self._message_entries: @@ -1573,8 +1578,8 @@ def _set_message( if referenced_message: self._increment_ref_count(referenced_message) - if mention_users is not undefined.UNDEFINED: - for user in mention_users.values(): + if user_mentions is not undefined.UNDEFINED: + for user in user_mentions.values(): self._increment_ref_count(user) if interaction_user: @@ -1584,7 +1589,7 @@ def _set_message( message, author=author, member=member, - mention_users=mention_users, + user_mentions=user_mentions, referenced_message=referenced_message, interaction_user=interaction_user, ) @@ -1625,19 +1630,19 @@ def update_message( self.set_message(message) elif cached_message_data := self._message_entries.get(message.id) or self._referenced_messages.get(message.id): - mention_user: undefined.UndefinedOr[ + user_mentions: undefined.UndefinedOr[ typing.Mapping[snowflakes.Snowflake, cache_utility.RefCell[users.User]] ] = undefined.UNDEFINED - if message.mentions.users is not undefined.UNDEFINED: - mention_user = {user_id: self._set_user(user) for user_id, user in message.mentions.users.items()} + if message.user_mentions is not undefined.UNDEFINED: + user_mentions = {user_id: self._set_user(user) for user_id, user in message.user_mentions.items()} # We want to ensure that any previously mentioned users are garbage collected if they're no longer # being mentioned. - if cached_message_data.object.mentions.users is not undefined.UNDEFINED: - for user_id, user in cached_message_data.object.mentions.users.items(): - if user_id not in mention_user: + if cached_message_data.object.user_mentions is not undefined.UNDEFINED: + for user_id, user in cached_message_data.object.user_mentions.items(): + if user_id not in user_mentions: self._garbage_collect_user(user, decrement=1) - cached_message_data.object.update(message, mention_users=mention_user) + cached_message_data.object.update(message, user_mentions=user_mentions) return cached_message, self.get_message(message.id) diff --git a/hikari/impl/config.py b/hikari/impl/config.py index 9fc035971a..581a878ca8 100644 --- a/hikari/impl/config.py +++ b/hikari/impl/config.py @@ -246,7 +246,7 @@ class HTTPTimeoutSettings: @request_socket_connect.validator @request_socket_read.validator @total.validator - def _(self, attrib: attr.Attribute[typing.Optional[float]], value: typing.Optional[float]) -> None: + def _(self, attrib: attr.Attribute[typing.Optional[float]], value: typing.Any) -> None: # This error won't occur until some time in the future where it will be annoying to # try and determine the root cause, so validate it NOW. if value is not None and (not isinstance(value, (float, int)) or value <= 0): @@ -317,7 +317,7 @@ class HTTPSettings(config.HTTPSettings): """ @max_redirects.validator - def _(self, _: attr.Attribute[typing.Optional[int]], value: typing.Optional[int]) -> None: + def _(self, _: attr.Attribute[typing.Optional[int]], value: typing.Any) -> None: # This error won't occur until some time in the future where it will be annoying to # try and determine the root cause, so validate it NOW. if value is not None and (not isinstance(value, int) or value <= 0): diff --git a/hikari/impl/entity_factory.py b/hikari/impl/entity_factory.py index 5d934649ff..9c052c3a5d 100644 --- a/hikari/impl/entity_factory.py +++ b/hikari/impl/entity_factory.py @@ -425,6 +425,7 @@ def __init__(self, app: traits.RESTAware) -> None: audit_log_models.AuditLogChangeKey.APPLICATION_ID: snowflakes.Snowflake, audit_log_models.AuditLogChangeKey.PERMISSIONS: _with_int_cast(permission_models.Permissions), audit_log_models.AuditLogChangeKey.COLOR: color_models.Color, + audit_log_models.AuditLogChangeKey.COMMAND_ID: snowflakes.Snowflake, audit_log_models.AuditLogChangeKey.ALLOW: _with_int_cast(permission_models.Permissions), audit_log_models.AuditLogChangeKey.DENY: _with_int_cast(permission_models.Permissions), audit_log_models.AuditLogChangeKey.CHANNEL_ID: snowflakes.Snowflake, @@ -639,7 +640,7 @@ def deserialize_implicit_token(self, query: data_binding.Query) -> application_m def _deserialize_audit_log_change_roles( self, payload: data_binding.JSONArray ) -> typing.Mapping[snowflakes.Snowflake, guild_models.PartialRole]: - roles = {} + roles: typing.Dict[snowflakes.Snowflake, guild_models.PartialRole] = {} for role_payload in payload: role = guild_models.PartialRole( app=self._app, id=snowflakes.Snowflake(role_payload["id"]), name=role_payload["name"] @@ -713,7 +714,7 @@ def deserialize_audit_log(self, payload: data_binding.JSONObject) -> audit_log_m for entry_payload in payload["audit_log_entries"]: entry_id = snowflakes.Snowflake(entry_payload["id"]) - changes = [] + changes: typing.List[audit_log_models.AuditLogChange] = [] if (change_payloads := entry_payload.get("changes")) is not None: for change_payload in change_payloads: key: typing.Union[audit_log_models.AuditLogChangeKey, str] = audit_log_models.AuditLogChangeKey( @@ -726,7 +727,9 @@ def deserialize_audit_log(self, payload: data_binding.JSONObject) -> audit_log_m new_value = value_converter(new_value) if new_value is not None else None old_value = value_converter(old_value) if old_value is not None else None - elif not isinstance(key, audit_log_models.AuditLogChangeKey): + elif not isinstance( + key, audit_log_models.AuditLogChangeKey + ): # pyright: ignore [reportUnnecessaryIsInstance] _LOGGER.debug("Unknown audit log change key found %r", key) changes.append(audit_log_models.AuditLogChange(key=key, new_value=new_value, old_value=old_value)) @@ -981,6 +984,11 @@ def deserialize_guild_voice_channel( channel_fields = self._set_guild_channel_attributes(payload, guild_id=guild_id) # Discord seems to be only returning this after it's been initially PATCHed in for older channels. video_quality_mode = payload.get("video_quality_mode", channel_models.VideoQualityMode.AUTO) + + last_message_id: typing.Optional[snowflakes.Snowflake] = None + if (raw_last_message_id := payload.get("last_message_id")) is not None: + last_message_id = snowflakes.Snowflake(raw_last_message_id) + return channel_models.GuildVoiceChannel( app=self._app, id=channel_fields.id, @@ -997,6 +1005,7 @@ def deserialize_guild_voice_channel( bitrate=int(payload["bitrate"]), user_limit=int(payload["user_limit"]), video_quality_mode=channel_models.VideoQualityMode(int(video_quality_mode)), + last_message_id=last_message_id, ) def deserialize_guild_stage_channel( @@ -1050,7 +1059,7 @@ def deserialize_embed(self, payload: data_binding.JSONObject) -> embed_models.Em timestamp = time.iso8601_datetime_string_to_datetime(payload["timestamp"]) if "timestamp" in payload else None fields: typing.Optional[typing.List[embed_models.EmbedField]] = None - image: typing.Optional[embed_models.EmbedImage[files.AsyncReader]] = None + image: typing.Optional[embed_models.EmbedImage] = None if (image_payload := payload.get("image")) and "url" in image_payload: proxy = files.ensure_resource(image_payload["proxy_url"]) if "proxy_url" in image_payload else None image = embed_models.EmbedImage( @@ -1060,7 +1069,7 @@ def deserialize_embed(self, payload: data_binding.JSONObject) -> embed_models.Em width=image_payload.get("width"), ) - thumbnail: typing.Optional[embed_models.EmbedImage[files.AsyncReader]] = None + thumbnail: typing.Optional[embed_models.EmbedImage] = None if (thumbnail_payload := payload.get("thumbnail")) and "url" in thumbnail_payload: proxy = files.ensure_resource(thumbnail_payload["proxy_url"]) if "proxy_url" in thumbnail_payload else None thumbnail = embed_models.EmbedImage( @@ -1070,7 +1079,7 @@ def deserialize_embed(self, payload: data_binding.JSONObject) -> embed_models.Em width=thumbnail_payload.get("width"), ) - video: typing.Optional[embed_models.EmbedVideo[files.AsyncReader]] = None + video: typing.Optional[embed_models.EmbedVideo] = None if (video_payload := payload.get("video")) and "url" in video_payload: raw_proxy_url = video_payload.get("proxy_url") video = embed_models.EmbedVideo( @@ -1084,7 +1093,7 @@ def deserialize_embed(self, payload: data_binding.JSONObject) -> embed_models.Em if provider_payload := payload.get("provider"): provider = embed_models.EmbedProvider(name=provider_payload.get("name"), url=provider_payload.get("url")) - icon: typing.Optional[embed_models.EmbedResourceWithProxy[files.AsyncReader]] + icon: typing.Optional[embed_models.EmbedResourceWithProxy] author: typing.Optional[embed_models.EmbedAuthor] = None if author_payload := payload.get("author"): icon = None @@ -1139,11 +1148,10 @@ def deserialize_embed(self, payload: data_binding.JSONObject) -> embed_models.Em ) def serialize_embed( # noqa: C901 - Function too complex - self, - embed: embed_models.Embed, + self, embed: embed_models.Embed ) -> typing.Tuple[data_binding.JSONObject, typing.List[files.Resource[files.AsyncReader]]]: - payload: data_binding.JSONObject = {} + payload: typing.Dict[str, typing.Any] = {} uploads: typing.List[files.Resource[files.AsyncReader]] = [] if embed.title is not None: @@ -1162,7 +1170,7 @@ def serialize_embed( # noqa: C901 - Function too complex payload["color"] = int(embed.color) if embed.footer is not None: - footer_payload: data_binding.JSONObject = {} + footer_payload: typing.MutableMapping[str, typing.Any] = {} if embed.footer.text is not None: footer_payload["text"] = embed.footer.text @@ -1176,7 +1184,7 @@ def serialize_embed( # noqa: C901 - Function too complex payload["footer"] = footer_payload if embed.image is not None: - image_payload: data_binding.JSONObject = {} + image_payload: typing.MutableMapping[str, typing.Any] = {} if not isinstance(embed.image.resource, files.WebResource): uploads.append(embed.image.resource) @@ -1185,7 +1193,7 @@ def serialize_embed( # noqa: C901 - Function too complex payload["image"] = image_payload if embed.thumbnail is not None: - thumbnail_payload: data_binding.JSONObject = {} + thumbnail_payload: typing.MutableMapping[str, typing.Any] = {} if not isinstance(embed.thumbnail.resource, files.WebResource): uploads.append(embed.thumbnail.resource) @@ -1194,7 +1202,7 @@ def serialize_embed( # noqa: C901 - Function too complex payload["thumbnail"] = thumbnail_payload if embed.author is not None: - author_payload: data_binding.JSONObject = {} + author_payload: typing.MutableMapping[str, typing.Any] = {} if embed.author.name is not None: author_payload["name"] = embed.author.name @@ -1210,7 +1218,7 @@ def serialize_embed( # noqa: C901 - Function too complex payload["author"] = author_payload if embed.fields: - field_payloads: data_binding.JSONArray = [] + field_payloads: typing.List[data_binding.JSONObject] = [] for i, field in enumerate(embed.fields): # Yep, these are technically two unreachable branches. However, this is an incredibly @@ -1340,7 +1348,7 @@ def deserialize_welcome_screen(self, payload: data_binding.JSONObject) -> guild_ return guild_models.WelcomeScreen(description=payload["description"], channels=channels) def serialize_welcome_channel(self, welcome_channel: guild_models.WelcomeChannel) -> data_binding.JSONObject: - payload: data_binding.JSONObject = { + payload: typing.Dict[str, typing.Any] = { "channel_id": str(welcome_channel.channel_id), "description": welcome_channel.description, } @@ -1378,10 +1386,9 @@ def deserialize_member( time.iso8601_datetime_string_to_datetime(raw_premium_since) if raw_premium_since is not None else None ) + communication_disabled_until: typing.Optional[datetime.datetime] = None if raw_communication_disabled_until := payload.get("communication_disabled_until"): communication_disabled_until = time.iso8601_datetime_string_to_datetime(raw_communication_disabled_until) - else: - communication_disabled_until = None return guild_models.Member( user=user, @@ -1777,6 +1784,14 @@ def deserialize_slash_command( if raw_options := payload.get("options"): options = [self._deserialize_command_option(option) for option in raw_options] + # Discord considers 0 the same thing as ADMINISTRATORS, but we make it nicer to work with + # by setting it correctly. + default_member_permissions = payload["default_member_permissions"] + if default_member_permissions == 0: + default_member_permissions = permission_models.Permissions.ADMINISTRATOR + else: + default_member_permissions = permission_models.Permissions(default_member_permissions or 0) + return commands.SlashCommand( app=self._app, id=snowflakes.Snowflake(payload["id"]), @@ -1785,7 +1800,8 @@ def deserialize_slash_command( name=payload["name"], description=payload["description"], options=options, - default_permission=payload.get("default_permission", True), + default_member_permissions=default_member_permissions, + is_dm_enabled=payload.get("dm_permission", False), guild_id=guild_id, version=snowflakes.Snowflake(payload["version"]), ) @@ -1800,13 +1816,22 @@ def deserialize_context_menu_command( raw_guild_id = payload["guild_id"] guild_id = snowflakes.Snowflake(raw_guild_id) if raw_guild_id is not None else None + # Discord considers 0 the same thing as ADMINISTRATORS, but we make it nicer to work with + # by setting it correctly. + default_member_permissions = payload["default_member_permissions"] + if default_member_permissions == 0: + default_member_permissions = permission_models.Permissions.ADMINISTRATOR + else: + default_member_permissions = permission_models.Permissions(default_member_permissions or 0) + return commands.ContextMenuCommand( app=self._app, id=snowflakes.Snowflake(payload["id"]), type=commands.CommandType(payload["type"]), application_id=snowflakes.Snowflake(payload["application_id"]), name=payload["name"], - default_permission=payload.get("default_permission", True), + default_member_permissions=default_member_permissions, + is_dm_enabled=payload.get("dm_permission", False), guild_id=guild_id, version=snowflakes.Snowflake(payload["version"]), ) @@ -1837,6 +1862,7 @@ def deserialize_guild_command_permissions( for perm in payload["permissions"] ] return commands.GuildCommandPermissions( + id=snowflakes.Snowflake(payload["id"]), application_id=snowflakes.Snowflake(payload["application_id"]), command_id=snowflakes.Snowflake(payload["id"]), guild_id=snowflakes.Snowflake(payload["guild_id"]), @@ -2032,6 +2058,7 @@ def deserialize_command_interaction( if raw_target_id := data_payload.get("target_id"): target_id = snowflakes.Snowflake(raw_target_id) + app_perms = payload.get("app_permissions") return command_interactions.CommandInteraction( app=self._app, application_id=snowflakes.Snowflake(payload["application_id"]), @@ -2051,6 +2078,7 @@ def deserialize_command_interaction( options=options, resolved=resolved, target_id=target_id, + app_permissions=permission_models.Permissions(app_perms) if app_perms is not None else None, ) def deserialize_autocomplete_interaction( @@ -2075,10 +2103,6 @@ def deserialize_autocomplete_interaction( member = None user = self.deserialize_user(payload["user"]) - resolved: typing.Optional[command_interactions.ResolvedOptionData] = None - if resolved_payload := data_payload.get("resolved"): - resolved = self._deserialize_resolved_option_data(resolved_payload, guild_id=guild_id) - return command_interactions.AutocompleteInteraction( app=self._app, application_id=snowflakes.Snowflake(payload["application_id"]), @@ -2094,7 +2118,6 @@ def deserialize_autocomplete_interaction( command_name=data_payload["name"], command_type=commands.CommandType(data_payload.get("type", commands.CommandType.SLASH)), options=options, - resolved=resolved, locale=locales.Locale(payload["locale"]), guild_locale=locales.Locale(payload["guild_locale"]) if "guild_locale" in payload else None, ) @@ -2117,6 +2140,8 @@ def deserialize_modal_interaction(self, payload: data_binding.JSONObject) -> mod member = None user = self.deserialize_user(payload["user"]) + app_perms = payload.get("app_permissions") + components: typing.List[typing.Any] = [] for component_payload in data_payload["components"]: try: @@ -2134,6 +2159,7 @@ def deserialize_modal_interaction(self, payload: data_binding.JSONObject) -> mod id=snowflakes.Snowflake(payload["id"]), type=base_interactions.InteractionType(payload["type"]), guild_id=guild_id, + app_permissions=permission_models.Permissions(app_perms) if app_perms is not None else None, guild_locale=locales.Locale(payload["guild_locale"]) if "guild_locale" in payload else None, locale=locales.Locale(payload["locale"]), channel_id=snowflakes.Snowflake(payload["channel_id"]), @@ -2156,7 +2182,7 @@ def deserialize_interaction(self, payload: data_binding.JSONObject) -> base_inte raise errors.UnrecognisedEntityError(f"Unrecognised interaction type {interaction_type}") def serialize_command_option(self, option: commands.CommandOption) -> data_binding.JSONObject: - payload: data_binding.JSONObject = { + payload: typing.MutableMapping[str, typing.Any] = { "type": option.type, "name": option.name, "description": option.description, @@ -2202,6 +2228,7 @@ def deserialize_component_interaction( member = None user = self.deserialize_user(payload["user"]) + app_perms = payload.get("app_permissions") return component_interactions.ComponentInteraction( app=self._app, application_id=snowflakes.Snowflake(payload["application_id"]), @@ -2219,6 +2246,7 @@ def deserialize_component_interaction( message=self.deserialize_message(payload["message"]), locale=locales.Locale(payload["locale"]), guild_locale=locales.Locale(payload["guild_locale"]) if "guild_locale" in payload else None, + app_permissions=permission_models.Permissions(app_perms) if app_perms is not None else None, ) ################## @@ -2489,6 +2517,20 @@ def deserialize_partial_message( # noqa CFQ001 - Function too long except errors.UnrecognisedEntityError: pass + channel_mentions: undefined.UndefinedOr[ + typing.Dict[snowflakes.Snowflake, channel_models.PartialChannel] + ] = undefined.UNDEFINED + if raw_channel_mentions := payload.get("mention_channels"): + channel_mentions = {c.id: c for c in map(self.deserialize_partial_channel, raw_channel_mentions)} + + user_mentions: undefined.UndefinedOr[typing.Dict[snowflakes.Snowflake, user_models.User]] = undefined.UNDEFINED + if raw_user_mentions := payload.get("mentions"): + user_mentions = {u.id: u for u in map(self.deserialize_user, raw_user_mentions)} + + role_mention_ids: undefined.UndefinedOr[typing.List[snowflakes.Snowflake]] = undefined.UNDEFINED + if raw_role_mention_ids := payload.get("mention_roles"): + role_mention_ids = [snowflakes.Snowflake(i) for i in raw_role_mention_ids] + message = message_models.PartialMessage( app=self._app, id=snowflakes.Snowflake(payload["id"]), @@ -2516,33 +2558,15 @@ def deserialize_partial_message( # noqa CFQ001 - Function too long application_id=application_id, interaction=interaction, components=components, + channel_mentions=channel_mentions, + user_mentions=user_mentions, + role_mention_ids=role_mention_ids, + mentions_everyone=payload.get("mention_everyone", undefined.UNDEFINED), # We initialize these next. mentions=NotImplemented, ) - channels: undefined.UndefinedOr[typing.Dict[snowflakes.Snowflake, channel_models.PartialChannel]] - channels = undefined.UNDEFINED - if raw_channels := payload.get("mention_channels"): - channels = {c.id: c for c in map(self.deserialize_partial_channel, raw_channels)} - - users: undefined.UndefinedOr[typing.Dict[snowflakes.Snowflake, user_models.User]] - users = undefined.UNDEFINED - if raw_users := payload.get("mentions"): - users = {u.id: u for u in map(self.deserialize_user, raw_users)} - - role_ids: undefined.UndefinedOr[typing.List[snowflakes.Snowflake]] = undefined.UNDEFINED - if raw_role_ids := payload.get("mention_roles"): - role_ids = [snowflakes.Snowflake(i) for i in raw_role_ids] - - everyone = payload.get("mention_everyone", undefined.UNDEFINED) - - message.mentions = message_models.Mentions( - message=message, - users=users, - role_ids=role_ids, - channels=channels, - everyone=everyone, - ) + message.mentions = message_models.Mentions(message=message) return message @@ -2580,9 +2604,9 @@ def deserialize_message( if "message_reference" in payload: message_reference = self._deserialize_message_reference(payload["message_reference"]) - referenced_message: typing.Optional[message_models.Message] = None + referenced_message: typing.Optional[message_models.PartialMessage] = None if referenced_message_payload := payload.get("referenced_message"): - referenced_message = self.deserialize_message(referenced_message_payload) + referenced_message = self.deserialize_partial_message(referenced_message_payload) application: typing.Optional[message_models.MessageApplication] = None if "application" in payload: @@ -2608,6 +2632,10 @@ def deserialize_message( except errors.UnrecognisedEntityError: pass + user_mentions = {u.id: u for u in map(self.deserialize_user, payload.get("mentions", ()))} + role_mention_ids = [snowflakes.Snowflake(i) for i in payload.get("mention_roles", ())] + channel_mentions = {u.id: u for u in map(self.deserialize_partial_channel, payload.get("mention_channels", ()))} + message = message_models.Message( app=self._app, id=snowflakes.Snowflake(payload["id"]), @@ -2635,37 +2663,15 @@ def deserialize_message( application_id=snowflakes.Snowflake(payload["application_id"]) if "application_id" in payload else None, interaction=interaction, components=components, + user_mentions=user_mentions, + channel_mentions=channel_mentions, + role_mention_ids=role_mention_ids, + mentions_everyone=payload.get("mention_everyone", False), # We initialize these next. mentions=NotImplemented, ) - if raw_channels := payload.get("mention_channels"): - channels = {c.id: c for c in map(self.deserialize_partial_channel, raw_channels)} - - else: - channels = {} - - if raw_users := payload.get("mentions"): - users = {u.id: u for u in map(self.deserialize_user, raw_users)} - - else: - users = {} - - if raw_role_ids := payload.get("mention_roles"): - role_ids = [snowflakes.Snowflake(i) for i in raw_role_ids] - - else: - role_ids = [] - - everyone = payload.get("mention_everyone", False) - - message.mentions = message_models.Mentions( - message=message, - users=users, - role_ids=role_ids, - channels=channels, - everyone=everyone, - ) + message.mentions = message_models.Mentions(message=message) return message @@ -2679,7 +2685,7 @@ def deserialize_member_presence( # noqa: CFQ001 - Max function length *, guild_id: undefined.UndefinedOr[snowflakes.Snowflake] = undefined.UNDEFINED, ) -> presence_models.MemberPresence: - activities = [] + activities: typing.List[presence_models.RichActivity] = [] for activity_payload in payload["activities"]: timestamps: typing.Optional[presence_models.ActivityTimestamps] = None if "timestamps" in activity_payload: diff --git a/hikari/impl/event_factory.py b/hikari/impl/event_factory.py index 49715d3b76..5489f0ffd5 100644 --- a/hikari/impl/event_factory.py +++ b/hikari/impl/event_factory.py @@ -38,6 +38,7 @@ from hikari import undefined from hikari import users as user_models from hikari.api import event_factory +from hikari.events import application_events from hikari.events import channel_events from hikari.events import guild_events from hikari.events import interaction_events @@ -73,6 +74,18 @@ class EventFactoryImpl(event_factory.EventFactory): def __init__(self, app: traits.RESTAware) -> None: self._app = app + ###################### + # APPLICATION EVENTS # + ###################### + + def deserialize_application_command_permission_update_event( + self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject + ) -> application_events.ApplicationCommandPermissionsUpdateEvent: + permissions = self._app.entity_factory.deserialize_guild_command_permissions(payload) + return application_events.ApplicationCommandPermissionsUpdateEvent( + app=self._app, shard=shard, permissions=permissions + ) + ################## # CHANNEL EVENTS # ################## diff --git a/hikari/impl/event_manager.py b/hikari/impl/event_manager.py index a9feacbf79..28467fab01 100644 --- a/hikari/impl/event_manager.py +++ b/hikari/impl/event_manager.py @@ -37,6 +37,7 @@ from hikari import presences as presences_ from hikari import snowflakes from hikari.api import config +from hikari.events import application_events from hikari.events import channel_events from hikari.events import guild_events from hikari.events import interaction_events @@ -92,7 +93,7 @@ async def _request_guild_members( class EventManagerImpl(event_manager_base.EventManagerBase): """Provides event handling logic for Discord events.""" - __slots__: typing.Sequence[str] = ("_cache", "_entity_factory") + __slots__: typing.Sequence[str] = ("_cache", "_entity_factory", "_auto_chunk_members") def __init__( self, @@ -101,9 +102,11 @@ def __init__( intents: intents_.Intents, /, *, + auto_chunk_members: bool = True, cache: typing.Optional[cache_.MutableCache] = None, ) -> None: self._cache = cache + self._auto_chunk_members = auto_chunk_members self._entity_factory = entity_factory components = cache.settings.components if cache else config.CacheComponents.NONE super().__init__(event_factory=event_factory, intents=intents, cache_components=components) @@ -127,6 +130,12 @@ async def on_resumed(self, shard: gateway_shard.GatewayShard, _: data_binding.JS """See https://discord.com/developers/docs/topics/gateway#resumed for more info.""" await self.dispatch(self._event_factory.deserialize_resumed_event(shard)) + @event_manager_base.filtered(application_events.ApplicationCommandPermissionsUpdateEvent) + async def on_application_command_permissions_update( + self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject + ) -> None: + await self.dispatch(self._event_factory.deserialize_application_command_permission_update_event(shard, payload)) + @event_manager_base.filtered(channel_events.GuildChannelCreateEvent, config.CacheComponents.GUILD_CHANNELS) async def on_channel_create(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#channel-create for more info.""" @@ -257,7 +266,8 @@ async def on_guild_create( # noqa: C901 - Function too complex # payload if presence intents are also declared, so if this isn't the case then we also want # to chunk small guilds. if ( - self._intents & intents_.Intents.GUILD_MEMBERS + self._auto_chunk_members + and self._intents & intents_.Intents.GUILD_MEMBERS and (payload.get("large") or not presences_declared) and ( self._cache_enabled_for(config.CacheComponents.MEMBERS) @@ -396,16 +406,16 @@ async def on_integration_create(self, shard: gateway_shard.GatewayShard, payload event = self._event_factory.deserialize_integration_create_event(shard, payload) await self.dispatch(event) - @event_manager_base.filtered(guild_events.IntegrationDeleteEvent) - async def on_integration_delete(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: - event = self._event_factory.deserialize_integration_delete_event(shard, payload) - await self.dispatch(event) - @event_manager_base.filtered(guild_events.IntegrationUpdateEvent) async def on_integration_update(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: event = self._event_factory.deserialize_integration_update_event(shard, payload) await self.dispatch(event) + @event_manager_base.filtered(guild_events.IntegrationDeleteEvent) + async def on_integration_delete(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: + event = self._event_factory.deserialize_integration_delete_event(shard, payload) + await self.dispatch(event) + @event_manager_base.filtered(member_events.MemberCreateEvent, config.CacheComponents.MEMBERS) async def on_guild_member_add(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None: """See https://discord.com/developers/docs/topics/gateway#guild-member-add for more info.""" diff --git a/hikari/impl/event_manager_base.py b/hikari/impl/event_manager_base.py index c4a98354c9..e855d314c4 100644 --- a/hikari/impl/event_manager_base.py +++ b/hikari/impl/event_manager_base.py @@ -63,7 +63,7 @@ typing.List[event_manager_.CallbackT[base_events.EventT]], ] _WaiterT = typing.Tuple[ - typing.Optional[event_manager_.PredicateT[base_events.EventT]], asyncio.Future[base_events.EventT] + typing.Optional[event_manager_.PredicateT[base_events.EventT]], "asyncio.Future[base_events.EventT]" ] _WaiterMapT = typing.Dict[typing.Type[base_events.EventT], typing.Set[_WaiterT[base_events.EventT]]] @@ -103,8 +103,8 @@ def __events_bitmask__(self) -> int: def _generate_weak_listener( reference: weakref.WeakMethod[typing.Any], -) -> typing.Callable[[base_events.EventT], typing.Coroutine[typing.Any, typing.Any, None]]: - async def call_weak_method(event: base_events.EventT) -> None: +) -> typing.Callable[[base_events.Event], typing.Coroutine[typing.Any, typing.Any, None]]: + async def call_weak_method(event: base_events.Event) -> None: method = reference() if method is None: raise TypeError( @@ -555,10 +555,7 @@ def decorator( return decorator - def dispatch(self, event: base_events.EventT) -> asyncio.Future[typing.Any]: - if not isinstance(event, base_events.Event): - raise TypeError(f"Events must be subclasses of {base_events.Event.__name__}, not {type(event).__name__}") - + def dispatch(self, event: base_events.Event) -> asyncio.Future[typing.Any]: tasks: typing.List[typing.Coroutine[None, typing.Any, None]] = [] for cls in event.dispatches(): @@ -613,6 +610,7 @@ async def wait_for( future: asyncio.Future[base_events.EventT] = asyncio.get_running_loop().create_future() + waiter_set: typing.MutableSet[_WaiterT[base_events.Event]] try: waiter_set = self._waiters[event_type] except KeyError: diff --git a/hikari/impl/rest.py b/hikari/impl/rest.py index 34792a0973..2c7bab0551 100644 --- a/hikari/impl/rest.py +++ b/hikari/impl/rest.py @@ -260,9 +260,6 @@ def proxy_settings(self) -> config_impl.ProxySettings: return self._rest().proxy_settings -_NONE_OR_UNDEFINED: typing.Final[typing.Tuple[None, undefined.UndefinedType]] = (None, undefined.UNDEFINED) - - class RESTApp(traits.ExecutorAware): """The base for a HTTP-only Discord application. @@ -762,7 +759,7 @@ async def _request( ) if trace_logging_enabled: - time_taken = (time.monotonic() - start) * 1_000 + time_taken = (time.monotonic() - start) * 1_000 # pyright: ignore[reportUnboundVariable] _LOGGER.log( ux.TRACE, "%s %s %s in %sms\n%s", @@ -1246,24 +1243,6 @@ def _build_message_payload( # noqa: C901- Function too complex if not undefined.any_undefined(embed, embeds): raise ValueError("You may only specify one of 'embed' or 'embeds', not both") - if attachments is not undefined.UNDEFINED and not isinstance(attachments, typing.Collection): - raise TypeError( - "You passed a non-collection to 'attachments', but this expects a collection. Maybe you meant to " - "use 'attachment' (singular) instead?" - ) - - if components not in _NONE_OR_UNDEFINED and not isinstance(components, typing.Collection): - raise TypeError( - "You passed a non-collection to 'components', but this expects a collection. Maybe you meant to " - "use 'component' (singular) instead?" - ) - - if embeds not in _NONE_OR_UNDEFINED and not isinstance(embeds, typing.Collection): - raise TypeError( - "You passed a non-collection to 'embeds', but this expects a collection. Maybe you meant to " - "use 'embed' (singular) instead?" - ) - if undefined.all_undefined(embed, embeds) and isinstance(content, embeds_.Embed): # Syntactic sugar, common mistake to accidentally send an embed # as the content, so let's detect this and fix it for the user. @@ -2786,13 +2765,24 @@ async def fetch_ban( assert isinstance(response, dict) return self._entity_factory.deserialize_guild_member_ban(response) - async def fetch_bans( - self, guild: snowflakes.SnowflakeishOr[guilds.PartialGuild] - ) -> typing.Sequence[guilds.GuildBan]: - route = routes.GET_GUILD_BANS.compile(guild=guild) - response = await self._request(route) - assert isinstance(response, list) - return [self._entity_factory.deserialize_guild_member_ban(ban_payload) for ban_payload in response] + def fetch_bans( + self, + guild: snowflakes.SnowflakeishOr[guilds.PartialGuild], + /, + *, + newest_first: bool = False, + start_at: undefined.UndefinedOr[snowflakes.SearchableSnowflakeishOr[users.PartialUser]] = undefined.UNDEFINED, + ) -> iterators.LazyIterator[guilds.GuildBan]: + if start_at is undefined.UNDEFINED: + start_at = snowflakes.Snowflake.max() if newest_first else snowflakes.Snowflake.min() + elif isinstance(start_at, datetime.datetime): + start_at = snowflakes.Snowflake.from_datetime(start_at) + else: + start_at = int(start_at) + + return special_endpoints_impl.GuildBanIterator( + self._entity_factory, self._request, guild, newest_first, str(start_at) + ) async def fetch_roles( self, @@ -3184,7 +3174,10 @@ async def _create_application_command( *, guild: undefined.UndefinedOr[snowflakes.SnowflakeishOr[guilds.PartialGuild]] = undefined.UNDEFINED, options: undefined.UndefinedOr[typing.Sequence[commands.CommandOption]] = undefined.UNDEFINED, - default_permission: undefined.UndefinedOr[bool] = undefined.UNDEFINED, + default_member_permissions: typing.Union[ + undefined.UndefinedType, int, permissions_.Permissions + ] = undefined.UNDEFINED, + dm_enabled: undefined.UndefinedOr[bool] = undefined.UNDEFINED, ) -> data_binding.JSONObject: if guild is undefined.UNDEFINED: route = routes.POST_APPLICATION_COMMAND.compile(application=application) @@ -3197,32 +3190,15 @@ async def _create_application_command( body.put("description", description) body.put("type", type) body.put_array("options", options, conversion=self._entity_factory.serialize_command_option) - body.put("default_permission", default_permission) + # Discord has some funky behaviour around what 0 means. They consider it to be the same as ADMINISTRATOR, + # but we consider it to be the same as None for developer sanity reasons + body.put("default_member_permissions", None if default_member_permissions == 0 else default_member_permissions) + body.put("dm_permission", dm_enabled) response = await self._request(route, json=body) assert isinstance(response, dict) return response - @deprecation.deprecated("2.0.0.dev106", "create_slash_command") - async def create_application_command( - self, - application: snowflakes.SnowflakeishOr[guilds.PartialApplication], - name: str, - description: str, - guild: undefined.UndefinedOr[snowflakes.SnowflakeishOr[guilds.PartialGuild]] = undefined.UNDEFINED, - *, - options: undefined.UndefinedOr[typing.Sequence[commands.CommandOption]] = undefined.UNDEFINED, - default_permission: undefined.UndefinedOr[bool] = undefined.UNDEFINED, - ) -> commands.SlashCommand: - return await self.create_slash_command( - application=application, - name=name, - description=description, - guild=guild, - options=options, - default_permission=default_permission, - ) - async def create_slash_command( self, application: snowflakes.SnowflakeishOr[guilds.PartialApplication], @@ -3231,7 +3207,10 @@ async def create_slash_command( *, guild: undefined.UndefinedOr[snowflakes.SnowflakeishOr[guilds.PartialGuild]] = undefined.UNDEFINED, options: undefined.UndefinedOr[typing.Sequence[commands.CommandOption]] = undefined.UNDEFINED, - default_permission: undefined.UndefinedOr[bool] = undefined.UNDEFINED, + default_member_permissions: typing.Union[ + undefined.UndefinedType, int, permissions_.Permissions + ] = undefined.UNDEFINED, + dm_enabled: undefined.UndefinedOr[bool] = undefined.UNDEFINED, ) -> commands.SlashCommand: response = await self._create_application_command( application=application, @@ -3240,7 +3219,8 @@ async def create_slash_command( description=description, guild=guild, options=options, - default_permission=default_permission, + default_member_permissions=default_member_permissions, + dm_enabled=dm_enabled, ) return self._entity_factory.deserialize_slash_command( response, guild_id=snowflakes.Snowflake(guild) if guild is not undefined.UNDEFINED else None @@ -3253,14 +3233,18 @@ async def create_context_menu_command( name: str, *, guild: undefined.UndefinedOr[snowflakes.SnowflakeishOr[guilds.PartialGuild]] = undefined.UNDEFINED, - default_permission: undefined.UndefinedOr[bool] = undefined.UNDEFINED, + default_member_permissions: typing.Union[ + undefined.UndefinedType, int, permissions_.Permissions + ] = undefined.UNDEFINED, + dm_enabled: undefined.UndefinedOr[bool] = undefined.UNDEFINED, ) -> commands.ContextMenuCommand: response = await self._create_application_command( application=application, type=type, name=name, guild=guild, - default_permission=default_permission, + default_member_permissions=default_member_permissions, + dm_enabled=dm_enabled, ) return self._entity_factory.deserialize_context_menu_command( response, guild_id=snowflakes.Snowflake(guild) if guild is not undefined.UNDEFINED else None @@ -3292,6 +3276,10 @@ async def edit_application_command( name: undefined.UndefinedOr[str] = undefined.UNDEFINED, description: undefined.UndefinedOr[str] = undefined.UNDEFINED, options: undefined.UndefinedOr[typing.Sequence[commands.CommandOption]] = undefined.UNDEFINED, + default_member_permissions: typing.Union[ + undefined.UndefinedType, int, permissions_.Permissions + ] = undefined.UNDEFINED, + dm_enabled: undefined.UndefinedOr[bool] = undefined.UNDEFINED, ) -> commands.PartialCommand: if guild is undefined.UNDEFINED: route = routes.PATCH_APPLICATION_COMMAND.compile(application=application, command=command) @@ -3305,6 +3293,10 @@ async def edit_application_command( body.put("name", name) body.put("description", description) body.put_array("options", options, conversion=self._entity_factory.serialize_command_option) + # Discord has some funky behaviour around what 0 means. They consider it to be the same as ADMINISTRATOR, + # but we consider it to be the same as None for developer sanity reasons + body.put("default_member_permissions", None if default_member_permissions == 0 else default_member_permissions) + body.put("dm_permission", dm_enabled) response = await self._request(route, json=body) assert isinstance(response, dict) @@ -3351,27 +3343,6 @@ async def fetch_application_command_permissions( assert isinstance(response, dict) return self._entity_factory.deserialize_guild_command_permissions(response) - async def set_application_guild_commands_permissions( - self, - application: snowflakes.SnowflakeishOr[guilds.PartialApplication], - guild: snowflakes.SnowflakeishOr[guilds.PartialGuild], - permissions: typing.Mapping[ - snowflakes.SnowflakeishOr[commands.PartialCommand], typing.Sequence[commands.CommandPermission] - ], - ) -> typing.Sequence[commands.GuildCommandPermissions]: - route = routes.PUT_APPLICATION_GUILD_COMMANDS_PERMISSIONS.compile(application=application, guild=guild) - body = [ - { - "id": str(snowflakes.Snowflake(command)), - "permissions": [self._entity_factory.serialize_command_permission(permission) for permission in perms], - } - for command, perms in permissions.items() - ] - response = await self._request(route, json=body) - - assert isinstance(response, list) - return [self._entity_factory.deserialize_guild_command_permissions(payload) for payload in response] - async def set_application_command_permissions( self, application: snowflakes.SnowflakeishOr[guilds.PartialApplication], diff --git a/hikari/impl/shard.py b/hikari/impl/shard.py index 8b7aa3b0dc..9c681fde6b 100644 --- a/hikari/impl/shard.py +++ b/hikari/impl/shard.py @@ -716,9 +716,9 @@ async def _identify(self) -> None: "compress": False, "large_threshold": self._large_threshold, "properties": { - "$os": f"{platform.system()} {platform.architecture()[0]}", - "$browser": f"hikari ({about.__version__}, aiohttp {aiohttp.__version__})", - "$device": f"hikari {about.__version__}", + "os": f"{platform.system()} {platform.architecture()[0]}", + "browser": f"hikari ({about.__version__}, aiohttp {aiohttp.__version__})", + "device": f"hikari {about.__version__}", }, "shard": [self._shard_id, self._shard_count], }, diff --git a/hikari/impl/special_endpoints.py b/hikari/impl/special_endpoints.py index 40c78d050a..ab75e13f7f 100644 --- a/hikari/impl/special_endpoints.py +++ b/hikari/impl/special_endpoints.py @@ -613,6 +613,58 @@ async def _next_chunk(self) -> typing.Optional[typing.Generator[applications.Own return (self._entity_factory.deserialize_own_guild(g) for g in chunk) +# We use an explicit forward reference for this, since this breaks potential +# circular import issues (once the file has executed, using those resources is +# not an issue for us). +class GuildBanIterator(iterators.BufferedLazyIterator["guilds.GuildBan"]): + """Iterator implementation for retrieving guild bans.""" + + __slots__: typing.Sequence[str] = ( + "_entity_factory", + "_guild_id", + "_request_call", + "_route", + "_first_id", + "_newest_first", + ) + + def __init__( + self, + entity_factory: entity_factory_.EntityFactory, + request_call: typing.Callable[ + ..., typing.Coroutine[None, None, typing.Union[None, data_binding.JSONObject, data_binding.JSONArray]] + ], + guild: snowflakes.SnowflakeishOr[guilds.PartialGuild], + newest_first: bool, + first_id: str, + ) -> None: + super().__init__() + self._guild_id = snowflakes.Snowflake(str(int(guild))) + self._route = routes.GET_GUILD_BANS.compile(guild=guild) + self._request_call = request_call + self._entity_factory = entity_factory + self._first_id = first_id + self._newest_first = newest_first + + async def _next_chunk(self) -> typing.Optional[typing.Generator[guilds.GuildBan, typing.Any, None]]: + query = data_binding.StringMapBuilder() + query.put("before" if self._newest_first else "after", self._first_id) + query.put("limit", 1000) + + chunk = await self._request_call(compiled_route=self._route, query=query) + assert isinstance(chunk, list) + + if not chunk: + return None + + if self._newest_first: + # These are always returned in ascending order by `.user.id`. + chunk.reverse() + + self._first_id = chunk[-1]["user"]["id"] + return (self._entity_factory.deserialize_guild_member_ban(b) for b in chunk) + + # We use an explicit forward reference for this, since this breaks potential # circular import issues (once the file has executed, using those resources is # not an issue for us). @@ -795,7 +847,7 @@ def set_choices( def build( self, _: entity_factory_.EntityFactory, / - ) -> typing.Tuple[data_binding.JSONObject, typing.Sequence[files.Resource[files.AsyncReader]]]: + ) -> typing.Tuple[typing.MutableMapping[str, typing.Any], typing.Sequence[files.Resource[files.AsyncReader]]]: data = {"choices": [{"name": choice.name, "value": choice.value} for choice in self._choices]} return {"type": self.type, "data": data}, () @@ -837,7 +889,7 @@ def set_flags( def build( self, _: entity_factory_.EntityFactory, / - ) -> typing.Tuple[data_binding.JSONObject, typing.Sequence[files.Resource[files.AsyncReader]]]: + ) -> typing.Tuple[typing.MutableMapping[str, typing.Any], typing.Sequence[files.Resource[files.AsyncReader]]]: if self._flags is not undefined.UNDEFINED: return {"type": self._type, "data": {"flags": self._flags}}, () @@ -1003,7 +1055,7 @@ def set_user_mentions( def build( self, entity_factory: entity_factory_.EntityFactory, / - ) -> typing.Tuple[data_binding.JSONObject, typing.Sequence[files.Resource[files.AsyncReader]]]: + ) -> typing.Tuple[typing.MutableMapping[str, typing.Any], typing.Sequence[files.Resource[files.AsyncReader]]]: data = data_binding.JSONObjectBuilder() data.put("content", self.content) @@ -1076,7 +1128,7 @@ def add_component( def build( self, entity_factory: entity_factory_.EntityFactory, / - ) -> typing.Tuple[data_binding.JSONObject, typing.Sequence[files.Resource[files.AsyncReader]]]: + ) -> typing.Tuple[typing.MutableMapping[str, typing.Any], typing.Sequence[files.Resource[files.AsyncReader]]]: data = data_binding.JSONObjectBuilder() data.put("title", self._title) data.put("custom_id", self._custom_id) @@ -1092,15 +1144,22 @@ class CommandBuilder(special_endpoints.CommandBuilder): _name: str = attr.field() _id: undefined.UndefinedOr[snowflakes.Snowflake] = attr.field(default=undefined.UNDEFINED, kw_only=True) - _default_permission: undefined.UndefinedOr[bool] = attr.field(default=undefined.UNDEFINED, kw_only=True) + _default_member_permissions: typing.Union[undefined.UndefinedType, int, permissions_.Permissions] = attr.field( + default=undefined.UNDEFINED, kw_only=True + ) + _is_dm_enabled: undefined.UndefinedOr[bool] = attr.field(default=undefined.UNDEFINED, kw_only=True) @property def id(self) -> undefined.UndefinedOr[snowflakes.Snowflake]: return self._id @property - def default_permission(self) -> undefined.UndefinedOr[bool]: - return self._default_permission + def default_member_permissions(self) -> typing.Union[undefined.UndefinedType, permissions_.Permissions, int]: + return self._default_member_permissions + + @property + def is_dm_enabled(self) -> undefined.UndefinedOr[bool]: + return self._is_dm_enabled @property def name(self) -> str: @@ -1110,16 +1169,25 @@ def set_id(self: _CommandBuilderT, id_: undefined.UndefinedOr[snowflakes.Snowfla self._id = snowflakes.Snowflake(id_) if id_ is not undefined.UNDEFINED else undefined.UNDEFINED return self - def set_default_permission(self: _CommandBuilderT, state: undefined.UndefinedOr[bool], /) -> _CommandBuilderT: - self._default_permission = state + def set_default_member_permissions( + self: _CommandBuilderT, + default_member_permissions: typing.Union[undefined.UndefinedType, int, permissions_.Permissions], + /, + ) -> _CommandBuilderT: + self._default_member_permissions = default_member_permissions + return self + + def set_is_dm_enabled(self: _CommandBuilderT, state: undefined.UndefinedOr[bool], /) -> _CommandBuilderT: + self._is_dm_enabled = state return self - def build(self, entity_factory: entity_factory_.EntityFactory, /) -> data_binding.JSONObjectBuilder: + def build(self, _: entity_factory_.EntityFactory, /) -> typing.MutableMapping[str, typing.Any]: data = data_binding.JSONObjectBuilder() data["name"] = self._name data["type"] = self.type data.put_snowflake("id", self._id) - data.put("default_permission", self._default_permission) + data.put("default_member_permissions", self._default_member_permissions) + data.put("dm_permission", self._is_dm_enabled) return data @@ -1147,8 +1215,12 @@ def add_option(self: _SlashCommandBuilderT, option: commands.CommandOption) -> _ def options(self) -> typing.Sequence[commands.CommandOption]: return self._options.copy() - def build(self, entity_factory: entity_factory_.EntityFactory, /) -> data_binding.JSONObjectBuilder: + def build(self, entity_factory: entity_factory_.EntityFactory, /) -> typing.MutableMapping[str, typing.Any]: data = super().build(entity_factory) + # Under this context we know this'll always be a JSONObjectBuilder but + # the return types need to be kept as MutableMapping to avoid exposing an + # internal type on the public API. + assert isinstance(data, data_binding.JSONObjectBuilder) data.put("description", self._description) data.put_array("options", self._options, conversion=entity_factory.serialize_command_option) return data @@ -1166,8 +1238,9 @@ async def create( self._name, self._description, guild=guild, - default_permission=self._default_permission, options=self._options, + default_member_permissions=self._default_member_permissions, + dm_enabled=self._is_dm_enabled, ) @@ -1193,7 +1266,12 @@ async def create( guild: undefined.UndefinedOr[snowflakes.SnowflakeishOr[guilds.PartialGuild]] = undefined.UNDEFINED, ) -> commands.ContextMenuCommand: return await rest.create_context_menu_command( - application, self._type, self._name, guild=guild, default_permission=self._default_permission + application, + self._type, + self._name, + guild=guild, + default_member_permissions=self._default_member_permissions, + dm_enabled=self._is_dm_enabled, ) @@ -1275,7 +1353,7 @@ def add_to_container(self) -> _ContainerProtoT: self._container.add_component(self) return self._container - def build(self) -> data_binding.JSONObject: + def build(self) -> typing.MutableMapping[str, typing.Any]: data = data_binding.JSONObjectBuilder() data["type"] = messages.ComponentType.BUTTON @@ -1376,7 +1454,7 @@ def add_to_menu(self) -> _SelectMenuBuilderT: self._menu.add_raw_option(self) return self._menu - def build(self) -> data_binding.JSONObject: + def build(self) -> typing.MutableMapping[str, typing.Any]: data = data_binding.JSONObjectBuilder() data["label"] = self._label @@ -1464,7 +1542,7 @@ def add_to_container(self) -> _ContainerProtoT: self._container.add_component(self) return self._container - def build(self) -> data_binding.JSONObject: + def build(self) -> typing.MutableMapping[str, typing.Any]: data = data_binding.JSONObjectBuilder() data["type"] = messages.ComponentType.SELECT_MENU @@ -1561,7 +1639,7 @@ def add_to_container(self) -> _ContainerProtoT: self._container.add_component(self) return self._container - def build(self) -> data_binding.JSONObject: + def build(self) -> typing.MutableMapping[str, typing.Any]: data = data_binding.JSONObjectBuilder() data["type"] = messages.ComponentType.TEXT_INPUT @@ -1650,7 +1728,7 @@ def add_text_input( self._assert_can_add_type(messages.ComponentType.TEXT_INPUT) return TextInputBuilder(container=self, custom_id=custom_id, label=label) - def build(self) -> data_binding.JSONObject: + def build(self) -> typing.MutableMapping[str, typing.Any]: return { "type": messages.ComponentType.ACTION_ROW, "components": [component.build() for component in self._components], diff --git a/hikari/interactions/command_interactions.py b/hikari/interactions/command_interactions.py index b5f9e4412b..9ac067d226 100644 --- a/hikari/interactions/command_interactions.py +++ b/hikari/interactions/command_interactions.py @@ -217,9 +217,6 @@ class BaseCommandInteraction(base_interactions.PartialInteraction): command_type: typing.Union[commands.CommandType, int] = attr.field(eq=False, hash=False, repr=True) """The type of the command.""" - resolved: typing.Optional[ResolvedOptionData] = attr.field(eq=False, hash=False, repr=False) - """Mappings of the objects resolved for the provided command options.""" - async def fetch_channel(self) -> channels.TextableChannel: """Fetch the guild channel this was triggered in. @@ -371,9 +368,15 @@ class CommandInteraction( ): """Represents a command interaction on Discord.""" + app_permissions: typing.Optional[permissions_.Permissions] = attr.field(eq=False, hash=False, repr=False) + """Permissions the bot has in this interaction's channel if it's in a guild.""" + options: typing.Optional[typing.Sequence[CommandInteractionOption]] = attr.field(eq=False, hash=False, repr=True) """Parameter values provided by the user invoking this command.""" + resolved: typing.Optional[ResolvedOptionData] = attr.field(eq=False, hash=False, repr=False) + """Mappings of the objects resolved for the provided command options.""" + target_id: typing.Optional[snowflakes.Snowflake] = attr.field(default=None, eq=False, hash=False, repr=True) """The target of the command. Only available if the command is a context menu command.""" diff --git a/hikari/interactions/component_interactions.py b/hikari/interactions/component_interactions.py index 1cb127da24..7f391bbcf4 100644 --- a/hikari/interactions/component_interactions.py +++ b/hikari/interactions/component_interactions.py @@ -38,6 +38,7 @@ from hikari import guilds from hikari import locales from hikari import messages + from hikari import permissions from hikari import snowflakes from hikari import users from hikari.api import special_endpoints @@ -141,6 +142,9 @@ class ComponentInteraction( locale: typing.Union[str, locales.Locale] = attr.field(eq=False, hash=False, repr=True) """The selected language of the user who triggered this component interaction.""" + app_permissions: typing.Optional[permissions.Permissions] = attr.field(eq=False, hash=False, repr=False) + """Permissions the bot has in this interaction's channel if it's in a guild.""" + def build_response(self, type_: _ImmediateTypesT, /) -> special_endpoints.InteractionMessageBuilder: """Get a message response builder for use in the REST server flow. diff --git a/hikari/interactions/modal_interactions.py b/hikari/interactions/modal_interactions.py index a24b8190c2..16629d8970 100644 --- a/hikari/interactions/modal_interactions.py +++ b/hikari/interactions/modal_interactions.py @@ -38,6 +38,7 @@ from hikari import channels from hikari import guilds from hikari import messages +from hikari import permissions from hikari import snowflakes from hikari import traits from hikari.interactions import base_interactions @@ -119,6 +120,9 @@ class ModalInteraction(base_interactions.MessageResponseMixin[ModalResponseTypes locale: str = attr.field(eq=False, hash=False, repr=True) """The selected language of the user who triggered this modal interaction.""" + app_permissions: typing.Optional[permissions.Permissions] = attr.field(eq=False, hash=False, repr=False) + """Permissions the bot has in this interaction's channel if it's in a guild.""" + components: typing.Sequence[messages.ActionRowComponent] = attr.field(eq=False, hash=False, repr=True) """Components in the modal.""" diff --git a/hikari/internal/attr_extensions.py b/hikari/internal/attr_extensions.py index 82fa474379..faf114d62b 100644 --- a/hikari/internal/attr_extensions.py +++ b/hikari/internal/attr_extensions.py @@ -60,7 +60,7 @@ def invalidate_deep_copy_cache() -> None: def get_fields_definition( - cls: typing.Type[ModelT], + cls: type, ) -> typing.Tuple[ typing.Sequence[typing.Tuple[attr.Attribute[typing.Any], str]], typing.Sequence[attr.Attribute[typing.Any]] ]: @@ -76,8 +76,8 @@ def get_fields_definition( typing.Sequence[typing.Tuple[builtins.str, builtins.str]] A sequence of tuples of string attribute names to string key-word names. """ - init_results = [] - non_init_results = [] + init_results: typing.List[typing.Tuple[attr.Attribute[typing.Any], str]] = [] + non_init_results: typing.List[attr.Attribute[typing.Any]] = [] for field in attr.fields(cls): if field.init: diff --git a/hikari/internal/cache.py b/hikari/internal/cache.py index d620728e0f..9980034731 100644 --- a/hikari/internal/cache.py +++ b/hikari/internal/cache.py @@ -33,7 +33,6 @@ "KnownCustomEmojiData", "RichActivityData", "MemberPresenceData", - "MentionsData", "MessageInteractionData", "MessageData", "VoiceStateData", @@ -626,84 +625,6 @@ def build_entity(self, app: traits.RESTAware, /) -> presences.MemberPresence: ) -@attr_extensions.with_copy -@attr.define(kw_only=True, repr=False, hash=False, weakref_slot=False) -class MentionsData(BaseData[messages.Mentions]): - """A model for storing message mentions data in an in-memory cache.""" - - users: undefined.UndefinedOr[typing.Mapping[snowflakes.Snowflake, RefCell[users_.User]]] = attr.field() - role_ids: undefined.UndefinedOr[typing.Tuple[snowflakes.Snowflake, ...]] = attr.field() - channels: undefined.UndefinedOr[typing.Mapping[snowflakes.Snowflake, channels_.PartialChannel]] = attr.field() - everyone: undefined.UndefinedOr[bool] = attr.field() - - @classmethod - def build_from_entity( - cls, - mentions: messages.Mentions, - /, - *, - users: undefined.UndefinedOr[typing.Mapping[snowflakes.Snowflake, RefCell[users_.User]]] = undefined.UNDEFINED, - ) -> MentionsData: - if not users and mentions.users is not undefined.UNDEFINED: - users = {user_id: RefCell(copy.copy(user)) for user_id, user in mentions.users.items()} - - channels: undefined.UndefinedOr[ - typing.Mapping[snowflakes.Snowflake, "channels_.PartialChannel"] - ] = undefined.UNDEFINED - if mentions.channels is not undefined.UNDEFINED: - channels = {channel_id: copy.copy(channel) for channel_id, channel in mentions.channels.items()} - - return cls( - users=users, - role_ids=tuple(mentions.role_ids) if mentions.role_ids is not undefined.UNDEFINED else undefined.UNDEFINED, - channels=channels, - everyone=mentions.everyone, - ) - - def build_entity( - self, _: traits.RESTAware, /, *, message: typing.Optional[messages.Message] = None - ) -> messages.Mentions: - users: undefined.UndefinedOr[typing.Mapping[snowflakes.Snowflake, users_.User]] = undefined.UNDEFINED - if self.users is not undefined.UNDEFINED: - users = {user_id: user.copy() for user_id, user in self.users.items()} - - channels: undefined.UndefinedOr[ - typing.Mapping[snowflakes.Snowflake, channels_.PartialChannel] - ] = undefined.UNDEFINED - if self.channels is not undefined.UNDEFINED: - channels = {channel_id: copy.copy(channel) for channel_id, channel in self.channels.items()} - - return messages.Mentions( - message=message or NotImplemented, - users=users, - role_ids=self.role_ids, - channels=channels, - everyone=self.everyone, - ) - - def update( - self, - mention: messages.Mentions, - /, - *, - users: undefined.UndefinedOr[typing.Mapping[snowflakes.Snowflake, RefCell[users_.User]]] = undefined.UNDEFINED, - ) -> None: - if users is not undefined.UNDEFINED: - self.users = users - - elif mention.users is not undefined.UNDEFINED: - self.users = {user_id: RefCell(copy.copy(user)) for user_id, user in mention.users.items()} - - if mention.role_ids is not undefined.UNDEFINED: - self.role_ids = tuple(mention.role_ids) - - if mention.channels is not undefined.UNDEFINED: - self.channels = {channel_id: copy.copy(channel) for channel_id, channel in mention.channels.items()} - - if mention.everyone is not undefined.UNDEFINED: - self.everyone = mention.everyone - - @attr_extensions.with_copy @attr.define(kw_only=True, repr=False, hash=False, weakref_slot=False) class MessageInteractionData(BaseData[messages.MessageInteraction]): @@ -762,7 +683,12 @@ class MessageData(BaseData[messages.Message]): timestamp: datetime.datetime = attr.field() edited_timestamp: typing.Optional[datetime.datetime] = attr.field() is_tts: bool = attr.field() - mentions: MentionsData = attr.field() + user_mentions: undefined.UndefinedOr[typing.Mapping[snowflakes.Snowflake, RefCell[users_.User]]] = attr.field() + role_mention_ids: undefined.UndefinedOr[typing.Tuple[snowflakes.Snowflake, ...]] = attr.field() + channel_mentions: undefined.UndefinedOr[ + typing.Mapping[snowflakes.Snowflake, channels_.PartialChannel] + ] = attr.field() + mentions_everyone: undefined.UndefinedOr[bool] = attr.field() attachments: typing.Tuple[messages.Attachment, ...] = attr.field() embeds: typing.Tuple[embeds_.Embed, ...] = attr.field() reactions: typing.Tuple[messages.Reaction, ...] = attr.field() @@ -788,7 +714,7 @@ def build_from_entity( *, author: typing.Optional[RefCell[users_.User]] = None, member: typing.Optional[RefCell[MemberData]] = None, - mention_users: undefined.UndefinedOr[ + user_mentions: undefined.UndefinedOr[ typing.Mapping[snowflakes.Snowflake, RefCell[users_.User]] ] = undefined.UNDEFINED, referenced_message: typing.Optional[RefCell[MessageData]] = None, @@ -797,12 +723,25 @@ def build_from_entity( if not member and message.member: member = RefCell(MemberData.build_from_entity(message.member)) - if not referenced_message and message.referenced_message: - referenced_message = RefCell(MessageData.build_from_entity(message.referenced_message)) + interaction = ( + MessageInteractionData.build_from_entity(message.interaction, user=interaction_user) + if message.interaction + else None + ) - interaction: typing.Optional[MessageInteractionData] = None - if message.interaction: - interaction = MessageInteractionData.build_from_entity(message.interaction, user=interaction_user) + if not user_mentions and message.user_mentions is not undefined.UNDEFINED: + user_mentions = {user_id: RefCell(copy.copy(user)) for user_id, user in message.user_mentions.items()} + + channel_mentions: undefined.UndefinedOr[typing.Mapping[snowflakes.Snowflake, channels_.PartialChannel]] = ( + {channel_id: copy.copy(channel) for channel_id, channel in message.channel_mentions.items()} + if message.channel_mentions is not undefined.UNDEFINED + else undefined.UNDEFINED + ) + role_mention_ids: undefined.UndefinedOr[typing.Tuple[snowflakes.Snowflake, ...]] = ( + tuple(message.role_mention_ids) + if message.role_mention_ids is not undefined.UNDEFINED + else undefined.UNDEFINED + ) return cls( id=message.id, @@ -814,7 +753,10 @@ def build_from_entity( timestamp=message.timestamp, edited_timestamp=message.edited_timestamp, is_tts=message.is_tts, - mentions=MentionsData.build_from_entity(message.mentions, users=mention_users), + user_mentions=user_mentions, + channel_mentions=channel_mentions, + role_mention_ids=role_mention_ids, + mentions_everyone=message.mentions_everyone, attachments=tuple(map(copy.copy, message.attachments)), embeds=tuple(map(_copy_embed, message.embeds)), reactions=tuple(map(copy.copy, message.reactions)), @@ -834,9 +776,16 @@ def build_from_entity( ) def build_entity(self, app: traits.RESTAware, /) -> messages.Message: - referenced_message: typing.Optional[messages.Message] = None - if self.referenced_message: - referenced_message = self.referenced_message.object.build_entity(app) + channel_mentions: undefined.UndefinedOr[typing.Mapping[snowflakes.Snowflake, channels_.PartialChannel]] = ( + {channel_id: copy.copy(channel) for channel_id, channel in self.channel_mentions.items()} + if self.channel_mentions is not undefined.UNDEFINED + else undefined.UNDEFINED + ) + user_mentions: undefined.UndefinedOr[typing.Mapping[snowflakes.Snowflake, users_.User]] = ( + {user_id: user.copy() for user_id, user in self.user_mentions.items()} + if self.user_mentions is not undefined.UNDEFINED + else undefined.UNDEFINED + ) message = messages.Message( id=self.id, @@ -850,6 +799,10 @@ def build_entity(self, app: traits.RESTAware, /) -> messages.Message: edited_timestamp=self.edited_timestamp, is_tts=self.is_tts, mentions=NotImplemented, + user_mentions=user_mentions, + channel_mentions=channel_mentions, + role_mention_ids=copy.copy(self.role_mention_ids), + mentions_everyone=self.mentions_everyone, attachments=tuple(map(copy.copy, self.attachments)), embeds=tuple(map(_copy_embed, self.embeds)), reactions=tuple(map(copy.copy, self.reactions)), @@ -862,12 +815,12 @@ def build_entity(self, app: traits.RESTAware, /) -> messages.Message: flags=self.flags, stickers=tuple(map(copy.copy, self.stickers)), nonce=self.nonce, - referenced_message=referenced_message, + referenced_message=self.referenced_message.object.build_entity(app) if self.referenced_message else None, interaction=self.interaction.build_entity(app) if self.interaction else None, application_id=self.application_id, components=self.components, ) - message.mentions = self.mentions.build_entity(app, message=message) + message.mentions = messages.Mentions(message=message) return message def update( @@ -875,7 +828,7 @@ def update( message: messages.PartialMessage, /, *, - mention_users: undefined.UndefinedOr[ + user_mentions: undefined.UndefinedOr[ typing.Mapping[snowflakes.Snowflake, RefCell[users_.User]] ] = undefined.UNDEFINED, ) -> None: @@ -897,7 +850,21 @@ def update( if message.components is not undefined.UNDEFINED: self.components = tuple(message.components) - self.mentions.update(message.mentions, users=mention_users) + if user_mentions is not undefined.UNDEFINED: + self.user_mentions = user_mentions + elif message.user_mentions is not undefined.UNDEFINED: + self.user_mentions = {user_id: RefCell(copy.copy(user)) for user_id, user in message.user_mentions.items()} + + if message.role_mention_ids is not undefined.UNDEFINED: + self.role_mention_ids = tuple(message.role_mention_ids) + + if message.channel_mentions is not undefined.UNDEFINED: + self.channel_mentions = { + channel_id: copy.copy(channel) for channel_id, channel in message.channel_mentions.items() + } + + if message.mentions_everyone is not undefined.UNDEFINED: + self.mentions_everyone = message.mentions_everyone @attr_extensions.with_copy diff --git a/hikari/internal/collections.py b/hikari/internal/collections.py index db61827baf..3c5509e083 100644 --- a/hikari/internal/collections.py +++ b/hikari/internal/collections.py @@ -134,28 +134,6 @@ def __setitem__(self, key: KeyT, value: ValueT) -> None: self._data[key] = value -class _FrozenDict(typing.MutableMapping[KeyT, ValueT]): - __slots__: typing.Sequence[str] = ("_source",) - - def __init__(self, source: typing.Dict[KeyT, typing.Tuple[float, ValueT]], /) -> None: - self._source = source - - def __getitem__(self, key: KeyT) -> ValueT: - return self._source[key][1] - - def __iter__(self) -> typing.Iterator[KeyT]: - return iter(self._source) - - def __len__(self) -> int: - return len(self._source) - - def __delitem__(self, key: KeyT) -> None: - del self._source[key] - - def __setitem__(self, key: KeyT, value: ValueT) -> None: - self._source[key] = (0.0, value) - - class LimitedCapacityCacheMap(ExtendedMutableMapping[KeyT, ValueT]): """Implementation of a capacity-limited most-recently-inserted mapping. @@ -357,8 +335,6 @@ def get_index_or_slice( Raises ------ - TypeError - If `index_or_slice` isn't a `builtins.slice` or `builtins.int`. IndexError If `index_or_slice` is an int and is outside the range of the mapping's contents. @@ -366,10 +342,7 @@ def get_index_or_slice( if isinstance(index_or_slice, slice): return tuple(itertools.islice(mapping.values(), index_or_slice.start, index_or_slice.stop, index_or_slice.step)) - if isinstance(index_or_slice, int): - try: - return next(itertools.islice(mapping.values(), index_or_slice, None)) - except StopIteration: - raise IndexError(index_or_slice) from None - - raise TypeError(f"sequence indices must be integers or slices, not {type(index_or_slice).__name__}") + try: + return next(itertools.islice(mapping.values(), index_or_slice, None)) + except StopIteration: + raise IndexError(index_or_slice) from None diff --git a/hikari/internal/data_binding.py b/hikari/internal/data_binding.py index b7390a4263..90017e8da1 100644 --- a/hikari/internal/data_binding.py +++ b/hikari/internal/data_binding.py @@ -60,10 +60,10 @@ # MyPy does not support recursive types yet. This has been ongoing for a long time, unfortunately. # See https://github.com/python/typing/issues/182 -JSONObject = typing.Dict[str, typing.Any] +JSONObject = typing.Mapping[str, typing.Any] """Type hint for a JSON-decoded object representation as a mapping.""" -JSONArray = typing.List[typing.Any] +JSONArray = typing.Sequence[typing.Any] """Type hint for a JSON-decoded array representation as a sequence.""" JSONish = typing.Union[str, int, float, bool, None, JSONArray, JSONObject] @@ -73,10 +73,7 @@ """Type hint for any valid that can be put in a StringMapBuilder""" _StringMapBuilderArg = typing.Union[ - typing.Mapping[str, str], - typing.Dict[str, str], - multidict.MultiMapping[str], - typing.Iterable[typing.Tuple[str, str]], + typing.Mapping[str, str], multidict.MultiMapping[str], typing.Iterable[typing.Tuple[str, str]] ] _APPLICATION_OCTET_STREAM: typing.Final[str] = "application/octet-stream" diff --git a/hikari/internal/net.py b/hikari/internal/net.py index 3024f51391..4e0f0cd817 100644 --- a/hikari/internal/net.py +++ b/hikari/internal/net.py @@ -62,7 +62,10 @@ async def generate_error_response(response: aiohttp.ClientResponse) -> errors.HT if response.status == http.HTTPStatus.NOT_FOUND: return errors.NotFoundError(*args) - status = http.HTTPStatus(response.status) + try: + status: typing.Union[http.HTTPStatus, int] = http.HTTPStatus(response.status) + except ValueError: + status = response.status if 400 <= status < 500: return errors.ClientHTTPResponseError(real_url, status, response.headers, raw_body) diff --git a/hikari/internal/reflect.py b/hikari/internal/reflect.py index d28af6c36e..627cfa4f42 100644 --- a/hikari/internal/reflect.py +++ b/hikari/internal/reflect.py @@ -62,7 +62,7 @@ def resolve_signature(func: typing.Callable[..., typing.Any]) -> inspect.Signatu signature = inspect.signature(func) resolved_typehints = typing.get_type_hints(func) - params = [] + params: typing.List[inspect.Parameter] = [] none_type = type(None) for name, param in signature.parameters.items(): @@ -98,7 +98,7 @@ def profiled(call: typing.Callable[..., _T]) -> typing.Callable[..., _T]: # pra @functools.wraps(call) def wrapped(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: - print("Profiling", call.__module__ + "." + call.__qualname__) # noqa: T001 print disallowed. + print("Profiling", call.__module__ + "." + call.__qualname__) # noqa: T201 print disallowed. cProfile.runctx(invoker, globals=globals(), locals=locals(), filename=None, sort=1) return locals()["result"] diff --git a/hikari/internal/ux.py b/hikari/internal/ux.py index 04217fc9d8..d38c744e0a 100644 --- a/hikari/internal/ux.py +++ b/hikari/internal/ux.py @@ -33,7 +33,6 @@ import re import string import sys -import time import typing import warnings @@ -223,10 +222,8 @@ def print_banner( for code in colorlog.escape_codes.escape_codes: args[code] = "" - sys.stdout.write(string.Template(raw_banner).safe_substitute(args)) - # Give the stream some time to flush - sys.stdout.flush() - time.sleep(0.125) + with open(sys.stdout.fileno(), "w", encoding="utf-8", closefd=False) as stdout: + stdout.write(string.Template(raw_banner).safe_substitute(args)) def supports_color(allow_color: bool, force_color: bool) -> bool: diff --git a/hikari/iterators.py b/hikari/iterators.py index 0aeeabe6ec..8b83de560a 100644 --- a/hikari/iterators.py +++ b/hikari/iterators.py @@ -115,7 +115,7 @@ def __call__(self, item: ValueT) -> bool: def __invert__(self) -> typing.Callable[[ValueT], bool]: return lambda item: not self(item) - def __or__(self, other: All[ValueT]) -> All[ValueT]: + def __or__(self, other: typing.Any) -> All[ValueT]: if not isinstance(other, All): raise TypeError(f"unsupported operand type(s) for |: {type(self).__name__!r} and {type(other).__name__!r}") diff --git a/hikari/locales.py b/hikari/locales.py index 929808830a..a02a5b10e1 100644 --- a/hikari/locales.py +++ b/hikari/locales.py @@ -70,7 +70,7 @@ class Locale(str, enums.Enum): NO = "no" """Norwegian""" - OL = "pl" + PL = "pl" """Polish""" PT_BR = "pt-BR" diff --git a/hikari/messages.py b/hikari/messages.py index 60495a2fa2..09437ee637 100644 --- a/hikari/messages.py +++ b/hikari/messages.py @@ -59,6 +59,7 @@ from hikari import undefined from hikari import urls from hikari.internal import attr_extensions +from hikari.internal import deprecation from hikari.internal import enums from hikari.internal import routes @@ -278,34 +279,46 @@ class Mentions: # through this mechanism. _message: PartialMessage = attr.field(repr=False) - users: undefined.UndefinedOr[typing.Mapping[snowflakes.Snowflake, users_.User]] = attr.field() - """Users who were notified by their mention in the message.""" - - role_ids: undefined.UndefinedOr[typing.Sequence[snowflakes.Snowflake]] = attr.field() - """IDs of roles that were notified by their mention in the message.""" - - channels: undefined.UndefinedOr[typing.Mapping[snowflakes.Snowflake, channels_.PartialChannel]] = attr.field() - """Channel mentions that reference channels in the target crosspost's guild. - - If the message is not crossposted, this will always be empty. - """ + @property + def channels(self) -> undefined.UndefinedOr[typing.Mapping[snowflakes.Snowflake, channels_.PartialChannel]]: + """Channel mentions that reference channels in the target crosspost's guild. - everyone: undefined.UndefinedOr[bool] = attr.field() - """Whether the message notifies using `@everyone` or `@here`.""" + If the message is not crossposted, this will always be empty. + """ + deprecation.warn_deprecated("Mentions.channels", alternative="channel_mentions in the base message object") + return self._message.channel_mentions @property def channels_ids(self) -> undefined.UndefinedOr[typing.Sequence[snowflakes.Snowflake]]: - if self.channels is undefined.UNDEFINED: - return undefined.UNDEFINED + """Sequence of IDs of the channels that were mentioned in the message.""" + deprecation.warn_deprecated( + "Mentions.channels_ids", alternative="channel_mention_ids in the base message object" + ) + return self._message.channel_mention_ids - return list(self.channels.keys()) + @property + def users(self) -> undefined.UndefinedOr[typing.Mapping[snowflakes.Snowflake, users_.User]]: + """Users who were notified by their mention in the message.""" + deprecation.warn_deprecated("Mentions.users", alternative="user_mentions in the base message object") + return self._message.user_mentions @property def user_ids(self) -> undefined.UndefinedOr[typing.Sequence[snowflakes.Snowflake]]: - if self.users is undefined.UNDEFINED: - return undefined.UNDEFINED + """Sequence of IDs of the users that were mentioned in the message.""" + deprecation.warn_deprecated("Mentions.user_ids", alternative="user_mentions_ids in the base message object") + return self._message.user_mentions_ids - return list(self.users.keys()) + @property + def role_ids(self) -> undefined.UndefinedOr[typing.Sequence[snowflakes.Snowflake]]: + """Sequence of IDs of roles that were notified by their mention in the message.""" + deprecation.warn_deprecated("Mentions.role_ids", alternative="role_mention_ids in the base message object") + return self._message.role_mention_ids + + @property + def everyone(self) -> undefined.UndefinedOr[bool]: + """Whether the message notifies using `@everyone` or `@here`.""" + deprecation.warn_deprecated("Mentions.everyone", alternative="mentions_everyone in the base message object") + return self._message.mentions_everyone def get_members(self) -> undefined.UndefinedOr[typing.Mapping[snowflakes.Snowflake, guilds.Member]]: """Discover any cached members notified by this message. @@ -328,18 +341,10 @@ def get_members(self) -> undefined.UndefinedOr[typing.Mapping[snowflakes.Snowfla means that there is a very small chance that some users provided in `notified_users` may not be present here. """ - if self.users is undefined.UNDEFINED: - return undefined.UNDEFINED - - if isinstance(self._message.app, traits.CacheAware) and self._message.guild_id is not None: - app = self._message.app - guild_id = self._message.guild_id - return self._map_cache_maybe_discover( - self.users, - lambda user_id: app.cache.get_member(guild_id, user_id), - ) - - return {} + deprecation.warn_deprecated( + "Mentions.get_members", alternative="get_member_mentions in the base message object" + ) + return self._message.get_member_mentions() def get_roles(self) -> undefined.UndefinedOr[typing.Mapping[snowflakes.Snowflake, guilds.Role]]: """Attempt to look up the roles that are notified by this message. @@ -363,29 +368,8 @@ def get_roles(self) -> undefined.UndefinedOr[typing.Mapping[snowflakes.Snowflake in `notifies_role_ids` may not be present here. This is a limitation of Discord, again. """ - if self.role_ids is undefined.UNDEFINED: - return undefined.UNDEFINED - - if isinstance(self._message.app, traits.CacheAware) and self._message.guild_id is not None: - app = self._message.app - return self._map_cache_maybe_discover( - self.role_ids, - app.cache.get_role, - ) - - return {} - - @staticmethod - def _map_cache_maybe_discover( - ids: typing.Iterable[snowflakes.Snowflake], - cache_call: typing.Callable[[snowflakes.Snowflake], typing.Optional[_T]], - ) -> typing.Dict[snowflakes.Snowflake, _T]: - results: typing.Dict[snowflakes.Snowflake, _T] = {} - for id_ in ids: - obj = cache_call(id_) - if obj is not None: - results[id_] = obj - return results + deprecation.warn_deprecated("Mentions.get_roles", alternative="get_role_mentions in the base message object") + return self._message.get_role_mentions() @attr_extensions.with_copy @@ -737,6 +721,18 @@ def __len__(self) -> int: return len(self.components) +def _map_cache_maybe_discover( + ids: typing.Iterable[snowflakes.Snowflake], + cache_call: typing.Callable[[snowflakes.Snowflake], typing.Optional[_T]], +) -> typing.Dict[snowflakes.Snowflake, _T]: + results: typing.Dict[snowflakes.Snowflake, _T] = {} + for id_ in ids: + obj = cache_call(id_) + if obj is not None: + results[id_] = obj + return results + + @attr_extensions.with_copy @attr.define(kw_only=True, repr=True, eq=False, weakref_slot=False) class PartialMessage(snowflakes.Unique): @@ -818,6 +814,54 @@ class PartialMessage(snowflakes.Unique): This is a Discord limitation. """ + user_mentions: undefined.UndefinedOr[typing.Mapping[snowflakes.Snowflake, users_.User]] = attr.field( + hash=False, eq=False, repr=False + ) + """Users who were notified by their mention in the message. + + !!! warning + If the contents have not mutated and this is a message update event, + some fields that are not affected may be empty instead. + + This is a Discord limitation. + """ + + role_mention_ids: undefined.UndefinedOr[typing.Sequence[snowflakes.Snowflake]] = attr.field( + hash=False, eq=False, repr=False + ) + """IDs of roles that were notified by their mention in the message. + + !!! warning + If the contents have not mutated and this is a message update event, + some fields that are not affected may be empty instead. + + This is a Discord limitation. + """ + + channel_mentions: undefined.UndefinedOr[ + typing.Mapping[snowflakes.Snowflake, channels_.PartialChannel] + ] = attr.field(hash=False, eq=False, repr=False) + """Channel mentions that reference channels in the target crosspost's guild. + + If the message is not crossposted, this will always be empty. + + !!! warning + If the contents have not mutated and this is a message update event, + some fields that are not affected may be empty instead. + + This is a Discord limitation. + """ + + mentions_everyone: undefined.UndefinedOr[bool] = attr.field(hash=False, eq=False, repr=False) + """Whether the message notifies using `@everyone` or `@here`. + + !!! warning + If the contents have not mutated and this is a message update event, + some fields that are not affected may be empty instead. + + This is a Discord limitation. + """ + attachments: undefined.UndefinedOr[typing.Sequence[Attachment]] = attr.field(hash=False, eq=False, repr=False) """The message attachments.""" @@ -869,7 +913,7 @@ class PartialMessage(snowflakes.Unique): This is a string used for validating a message was sent. """ - referenced_message: undefined.UndefinedNoneOr[Message] = attr.field(hash=False, eq=False, repr=False) + referenced_message: undefined.UndefinedNoneOr[PartialMessage] = attr.field(hash=False, eq=False, repr=False) """The message that was replied to. If `type` is `MessageType.REPLY` and `hikari.undefined.UNDEFINED`, Discord's @@ -890,6 +934,101 @@ class PartialMessage(snowflakes.Unique): components: undefined.UndefinedOr[typing.Sequence[PartialComponent]] = attr.field(hash=False, eq=False, repr=False) """Sequence of the components attached to this message.""" + @property + def channel_mention_ids(self) -> undefined.UndefinedOr[typing.Sequence[snowflakes.Snowflake]]: + """Ids of channels that reference channels in the target crosspost's guild. + + If the message is not crossposted, this will always be empty. + + !!! warning + If the contents have not mutated and this is a message update event, + some fields that are not affected may be empty instead. + + This is a Discord limitation. + """ + if self.channel_mentions is undefined.UNDEFINED: + return undefined.UNDEFINED + + return list(self.channel_mentions.keys()) + + @property + def user_mentions_ids(self) -> undefined.UndefinedOr[typing.Sequence[snowflakes.Snowflake]]: + """Ids of the users who were notified by their mention in the message. + + !!! warning + If the contents have not mutated and this is a message update event, + some fields that are not affected may be empty instead. + + This is a Discord limitation. + """ + if self.user_mentions is undefined.UNDEFINED: + return undefined.UNDEFINED + + return list(self.user_mentions.keys()) + + def get_member_mentions(self) -> undefined.UndefinedOr[typing.Mapping[snowflakes.Snowflake, guilds.Member]]: + """Discover any cached members notified by this message. + + If this message was sent in a DM, this will always be empty. + + !!! warning + This will only return valid results on gateway events. For REST + endpoints, this will potentially be empty. This is a limitation of + Discord's API, as they do not consistently notify of the ID of the + guild a message was sent in. + + !!! note + If you are using a stateless application such as a stateless bot + or a REST-only client, this will always be empty. Furthermore, + if you are running a stateful bot and have the GUILD_MEMBERS + intent disabled, this will also be empty. + + Members that are not cached will not appear in this mapping. This + means that there is a very small chance that some users provided + in `notified_users` may not be present here. + """ + if self.user_mentions is undefined.UNDEFINED: + return undefined.UNDEFINED + + if isinstance(self.app, traits.CacheAware) and self.guild_id is not None: + app = self.app + guild_id = self.guild_id + return _map_cache_maybe_discover( + self.user_mentions, lambda user_id: app.cache.get_member(guild_id, user_id) + ) + + return {} + + def get_role_mentions(self) -> undefined.UndefinedOr[typing.Mapping[snowflakes.Snowflake, guilds.Role]]: + """Attempt to look up the roles that are notified by this message. + + If this message was sent in a DM, this will always be empty. + + !!! warning + This will only return valid results on gateway events. For REST + endpoints, this will potentially be empty. This is a limitation of + Discord's API, as they do not consistently notify of the ID of the + guild a message was sent in. + + !!! note + If you are using a stateless application such as a stateless bot + or a REST-only client, this will always be empty. Furthermore, + if you are running a stateful bot and have the GUILD intent + disabled, this will also be empty. + + Roles that are not cached will not appear in this mapping. This + means that there is a very small chance that some role IDs provided + in `notifies_role_ids` may not be present here. This is a limitation + of Discord, again. + """ + if self.role_mention_ids is undefined.UNDEFINED: + return undefined.UNDEFINED + + if isinstance(self.app, traits.CacheAware) and self.guild_id is not None: + return _map_cache_maybe_discover(self.role_mention_ids, self.app.cache.get_role) + + return {} + def make_link(self, guild: typing.Optional[snowflakes.SnowflakeishOr[guilds.PartialGuild]]) -> str: """Generate a jump link to this message. @@ -909,7 +1048,6 @@ def make_link(self, guild: typing.Optional[snowflakes.SnowflakeishOr[guilds.Part builtins.str The jump link to the message. """ - # TODO: this doesn't seem like a safe assumption for rest only applications guild_id_str = "@me" if guild is None else str(int(guild)) return f"{urls.BASE_URL}/channels/{guild_id_str}/{self.channel_id}/{self.id}" @@ -1617,8 +1755,11 @@ class Message(PartialMessage): nonce: typing.Optional[str] = attr.field(hash=False, eq=False, repr=False) """The message nonce. This is a string used for validating a message was sent.""" - referenced_message: typing.Optional[Message] = attr.field(hash=False, eq=False, repr=False) - """The message that was replied to.""" + referenced_message: typing.Optional[PartialMessage] = attr.field(hash=False, eq=False, repr=False) + """The message that was replied to. + + If `type` is `MessageType.REPLY` and `builtins.None`, the message was deleted. + """ interaction: typing.Optional[MessageInteraction] = attr.field(hash=False, eq=False, repr=False) """Information about the interaction this message was created by.""" diff --git a/hikari/permissions.py b/hikari/permissions.py index 5e1f358840..86808fb661 100644 --- a/hikari/permissions.py +++ b/hikari/permissions.py @@ -37,52 +37,43 @@ class Permissions(enums.Flag): """Represents the permissions available in a given channel or guild. - This enum is an `enum.IntFlag`. This means that you can **combine multiple - permissions together** into one value using the bitwise-OR operator (`|`). - - my_perms = Permissions.MANAGE_CHANNELS | Permissions.MANAGE_GUILD - - your_perms = ( - Permissions.CREATE_INSTANT_INVITE - | Permissions.KICK_MEMBERS - | Permissions.BAN_MEMBERS - | Permissions.MANAGE_GUILD - ) - - You can **check if a permission is present** in a set of combined - permissions by using the bitwise-AND operator (`&`). This will return - the int-value of the permission if it is present, or `0` if not present. - - my_perms = Permissions.MANAGE_CHANNELS | Permissions.MANAGE_GUILD - - if my_perms & Permissions.MANAGE_CHANNELS: - if my_perms & Permissions.MANAGE_GUILD: - print("I have the permission to both manage the guild and the channels in it!") - else: - print("I have the permission to manage channels!") - else: - print("I don't have the permission to manage channels!") - - # Or you could simplify it: - - if my_perms & (Permissions.MANAGE_CHANNELS | Permissions.MANAGE_GUILD): - print("I have the permission to both manage the guild and the channels in it!") - elif my_perms & Permissions.MANAGE_CHANNELS: - print("I have the permission to manage channels!") - else: - print("I don't have the permission to manage channels!") - - If you need to **check that a permission is not present**, you can use the - bitwise-XOR operator (`^`) to check. If the permission is not present, it - will return a non-zero value, otherwise if it is present, it will return `0`. - - my_perms = Permissions.MANAGE_CHANNELS | Permissions.MANAGE_GUILD - - if my_perms ^ Permissions.MANAGE_CHANNELS: - print("Please give me the MANAGE_CHANNELS permission!") - - Lastly, if you need all the permissions set except the permission you want, - you can use the inversion operator (`~`) to do that. + This enum is an `enum.IntFlag`, which means that it is stored as a bit field + where each bit represents a permission. You can use bitwise operators + to efficiently manipulate and compare permissions. + + Examples + -------- + You can create an enum which combines multiple permissions using the bitwise OR operator (`|`): + + my_perms = Permissions.MANAGE_CHANNELS | Permissions.MANAGE_GUILD + + required_perms = ( + Permissions.CREATE_INSTANT_INVITE + | Permissions.KICK_MEMBERS + | Permissions.BAN_MEMBERS + | Permissions.MANAGE_GUILD + ) + + To find the intersection of two sets of permissions, use the bitwise AND + operator (`&`) between them. By then applying the `==` operator, you can check if all + permissions from one set are present in another set. This is useful, for instance, + for checking if a user has all the required permissions + + if (my_perms & required_perms) == required_perms: + print("I have all of the required permissions!") + else: + print("I am missing at least one required permission!") + + To determine which permissions from one set are missing from another, you can use the + bitwise equivalent of the set difference operation, as shown below. This can be used, + for instance, to find which of a user's permissions are missing from the required permissions. + + missing_perms = ~my_perms & required_perms + if (missing_perms): + print(f"I'm missing these permissions: {missing_perms}") + + Lastly, if you need all the permissions from a set except for a few, + you can use the bitwise NOT operator (`~`). # All permissions except ADMINISTRATOR. my_perms = ~Permissions.ADMINISTRATOR diff --git a/hikari/snowflakes.py b/hikari/snowflakes.py index b0f0657674..9c6cc0335b 100644 --- a/hikari/snowflakes.py +++ b/hikari/snowflakes.py @@ -57,9 +57,6 @@ class Snowflake(int): __slots__: typing.Sequence[str] = () - ___MIN___: Snowflake - ___MAX___: Snowflake - @property def created_at(self) -> datetime.datetime: """When the object was created.""" @@ -89,22 +86,12 @@ def from_datetime(cls, timestamp: datetime.datetime) -> Snowflake: @classmethod def min(cls) -> Snowflake: """Minimum value for a snowflakes.""" - try: - return cls.___MIN___ - - except AttributeError: - cls.___MIN___ = Snowflake(0) - return cls.___MIN___ + return cls(0) @classmethod def max(cls) -> Snowflake: """Maximum value for a snowflakes.""" - try: - return cls.___MAX___ - - except AttributeError: - cls.___MAX___ = Snowflake((1 << 63) - 1) - return cls.___MAX___ + return cls((1 << 63) - 1) @classmethod def from_data(cls, timestamp: datetime.datetime, worker_id: int, process_id: int, increment: int) -> Snowflake: diff --git a/pipelines/mypy.nox.py b/pipelines/mypy.nox.py index bb5a53957b..65c0cfde0b 100644 --- a/pipelines/mypy.nox.py +++ b/pipelines/mypy.nox.py @@ -34,7 +34,7 @@ @nox.session(reuse_venv=True) def mypy(session: nox.Session) -> None: - """Perform static type analysis on Python source code.""" + """Perform static type analysis on Python source code using mypy.""" session.install( "-r", "requirements.txt", diff --git a/pipelines/pyright.nox.py b/pipelines/pyright.nox.py index 60eecdbf17..4e2d6d0145 100644 --- a/pipelines/pyright.nox.py +++ b/pipelines/pyright.nox.py @@ -25,6 +25,27 @@ from pipelines import nox +@nox.session() +def pyright(session: nox.Session) -> None: + """Perform static type analysis on Python source code using pyright. + + At the time of writing this, this pipeline will not run successfully, + as hikari does not have 100% compatibility with pyright just yet. This + exists to make it easier to test and eventually reach that 100% compatibility. + """ + session.install( + "-r", + "requirements.txt", + "-r", + "speedup-requirements.txt", + "-r", + "server-requirements.txt", + "-r", + "dev-requirements.txt", + ) + session.run("python", "-m", "pyright") + + @nox.session() def verify_types(session: nox.Session) -> None: """Verify the "type completeness" of types exported by the library using Pyright.""" diff --git a/pyproject.toml b/pyproject.toml index dcbab9a8e0..29e5efc054 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,10 +59,24 @@ exclude_lines = [ ] [tool.pyright] -include = ["examples", "hikari"] +include = ["hikari", "examples"] +exclude = ["examples/simple_dashboard.py", "**/__init__.py", "hikari/internal/enums.py", "hikari/internal/fast_protocol.py"] pythonVersion = "3.8" typeCheckingMode = "strict" +reportUnnecessaryTypeIgnoreComment = "error" +reportMissingTypeStubs = "none" +reportImportCycles = "none" # Doesn't account for TYPE_CHECKING +reportIncompatibleMethodOverride = "none" # This relies on ordering for keyword-only arguments +reportOverlappingOverload = "none" # Type-Vars in last overloads may interfere +reportIncompatibleVariableOverride = "none" # Cannot overwrite abstract properties using attrs + +# Attrs validators will always be unknown +# https://github.com/python-attrs/attrs/issues/795 +reportUnknownMemberType = "warning" +reportUntypedFunctionDecorator = "warning" +reportOptionalMemberAccess = "warning" + [tool.pytest.ini_options] asyncio_mode = "strict" xfail_strict = true diff --git a/tests/hikari/hikari_test_helpers.py b/tests/hikari/hikari_test_helpers.py index 1d6a248b37..18b543752f 100644 --- a/tests/hikari/hikari_test_helpers.py +++ b/tests/hikari/hikari_test_helpers.py @@ -102,7 +102,7 @@ async def retry_wrapper(*args, **kwargs): ex = None for i in range(max_retries + 1): if i: - print("retry", i, "of", max_retries) # noqa: T001 - Print found + print("retry", i, "of", max_retries) # noqa: T201 - Print found try: await func(*args, **kwargs) return diff --git a/tests/hikari/impl/test_bot.py b/tests/hikari/impl/test_bot.py index 29cce434d2..0a632bc253 100644 --- a/tests/hikari/impl/test_bot.py +++ b/tests/hikari/impl/test_bot.py @@ -192,6 +192,7 @@ def test_init(self): cache_settings=cache_settings, http_settings=http_settings, intents=intents, + auto_chunk_members=False, logs="DEBUG", max_rate_limit=200, max_retries=0, @@ -205,7 +206,11 @@ def test_init(self): cache.assert_called_once_with(bot, cache_settings) assert bot._event_manager is event_manager.return_value event_manager.assert_called_once_with( - entity_factory.return_value, event_factory.return_value, intents, cache=cache.return_value + entity_factory.return_value, + event_factory.return_value, + intents, + auto_chunk_members=False, + cache=cache.return_value, ) assert bot._entity_factory is entity_factory.return_value entity_factory.assert_called_once_with(bot) diff --git a/tests/hikari/impl/test_cache.py b/tests/hikari/impl/test_cache.py index b9c69cea6e..5ebdbc89ba 100644 --- a/tests/hikari/impl/test_cache.py +++ b/tests/hikari/impl/test_cache.py @@ -2437,12 +2437,9 @@ def test__build_message(self, cache_impl): member_data = mock.Mock(build_entity=mock.Mock(return_value=mock_member)) mock_channel = mock.MagicMock() mock_mention_user = mock.MagicMock() - mention_data = cache_utilities.MentionsData( - users={snowflakes.Snowflake(4231): cache_utilities.RefCell(mock_mention_user)}, - role_ids=(snowflakes.Snowflake(21323123),), - channels={snowflakes.Snowflake(4444): mock_channel}, - everyone=True, - ) + mock_user_mentions = {snowflakes.Snowflake(4231): cache_utilities.RefCell(mock_mention_user)} + mock_role_mention_ids = (snowflakes.Snowflake(21323123),) + mock_channel_mentions = {snowflakes.Snowflake(4444): mock_channel} mock_attachment = mock.MagicMock(messages.Attachment) mock_embed_field = mock.MagicMock(embeds.EmbedField) mock_embed = mock.MagicMock(embeds.Embed, fields=(mock_embed_field,)) @@ -2468,7 +2465,10 @@ def test__build_message(self, cache_impl): timestamp=datetime.datetime(2020, 7, 30, 7, 10, 9, 550233, tzinfo=datetime.timezone.utc), edited_timestamp=datetime.datetime(2020, 8, 30, 7, 10, 9, 550233, tzinfo=datetime.timezone.utc), is_tts=True, - mentions=mention_data, + user_mentions=mock_user_mentions, + role_mention_ids=mock_role_mention_ids, + channel_mentions=mock_channel_mentions, + mentions_everyone=False, attachments=(mock_attachment,), embeds=(mock_embed,), reactions=(mock_reaction,), @@ -2499,13 +2499,10 @@ def test__build_message(self, cache_impl): assert result.timestamp == datetime.datetime(2020, 7, 30, 7, 10, 9, 550233, tzinfo=datetime.timezone.utc) assert result.edited_timestamp == datetime.datetime(2020, 8, 30, 7, 10, 9, 550233, tzinfo=datetime.timezone.utc) assert result.is_tts is True - - # MentionsData - assert result.mentions.users == {4231: mock_mention_user} - assert result.mentions.role_ids == (snowflakes.Snowflake(21323123),) - assert result.mentions.channels == {4444: mock_channel} - assert result.mentions.everyone is True - + assert result.user_mentions == {4231: mock_mention_user} + assert result.role_mention_ids == (snowflakes.Snowflake(21323123),) + assert result.channel_mentions == {4444: mock_channel} + assert result.mentions_everyone is False assert result.attachments == (mock_attachment,) for field in ( @@ -2545,12 +2542,6 @@ def test__build_message(self, cache_impl): assert result.components == (mock_component,) def test__build_message_with_null_fields(self, cache_impl): - mentions = cache_utilities.MentionsData( - role_ids=undefined.UNDEFINED, - channels=undefined.UNDEFINED, - everyone=undefined.UNDEFINED, - users=undefined.UNDEFINED, - ) message_data = cache_utilities.MessageData( id=snowflakes.Snowflake(32123123), channel_id=snowflakes.Snowflake(3123123123), @@ -2561,7 +2552,10 @@ def test__build_message_with_null_fields(self, cache_impl): timestamp=datetime.datetime(2020, 7, 30, 7, 10, 9, 550233, tzinfo=datetime.timezone.utc), edited_timestamp=None, is_tts=True, - mentions=mentions, + user_mentions=undefined.UNDEFINED, + role_mention_ids=undefined.UNDEFINED, + channel_mentions=undefined.UNDEFINED, + mentions_everyone=undefined.UNDEFINED, attachments=(), embeds=(), reactions=(), @@ -2589,10 +2583,10 @@ def test__build_message_with_null_fields(self, cache_impl): assert result.is_tts is True # MentionsData - assert result.mentions.users is undefined.UNDEFINED - assert result.mentions.role_ids is undefined.UNDEFINED - assert result.mentions.channels is undefined.UNDEFINED - assert result.mentions.everyone is undefined.UNDEFINED + assert result.user_mentions is undefined.UNDEFINED + assert result.role_mention_ids is undefined.UNDEFINED + assert result.channel_mentions is undefined.UNDEFINED + assert result.mentions_everyone is undefined.UNDEFINED assert result.webhook_id is None assert result.activity is None diff --git a/tests/hikari/impl/test_entity_factory.py b/tests/hikari/impl/test_entity_factory.py index baa667bced..5e9d70f838 100644 --- a/tests/hikari/impl/test_entity_factory.py +++ b/tests/hikari/impl/test_entity_factory.py @@ -100,6 +100,7 @@ def guild_voice_channel_payload(permission_overwrite_payload): "rtc_region": "europe", "parent_id": "456", "video_quality_mode": 1, + "last_message_id": 1234567890, } @@ -1764,6 +1765,51 @@ def test_deserialize_channel_handles_unknown_channel_type(self, entity_factory_i with pytest.raises(errors.UnrecognisedEntityError): entity_factory_impl.deserialize_channel({"type": -9999999999}) + @pytest.mark.parametrize( + ("type_", "fn"), + [ + (0, "deserialize_guild_text_channel"), + (2, "deserialize_guild_voice_channel"), + (4, "deserialize_guild_category"), + (5, "deserialize_guild_news_channel"), + (13, "deserialize_guild_stage_channel"), + ], + ) + def test_deserialize_channel_when_guild(self, mock_app, type_, fn): + payload = {"type": type_} + + with mock.patch.object(entity_factory.EntityFactoryImpl, fn) as expected_fn: + # We need to instantiate it after the mock so that the functions that are stored in the dicts + # are the ones we mock + entity_factory_impl = entity_factory.EntityFactoryImpl(app=mock_app) + + assert entity_factory_impl.deserialize_channel(payload, guild_id=123) is expected_fn.return_value + + expected_fn.assert_called_once_with(payload, guild_id=123) + + @pytest.mark.parametrize( + ("type_", "fn"), + [ + (1, "deserialize_dm"), + (3, "deserialize_group_dm"), + ], + ) + def test_deserialize_channel_when_dm(self, mock_app, type_, fn): + payload = {"type": type_} + + with mock.patch.object(entity_factory.EntityFactoryImpl, fn) as expected_fn: + # We need to instantiate it after the mock so that the functions that are stored in the dicts + # are the ones we mock + entity_factory_impl = entity_factory.EntityFactoryImpl(app=mock_app) + + assert entity_factory_impl.deserialize_channel(payload, guild_id=123123123) is expected_fn.return_value + + expected_fn.assert_called_once_with(payload) + + def test_deserialize_channel_when_unknown_type(self, entity_factory_impl): + with pytest.raises(errors.UnrecognisedEntityError): + entity_factory_impl.deserialize_channel({"type": -111}) + ################ # EMBED MODELS # ################ @@ -3097,7 +3143,7 @@ def test_deserialize_gateway_guild_ignores_unrecognised_channels(self, entity_fa ###################### @pytest.fixture() - def command_payload(self): + def slash_command_payload(self): return { "id": "1231231231", "application_id": "12354123", @@ -3105,7 +3151,8 @@ def command_payload(self): "type": 1, "name": "good name", "description": "very good description", - "default_permission": False, + "default_member_permissions": 8, + "dm_permission": True, "options": [ { "type": 1, @@ -3128,8 +3175,8 @@ def command_payload(self): "version": "123321123", } - def test_deserialize_command(self, entity_factory_impl, mock_app, command_payload): - command = entity_factory_impl.deserialize_command(payload=command_payload) + def test_deserialize_slash_command(self, entity_factory_impl, mock_app, slash_command_payload): + command = entity_factory_impl.deserialize_slash_command(payload=slash_command_payload) assert command.app is mock_app assert command.id == 1231231231 @@ -3137,7 +3184,8 @@ def test_deserialize_command(self, entity_factory_impl, mock_app, command_payloa assert command.guild_id == 49949494 assert command.name == "good name" assert command.description == "very good description" - assert command.default_permission is False + assert command.default_member_permissions == permission_models.Permissions.ADMINISTRATOR + assert command.is_dm_enabled is True assert command.version == 123321123 # CommandOption @@ -3177,22 +3225,24 @@ def test_deserialize_command(self, entity_factory_impl, mock_app, command_payloa assert isinstance(option, commands.CommandOption) assert isinstance(command, commands.SlashCommand) - def test_deserialize_command_with_passed_through_guild_id(self, entity_factory_impl): + def test_deserialize_slash_command_with_passed_through_guild_id(self, entity_factory_impl): payload = { "id": "1231231231", + "guild_id": "987654321", "application_id": "12354123", "type": 1, "name": "good name", "description": "very good description", "options": [], + "default_member_permissions": 0, "version": "123312", } - command = entity_factory_impl.deserialize_command(payload, guild_id=123123) + command = entity_factory_impl.deserialize_slash_command(payload, guild_id=123123) assert command.guild_id == 123123 - def test_deserialize_command_with_null_and_unset_values(self, entity_factory_impl): + def test_deserialize_slash_command_with_null_and_unset_values(self, entity_factory_impl): payload = { "id": "1231231231", "application_id": "12354123", @@ -3201,15 +3251,49 @@ def test_deserialize_command_with_null_and_unset_values(self, entity_factory_imp "name": "good name", "description": "very good description", "options": [], + "default_member_permissions": 0, "version": "43123", } - command = entity_factory_impl.deserialize_command(payload) + command = entity_factory_impl.deserialize_slash_command(payload) assert command.options is None - assert command.default_permission is True + assert command.is_dm_enabled is False assert isinstance(command, commands.SlashCommand) + def test_deserialize_slash_command_standardizes_default_member_permissions( + self, entity_factory_impl, slash_command_payload + ): + slash_command_payload["default_member_permissions"] = 0 + + command = entity_factory_impl.deserialize_slash_command(slash_command_payload) + + assert command.default_member_permissions == permission_models.Permissions.ADMINISTRATOR + + @pytest.mark.parametrize( + ("type_", "fn"), + [ + (1, "deserialize_slash_command"), + (2, "deserialize_context_menu_command"), + (3, "deserialize_context_menu_command"), + ], + ) + def test_deserialize_command(self, mock_app, type_, fn): + payload = {"type": type_} + + with mock.patch.object(entity_factory.EntityFactoryImpl, fn) as expected_fn: + # We need to instantiate it after the mock so that the functions that are stored in the dicts + # are the ones we mock + entity_factory_impl = entity_factory.EntityFactoryImpl(app=mock_app) + + assert entity_factory_impl.deserialize_command(payload, guild_id=123) is expected_fn.return_value + + expected_fn.assert_called_once_with(payload, guild_id=123) + + def test_deserialize_command_when_unknown_type(self, entity_factory_impl): + with pytest.raises(errors.UnrecognisedEntityError): + entity_factory_impl.deserialize_command({"type": -111}) + @pytest.fixture() def guild_command_permissions_payload(self): return { @@ -3310,7 +3394,7 @@ def test__deserialize_interaction_member(self, entity_factory_impl, interaction_ assert isinstance(member, base_interactions.InteractionMember) def test__deserialize_interaction_member_when_guild_id_already_in_roles_doesnt_duplicate( - self, entity_factory_impl, interaction_member_payload, user_payload + self, entity_factory_impl, interaction_member_payload ): interaction_member_payload["roles"] = [ 582345963851743243, @@ -3329,9 +3413,7 @@ def test__deserialize_interaction_member_when_guild_id_already_in_roles_doesnt_d 43123123, ] - def test__deserialize_interaction_member_with_unset_fields( - self, entity_factory_impl, interaction_member_payload, user_payload - ): + def test__deserialize_interaction_member_with_unset_fields(self, entity_factory_impl, interaction_member_payload): del interaction_member_payload["premium_since"] del interaction_member_payload["avatar"] del interaction_member_payload["communication_disabled_until"] @@ -3342,9 +3424,7 @@ def test__deserialize_interaction_member_with_unset_fields( assert member.premium_since is None assert member.raw_communication_disabled_until is None - def test__deserialize_interaction_member_with_passed_user( - self, entity_factory_impl, interaction_member_payload, user_payload - ): + def test__deserialize_interaction_member_with_passed_user(self, entity_factory_impl, interaction_member_payload): mock_user = object() member = entity_factory_impl._deserialize_interaction_member( interaction_member_payload, guild_id=43123123, user=mock_user @@ -3449,6 +3529,7 @@ def command_interaction_payload(self, interaction_member_payload, interaction_re "guild_locale": "en-US", "version": 69420, "application_id": "76234234", + "app_permissions": "54123", } def test_deserialize_command_interaction( @@ -3481,6 +3562,7 @@ def test_deserialize_command_interaction( assert interaction.resolved == entity_factory_impl._deserialize_resolved_option_data( interaction_resolved_data_payload, guild_id=43123123 ) + assert interaction.app_permissions == 54123 # CommandInteractionOption assert len(interaction.options) == 1 @@ -3508,17 +3590,42 @@ def test_deserialize_command_interaction( assert isinstance(interaction, command_interactions.CommandInteraction) + @pytest.fixture() + def context_menu_command_interaction_payload(self, interaction_member_payload, user_payload): + return { + "id": "3490190239012093", + "type": 4, + "guild_id": "43123123", + "data": { + "id": "43123123", + "name": "okokokok", + "type": 2, + "target_id": "115590097100865541", + "resolved": { + "users": { + "115590097100865541": user_payload, + } + }, + }, + "channel_id": "49949494", + "member": interaction_member_payload, + "token": "moe cat girls", + "locale": "es-ES", + "guild_locale": "en-US", + "version": 69420, + "application_id": "76234234", + "app_permissions": "54123123", + } + def test_deserialize_command_interaction_with_context_menu_field( - self, - entity_factory_impl, - context_menu_command_interaction_payload, + self, entity_factory_impl, context_menu_command_interaction_payload ): interaction = entity_factory_impl.deserialize_command_interaction(context_menu_command_interaction_payload) assert interaction.target_id == 115590097100865541 assert isinstance(interaction, command_interactions.CommandInteraction) def test_deserialize_command_interaction_with_null_attributes( - self, entity_factory_impl, mock_app, command_interaction_payload, user_payload + self, entity_factory_impl, command_interaction_payload, user_payload ): del command_interaction_payload["guild_id"] del command_interaction_payload["member"] @@ -3526,6 +3633,7 @@ def test_deserialize_command_interaction_with_null_attributes( del command_interaction_payload["data"]["resolved"] del command_interaction_payload["data"]["options"] del command_interaction_payload["guild_locale"] + del command_interaction_payload["app_permissions"] interaction = entity_factory_impl.deserialize_command_interaction(command_interaction_payload) @@ -3535,6 +3643,7 @@ def test_deserialize_command_interaction_with_null_attributes( assert interaction.options is None assert interaction.resolved is None assert interaction.guild_locale is None + assert interaction.app_permissions is None @pytest.fixture() def autocomplete_interaction_payload(self, user_payload, interaction_resolved_data_payload): @@ -3556,7 +3665,6 @@ def autocomplete_interaction_payload(self, user_payload, interaction_resolved_da ], }, ], - "resolved": interaction_resolved_data_payload, }, "channel_id": "49949494", "user": user_payload, @@ -3568,7 +3676,7 @@ def autocomplete_interaction_payload(self, user_payload, interaction_resolved_da } def test_deserialize_autocomplete_interaction( - self, entity_factory_impl, mock_app, autocomplete_interaction_payload, interaction_resolved_data_payload + self, entity_factory_impl, mock_app, autocomplete_interaction_payload ): interaction = entity_factory_impl.deserialize_autocomplete_interaction(autocomplete_interaction_payload) @@ -3584,9 +3692,6 @@ def test_deserialize_autocomplete_interaction( assert interaction.locale is locales.Locale.ES_ES assert interaction.guild_locale == "en-US" assert interaction.guild_locale is locales.Locale.EN_US - assert interaction.resolved == entity_factory_impl._deserialize_resolved_option_data( - interaction_resolved_data_payload, guild_id=43123123 - ) # AutocompleteInteractionOption assert len(interaction.options) == 1 @@ -3616,9 +3721,8 @@ def test_deserialize_autocomplete_interaction( assert isinstance(interaction, command_interactions.AutocompleteInteraction) def test_deserialize_autocomplete_interaction_with_null_fields( - self, entity_factory_impl, user_payload, mock_app, autocomplete_interaction_payload + self, entity_factory_impl, user_payload, autocomplete_interaction_payload ): - del autocomplete_interaction_payload["data"]["resolved"] del autocomplete_interaction_payload["guild_locale"] del autocomplete_interaction_payload["guild_id"] @@ -3627,89 +3731,43 @@ def test_deserialize_autocomplete_interaction_with_null_fields( assert interaction.guild_id is None assert interaction.member is None assert interaction.user == entity_factory_impl.deserialize_user(user_payload) - assert interaction.resolved is None assert interaction.guild_locale is None - def test_deserialize_interaction_returns_expected_type( - self, entity_factory_impl, command_interaction_payload, component_interaction_payload - ): - for payload, expected_type in [ - (command_interaction_payload, command_interactions.CommandInteraction), - (component_interaction_payload, component_interactions.ComponentInteraction), - ]: - assert type(entity_factory_impl.deserialize_interaction(payload)) is expected_type + @pytest.mark.parametrize( + ("type_", "fn"), + [ + (2, "deserialize_command_interaction"), + (3, "deserialize_component_interaction"), + (4, "deserialize_autocomplete_interaction"), + ], + ) + def test_deserialize_interaction(self, mock_app, type_, fn): + payload = {"type": type_} + + with mock.patch.object(entity_factory.EntityFactoryImpl, fn) as expected_fn: + # We need to instantiate it after the mock so that the functions that are stored in the dicts + # are the ones we mock + entity_factory_impl = entity_factory.EntityFactoryImpl(app=mock_app) + + assert entity_factory_impl.deserialize_interaction(payload) is expected_fn.return_value + + expected_fn.assert_called_once_with(payload) def test_deserialize_interaction_handles_unknown_type(self, entity_factory_impl): with pytest.raises(errors.UnrecognisedEntityError): entity_factory_impl.deserialize_interaction({"type": -999}) - def test_serialize_command_option_with_channel_type(self, entity_factory_impl): + def test_serialize_command_option(self, entity_factory_impl): option = commands.CommandOption( type=commands.OptionType.INTEGER, name="a name", description="go away", is_required=True, - channel_types=[channel_models.ChannelType.GUILD_STAGE, channel_models.ChannelType.GUILD_TEXT, 100], - ) - - result = entity_factory_impl.serialize_command_option(option) - - assert result == { - "type": 4, - "name": "a name", - "description": "go away", - "required": True, - "channel_types": [13, 0, 100], - } - - def test_serialize_command_option_with_min_and_max_value(self, entity_factory_impl): - option = commands.CommandOption( - type=commands.OptionType.FLOAT, - name="a name", - description="go away", - is_required=True, + autocomplete=True, min_value=1.2, max_value=9.999, - ) - - result = entity_factory_impl.serialize_command_option(option) - - assert result == { - "type": 10, - "name": "a name", - "description": "go away", - "required": True, - "min_value": 1.2, - "max_value": 9.999, - } - - def test_serialize_command_option_with_choices(self, entity_factory_impl): - option = commands.CommandOption( - type=commands.OptionType.INTEGER, - name="a name", - description="go away", - is_required=True, + channel_types=[channel_models.ChannelType.GUILD_STAGE, channel_models.ChannelType.GUILD_TEXT, 100], choices=[commands.CommandChoice(name="a", value="choice")], - options=None, - ) - - result = entity_factory_impl.serialize_command_option(option) - - assert result == { - "type": 4, - "name": "a name", - "description": "go away", - "required": True, - "choices": [{"name": "a", "value": "choice"}], - } - - def test_serialize_command_option_with_options(self, entity_factory_impl): - option = commands.CommandOption( - type=commands.OptionType.SUB_COMMAND, - name="a name", - description="go away", - is_required=True, - choices=None, options=[ commands.CommandOption( type=commands.OptionType.STRING, @@ -3725,10 +3783,15 @@ def test_serialize_command_option_with_options(self, entity_factory_impl): result = entity_factory_impl.serialize_command_option(option) assert result == { - "type": 1, + "type": 4, "name": "a name", "description": "go away", "required": True, + "channel_types": [13, 0, 100], + "min_value": 1.2, + "max_value": 9.999, + "autocomplete": True, + "choices": [{"name": "a", "value": "choice"}], "options": [ { "type": 3, @@ -3740,51 +3803,6 @@ def test_serialize_command_option_with_options(self, entity_factory_impl): ], } - def test_serialize_command_option_with_autocomplete(self, entity_factory_impl): - option = commands.CommandOption( - type=commands.OptionType.STRING, - name="a name", - description="go away", - is_required=True, - autocomplete=True, - ) - - result = entity_factory_impl.serialize_command_option(option) - - assert result == { - "type": 3, - "name": "a name", - "description": "go away", - "required": True, - "autocomplete": True, - } - - @pytest.fixture() - def context_menu_command_interaction_payload(self, interaction_member_payload, user_payload): - return { - "id": "3490190239012093", - "type": 4, - "guild_id": "43123123", - "data": { - "id": "43123123", - "name": "okokokok", - "type": 2, - "target_id": "115590097100865541", - "resolved": { - "users": { - "115590097100865541": user_payload, - } - }, - }, - "channel_id": "49949494", - "member": interaction_member_payload, - "token": "moe cat girls", - "locale": "es-ES", - "guild_locale": "en-US", - "version": 69420, - "application_id": "76234234", - } - @pytest.fixture() def context_menu_command_payload(self): return { @@ -3793,16 +3811,13 @@ def context_menu_command_payload(self): "guild_id": "49949494", "type": 2, "name": "good name", - "default_permission": False, + "default_member_permissions": 8, + "dm_permission": True, "version": "123321123", } - def test_deserialize_context_menu_command( - self, - entity_factory_impl, - context_menu_command_payload, - ): - command = entity_factory_impl.deserialize_command(context_menu_command_payload) + def test_deserialize_context_menu_command(self, entity_factory_impl, context_menu_command_payload): + command = entity_factory_impl.deserialize_context_menu_command(context_menu_command_payload) assert isinstance(command, commands.ContextMenuCommand) assert command.id == 1231231231 @@ -3810,19 +3825,28 @@ def test_deserialize_context_menu_command( assert command.guild_id == 49949494 assert command.type == commands.CommandType.USER assert command.name == "good name" - assert command.default_permission is False + assert command.default_member_permissions == permission_models.Permissions.ADMINISTRATOR + assert command.is_dm_enabled is True assert command.version == 123321123 - def test_unknown_command_type( - self, - entity_factory_impl, - command_payload, + def test_deserialize_context_menu_command_with_with_null_and_unset_values( + self, entity_factory_impl, context_menu_command_payload ): - payload = command_payload.copy() - payload["type"] = 4 + del context_menu_command_payload["dm_permission"] - with pytest.raises(errors.UnrecognisedEntityError): - entity_factory_impl.deserialize_command(payload) + command = entity_factory_impl.deserialize_context_menu_command(context_menu_command_payload) + assert isinstance(command, commands.ContextMenuCommand) + + assert command.is_dm_enabled is False + + def test_deserialize_context_menu_command_default_member_permissions( + self, entity_factory_impl, context_menu_command_payload + ): + context_menu_command_payload["default_member_permissions"] = 0 + + command = entity_factory_impl.deserialize_context_menu_command(context_menu_command_payload) + + assert command.default_member_permissions == permission_models.Permissions.ADMINISTRATOR @pytest.fixture() def component_interaction_payload(self, interaction_member_payload, message_payload): @@ -3839,6 +3863,7 @@ def component_interaction_payload(self, interaction_member_payload, message_payl "application_id": "290926444748734465", "locale": "es-ES", "guild_locale": "en-US", + "app_permissions": "5431234", } def test_deserialize_component_interaction( @@ -3866,6 +3891,7 @@ def test_deserialize_component_interaction( assert interaction.locale is locales.Locale.ES_ES assert interaction.guild_locale == "en-US" assert interaction.guild_locale is locales.Locale.EN_US + assert interaction.app_permissions == 5431234 assert isinstance(interaction, component_interactions.ComponentInteraction) def test_deserialize_component_interaction_with_undefined_fields( @@ -3891,6 +3917,7 @@ def test_deserialize_component_interaction_with_undefined_fields( assert interaction.user == entity_factory_impl.deserialize_user(user_payload) assert interaction.values == () assert interaction.guild_locale is None + assert interaction.app_permissions is None assert isinstance(interaction, component_interactions.ComponentInteraction) @pytest.fixture() @@ -4500,13 +4527,25 @@ def test__deserialize_select_menu_partial(self, entity_factory_impl): assert menu.max_values == 1 assert menu.is_disabled is False - def test__deserialize_component(self, entity_factory_impl, action_row_payload, button_payload, select_menu_payload): - for expected_type, payload in [ - (message_models.ActionRowComponent, action_row_payload), - (message_models.ButtonComponent, button_payload), - (message_models.SelectMenuComponent, select_menu_payload), - ]: - assert type(entity_factory_impl._deserialize_component(payload)) is expected_type + @pytest.mark.parametrize( + ("type_", "fn"), + [ + (1, "_deserialize_action_row"), + (2, "_deserialize_button"), + (3, "_deserialize_select_menu"), + ], + ) + def test__deserialize_component(self, mock_app, type_, fn): + payload = {"type": type_} + + with mock.patch.object(entity_factory.EntityFactoryImpl, fn) as expected_fn: + # We need to instantiate it after the mock so that the functions that are stored in the dicts + # are the ones we mock + entity_factory_impl = entity_factory.EntityFactoryImpl(app=mock_app) + + assert entity_factory_impl._deserialize_component(payload) is expected_fn.return_value + + expected_fn.assert_called_once_with(payload) def test__deserialize_component_handles_unknown_type(self, entity_factory_impl): with pytest.raises(errors.UnrecognisedEntityError): @@ -4903,7 +4942,7 @@ def test_deserialize_message( assert message.message_reference.guild_id == 278325129692446720 assert isinstance(message.message_reference, message_models.MessageReference) - assert message.referenced_message == entity_factory_impl.deserialize_message(referenced_message) + assert message.referenced_message == entity_factory_impl.deserialize_partial_message(referenced_message) assert message.flags == message_models.MessageFlag.IS_CROSSPOST # Sticker @@ -5993,16 +6032,25 @@ def test_deserialize_application_webhook_without_optional_fields( assert webhook.avatar_hash is None - def test_deserialize_webhook( - self, entity_factory_impl, incoming_webhook_payload, follower_webhook_payload, application_webhook_payload - ): - for expected_type, payload in [ - (webhook_models.IncomingWebhook, incoming_webhook_payload), - (webhook_models.ChannelFollowerWebhook, follower_webhook_payload), - (webhook_models.ApplicationWebhook, application_webhook_payload), - ]: - result = entity_factory_impl.deserialize_webhook(payload) - assert isinstance(result, expected_type) + @pytest.mark.parametrize( + ("type_", "fn"), + [ + (1, "deserialize_incoming_webhook"), + (2, "deserialize_channel_follower_webhook"), + (3, "deserialize_application_webhook"), + ], + ) + def test_deserialize_webhook(self, mock_app, type_, fn): + payload = {"type": type_} + + with mock.patch.object(entity_factory.EntityFactoryImpl, fn) as expected_fn: + # We need to instantiate it after the mock so that the functions that are stored in the dicts + # are the ones we mock + entity_factory_impl = entity_factory.EntityFactoryImpl(app=mock_app) + + assert entity_factory_impl.deserialize_webhook(payload) is expected_fn.return_value + + expected_fn.assert_called_once_with(payload) def test_deserialize_webhook_for_unexpected_webhook_type(self, entity_factory_impl): with pytest.raises(errors.UnrecognisedEntityError): diff --git a/tests/hikari/impl/test_event_factory.py b/tests/hikari/impl/test_event_factory.py index 1cf4682967..e269c9d0c2 100644 --- a/tests/hikari/impl/test_event_factory.py +++ b/tests/hikari/impl/test_event_factory.py @@ -29,6 +29,7 @@ from hikari import traits from hikari import undefined from hikari import users as user_models +from hikari.events import application_events from hikari.events import channel_events from hikari.events import guild_events from hikari.events import interaction_events @@ -58,13 +59,28 @@ def mock_shard(self): def event_factory(self, mock_app): return event_factory_.EventFactoryImpl(mock_app) + ###################### + # APPLICATION EVENTS # + ###################### + + def test_deserialize_application_command_permission_update_event(self, event_factory, mock_app, mock_shard): + mock_payload = object() + + event = event_factory.deserialize_application_command_permission_update_event(mock_shard, mock_payload) + + mock_app.entity_factory.deserialize_guild_command_permissions.assert_called_once_with(mock_payload) + assert isinstance(event, application_events.ApplicationCommandPermissionsUpdateEvent) + assert event.app is mock_app + assert event.shard is mock_shard + assert event.permissions is mock_app.entity_factory.deserialize_guild_command_permissions.return_value + ################## # CHANNEL EVENTS # ################## def test_deserialize_guild_channel_create_event(self, event_factory, mock_app, mock_shard): mock_app.entity_factory.deserialize_channel.return_value = mock.Mock(spec=channel_models.GuildChannel) - mock_payload = mock.Mock(app=mock_app) + mock_payload = object() event = event_factory.deserialize_guild_channel_create_event(mock_shard, mock_payload) diff --git a/tests/hikari/impl/test_event_manager.py b/tests/hikari/impl/test_event_manager.py index 426d0eaaf5..b82e6c49e1 100644 --- a/tests/hikari/impl/test_event_manager.py +++ b/tests/hikari/impl/test_event_manager.py @@ -147,6 +147,17 @@ async def test_on_resumed(self, event_manager_impl, shard, event_factory): event_factory.deserialize_resumed_event.assert_called_once_with(shard) event_manager_impl.dispatch.assert_awaited_once_with(event_factory.deserialize_resumed_event.return_value) + @pytest.mark.asyncio() + async def test_on_application_command_permissions_update(self, event_manager_impl, shard, event_factory): + payload = {} + + await event_manager_impl.on_application_command_permissions_update(shard, payload) + + event_factory.deserialize_application_command_permission_update_event.assert_called_once_with(shard, payload) + event_manager_impl.dispatch.assert_awaited_once_with( + event_factory.deserialize_application_command_permission_update_event.return_value + ) + @pytest.mark.asyncio() async def test_on_channel_create_stateful(self, event_manager_impl, shard, event_factory): payload = {} @@ -438,6 +449,29 @@ async def test_on_guild_create_when_members_declared_and_enabled_for_member_chun assert mock_event.chunk_nonce == "123.abc" stateless_event_manager_impl.dispatch.assert_awaited_once_with(mock_event) + @pytest.mark.parametrize("cache_enabled", [True, False]) + @pytest.mark.parametrize("large", [True, False]) + @pytest.mark.parametrize("enabled_for_event", [True, False]) + @pytest.mark.asyncio() + async def test_on_guild_create_when_chunk_members_disabled( + self, + stateless_event_manager_impl, + shard, + large, + cache_enabled, + enabled_for_event, + ): + shard.id = 123 + stateless_event_manager_impl._intents = intents.Intents.GUILD_MEMBERS + stateless_event_manager_impl._cache_enabled_for = mock.Mock(return_value=cache_enabled) + stateless_event_manager_impl._enabled_for_event = mock.Mock(return_value=enabled_for_event) + stateless_event_manager_impl._auto_chunk_members = False + + with mock.patch.object(event_manager, "_request_guild_members") as request_guild_members: + await stateless_event_manager_impl.on_guild_create(shard, {"id": 456, "large": large}) + + request_guild_members.assert_not_called() + @pytest.mark.asyncio() async def test_on_guild_update_when_stateless( self, stateless_event_manager_impl, shard, event_factory, entity_factory diff --git a/tests/hikari/impl/test_rest.py b/tests/hikari/impl/test_rest.py index bdb338b7dd..29c959e5f7 100644 --- a/tests/hikari/impl/test_rest.py +++ b/tests/hikari/impl/test_rest.py @@ -23,7 +23,6 @@ import contextlib import datetime import http -import re import warnings import mock @@ -1044,6 +1043,59 @@ def test_unban_member(self, rest_client): assert reason is mock_unban_user.return_value mock_unban_user.assert_called_once_with(123, 321, reason="ayaya") + def test_fetch_bans(self, rest_client: rest.RESTClientImpl): + with mock.patch.object(special_endpoints, "GuildBanIterator") as iterator_cls: + iterator = rest_client.fetch_bans(187, newest_first=True, start_at=StubModel(65652342134)) + + iterator_cls.assert_called_once_with( + rest_client._entity_factory, + rest_client._request, + 187, + True, + "65652342134", + ) + assert iterator is iterator_cls.return_value + + def test_fetch_bans_when_datetime_for_start_at(self, rest_client: rest.RESTClientImpl): + start_at = datetime.datetime(2022, 3, 6, 12, 1, 58, 415625, tzinfo=datetime.timezone.utc) + with mock.patch.object(special_endpoints, "GuildBanIterator") as iterator_cls: + iterator = rest_client.fetch_bans(9000, newest_first=True, start_at=start_at) + + iterator_cls.assert_called_once_with( + rest_client._entity_factory, + rest_client._request, + 9000, + True, + "950000286338908160", + ) + assert iterator is iterator_cls.return_value + + def test_fetch_bans_when_start_at_undefined(self, rest_client: rest.RESTClientImpl): + with mock.patch.object(special_endpoints, "GuildBanIterator") as iterator_cls: + iterator = rest_client.fetch_bans(8844) + + iterator_cls.assert_called_once_with( + rest_client._entity_factory, + rest_client._request, + 8844, + False, + str(snowflakes.Snowflake.min()), + ) + assert iterator is iterator_cls.return_value + + def test_fetch_bans_when_start_at_undefined_and_newest_first(self, rest_client: rest.RESTClientImpl): + with mock.patch.object(special_endpoints, "GuildBanIterator") as iterator_cls: + iterator = rest_client.fetch_bans(3848, newest_first=True) + + iterator_cls.assert_called_once_with( + rest_client._entity_factory, + rest_client._request, + 3848, + True, + str(snowflakes.Snowflake.max()), + ) + assert iterator is iterator_cls.return_value + def test_command_builder(self, rest_client): with warnings.catch_warnings(): warnings.simplefilter("ignore", category=DeprecationWarning) @@ -1334,19 +1386,6 @@ def test__build_message_payload_when_both_single_and_plural_args_passed( ): rest_client._build_message_payload(**{singular_arg: object(), plural_arg: object()}) - @pytest.mark.parametrize( - ("singular_arg", "plural_arg"), - [("attachment", "attachments"), ("component", "components"), ("embed", "embeds")], - ) - def test__build_message_payload_when_non_collection_passed_to_plural(self, rest_client, singular_arg, plural_arg): - expected_error_message = ( - f"You passed a non-collection to '{plural_arg}', but this expects a collection. Maybe you meant to use " - f"'{singular_arg}' (singular) instead?" - ) - - with pytest.raises(TypeError, match=re.escape(expected_error_message)): - rest_client._build_message_payload(**{plural_arg: object()}) - def test_interaction_deferred_builder(self, rest_client): result = rest_client.interaction_deferred_builder(5) @@ -4223,21 +4262,6 @@ async def test_fetch_ban(self, rest_client): rest_client._request.assert_awaited_once_with(expected_route) rest_client._entity_factory.deserialize_guild_member_ban.assert_called_once_with({"id": "789"}) - async def test_fetch_bans(self, rest_client): - ban1 = StubModel(456) - ban2 = StubModel(789) - expected_route = routes.GET_GUILD_BANS.compile(guild=123) - rest_client._request = mock.AsyncMock(return_value=[{"id": "456"}, {"id": "789"}]) - rest_client._entity_factory.deserialize_guild_member_ban = mock.Mock(side_effect=[ban1, ban2]) - - assert await rest_client.fetch_bans(StubModel(123)) == [ban1, ban2] - - rest_client._request.assert_awaited_once_with(expected_route) - assert rest_client._entity_factory.deserialize_guild_member_ban.call_count == 2 - rest_client._entity_factory.deserialize_guild_member_ban.assert_has_calls( - [mock.call({"id": "456"}), mock.call({"id": "789"})] - ) - async def test_fetch_roles(self, rest_client): role1 = StubModel(456) role2 = StubModel(789) @@ -4745,33 +4769,33 @@ async def test_fetch_application_commands_without_guild(self, rest_client): rest_client._request.assert_awaited_once_with(expected_route) rest_client._entity_factory.deserialize_command.assert_called_once_with({"id": "34512312"}, guild_id=None) - async def test_create_application_command_with_optionals(self, rest_client: rest.RESTClientImpl): + async def test__create_application_command_with_optionals(self, rest_client: rest.RESTClientImpl): expected_route = routes.POST_APPLICATION_GUILD_COMMAND.compile(application=4332123, guild=653452134) rest_client._request = mock.AsyncMock(return_value={"id": "29393939"}) mock_option = object() - with warnings.catch_warnings(): - warnings.simplefilter("ignore", category=DeprecationWarning) - result = await rest_client.create_application_command( - application=StubModel(4332123), - guild=StubModel(653452134), - name="okokok", - description="not ok anymore", - options=[mock_option], - ) + result = await rest_client._create_application_command( + application=StubModel(4332123), + type=100, + name="okokok", + description="not ok anymore", + guild=StubModel(653452134), + options=[mock_option], + default_member_permissions=permissions.Permissions.ADMINISTRATOR, + dm_enabled=False, + ) - assert result is rest_client._entity_factory.deserialize_slash_command.return_value + assert result is rest_client._request.return_value rest_client._entity_factory.serialize_command_option.assert_called_once_with(mock_option) - rest_client._entity_factory.deserialize_slash_command.assert_called_once_with( - rest_client._request.return_value, guild_id=653452134 - ) rest_client._request.assert_awaited_once_with( expected_route, json={ - "type": 1, + "type": 100, "name": "okokok", "description": "not ok anymore", "options": [rest_client._entity_factory.serialize_command_option.return_value], + "default_member_permissions": 8, + "dm_permission": False, }, ) @@ -4779,47 +4803,102 @@ async def test_create_application_command_without_optionals(self, rest_client: r expected_route = routes.POST_APPLICATION_COMMAND.compile(application=4332123) rest_client._request = mock.AsyncMock(return_value={"id": "29393939"}) - with warnings.catch_warnings(): - warnings.simplefilter("ignore", category=DeprecationWarning) - result = await rest_client.create_application_command( - StubModel(4332123), - name="okokok", - description="not ok anymore", - ) - - assert result is rest_client._entity_factory.deserialize_slash_command.return_value - rest_client._entity_factory.deserialize_slash_command.assert_called_once_with( - rest_client._request.return_value, guild_id=None + result = await rest_client._create_application_command( + application=StubModel(4332123), type=100, name="okokok", description="not ok anymore" ) + + assert result is rest_client._request.return_value rest_client._request.assert_awaited_once_with( - expected_route, json={"type": 1, "name": "okokok", "description": "not ok anymore"} + expected_route, + json={ + "type": 100, + "name": "okokok", + "description": "not ok anymore", + }, ) - async def test_create_slash_command(self, rest_client: rest.RESTClientImpl): + async def test__create_application_command_standardizes_default_member_permissions( + self, rest_client: rest.RESTClientImpl + ): expected_route = routes.POST_APPLICATION_COMMAND.compile(application=4332123) rest_client._request = mock.AsyncMock(return_value={"id": "29393939"}) - result = await rest_client.create_slash_command(StubModel(4332123), "okokok", "not ok anymore") + result = await rest_client._create_application_command( + application=StubModel(4332123), + type=100, + name="okokok", + description="not ok anymore", + default_member_permissions=permissions.Permissions.NONE, + ) + + assert result is rest_client._request.return_value + rest_client._request.assert_awaited_once_with( + expected_route, + json={ + "type": 100, + "name": "okokok", + "description": "not ok anymore", + "default_member_permissions": None, + }, + ) + + async def test_create_slash_command(self, rest_client: rest.RESTClientImpl): + rest_client._create_application_command = mock.AsyncMock() + mock_options = object() + mock_application = StubModel(4332123) + mock_guild = StubModel(123123123) + + result = await rest_client.create_slash_command( + mock_application, + "okokok", + "not ok anymore", + guild=mock_guild, + options=mock_options, + default_member_permissions=permissions.Permissions.ADMINISTRATOR, + dm_enabled=False, + ) assert result is rest_client._entity_factory.deserialize_slash_command.return_value rest_client._entity_factory.deserialize_slash_command.assert_called_once_with( - rest_client._request.return_value, guild_id=None + rest_client._create_application_command.return_value, guild_id=123123123 ) - rest_client._request.assert_awaited_once_with( - expected_route, json={"type": 1, "name": "okokok", "description": "not ok anymore"} + rest_client._create_application_command.assert_awaited_once_with( + application=mock_application, + type=commands.CommandType.SLASH, + name="okokok", + description="not ok anymore", + guild=mock_guild, + options=mock_options, + default_member_permissions=permissions.Permissions.ADMINISTRATOR, + dm_enabled=False, ) async def test_create_context_menu_command(self, rest_client: rest.RESTClientImpl): - expected_route = routes.POST_APPLICATION_COMMAND.compile(application=4332123) - rest_client._request = mock.AsyncMock(return_value={"id": "29393939"}) + rest_client._create_application_command = mock.AsyncMock() + mock_application = StubModel(4332123) + mock_guild = StubModel(123123123) - result = await rest_client.create_context_menu_command(StubModel(4332123), 2, "okokok") + result = await rest_client.create_context_menu_command( + mock_application, + commands.CommandType.USER, + "okokok", + guild=mock_guild, + default_member_permissions=permissions.Permissions.ADMINISTRATOR, + dm_enabled=False, + ) assert result is rest_client._entity_factory.deserialize_context_menu_command.return_value rest_client._entity_factory.deserialize_context_menu_command.assert_called_once_with( - rest_client._request.return_value, guild_id=None + rest_client._create_application_command.return_value, guild_id=123123123 + ) + rest_client._create_application_command.assert_awaited_once_with( + application=mock_application, + type=commands.CommandType.USER, + name="okokok", + guild=mock_guild, + default_member_permissions=permissions.Permissions.ADMINISTRATOR, + dm_enabled=False, ) - rest_client._request.assert_awaited_once_with(expected_route, json={"type": 2, "name": "okokok"}) async def test_set_application_commands_with_guild(self, rest_client): expected_route = routes.PUT_APPLICATION_GUILD_COMMANDS.compile(application=4321231, guild=6543234) @@ -4861,6 +4940,8 @@ async def test_edit_application_command_with_optionals(self, rest_client): name="ok sis", description="cancelled", options=[mock_option], + default_member_permissions=permissions.Permissions.BAN_MEMBERS, + dm_enabled=True, ) assert result is rest_client._entity_factory.deserialize_command.return_value @@ -4873,6 +4954,8 @@ async def test_edit_application_command_with_optionals(self, rest_client): "name": "ok sis", "description": "cancelled", "options": [rest_client._entity_factory.serialize_command_option.return_value], + "default_member_permissions": 4, + "dm_permission": True, }, ) rest_client._entity_factory.serialize_command_option.assert_called_once_with(mock_option) @@ -4892,6 +4975,27 @@ async def test_edit_application_command_without_optionals(self, rest_client): ) rest_client._request.assert_awaited_once_with(expected_route, json={}) + async def test_edit_application_command_standardizes_default_member_permissions( + self, rest_client: rest.RESTClientImpl + ): + expected_route = routes.PATCH_APPLICATION_COMMAND.compile(application=1235432, command=3451231) + rest_client._request = mock.AsyncMock(return_value={"id": "94594994"}) + + result = await rest_client.edit_application_command( + StubModel(1235432), + StubModel(3451231), + default_member_permissions=permissions.Permissions.NONE, + ) + + assert result is rest_client._entity_factory.deserialize_command.return_value + rest_client._entity_factory.deserialize_command.assert_called_once_with( + rest_client._request.return_value, guild_id=None + ) + rest_client._request.assert_awaited_once_with( + expected_route, + json={"default_member_permissions": None}, + ) + async def test_delete_application_command_with_guild(self, rest_client): expected_route = routes.DELETE_APPLICATION_GUILD_COMMAND.compile( application=312312, command=65234323, guild=5421312 @@ -4934,29 +5038,6 @@ async def test_fetch_application_command_permissions(self, rest_client): rest_client._entity_factory.deserialize_guild_command_permissions.assert_called_once_with(mock_command_payload) rest_client._request.assert_awaited_once_with(expected_route) - async def test_set_application_guild_commands_permissions(self, rest_client): - expected_route = routes.PUT_APPLICATION_GUILD_COMMANDS_PERMISSIONS.compile(application=321123, guild=542123) - mock_command_payload = object() - mock_permission = object() - rest_client._request = mock.AsyncMock(return_value=[mock_command_payload]) - - result = await rest_client.set_application_guild_commands_permissions( - 321123, 542123, {564123123: [mock_permission]} - ) - - assert result == [rest_client._entity_factory.deserialize_guild_command_permissions.return_value] - rest_client._entity_factory.serialize_command_permission.assert_called_once_with(mock_permission) - rest_client._entity_factory.deserialize_guild_command_permissions.assert_called_once_with(mock_command_payload) - rest_client._request.assert_awaited_once_with( - expected_route, - json=[ - { - "id": "564123123", - "permissions": [rest_client._entity_factory.serialize_command_permission.return_value], - } - ], - ) - async def test_set_application_command_permissions(self, rest_client): route = routes.PUT_APPLICATION_COMMAND_PERMISSIONS.compile(application=2321, guild=431, command=666666) mock_permission = object() diff --git a/tests/hikari/impl/test_shard.py b/tests/hikari/impl/test_shard.py index 18d3f1fcc8..b964581d26 100644 --- a/tests/hikari/impl/test_shard.py +++ b/tests/hikari/impl/test_shard.py @@ -1115,9 +1115,9 @@ async def test__identify(self, client): "compress": False, "large_threshold": 123, "properties": { - "$os": "Potato PC ARM64", - "$browser": "hikari (v1.0.0, aiohttp v0.0.1)", - "$device": "hikari v1.0.0", + "os": "Potato PC ARM64", + "browser": "hikari (v1.0.0, aiohttp v0.0.1)", + "device": "hikari v1.0.0", }, "shard": [0, 1], "intents": 131071, diff --git a/tests/hikari/impl/test_special_endpoints.py b/tests/hikari/impl/test_special_endpoints.py index 9004ffe2e2..d0bd8a5288 100644 --- a/tests/hikari/impl/test_special_endpoints.py +++ b/tests/hikari/impl/test_special_endpoints.py @@ -27,6 +27,7 @@ from hikari import emojis from hikari import files from hikari import messages +from hikari import permissions from hikari import snowflakes from hikari import undefined from hikari.impl import special_endpoints @@ -165,6 +166,133 @@ async def test_aiter_when_empty_chunk(self, newest_first: bool): mock_request.assert_awaited_once_with(compiled_route=expected_route, query=query) +class TestGuildBanIterator: + @pytest.mark.asyncio() + async def test_aiter(self): + expected_route = routes.GET_GUILD_BANS.compile(guild=10000) + mock_entity_factory = mock.Mock() + mock_payload_1 = {"user": {"id": "45234"}} + mock_payload_2 = {"user": {"id": "452745"}} + mock_payload_3 = {"user": {"id": "45237656"}} + mock_payload_4 = {"user": {"id": "452345666"}} + mock_payload_5 = {"user": {"id": "4523456744"}} + mock_result_1 = mock.Mock() + mock_result_2 = mock.Mock() + mock_result_3 = mock.Mock() + mock_result_4 = mock.Mock() + mock_result_5 = mock.Mock() + mock_entity_factory.deserialize_guild_member_ban.side_effect = [ + mock_result_1, + mock_result_2, + mock_result_3, + mock_result_4, + mock_result_5, + ] + mock_request = mock.AsyncMock( + side_effect=[[mock_payload_1, mock_payload_2, mock_payload_3], [mock_payload_4, mock_payload_5], []] + ) + iterator = special_endpoints.GuildBanIterator( + entity_factory=mock_entity_factory, + request_call=mock_request, + guild=10000, + newest_first=False, + first_id="0", + ) + + result = await iterator + + assert result == [mock_result_1, mock_result_2, mock_result_3, mock_result_4, mock_result_5] + mock_entity_factory.deserialize_guild_member_ban.assert_has_calls( + [ + mock.call(mock_payload_1), + mock.call(mock_payload_2), + mock.call(mock_payload_3), + mock.call(mock_payload_4), + mock.call(mock_payload_5), + ] + ) + mock_request.assert_has_awaits( + [ + mock.call(compiled_route=expected_route, query={"after": "0", "limit": "1000"}), + mock.call(compiled_route=expected_route, query={"after": "45237656", "limit": "1000"}), + mock.call(compiled_route=expected_route, query={"after": "4523456744", "limit": "1000"}), + ] + ) + + @pytest.mark.asyncio() + async def test_aiter_when_newest_first(self): + expected_route = routes.GET_GUILD_BANS.compile(guild=10000) + mock_entity_factory = mock.Mock() + mock_payload_1 = {"user": {"id": "432234"}} + mock_payload_2 = {"user": {"id": "1233211"}} + mock_payload_3 = {"user": {"id": "12332112"}} + mock_payload_4 = {"user": {"id": "1233"}} + mock_payload_5 = {"user": {"id": "54334"}} + mock_result_1 = mock.Mock() + mock_result_2 = mock.Mock() + mock_result_3 = mock.Mock() + mock_result_4 = mock.Mock() + mock_result_5 = mock.Mock() + mock_entity_factory.deserialize_guild_member_ban.side_effect = [ + mock_result_1, + mock_result_2, + mock_result_3, + mock_result_4, + mock_result_5, + ] + mock_request = mock.AsyncMock( + side_effect=[[mock_payload_1, mock_payload_2, mock_payload_3], [mock_payload_4, mock_payload_5], []] + ) + iterator = special_endpoints.GuildBanIterator( + entity_factory=mock_entity_factory, + request_call=mock_request, + guild=10000, + newest_first=True, + first_id="321123321", + ) + + result = await iterator + + assert result == [mock_result_1, mock_result_2, mock_result_3, mock_result_4, mock_result_5] + mock_entity_factory.deserialize_guild_member_ban.assert_has_calls( + [ + mock.call(mock_payload_3), + mock.call(mock_payload_2), + mock.call(mock_payload_1), + mock.call(mock_payload_5), + mock.call(mock_payload_4), + ] + ) + mock_request.assert_has_awaits( + [ + mock.call(compiled_route=expected_route, query={"before": "321123321", "limit": "1000"}), + mock.call(compiled_route=expected_route, query={"before": "432234", "limit": "1000"}), + mock.call(compiled_route=expected_route, query={"before": "1233", "limit": "1000"}), + ] + ) + + @pytest.mark.parametrize("newest_first", [True, False]) + @pytest.mark.asyncio() + async def test_aiter_when_empty_chunk(self, newest_first: bool): + expected_route = routes.GET_GUILD_BANS.compile(guild=10000) + mock_entity_factory = mock.Mock() + mock_request = mock.AsyncMock(return_value=[]) + iterator = special_endpoints.GuildBanIterator( + entity_factory=mock_entity_factory, + request_call=mock_request, + guild=10000, + newest_first=newest_first, + first_id="54234123123", + ) + + result = await iterator + + assert result == [] + mock_entity_factory.deserialize_guild_member_ban.assert_not_called() + query = {"before" if newest_first else "after": "54234123123", "limit": "1000"} + mock_request.assert_awaited_once_with(compiled_route=expected_route, query=query) + + class TestScheduledEventUserIterator: @pytest.mark.asyncio() async def test_aiter(self): @@ -564,10 +692,17 @@ def test_id_property(self): assert builder.id == 3212123 - def test_default_permission(self): - builder = special_endpoints.SlashCommandBuilder("oksksksk", "kfdkodfokfd").set_default_permission(True) + def test_default_member_permissions(self): + builder = special_endpoints.SlashCommandBuilder("oksksksk", "kfdkodfokfd").set_default_member_permissions( + permissions.Permissions.ADMINISTRATOR + ) - assert builder.default_permission is True + assert builder.default_member_permissions == permissions.Permissions.ADMINISTRATOR + + def test_is_dm_enabled(self): + builder = special_endpoints.SlashCommandBuilder("oksksksk", "kfdkodfokfd").set_is_dm_enabled(True) + + assert builder.is_dm_enabled is True def test_build_with_optional_data(self): mock_entity_factory = mock.Mock() @@ -576,7 +711,8 @@ def test_build_with_optional_data(self): special_endpoints.SlashCommandBuilder("we are number", "one") .add_option(mock_option) .set_id(3412312) - .set_default_permission(False) + .set_default_member_permissions(permissions.Permissions.ADMINISTRATOR) + .set_is_dm_enabled(True) ) result = builder.build(mock_entity_factory) @@ -586,7 +722,8 @@ def test_build_with_optional_data(self): "name": "we are number", "description": "one", "type": 1, - "default_permission": False, + "dm_permission": True, + "default_member_permissions": 8, "options": [mock_entity_factory.serialize_command_option.return_value], "id": "3412312", } @@ -603,8 +740,8 @@ async def test_create(self): builder = ( special_endpoints.SlashCommandBuilder("we are number", "one") .add_option(mock.Mock()) - .set_id(3412312) - .set_default_permission(False) + .set_default_member_permissions(permissions.Permissions.BAN_MEMBERS) + .set_is_dm_enabled(True) ) mock_rest = mock.AsyncMock() @@ -616,13 +753,18 @@ async def test_create(self): builder.name, builder.description, guild=undefined.UNDEFINED, - default_permission=builder.default_permission, options=builder.options, + default_member_permissions=permissions.Permissions.BAN_MEMBERS, + dm_enabled=True, ) @pytest.mark.asyncio() async def test_create_with_guild(self): - builder = special_endpoints.SlashCommandBuilder("we are number", "one") + builder = ( + special_endpoints.SlashCommandBuilder("we are number", "one") + .set_default_member_permissions(permissions.Permissions.BAN_MEMBERS) + .set_is_dm_enabled(True) + ) mock_rest = mock.AsyncMock() result = await builder.create(mock_rest, 54455445, guild=54123123321) @@ -633,8 +775,9 @@ async def test_create_with_guild(self): builder.name, builder.description, guild=54123123321, - default_permission=builder.default_permission, options=builder.options, + default_member_permissions=permissions.Permissions.BAN_MEMBERS, + dm_enabled=True, ) @@ -643,7 +786,8 @@ def test_build_with_optional_data(self): builder = ( special_endpoints.ContextMenuCommandBuilder(commands.CommandType.USER, "we are number") .set_id(3412312) - .set_default_permission(False) + .set_default_member_permissions(permissions.Permissions.ADMINISTRATOR) + .set_is_dm_enabled(True) ) result = builder.build(mock.Mock()) @@ -651,7 +795,8 @@ def test_build_with_optional_data(self): assert result == { "name": "we are number", "type": 2, - "default_permission": False, + "dm_permission": True, + "default_member_permissions": 8, "id": "3412312", } @@ -666,8 +811,8 @@ def test_build_without_optional_data(self): async def test_create(self): builder = ( special_endpoints.ContextMenuCommandBuilder(commands.CommandType.USER, "we are number") - .set_id(3412312) - .set_default_permission(False) + .set_default_member_permissions(permissions.Permissions.BAN_MEMBERS) + .set_is_dm_enabled(True) ) mock_rest = mock.AsyncMock() @@ -679,12 +824,17 @@ async def test_create(self): builder.type, builder.name, guild=undefined.UNDEFINED, - default_permission=builder.default_permission, + default_member_permissions=permissions.Permissions.BAN_MEMBERS, + dm_enabled=True, ) @pytest.mark.asyncio() async def test_create_with_guild(self): - builder = special_endpoints.ContextMenuCommandBuilder(commands.CommandType.MESSAGE, "we are number") + builder = ( + special_endpoints.ContextMenuCommandBuilder(commands.CommandType.USER, "we are number") + .set_default_member_permissions(permissions.Permissions.BAN_MEMBERS) + .set_is_dm_enabled(True) + ) mock_rest = mock.AsyncMock() result = await builder.create(mock_rest, 4444444, guild=765234123) @@ -695,7 +845,8 @@ async def test_create_with_guild(self): builder.type, builder.name, guild=765234123, - default_permission=builder.default_permission, + default_member_permissions=permissions.Permissions.BAN_MEMBERS, + dm_enabled=True, ) diff --git a/tests/hikari/interactions/test_command_interactions.py b/tests/hikari/interactions/test_command_interactions.py index 0cd383da77..a579c3e169 100644 --- a/tests/hikari/interactions/test_command_interactions.py +++ b/tests/hikari/interactions/test_command_interactions.py @@ -56,6 +56,7 @@ def mock_command_interaction(self, mock_app): resolved=None, locale="es-ES", guild_locale="en-US", + app_permissions=543123, ) def test_build_response(self, mock_command_interaction, mock_app): @@ -113,7 +114,6 @@ def mock_autocomplete_interaction(self, mock_app): command_name="OKOKOK", command_type=1, options=[], - resolved=None, ) @pytest.fixture() diff --git a/tests/hikari/interactions/test_component_interactions.py b/tests/hikari/interactions/test_component_interactions.py index d96358c62d..19b2ac4e37 100644 --- a/tests/hikari/interactions/test_component_interactions.py +++ b/tests/hikari/interactions/test_component_interactions.py @@ -55,6 +55,7 @@ def mock_component_interaction(self, mock_app): message=object(), locale="es-ES", guild_locale="en-US", + app_permissions=123321, ) def test_build_response(self, mock_component_interaction, mock_app): diff --git a/tests/hikari/interactions/test_modal_interactions.py b/tests/hikari/interactions/test_modal_interactions.py index 7dc9fd3eb5..5adf37dc12 100644 --- a/tests/hikari/interactions/test_modal_interactions.py +++ b/tests/hikari/interactions/test_modal_interactions.py @@ -54,6 +54,7 @@ def mock_modal_interaction(self, mock_app): message=object(), locale="es-ES", guild_locale="en-US", + app_permissions=543123, components=special_endpoints.ActionRowBuilder( components=[ modal_interactions.InteractionTextInput( diff --git a/tests/hikari/internal/test_aio.py b/tests/hikari/internal/test_aio.py index bdf532fb68..d50317470c 100644 --- a/tests/hikari/internal/test_aio.py +++ b/tests/hikari/internal/test_aio.py @@ -44,7 +44,7 @@ def __await__(self): def __repr__(self): args = ", ".join(map(repr, self.args)) - kwargs = ", ".join(map(lambda k, v: f"{k!s}={v!r}", self.kwargs.items())) + kwargs = ", ".join(f"{key!s}={value!r}" for key, value in self.kwargs.items()) return f"({args}, {kwargs})" diff --git a/tests/hikari/internal/test_collections.py b/tests/hikari/internal/test_collections.py index a6b9802609..4b4cdac12a 100644 --- a/tests/hikari/internal/test_collections.py +++ b/tests/hikari/internal/test_collections.py @@ -80,35 +80,6 @@ def test___setitem__(self): assert mock_map == {"hmm": "forearm", "cat": "bag", "ok": "bye", "bye": 4} -class TestFrozenDict: - def test___init__(self): - mock_map = collections._FrozenDict({"foo": (0.432, "bar"), "blam": (0.111, "okok")}) - assert mock_map == {"foo": "bar", "blam": "okok"} - - def test___getitem__(self): - mock_map = collections._FrozenDict({"blam": (0.432, "bar"), "obar": (0.111, "okok")}) - assert mock_map["obar"] == "okok" - - def test___iter__(self): - mock_map = collections._FrozenDict({"bye": (0.33, "bye"), "111": (0.2, "222"), "45949": (0.5, "020202")}) - assert list(mock_map) == ["bye", "111", "45949"] - - def test___len__(self): - mock_map = collections._FrozenDict({"wsw": (0.3, "3"), "fdsa": (0.55, "ewqwe"), "45949": (0.23, "fsasd")}) - assert len(mock_map) == 3 - - def test___delitem__(self): - mock_map = collections._FrozenDict({"rororo": (0.55, "bye bye"), "raw": (0.999, "ywywyw")}) - del mock_map["raw"] - assert mock_map == {"rororo": "bye bye"} - - def test___setitem__(self): - mock_map = collections._FrozenDict({"rororo": (0.55, "bye 3231"), "2121": (0.999, "4321")}) - mock_map["foo bar"] = 42 - - assert mock_map == {"rororo": "bye 3231", "2121": "4321", "foo bar": 42} - - class TestLimitedCapacityCacheMap: def test___init___with_source(self): raw_map = {"voo": "doo", "blam": "blast", "foo": "bye"} @@ -380,8 +351,3 @@ def test_get_index_or_slice_with_index_outside_range(): def test_get_index_or_slice_with_slice(): test_map = {"o": "b", "b": "o", "a": "m", "arara": "blam", "oof": "no", "rika": "may"} assert collections.get_index_or_slice(test_map, slice(1, 5, 2)) == ("o", "blam") - - -def test_get_index_or_slice_with_invalid_type(): - with pytest.raises(TypeError): - collections.get_index_or_slice({}, object()) diff --git a/tests/hikari/internal/test_net.py b/tests/hikari/internal/test_net.py index 4d7312fb8d..62891653b6 100644 --- a/tests/hikari/internal/test_net.py +++ b/tests/hikari/internal/test_net.py @@ -70,6 +70,52 @@ async def json(self): assert returned is error() +@pytest.mark.parametrize( + ("status_", "expected_error"), + [ + # The following internal server non-conforming status codes are used by cloudflare. + # Source I made it up... + # jk https://en.wikipedia.org/wiki/List_of_HTTP_status_codes + (520, "InternalServerError"), + (521, "InternalServerError"), + (522, "InternalServerError"), + (523, "InternalServerError"), + (524, "InternalServerError"), + (525, "InternalServerError"), + (526, "InternalServerError"), + (527, "InternalServerError"), + (530, "InternalServerError"), + # These non-conforming bad requests status codes are sent by NGINX. + # Same source as cloudflare status codes. + (494, "ClientHTTPResponseError"), + (495, "ClientHTTPResponseError"), + (496, "ClientHTTPResponseError"), + (497, "ClientHTTPResponseError"), + # This non-conforming status code is made up. + (694, "HTTPResponseError"), + ], +) +@pytest.mark.asyncio() +async def test_generate_error_response_with_non_conforming_status_code(status_, expected_error): + class StubResponse: + real_url = "https://some.url" + status = status_ + headers = {} + + async def read(self): + return "some raw body" + + async def json(self): + return {"message": "raw message", "code": 123} + + with mock.patch.object(errors, expected_error) as error: + returned = await net.generate_error_response(StubResponse()) + + error.assert_called_once_with("https://some.url", status_, {}, "some raw body") + + assert returned is error() + + @pytest.mark.parametrize( ("status_", "expected_error"), [ diff --git a/tests/hikari/internal/test_ux.py b/tests/hikari/internal/test_ux.py index d5c6b96fc0..fbbcb56397 100644 --- a/tests/hikari/internal/test_ux.py +++ b/tests/hikari/internal/test_ux.py @@ -19,6 +19,7 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +import builtins import contextlib import importlib import logging @@ -180,9 +181,10 @@ def test_when_supports_color(self, mock_args): supports_color = stack.enter_context(mock.patch.object(ux, "supports_color", return_value=True)) read_text = stack.enter_context(mock.patch.object(importlib.resources, "read_text")) template = stack.enter_context(mock.patch.object(string, "Template")) - write = stack.enter_context(mock.patch.object(sys.stdout, "write")) + builtins_open = stack.enter_context(mock.patch.object(builtins, "open")) abspath = stack.enter_context(mock.patch.object(os.path, "abspath", return_value="some path")) dirname = stack.enter_context(mock.patch.object(os.path, "dirname")) + fileno = stack.enter_context(mock.patch.object(sys.stdout, "fileno")) with stack: ux.print_banner("hikari", True, False) @@ -207,7 +209,8 @@ def test_when_supports_color(self, mock_args): template.assert_called_once_with(read_text()) template().safe_substitute.assert_called_once_with(args) - write.assert_called_once_with(template().safe_substitute()) + builtins_open.assert_called_once_with(fileno.return_value, "w", encoding="utf-8", closefd=False) + builtins_open.return_value.__enter__.return_value.write.assert_called_once_with(template().safe_substitute()) dirname.assert_called_once_with("~/hikari") abspath.assert_called_once_with(dirname()) supports_color.assert_called_once_with(True, False) @@ -221,9 +224,10 @@ def test_when_doesnt_supports_color(self, mock_args): supports_color = stack.enter_context(mock.patch.object(ux, "supports_color", return_value=False)) read_text = stack.enter_context(mock.patch.object(importlib.resources, "read_text")) template = stack.enter_context(mock.patch.object(string, "Template")) - write = stack.enter_context(mock.patch.object(sys.stdout, "write")) abspath = stack.enter_context(mock.patch.object(os.path, "abspath", return_value="some path")) dirname = stack.enter_context(mock.patch.object(os.path, "dirname")) + builtins_open = stack.enter_context(mock.patch.object(builtins, "open")) + fileno = stack.enter_context(mock.patch.object(sys.stdout, "fileno")) with stack: ux.print_banner("hikari", True, False) @@ -248,10 +252,11 @@ def test_when_doesnt_supports_color(self, mock_args): template.assert_called_once_with(read_text()) template().safe_substitute.assert_called_once_with(args) - write.assert_called_once_with(template().safe_substitute()) dirname.assert_called_once_with("~/hikari") abspath.assert_called_once_with(dirname()) supports_color.assert_called_once_with(True, False) + builtins_open.assert_called_once_with(fileno.return_value, "w", encoding="utf-8", closefd=False) + builtins_open.return_value.__enter__.return_value.write.assert_called_once_with(template().safe_substitute()) def test_use_extra_args(self, mock_args): stack = contextlib.ExitStack() @@ -259,8 +264,9 @@ def test_use_extra_args(self, mock_args): stack.enter_context(mock.patch.object(time, "sleep")) read_text = stack.enter_context(mock.patch.object(importlib.resources, "read_text")) template = stack.enter_context(mock.patch.object(string, "Template")) - write = stack.enter_context(mock.patch.object(sys.stdout, "write")) + builtins_open = stack.enter_context(mock.patch.object(builtins, "open")) stack.enter_context(mock.patch.object(os.path, "abspath", return_value="some path")) + fileno = stack.enter_context(mock.patch.object(sys.stdout, "fileno")) extra_args = { "extra_argument_1": "one", @@ -289,7 +295,8 @@ def test_use_extra_args(self, mock_args): template.assert_called_once_with(read_text()) template().safe_substitute.assert_called_once_with(args) - write.assert_called_once_with(template().safe_substitute()) + builtins_open.assert_called_once_with(fileno.return_value, "w", encoding="utf-8", closefd=False) + builtins_open.return_value.__enter__.return_value.write.assert_called_once_with(template().safe_substitute()) def test_overwrite_args_raises_error(self, mock_args): stack = contextlib.ExitStack() diff --git a/tests/hikari/test_audit_logs.py b/tests/hikari/test_audit_logs.py index b76c44e784..b14ae12604 100644 --- a/tests/hikari/test_audit_logs.py +++ b/tests/hikari/test_audit_logs.py @@ -158,15 +158,6 @@ def test_get_item_with_slice(self): ) assert audit_log[1:5:2] == (entry_1, entry_2) - def test_get_item_with_ivalid_type(self): - with pytest.raises(TypeError): - audit_logs.AuditLog( - entries=[object(), object()], - integrations={}, - users={}, - webhooks={}, - )["OK"] - def test_len(self): audit_log = audit_logs.AuditLog( entries={ diff --git a/tests/hikari/test_colors.py b/tests/hikari/test_colors.py index 2a99976cb5..1f178fb5e9 100644 --- a/tests/hikari/test_colors.py +++ b/tests/hikari/test_colors.py @@ -276,7 +276,7 @@ def test_Color_to_bytes(self): [ (colors.Color(0xFF051A), colors.Color(0xFF051A)), (0xFF051A, colors.Color(0xFF051A)), - ((1, 0.5, 0), colors.Color(0xFF7F00)), + ((1.0, 0.5, 0.0), colors.Color(0xFF7F00)), ([0xFF, 0x5, 0x1A], colors.Color(0xFF051A)), ("#1a2b3c", colors.Color(0x1A2B3C)), ("#123", colors.Color(0x112233)), @@ -305,9 +305,9 @@ def test_Color_of_happy_path(self, input, expected_result): (NotImplemented, r"Could not transform NotImplemented into a Color object"), ((1, 1, 1, 1), r"Color must be an RGB triplet if set to a tuple type"), ((1, "a", 1), r"Could not transform \(1, 'a', 1\) into a Color object"), - ((1.1, 1, 1), r"Expected red channel to be in the inclusive range of 0.0 and 1.0"), - ((1, 1.1, 1), r"Expected green channel to be in the inclusive range of 0.0 and 1.0"), - ((1, 1, 1.1), r"Expected blue channel to be in the inclusive range of 0.0 and 1.0"), + ((1.1, 1.0, 1.0), r"Expected red channel to be in the inclusive range of 0.0 and 1.0"), + ((1.0, 1.1, 1.0), r"Expected green channel to be in the inclusive range of 0.0 and 1.0"), + ((1.0, 1.0, 1.1), r"Expected blue channel to be in the inclusive range of 0.0 and 1.0"), ((), r"Color must be an RGB triplet if set to a tuple type"), ({}, r"Could not transform \{\} into a Color object"), ([], r"Color must be an RGB triplet if set to a list type"), diff --git a/tests/hikari/test_commands.py b/tests/hikari/test_commands.py index 19b4861d60..76117ec999 100644 --- a/tests/hikari/test_commands.py +++ b/tests/hikari/test_commands.py @@ -43,7 +43,8 @@ def mock_command(self, mock_app): type=commands.CommandType.SLASH, application_id=snowflakes.Snowflake(65234123), name="Name", - default_permission=False, + default_member_permissions=None, + is_dm_enabled=False, guild_id=snowflakes.Snowflake(31231235), version=snowflakes.Snowflake(43123123), ) diff --git a/tests/hikari/test_errors.py b/tests/hikari/test_errors.py index 4d6d955839..9b71a1098f 100644 --- a/tests/hikari/test_errors.py +++ b/tests/hikari/test_errors.py @@ -87,6 +87,10 @@ def error(self): def test_str(self, error): assert str(error) == "Bad Request 400: (12345) 'message' for https://some.url" + def test_str_when_int_status_code(self, error): + error.status = 699 + assert str(error) == "Unknown Status 699: (12345) 'message' for https://some.url" + def test_str_when_message_is_None(self, error): error.message = None assert str(error) == "Bad Request 400: (12345) 'raw body' for https://some.url" diff --git a/tests/hikari/test_messages.py b/tests/hikari/test_messages.py index 1e95ddfeca..afd5a561ae 100644 --- a/tests/hikari/test_messages.py +++ b/tests/hikari/test_messages.py @@ -127,13 +127,11 @@ def message(): timestamp=datetime.datetime.now().astimezone(), edited_timestamp=None, is_tts=False, - mentions=messages.Mentions( - message=mock.Mock(), - users={}, - role_ids=[], - channels={}, - everyone=False, - ), + mentions=messages.Mentions(message=mock.Mock()), + user_mentions={}, + role_mention_ids=[], + channel_mentions={}, + mentions_everyone=False, attachments=(), embeds=(), reactions=(), diff --git a/tests/hikari/test_snowflake.py b/tests/hikari/test_snowflake.py index c5ed3f6552..c10c72d3f9 100644 --- a/tests/hikari/test_snowflake.py +++ b/tests/hikari/test_snowflake.py @@ -88,14 +88,10 @@ def test_from_datetime(self): assert isinstance(result, snowflakes.Snowflake) def test_min(self): - sf = snowflakes.Snowflake.min() - assert sf == 0 - assert snowflakes.Snowflake.min() is sf + assert snowflakes.Snowflake.min() == 0 def test_max(self): - sf = snowflakes.Snowflake.max() - assert sf == (1 << 63) - 1 - assert snowflakes.Snowflake.max() is sf + assert snowflakes.Snowflake.max() == (1 << 63) - 1 class TestUnique: