Skip to content

Commit

Permalink
Merge pull request #1069 from Pythagora-io/relevant
Browse files Browse the repository at this point in the history
Relevant
  • Loading branch information
LeonOstrez authored Aug 6, 2024
2 parents b33282f + 1c5ece7 commit 82411d8
Show file tree
Hide file tree
Showing 23 changed files with 227 additions and 101 deletions.
4 changes: 4 additions & 0 deletions core/agents/convo.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,7 @@ def remove_defs(d):
f"YOU MUST NEVER add any additional fields to your response, and NEVER add additional preamble like 'Here is your JSON'."
)
return self

def remove_last_x_messages(self, x: int) -> "AgentConvo":
self.messages = self.messages[:-x]
return self
81 changes: 48 additions & 33 deletions core/agents/developer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Optional
from enum import Enum
from typing import Annotated, Literal, Optional, Union
from uuid import uuid4

from pydantic import BaseModel, Field

from core.agents.base import BaseAgent
from core.agents.convo import AgentConvo
from core.agents.mixins import TaskSteps
from core.agents.mixins import RelevantFilesMixin
from core.agents.response import AgentResponse, ResponseType
from core.config import TASK_BREAKDOWN_AGENT_NAME
from core.db.models.project_state import IterationStatus, TaskStatus
Expand All @@ -17,11 +18,48 @@
log = get_logger(__name__)


class RelevantFiles(BaseModel):
relevant_files: list[str] = Field(description="List of relevant files for the current task.")
class StepType(str, Enum):
COMMAND = "command"
SAVE_FILE = "save_file"
HUMAN_INTERVENTION = "human_intervention"


class Developer(BaseAgent):
class CommandOptions(BaseModel):
command: str = Field(description="Command to run")
timeout: int = Field(description="Timeout in seconds")
success_message: str = ""


class SaveFileOptions(BaseModel):
path: str


class SaveFileStep(BaseModel):
type: Literal[StepType.SAVE_FILE] = StepType.SAVE_FILE
save_file: SaveFileOptions


class CommandStep(BaseModel):
type: Literal[StepType.COMMAND] = StepType.COMMAND
command: CommandOptions


class HumanInterventionStep(BaseModel):
type: Literal[StepType.HUMAN_INTERVENTION] = StepType.HUMAN_INTERVENTION
human_intervention_description: str


Step = Annotated[
Union[SaveFileStep, CommandStep, HumanInterventionStep],
Field(discriminator="type"),
]


class TaskSteps(BaseModel):
steps: list[Step]


class Developer(RelevantFilesMixin, BaseAgent):
agent_type = "developer"
display_name = "Developer"

