Skip to content

Commit

Permalink
Better tool output for non Anthropic models
Browse files Browse the repository at this point in the history
  • Loading branch information
KillianLucas committed Dec 3, 2024
1 parent c365866 commit 1510d85
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 94 deletions.
3 changes: 3 additions & 0 deletions interpreter/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ def _handle_set_command(self, parts: list[str]) -> bool:
value_str = parts[2]
type_hint, _ = SETTINGS[param]
try:
self.interpreter._client = (
None # Reset client, in case they changed API key or API base
)
value = parse_value(value_str, type_hint)
setattr(self.interpreter, param, value)
print(f"Set {param} = {value}")
Expand Down
114 changes: 69 additions & 45 deletions interpreter/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,8 @@ async def async_respond(self, user_input=None):
provider = self.provider # Keep existing provider if set
max_tokens = self.max_tokens # Keep existing max_tokens if set

if self.model == "claude-3-5-sonnet-latest":
# For some reason, Litellm can't find the model info for claude-3-5-sonnet-latest
if self.model in ["claude-3-5-sonnet-latest", "claude-3-5-sonnet-20241022"]:
# For some reason, Litellm can't find the model info for these
provider = "anthropic"

# Only try to get model info if we need either provider or max_tokens
Expand Down Expand Up @@ -294,33 +294,33 @@ async def async_respond(self, user_input=None):

self._spinner.start()

enable_prompt_caching = False
betas = [COMPUTER_USE_BETA_FLAG]

if enable_prompt_caching:
betas.append(PROMPT_CACHING_BETA_FLAG)
image_truncation_threshold = 50
system["cache_control"] = {"type": "ephemeral"}

edit = ToolRenderer()

if (
provider == "anthropic" and not self.serve
): # Server can't handle Anthropic yet
if self._client is None:
if self.api_key:
self._client = Anthropic(api_key=self.api_key)
else:
self._client = Anthropic()
anthropic_params = {}
if self.api_key is not None:
anthropic_params["api_key"] = self.api_key
if self.api_base is not None:
anthropic_params["base_url"] = self.api_base
self._client = Anthropic(**anthropic_params)

if self.debug:
print("Sending messages:", self.messages, "\n")

model = self.model
if model.startswith("anthropic/"):
model = model[len("anthropic/") :]

# Use Anthropic API which supports betas
raw_response = self._client.beta.messages.create(
max_tokens=max_tokens,
messages=self.messages,
model=self.model,
model=model,
system=system["text"],
tools=tool_collection.to_params(),
betas=betas,
Expand Down Expand Up @@ -698,7 +698,7 @@ async def async_respond(self, user_input=None):
"temperature": self.temperature,
"api_key": self.api_key,
"api_version": self.api_version,
"parallel_tool_calls": False,
# "parallel_tool_calls": True,
}

if self.tool_calling:
Expand All @@ -707,13 +707,32 @@ async def async_respond(self, user_input=None):
params["stream"] = False
stream = False

if self.debug:
print(params)
if provider == "anthropic" and self.tool_calling:
params["tools"] = tool_collection.to_params()
for t in params["tools"]:
t["function"] = {"name": t["name"]}
if t["name"] == "computer":
t["function"]["parameters"] = {
"display_height_px": t["display_height_px"],
"display_width_px": t["display_width_px"],
"display_number": t["display_number"],
}
params["extra_headers"] = {
"anthropic-beta": "computer-use-2024-10-22"
}

if self.debug:
print("Sending request...", params)
# if self.debug:
# print("Sending request...", params)
# time.sleep(3)

time.sleep(3)
if self.debug:
print("Messages:")
for m in self.messages:
if len(str(m)) > 1000:
print(str(m)[:1000] + "...")
else:
print(str(m))
print()

raw_response = litellm.completion(**params)

Expand Down Expand Up @@ -856,6 +875,8 @@ async def async_respond(self, user_input=None):
else:
user_approval = input("\nRun tool(s)? (y/n): ").lower().strip()

user_content_to_add = []

for tool_call in message.tool_calls:
function_arguments = json.loads(tool_call.function.arguments)

Expand All @@ -869,43 +890,46 @@ async def async_respond(self, user_input=None):

