Skip to content

Commit

Permalink
fix: tool calling cohere (#3355)
Browse files Browse the repository at this point in the history
* Add support for tool calling cohere

* update tool calling code

* make client name configurable with default

* formatting nits

* update docs

---------

Co-authored-by: Mark Sze <[email protected]>
Co-authored-by: Li Jiang <[email protected]>
  • Loading branch information
3 people authored Aug 28, 2024
1 parent 2e77d3b commit 5861bd9
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 27 deletions.
71 changes: 45 additions & 26 deletions autogen/oai/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"api_type": "cohere",
"model": "command-r-plus",
"api_key": os.environ.get("COHERE_API_KEY")
"client_name": "autogen-cohere", # Optional parameter
}
]}
Expand Down Expand Up @@ -144,7 +145,7 @@ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
def create(self, params: Dict) -> ChatCompletion:

messages = params.get("messages", [])

client_name = params.get("client_name") or "autogen-cohere"
# Parse parameters to the Cohere API's parameters
cohere_params = self.parse_params(params)

Expand All @@ -156,7 +157,7 @@ def create(self, params: Dict) -> ChatCompletion:
cohere_params["preamble"] = preamble

# We use chat model by default
client = Cohere(api_key=self.api_key)
client = Cohere(api_key=self.api_key, client_name=client_name)

# Token counts will be returned
prompt_tokens = 0
Expand Down Expand Up @@ -285,6 +286,23 @@ def create(self, params: Dict) -> ChatCompletion:
return response_oai


def extract_to_cohere_tool_results(tool_call_id: str, content_output: str, all_tool_calls) -> List[Dict[str, Any]]:
temp_tool_results = []

for tool_call in all_tool_calls:
if tool_call["id"] == tool_call_id:

call = {
"name": tool_call["function"]["name"],
"parameters": json.loads(
tool_call["function"]["arguments"] if not tool_call["function"]["arguments"] == "" else "{}"
),
}
output = [{"value": content_output}]
temp_tool_results.append(ToolResult(call=call, outputs=output))
return temp_tool_results


def oai_messages_to_cohere_messages(
messages: list[Dict[str, Any]], params: Dict[str, Any], cohere_params: Dict[str, Any]
) -> tuple[list[dict[str, Any]], str, str]:
Expand Down Expand Up @@ -352,7 +370,8 @@ def oai_messages_to_cohere_messages(
# 'content' field renamed to 'message'
# tools go into tools parameter
# tool_results go into tool_results parameter
for message in messages:
messages_length = len(messages)
for index, message in enumerate(messages):

if "role" in message and message["role"] == "system":
# System message
Expand All @@ -369,34 +388,34 @@ def oai_messages_to_cohere_messages(
new_message = {
"role": "CHATBOT",
"message": message["content"],
# Not including tools in this message, may need to. Testing required.
"tool_calls": [
{
"name": tool_call_.get("function", {}).get("name"),
"parameters": json.loads(tool_call_.get("function", {}).get("arguments") or "null"),
}
for tool_call_ in message["tool_calls"]
],
}

cohere_messages.append(new_message)
elif "role" in message and message["role"] == "tool":
if "tool_call_id" in message:
# Convert the tool call to a result
if not (tool_call_id := message.get("tool_call_id")):
continue

# Convert the tool call to a result
content_output = message["content"]
tool_results_chat_turn = extract_to_cohere_tool_results(tool_call_id, content_output, tool_calls)
if (index == messages_length - 1) or (messages[index + 1].get("role", "").lower() in ("user", "tool")):
# If the tool call is the last message or the next message is a user/tool message, this is a recent tool call.
# So, we pass it into tool_results.
tool_results.extend(tool_results_chat_turn)
continue

tool_call_id = message["tool_call_id"]
content_output = message["content"]

# Find the original tool
for tool_call in tool_calls:
if tool_call["id"] == tool_call_id:

call = {
"name": tool_call["function"]["name"],
"parameters": json.loads(
tool_call["function"]["arguments"]
if not tool_call["function"]["arguments"] == ""
else "{}"
),
}
output = [{"value": content_output}]

tool_results.append(ToolResult(call=call, outputs=output))
else:
# If its not the current tool call, we pass it as a tool message in the chat history.
new_message = {"role": "TOOL", "tool_results": tool_results_chat_turn}
cohere_messages.append(new_message)

break
elif "content" in message and isinstance(message["content"], str):
# Standard text message
new_message = {
Expand All @@ -416,7 +435,7 @@ def oai_messages_to_cohere_messages(
# If we're adding tool_results, like we are, the last message can't be a USER message
# So, we add a CHATBOT 'continue' message, if so.
# Changed key from "content" to "message" (jaygdesai/autogen_Jay)
if cohere_messages[-1]["role"] == "USER":
if cohere_messages[-1]["role"].lower() == "user":
cohere_messages.append({"role": "CHATBOT", "message": "Please continue."})

# We return a blank message when we have tool results
Expand Down
4 changes: 3 additions & 1 deletion website/docs/topics/non-openai-models/cloud-cohere.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
"- seed (null, integer)\n",
"- frequency_penalty (number 0..1)\n",
"- presence_penalty (number 0..1)\n",
"- client_name (null, string)\n",
"\n",
"Example:\n",
"```python\n",
Expand All @@ -108,6 +109,7 @@
" \"model\": \"command-r\",\n",
" \"api_key\": \"your Cohere API Key goes here\",\n",
" \"api_type\": \"cohere\",\n",
" \"client_name\": \"autogen-cohere\",\n",
" \"temperature\": 0.5,\n",
" \"p\": 0.2,\n",
" \"k\": 100,\n",
Expand Down Expand Up @@ -526,7 +528,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
"version": "3.12.5"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 5861bd9

Please sign in to comment.