From 1510d858066743e15347c519e5c84c1938833c28 Mon Sep 17 00:00:00 2001 From: killian <63927363+KillianLucas@users.noreply.github.com> Date: Tue, 3 Dec 2024 13:17:02 -0800 Subject: [PATCH] Better tool output for non Anthropic models --- interpreter/commands.py | 3 + interpreter/interpreter.py | 114 ++++++++++++++++++++-------------- interpreter/misc/get_input.py | 102 ++++++++++++++++-------------- interpreter/profiles.py | 2 +- 4 files changed, 127 insertions(+), 94 deletions(-) diff --git a/interpreter/commands.py b/interpreter/commands.py index 7da4f6de5..06fba98a4 100644 --- a/interpreter/commands.py +++ b/interpreter/commands.py @@ -160,6 +160,9 @@ def _handle_set_command(self, parts: list[str]) -> bool: value_str = parts[2] type_hint, _ = SETTINGS[param] try: + self.interpreter._client = ( + None # Reset client, in case they changed API key or API base + ) value = parse_value(value_str, type_hint) setattr(self.interpreter, param, value) print(f"Set {param} = {value}") diff --git a/interpreter/interpreter.py b/interpreter/interpreter.py index 8a6d72372..771fcc655 100644 --- a/interpreter/interpreter.py +++ b/interpreter/interpreter.py @@ -249,8 +249,8 @@ async def async_respond(self, user_input=None): provider = self.provider # Keep existing provider if set max_tokens = self.max_tokens # Keep existing max_tokens if set - if self.model == "claude-3-5-sonnet-latest": - # For some reason, Litellm can't find the model info for claude-3-5-sonnet-latest + if self.model in ["claude-3-5-sonnet-latest", "claude-3-5-sonnet-20241022"]: + # For some reason, Litellm can't find the model info for these provider = "anthropic" # Only try to get model info if we need either provider or max_tokens @@ -294,33 +294,33 @@ async def async_respond(self, user_input=None): self._spinner.start() - enable_prompt_caching = False betas = [COMPUTER_USE_BETA_FLAG] - if enable_prompt_caching: - betas.append(PROMPT_CACHING_BETA_FLAG) - image_truncation_threshold = 50 - system["cache_control"] = {"type": "ephemeral"} - edit = ToolRenderer() if ( provider == "anthropic" and not self.serve ): # Server can't handle Anthropic yet if self._client is None: - if self.api_key: - self._client = Anthropic(api_key=self.api_key) - else: - self._client = Anthropic() + anthropic_params = {} + if self.api_key is not None: + anthropic_params["api_key"] = self.api_key + if self.api_base is not None: + anthropic_params["base_url"] = self.api_base + self._client = Anthropic(**anthropic_params) if self.debug: print("Sending messages:", self.messages, "\n") + model = self.model + if model.startswith("anthropic/"): + model = model[len("anthropic/") :] + # Use Anthropic API which supports betas raw_response = self._client.beta.messages.create( max_tokens=max_tokens, messages=self.messages, - model=self.model, + model=model, system=system["text"], tools=tool_collection.to_params(), betas=betas, @@ -698,7 +698,7 @@ async def async_respond(self, user_input=None): "temperature": self.temperature, "api_key": self.api_key, "api_version": self.api_version, - "parallel_tool_calls": False, + # "parallel_tool_calls": True, } if self.tool_calling: @@ -707,13 +707,32 @@ async def async_respond(self, user_input=None): params["stream"] = False stream = False - if self.debug: - print(params) + if provider == "anthropic" and self.tool_calling: + params["tools"] = tool_collection.to_params() + for t in params["tools"]: + t["function"] = {"name": t["name"]} + if t["name"] == "computer": + t["function"]["parameters"] = { + "display_height_px": t["display_height_px"], + "display_width_px": t["display_width_px"], + "display_number": t["display_number"], + } + params["extra_headers"] = { + "anthropic-beta": "computer-use-2024-10-22" + } - if self.debug: - print("Sending request...", params) + # if self.debug: + # print("Sending request...", params) + # time.sleep(3) - time.sleep(3) + if self.debug: + print("Messages:") + for m in self.messages: + if len(str(m)) > 1000: + print(str(m)[:1000] + "...") + else: + print(str(m)) + print() raw_response = litellm.completion(**params) @@ -856,6 +875,8 @@ async def async_respond(self, user_input=None): else: user_approval = input("\nRun tool(s)? (y/n): ").lower().strip() + user_content_to_add = [] + for tool_call in message.tool_calls: function_arguments = json.loads(tool_call.function.arguments) @@ -869,43 +890,46 @@ async def async_respond(self, user_input=None): if self.tool_calling: if result.base64_image: - # Add image to tool result self.messages.append( { "role": "tool", - "content": "The user will reply with the image outputted by the tool.", + "content": "The user will reply with the tool's image output.", "tool_call_id": tool_call.id, } ) - self.messages.append( + user_content_to_add.append( { - "role": "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": f"data:image/png;base64,{result.base64_image}", - }, - } - ], - } - ) - else: - self.messages.append( - { - "role": "tool", - "content": json.dumps(dataclasses.asdict(result)), - "tool_call_id": tool_call.id, + "type": "image_url", + "image_url": { + "url": f"data:image/png;base64,{result.base64_image}", + }, } ) else: - self.messages.append( - { - "role": "user", - "content": "This was the output of the tool call. What does it mean/what's next?" - + json.dumps(dataclasses.asdict(result)), - } + text_content = ( + "This was the output of the tool call. What does it mean/what's next?\n" + + (result.output or "") ) + if result.base64_image: + content = [ + {"type": "text", "text": text_content}, + { + "type": "image", + "image_url": { + "url": "data:image/png;base64," + + result.base64_image + }, + }, + ] + else: + content = text_content + + self.messages.append({"role": "user", "content": content}) + + if user_content_to_add: + self.messages.append( + {"role": "user", "content": user_content_to_add} + ) def _ask_user_approval(self) -> str: """Ask user for approval to run a tool""" diff --git a/interpreter/misc/get_input.py b/interpreter/misc/get_input.py index a973b5f3b..1409afac5 100644 --- a/interpreter/misc/get_input.py +++ b/interpreter/misc/get_input.py @@ -11,7 +11,8 @@ async def async_get_input( placeholder_color: str = "gray", multiline_support: bool = True, ) -> str: - placeholder_text = "Describe command" + # placeholder_text = "Describe command" + placeholder_text = 'Use """ for multi-line input' history = InMemoryHistory() session = PromptSession( history=history, @@ -27,11 +28,16 @@ async def async_get_input( def _(event): current_line = event.current_buffer.document.current_line.rstrip() - if current_line == '"""': - multiline[0] = not multiline[0] + if not multiline[0] and current_line.endswith('"""'): + # Enter multiline mode + multiline[0] = True event.current_buffer.insert_text("\n") - if not multiline[0]: # If exiting multiline mode, submit - event.current_buffer.validate_and_handle() + return + + if multiline[0] and current_line.startswith('"""'): + # Exit multiline mode and submit + multiline[0] = False + event.current_buffer.validate_and_handle() return if multiline[0]: @@ -55,50 +61,50 @@ def _(event): return result -def get_input( - placeholder_text: Optional[str] = None, - placeholder_color: str = "gray", - multiline_support: bool = True, -) -> str: - placeholder_text = "Describe command" - history = InMemoryHistory() - session = PromptSession( - history=history, - enable_open_in_editor=False, - enable_history_search=False, - auto_suggest=None, - multiline=True, - ) - kb = KeyBindings() - multiline = [False] +# def get_input( +# placeholder_text: Optional[str] = None, +# placeholder_color: str = "gray", +# multiline_support: bool = True, +# ) -> str: +# placeholder_text = "Describe command" +# history = InMemoryHistory() +# session = PromptSession( +# history=history, +# enable_open_in_editor=False, +# enable_history_search=False, +# auto_suggest=None, +# multiline=True, +# ) +# kb = KeyBindings() +# multiline = [False] - @kb.add("enter") - def _(event): - current_line = event.current_buffer.document.current_line.rstrip() +# @kb.add("enter") +# def _(event): +# current_line = event.current_buffer.document.current_line.rstrip() - if current_line == '"""': - multiline[0] = not multiline[0] - event.current_buffer.insert_text("\n") - if not multiline[0]: # If exiting multiline mode, submit - event.current_buffer.validate_and_handle() - return +# if current_line == '"""': +# multiline[0] = not multiline[0] +# event.current_buffer.insert_text("\n") +# if not multiline[0]: # If exiting multiline mode, submit +# event.current_buffer.validate_and_handle() +# return - if multiline[0]: - event.current_buffer.insert_text("\n") - else: - event.current_buffer.validate_and_handle() +# if multiline[0]: +# event.current_buffer.insert_text("\n") +# else: +# event.current_buffer.validate_and_handle() - result = session.prompt( - "> ", - placeholder=HTML(f'') - if placeholder_text - else None, - key_bindings=kb, - complete_while_typing=False, - enable_suspend=False, - search_ignore_case=True, - include_default_pygments_style=False, - input_processors=[], - enable_system_prompt=False, - ) - return result +# result = session.prompt( +# "> ", +# placeholder=HTML(f'') +# if placeholder_text +# else None, +# key_bindings=kb, +# complete_while_typing=False, +# enable_suspend=False, +# search_ignore_case=True, +# include_default_pygments_style=False, +# input_processors=[], +# enable_system_prompt=False, +# ) +# return result diff --git a/interpreter/profiles.py b/interpreter/profiles.py index 6a1d6f53c..140361c14 100644 --- a/interpreter/profiles.py +++ b/interpreter/profiles.py @@ -32,7 +32,7 @@ class Profile: def __init__(self): # Default values if no profile exists # Model configuration - self.model = "claude-3-5-sonnet-latest" # The LLM model to use + self.model = "claude-3-5-sonnet-20241022" # The LLM model to use self.provider = ( None # The model provider (e.g. anthropic, openai) None will auto-detect )