diff --git a/tanjun/__init__.py b/tanjun/__init__.py index 96cd4726c..96ccf0944 100644 --- a/tanjun/__init__.py +++ b/tanjun/__init__.py @@ -121,6 +121,7 @@ async def main() -> None: "with_owner_check", "with_author_permission_check", "with_own_permission_check", + "with_any_role_check", # clients.py "clients", "as_loader", diff --git a/tanjun/checks.py b/tanjun/checks.py index 3c3cfa92e..fb89d00a7 100644 --- a/tanjun/checks.py +++ b/tanjun/checks.py @@ -48,6 +48,7 @@ "with_owner_check", "with_author_permission_check", "with_own_permission_check", + "with_any_role_check", "DmCheck", "GuildCheck", "NsfwCheck", @@ -55,6 +56,7 @@ "OwnerCheck", "AuthorPermissionCheck", "OwnPermissionCheck", + "HasAnyRoleCheck", ] import typing @@ -509,6 +511,43 @@ async def __call__( return self._handle_result((permissions & self._permissions) == self._permissions) +class HasAnyRoleCheck(_Check): + __slots__ = ("required_roles", "ids_only") + + def __init__( + self, + roles: collections.Sequence[typing.Union[hikari.SnowflakeishOr[hikari.Role], str]] = [], + *, + error_message: typing.Optional[str] = "You do not have the required roles to use this command!", + halt_execution: bool = True, + ) -> None: + super().__init__(error_message, halt_execution) + self.required_roles = roles + self.ids_only = all(isinstance(role, int) for role in self.required_roles) + + async def __call__(self, ctx: tanjun_abc.Context, /) -> bool: + if not ctx.member: + return self._handle_result(False) + + if not self.ids_only: + guild_roles = ctx.cache.get_roles_view_for_guild(ctx.member.guild_id) if ctx.cache else None + if not guild_roles: + guild_roles = await ctx.rest.fetch_roles(ctx.member.guild_id) + member_roles = [role for role in guild_roles if role.id in ctx.member.role_ids] + else: + member_roles = [guild_roles.get(role) for role in ctx.member.role_ids] + else: + member_roles = ctx.member.role_ids + + return self._handle_result(any(map(self._check_roles, member_roles))) + + def _check_roles(self, member_role: typing.Union[int, hikari.Role]) -> bool: + if isinstance(member_role, int): + return any(member_role == check for check in self.required_roles) + + return any(member_role.id == check or member_role.name == check for check in self.required_roles) + + @typing.overload def with_dm_check(command: CommandT, /) -> CommandT: ... @@ -862,6 +901,41 @@ def with_own_permission_check( ) +def with_any_role_check( + roles: collections.Sequence[typing.Union[hikari.SnowflakeishOr[hikari.Role], int, str]] = [], + *, + error_message: typing.Optional[str] = "You do not have the required roles to use this command!", + halt_execution: bool = False, +) -> collections.Callable[[CommandT], CommandT]: + """Only let a command run if the author has a specific role and the command is called in a guild. + + Parameters + ---------- + roles: collections.Sequence[Union[SnowflakeishOr[Role], int, str]] + The author must have at least one (1) role in this list. (Role.name and Role.id are checked) + + Other Parameters + ---------------- + error_message: Optional[str] + The error message raised if the member does not have a required role. + + Defaults to 'You do not have the required roles to use this command!' + halt_execution: bool + Whether this check should raise `tanjun.errors.HaltExecution` to + end the execution search when it fails instead of returning `False`. + + Defaults to `False`. + + Returns + ------- + collections.abc.Callable[[CommandT], CommandT] + A command decorator callback which adds the check. + """ + return lambda command: command.add_check( + HasAnyRoleCheck(roles, error_message=error_message, halt_execution=halt_execution) + ) + + def with_check(check: tanjun_abc.CheckSig, /) -> collections.Callable[[CommandT], CommandT]: """Add a generic check to a command. diff --git a/tests/test_checks.py b/tests/test_checks.py index 9ba2efb65..d33dc38b8 100644 --- a/tests/test_checks.py +++ b/tests/test_checks.py @@ -558,6 +558,29 @@ def test_with_own_permission_check(command: mock.Mock): own_permission_check.assert_called_once_with(5412312, halt_execution=True, error_message="hi") +def test_with_has_any_role_check(command: mock.Mock): + with mock.patch.object(tanjun.checks, "HasAnyRoleCheck") as any_role_check: + assert ( + tanjun.checks.with_any_role_check( + [ + "Admin", + ], + halt_execution=True, + error_message="hi", + )(command) + is command + ) + + command.add_check.assert_called_once_with(any_role_check.return_value) + any_role_check.assert_called_once_with( + [ + "Admin", + ], + halt_execution=True, + error_message="hi", + ) + + def test_with_check(command: mock.Mock): mock_check = mock.Mock()