Skip to content

Commit

Permalink
Merge pull request #369 from PrefectHQ/success-kwarg
Browse files Browse the repository at this point in the history
Pass basemodel attributes directly as kwargs
  • Loading branch information
jlowin authored Oct 29, 2024
2 parents 970c9ed + dbcf7a2 commit 0d4100b
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 20 deletions.
1 change: 0 additions & 1 deletion .github/ai-labeler.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
labels:
# Simple form: just the name
- bug
- breaking change
- documentation
Expand Down
60 changes: 41 additions & 19 deletions src/controlflow/tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from prefect.context import TaskRunContext
from pydantic import (
BaseModel,
Field,
PydanticSchemaGenerationError,
RootModel,
Expand All @@ -44,6 +45,7 @@
NOTSET,
ControlFlowModel,
hash_objects,
safe_issubclass,
unwrap,
)
from controlflow.utilities.logging import get_logger
Expand Down Expand Up @@ -624,25 +626,45 @@ def get_success_tool(self) -> Tool:
"Please use a custom type or add compatibility."
)

@tool(
name=f"mark_task_{self.id}_successful",
description=f"Mark task {self.id} as successful.",
instructions=instructions,
include_return_description=False,
)
def succeed(result: result_schema) -> str: # type: ignore
if self.is_successful():
raise ValueError(
f"{self.friendly_name()} is already marked successful."
)
if options:
if result not in options:
raise ValueError(f"Invalid option. Please choose one of {options}")
result = options[result]
self.mark_successful(result=result)
return f"{self.friendly_name()} marked successful."

return succeed
# for basemodel subclasses, we accept the model properties directly as kwargs
if safe_issubclass(result_schema, BaseModel):

def succeed(**kwargs) -> str:
self.mark_successful(result=result_schema(**kwargs))
return f"{self.friendly_name()} marked successful."

return Tool(
fn=succeed,
name=f"mark_task_{self.id}_successful",
description=f"Mark task {self.id} as successful.",
instructions=instructions,
parameters=result_schema.model_json_schema(),
)

# for all other results, we create a single `result` kwarg to capture the result
else:

@tool(
name=f"mark_task_{self.id}_successful",
description=f"Mark task {self.id} as successful.",
instructions=instructions,
include_return_description=False,
)
def succeed(result: result_schema) -> str: # type: ignore
if self.is_successful():
raise ValueError(
f"{self.friendly_name()} is already marked successful."
)
if options:
if result not in options:
raise ValueError(
f"Invalid option. Please choose one of {options}"
)
result = options[result]
self.mark_successful(result=result)
return f"{self.friendly_name()} marked successful."

return succeed

def get_fail_tool(self) -> Tool:
"""
Expand Down
12 changes: 12 additions & 0 deletions src/controlflow/utilities/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,15 @@ class PandasSeries(ControlFlowModel):
index: Optional[list[str]] = None
name: Optional[str] = None
dtype: Optional[str] = None


def safe_issubclass(cls: type, subclass: type) -> bool:
"""
`issubclass` raises a TypeError if cls is not a type. This helper function
safely checks if cls is a type and then checks if it is a subclass of
subclass.
"""
try:
return isinstance(cls, type) and issubclass(cls, subclass)
except TypeError:
return False

0 comments on commit 0d4100b

Please sign in to comment.