Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

with_any_role_check updated #145

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions tanjun/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ async def main() -> None:
"with_owner_check",
"with_author_permission_check",
"with_own_permission_check",
"with_any_role_check",
# clients.py
"clients",
"as_loader",
Expand Down
74 changes: 74 additions & 0 deletions tanjun/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,15 @@
"with_owner_check",
"with_author_permission_check",
"with_own_permission_check",
"with_any_role_check",
"DmCheck",
"GuildCheck",
"NsfwCheck",
"SfwCheck",
"OwnerCheck",
"AuthorPermissionCheck",
"OwnPermissionCheck",
"HasAnyRoleCheck",
]

import typing
Expand Down Expand Up @@ -509,6 +511,43 @@ async def __call__(
return self._handle_result((permissions & self._permissions) == self._permissions)


class HasAnyRoleCheck(_Check):
patchwork-systems marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One style thing, since this was originally written doc strings have been added to the check classes and their inits so that would prob be added for this as well

__slots__ = ("required_roles", "ids_only")

def __init__(
self,
roles: collections.Sequence[typing.Union[hikari.SnowflakeishOr[hikari.Role], str]] = [],
*,
error_message: typing.Optional[str] = "You do not have the required roles to use this command!",
halt_execution: bool = True,
) -> None:
super().__init__(error_message, halt_execution)
self.required_roles = roles
patchwork-systems marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another style thing, these attributes should be private now (so _required_roles and _ids_only

self.ids_only = all(isinstance(role, int) for role in self.required_roles)

async def __call__(self, ctx: tanjun_abc.Context, /) -> bool:
if not ctx.member:
return self._handle_result(False)

if not self.ids_only:
guild_roles = ctx.cache.get_roles_view_for_guild(ctx.member.guild_id) if ctx.cache else None
if not guild_roles:
guild_roles = await ctx.rest.fetch_roles(ctx.member.guild_id)
member_roles = [role for role in guild_roles if role.id in ctx.member.role_ids]
else:
member_roles = [guild_roles.get(role) for role in ctx.member.role_ids]
else:
member_roles = ctx.member.role_ids

return self._handle_result(any(map(self._check_roles, member_roles)))

def _check_roles(self, member_role: typing.Union[int, hikari.Role]) -> bool:
if isinstance(member_role, int):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rather than instance check here could the different checks be performed in the separate parts of the if not self._ids_only else statement to avoid the instance checks all together (if type checking doesn't quite like this you can just cast but i don't think it should be too strict for equality checks)

return any(member_role == check for check in self.required_roles)

return any(member_role.id == check or member_role.name == check for check in self.required_roles)


@typing.overload
def with_dm_check(command: CommandT, /) -> CommandT:
...
Expand Down Expand Up @@ -862,6 +901,41 @@ def with_own_permission_check(
)


def with_any_role_check(
patchwork-systems marked this conversation as resolved.
Show resolved Hide resolved
roles: collections.Sequence[typing.Union[hikari.SnowflakeishOr[hikari.Role], int, str]] = [],
*,
error_message: typing.Optional[str] = "You do not have the required roles to use this command!",
halt_execution: bool = False,
) -> collections.Callable[[CommandT], CommandT]:
"""Only let a command run if the author has a specific role and the command is called in a guild.

patchwork-systems marked this conversation as resolved.
Show resolved Hide resolved
Parameters
----------
roles: collections.Sequence[Union[SnowflakeishOr[Role], int, str]]
The author must have at least one (1) role in this list. (Role.name and Role.id are checked)

Other Parameters
----------------
error_message: Optional[str]
The error message raised if the member does not have a required role.

Defaults to 'You do not have the required roles to use this command!'
halt_execution: bool
Whether this check should raise `tanjun.errors.HaltExecution` to
end the execution search when it fails instead of returning `False`.

Defaults to `False`.
patchwork-systems marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
collections.abc.Callable[[CommandT], CommandT]
A command decorator callback which adds the check.
"""
return lambda command: command.add_check(
HasAnyRoleCheck(roles, error_message=error_message, halt_execution=halt_execution)
)


def with_check(check: tanjun_abc.CheckSig, /) -> collections.Callable[[CommandT], CommandT]:
"""Add a generic check to a command.

Expand Down
23 changes: 23 additions & 0 deletions tests/test_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,29 @@ def test_with_own_permission_check(command: mock.Mock):
own_permission_check.assert_called_once_with(5412312, halt_execution=True, error_message="hi")


def test_with_has_any_role_check(command: mock.Mock):
with mock.patch.object(tanjun.checks, "HasAnyRoleCheck") as any_role_check:
assert (
tanjun.checks.with_any_role_check(
[
"Admin",
],
halt_execution=True,
error_message="hi",
)(command)
is command
)

command.add_check.assert_called_once_with(any_role_check.return_value)
any_role_check.assert_called_once_with(
[
"Admin",
],
halt_execution=True,
error_message="hi",
)


def test_with_check(command: mock.Mock):
mock_check = mock.Mock()

Expand Down