Skip to content

Commit

Permalink
Fix OpenAIFunctionsAgent function call message content retrieving (la…
Browse files Browse the repository at this point in the history
…ngchain-ai#10488)

`langchain.agents.openai_functions[_multi]_agent._parse_ai_message()`
incorrectly extracts AI message content, thus LLM response ("thoughts")
is lost and can't be logged or processed by callbacks.

This PR fixes function call message content retrieving.
  • Loading branch information
skozlovf authored Sep 13, 2023
1 parent 2dc3c64 commit 0a0276b
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def _parse_ai_message(message: BaseMessage) -> Union[AgentAction, AgentFinish]:
else:
tool_input = _tool_input

content_msg = "responded: {content}\n" if message.content else "\n"
content_msg = f"responded: {message.content}\n" if message.content else "\n"

return _FunctionsAgentAction(
tool=function_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def _parse_ai_message(message: BaseMessage) -> Union[List[AgentAction], AgentFin
else:
tool_input = _tool_input

content_msg = "responded: {content}\n" if message.content else "\n"
content_msg = f"responded: {message.content}\n" if message.content else "\n"
log = f"\nInvoking: `{function_name}` with `{tool_input}`\n{content_msg}\n"
_tool = _FunctionsAgentAction(
tool=function_name,
Expand Down
76 changes: 76 additions & 0 deletions libs/langchain/tests/unit_tests/agents/test_openai_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import pytest

from langchain.agents.openai_functions_agent.base import (
_FunctionsAgentAction,
_parse_ai_message,
)
from langchain.schema import AgentFinish, OutputParserException
from langchain.schema.messages import AIMessage, SystemMessage


# Test: _parse_ai_message() function.
class TestParseAIMessage:
# Test: Pass Non-AIMessage.
def test_not_an_ai(self) -> None:
err = f"Expected an AI message got {str(SystemMessage)}"
with pytest.raises(TypeError, match=err):
_parse_ai_message(SystemMessage(content="x"))

# Test: Model response (not a function call).
def test_model_response(self) -> None:
msg = AIMessage(content="Model response.")
result = _parse_ai_message(msg)

assert isinstance(result, AgentFinish)
assert result.return_values == {"output": "Model response."}
assert result.log == "Model response."

# Test: Model response with a function call.
def test_func_call(self) -> None:
msg = AIMessage(
content="LLM thoughts.",
additional_kwargs={
"function_call": {"name": "foo", "arguments": '{"param": 42}'}
},
)
result = _parse_ai_message(msg)

assert isinstance(result, _FunctionsAgentAction)
assert result.tool == "foo"
assert result.tool_input == {"param": 42}
assert result.log == (
"\nInvoking: `foo` with `{'param': 42}`\nresponded: LLM thoughts.\n\n"
)
assert result.message_log == [msg]

# Test: Model response with a function call (old style tools).
def test_func_call_oldstyle(self) -> None:
msg = AIMessage(
content="LLM thoughts.",
additional_kwargs={
"function_call": {"name": "foo", "arguments": '{"__arg1": "42"}'}
},
)
result = _parse_ai_message(msg)

assert isinstance(result, _FunctionsAgentAction)
assert result.tool == "foo"
assert result.tool_input == "42"
assert result.log == (
"\nInvoking: `foo` with `42`\nresponded: LLM thoughts.\n\n"
)
assert result.message_log == [msg]

# Test: Invalid function call args.
def test_func_call_invalid(self) -> None:
msg = AIMessage(
content="LLM thoughts.",
additional_kwargs={"function_call": {"name": "foo", "arguments": "{42]"}},
)

err = (
"Could not parse tool input: {'name': 'foo', 'arguments': '{42]'} "
"because the `arguments` is not valid JSON."
)
with pytest.raises(OutputParserException, match=err):
_parse_ai_message(msg)
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import json

import pytest

from langchain.agents.openai_functions_multi_agent.base import (
_FunctionsAgentAction,
_parse_ai_message,
)
from langchain.schema import AgentFinish, OutputParserException
from langchain.schema.messages import AIMessage, SystemMessage


# Test: _parse_ai_message() function.
class TestParseAIMessage:
# Test: Pass Non-AIMessage.
def test_not_an_ai(self) -> None:
err = f"Expected an AI message got {str(SystemMessage)}"
with pytest.raises(TypeError, match=err):
_parse_ai_message(SystemMessage(content="x"))

# Test: Model response (not a function call).
def test_model_response(self) -> None:
msg = AIMessage(content="Model response.")
result = _parse_ai_message(msg)

assert isinstance(result, AgentFinish)
assert result.return_values == {"output": "Model response."}
assert result.log == "Model response."

# Test: Model response with a function call.
def test_func_call(self) -> None:
act = json.dumps([{"action_name": "foo", "action": {"param": 42}}])

msg = AIMessage(
content="LLM thoughts.",
additional_kwargs={
"function_call": {"name": "foo", "arguments": f'{{"actions": {act}}}'}
},
)
result = _parse_ai_message(msg)

assert isinstance(result, list)
assert len(result) == 1

action = result[0]
assert isinstance(action, _FunctionsAgentAction)
assert action.tool == "foo"
assert action.tool_input == {"param": 42}
assert action.log == (
"\nInvoking: `foo` with `{'param': 42}`\nresponded: LLM thoughts.\n\n"
)
assert action.message_log == [msg]

# Test: Model response with a function call (old style tools).
def test_func_call_oldstyle(self) -> None:
act = json.dumps([{"action_name": "foo", "action": {"__arg1": "42"}}])

msg = AIMessage(
content="LLM thoughts.",
additional_kwargs={
"function_call": {"name": "foo", "arguments": f'{{"actions": {act}}}'}
},
)
result = _parse_ai_message(msg)

assert isinstance(result, list)
assert len(result) == 1

action = result[0]
assert isinstance(action, _FunctionsAgentAction)
assert action.tool == "foo"
assert action.tool_input == "42"
assert action.log == (
"\nInvoking: `foo` with `42`\nresponded: LLM thoughts.\n\n"
)
assert action.message_log == [msg]

# Test: Invalid function call args.
def test_func_call_invalid(self) -> None:
msg = AIMessage(
content="LLM thoughts.",
additional_kwargs={"function_call": {"name": "foo", "arguments": "{42]"}},
)

err = (
"Could not parse tool input: {'name': 'foo', 'arguments': '{42]'} "
"because the `arguments` is not valid JSON."
)
with pytest.raises(OutputParserException, match=err):
_parse_ai_message(msg)

0 comments on commit 0a0276b

Please sign in to comment.