Skip to content

Commit

Permalink
Add UserProxyAgent in AgentChat API (#4255)
Browse files Browse the repository at this point in the history
* initial addition of a user proxy agent in agentchat, related to #3614

* fix typing/mypy errors

* format fixes

* format and pyright checks

* update, add support for returning handoff message, add tests

---------

Co-authored-by: Ryan Sweet <[email protected]>
Co-authored-by: Hussein Mozannar <[email protected]>
  • Loading branch information
3 people authored Nov 24, 2024
1 parent c9835f3 commit 0ff1687
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from ._coding_assistant_agent import CodingAssistantAgent
from ._society_of_mind_agent import SocietyOfMindAgent
from ._tool_use_assistant_agent import ToolUseAssistantAgent
from ._user_proxy_agent import UserProxyAgent

__all__ = [
"BaseChatAgent",
Expand All @@ -13,4 +14,5 @@
"CodingAssistantAgent",
"ToolUseAssistantAgent",
"SocietyOfMindAgent",
"UserProxyAgent",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import asyncio
from inspect import iscoroutinefunction
from typing import Awaitable, Callable, List, Optional, Sequence, Union, cast

from autogen_core.base import CancellationToken

from ..base import Response
from ..messages import ChatMessage, HandoffMessage, TextMessage
from ._base_chat_agent import BaseChatAgent

# Define input function types more precisely
SyncInputFunc = Callable[[str], str]
AsyncInputFunc = Callable[[str, Optional[CancellationToken]], Awaitable[str]]
InputFuncType = Union[SyncInputFunc, AsyncInputFunc]


class UserProxyAgent(BaseChatAgent):
"""An agent that can represent a human user in a chat."""

def __init__(
self,
name: str,
description: str = "a human user",
input_func: Optional[InputFuncType] = None,
) -> None:
"""Initialize the UserProxyAgent."""
super().__init__(name=name, description=description)
self.input_func = input_func or input
self._is_async = iscoroutinefunction(self.input_func)

@property
def produced_message_types(self) -> List[type[ChatMessage]]:
"""Message types this agent can produce."""
return [TextMessage, HandoffMessage]

def _get_latest_handoff(self, messages: Sequence[ChatMessage]) -> Optional[HandoffMessage]:
"""Find the most recent HandoffMessage in the message sequence."""
for message in reversed(messages):
if isinstance(message, HandoffMessage):
return message
return None

async def _get_input(self, prompt: str, cancellation_token: Optional[CancellationToken]) -> str:
"""Handle input based on function signature."""
try:
if self._is_async:
# Cast to AsyncInputFunc for proper typing
async_func = cast(AsyncInputFunc, self.input_func)
return await async_func(prompt, cancellation_token)
else:
# Cast to SyncInputFunc for proper typing
sync_func = cast(SyncInputFunc, self.input_func)
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, sync_func, prompt)

except asyncio.CancelledError:
raise
except Exception as e:
raise RuntimeError(f"Failed to get user input: {str(e)}") from e

async def on_messages(
self, messages: Sequence[ChatMessage], cancellation_token: Optional[CancellationToken] = None
) -> Response:
"""Handle incoming messages by requesting user input."""
try:
# Check for handoff first
handoff = self._get_latest_handoff(messages)
prompt = (
f"Handoff received from {handoff.source}. Enter your response: " if handoff else "Enter your response: "
)

user_input = await self._get_input(prompt, cancellation_token)

# Return appropriate message type based on handoff presence
if handoff:
return Response(
chat_message=HandoffMessage(content=user_input, target=handoff.source, source=self.name)
)
else:
return Response(chat_message=TextMessage(content=user_input, source=self.name))

except asyncio.CancelledError:
raise
except Exception as e:
raise RuntimeError(f"Failed to get user input: {str(e)}") from e

async def on_reset(self, cancellation_token: Optional[CancellationToken] = None) -> None:
"""Reset agent state."""
pass
103 changes: 103 additions & 0 deletions python/packages/autogen-agentchat/tests/test_userproxy_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import asyncio
from typing import Optional, Sequence

import pytest
from autogen_agentchat.agents import UserProxyAgent
from autogen_agentchat.base import Response
from autogen_agentchat.messages import ChatMessage, HandoffMessage, TextMessage
from autogen_core.base import CancellationToken


@pytest.mark.asyncio
async def test_basic_input() -> None:
"""Test basic message handling with custom input"""

def custom_input(prompt: str) -> str:
return "The height of the eiffel tower is 324 meters. Aloha!"

agent = UserProxyAgent(name="test_user", input_func=custom_input)
messages = [TextMessage(content="What is the height of the eiffel tower?", source="assistant")]

response = await agent.on_messages(messages, CancellationToken())

assert isinstance(response, Response)
assert isinstance(response.chat_message, TextMessage)
assert response.chat_message.content == "The height of the eiffel tower is 324 meters. Aloha!"
assert response.chat_message.source == "test_user"


@pytest.mark.asyncio
async def test_async_input() -> None:
"""Test handling of async input function"""

async def async_input(prompt: str, token: Optional[CancellationToken] = None) -> str:
await asyncio.sleep(0.1)
return "async response"

agent = UserProxyAgent(name="test_user", input_func=async_input)
messages = [TextMessage(content="test prompt", source="assistant")]

response = await agent.on_messages(messages, CancellationToken())

assert isinstance(response.chat_message, TextMessage)
assert response.chat_message.content == "async response"
assert response.chat_message.source == "test_user"


@pytest.mark.asyncio
async def test_handoff_handling() -> None:
"""Test handling of handoff messages"""

def custom_input(prompt: str) -> str:
return "handoff response"

agent = UserProxyAgent(name="test_user", input_func=custom_input)

messages: Sequence[ChatMessage] = [
TextMessage(content="Initial message", source="assistant"),
HandoffMessage(content="Handing off to user for confirmation", source="assistant", target="test_user"),
]

response = await agent.on_messages(messages, CancellationToken())

assert isinstance(response.chat_message, HandoffMessage)
assert response.chat_message.content == "handoff response"
assert response.chat_message.source == "test_user"
assert response.chat_message.target == "assistant"


@pytest.mark.asyncio
async def test_cancellation() -> None:
"""Test cancellation during message handling"""

async def cancellable_input(prompt: str, token: Optional[CancellationToken] = None) -> str:
await asyncio.sleep(0.1)
if token and token.is_cancelled():
raise asyncio.CancelledError()
return "cancellable response"

agent = UserProxyAgent(name="test_user", input_func=cancellable_input)
messages = [TextMessage(content="test prompt", source="assistant")]
token = CancellationToken()

async def cancel_after_delay() -> None:
await asyncio.sleep(0.05)
token.cancel()

with pytest.raises(asyncio.CancelledError):
await asyncio.gather(agent.on_messages(messages, token), cancel_after_delay())


@pytest.mark.asyncio
async def test_error_handling() -> None:
"""Test error handling with problematic input function"""

def failing_input(_: str) -> str:
raise ValueError("Input function failed")

agent = UserProxyAgent(name="test_user", input_func=failing_input)
messages = [TextMessage(content="test prompt", source="assistant")]

with pytest.raises(RuntimeError) as exc_info:
await agent.on_messages(messages, CancellationToken())
assert "Failed to get user input" in str(exc_info.value)

0 comments on commit 0ff1687

Please sign in to comment.