From 1124aadcf70c99259171b9f360a28897e090c9ce Mon Sep 17 00:00:00 2001 From: Patchwork Collective <226386-aster.codes@users.noreply.gitlab.com> Date: Sun, 17 Oct 2021 10:56:06 -0400 Subject: [PATCH 01/10] Added updated with_any_role_check. --- tanjun/checks.py | 70 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/tanjun/checks.py b/tanjun/checks.py index 3c3cfa92e..4765da370 100644 --- a/tanjun/checks.py +++ b/tanjun/checks.py @@ -509,6 +509,42 @@ async def __call__( return self._handle_result((permissions & self._permissions) == self._permissions) +class HasAnyRoleCheck(_Check): + __slots__ = ( + "_halt_execution", + "_error_message", + "required_roles", + ) + + def __init__( + self, + roles: list[hikari.SnowflakeishOr[hikari.Role] | str] = list(), + *, + error_message: typing.Optional[str] = "You do not have the required roles to use this command!", + halt_execution: bool = True, + ) -> None: + super().__init__(error_message, halt_execution) + self.required_roles = roles + + async def __call__(self, ctx: tanjun_abc.Context, /) -> bool: + + if not ctx.member: + return self._handle_result(False) + + member_roles = ctx.member.get_roles() + + result = any(self.check_roles(member_role) for member_role in member_roles) + return self._handle_result(result) + + def check_roles(self, member_role: hikari.Role) -> bool: + for check in self.required_roles: + if isinstance(check, int) and member_role.id == check: + return True + elif isinstance(check, str) and member_role.name == check: + return True + return False + + @typing.overload def with_dm_check(command: CommandT, /) -> CommandT: ... @@ -862,6 +898,40 @@ def with_own_permission_check( ) +def with_any_role_check( + roles: list[hikari.SnowflakeishOr[hikari.Role] | int | str] = list(), + *, + error_message: typing.Optional[str] = "You do not have the required roles to use this command!", + halt_execution: bool = True, +) -> collections.Callable[[CommandT], CommandT]: + """Only let a command run if the author has a specific role. + + Parameters + ---------- + roles: list[Union[SnowflakeishOr[Role], int]] + The author must have at least one (1) role in this list. (Role.name and Role.id are checked) + + Other Parameters + ---------------- + error_message: Optional[str] + The error message raised if the member does not have a required role. + + Defaults to 'You do not have the required roles to use this command!' + halt_execution: bool + Whether this check should raise `tanjun.errors.HaltExecution` to + end the execution search when it fails instead of returning `False`. + + Defaults to `False`. + + Returns + ------- + collections.abc.Callable[[CommandT], CommandT] + A command decorator callback which adds the check.""" + return lambda command: command.add_check( + HasAnyRoleCheck(roles, error_message=error_message, halt_execution=halt_execution) + ) + + def with_check(check: tanjun_abc.CheckSig, /) -> collections.Callable[[CommandT], CommandT]: """Add a generic check to a command. From 15bdc0f99c9b7a25d9370f765b5f5863820abf87 Mon Sep 17 00:00:00 2001 From: Patchwork Collective <226386-aster.codes@users.noreply.gitlab.com> Date: Sun, 17 Oct 2021 11:00:32 -0400 Subject: [PATCH 02/10] Nox cleanups. --- tanjun/checks.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tanjun/checks.py b/tanjun/checks.py index 4765da370..45b91b3e7 100644 --- a/tanjun/checks.py +++ b/tanjun/checks.py @@ -518,7 +518,7 @@ class HasAnyRoleCheck(_Check): def __init__( self, - roles: list[hikari.SnowflakeishOr[hikari.Role] | str] = list(), + roles: list[hikari.SnowflakeishOr[hikari.Role] | str] = [], *, error_message: typing.Optional[str] = "You do not have the required roles to use this command!", halt_execution: bool = True, @@ -899,7 +899,7 @@ def with_own_permission_check( def with_any_role_check( - roles: list[hikari.SnowflakeishOr[hikari.Role] | int | str] = list(), + roles: list[hikari.SnowflakeishOr[hikari.Role] | int | str] = [], *, error_message: typing.Optional[str] = "You do not have the required roles to use this command!", halt_execution: bool = True, @@ -926,7 +926,8 @@ def with_any_role_check( Returns ------- collections.abc.Callable[[CommandT], CommandT] - A command decorator callback which adds the check.""" + A command decorator callback which adds the check. + """ return lambda command: command.add_check( HasAnyRoleCheck(roles, error_message=error_message, halt_execution=halt_execution) ) From c3741291dd03f5e683f1221519950fda9e46152c Mon Sep 17 00:00:00 2001 From: Patchwork Collective <226386-aster.codes@users.noreply.gitlab.com> Date: Sun, 17 Oct 2021 11:23:05 -0400 Subject: [PATCH 03/10] Simple test for with_any_role_check. --- tests/test_checks.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/test_checks.py b/tests/test_checks.py index 9ba2efb65..d33dc38b8 100644 --- a/tests/test_checks.py +++ b/tests/test_checks.py @@ -558,6 +558,29 @@ def test_with_own_permission_check(command: mock.Mock): own_permission_check.assert_called_once_with(5412312, halt_execution=True, error_message="hi") +def test_with_has_any_role_check(command: mock.Mock): + with mock.patch.object(tanjun.checks, "HasAnyRoleCheck") as any_role_check: + assert ( + tanjun.checks.with_any_role_check( + [ + "Admin", + ], + halt_execution=True, + error_message="hi", + )(command) + is command + ) + + command.add_check.assert_called_once_with(any_role_check.return_value) + any_role_check.assert_called_once_with( + [ + "Admin", + ], + halt_execution=True, + error_message="hi", + ) + + def test_with_check(command: mock.Mock): mock_check = mock.Mock() From 186b465c41aa59adf3b73f826453e5eb02701c9d Mon Sep 17 00:00:00 2001 From: Patchwork Collective <226386-aster.codes@users.noreply.gitlab.com> Date: Sun, 17 Oct 2021 13:15:10 -0400 Subject: [PATCH 04/10] Removed Python 3.10 type hinting. Removed unused/leftover slots. Loosened HasAnyRoleCheck and with_any_role_check paramter requirements to a more general type. Changed HasAnyRoleCheck.required_roles to a set type for better performance. --- tanjun/checks.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/tanjun/checks.py b/tanjun/checks.py index 45b91b3e7..df338af3b 100644 --- a/tanjun/checks.py +++ b/tanjun/checks.py @@ -510,31 +510,28 @@ async def __call__( class HasAnyRoleCheck(_Check): - __slots__ = ( - "_halt_execution", - "_error_message", - "required_roles", - ) + __slots__ = ("required_roles",) def __init__( self, - roles: list[hikari.SnowflakeishOr[hikari.Role] | str] = [], + roles: collections.Sequence[typing.Union[hikari.SnowflakeishOr[hikari.Role], str]] = [], *, error_message: typing.Optional[str] = "You do not have the required roles to use this command!", halt_execution: bool = True, ) -> None: super().__init__(error_message, halt_execution) - self.required_roles = roles + self.required_roles = set(roles) async def __call__(self, ctx: tanjun_abc.Context, /) -> bool: - if not ctx.member: return self._handle_result(False) member_roles = ctx.member.get_roles() - result = any(self.check_roles(member_role) for member_role in member_roles) - return self._handle_result(result) + for member_role in member_roles: + if result := self.check_roles(member_role): + return self._handle_result(result) + return self._handle_result(False) def check_roles(self, member_role: hikari.Role) -> bool: for check in self.required_roles: @@ -899,7 +896,7 @@ def with_own_permission_check( def with_any_role_check( - roles: list[hikari.SnowflakeishOr[hikari.Role] | int | str] = [], + roles: collections.Sequence[typing.Union[hikari.SnowflakeishOr[hikari.Role], int, str]] = [], *, error_message: typing.Optional[str] = "You do not have the required roles to use this command!", halt_execution: bool = True, From 8b3998c84ac5b5819bacbaa59910467be2512eca Mon Sep 17 00:00:00 2001 From: Patchwork Collective <226386-aster.codes@users.noreply.gitlab.com> Date: Sun, 17 Oct 2021 13:58:49 -0400 Subject: [PATCH 05/10] Update docs to reflect new typing. Simplified hanlder logic. Reverted internal roles typing from set to list. Simplified role checking logic. --- tanjun/checks.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/tanjun/checks.py b/tanjun/checks.py index df338af3b..c9964025d 100644 --- a/tanjun/checks.py +++ b/tanjun/checks.py @@ -520,7 +520,7 @@ def __init__( halt_execution: bool = True, ) -> None: super().__init__(error_message, halt_execution) - self.required_roles = set(roles) + self.required_roles = roles async def __call__(self, ctx: tanjun_abc.Context, /) -> bool: if not ctx.member: @@ -529,15 +529,13 @@ async def __call__(self, ctx: tanjun_abc.Context, /) -> bool: member_roles = ctx.member.get_roles() for member_role in member_roles: - if result := self.check_roles(member_role): - return self._handle_result(result) + if self.check_roles(member_role): + return self._handle_result(True) return self._handle_result(False) def check_roles(self, member_role: hikari.Role) -> bool: for check in self.required_roles: - if isinstance(check, int) and member_role.id == check: - return True - elif isinstance(check, str) and member_role.name == check: + if member_role.id == check or member_role.name == check: return True return False @@ -899,13 +897,13 @@ def with_any_role_check( roles: collections.Sequence[typing.Union[hikari.SnowflakeishOr[hikari.Role], int, str]] = [], *, error_message: typing.Optional[str] = "You do not have the required roles to use this command!", - halt_execution: bool = True, + halt_execution: bool = False, ) -> collections.Callable[[CommandT], CommandT]: """Only let a command run if the author has a specific role. Parameters ---------- - roles: list[Union[SnowflakeishOr[Role], int]] + roles: collections.Sequence[Union[SnowflakeishOr[Role], int, str]] The author must have at least one (1) role in this list. (Role.name and Role.id are checked) Other Parameters From 305ab308386add8761b0ec33b984b8aec28a2628 Mon Sep 17 00:00:00 2001 From: Patchwork Collective <226386-aster.codes@users.noreply.gitlab.com> Date: Wed, 20 Oct 2021 18:48:47 -0400 Subject: [PATCH 06/10] Fixed __all__ in tanjun and checks.py. Updated docs to clarify check locks commands to guilds. --- tanjun/__init__.py | 1 + tanjun/checks.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/tanjun/__init__.py b/tanjun/__init__.py index 96cd4726c..96ccf0944 100644 --- a/tanjun/__init__.py +++ b/tanjun/__init__.py @@ -121,6 +121,7 @@ async def main() -> None: "with_owner_check", "with_author_permission_check", "with_own_permission_check", + "with_any_role_check", # clients.py "clients", "as_loader", diff --git a/tanjun/checks.py b/tanjun/checks.py index c9964025d..88c231d7b 100644 --- a/tanjun/checks.py +++ b/tanjun/checks.py @@ -48,6 +48,7 @@ "with_owner_check", "with_author_permission_check", "with_own_permission_check", + "with_any_role_check", "DmCheck", "GuildCheck", "NsfwCheck", @@ -55,6 +56,7 @@ "OwnerCheck", "AuthorPermissionCheck", "OwnPermissionCheck", + "HasAnyRoleCheck", ] import typing @@ -899,7 +901,7 @@ def with_any_role_check( error_message: typing.Optional[str] = "You do not have the required roles to use this command!", halt_execution: bool = False, ) -> collections.Callable[[CommandT], CommandT]: - """Only let a command run if the author has a specific role. + """Only let a command run if the author has a specific role and the command is called in a guild. Parameters ---------- From c40e3b99022d3aedded3ec33dffde6f6c7da2891 Mon Sep 17 00:00:00 2001 From: Patchwork Collective <226386-aster.codes@users.noreply.gitlab.com> Date: Wed, 20 Oct 2021 18:55:25 -0400 Subject: [PATCH 07/10] Changed for-if loop to any-map. --- tanjun/checks.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tanjun/checks.py b/tanjun/checks.py index 88c231d7b..ba657666b 100644 --- a/tanjun/checks.py +++ b/tanjun/checks.py @@ -530,10 +530,7 @@ async def __call__(self, ctx: tanjun_abc.Context, /) -> bool: member_roles = ctx.member.get_roles() - for member_role in member_roles: - if self.check_roles(member_role): - return self._handle_result(True) - return self._handle_result(False) + return self._handle_result(any(map(self.check_roles, member_roles))) def check_roles(self, member_role: hikari.Role) -> bool: for check in self.required_roles: From d87f8cf51dfabca61dfab7b2b2f3b6ed90476e3e Mon Sep 17 00:00:00 2001 From: Patchwork Collective <226386-aster.codes@users.noreply.gitlab.com> Date: Wed, 20 Oct 2021 19:15:56 -0400 Subject: [PATCH 08/10] Condensed HasAnyRoles.check_roles into an any() statement. --- tanjun/checks.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tanjun/checks.py b/tanjun/checks.py index ba657666b..9d3a6902a 100644 --- a/tanjun/checks.py +++ b/tanjun/checks.py @@ -533,10 +533,7 @@ async def __call__(self, ctx: tanjun_abc.Context, /) -> bool: return self._handle_result(any(map(self.check_roles, member_roles))) def check_roles(self, member_role: hikari.Role) -> bool: - for check in self.required_roles: - if member_role.id == check or member_role.name == check: - return True - return False + return any(member_role.id == check or member_role.name == check for check in self.required_roles) @typing.overload From 9bc580d4075468c026cd430fb0762e7090ca1399 Mon Sep 17 00:00:00 2001 From: Patchwork Collective <226386-aster.codes@users.noreply.gitlab.com> Date: Mon, 25 Oct 2021 01:38:40 -0400 Subject: [PATCH 09/10] We tested this some tonight, with caching on and off, and it seems to work. It's not the most elegant solution or the fastes for sure. Snab had mentioned making `HasAnyRole.required_roles` a property and coercing all the roles to ids for faster comparision. We might add that in later. --- tanjun/checks.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/tanjun/checks.py b/tanjun/checks.py index 9d3a6902a..aa2760f6c 100644 --- a/tanjun/checks.py +++ b/tanjun/checks.py @@ -528,11 +528,19 @@ async def __call__(self, ctx: tanjun_abc.Context, /) -> bool: if not ctx.member: return self._handle_result(False) - member_roles = ctx.member.get_roles() + guild_roles = ctx.cache.get_roles_view_for_guild(ctx.member.guild_id) if ctx.cache else None + if not guild_roles: + guild_roles = await ctx.rest.fetch_roles(ctx.member.guild_id) + member_roles = [role for role in guild_roles if role.id in ctx.member.role_ids] + else: + member_roles = [guild_roles.get(role) for role in ctx.member.role_ids] + + return self._handle_result(any(map(self._check_roles, member_roles))) - return self._handle_result(any(map(self.check_roles, member_roles))) + def _check_roles(self, member_role: typing.Union[int, hikari.Role]) -> bool: + if isinstance(member_role, int): + return any(member_role == check for check in self.required_roles) - def check_roles(self, member_role: hikari.Role) -> bool: return any(member_role.id == check or member_role.name == check for check in self.required_roles) From df3ba8d4a9ce4704de2ca76119c6556a2adfe801 Mon Sep 17 00:00:00 2001 From: Patchwork Collective <226386-aster.codes@users.noreply.gitlab.com> Date: Tue, 1 Feb 2022 21:33:59 -0500 Subject: [PATCH 10/10] Added HasAnyRole.ids_only as mentioned in PR --- tanjun/checks.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/tanjun/checks.py b/tanjun/checks.py index aa2760f6c..fb89d00a7 100644 --- a/tanjun/checks.py +++ b/tanjun/checks.py @@ -512,7 +512,7 @@ async def __call__( class HasAnyRoleCheck(_Check): - __slots__ = ("required_roles",) + __slots__ = ("required_roles", "ids_only") def __init__( self, @@ -523,17 +523,21 @@ def __init__( ) -> None: super().__init__(error_message, halt_execution) self.required_roles = roles + self.ids_only = all(isinstance(role, int) for role in self.required_roles) async def __call__(self, ctx: tanjun_abc.Context, /) -> bool: if not ctx.member: return self._handle_result(False) - guild_roles = ctx.cache.get_roles_view_for_guild(ctx.member.guild_id) if ctx.cache else None - if not guild_roles: - guild_roles = await ctx.rest.fetch_roles(ctx.member.guild_id) - member_roles = [role for role in guild_roles if role.id in ctx.member.role_ids] + if not self.ids_only: + guild_roles = ctx.cache.get_roles_view_for_guild(ctx.member.guild_id) if ctx.cache else None + if not guild_roles: + guild_roles = await ctx.rest.fetch_roles(ctx.member.guild_id) + member_roles = [role for role in guild_roles if role.id in ctx.member.role_ids] + else: + member_roles = [guild_roles.get(role) for role in ctx.member.role_ids] else: - member_roles = [guild_roles.get(role) for role in ctx.member.role_ids] + member_roles = ctx.member.role_ids return self._handle_result(any(map(self._check_roles, member_roles)))