From 23dbb6a632232d867b8482e18d25dfc05f337480 Mon Sep 17 00:00:00 2001 From: Leon De Andrade Date: Sun, 29 Dec 2024 07:50:54 +0100 Subject: [PATCH] Add missing model context attribute (#4848) * Add missing model context attribute * fix type * Add test * imports --------- Co-authored-by: Eric Zhu --- .../agents/_assistant_agent.py | 4 +- .../tests/test_assistant_agent.py | 44 +++++++++++++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py index 9be3adcdc994..3b92a4a51b03 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py @@ -296,7 +296,9 @@ def __init__( raise ValueError( f"Handoff names must be unique from tool names. Handoff names: {handoff_tool_names}; tool names: {tool_names}" ) - if not model_context: + if model_context is not None: + self._model_context = model_context + else: self._model_context = UnboundedChatCompletionContext() self._reflect_on_tool_use = reflect_on_tool_use self._tool_call_summary_format = tool_call_summary_format diff --git a/python/packages/autogen-agentchat/tests/test_assistant_agent.py b/python/packages/autogen-agentchat/tests/test_assistant_agent.py index 9065d5139180..48f51c4712ed 100644 --- a/python/packages/autogen-agentchat/tests/test_assistant_agent.py +++ b/python/packages/autogen-agentchat/tests/test_assistant_agent.py @@ -17,6 +17,8 @@ ToolCallSummaryMessage, ) from autogen_core import Image +from autogen_core.model_context import BufferedChatCompletionContext +from autogen_core.models import LLMMessage from autogen_core.tools import FunctionTool from autogen_ext.models.openai import OpenAIChatCompletionClient from openai.resources.chat.completions import AsyncCompletions @@ -39,10 +41,12 @@ class _MockChatCompletion: def __init__(self, chat_completions: List[ChatCompletion]) -> None: self._saved_chat_completions = chat_completions self.curr_index = 0 + self.calls: List[List[LLMMessage]] = [] async def mock_create( self, *args: Any, **kwargs: Any ) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]: + self.calls.append(kwargs["messages"]) # Save the call await asyncio.sleep(0.1) completion = self._saved_chat_completions[self.curr_index] self.curr_index += 1 @@ -468,3 +472,43 @@ async def test_list_chat_messages(monkeypatch: pytest.MonkeyPatch) -> None: else: assert message == result.messages[index] index += 1 + + +@pytest.mark.asyncio +async def test_model_context(monkeypatch: pytest.MonkeyPatch) -> None: + model = "gpt-4o-2024-05-13" + chat_completions = [ + ChatCompletion( + id="id1", + choices=[ + Choice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage(content="Response to message 3", role="assistant"), + ) + ], + created=0, + model=model, + object="chat.completion", + usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15), + ), + ] + mock = _MockChatCompletion(chat_completions) + monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create) + model_context = BufferedChatCompletionContext(buffer_size=2) + agent = AssistantAgent( + "test_agent", + model_client=OpenAIChatCompletionClient(model=model, api_key=""), + model_context=model_context, + ) + + messages = [ + TextMessage(content="Message 1", source="user"), + TextMessage(content="Message 2", source="user"), + TextMessage(content="Message 3", source="user"), + ] + await agent.run(task=messages) + + # Check if the mock client is called with only the last two messages. + assert len(mock.calls) == 1 + assert len(mock.calls[0]) == 3 # 2 message from the context + 1 system message