Expand Down Expand Up @@ -96,7 +134,8 @@ async def breakdown_current_iteration(self, task_review_feedback: Optional[str]
log.debug(f"Breaking down the iteration {description}")
await self.send_message("Breaking down the current task iteration ...")

await self.get_relevant_files(user_feedback, description)
if self.current_state.files and self.current_state.relevant_files is None:
return await self.get_relevant_files(user_feedback, description)

await self.ui.send_task_progress(
n_tasks, # iterations and reviews can be created only one at a time, so we are always on last one
Expand All @@ -114,7 +153,6 @@ async def breakdown_current_iteration(self, task_review_feedback: Optional[str]
AgentConvo(self)
.template(
"iteration",
current_task=current_task,
user_feedback=user_feedback,
user_feedback_qa=None,
next_solution_to_try=None,
Expand Down Expand Up @@ -175,7 +213,7 @@ async def breakdown_current_task(self) -> AgentResponse:
log.debug(f"Current state files: {len(self.current_state.files)}, relevant {self.current_state.relevant_files}")
# Check which files are relevant to the current task
if self.current_state.files and self.current_state.relevant_files is None:
await self.get_relevant_files()
return await self.get_relevant_files()

current_task_index = self.current_state.tasks.index(current_task)

Expand All @@ -189,6 +227,8 @@ async def breakdown_current_task(self) -> AgentResponse:
)
response: str = await llm(convo)

await self.get_relevant_files(None, response)

self.next_state.tasks[current_task_index] = {
**current_task,
"instructions": response,
Expand All @@ -214,31 +254,6 @@ async def breakdown_current_task(self) -> AgentResponse:
)
return AgentResponse.done(self)

async def get_relevant_files(
self, user_feedback: Optional[str] = None, solution_description: Optional[str] = None
) -> AgentResponse:
log.debug("Getting relevant files for the current task")
await self.send_message("Figuring out which project files are relevant for the next task ...")

llm = self.get_llm()
convo = (
AgentConvo(self)
.template(
"filter_files",
current_task=self.current_state.current_task,
user_feedback=user_feedback,
solution_description=solution_description,
)
.require_schema(RelevantFiles)
)

llm_response: list[str] = await llm(convo, parser=JSONParser(RelevantFiles), temperature=0)

existing_files = {file.path for file in self.current_state.files}
self.next_state.relevant_files = [path for path in llm_response.relevant_files if path in existing_files]

return AgentResponse.done(self)

def set_next_steps(self, response: TaskSteps, source: str):
# For logging/debugging purposes, we don't want to remove the finished steps
# until we're done with the task.
Expand Down
106 changes: 65 additions & 41 deletions core/agents/mixins.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,21 @@
from enum import Enum
from typing import Annotated, Literal, Optional, Union
from typing import Optional

from pydantic import BaseModel, Field

from core.agents.convo import AgentConvo
from core.agents.response import AgentResponse
from core.config import GET_RELEVANT_FILES_AGENT_NAME
from core.llm.parser import JSONParser
from core.log import get_logger

log = get_logger(__name__)

class StepType(str, Enum):
COMMAND = "command"
SAVE_FILE = "save_file"
HUMAN_INTERVENTION = "human_intervention"


class CommandOptions(BaseModel):
command: str = Field(description="Command to run")
timeout: int = Field(description="Timeout in seconds")
success_message: str = ""


class SaveFileOptions(BaseModel):
path: str


class SaveFileStep(BaseModel):
type: Literal[StepType.SAVE_FILE] = StepType.SAVE_FILE
save_file: SaveFileOptions


class CommandStep(BaseModel):
type: Literal[StepType.COMMAND] = StepType.COMMAND
command: CommandOptions


class HumanInterventionStep(BaseModel):
type: Literal[StepType.HUMAN_INTERVENTION] = StepType.HUMAN_INTERVENTION
human_intervention_description: str


Step = Annotated[
Union[SaveFileStep, CommandStep, HumanInterventionStep],
Field(discriminator="type"),
]


class TaskSteps(BaseModel):
steps: list[Step]
class RelevantFiles(BaseModel):
read_files: list[str] = Field(description="List of files you want to read.")
add_files: list[str] = Field(description="List of files you want to add to the list of relevant files.")
remove_files: list[str] = Field(description="List of files you want to remove from the list of relevant files.")
done: bool = Field(description="Boolean flag to indicate that you are done selecting relevant files.")


class IterationPromptMixin:
Expand Down Expand Up @@ -74,11 +45,64 @@ async def find_solution(
llm = self.get_llm()
convo = AgentConvo(self).template(
"iteration",
current_task=self.current_state.current_task,
user_feedback=user_feedback,
user_feedback_qa=user_feedback_qa,
next_solution_to_try=next_solution_to_try,
bug_hunting_cycles=bug_hunting_cycles,
)
llm_solution: str = await llm(convo)
return llm_solution


class RelevantFilesMixin:
"""
Provides a method to get relevant files for the current task.
"""

async def get_relevant_files(
self, user_feedback: Optional[str] = None, solution_description: Optional[str] = None
) -> AgentResponse:
log.debug("Getting relevant files for the current task")
await self.send_message("Figuring out which project files are relevant for the next task ...")

done = False
relevant_files = set()
llm = self.get_llm(GET_RELEVANT_FILES_AGENT_NAME)
convo = (
AgentConvo(self)
.template(
"filter_files",
user_feedback=user_feedback,
solution_description=solution_description,
relevant_files=relevant_files,
)
.require_schema(RelevantFiles)
)

while not done and len(convo.messages) < 13:
llm_response: RelevantFiles = await llm(convo, parser=JSONParser(RelevantFiles), temperature=0)

# Check if there are files to add to the list
if llm_response.add_files:
# Add only the files from add_files that are not already in relevant_files
relevant_files.update(file for file in llm_response.add_files if file not in relevant_files)

# Check if there are files to remove from the list
if llm_response.remove_files:
# Remove files from relevant_files that are in remove_files
relevant_files.difference_update(llm_response.remove_files)

read_files = [file for file in self.current_state.files if file.path in llm_response.read_files]

convo.remove_last_x_messages(1)
convo.assistant(llm_response.original_response)
convo.template("filter_files_loop", read_files=read_files, relevant_files=relevant_files).require_schema(
RelevantFiles
)
done = llm_response.done

existing_files = {file.path for file in self.current_state.files}
relevant_files = [path for path in relevant_files if path in existing_files]
self.next_state.relevant_files = relevant_files

return AgentResponse.done(self)
1 change: 1 addition & 0 deletions core/agents/spec_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ async def initialize_spec(self) -> AgentResponse:
},
)

reviewed_spec = user_description
if len(user_description) < ANALYZE_THRESHOLD and complexity != Complexity.SIMPLE:
initial_spec = await self.analyze_spec(user_description)
reviewed_spec = await self.review_spec(desc=user_description, spec=initial_spec)
Expand Down
7 changes: 6 additions & 1 deletion core/agents/task_reviewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ async def review_code_changes(self) -> AgentResponse:
# Some iterations are created by the task reviewer and have no user feedback
if iteration["user_feedback"]
]
bug_hunter_instructions = [
iteration["bug_hunting_cycles"][-1]["human_readable_instructions"].replace("```", "").strip()
for iteration in self.current_state.iterations
if iteration["bug_hunting_cycles"]
]

files_before_modification = self.current_state.modified_files
files_after_modification = [
Expand All @@ -40,10 +45,10 @@ async def review_code_changes(self) -> AgentResponse:
# TODO instead of sending files before and after maybe add nice way to show diff for multiple files
convo = AgentConvo(self).template(
"review_task",
current_task=self.current_state.current_task,
all_feedbacks=all_feedbacks,
files_before_modification=files_before_modification,
files_after_modification=files_after_modification,
bug_hunter_instructions=bug_hunter_instructions,
)
llm_response: str = await llm(convo, temperature=0.7)

Expand Down
5 changes: 3 additions & 2 deletions core/agents/troubleshooter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from core.agents.base import BaseAgent
from core.agents.convo import AgentConvo
from core.agents.mixins import IterationPromptMixin
from core.agents.mixins import IterationPromptMixin, RelevantFilesMixin
from core.agents.response import AgentResponse
from core.db.models.file import File
from core.db.models.project_state import IterationStatus, TaskStatus
Expand All @@ -28,7 +28,7 @@ class RouteFilePaths(BaseModel):
files: list[str] = Field(description="List of paths for files that contain routes")


class Troubleshooter(IterationPromptMixin, BaseAgent):
class Troubleshooter(IterationPromptMixin, RelevantFilesMixin, BaseAgent):
agent_type = "troubleshooter"
display_name = "Troubleshooter"

Expand Down Expand Up @@ -102,6 +102,7 @@ async def create_iteration(self) -> AgentResponse:
else:
# should be - elif change_description is not None: - but to prevent bugs with the extension
# this might be caused if we show the input field instead of buttons
await self.get_relevant_files(user_feedback)
iteration_status = IterationStatus.NEW_FEATURE_REQUESTED

self.next_state.iterations = self.current_state.iterations + [
Expand Down
2 changes: 2 additions & 0 deletions core/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
CHECK_LOGS_AGENT_NAME = "BugHunter.check_logs"
TASK_BREAKDOWN_AGENT_NAME = "Developer.breakdown_current_task"
SPEC_WRITER_AGENT_NAME = "SpecWriter"
GET_RELEVANT_FILES_AGENT_NAME = "get_relevant_files"

# Endpoint for the external documentation
EXTERNAL_DOCUMENTATION_API = "http://docs-pythagora-io-439719575.us-east-1.elb.amazonaws.com"
Expand Down Expand Up @@ -330,6 +331,7 @@ class Config(_StrictModel):
temperature=0.5,
),
SPEC_WRITER_AGENT_NAME: AgentLLMConfig(model="gpt-4-0125-preview", temperature=0.0),
GET_RELEVANT_FILES_AGENT_NAME: AgentLLMConfig(model="claude-3-5-sonnet-20240620", temperature=0.0),
}
)
prompt: PromptConfig = PromptConfig()
Expand Down
1 change: 1 addition & 0 deletions core/db/models/project_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ def complete_iteration(self):

