From ea1132402ce0204727b8f7b0f578176360efddd2 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Thu, 14 Nov 2024 18:41:06 -0500 Subject: [PATCH 1/2] simplify result instructions --- src/controlflow/settings.py | 4 +++- src/controlflow/tasks/task.py | 16 +++++++++------- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/controlflow/settings.py b/src/controlflow/settings.py index 4f4c8b0..387753b 100644 --- a/src/controlflow/settings.py +++ b/src/controlflow/settings.py @@ -12,7 +12,9 @@ from controlflow.utilities.general import unwrap -CONTROLFLOW_ENV_FILE = os.getenv("CONTROLFLOW_ENV_FILE", "~/.controlflow/.env") +CONTROLFLOW_ENV_FILE = os.path.expanduser( + os.path.expandvars(os.getenv("CONTROLFLOW_ENV_FILE", "~/.controlflow/.env")) +) class ControlFlowSettings(BaseSettings): diff --git a/src/controlflow/tasks/task.py b/src/controlflow/tasks/task.py index bf50249..44a5507 100644 --- a/src/controlflow/tasks/task.py +++ b/src/controlflow/tasks/task.py @@ -673,7 +673,8 @@ def get_success_tool(self) -> Tool: instructions.append( unwrap( f""" - Use this tool to mark the task as successful and provide a result. The result schema is: {result_schema} + Use this tool to mark the task as successful and provide a + result. The result schema is: {result_schema} """ ) ) @@ -696,8 +697,9 @@ def succeed(**kwargs) -> str: instructions.append( unwrap( f""" - Use this tool to mark the task as successful and provide a result with the `task_result` kwarg. - The `task_result` schema is: {{"task_result": {result_schema}}} + Use this tool to mark the task as successful and provide a + `result` value. The `result` value has the following schema: + {result_schema}. """ ) ) @@ -709,18 +711,18 @@ def succeed(**kwargs) -> str: include_return_description=False, metadata=metadata, ) - def succeed(task_result: result_schema) -> str: # type: ignore + def succeed(result: result_schema) -> str: # type: ignore if self.is_successful(): raise ValueError( f"{self.friendly_name()} is already marked successful." ) if options: - if task_result not in options: + if result not in options: raise ValueError( f"Invalid option. Please choose one of {options}" ) - task_result = options[task_result] - self.mark_successful(result=task_result) + result = options[result] + self.mark_successful(result=result) return f"{self.friendly_name()} marked successful." return succeed From fa18e8acf152325bfc1412ca74f83c4005a3e69c Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Thu, 14 Nov 2024 18:44:44 -0500 Subject: [PATCH 2/2] Update task result references --- tests/tasks/test_tasks.py | 12 ++++++------ tests/test_run.py | 2 +- tests/utilities/test_testing.py | 4 ++-- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/tasks/test_tasks.py b/tests/tasks/test_tasks.py index 5248023..fa6b2fb 100644 --- a/tests/tasks/test_tasks.py +++ b/tests/tasks/test_tasks.py @@ -485,14 +485,14 @@ class TestSuccessTool: def test_success_tool(self): task = Task("choose 5", result_type=int) tool = task.get_success_tool() - tool.run(input=dict(task_result=5)) + tool.run(input=dict(result=5)) assert task.is_successful() assert task.result == 5 def test_success_tool_with_list_of_options(self): task = Task('choose "good"', result_type=["bad", "good", "medium"]) tool = task.get_success_tool() - tool.run(input=dict(task_result=1)) + tool.run(input=dict(result=1)) assert task.is_successful() assert task.result == "good" @@ -500,12 +500,12 @@ def test_success_tool_with_list_of_options_requires_int(self): task = Task('choose "good"', result_type=["bad", "good", "medium"]) tool = task.get_success_tool() with pytest.raises(ValueError): - tool.run(input=dict(task_result="good")) + tool.run(input=dict(result="good")) def test_tuple_of_ints_result(self): task = Task("choose 5", result_type=(4, 5, 6)) tool = task.get_success_tool() - tool.run(input=dict(task_result=1)) + tool.run(input=dict(result=1)) assert task.result == 5 def test_tuple_of_pydantic_models_result(self): @@ -518,7 +518,7 @@ class Person(BaseModel): result_type=(Person(name="Alice", age=30), Person(name="Bob", age=35)), ) tool = task.get_success_tool() - tool.run(input=dict(task_result=1)) + tool.run(input=dict(result=1)) assert task.result == Person(name="Bob", age=35) assert isinstance(task.result, Person) @@ -604,7 +604,7 @@ def test_invalid_completion_tool(self): def test_manual_success_tool(self): task = Task(objective="Test task", completion_tools=[], result_type=int) success_tool = task.get_success_tool() - success_tool.run(input=dict(task_result=5)) + success_tool.run(input=dict(result=5)) assert task.is_successful() assert task.result == 5 diff --git a/tests/test_run.py b/tests/test_run.py index c8f1cab..a78e37a 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -191,7 +191,7 @@ def task(self, default_fake_llm): tool_calls=[ { "name": "mark_task_12345_successful", - "args": {"task_result": "Hello!"}, + "args": {"result": "Hello!"}, "id": "call_ZEPdV8mCgeBe5UHjKzm6e3pe", "type": "tool_call", } diff --git a/tests/utilities/test_testing.py b/tests/utilities/test_testing.py index 380bdd5..e4acefd 100644 --- a/tests/utilities/test_testing.py +++ b/tests/utilities/test_testing.py @@ -20,7 +20,7 @@ def test_record_task_events(default_fake_llm): tool_calls=[ { "name": "mark_task_12345_successful", - "args": {"task_result": "Hello!"}, + "args": {"result": "Hello!"}, "id": "call_ZEPdV8mCgeBe5UHjKzm6e3pe", "type": "tool_call", } @@ -39,7 +39,7 @@ def test_record_task_events(default_fake_llm): assert events[3].event == "tool-result" assert events[3].tool_result.tool_call == { "name": "mark_task_12345_successful", - "args": {"task_result": "Hello!"}, + "args": {"result": "Hello!"}, "id": "call_ZEPdV8mCgeBe5UHjKzm6e3pe", "type": "tool_call", }