Skip to content

Commit

Permalink
Fix Await fiasco
Browse files Browse the repository at this point in the history
  • Loading branch information
majdyz committed May 16, 2024
1 parent ea134c7 commit 7b5272f
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 16 deletions.
1 change: 1 addition & 0 deletions autogpts/autogpt/autogpt/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ async def execute(
# Get commands
self.commands = await self.run_pipeline(CommandProvider.get_commands)
self._remove_disabled_commands()
self.code_flow_executor.set_available_functions(self.commands)

try:
return_value = await self._execute_tool(tool)
Expand Down
39 changes: 25 additions & 14 deletions autogpts/autogpt/autogpt/agents/prompt_strategies/code_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@

from pydantic import BaseModel, Field

from autogpt.agents.base import BaseAgentActionProposal
from autogpt.agents.prompt_strategies.one_shot import OneShotAgentPromptConfiguration, AssistantThoughts
from autogpt.agents.prompt_strategies.one_shot import OneShotAgentPromptConfiguration, AssistantThoughts, OneShotAgentActionProposal
from autogpt.config.ai_directives import AIDirectives
from autogpt.config.ai_profile import AIProfile
from autogpt.core.configuration.schema import SystemConfiguration
Expand Down Expand Up @@ -59,7 +58,8 @@ class CodeFlowAgentActionProposal(BaseModel):
"arguments can't be determined yet. Reduce the amount of unnecessary data passed into "
"these magic functions where possible, because magic costs money and magically "
"processing large amounts of data is expensive. If you think are done with the task, "
"you can simply call finish(reason='your reason') to end the task. "
"you can simply call finish(reason='your reason') to end the task, "
"a function that has one `finish` command, don't mix finish with other functions. "
)


Expand Down Expand Up @@ -187,7 +187,7 @@ def _generate_function_headers(self, funcs: list[CompletionModelFunction]) -> st
async def parse_response_content(
self,
response: AssistantChatMessage,
) -> BaseAgentActionProposal:
) -> OneShotAgentActionProposal:
if not response.content:
raise InvalidAgentResponseError("Assistant response has no text content")

Expand All @@ -210,6 +210,7 @@ async def parse_response_content(
name=f.name,
arg_types=[(name, p.python_type) for name, p in f.parameters.items()],
arg_descs={name: p.description for name, p in f.parameters.items()},
arg_defaults={name: p.default or "None" for name, p in f.parameters.items() if p.default or not p.required},
return_type="str",
return_desc="Output of the function",
function_desc=f.description,
Expand All @@ -235,14 +236,24 @@ async def parse_response_content(
available_functions=available_functions,
).validate_code(parsed_response.python_code)

result = BaseAgentActionProposal(
thoughts=parsed_response.thoughts,
use_tool=AssistantFunctionCall(
name="execute_code_flow",
arguments={
"python_code": code_validation.functionCode,
"plan_text": parsed_response.immediate_plan,
},
),
)
if re.search(r"finish\((.*?)\)", code_validation.functionCode):
finish_reason = re.search(r"finish\((reason=)?(.*?)\)", code_validation.functionCode).group(2)
result = OneShotAgentActionProposal(
thoughts=parsed_response.thoughts,
use_tool=AssistantFunctionCall(
name="finish",
arguments={"reason": finish_reason[1:-1]},
),
)
else:
result = OneShotAgentActionProposal(
thoughts=parsed_response.thoughts,
use_tool=AssistantFunctionCall(
name="execute_code_flow",
arguments={
"python_code": code_validation.functionCode,
"plan_text": parsed_response.immediate_plan,
},
),
)
return result
31 changes: 30 additions & 1 deletion autogpts/autogpt/autogpt/utils/function/code_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,6 @@ async def validate_code(
validation_errors=validation_errors,
)
function_template = main_func.function_template
function_code = main_func.function_code
else:
function_template = None

Expand Down Expand Up @@ -397,6 +396,7 @@ async def __execute_pyright_commands(code: str) -> list[str]:

# read code from code.py. split the code into imports and raw code
code = open(f"{temp_dir}/code.py").read()
code, error_messages = await __fix_async_calls(code, validation_errors)
func.imports, func.rawCode = __unpack_import_and_function_code(code)

return validation_errors
Expand Down Expand Up @@ -450,6 +450,35 @@ async def find_module_dist_and_source(
AUTO_IMPORT_TYPES[t] = f"from collections import {t}"


async def __fix_async_calls(code: str, errors: list[str]) -> tuple[str, list[str]]:
"""
Fix the async calls in the code
Args:
code (str): The code snippet
errors (list[str]): The list of errors
func (ValidationResponse): The function to fix the async calls
Returns:
tuple[str, list[str]]: The fixed code snippet and the list of errors
"""
async_calls = set()
new_errors = []
for error in errors:
pattern = '"__await__" is not present. reportGeneralTypeIssues -> (.+)'
match = re.search(pattern, error)
if match:
async_calls.add(match.group(1))
else:
new_errors.append(error)

for async_call in async_calls:
func_call = re.search(r"await ([a-zA-Z0-9_]+)", async_call)
if func_call:
func_name = func_call.group(1)
code = code.replace(f"await {func_name}", f"{func_name}")

return code, new_errors


async def __fix_missing_imports(
errors: list[str], func: ValidationResponse
) -> tuple[set[str], list[str]]:
Expand Down
6 changes: 5 additions & 1 deletion autogpts/autogpt/autogpt/utils/function/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class ObjectField(BaseModel):
class FunctionDef(BaseModel):
name: str
arg_types: list[tuple[str, str]]
arg_defaults: dict[str, str] = {}
arg_descs: dict[str, str]
return_type: str | None = None
return_desc: str
Expand All @@ -46,7 +47,10 @@ class FunctionDef(BaseModel):
is_async: bool = False

def __generate_function_template(f) -> str:
args_str = ", ".join([f"{name}: {type}" for name, type in f.arg_types])
args_str = ", ".join([
f"{name}: {type}" + (f" = {f.arg_defaults.get(name, '')}" if name in f.arg_defaults else "")
for name, type in f.arg_types
])
arg_desc = f"\n{' '*4}".join(
[
f'{name} ({type}): {f.arg_descs.get(name, "-")}'
Expand Down
6 changes: 6 additions & 0 deletions autogpts/autogpt/tests/unit/test_function_code_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,22 @@ def crawl_info(url: str, query: str) -> str | None:
return None
def hehe():
return 'hehe'
def main() -> str:
query = "Find the number of contributors to the autogpt github repository, or if any, list of urls that can be crawled to find the number of contributors"
for title, url in ("autogpt github contributor page"):
info = await crawl_info(url, query)
if info:
return info
x = await hehe()
return "No info found"
""",
packages=[],
)
assert response.functionCode is not None
assert "async def crawl_info" in response.functionCode # async is added
assert "async def main" in response.functionCode
assert "x = hehe()" in response.functionCode # await is removed

0 comments on commit 7b5272f

Please sign in to comment.