diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/__init__.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/__init__.py index cd435bf0228..4cff9f45822 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/__init__.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/__init__.py @@ -4,6 +4,7 @@ from ._coding_assistant_agent import CodingAssistantAgent from ._society_of_mind_agent import SocietyOfMindAgent from ._tool_use_assistant_agent import ToolUseAssistantAgent +from ._user_proxy_agent import UserProxyAgent __all__ = [ "BaseChatAgent", @@ -13,4 +14,5 @@ "CodingAssistantAgent", "ToolUseAssistantAgent", "SocietyOfMindAgent", + "UserProxyAgent", ] diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_user_proxy_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_user_proxy_agent.py new file mode 100644 index 00000000000..bdaca53ddc6 --- /dev/null +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_user_proxy_agent.py @@ -0,0 +1,89 @@ +import asyncio +from inspect import iscoroutinefunction +from typing import Awaitable, Callable, List, Optional, Sequence, Union, cast + +from autogen_core.base import CancellationToken + +from ..base import Response +from ..messages import ChatMessage, HandoffMessage, TextMessage +from ._base_chat_agent import BaseChatAgent + +# Define input function types more precisely +SyncInputFunc = Callable[[str], str] +AsyncInputFunc = Callable[[str, Optional[CancellationToken]], Awaitable[str]] +InputFuncType = Union[SyncInputFunc, AsyncInputFunc] + + +class UserProxyAgent(BaseChatAgent): + """An agent that can represent a human user in a chat.""" + + def __init__( + self, + name: str, + description: str = "a human user", + input_func: Optional[InputFuncType] = None, + ) -> None: + """Initialize the UserProxyAgent.""" + super().__init__(name=name, description=description) + self.input_func = input_func or input + self._is_async = iscoroutinefunction(self.input_func) + + @property + def produced_message_types(self) -> List[type[ChatMessage]]: + """Message types this agent can produce.""" + return [TextMessage, HandoffMessage] + + def _get_latest_handoff(self, messages: Sequence[ChatMessage]) -> Optional[HandoffMessage]: + """Find the most recent HandoffMessage in the message sequence.""" + for message in reversed(messages): + if isinstance(message, HandoffMessage): + return message + return None + + async def _get_input(self, prompt: str, cancellation_token: Optional[CancellationToken]) -> str: + """Handle input based on function signature.""" + try: + if self._is_async: + # Cast to AsyncInputFunc for proper typing + async_func = cast(AsyncInputFunc, self.input_func) + return await async_func(prompt, cancellation_token) + else: + # Cast to SyncInputFunc for proper typing + sync_func = cast(SyncInputFunc, self.input_func) + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, sync_func, prompt) + + except asyncio.CancelledError: + raise + except Exception as e: + raise RuntimeError(f"Failed to get user input: {str(e)}") from e + + async def on_messages( + self, messages: Sequence[ChatMessage], cancellation_token: Optional[CancellationToken] = None + ) -> Response: + """Handle incoming messages by requesting user input.""" + try: + # Check for handoff first + handoff = self._get_latest_handoff(messages) + prompt = ( + f"Handoff received from {handoff.source}. Enter your response: " if handoff else "Enter your response: " + ) + + user_input = await self._get_input(prompt, cancellation_token) + + # Return appropriate message type based on handoff presence + if handoff: + return Response( + chat_message=HandoffMessage(content=user_input, target=handoff.source, source=self.name) + ) + else: + return Response(chat_message=TextMessage(content=user_input, source=self.name)) + + except asyncio.CancelledError: + raise + except Exception as e: + raise RuntimeError(f"Failed to get user input: {str(e)}") from e + + async def on_reset(self, cancellation_token: Optional[CancellationToken] = None) -> None: + """Reset agent state.""" + pass diff --git a/python/packages/autogen-agentchat/tests/test_userproxy_agent.py b/python/packages/autogen-agentchat/tests/test_userproxy_agent.py new file mode 100644 index 00000000000..2ef3053f09b --- /dev/null +++ b/python/packages/autogen-agentchat/tests/test_userproxy_agent.py @@ -0,0 +1,103 @@ +import asyncio +from typing import Optional, Sequence + +import pytest +from autogen_agentchat.agents import UserProxyAgent +from autogen_agentchat.base import Response +from autogen_agentchat.messages import ChatMessage, HandoffMessage, TextMessage +from autogen_core.base import CancellationToken + + +@pytest.mark.asyncio +async def test_basic_input() -> None: + """Test basic message handling with custom input""" + + def custom_input(prompt: str) -> str: + return "The height of the eiffel tower is 324 meters. Aloha!" + + agent = UserProxyAgent(name="test_user", input_func=custom_input) + messages = [TextMessage(content="What is the height of the eiffel tower?", source="assistant")] + + response = await agent.on_messages(messages, CancellationToken()) + + assert isinstance(response, Response) + assert isinstance(response.chat_message, TextMessage) + assert response.chat_message.content == "The height of the eiffel tower is 324 meters. Aloha!" + assert response.chat_message.source == "test_user" + + +@pytest.mark.asyncio +async def test_async_input() -> None: + """Test handling of async input function""" + + async def async_input(prompt: str, token: Optional[CancellationToken] = None) -> str: + await asyncio.sleep(0.1) + return "async response" + + agent = UserProxyAgent(name="test_user", input_func=async_input) + messages = [TextMessage(content="test prompt", source="assistant")] + + response = await agent.on_messages(messages, CancellationToken()) + + assert isinstance(response.chat_message, TextMessage) + assert response.chat_message.content == "async response" + assert response.chat_message.source == "test_user" + + +@pytest.mark.asyncio +async def test_handoff_handling() -> None: + """Test handling of handoff messages""" + + def custom_input(prompt: str) -> str: + return "handoff response" + + agent = UserProxyAgent(name="test_user", input_func=custom_input) + + messages: Sequence[ChatMessage] = [ + TextMessage(content="Initial message", source="assistant"), + HandoffMessage(content="Handing off to user for confirmation", source="assistant", target="test_user"), + ] + + response = await agent.on_messages(messages, CancellationToken()) + + assert isinstance(response.chat_message, HandoffMessage) + assert response.chat_message.content == "handoff response" + assert response.chat_message.source == "test_user" + assert response.chat_message.target == "assistant" + + +@pytest.mark.asyncio +async def test_cancellation() -> None: + """Test cancellation during message handling""" + + async def cancellable_input(prompt: str, token: Optional[CancellationToken] = None) -> str: + await asyncio.sleep(0.1) + if token and token.is_cancelled(): + raise asyncio.CancelledError() + return "cancellable response" + + agent = UserProxyAgent(name="test_user", input_func=cancellable_input) + messages = [TextMessage(content="test prompt", source="assistant")] + token = CancellationToken() + + async def cancel_after_delay() -> None: + await asyncio.sleep(0.05) + token.cancel() + + with pytest.raises(asyncio.CancelledError): + await asyncio.gather(agent.on_messages(messages, token), cancel_after_delay()) + + +@pytest.mark.asyncio +async def test_error_handling() -> None: + """Test error handling with problematic input function""" + + def failing_input(_: str) -> str: + raise ValueError("Input function failed") + + agent = UserProxyAgent(name="test_user", input_func=failing_input) + messages = [TextMessage(content="test prompt", source="assistant")] + + with pytest.raises(RuntimeError) as exc_info: + await agent.on_messages(messages, CancellationToken()) + assert "Failed to get user input" in str(exc_info.value)