-
Notifications
You must be signed in to change notification settings - Fork 5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add UserProxyAgent in AgentChat API (#4255)
* initial addition of a user proxy agent in agentchat, related to #3614 * fix typing/mypy errors * format fixes * format and pyright checks * update, add support for returning handoff message, add tests --------- Co-authored-by: Ryan Sweet <[email protected]> Co-authored-by: Hussein Mozannar <[email protected]>
- Loading branch information
1 parent
c9835f3
commit 0ff1687
Showing
3 changed files
with
194 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
89 changes: 89 additions & 0 deletions
89
python/packages/autogen-agentchat/src/autogen_agentchat/agents/_user_proxy_agent.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import asyncio | ||
from inspect import iscoroutinefunction | ||
from typing import Awaitable, Callable, List, Optional, Sequence, Union, cast | ||
|
||
from autogen_core.base import CancellationToken | ||
|
||
from ..base import Response | ||
from ..messages import ChatMessage, HandoffMessage, TextMessage | ||
from ._base_chat_agent import BaseChatAgent | ||
|
||
# Define input function types more precisely | ||
SyncInputFunc = Callable[[str], str] | ||
AsyncInputFunc = Callable[[str, Optional[CancellationToken]], Awaitable[str]] | ||
InputFuncType = Union[SyncInputFunc, AsyncInputFunc] | ||
|
||
|
||
class UserProxyAgent(BaseChatAgent): | ||
"""An agent that can represent a human user in a chat.""" | ||
|
||
def __init__( | ||
self, | ||
name: str, | ||
description: str = "a human user", | ||
input_func: Optional[InputFuncType] = None, | ||
) -> None: | ||
"""Initialize the UserProxyAgent.""" | ||
super().__init__(name=name, description=description) | ||
self.input_func = input_func or input | ||
self._is_async = iscoroutinefunction(self.input_func) | ||
|
||
@property | ||
def produced_message_types(self) -> List[type[ChatMessage]]: | ||
"""Message types this agent can produce.""" | ||
return [TextMessage, HandoffMessage] | ||
|
||
def _get_latest_handoff(self, messages: Sequence[ChatMessage]) -> Optional[HandoffMessage]: | ||
"""Find the most recent HandoffMessage in the message sequence.""" | ||
for message in reversed(messages): | ||
if isinstance(message, HandoffMessage): | ||
return message | ||
return None | ||
|
||
async def _get_input(self, prompt: str, cancellation_token: Optional[CancellationToken]) -> str: | ||
"""Handle input based on function signature.""" | ||
try: | ||
if self._is_async: | ||
# Cast to AsyncInputFunc for proper typing | ||
async_func = cast(AsyncInputFunc, self.input_func) | ||
return await async_func(prompt, cancellation_token) | ||
else: | ||
# Cast to SyncInputFunc for proper typing | ||
sync_func = cast(SyncInputFunc, self.input_func) | ||
loop = asyncio.get_event_loop() | ||
return await loop.run_in_executor(None, sync_func, prompt) | ||
|
||
except asyncio.CancelledError: | ||
raise | ||
except Exception as e: | ||
raise RuntimeError(f"Failed to get user input: {str(e)}") from e | ||
|
||
async def on_messages( | ||
self, messages: Sequence[ChatMessage], cancellation_token: Optional[CancellationToken] = None | ||
) -> Response: | ||
"""Handle incoming messages by requesting user input.""" | ||
try: | ||
# Check for handoff first | ||
handoff = self._get_latest_handoff(messages) | ||
prompt = ( | ||
f"Handoff received from {handoff.source}. Enter your response: " if handoff else "Enter your response: " | ||
) | ||
|
||
user_input = await self._get_input(prompt, cancellation_token) | ||
|
||
# Return appropriate message type based on handoff presence | ||
if handoff: | ||
return Response( | ||
chat_message=HandoffMessage(content=user_input, target=handoff.source, source=self.name) | ||
) | ||
else: | ||
return Response(chat_message=TextMessage(content=user_input, source=self.name)) | ||
|
||
except asyncio.CancelledError: | ||
raise | ||
except Exception as e: | ||
raise RuntimeError(f"Failed to get user input: {str(e)}") from e | ||
|
||
async def on_reset(self, cancellation_token: Optional[CancellationToken] = None) -> None: | ||
"""Reset agent state.""" | ||
pass |
103 changes: 103 additions & 0 deletions
103
python/packages/autogen-agentchat/tests/test_userproxy_agent.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import asyncio | ||
from typing import Optional, Sequence | ||
|
||
import pytest | ||
from autogen_agentchat.agents import UserProxyAgent | ||
from autogen_agentchat.base import Response | ||
from autogen_agentchat.messages import ChatMessage, HandoffMessage, TextMessage | ||
from autogen_core.base import CancellationToken | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_basic_input() -> None: | ||
"""Test basic message handling with custom input""" | ||
|
||
def custom_input(prompt: str) -> str: | ||
return "The height of the eiffel tower is 324 meters. Aloha!" | ||
|
||
agent = UserProxyAgent(name="test_user", input_func=custom_input) | ||
messages = [TextMessage(content="What is the height of the eiffel tower?", source="assistant")] | ||
|
||
response = await agent.on_messages(messages, CancellationToken()) | ||
|
||
assert isinstance(response, Response) | ||
assert isinstance(response.chat_message, TextMessage) | ||
assert response.chat_message.content == "The height of the eiffel tower is 324 meters. Aloha!" | ||
assert response.chat_message.source == "test_user" | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_async_input() -> None: | ||
"""Test handling of async input function""" | ||
|
||
async def async_input(prompt: str, token: Optional[CancellationToken] = None) -> str: | ||
await asyncio.sleep(0.1) | ||
return "async response" | ||
|
||
agent = UserProxyAgent(name="test_user", input_func=async_input) | ||
messages = [TextMessage(content="test prompt", source="assistant")] | ||
|
||
response = await agent.on_messages(messages, CancellationToken()) | ||
|
||
assert isinstance(response.chat_message, TextMessage) | ||
assert response.chat_message.content == "async response" | ||
assert response.chat_message.source == "test_user" | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_handoff_handling() -> None: | ||
"""Test handling of handoff messages""" | ||
|
||
def custom_input(prompt: str) -> str: | ||
return "handoff response" | ||
|
||
agent = UserProxyAgent(name="test_user", input_func=custom_input) | ||
|
||
messages: Sequence[ChatMessage] = [ | ||
TextMessage(content="Initial message", source="assistant"), | ||
HandoffMessage(content="Handing off to user for confirmation", source="assistant", target="test_user"), | ||
] | ||
|
||
response = await agent.on_messages(messages, CancellationToken()) | ||
|
||
assert isinstance(response.chat_message, HandoffMessage) | ||
assert response.chat_message.content == "handoff response" | ||
assert response.chat_message.source == "test_user" | ||
assert response.chat_message.target == "assistant" | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_cancellation() -> None: | ||
"""Test cancellation during message handling""" | ||
|
||
async def cancellable_input(prompt: str, token: Optional[CancellationToken] = None) -> str: | ||
await asyncio.sleep(0.1) | ||
if token and token.is_cancelled(): | ||
raise asyncio.CancelledError() | ||
return "cancellable response" | ||
|
||
agent = UserProxyAgent(name="test_user", input_func=cancellable_input) | ||
messages = [TextMessage(content="test prompt", source="assistant")] | ||
token = CancellationToken() | ||
|
||
async def cancel_after_delay() -> None: | ||
await asyncio.sleep(0.05) | ||
token.cancel() | ||
|
||
with pytest.raises(asyncio.CancelledError): | ||
await asyncio.gather(agent.on_messages(messages, token), cancel_after_delay()) | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_error_handling() -> None: | ||
"""Test error handling with problematic input function""" | ||
|
||
def failing_input(_: str) -> str: | ||
raise ValueError("Input function failed") | ||
|
||
agent = UserProxyAgent(name="test_user", input_func=failing_input) | ||
messages = [TextMessage(content="test prompt", source="assistant")] | ||
|
||
with pytest.raises(RuntimeError) as exc_info: | ||
await agent.on_messages(messages, CancellationToken()) | ||
assert "Failed to get user input" in str(exc_info.value) |