From 170d2683268ec34dfd936697187352cd4701f9d8 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Wed, 11 Sep 2024 19:53:22 -0400 Subject: [PATCH] Expose and document handlers --- docs/patterns/running-tasks.mdx | 43 +++++++++++ src/controlflow/agents/agent.py | 5 ++ src/controlflow/events/orchestrator_events.py | 15 ++++ src/controlflow/orchestration/__init__.py | 1 + src/controlflow/orchestration/handler.py | 62 ++++++++++++++- src/controlflow/orchestration/orchestrator.py | 77 +++++++++++-------- .../orchestration/print_handler.py | 28 +++---- src/controlflow/run.py | 9 +++ src/controlflow/tasks/task.py | 5 ++ tests/agents/test_agents.py | 38 +++++++++ tests/tasks/test_tasks.py | 33 ++++++++ tests/test_run.py | 31 ++++++++ 12 files changed, 300 insertions(+), 47 deletions(-) diff --git a/docs/patterns/running-tasks.mdx b/docs/patterns/running-tasks.mdx index b594f088..7a14da7e 100644 --- a/docs/patterns/running-tasks.mdx +++ b/docs/patterns/running-tasks.mdx @@ -369,3 +369,46 @@ The orchestrator is instantiated with the following arguments: You can then use the orchestrator's `run()` method to step through the loop manually. If you call `run()` with no arguments, it will continue until all of the provided tasks are complete. You can provide `max_llm_calls` and `max_agent_turns` to further limit the behavior. + +## Using handlers + +Handlers in ControlFlow provide a way to observe and react to events that occur during task execution. They allow you to customize logging, monitoring, or take specific actions based on the orchestration process. + +Handlers implement the `Handler` interface, which defines methods for various events that can occur during task execution, including agent messages (and message deltas), user messages, tool calls, tool results, orchestrator sessions starting or stopping, and more. + +ControlFlow includes a built-in `PrintHandler` that pretty-prints agent responses and tool calls to the terminal. It's used by default if `controlflow.settings.pretty_print_agent_events=True` and no other handlers are provided. + +### How handlers work + +Whenever an event is generated by ControlFlow, the orchestrator will pass it to all of its registered handlers. Each handler will dispatch to one of its methods based on the type of event. For example, an `AgentMessage` event will be handled by the handler's `on_agent_message` method. The `on_event` method is always called for every event. This table describes all event types and the methods they are dispatched to: + +| Event Type | Method | +|------------|--------| +| `Event` (all events) | `on_event` | +| `UserMessage` | `on_user_message` | +| `OrchestratorMessage` | `on_orchestrator_message` | +| `AgentMessage` | `on_agent_message` | +| `AgentMessageDelta` | `on_agent_message_delta` | +| `ToolCall` | `on_tool_call` | +| `ToolResult` | `on_tool_result` | +| `OrchestratorStart` | `on_orchestrator_start` | +| `OrchestratorEnd` | `on_orchestrator_end` | +| `OrchestratorError` | `on_orchestrator_error` | +| `EndTurn` | `on_end_turn` | + + +### Writing a custom handler + +To create a custom handler, subclass the `Handler` class and implement the methods for the events you're interested in. Here's a simple example that logs agent messages: + +```python +import controlflow as cf +from controlflow.orchestration.handler import Handler +from controlflow.events.events import AgentMessage + +class LoggingHandler(Handler): + def on_agent_message(self, event: AgentMessage): + print(f"Agent {event.agent.name} said: {event.ai_message.content}") + +cf.run("Write a short poem about AI", handlers=[LoggingHandler()]) +``` diff --git a/src/controlflow/agents/agent.py b/src/controlflow/agents/agent.py index c42d9eb0..a1b2165b 100644 --- a/src/controlflow/agents/agent.py +++ b/src/controlflow/agents/agent.py @@ -34,6 +34,7 @@ from .memory import Memory if TYPE_CHECKING: + from controlflow.orchestration.handler import Handler from controlflow.orchestration.turn_strategies import TurnStrategy from controlflow.tasks import Task from controlflow.tools.tools import Tool @@ -196,12 +197,14 @@ def run( objective: str, *, turn_strategy: "TurnStrategy" = None, + handlers: list["Handler"] = None, **task_kwargs, ): return controlflow.run( objective=objective, agents=[self], turn_strategy=turn_strategy, + handlers=handlers, **task_kwargs, ) @@ -210,12 +213,14 @@ async def run_async( objective: str, *, turn_strategy: "TurnStrategy" = None, + handlers: list["Handler"] = None, **task_kwargs, ): return await controlflow.run_async( objective=objective, agents=[self], turn_strategy=turn_strategy, + handlers=handlers, **task_kwargs, ) diff --git a/src/controlflow/events/orchestrator_events.py b/src/controlflow/events/orchestrator_events.py index 98d0297e..6c07fef1 100644 --- a/src/controlflow/events/orchestrator_events.py +++ b/src/controlflow/events/orchestrator_events.py @@ -1,5 +1,6 @@ from typing import Literal +from controlflow.agents.agent import Agent from controlflow.events.base import UnpersistedEvent from controlflow.orchestration.orchestrator import Orchestrator @@ -21,3 +22,17 @@ class OrchestratorError(UnpersistedEvent): persist: bool = False orchestrator: Orchestrator error: Exception + + +class AgentTurnStart(UnpersistedEvent): + event: Literal["agent-turn-start"] = "agent-turn-start" + persist: bool = False + orchestrator: Orchestrator + agent: Agent + + +class AgentTurnEnd(UnpersistedEvent): + event: Literal["agent-turn-end"] = "agent-turn-end" + persist: bool = False + orchestrator: Orchestrator + agent: Agent diff --git a/src/controlflow/orchestration/__init__.py b/src/controlflow/orchestration/__init__.py index 5f7b6691..8f3ed651 100644 --- a/src/controlflow/orchestration/__init__.py +++ b/src/controlflow/orchestration/__init__.py @@ -1 +1,2 @@ from .orchestrator import Orchestrator +from .handler import Handler diff --git a/src/controlflow/orchestration/handler.py b/src/controlflow/orchestration/handler.py index e4bb2d31..935cd926 100644 --- a/src/controlflow/orchestration/handler.py +++ b/src/controlflow/orchestration/handler.py @@ -1,19 +1,77 @@ -from typing import Callable +from typing import TYPE_CHECKING, Callable from controlflow.events.base import Event +if TYPE_CHECKING: + from controlflow.events.events import ( + AgentMessage, + AgentMessageDelta, + EndTurn, + OrchestratorMessage, + ToolCallEvent, + ToolResultEvent, + UserMessage, + ) + from controlflow.events.orchestrator_events import ( + OrchestratorEnd, + OrchestratorError, + OrchestratorStart, + ) + class Handler: def handle(self, event: Event): + """ + Handle is called whenever an event is emitted. + + By default, it dispatches to a method named after the event type e.g. + `self.on_{event_type}(event=event)`. + + The `on_event` method is always called for every event. + """ + self.on_event(event=event) event_type = event.event.replace("-", "_") method = getattr(self, f"on_{event_type}", None) if method: method(event=event) + def on_event(self, event: Event): + pass + + def on_orchestrator_start(self, event: "OrchestratorStart"): + pass + + def on_orchestrator_end(self, event: "OrchestratorEnd"): + pass + + def on_orchestrator_error(self, event: "OrchestratorError"): + pass + + def on_agent_message(self, event: "AgentMessage"): + pass + + def on_agent_message_delta(self, event: "AgentMessageDelta"): + pass + + def on_tool_call(self, event: "ToolCallEvent"): + pass + + def on_tool_result(self, event: "ToolResultEvent"): + pass + + def on_orchestrator_message(self, event: "OrchestratorMessage"): + pass + + def on_user_message(self, event: "UserMessage"): + pass + + def on_end_turn(self, event: "EndTurn"): + pass + class CallbackHandler(Handler): def __init__(self, callback: Callable[[Event], None]): self.callback = callback - def handle(self, event: Event): + def on_event(self, event: Event): self.callback(event) diff --git a/src/controlflow/orchestration/orchestrator.py b/src/controlflow/orchestration/orchestrator.py index eeadd582..b11c970b 100644 --- a/src/controlflow/orchestration/orchestrator.py +++ b/src/controlflow/orchestration/orchestrator.py @@ -160,24 +160,18 @@ def run( if max_llm_calls is not None and call_count >= max_llm_calls: break + self.handle_event( + controlflow.events.orchestrator_events.AgentTurnStart( + orchestrator=self, agent=self.agent + ) + ) turn_count += 1 - self.turn_strategy.begin_turn() - - # Mark assigned tasks as running - for task in self.get_tasks("assigned"): - if not task.is_running(): - task.mark_running() - self.flow.add_events( - [ - OrchestratorMessage( - content=f"Starting task {task.name} (ID {task.id}) " - f"with objective: {task.objective}" - ) - ] - ) - - # Run the agent's turn call_count += self.run_agent_turn(max_llm_calls - call_count) + self.handle_event( + controlflow.events.orchestrator_events.AgentTurnEnd( + orchestrator=self, agent=self.agent + ) + ) # Select the next agent for the following turn if available_agents := self.get_available_agents(): @@ -244,25 +238,20 @@ async def run_async( if max_llm_calls is not None and call_count >= max_llm_calls: break + self.handle_event( + controlflow.events.orchestrator_events.AgentTurnStart( + orchestrator=self, agent=self.agent + ) + ) turn_count += 1 - self.turn_strategy.begin_turn() - - # Mark assigned tasks as running - for task in self.get_tasks("assigned"): - if not task.is_running(): - task.mark_running() - self.flow.add_events( - [ - OrchestratorMessage( - content=f"Starting task {task.name} (ID {task.id}) with objective: {task.objective}" - ) - ] - ) - - # Run the agent's turn call_count += await self.run_agent_turn_async( max_llm_calls - call_count ) + self.handle_event( + controlflow.events.orchestrator_events.AgentTurnEnd( + orchestrator=self, agent=self.agent + ) + ) # Select the next agent for the following turn if available_agents := self.get_available_agents(): @@ -300,6 +289,19 @@ def run_agent_turn(self, max_llm_calls: Optional[int]) -> int: call_count = 0 assigned_tasks = self.get_tasks("assigned") + self.turn_strategy.begin_turn() + + # Mark assigned tasks as running + for task in assigned_tasks: + if not task.is_running(): + task.mark_running() + self.handle_event( + OrchestratorMessage( + content=f"Starting task {task.name} (ID {task.id}) " + f"with objective: {task.objective}" + ) + ) + while not self.turn_strategy.should_end_turn(): for task in assigned_tasks: if task.max_llm_calls and task._llm_calls >= task.max_llm_calls: @@ -340,6 +342,19 @@ async def run_agent_turn_async(self, max_llm_calls: Optional[int]) -> int: call_count = 0 assigned_tasks = self.get_tasks("assigned") + self.turn_strategy.begin_turn() + + # Mark assigned tasks as running + for task in assigned_tasks: + if not task.is_running(): + task.mark_running() + self.handle_event( + OrchestratorMessage( + content=f"Starting task {task.name} (ID {task.id}) " + f"with objective: {task.objective}" + ) + ) + while not self.turn_strategy.should_end_turn(): for task in assigned_tasks: if task.max_llm_calls and task._llm_calls >= task.max_llm_calls: diff --git a/src/controlflow/orchestration/print_handler.py b/src/controlflow/orchestration/print_handler.py index a06fb675..f76523aa 100644 --- a/src/controlflow/orchestration/print_handler.py +++ b/src/controlflow/orchestration/print_handler.py @@ -35,20 +35,6 @@ def __init__(self): self.paused_id: str = None super().__init__() - def on_orchestrator_start(self, event: OrchestratorStart): - self.live: Live = Live(auto_refresh=False, console=cf_console) - self.events.clear() - try: - self.live.start() - except rich.errors.LiveError: - pass - - def on_orchestrator_end(self, event: OrchestratorEnd): - self.live.stop() - - def on_orchestrator_error(self, event: OrchestratorError): - self.live.stop() - def update_live(self, latest: BaseMessage = None): events = sorted(self.events.items(), key=lambda e: (e[1].timestamp, e[0])) content = [] @@ -72,6 +58,20 @@ def update_live(self, latest: BaseMessage = None): elif latest: cf_console.print(format_event(latest)) + def on_orchestrator_start(self, event: OrchestratorStart): + self.live: Live = Live(auto_refresh=False, console=cf_console) + self.events.clear() + try: + self.live.start() + except rich.errors.LiveError: + pass + + def on_orchestrator_end(self, event: OrchestratorEnd): + self.live.stop() + + def on_orchestrator_error(self, event: OrchestratorError): + self.live.stop() + def on_agent_message_delta(self, event: AgentMessageDelta): self.events[event.snapshot_message.id] = event self.update_live() diff --git a/src/controlflow/run.py b/src/controlflow/run.py index 6ec82169..1f01a350 100644 --- a/src/controlflow/run.py +++ b/src/controlflow/run.py @@ -4,6 +4,7 @@ from controlflow.agents.agent import Agent from controlflow.flows import Flow, get_flow +from controlflow.orchestration.handler import Handler from controlflow.orchestration.orchestrator import Orchestrator, TurnStrategy from controlflow.tasks.task import Task from controlflow.utilities.prefect import prefect_task @@ -25,6 +26,7 @@ def run_tasks( raise_on_error: bool = True, max_llm_calls: int = None, max_agent_turns: int = None, + handlers: list[Handler] = None, ) -> list[Any]: """ Run a list of tasks. @@ -38,6 +40,7 @@ def run_tasks( flow=flow, agent=agent, turn_strategy=turn_strategy, + handlers=handlers, ) orchestrator.run( max_llm_calls=max_llm_calls, @@ -64,6 +67,7 @@ async def run_tasks_async( raise_on_error: bool = True, max_llm_calls: int = None, max_agent_turns: int = None, + handlers: list[Handler] = None, ): """ Run a list of tasks. @@ -74,6 +78,7 @@ async def run_tasks_async( flow=flow, agent=agent, turn_strategy=turn_strategy, + handlers=handlers, ) await orchestrator.run_async( max_llm_calls=max_llm_calls, @@ -98,6 +103,7 @@ def run( max_llm_calls: int = None, max_agent_turns: int = None, raise_on_error: bool = True, + handlers: list[Handler] = None, **task_kwargs, ) -> Any: task = Task(objective=objective, **task_kwargs) @@ -107,6 +113,7 @@ def run( turn_strategy=turn_strategy, max_llm_calls=max_llm_calls, max_agent_turns=max_agent_turns, + handlers=handlers, ) return results[0] @@ -120,6 +127,7 @@ async def run_async( max_llm_calls: int = None, max_agent_turns: int = None, raise_on_error: bool = True, + handlers: list[Handler] = None, **task_kwargs, ) -> Any: task = Task(objective=objective, **task_kwargs) @@ -131,5 +139,6 @@ async def run_async( max_llm_calls=max_llm_calls, max_agent_turns=max_agent_turns, raise_on_error=raise_on_error, + handlers=handlers, ) return results[0] diff --git a/src/controlflow/tasks/task.py b/src/controlflow/tasks/task.py index b166527e..ff0f6221 100644 --- a/src/controlflow/tasks/task.py +++ b/src/controlflow/tasks/task.py @@ -43,6 +43,7 @@ if TYPE_CHECKING: from controlflow.flows import Flow + from controlflow.orchestration.handler import Handler from controlflow.orchestration.turn_strategies import TurnStrategy T = TypeVar("T") @@ -351,6 +352,7 @@ def run( turn_strategy: "TurnStrategy" = None, max_llm_calls: int = None, max_agent_turns: int = None, + handlers: list["Handler"] = None, ) -> T: """ Run the task @@ -364,6 +366,7 @@ def run( max_llm_calls=max_llm_calls, max_agent_turns=max_agent_turns, raise_on_error=False, + handlers=handlers, ) if self.is_successful(): @@ -379,6 +382,7 @@ async def run_async( turn_strategy: "TurnStrategy" = None, max_llm_calls: int = None, max_agent_turns: int = None, + handlers: list["Handler"] = None, ) -> T: """ Run the task @@ -392,6 +396,7 @@ async def run_async( max_llm_calls=max_llm_calls, max_agent_turns=max_agent_turns, raise_on_error=False, + handlers=handlers, ) if self.is_successful(): diff --git a/tests/agents/test_agents.py b/tests/agents/test_agents.py index d44f15b7..9a82b158 100644 --- a/tests/agents/test_agents.py +++ b/tests/agents/test_agents.py @@ -1,9 +1,13 @@ +import pytest from langchain_openai import ChatOpenAI import controlflow from controlflow.agents import Agent +from controlflow.events.base import Event +from controlflow.events.events import AgentMessage from controlflow.instructions import instructions from controlflow.llm.rules import LLMRules +from controlflow.orchestration.handler import Handler from controlflow.tasks.task import Task @@ -138,3 +142,37 @@ def test_context_manager(self): from controlflow.utilities.context import ctx assert ctx.get("agent") is agent + + +class TestHandlers: + class TestHandler(Handler): + def __init__(self): + self.events = [] + self.agent_messages = [] + + def on_event(self, event: Event): + self.events.append(event) + + def on_agent_message(self, event: AgentMessage): + self.agent_messages.append(event) + + def test_agent_run_with_handlers(self, default_fake_llm): + handler = self.TestHandler() + agent = Agent() + agent.run( + "Calculate 2 + 2", result_type=int, handlers=[handler], max_llm_calls=1 + ) + + assert len(handler.events) > 0 + assert len(handler.agent_messages) == 1 + + @pytest.mark.asyncio + async def test_agent_run_async_with_handlers(self, default_fake_llm): + handler = self.TestHandler() + agent = Agent() + await agent.run_async( + "Calculate 2 + 2", result_type=int, handlers=[handler], max_llm_calls=1 + ) + + assert len(handler.events) > 0 + assert len(handler.agent_messages) == 1 diff --git a/tests/tasks/test_tasks.py b/tests/tasks/test_tasks.py index 5fda4081..f2da7123 100644 --- a/tests/tasks/test_tasks.py +++ b/tests/tasks/test_tasks.py @@ -5,8 +5,11 @@ import controlflow from controlflow.agents import Agent +from controlflow.events.base import Event +from controlflow.events.events import AgentMessage from controlflow.flows import Flow from controlflow.instructions import instructions +from controlflow.orchestration.handler import Handler from controlflow.tasks.task import ( COMPLETE_STATUSES, INCOMPLETE_STATUSES, @@ -423,3 +426,33 @@ class Person(BaseModel): tool.run(input=dict(result=1)) assert task.result == Person(name="Bob", age=35) assert isinstance(task.result, Person) + + +class TestHandlers: + class TestHandler(Handler): + def __init__(self): + self.events = [] + self.agent_messages = [] + + def on_event(self, event: Event): + self.events.append(event) + + def on_agent_message(self, event: AgentMessage): + self.agent_messages.append(event) + + def test_task_run_with_handlers(self, default_fake_llm): + handler = self.TestHandler() + task = Task(objective="Calculate 2 + 2", result_type=int) + task.run(handlers=[handler], max_llm_calls=1) + + assert len(handler.events) > 0 + assert len(handler.agent_messages) == 1 + + @pytest.mark.asyncio + async def test_task_run_async_with_handlers(self, default_fake_llm): + handler = self.TestHandler() + task = Task(objective="Calculate 2 + 2", result_type=int) + await task.run_async(handlers=[handler], max_llm_calls=1) + + assert len(handler.events) > 0 + assert len(handler.agent_messages) == 1 diff --git a/tests/test_run.py b/tests/test_run.py index 007d5562..41f5d470 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -1,6 +1,37 @@ +from controlflow.events.base import Event +from controlflow.events.events import AgentMessage +from controlflow.orchestration.handler import Handler from controlflow.run import run, run_async +class TestHandlers: + class TestHandler(Handler): + def __init__(self): + self.events = [] + self.agent_messages = [] + + def on_event(self, event: Event): + self.events.append(event) + + def on_agent_message(self, event: AgentMessage): + self.agent_messages.append(event) + + def test_run_with_handlers(self, default_fake_llm): + handler = self.TestHandler() + run("what's 2 + 2", result_type=int, handlers=[handler], max_llm_calls=1) + assert len(handler.events) > 0 + assert len(handler.agent_messages) == 1 + + async def test_run_async_with_handlers(self, default_fake_llm): + handler = self.TestHandler() + await run_async( + "what's 2 + 2", result_type=int, handlers=[handler], max_llm_calls=1 + ) + + assert len(handler.events) > 0 + assert len(handler.agent_messages) == 1 + + def test_run(): result = run("what's 2 + 2", result_type=int) assert result == 4