From 1510d858066743e15347c519e5c84c1938833c28 Mon Sep 17 00:00:00 2001
From: killian <63927363+KillianLucas@users.noreply.github.com>
Date: Tue, 3 Dec 2024 13:17:02 -0800
Subject: [PATCH] Better tool output for non Anthropic models
---
interpreter/commands.py | 3 +
interpreter/interpreter.py | 114 ++++++++++++++++++++--------------
interpreter/misc/get_input.py | 102 ++++++++++++++++--------------
interpreter/profiles.py | 2 +-
4 files changed, 127 insertions(+), 94 deletions(-)
diff --git a/interpreter/commands.py b/interpreter/commands.py
index 7da4f6de5..06fba98a4 100644
--- a/interpreter/commands.py
+++ b/interpreter/commands.py
@@ -160,6 +160,9 @@ def _handle_set_command(self, parts: list[str]) -> bool:
value_str = parts[2]
type_hint, _ = SETTINGS[param]
try:
+ self.interpreter._client = (
+ None # Reset client, in case they changed API key or API base
+ )
value = parse_value(value_str, type_hint)
setattr(self.interpreter, param, value)
print(f"Set {param} = {value}")
diff --git a/interpreter/interpreter.py b/interpreter/interpreter.py
index 8a6d72372..771fcc655 100644
--- a/interpreter/interpreter.py
+++ b/interpreter/interpreter.py
@@ -249,8 +249,8 @@ async def async_respond(self, user_input=None):
provider = self.provider # Keep existing provider if set
max_tokens = self.max_tokens # Keep existing max_tokens if set
- if self.model == "claude-3-5-sonnet-latest":
- # For some reason, Litellm can't find the model info for claude-3-5-sonnet-latest
+ if self.model in ["claude-3-5-sonnet-latest", "claude-3-5-sonnet-20241022"]:
+ # For some reason, Litellm can't find the model info for these
provider = "anthropic"
# Only try to get model info if we need either provider or max_tokens
@@ -294,33 +294,33 @@ async def async_respond(self, user_input=None):
self._spinner.start()
- enable_prompt_caching = False
betas = [COMPUTER_USE_BETA_FLAG]
- if enable_prompt_caching:
- betas.append(PROMPT_CACHING_BETA_FLAG)
- image_truncation_threshold = 50
- system["cache_control"] = {"type": "ephemeral"}
-
edit = ToolRenderer()
if (
provider == "anthropic" and not self.serve
): # Server can't handle Anthropic yet
if self._client is None:
- if self.api_key:
- self._client = Anthropic(api_key=self.api_key)
- else:
- self._client = Anthropic()
+ anthropic_params = {}
+ if self.api_key is not None:
+ anthropic_params["api_key"] = self.api_key
+ if self.api_base is not None:
+ anthropic_params["base_url"] = self.api_base
+ self._client = Anthropic(**anthropic_params)
if self.debug:
print("Sending messages:", self.messages, "\n")
+ model = self.model
+ if model.startswith("anthropic/"):
+ model = model[len("anthropic/") :]
+
# Use Anthropic API which supports betas
raw_response = self._client.beta.messages.create(
max_tokens=max_tokens,
messages=self.messages,
- model=self.model,
+ model=model,
system=system["text"],
tools=tool_collection.to_params(),
betas=betas,
@@ -698,7 +698,7 @@ async def async_respond(self, user_input=None):
"temperature": self.temperature,
"api_key": self.api_key,
"api_version": self.api_version,
- "parallel_tool_calls": False,
+ # "parallel_tool_calls": True,
}
if self.tool_calling:
@@ -707,13 +707,32 @@ async def async_respond(self, user_input=None):
params["stream"] = False
stream = False
- if self.debug:
- print(params)
+ if provider == "anthropic" and self.tool_calling:
+ params["tools"] = tool_collection.to_params()
+ for t in params["tools"]:
+ t["function"] = {"name": t["name"]}
+ if t["name"] == "computer":
+ t["function"]["parameters"] = {
+ "display_height_px": t["display_height_px"],
+ "display_width_px": t["display_width_px"],
+ "display_number": t["display_number"],
+ }
+ params["extra_headers"] = {
+ "anthropic-beta": "computer-use-2024-10-22"
+ }
- if self.debug:
- print("Sending request...", params)
+ # if self.debug:
+ # print("Sending request...", params)
+ # time.sleep(3)
- time.sleep(3)
+ if self.debug:
+ print("Messages:")
+ for m in self.messages:
+ if len(str(m)) > 1000:
+ print(str(m)[:1000] + "...")
+ else:
+ print(str(m))
+ print()
raw_response = litellm.completion(**params)
@@ -856,6 +875,8 @@ async def async_respond(self, user_input=None):
else:
user_approval = input("\nRun tool(s)? (y/n): ").lower().strip()
+ user_content_to_add = []
+
for tool_call in message.tool_calls:
function_arguments = json.loads(tool_call.function.arguments)
@@ -869,43 +890,46 @@ async def async_respond(self, user_input=None):
if self.tool_calling:
if result.base64_image:
- # Add image to tool result
self.messages.append(
{
"role": "tool",
- "content": "The user will reply with the image outputted by the tool.",
+ "content": "The user will reply with the tool's image output.",
"tool_call_id": tool_call.id,
}
)
- self.messages.append(
+ user_content_to_add.append(
{
- "role": "user",
- "content": [
- {
- "type": "image_url",
- "image_url": {
- "url": f"data:image/png;base64,{result.base64_image}",
- },
- }
- ],
- }
- )
- else:
- self.messages.append(
- {
- "role": "tool",
- "content": json.dumps(dataclasses.asdict(result)),
- "tool_call_id": tool_call.id,
+ "type": "image_url",
+ "image_url": {
+ "url": f"data:image/png;base64,{result.base64_image}",
+ },
}
)
else:
- self.messages.append(
- {
- "role": "user",
- "content": "This was the output of the tool call. What does it mean/what's next?"
- + json.dumps(dataclasses.asdict(result)),
- }
+ text_content = (
+ "This was the output of the tool call. What does it mean/what's next?\n"
+ + (result.output or "")
)
+ if result.base64_image:
+ content = [
+ {"type": "text", "text": text_content},
+ {
+ "type": "image",
+ "image_url": {
+ "url": "data:image/png;base64,"
+ + result.base64_image
+ },
+ },
+ ]
+ else:
+ content = text_content
+
+ self.messages.append({"role": "user", "content": content})
+
+ if user_content_to_add:
+ self.messages.append(
+ {"role": "user", "content": user_content_to_add}
+ )
def _ask_user_approval(self) -> str:
"""Ask user for approval to run a tool"""
diff --git a/interpreter/misc/get_input.py b/interpreter/misc/get_input.py
index a973b5f3b..1409afac5 100644
--- a/interpreter/misc/get_input.py
+++ b/interpreter/misc/get_input.py
@@ -11,7 +11,8 @@ async def async_get_input(
placeholder_color: str = "gray",
multiline_support: bool = True,
) -> str:
- placeholder_text = "Describe command"
+ # placeholder_text = "Describe command"
+ placeholder_text = 'Use """ for multi-line input'
history = InMemoryHistory()
session = PromptSession(
history=history,
@@ -27,11 +28,16 @@ async def async_get_input(
def _(event):
current_line = event.current_buffer.document.current_line.rstrip()
- if current_line == '"""':
- multiline[0] = not multiline[0]
+ if not multiline[0] and current_line.endswith('"""'):
+ # Enter multiline mode
+ multiline[0] = True
event.current_buffer.insert_text("\n")
- if not multiline[0]: # If exiting multiline mode, submit
- event.current_buffer.validate_and_handle()
+ return
+
+ if multiline[0] and current_line.startswith('"""'):
+ # Exit multiline mode and submit
+ multiline[0] = False
+ event.current_buffer.validate_and_handle()
return
if multiline[0]:
@@ -55,50 +61,50 @@ def _(event):
return result
-def get_input(
- placeholder_text: Optional[str] = None,
- placeholder_color: str = "gray",
- multiline_support: bool = True,
-) -> str:
- placeholder_text = "Describe command"
- history = InMemoryHistory()
- session = PromptSession(
- history=history,
- enable_open_in_editor=False,
- enable_history_search=False,
- auto_suggest=None,
- multiline=True,
- )
- kb = KeyBindings()
- multiline = [False]
+# def get_input(
+# placeholder_text: Optional[str] = None,
+# placeholder_color: str = "gray",
+# multiline_support: bool = True,
+# ) -> str:
+# placeholder_text = "Describe command"
+# history = InMemoryHistory()
+# session = PromptSession(
+# history=history,
+# enable_open_in_editor=False,
+# enable_history_search=False,
+# auto_suggest=None,
+# multiline=True,
+# )
+# kb = KeyBindings()
+# multiline = [False]
- @kb.add("enter")
- def _(event):
- current_line = event.current_buffer.document.current_line.rstrip()
+# @kb.add("enter")
+# def _(event):
+# current_line = event.current_buffer.document.current_line.rstrip()
- if current_line == '"""':
- multiline[0] = not multiline[0]
- event.current_buffer.insert_text("\n")
- if not multiline[0]: # If exiting multiline mode, submit
- event.current_buffer.validate_and_handle()
- return
+# if current_line == '"""':
+# multiline[0] = not multiline[0]
+# event.current_buffer.insert_text("\n")
+# if not multiline[0]: # If exiting multiline mode, submit
+# event.current_buffer.validate_and_handle()
+# return
- if multiline[0]:
- event.current_buffer.insert_text("\n")
- else:
- event.current_buffer.validate_and_handle()
+# if multiline[0]:
+# event.current_buffer.insert_text("\n")
+# else:
+# event.current_buffer.validate_and_handle()
- result = session.prompt(
- "> ",
- placeholder=HTML(f'')
- if placeholder_text
- else None,
- key_bindings=kb,
- complete_while_typing=False,
- enable_suspend=False,
- search_ignore_case=True,
- include_default_pygments_style=False,
- input_processors=[],
- enable_system_prompt=False,
- )
- return result
+# result = session.prompt(
+# "> ",
+# placeholder=HTML(f'')
+# if placeholder_text
+# else None,
+# key_bindings=kb,
+# complete_while_typing=False,
+# enable_suspend=False,
+# search_ignore_case=True,
+# include_default_pygments_style=False,
+# input_processors=[],
+# enable_system_prompt=False,
+# )
+# return result
diff --git a/interpreter/profiles.py b/interpreter/profiles.py
index 6a1d6f53c..140361c14 100644
--- a/interpreter/profiles.py
+++ b/interpreter/profiles.py
@@ -32,7 +32,7 @@ class Profile:
def __init__(self):
# Default values if no profile exists
# Model configuration
- self.model = "claude-3-5-sonnet-latest" # The LLM model to use
+ self.model = "claude-3-5-sonnet-20241022" # The LLM model to use
self.provider = (
None # The model provider (e.g. anthropic, openai) None will auto-detect
)