From 0a0276bcdb4b195baf64d515ed68031529881e22 Mon Sep 17 00:00:00 2001 From: Sergey Kozlov Date: Thu, 14 Sep 2023 05:19:25 +0600 Subject: [PATCH] Fix OpenAIFunctionsAgent function call message content retrieving (#10488) `langchain.agents.openai_functions[_multi]_agent._parse_ai_message()` incorrectly extracts AI message content, thus LLM response ("thoughts") is lost and can't be logged or processed by callbacks. This PR fixes function call message content retrieving. --- .../agents/openai_functions_agent/base.py | 2 +- .../openai_functions_multi_agent/base.py | 2 +- .../agents/test_openai_functions.py | 76 ++++++++++++++++ .../agents/test_openai_functions_multi.py | 90 +++++++++++++++++++ 4 files changed, 168 insertions(+), 2 deletions(-) create mode 100644 libs/langchain/tests/unit_tests/agents/test_openai_functions.py create mode 100644 libs/langchain/tests/unit_tests/agents/test_openai_functions_multi.py diff --git a/libs/langchain/langchain/agents/openai_functions_agent/base.py b/libs/langchain/langchain/agents/openai_functions_agent/base.py index 19d5ebbc43380..52aa91b7c8ef5 100644 --- a/libs/langchain/langchain/agents/openai_functions_agent/base.py +++ b/libs/langchain/langchain/agents/openai_functions_agent/base.py @@ -127,7 +127,7 @@ def _parse_ai_message(message: BaseMessage) -> Union[AgentAction, AgentFinish]: else: tool_input = _tool_input - content_msg = "responded: {content}\n" if message.content else "\n" + content_msg = f"responded: {message.content}\n" if message.content else "\n" return _FunctionsAgentAction( tool=function_name, diff --git a/libs/langchain/langchain/agents/openai_functions_multi_agent/base.py b/libs/langchain/langchain/agents/openai_functions_multi_agent/base.py index fcc51227fdac8..7469f303895a5 100644 --- a/libs/langchain/langchain/agents/openai_functions_multi_agent/base.py +++ b/libs/langchain/langchain/agents/openai_functions_multi_agent/base.py @@ -129,7 +129,7 @@ def _parse_ai_message(message: BaseMessage) -> Union[List[AgentAction], AgentFin else: tool_input = _tool_input - content_msg = "responded: {content}\n" if message.content else "\n" + content_msg = f"responded: {message.content}\n" if message.content else "\n" log = f"\nInvoking: `{function_name}` with `{tool_input}`\n{content_msg}\n" _tool = _FunctionsAgentAction( tool=function_name, diff --git a/libs/langchain/tests/unit_tests/agents/test_openai_functions.py b/libs/langchain/tests/unit_tests/agents/test_openai_functions.py new file mode 100644 index 0000000000000..046f8d0a509d4 --- /dev/null +++ b/libs/langchain/tests/unit_tests/agents/test_openai_functions.py @@ -0,0 +1,76 @@ +import pytest + +from langchain.agents.openai_functions_agent.base import ( + _FunctionsAgentAction, + _parse_ai_message, +) +from langchain.schema import AgentFinish, OutputParserException +from langchain.schema.messages import AIMessage, SystemMessage + + +# Test: _parse_ai_message() function. +class TestParseAIMessage: + # Test: Pass Non-AIMessage. + def test_not_an_ai(self) -> None: + err = f"Expected an AI message got {str(SystemMessage)}" + with pytest.raises(TypeError, match=err): + _parse_ai_message(SystemMessage(content="x")) + + # Test: Model response (not a function call). + def test_model_response(self) -> None: + msg = AIMessage(content="Model response.") + result = _parse_ai_message(msg) + + assert isinstance(result, AgentFinish) + assert result.return_values == {"output": "Model response."} + assert result.log == "Model response." + + # Test: Model response with a function call. + def test_func_call(self) -> None: + msg = AIMessage( + content="LLM thoughts.", + additional_kwargs={ + "function_call": {"name": "foo", "arguments": '{"param": 42}'} + }, + ) + result = _parse_ai_message(msg) + + assert isinstance(result, _FunctionsAgentAction) + assert result.tool == "foo" + assert result.tool_input == {"param": 42} + assert result.log == ( + "\nInvoking: `foo` with `{'param': 42}`\nresponded: LLM thoughts.\n\n" + ) + assert result.message_log == [msg] + + # Test: Model response with a function call (old style tools). + def test_func_call_oldstyle(self) -> None: + msg = AIMessage( + content="LLM thoughts.", + additional_kwargs={ + "function_call": {"name": "foo", "arguments": '{"__arg1": "42"}'} + }, + ) + result = _parse_ai_message(msg) + + assert isinstance(result, _FunctionsAgentAction) + assert result.tool == "foo" + assert result.tool_input == "42" + assert result.log == ( + "\nInvoking: `foo` with `42`\nresponded: LLM thoughts.\n\n" + ) + assert result.message_log == [msg] + + # Test: Invalid function call args. + def test_func_call_invalid(self) -> None: + msg = AIMessage( + content="LLM thoughts.", + additional_kwargs={"function_call": {"name": "foo", "arguments": "{42]"}}, + ) + + err = ( + "Could not parse tool input: {'name': 'foo', 'arguments': '{42]'} " + "because the `arguments` is not valid JSON." + ) + with pytest.raises(OutputParserException, match=err): + _parse_ai_message(msg) diff --git a/libs/langchain/tests/unit_tests/agents/test_openai_functions_multi.py b/libs/langchain/tests/unit_tests/agents/test_openai_functions_multi.py new file mode 100644 index 0000000000000..a76f790a626a0 --- /dev/null +++ b/libs/langchain/tests/unit_tests/agents/test_openai_functions_multi.py @@ -0,0 +1,90 @@ +import json + +import pytest + +from langchain.agents.openai_functions_multi_agent.base import ( + _FunctionsAgentAction, + _parse_ai_message, +) +from langchain.schema import AgentFinish, OutputParserException +from langchain.schema.messages import AIMessage, SystemMessage + + +# Test: _parse_ai_message() function. +class TestParseAIMessage: + # Test: Pass Non-AIMessage. + def test_not_an_ai(self) -> None: + err = f"Expected an AI message got {str(SystemMessage)}" + with pytest.raises(TypeError, match=err): + _parse_ai_message(SystemMessage(content="x")) + + # Test: Model response (not a function call). + def test_model_response(self) -> None: + msg = AIMessage(content="Model response.") + result = _parse_ai_message(msg) + + assert isinstance(result, AgentFinish) + assert result.return_values == {"output": "Model response."} + assert result.log == "Model response." + + # Test: Model response with a function call. + def test_func_call(self) -> None: + act = json.dumps([{"action_name": "foo", "action": {"param": 42}}]) + + msg = AIMessage( + content="LLM thoughts.", + additional_kwargs={ + "function_call": {"name": "foo", "arguments": f'{{"actions": {act}}}'} + }, + ) + result = _parse_ai_message(msg) + + assert isinstance(result, list) + assert len(result) == 1 + + action = result[0] + assert isinstance(action, _FunctionsAgentAction) + assert action.tool == "foo" + assert action.tool_input == {"param": 42} + assert action.log == ( + "\nInvoking: `foo` with `{'param': 42}`\nresponded: LLM thoughts.\n\n" + ) + assert action.message_log == [msg] + + # Test: Model response with a function call (old style tools). + def test_func_call_oldstyle(self) -> None: + act = json.dumps([{"action_name": "foo", "action": {"__arg1": "42"}}]) + + msg = AIMessage( + content="LLM thoughts.", + additional_kwargs={ + "function_call": {"name": "foo", "arguments": f'{{"actions": {act}}}'} + }, + ) + result = _parse_ai_message(msg) + + assert isinstance(result, list) + assert len(result) == 1 + + action = result[0] + assert isinstance(action, _FunctionsAgentAction) + assert action.tool == "foo" + assert action.tool_input == "42" + assert action.log == ( + "\nInvoking: `foo` with `42`\nresponded: LLM thoughts.\n\n" + ) + assert action.message_log == [msg] + + # Test: Invalid function call args. + def test_func_call_invalid(self) -> None: + msg = AIMessage( + content="LLM thoughts.", + additional_kwargs={"function_call": {"name": "foo", "arguments": "{42]"}}, + ) + + err = ( + "Could not parse tool input: {'name': 'foo', 'arguments': '{42]'} " + "because the `arguments` is not valid JSON." + ) + with pytest.raises(OutputParserException, match=err): + _parse_ai_message(msg)