From a271708a97bf754a5d94860554d78677f85a78fe Mon Sep 17 00:00:00 2001 From: jspv Date: Fri, 20 Dec 2024 00:23:18 -0500 Subject: [PATCH] Tool call result summary message (#4755) * Adding ToolCallResultSummaryMessage * Support for ToolCallResultSummaryMessage * Added ToolCallSummaryMessage * ruff format * Add ToolCallSummaryMessage to ChatMessage * typing and tests for ToolCallSummaryMessage * PR Feedback --------- Co-authored-by: Eric Zhu Co-authored-by: Hussein Mozannar --- .../agents/_assistant_agent.py | 12 ++++++---- .../src/autogen_agentchat/messages.py | 22 +++++++++++++++++-- .../_magentic_one_orchestrator.py | 3 ++- .../teams/_group_chat/_selector_group_chat.py | 3 ++- .../tests/test_assistant_agent.py | 3 ++- .../tests/test_group_chat.py | 4 +++- 6 files changed, 37 insertions(+), 10 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py index 7afa7f48a51..d421d7ccb2c 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py @@ -28,6 +28,7 @@ TextMessage, ToolCallExecutionEvent, ToolCallRequestEvent, + ToolCallSummaryMessage, ) from ..state import AssistantAgentState from ._base_chat_agent import BaseChatAgent @@ -62,7 +63,7 @@ class AssistantAgent(BaseChatAgent): * If the model returns no tool call, then the response is immediately returned as a :class:`~autogen_agentchat.messages.TextMessage` in :attr:`~autogen_agentchat.base.Response.chat_message`. * When the model returns tool calls, they will be executed right away: - - When `reflect_on_tool_use` is False (default), the tool call results are returned as a :class:`~autogen_agentchat.messages.TextMessage` in :attr:`~autogen_agentchat.base.Response.chat_message`. `tool_call_summary_format` can be used to customize the tool call summary. + - When `reflect_on_tool_use` is False (default), the tool call results are returned as a :class:`~autogen_agentchat.messages.ToolCallSummaryMessage` in :attr:`~autogen_agentchat.base.Response.chat_message`. `tool_call_summary_format` can be used to customize the tool call summary. - When `reflect_on_tool_use` is True, the another model inference is made using the tool calls and results, and the text response is returned as a :class:`~autogen_agentchat.messages.TextMessage` in :attr:`~autogen_agentchat.base.Response.chat_message`. Hand off behavior: @@ -280,9 +281,12 @@ def __init__( @property def produced_message_types(self) -> List[type[ChatMessage]]: """The types of messages that the assistant agent produces.""" + message_types: List[type[ChatMessage]] = [TextMessage] if self._handoffs: - return [TextMessage, HandoffMessage] - return [TextMessage] + message_types.append(HandoffMessage) + if self._tools: + message_types.append(ToolCallSummaryMessage) + return message_types async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response: async for message in self.on_messages_stream(messages, cancellation_token): @@ -379,7 +383,7 @@ async def on_messages_stream( ) tool_call_summary = "\n".join(tool_call_summaries) yield Response( - chat_message=TextMessage(content=tool_call_summary, source=self.name), + chat_message=ToolCallSummaryMessage(content=tool_call_summary, source=self.name), inner_messages=inner_messages, ) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py index 547b4a7fab9..7237812ca72 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py @@ -101,7 +101,18 @@ class ToolCallExecutionEvent(BaseMessage): type: Literal["ToolCallExecutionEvent"] = "ToolCallExecutionEvent" -ChatMessage = Annotated[TextMessage | MultiModalMessage | StopMessage | HandoffMessage, Field(discriminator="type")] +class ToolCallSummaryMessage(BaseMessage): + """A message signaling the summary of tool call results.""" + + content: str + """Summary of the the tool call results.""" + + type: Literal["ToolCallSummaryMessage"] = "ToolCallSummaryMessage" + + +ChatMessage = Annotated[ + TextMessage | MultiModalMessage | StopMessage | ToolCallSummaryMessage | HandoffMessage, Field(discriminator="type") +] """Messages for agent-to-agent communication only.""" @@ -110,7 +121,13 @@ class ToolCallExecutionEvent(BaseMessage): AgentMessage = Annotated[ - TextMessage | MultiModalMessage | StopMessage | HandoffMessage | ToolCallRequestEvent | ToolCallExecutionEvent, + TextMessage + | MultiModalMessage + | StopMessage + | HandoffMessage + | ToolCallRequestEvent + | ToolCallExecutionEvent + | ToolCallSummaryMessage, Field(discriminator="type"), ] """(Deprecated, will be removed in 0.4.0) All message and event types.""" @@ -126,6 +143,7 @@ class ToolCallExecutionEvent(BaseMessage): "ToolCallExecutionEvent", "ToolCallMessage", "ToolCallResultMessage", + "ToolCallSummaryMessage", "ChatMessage", "AgentEvent", "AgentMessage", diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py index 6e0d12f7f75..d405bab5b13 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py @@ -21,6 +21,7 @@ TextMessage, ToolCallExecutionEvent, ToolCallRequestEvent, + ToolCallSummaryMessage, ) from ....state import MagenticOneOrchestratorState from .._base_group_chat_manager import BaseGroupChatManager @@ -433,7 +434,7 @@ def _thread_to_context(self) -> List[LLMMessage]: elif isinstance(m, StopMessage | HandoffMessage): context.append(UserMessage(content=m.content, source=m.source)) elif m.source == self._name: - assert isinstance(m, TextMessage) + assert isinstance(m, TextMessage | ToolCallSummaryMessage) context.append(AssistantMessage(content=m.content, source=m.source)) else: assert isinstance(m, TextMessage) or isinstance(m, MultiModalMessage) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py index 735e533c908..be8ec726c30 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py @@ -15,6 +15,7 @@ TextMessage, ToolCallExecutionEvent, ToolCallRequestEvent, + ToolCallSummaryMessage, ) from ...state import SelectorManagerState from ._base_group_chat import BaseGroupChat @@ -100,7 +101,7 @@ async def select_speaker(self, thread: List[AgentEvent | ChatMessage]) -> str: continue # The agent type must be the same as the topic type, which we use as the agent name. message = f"{msg.source}:" - if isinstance(msg, TextMessage | StopMessage | HandoffMessage): + if isinstance(msg, TextMessage | StopMessage | HandoffMessage | ToolCallSummaryMessage): message += f" {msg.content}" elif isinstance(msg, MultiModalMessage): for item in msg.content: diff --git a/python/packages/autogen-agentchat/tests/test_assistant_agent.py b/python/packages/autogen-agentchat/tests/test_assistant_agent.py index 67969cfce8a..9065d513918 100644 --- a/python/packages/autogen-agentchat/tests/test_assistant_agent.py +++ b/python/packages/autogen-agentchat/tests/test_assistant_agent.py @@ -14,6 +14,7 @@ TextMessage, ToolCallExecutionEvent, ToolCallRequestEvent, + ToolCallSummaryMessage, ) from autogen_core import Image from autogen_core.tools import FunctionTool @@ -142,7 +143,7 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None: assert result.messages[1].models_usage.prompt_tokens == 10 assert isinstance(result.messages[2], ToolCallExecutionEvent) assert result.messages[2].models_usage is None - assert isinstance(result.messages[3], TextMessage) + assert isinstance(result.messages[3], ToolCallSummaryMessage) assert result.messages[3].content == "pass" assert result.messages[3].models_usage is None diff --git a/python/packages/autogen-agentchat/tests/test_group_chat.py b/python/packages/autogen-agentchat/tests/test_group_chat.py index d3a0f2e56f9..6d2fe29beb8 100644 --- a/python/packages/autogen-agentchat/tests/test_group_chat.py +++ b/python/packages/autogen-agentchat/tests/test_group_chat.py @@ -22,6 +22,7 @@ TextMessage, ToolCallExecutionEvent, ToolCallRequestEvent, + ToolCallSummaryMessage, ) from autogen_agentchat.teams import ( RoundRobinGroupChat, @@ -325,7 +326,8 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch assert isinstance(result.messages[0], TextMessage) # task assert isinstance(result.messages[1], ToolCallRequestEvent) # tool call assert isinstance(result.messages[2], ToolCallExecutionEvent) # tool call result - assert isinstance(result.messages[3], TextMessage) # tool use agent response + assert isinstance(result.messages[3], ToolCallSummaryMessage) # tool use agent response + assert result.messages[3].content == "pass" # ensure the tool call was executed assert isinstance(result.messages[4], TextMessage) # echo agent response assert isinstance(result.messages[5], TextMessage) # tool use agent response assert isinstance(result.messages[6], TextMessage) # echo agent response