log.debug(f"Completing iteration {self.unfinished_iterations[0]}")
self.unfinished_iterations[0]["status"] = IterationStatus.DONE
self.relevant_files = None
self.flag_iterations_as_modified()

def flag_iterations_as_modified(self):
Expand Down
18 changes: 15 additions & 3 deletions core/llm/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from enum import Enum
from typing import Optional, Union

from pydantic import BaseModel, ValidationError
from pydantic import BaseModel, ValidationError, create_model


class MultiCodeBlockParser:
Expand Down Expand Up @@ -86,6 +86,7 @@ class JSONParser:
def __init__(self, spec: Optional[BaseModel] = None, strict: bool = True):
self.spec = spec
self.strict = strict or (spec is not None)
self.original_response = None

@property
def schema(self):
Expand All @@ -102,7 +103,8 @@ def errors_to_markdown(errors: list) -> str:
return "\n".join(error_txt)

def __call__(self, text: str) -> Union[BaseModel, dict, None]:
text = text.strip()
self.original_response = text.strip() # Store the original text
text = self.original_response
if text.startswith("```"):
try:
text = CodeBlockParser()(text)
Expand Down Expand Up @@ -130,7 +132,17 @@ def __call__(self, text: str) -> Union[BaseModel, dict, None]:
except Exception as err:
raise ValueError(f"Error parsing JSON: {err}") from err

return model
# Create a new model that includes the original model fields and the original text
ExtendedModel = create_model(
f"Extended{self.spec.__name__}",
original_response=(str, ...),
**{field_name: (field.annotation, field.default) for field_name, field in self.spec.__fields__.items()},
)

# Instantiate the extended model
extended_model = ExtendedModel(original_response=self.original_response, **model.dict())

return extended_model


class EnumParser:
Expand Down
Loading

0 comments on commit 82411d8

Please sign in to comment.