Skip to content

Commit

Permalink
Add RetryingCaller and AsyncRetryingCaller (#56)
Browse files Browse the repository at this point in the history
* Add RetryingCaller and AsyncRetryingCaller

Implements #45

* Add PR link

* Update CHANGELOG.md
  • Loading branch information
hynek committed Jan 28, 2024
1 parent 6ddcc7d commit b5314a2
Show file tree
Hide file tree
Showing 8 changed files with 210 additions and 2 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ You can find our backwards-compatibility policy [here](https://github.com/hynek/

## [Unreleased](https://github.com/hynek/stamina/compare/24.1.0...HEAD)

### Added

- `stamina.RetryingCaller` and `stamina.AsyncRetryingCaller` that allow even easier retries of single callables.
[#56](https://github.com/hynek/stamina/pull/56)


## [24.1.0](https://github.com/hynek/stamina/compare/23.3.0...24.1.0) - 2024-01-03

Expand Down
2 changes: 2 additions & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
.. autofunction:: retry_context
.. autoclass:: Attempt
:members: num
.. autoclass:: RetryingCaller
.. autoclass:: AsyncRetryingCaller
```


Expand Down
15 changes: 15 additions & 0 deletions docs/tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,21 @@ for attempt in stamina.retry_context(on=httpx.HTTPError):
resp.raise_for_status()
```

If you want to retry just one function call, *stamina* comes with an even easier way in the shape of {class}`stamina.RetryingCaller` and {class}`stamina.AsyncRetryingCaller`:

```python
def do_something_with_url(url, some_kw):
resp = httpx.get(url)
resp.raise_for_status()
...

rc = stamina.RetryingCaller(on=httpx.HTTPError)

rc(do_something_with_url, f"https://httpbin.org/status/404", some_kw=42)
```

The last line calls `do_something_with_url(f"https://httpbin.org/status/404", some_kw=42)` and retries on `httpx.HTTPError`.


## Async

Expand Down
10 changes: 9 additions & 1 deletion src/stamina/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@

from . import instrumentation
from ._config import is_active, set_active
from ._core import Attempt, retry, retry_context
from ._core import (
AsyncRetryingCaller,
Attempt,
RetryingCaller,
retry,
retry_context,
)


__all__ = [
Expand All @@ -14,6 +20,8 @@
"is_active",
"set_active",
"instrumentation",
"RetryingCaller",
"AsyncRetryingCaller",
]


Expand Down
91 changes: 90 additions & 1 deletion src/stamina/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from functools import wraps
from inspect import iscoroutinefunction
from types import TracebackType
from typing import AsyncIterator, Iterator, TypeVar
from typing import AsyncIterator, Awaitable, Iterator, TypedDict, TypeVar

import tenacity as _t

Expand Down Expand Up @@ -126,6 +126,95 @@ def __exit__(
)


class RetryKWs(TypedDict):
on: type[Exception] | tuple[type[Exception], ...]
attempts: int | None
timeout: float | dt.timedelta | None
wait_initial: float | dt.timedelta
wait_max: float | dt.timedelta
wait_jitter: float | dt.timedelta
wait_exp_base: float


class BaseRetryingCaller:
"""
.. versionadded:: 24.2.0
"""

__slots__ = ("_context_kws",)

_context_kws: RetryKWs

def __init__(
self,
on: type[Exception] | tuple[type[Exception], ...],
attempts: int | None = 10,
timeout: float | dt.timedelta | None = 45.0,
wait_initial: float | dt.timedelta = 0.1,
wait_max: float | dt.timedelta = 5.0,
wait_jitter: float | dt.timedelta = 1.0,
wait_exp_base: float = 2.0,
):
self._context_kws = {
"on": on,
"attempts": attempts,
"timeout": timeout,
"wait_initial": wait_initial,
"wait_max": wait_max,
"wait_jitter": wait_jitter,
"wait_exp_base": wait_exp_base,
}

def __repr__(self) -> str:
on = guess_name(self._context_kws["on"])
kws = ", ".join(
f"{k}={self._context_kws[k]!r}" # type: ignore[literal-required]
for k in sorted(self._context_kws)
if k != "on"
)
return f"<{self.__class__.__name__}(on={on}, {kws})>"


class RetryingCaller(BaseRetryingCaller):
"""
An object that will call your callable with retries.
Instances of `RetryingCaller` may be reused because they create a new
:func:`retry_context` iterator on each call.
.. versionadded:: 24.2.0
"""

def __call__(
self, func: Callable[P, T], /, *args: P.args, **kw: P.kwargs
) -> T:
for attempt in retry_context(**self._context_kws):
with attempt:
return func(*args, **kw)

raise SystemError("unreachable") # pragma: no cover # noqa: EM101


class AsyncRetryingCaller(BaseRetryingCaller):
"""
An object that will call your async callable with retries.
Instances of `AsyncRetryingCaller` may be reused because they create a new
:func:`retry_context` iterator on each call.
.. versionadded:: 24.2.0
"""

async def __call__(
self, func: Callable[P, Awaitable[T]], /, *args: P.args, **kw: P.kwargs
) -> T:
async for attempt in retry_context(**self._context_kws):
with attempt:
return await func(*args, **kw)

raise SystemError("unreachable") # pragma: no cover # noqa: EM101


_STOP_NO_RETRY = _t.stop_after_attempt(1)


Expand Down
24 changes: 24 additions & 0 deletions tests/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,3 +209,27 @@ async def test_retry_blocks_can_be_disabled():
raise Exception("passed")

assert 1 == num_called


class TestAsyncRetryingCaller:
async def test_retries(self):
"""
Retries if the specific error is raised. Arguments are passed through.
"""
i = 0

async def f(*args, **kw):
nonlocal i
if i < 1:
i += 1
raise ValueError

return args, kw

arc = stamina.AsyncRetryingCaller(on=ValueError)

args, kw = await arc(f, 42, foo="bar")

assert 1 == i
assert (42,) == args
assert {"foo": "bar"} == kw
44 changes: 44 additions & 0 deletions tests/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,47 @@ def test_never(self):
If all conditions are None, return stop_never.
"""
assert tenacity.stop_never is _make_stop(attempts=None, timeout=None)


class TestRetryingCaller:
def test_retries(self):
"""
Retries if the specific error is raised. Arguments are passed through.
"""
i = 0

def f(*args, **kw):
nonlocal i
if i < 1:
i += 1
raise ValueError

return args, kw

rc = stamina.RetryingCaller(on=ValueError)

args, kw = rc(f, 42, foo="bar")

assert 1 == i
assert (42,) == args
assert {"foo": "bar"} == kw

def test_repr(self):
"""
repr() is useful.
"""
rc = stamina.RetryingCaller(
on=ValueError,
attempts=42,
timeout=13.0,
wait_initial=23,
wait_max=123,
wait_jitter=0.42,
wait_exp_base=666,
)

assert (
"<RetryingCaller(on=ValueError, attempts=42, timeout=13.0, "
"wait_exp_base=666, wait_initial=23, wait_jitter=0.42, "
"wait_max=123)>"
) == repr(rc)
21 changes: 21 additions & 0 deletions tests/typing/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import datetime as dt

from stamina import (
AsyncRetryingCaller,
RetryingCaller,
is_active,
retry,
retry_context,
Expand Down Expand Up @@ -125,3 +127,22 @@ async def f() -> None:
):
with attempt:
pass


def sync_f(x: int, foo: str) -> bool:
return True


rc = RetryingCaller(on=ValueError, timeout=13.0, attempts=10)
b: bool = rc(sync_f, 1, foo="bar")


async def async_f(x: int, foo: str) -> bool:
return True


arc = AsyncRetryingCaller(on=ValueError, timeout=13.0, attempts=10)


async def g() -> bool:
return await arc(async_f, 1, foo="bar")

0 comments on commit b5314a2

Please sign in to comment.