forked from langchain-ai/langchain
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix OpenAIFunctionsAgent function call message content retrieving (la…
…ngchain-ai#10488) `langchain.agents.openai_functions[_multi]_agent._parse_ai_message()` incorrectly extracts AI message content, thus LLM response ("thoughts") is lost and can't be logged or processed by callbacks. This PR fixes function call message content retrieving.
- Loading branch information
Showing
4 changed files
with
168 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
76 changes: 76 additions & 0 deletions
76
libs/langchain/tests/unit_tests/agents/test_openai_functions.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import pytest | ||
|
||
from langchain.agents.openai_functions_agent.base import ( | ||
_FunctionsAgentAction, | ||
_parse_ai_message, | ||
) | ||
from langchain.schema import AgentFinish, OutputParserException | ||
from langchain.schema.messages import AIMessage, SystemMessage | ||
|
||
|
||
# Test: _parse_ai_message() function. | ||
class TestParseAIMessage: | ||
# Test: Pass Non-AIMessage. | ||
def test_not_an_ai(self) -> None: | ||
err = f"Expected an AI message got {str(SystemMessage)}" | ||
with pytest.raises(TypeError, match=err): | ||
_parse_ai_message(SystemMessage(content="x")) | ||
|
||
# Test: Model response (not a function call). | ||
def test_model_response(self) -> None: | ||
msg = AIMessage(content="Model response.") | ||
result = _parse_ai_message(msg) | ||
|
||
assert isinstance(result, AgentFinish) | ||
assert result.return_values == {"output": "Model response."} | ||
assert result.log == "Model response." | ||
|
||
# Test: Model response with a function call. | ||
def test_func_call(self) -> None: | ||
msg = AIMessage( | ||
content="LLM thoughts.", | ||
additional_kwargs={ | ||
"function_call": {"name": "foo", "arguments": '{"param": 42}'} | ||
}, | ||
) | ||
result = _parse_ai_message(msg) | ||
|
||
assert isinstance(result, _FunctionsAgentAction) | ||
assert result.tool == "foo" | ||
assert result.tool_input == {"param": 42} | ||
assert result.log == ( | ||
"\nInvoking: `foo` with `{'param': 42}`\nresponded: LLM thoughts.\n\n" | ||
) | ||
assert result.message_log == [msg] | ||
|
||
# Test: Model response with a function call (old style tools). | ||
def test_func_call_oldstyle(self) -> None: | ||
msg = AIMessage( | ||
content="LLM thoughts.", | ||
additional_kwargs={ | ||
"function_call": {"name": "foo", "arguments": '{"__arg1": "42"}'} | ||
}, | ||
) | ||
result = _parse_ai_message(msg) | ||
|
||
assert isinstance(result, _FunctionsAgentAction) | ||
assert result.tool == "foo" | ||
assert result.tool_input == "42" | ||
assert result.log == ( | ||
"\nInvoking: `foo` with `42`\nresponded: LLM thoughts.\n\n" | ||
) | ||
assert result.message_log == [msg] | ||
|
||
# Test: Invalid function call args. | ||
def test_func_call_invalid(self) -> None: | ||
msg = AIMessage( | ||
content="LLM thoughts.", | ||
additional_kwargs={"function_call": {"name": "foo", "arguments": "{42]"}}, | ||
) | ||
|
||
err = ( | ||
"Could not parse tool input: {'name': 'foo', 'arguments': '{42]'} " | ||
"because the `arguments` is not valid JSON." | ||
) | ||
with pytest.raises(OutputParserException, match=err): | ||
_parse_ai_message(msg) |
90 changes: 90 additions & 0 deletions
90
libs/langchain/tests/unit_tests/agents/test_openai_functions_multi.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import json | ||
|
||
import pytest | ||
|
||
from langchain.agents.openai_functions_multi_agent.base import ( | ||
_FunctionsAgentAction, | ||
_parse_ai_message, | ||
) | ||
from langchain.schema import AgentFinish, OutputParserException | ||
from langchain.schema.messages import AIMessage, SystemMessage | ||
|
||
|
||
# Test: _parse_ai_message() function. | ||
class TestParseAIMessage: | ||
# Test: Pass Non-AIMessage. | ||
def test_not_an_ai(self) -> None: | ||
err = f"Expected an AI message got {str(SystemMessage)}" | ||
with pytest.raises(TypeError, match=err): | ||
_parse_ai_message(SystemMessage(content="x")) | ||
|
||
# Test: Model response (not a function call). | ||
def test_model_response(self) -> None: | ||
msg = AIMessage(content="Model response.") | ||
result = _parse_ai_message(msg) | ||
|
||
assert isinstance(result, AgentFinish) | ||
assert result.return_values == {"output": "Model response."} | ||
assert result.log == "Model response." | ||
|
||
# Test: Model response with a function call. | ||
def test_func_call(self) -> None: | ||
act = json.dumps([{"action_name": "foo", "action": {"param": 42}}]) | ||
|
||
msg = AIMessage( | ||
content="LLM thoughts.", | ||
additional_kwargs={ | ||
"function_call": {"name": "foo", "arguments": f'{{"actions": {act}}}'} | ||
}, | ||
) | ||
result = _parse_ai_message(msg) | ||
|
||
assert isinstance(result, list) | ||
assert len(result) == 1 | ||
|
||
action = result[0] | ||
assert isinstance(action, _FunctionsAgentAction) | ||
assert action.tool == "foo" | ||
assert action.tool_input == {"param": 42} | ||
assert action.log == ( | ||
"\nInvoking: `foo` with `{'param': 42}`\nresponded: LLM thoughts.\n\n" | ||
) | ||
assert action.message_log == [msg] | ||
|
||
# Test: Model response with a function call (old style tools). | ||
def test_func_call_oldstyle(self) -> None: | ||
act = json.dumps([{"action_name": "foo", "action": {"__arg1": "42"}}]) | ||
|
||
msg = AIMessage( | ||
content="LLM thoughts.", | ||
additional_kwargs={ | ||
"function_call": {"name": "foo", "arguments": f'{{"actions": {act}}}'} | ||
}, | ||
) | ||
result = _parse_ai_message(msg) | ||
|
||
assert isinstance(result, list) | ||
assert len(result) == 1 | ||
|
||
action = result[0] | ||
assert isinstance(action, _FunctionsAgentAction) | ||
assert action.tool == "foo" | ||
assert action.tool_input == "42" | ||
assert action.log == ( | ||
"\nInvoking: `foo` with `42`\nresponded: LLM thoughts.\n\n" | ||
) | ||
assert action.message_log == [msg] | ||
|
||
# Test: Invalid function call args. | ||
def test_func_call_invalid(self) -> None: | ||
msg = AIMessage( | ||
content="LLM thoughts.", | ||
additional_kwargs={"function_call": {"name": "foo", "arguments": "{42]"}}, | ||
) | ||
|
||
err = ( | ||
"Could not parse tool input: {'name': 'foo', 'arguments': '{42]'} " | ||
"because the `arguments` is not valid JSON." | ||
) | ||
with pytest.raises(OutputParserException, match=err): | ||
_parse_ai_message(msg) |