Skip to content

Commit

Permalink
feat: names for hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
tandemdude committed Sep 18, 2024
1 parent 2f7dc07 commit d37e5ed
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 10 deletions.
1 change: 1 addition & 0 deletions fragments/+hook_name.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `name` argument to the `hook()` decorator and `name` attribute to `ExecutionHook`.
10 changes: 8 additions & 2 deletions lightbulb/commands/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ class ExecutionHook:
"""The step that this hook should be run during."""
skip_when_failed: bool
"""Whether this hook should be skipped if the pipeline has already failed."""
name: str
"""The name of this hook."""
func: ExecutionHookFunc
"""The function that this hook executes."""

Expand Down Expand Up @@ -263,7 +265,9 @@ async def _run(self) -> None:
)


def hook(step: ExecutionStep, skip_when_failed: bool = False) -> Callable[[ExecutionHookFunc], ExecutionHook]:
def hook(
step: ExecutionStep, skip_when_failed: bool = False, name: str = ""
) -> Callable[[ExecutionHookFunc], ExecutionHook]:
"""
Second order decorator to convert a function into an execution hook for the given
step. Also enables dependency injection on the decorated function.
Expand All @@ -283,6 +287,8 @@ def example_hook(pl: lightbulb.ExecutionPipeline, ctx: lightbulb.Context) -> Non
step: The step that this hook should be run during.
skip_when_failed: Whether this hook should be skipped if the :obj:`~ExecutionPipeline`
has already failed due to a different hook or command invocation exception. Defaults to :obj:`False`.
name: The name of the hook. If not specified (an empty string), this will be set to the name of the
hook function.
Returns:
:obj:`~ExecutionHook`: The created execution hook.
Expand All @@ -303,7 +309,7 @@ def only_on_mondays(pl: lightbulb.ExecutionPipeline, _: lightbulb.Context) -> No
raise ValueError("hooks cannot be registered for the 'INVOKE' execution step")

def inner(func: ExecutionHookFunc) -> ExecutionHook:
return ExecutionHook(step, skip_when_failed, di.with_di(func)) # type: ignore[reportArgumentType]
return ExecutionHook(step, skip_when_failed, name or func.__name__, di.with_di(func)) # type: ignore[reportArgumentType]

return inner

Expand Down
8 changes: 4 additions & 4 deletions lightbulb/prefab/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class NotOwner(Exception):
"""Exception raised when a user that does not own the bot attempts to invoke a protected command."""


@execution.hook(execution.ExecutionSteps.CHECKS, skip_when_failed=True)
@execution.hook(execution.ExecutionSteps.CHECKS, skip_when_failed=True, name="owner_only")
async def owner_only(_: execution.ExecutionPipeline, ctx: context.Context) -> None:
"""
Hook that checks whether the user invoking the command is an owner of the bot. This takes into account
Expand Down Expand Up @@ -97,7 +97,7 @@ class YourCommand(
...
"""

@execution.hook(execution.ExecutionSteps.CHECKS, skip_when_failed=True)
@execution.hook(execution.ExecutionSteps.CHECKS, skip_when_failed=True, name="has_permissions")
def _has_permissions(_: execution.ExecutionPipeline, ctx: context.Context) -> None:
if ctx.member is None:
if fail_in_dm:
Expand Down Expand Up @@ -143,7 +143,7 @@ class YourCommand(
...
"""

@execution.hook(execution.ExecutionSteps.CHECKS, skip_when_failed=True)
@execution.hook(execution.ExecutionSteps.CHECKS, skip_when_failed=True, name="bot_has_permissions")
def _bot_has_permissions(_: execution.ExecutionPipeline, ctx: context.Context) -> None:
if ctx.interaction.app_permissions is None:
if fail_in_dm:
Expand Down Expand Up @@ -194,7 +194,7 @@ class YourCommand(
"""
flattened_role_ids = [elem for item in role_ids for elem in (item if isinstance(item, Iterable) else [item])]

@execution.hook(execution.ExecutionSteps.CHECKS, skip_when_failed=True)
@execution.hook(execution.ExecutionSteps.CHECKS, skip_when_failed=True, name="has_roles")
def _has_roles(_: execution.ExecutionPipeline, ctx: context.Context) -> None:
if ctx.member is None:
if fail_in_dm:
Expand Down
4 changes: 2 additions & 2 deletions lightbulb/prefab/concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,14 @@ class YourCommand(
invocations: dict[hikari.Snowflakeish, int] = collections.defaultdict(lambda: 0)
bucket_callable = _PROVIDED_BUCKETS[bucket] if isinstance(bucket, str) else bucket

@execution.hook(execution.ExecutionSteps.MAX_CONCURRENCY)
@execution.hook(execution.ExecutionSteps.MAX_CONCURRENCY, name="incr_concurrency")
async def _increment_invocation_count(_: execution.ExecutionPipeline, ctx: context.Context) -> None:
if invocations[hash := await utils.maybe_await(bucket_callable(ctx))] >= n_invocations:
raise MaxConcurrencyReached

invocations[hash] += 1

@execution.hook(execution.ExecutionSteps.POST_INVOKE)
@execution.hook(execution.ExecutionSteps.POST_INVOKE, name="decr_concurrency")
async def _decrement_invocation_count(_: execution.ExecutionPipeline, ctx: context.Context) -> None:
invocations[hash] = min(invocations[hash := await utils.maybe_await(bucket_callable(ctx))] - 1, 0)

Expand Down
4 changes: 2 additions & 2 deletions lightbulb/prefab/cooldowns.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def fixed_window(
Returns:
The created hook.
"""
return execution.hook(execution.ExecutionSteps.COOLDOWNS, skip_when_failed=True)(
return execution.hook(execution.ExecutionSteps.COOLDOWNS, skip_when_failed=True, name="fixed_window")(
_FixedWindow(
window_length, allowed_invocations, _PROVIDED_BUCKETS[bucket] if isinstance(bucket, str) else bucket
)
Expand Down Expand Up @@ -165,7 +165,7 @@ def sliding_window(
Returns:
The created hook.
"""
return execution.hook(execution.ExecutionSteps.COOLDOWNS, skip_when_failed=True)(
return execution.hook(execution.ExecutionSteps.COOLDOWNS, skip_when_failed=True, name="sliding_window")(
_SlidingWindow(
window_length, allowed_invocations, _PROVIDED_BUCKETS[bucket] if isinstance(bucket, str) else bucket
)
Expand Down

0 comments on commit d37e5ed

Please sign in to comment.