if self.tool_calling:
if result.base64_image:
# Add image to tool result
self.messages.append(
{
"role": "tool",
"content": "The user will reply with the image outputted by the tool.",
"content": "The user will reply with the tool's image output.",
"tool_call_id": tool_call.id,
}
)
self.messages.append(
user_content_to_add.append(
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{result.base64_image}",
},
}
],
}
)
else:
self.messages.append(
{
"role": "tool",
"content": json.dumps(dataclasses.asdict(result)),
"tool_call_id": tool_call.id,
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{result.base64_image}",
},
}
)
else:
self.messages.append(
{
"role": "user",
"content": "This was the output of the tool call. What does it mean/what's next?"
+ json.dumps(dataclasses.asdict(result)),
}
text_content = (
"This was the output of the tool call. What does it mean/what's next?\n"
+ (result.output or "")
)
if result.base64_image:
content = [
{"type": "text", "text": text_content},
{
"type": "image",
"image_url": {
"url": "data:image/png;base64,"
+ result.base64_image
},
},
]
else:
content = text_content

self.messages.append({"role": "user", "content": content})

if user_content_to_add:
self.messages.append(
{"role": "user", "content": user_content_to_add}
)

def _ask_user_approval(self) -> str:
"""Ask user for approval to run a tool"""
Expand Down
102 changes: 54 additions & 48 deletions interpreter/misc/get_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ async def async_get_input(
placeholder_color: str = "gray",
multiline_support: bool = True,
) -> str:
placeholder_text = "Describe command"
# placeholder_text = "Describe command"
placeholder_text = 'Use """ for multi-line input'
history = InMemoryHistory()
session = PromptSession(
history=history,
Expand All @@ -27,11 +28,16 @@ async def async_get_input(
def _(event):
current_line = event.current_buffer.document.current_line.rstrip()

if current_line == '"""':
multiline[0] = not multiline[0]
if not multiline[0] and current_line.endswith('"""'):
# Enter multiline mode
multiline[0] = True
event.current_buffer.insert_text("\n")
if not multiline[0]: # If exiting multiline mode, submit
event.current_buffer.validate_and_handle()
return

if multiline[0] and current_line.startswith('"""'):
# Exit multiline mode and submit
multiline[0] = False
event.current_buffer.validate_and_handle()
return

if multiline[0]:
Expand All @@ -55,50 +61,50 @@ def _(event):
return result


def get_input(
placeholder_text: Optional[str] = None,
placeholder_color: str = "gray",
multiline_support: bool = True,
) -> str:
placeholder_text = "Describe command"
history = InMemoryHistory()
session = PromptSession(
history=history,
enable_open_in_editor=False,
enable_history_search=False,
auto_suggest=None,
multiline=True,
)
kb = KeyBindings()
multiline = [False]
# def get_input(
# placeholder_text: Optional[str] = None,
# placeholder_color: str = "gray",
# multiline_support: bool = True,
# ) -> str:
# placeholder_text = "Describe command"
# history = InMemoryHistory()
# session = PromptSession(
# history=history,
# enable_open_in_editor=False,
# enable_history_search=False,
# auto_suggest=None,
# multiline=True,
# )
# kb = KeyBindings()
# multiline = [False]

@kb.add("enter")
def _(event):
current_line = event.current_buffer.document.current_line.rstrip()
# @kb.add("enter")
# def _(event):
# current_line = event.current_buffer.document.current_line.rstrip()

if current_line == '"""':
multiline[0] = not multiline[0]
event.current_buffer.insert_text("\n")
if not multiline[0]: # If exiting multiline mode, submit
event.current_buffer.validate_and_handle()
return
# if current_line == '"""':
# multiline[0] = not multiline[0]
# event.current_buffer.insert_text("\n")
# if not multiline[0]: # If exiting multiline mode, submit
# event.current_buffer.validate_and_handle()
# return

if multiline[0]:
event.current_buffer.insert_text("\n")
else:
event.current_buffer.validate_and_handle()
# if multiline[0]:
# event.current_buffer.insert_text("\n")
# else:
# event.current_buffer.validate_and_handle()

result = session.prompt(
"> ",
placeholder=HTML(f'<style fg="{placeholder_color}">{placeholder_text}</style>')
if placeholder_text
else None,
key_bindings=kb,
complete_while_typing=False,
enable_suspend=False,
search_ignore_case=True,
include_default_pygments_style=False,
input_processors=[],
enable_system_prompt=False,
)
return result
# result = session.prompt(
# "> ",
# placeholder=HTML(f'<style fg="{placeholder_color}">{placeholder_text}</style>')
# if placeholder_text
# else None,
# key_bindings=kb,
# complete_while_typing=False,
# enable_suspend=False,
# search_ignore_case=True,
# include_default_pygments_style=False,
# input_processors=[],
# enable_system_prompt=False,
# )
# return result
2 changes: 1 addition & 1 deletion interpreter/profiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class Profile:
def __init__(self):
# Default values if no profile exists
# Model configuration
self.model = "claude-3-5-sonnet-latest" # The LLM model to use
self.model = "claude-3-5-sonnet-20241022" # The LLM model to use
self.provider = (
None # The model provider (e.g. anthropic, openai) None will auto-detect
)
Expand Down

0 comments on commit 1510d85

Please sign in to comment.