Skip to content

Commit

Permalink
Merge pull request #387 from PrefectHQ/schema-instructions
Browse files Browse the repository at this point in the history
simplify result instructions
  • Loading branch information
jlowin authored Nov 14, 2024
2 parents ec453cc + fa18e8a commit 9e483fc
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 17 deletions.
4 changes: 3 additions & 1 deletion src/controlflow/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@

from controlflow.utilities.general import unwrap

CONTROLFLOW_ENV_FILE = os.getenv("CONTROLFLOW_ENV_FILE", "~/.controlflow/.env")
CONTROLFLOW_ENV_FILE = os.path.expanduser(
os.path.expandvars(os.getenv("CONTROLFLOW_ENV_FILE", "~/.controlflow/.env"))
)


class ControlFlowSettings(BaseSettings):
Expand Down
16 changes: 9 additions & 7 deletions src/controlflow/tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,8 @@ def get_success_tool(self) -> Tool:
instructions.append(
unwrap(
f"""
Use this tool to mark the task as successful and provide a result. The result schema is: {result_schema}
Use this tool to mark the task as successful and provide a
result. The result schema is: {result_schema}
"""
)
)
Expand All @@ -696,8 +697,9 @@ def succeed(**kwargs) -> str:
instructions.append(
unwrap(
f"""
Use this tool to mark the task as successful and provide a result with the `task_result` kwarg.
The `task_result` schema is: {{"task_result": {result_schema}}}
Use this tool to mark the task as successful and provide a
`result` value. The `result` value has the following schema:
{result_schema}.
"""
)
)
Expand All @@ -709,18 +711,18 @@ def succeed(**kwargs) -> str:
include_return_description=False,
metadata=metadata,
)
def succeed(task_result: result_schema) -> str: # type: ignore
def succeed(result: result_schema) -> str: # type: ignore
if self.is_successful():
raise ValueError(
f"{self.friendly_name()} is already marked successful."
)
if options:
if task_result not in options:
if result not in options:
raise ValueError(
f"Invalid option. Please choose one of {options}"
)
task_result = options[task_result]
self.mark_successful(result=task_result)
result = options[result]
self.mark_successful(result=result)
return f"{self.friendly_name()} marked successful."

return succeed
Expand Down
12 changes: 6 additions & 6 deletions tests/tasks/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,27 +485,27 @@ class TestSuccessTool:
def test_success_tool(self):
task = Task("choose 5", result_type=int)
tool = task.get_success_tool()
tool.run(input=dict(task_result=5))
tool.run(input=dict(result=5))
assert task.is_successful()
assert task.result == 5

def test_success_tool_with_list_of_options(self):
task = Task('choose "good"', result_type=["bad", "good", "medium"])
tool = task.get_success_tool()
tool.run(input=dict(task_result=1))
tool.run(input=dict(result=1))
assert task.is_successful()
assert task.result == "good"

def test_success_tool_with_list_of_options_requires_int(self):
task = Task('choose "good"', result_type=["bad", "good", "medium"])
tool = task.get_success_tool()
with pytest.raises(ValueError):
tool.run(input=dict(task_result="good"))
tool.run(input=dict(result="good"))

def test_tuple_of_ints_result(self):
task = Task("choose 5", result_type=(4, 5, 6))
tool = task.get_success_tool()
tool.run(input=dict(task_result=1))
tool.run(input=dict(result=1))
assert task.result == 5

def test_tuple_of_pydantic_models_result(self):
Expand All @@ -518,7 +518,7 @@ class Person(BaseModel):
result_type=(Person(name="Alice", age=30), Person(name="Bob", age=35)),
)
tool = task.get_success_tool()
tool.run(input=dict(task_result=1))
tool.run(input=dict(result=1))
assert task.result == Person(name="Bob", age=35)
assert isinstance(task.result, Person)

Expand Down Expand Up @@ -604,7 +604,7 @@ def test_invalid_completion_tool(self):
def test_manual_success_tool(self):
task = Task(objective="Test task", completion_tools=[], result_type=int)
success_tool = task.get_success_tool()
success_tool.run(input=dict(task_result=5))
success_tool.run(input=dict(result=5))
assert task.is_successful()
assert task.result == 5

Expand Down
2 changes: 1 addition & 1 deletion tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def task(self, default_fake_llm):
tool_calls=[
{
"name": "mark_task_12345_successful",
"args": {"task_result": "Hello!"},
"args": {"result": "Hello!"},
"id": "call_ZEPdV8mCgeBe5UHjKzm6e3pe",
"type": "tool_call",
}
Expand Down
4 changes: 2 additions & 2 deletions tests/utilities/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_record_task_events(default_fake_llm):
tool_calls=[
{
"name": "mark_task_12345_successful",
"args": {"task_result": "Hello!"},
"args": {"result": "Hello!"},
"id": "call_ZEPdV8mCgeBe5UHjKzm6e3pe",
"type": "tool_call",
}
Expand All @@ -39,7 +39,7 @@ def test_record_task_events(default_fake_llm):
assert events[3].event == "tool-result"
assert events[3].tool_result.tool_call == {
"name": "mark_task_12345_successful",
"args": {"task_result": "Hello!"},
"args": {"result": "Hello!"},
"id": "call_ZEPdV8mCgeBe5UHjKzm6e3pe",
"type": "tool_call",
}
Expand Down

0 comments on commit 9e483fc

Please sign in to comment.