Skip to content

Commit

Permalink
Merge pull request #337 from PrefectHQ/decorator
Browse files Browse the repository at this point in the history
Support async tasks with task decorator
  • Loading branch information
jlowin authored Sep 26, 2024
2 parents d62e6b3 + b4d859c commit a222edd
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 9 deletions.
24 changes: 15 additions & 9 deletions src/controlflow/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,18 +201,24 @@ def _get_task(*args, **kwargs) -> Task:
**task_kwargs,
)

@functools.wraps(fn)
@prefect_task(
if asyncio.iscoroutinefunction(fn):

@functools.wraps(fn)
async def wrapper(*args, **kwargs):
task = _get_task(*args, **kwargs)
return await task.run_async()
else:

@functools.wraps(fn)
def wrapper(*args, **kwargs):
task = _get_task(*args, **kwargs)
return task.run()

wrapper = prefect_task(
timeout_seconds=timeout_seconds,
retries=retries,
retry_delay_seconds=retry_delay_seconds,
)
def wrapper(
*args,
**kwargs,
):
task = _get_task(*args, **kwargs)
return task.run()
)(wrapper)

# store the `as_task` method for loading the task object
wrapper.as_task = _get_task
Expand Down
36 changes: 36 additions & 0 deletions tests/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,39 @@ def partial_flow():

result = partial_flow()
assert result == 10


class TestTaskDecorator:
def test_task_decorator_sync_as_task(self):
@controlflow.task
def write_poem(topic: str) -> str:
"""write a two-line poem about `topic`"""

task = write_poem.as_task("AI")
assert task.name == "write_poem"
assert task.objective == "write a two-line poem about `topic`"
assert task.result_type is str

def test_task_decorator_async_as_task(self):
@controlflow.task
async def write_poem(topic: str) -> str:
"""write a two-line poem about `topic`"""

task = write_poem.as_task("AI")
assert task.name == "write_poem"
assert task.objective == "write a two-line poem about `topic`"
assert task.result_type is str

def test_task_decorator_sync(self):
@controlflow.task
def write_poem(topic: str) -> str:
"""write a two-line poem about `topic`"""

assert write_poem("AI")

async def test_task_decorator_async(self):
@controlflow.task
async def write_poem(topic: str) -> str:
"""write a two-line poem about `topic`"""

assert await write_poem("AI")

0 comments on commit a222edd

Please sign in to comment.