Skip to content

Commit

Permalink
Merge branch 'main' into rysweet-python-xlang
Browse files Browse the repository at this point in the history
  • Loading branch information
rysweet authored Nov 18, 2024
2 parents dedbb39 + f1daff1 commit c591c9e
Show file tree
Hide file tree
Showing 11 changed files with 242 additions and 105 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from autogen_core.base import CancellationToken

from ..base import ChatAgent, Response, TaskResult
from ..messages import AgentMessage, ChatMessage, MultiModalMessage, TextMessage
from ..messages import AgentMessage, ChatMessage, HandoffMessage, MultiModalMessage, StopMessage, TextMessage


class BaseChatAgent(ChatAgent, ABC):
Expand Down Expand Up @@ -54,21 +54,25 @@ async def on_messages_stream(
async def run(
self,
*,
task: str | TextMessage | MultiModalMessage | None = None,
task: str | ChatMessage | None = None,
cancellation_token: CancellationToken | None = None,
) -> TaskResult:
"""Run the agent with the given task and return the result."""
if cancellation_token is None:
cancellation_token = CancellationToken()
input_messages: List[ChatMessage] = []
output_messages: List[AgentMessage] = []
if isinstance(task, str):
if task is None:
pass
elif isinstance(task, str):
text_msg = TextMessage(content=task, source="user")
input_messages.append(text_msg)
output_messages.append(text_msg)
elif isinstance(task, TextMessage | MultiModalMessage):
elif isinstance(task, TextMessage | MultiModalMessage | StopMessage | HandoffMessage):
input_messages.append(task)
output_messages.append(task)
else:
raise ValueError(f"Invalid task type: {type(task)}")
response = await self.on_messages(input_messages, cancellation_token)
if response.inner_messages is not None:
output_messages += response.inner_messages
Expand All @@ -78,7 +82,7 @@ async def run(
async def run_stream(
self,
*,
task: str | TextMessage | MultiModalMessage | None = None,
task: str | ChatMessage | None = None,
cancellation_token: CancellationToken | None = None,
) -> AsyncGenerator[AgentMessage | TaskResult, None]:
"""Run the agent with the given task and return a stream of messages
Expand All @@ -87,15 +91,19 @@ async def run_stream(
cancellation_token = CancellationToken()
input_messages: List[ChatMessage] = []
output_messages: List[AgentMessage] = []
if isinstance(task, str):
if task is None:
pass
elif isinstance(task, str):
text_msg = TextMessage(content=task, source="user")
input_messages.append(text_msg)
output_messages.append(text_msg)
yield text_msg
elif isinstance(task, TextMessage | MultiModalMessage):
elif isinstance(task, TextMessage | MultiModalMessage | StopMessage | HandoffMessage):
input_messages.append(task)
output_messages.append(task)
yield task
else:
raise ValueError(f"Invalid task type: {type(task)}")
async for message in self.on_messages_stream(input_messages, cancellation_token):
if isinstance(message, Response):
yield message.chat_message
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from autogen_core.base import CancellationToken

from ..messages import AgentMessage, MultiModalMessage, TextMessage
from ..messages import AgentMessage, ChatMessage


@dataclass
Expand All @@ -23,7 +23,7 @@ class TaskRunner(Protocol):
async def run(
self,
*,
task: str | TextMessage | MultiModalMessage | None = None,
task: str | ChatMessage | None = None,
cancellation_token: CancellationToken | None = None,
) -> TaskResult:
"""Run the task and return the result.
Expand All @@ -36,7 +36,7 @@ async def run(
def run_stream(
self,
*,
task: str | TextMessage | MultiModalMessage | None = None,
task: str | ChatMessage | None = None,
cancellation_token: CancellationToken | None = None,
) -> AsyncGenerator[AgentMessage | TaskResult, None]:
"""Run the task and produces a stream of messages and the final result
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,12 @@
AgentType,
CancellationToken,
MessageContext,
TopicId,
)
from autogen_core.components import ClosureAgent, TypeSubscription

from ... import EVENT_LOGGER_NAME
from ...base import ChatAgent, TaskResult, Team, TerminationCondition
from ...messages import AgentMessage, MultiModalMessage, TextMessage
from ...messages import AgentMessage, ChatMessage, HandoffMessage, MultiModalMessage, StopMessage, TextMessage
from ._chat_agent_container import ChatAgentContainer
from ._events import GroupChatMessage, GroupChatReset, GroupChatStart, GroupChatTermination
from ._sequential_routed_agent import SequentialRoutedAgent
Expand Down Expand Up @@ -164,7 +163,7 @@ async def collect_output_messages(
async def run(
self,
*,
task: str | TextMessage | MultiModalMessage | None = None,
task: str | ChatMessage | None = None,
cancellation_token: CancellationToken | None = None,
) -> TaskResult:
"""Run the team and return the result. The base implementation uses
Expand Down Expand Up @@ -215,7 +214,7 @@ async def main() -> None:
async def run_stream(
self,
*,
task: str | TextMessage | MultiModalMessage | None = None,
task: str | ChatMessage | None = None,
cancellation_token: CancellationToken | None = None,
) -> AsyncGenerator[AgentMessage | TaskResult, None]:
"""Run the team and produces a stream of messages and the final result
Expand Down Expand Up @@ -253,6 +252,16 @@ async def main() -> None:
asyncio.run(main())
"""
# Create the first chat message if the task is a string or a chat message.
first_chat_message: ChatMessage | None = None
if task is None:
pass
elif isinstance(task, str):
first_chat_message = TextMessage(content=task, source="user")
elif isinstance(task, TextMessage | MultiModalMessage | StopMessage | HandoffMessage):
first_chat_message = task
else:
raise ValueError(f"Invalid task type: {type(task)}")

if self._is_running:
raise ValueError("The team is already running, it cannot run again until it is stopped.")
Expand All @@ -265,42 +274,44 @@ async def main() -> None:
if not self._initialized:
await self._init(self._runtime)

# Run the team by publishing the start message.
first_chat_message: TextMessage | MultiModalMessage | None = None
if isinstance(task, str):
first_chat_message = TextMessage(content=task, source="user")
elif isinstance(task, TextMessage | MultiModalMessage):
first_chat_message = task
await self._runtime.publish_message(
GroupChatStart(message=first_chat_message),
topic_id=TopicId(type=self._group_topic_type, source=self._team_id),
)

# Start a coroutine to stop the runtime and signal the output message queue is complete.
async def stop_runtime() -> None:
await self._runtime.stop_when_idle()
await self._output_message_queue.put(None)

shutdown_task = asyncio.create_task(stop_runtime())

# Collect the output messages in order.
output_messages: List[AgentMessage] = []
# Yield the messsages until the queue is empty.
while True:
message = await self._output_message_queue.get()
if message is None:
break
yield message
output_messages.append(message)

# Wait for the shutdown task to finish.
await shutdown_task

# Yield the final result.
yield TaskResult(messages=output_messages, stop_reason=self._stop_reason)

# Indicate that the team is no longer running.
self._is_running = False
try:
# Run the team by sending the start message to the group chat manager.
# The group chat manager will start the group chat by relaying the message to the participants
# and the closure agent.
await self._runtime.send_message(
GroupChatStart(message=first_chat_message),
recipient=AgentId(type=self._group_chat_manager_topic_type, key=self._team_id),
)
# Collect the output messages in order.
output_messages: List[AgentMessage] = []
# Yield the messsages until the queue is empty.
while True:
message = await self._output_message_queue.get()
if message is None:
break
yield message
output_messages.append(message)

# Yield the final result.
yield TaskResult(messages=output_messages, stop_reason=self._stop_reason)

finally:
# Wait for the shutdown task to finish.
await shutdown_task

# Clear the output message queue.
while not self._output_message_queue.empty():
self._output_message_queue.get_nowait()

# Indicate that the team is no longer running.
self._is_running = False

async def reset(self) -> None:
"""Reset the team and its participants to their initial state.
Expand Down Expand Up @@ -352,19 +363,26 @@ async def main() -> None:
# Start the runtime.
self._runtime.start()

# Send a reset message to the group chat.
await self._runtime.publish_message(
GroupChatReset(),
topic_id=TopicId(type=self._group_topic_type, source=self._team_id),
)

# Stop the runtime.
await self._runtime.stop_when_idle()
try:
# Send a reset messages to all participants.
for participant_topic_type in self._participant_topic_types:
await self._runtime.send_message(
GroupChatReset(),
recipient=AgentId(type=participant_topic_type, key=self._team_id),
)
# Send a reset message to the group chat manager.
await self._runtime.send_message(
GroupChatReset(),
recipient=AgentId(type=self._group_chat_manager_topic_type, key=self._team_id),
)
finally:
# Stop the runtime.
await self._runtime.stop_when_idle()

# Reset the output message queue.
self._stop_reason = None
while not self._output_message_queue.empty():
self._output_message_queue.get_nowait()
# Reset the output message queue.
self._stop_reason = None
while not self._output_message_queue.empty():
self._output_message_queue.get_nowait()

# Indicate that the team is no longer running.
self._is_running = False
# Indicate that the team is no longer running.
self._is_running = False
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from typing import Any, List

from autogen_core.base import MessageContext
from autogen_core.components import DefaultTopicId, event
from autogen_core.components import DefaultTopicId, event, rpc

from ...base import TerminationCondition
from ...messages import AgentMessage, StopMessage
from ...messages import AgentMessage, ChatMessage, StopMessage
from ._events import (
GroupChatAgentResponse,
GroupChatRequestPublish,
Expand Down Expand Up @@ -55,7 +55,7 @@ def __init__(
self._max_turns = max_turns
self._current_turn = 0

@event
@rpc
async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> None:
"""Handle the start of a group chat by selecting a speaker to start the conversation."""

Expand All @@ -70,10 +70,16 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No
# Stop the group chat.
return

# Validate the group state given the start message.
await self.validate_group_state(message.message)

if message.message is not None:
# Log the start message.
await self.publish_message(message, topic_id=DefaultTopicId(type=self._output_topic_type))

# Relay the start message to the participants.
await self.publish_message(message, topic_id=DefaultTopicId(type=self._group_topic_type))

# Append the user message to the message thread.
self._message_thread.append(message.message)

Expand Down Expand Up @@ -137,11 +143,16 @@ async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: Mess
speaker_topic_type = await self.select_speaker(self._message_thread)
await self.publish_message(GroupChatRequestPublish(), topic_id=DefaultTopicId(type=speaker_topic_type))

@event
@rpc
async def handle_reset(self, message: GroupChatReset, ctx: MessageContext) -> None:
# Reset the group chat manager.
await self.reset()

@abstractmethod
async def validate_group_state(self, message: ChatMessage | None) -> None:
"""Validate the state of the group chat given the start message. This is executed when the group chat manager receives a GroupChatStart event."""
...

@abstractmethod
async def select_speaker(self, thread: List[AgentMessage]) -> str:
"""Select a speaker from the participants and return the
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, List

from autogen_core.base import MessageContext
from autogen_core.components import DefaultTopicId, event
from autogen_core.components import DefaultTopicId, event, rpc

from ...base import ChatAgent, Response
from ...messages import ChatMessage
Expand Down Expand Up @@ -38,7 +38,7 @@ async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: Mess
"""Handle an agent response event by appending the content to the buffer."""
self._message_buffer.append(message.agent_response.chat_message)

@event
@rpc
async def handle_reset(self, message: GroupChatReset, ctx: MessageContext) -> None:
"""Handle a reset event by resetting the agent."""
self._message_buffer.clear()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, List

from autogen_core.base import MessageContext
from autogen_core.components import DefaultTopicId, Image, event
from autogen_core.components import DefaultTopicId, Image, event, rpc
from autogen_core.components.models import (
AssistantMessage,
ChatCompletionClient,
Expand Down Expand Up @@ -102,7 +102,7 @@ def _get_task_ledger_plan_update_prompt(self, team: str) -> str:
def _get_final_answer_prompt(self, task: str) -> str:
return ORCHESTRATOR_FINAL_ANSWER_PROMPT.format(task=task)

@event
@rpc
async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> None:
"""Handle the start of a group chat by selecting a speaker to start the conversation."""
assert message is not None and message.message is not None
Expand Down Expand Up @@ -145,7 +145,7 @@ async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: Mess
self._message_thread.append(message.agent_response.chat_message)
await self._orchestrate_step()

@event
@rpc
async def handle_reset(self, message: GroupChatReset, ctx: MessageContext) -> None:
# Reset the group chat manager.
await self.reset()
Expand Down
Loading

0 comments on commit c591c9e

Please sign in to comment.