Skip to content

Commit

Permalink
Merge pull request #307 from PrefectHQ/handlers
Browse files Browse the repository at this point in the history
Expose and document handlers
  • Loading branch information
jlowin authored Sep 12, 2024
2 parents dfc74c4 + 170d268 commit d2ebf6c
Show file tree
Hide file tree
Showing 12 changed files with 300 additions and 47 deletions.
43 changes: 43 additions & 0 deletions docs/patterns/running-tasks.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -369,3 +369,46 @@ The orchestrator is instantiated with the following arguments:

You can then use the orchestrator's `run()` method to step through the loop manually. If you call `run()` with no arguments, it will continue until all of the provided tasks are complete. You can provide `max_llm_calls` and `max_agent_turns` to further limit the behavior.


## Using handlers

Handlers in ControlFlow provide a way to observe and react to events that occur during task execution. They allow you to customize logging, monitoring, or take specific actions based on the orchestration process.

Handlers implement the `Handler` interface, which defines methods for various events that can occur during task execution, including agent messages (and message deltas), user messages, tool calls, tool results, orchestrator sessions starting or stopping, and more.

ControlFlow includes a built-in `PrintHandler` that pretty-prints agent responses and tool calls to the terminal. It's used by default if `controlflow.settings.pretty_print_agent_events=True` and no other handlers are provided.

### How handlers work

Whenever an event is generated by ControlFlow, the orchestrator will pass it to all of its registered handlers. Each handler will dispatch to one of its methods based on the type of event. For example, an `AgentMessage` event will be handled by the handler's `on_agent_message` method. The `on_event` method is always called for every event. This table describes all event types and the methods they are dispatched to:

| Event Type | Method |
|------------|--------|
| `Event` (all events) | `on_event` |
| `UserMessage` | `on_user_message` |
| `OrchestratorMessage` | `on_orchestrator_message` |
| `AgentMessage` | `on_agent_message` |
| `AgentMessageDelta` | `on_agent_message_delta` |
| `ToolCall` | `on_tool_call` |
| `ToolResult` | `on_tool_result` |
| `OrchestratorStart` | `on_orchestrator_start` |
| `OrchestratorEnd` | `on_orchestrator_end` |
| `OrchestratorError` | `on_orchestrator_error` |
| `EndTurn` | `on_end_turn` |


### Writing a custom handler

To create a custom handler, subclass the `Handler` class and implement the methods for the events you're interested in. Here's a simple example that logs agent messages:

```python
import controlflow as cf
from controlflow.orchestration.handler import Handler
from controlflow.events.events import AgentMessage

class LoggingHandler(Handler):
def on_agent_message(self, event: AgentMessage):
print(f"Agent {event.agent.name} said: {event.ai_message.content}")

cf.run("Write a short poem about AI", handlers=[LoggingHandler()])
```
5 changes: 5 additions & 0 deletions src/controlflow/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from .memory import Memory

if TYPE_CHECKING:
from controlflow.orchestration.handler import Handler
from controlflow.orchestration.turn_strategies import TurnStrategy
from controlflow.tasks import Task
from controlflow.tools.tools import Tool
Expand Down Expand Up @@ -196,12 +197,14 @@ def run(
objective: str,
*,
turn_strategy: "TurnStrategy" = None,
handlers: list["Handler"] = None,
**task_kwargs,
):
return controlflow.run(
objective=objective,
agents=[self],
turn_strategy=turn_strategy,
handlers=handlers,
**task_kwargs,
)

Expand All @@ -210,12 +213,14 @@ async def run_async(
objective: str,
*,
turn_strategy: "TurnStrategy" = None,
handlers: list["Handler"] = None,
**task_kwargs,
):
return await controlflow.run_async(
objective=objective,
agents=[self],
turn_strategy=turn_strategy,
handlers=handlers,
**task_kwargs,
)

Expand Down
15 changes: 15 additions & 0 deletions src/controlflow/events/orchestrator_events.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Literal

from controlflow.agents.agent import Agent
from controlflow.events.base import UnpersistedEvent
from controlflow.orchestration.orchestrator import Orchestrator

Expand All @@ -21,3 +22,17 @@ class OrchestratorError(UnpersistedEvent):
persist: bool = False
orchestrator: Orchestrator
error: Exception


class AgentTurnStart(UnpersistedEvent):
event: Literal["agent-turn-start"] = "agent-turn-start"
persist: bool = False
orchestrator: Orchestrator
agent: Agent


class AgentTurnEnd(UnpersistedEvent):
event: Literal["agent-turn-end"] = "agent-turn-end"
persist: bool = False
orchestrator: Orchestrator
agent: Agent
1 change: 1 addition & 0 deletions src/controlflow/orchestration/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .orchestrator import Orchestrator
from .handler import Handler
62 changes: 60 additions & 2 deletions src/controlflow/orchestration/handler.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,77 @@
from typing import Callable
from typing import TYPE_CHECKING, Callable

from controlflow.events.base import Event

if TYPE_CHECKING:
from controlflow.events.events import (
AgentMessage,
AgentMessageDelta,
EndTurn,
OrchestratorMessage,
ToolCallEvent,
ToolResultEvent,
UserMessage,
)
from controlflow.events.orchestrator_events import (
OrchestratorEnd,
OrchestratorError,
OrchestratorStart,
)


class Handler:
def handle(self, event: Event):
"""
Handle is called whenever an event is emitted.
By default, it dispatches to a method named after the event type e.g.
`self.on_{event_type}(event=event)`.
The `on_event` method is always called for every event.
"""
self.on_event(event=event)
event_type = event.event.replace("-", "_")
method = getattr(self, f"on_{event_type}", None)
if method:
method(event=event)

def on_event(self, event: Event):
pass

def on_orchestrator_start(self, event: "OrchestratorStart"):
pass

def on_orchestrator_end(self, event: "OrchestratorEnd"):
pass

def on_orchestrator_error(self, event: "OrchestratorError"):
pass

def on_agent_message(self, event: "AgentMessage"):
pass

def on_agent_message_delta(self, event: "AgentMessageDelta"):
pass

def on_tool_call(self, event: "ToolCallEvent"):
pass

def on_tool_result(self, event: "ToolResultEvent"):
pass

def on_orchestrator_message(self, event: "OrchestratorMessage"):
pass

def on_user_message(self, event: "UserMessage"):
pass

def on_end_turn(self, event: "EndTurn"):
pass


class CallbackHandler(Handler):
def __init__(self, callback: Callable[[Event], None]):
self.callback = callback

def handle(self, event: Event):
def on_event(self, event: Event):
self.callback(event)
77 changes: 46 additions & 31 deletions src/controlflow/orchestration/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,24 +160,18 @@ def run(
if max_llm_calls is not None and call_count >= max_llm_calls:
break

self.handle_event(
controlflow.events.orchestrator_events.AgentTurnStart(
orchestrator=self, agent=self.agent
)
)
turn_count += 1
self.turn_strategy.begin_turn()

# Mark assigned tasks as running
for task in self.get_tasks("assigned"):
if not task.is_running():
task.mark_running()
self.flow.add_events(
[
OrchestratorMessage(
content=f"Starting task {task.name} (ID {task.id}) "
f"with objective: {task.objective}"
)
]
)

# Run the agent's turn
call_count += self.run_agent_turn(max_llm_calls - call_count)
self.handle_event(
controlflow.events.orchestrator_events.AgentTurnEnd(
orchestrator=self, agent=self.agent
)
)

# Select the next agent for the following turn
if available_agents := self.get_available_agents():
Expand Down Expand Up @@ -244,25 +238,20 @@ async def run_async(
if max_llm_calls is not None and call_count >= max_llm_calls:
break

self.handle_event(
controlflow.events.orchestrator_events.AgentTurnStart(
orchestrator=self, agent=self.agent
)
)
turn_count += 1
self.turn_strategy.begin_turn()

# Mark assigned tasks as running
for task in self.get_tasks("assigned"):
if not task.is_running():
task.mark_running()
self.flow.add_events(
[
OrchestratorMessage(
content=f"Starting task {task.name} (ID {task.id}) with objective: {task.objective}"
)
]
)

# Run the agent's turn
call_count += await self.run_agent_turn_async(
max_llm_calls - call_count
)
self.handle_event(
controlflow.events.orchestrator_events.AgentTurnEnd(
orchestrator=self, agent=self.agent
)
)

# Select the next agent for the following turn
if available_agents := self.get_available_agents():
Expand Down Expand Up @@ -300,6 +289,19 @@ def run_agent_turn(self, max_llm_calls: Optional[int]) -> int:
call_count = 0
assigned_tasks = self.get_tasks("assigned")

self.turn_strategy.begin_turn()

# Mark assigned tasks as running
for task in assigned_tasks:
if not task.is_running():
task.mark_running()
self.handle_event(
OrchestratorMessage(
content=f"Starting task {task.name} (ID {task.id}) "
f"with objective: {task.objective}"
)
)

while not self.turn_strategy.should_end_turn():
for task in assigned_tasks:
if task.max_llm_calls and task._llm_calls >= task.max_llm_calls:
Expand Down Expand Up @@ -340,6 +342,19 @@ async def run_agent_turn_async(self, max_llm_calls: Optional[int]) -> int:
call_count = 0
assigned_tasks = self.get_tasks("assigned")

self.turn_strategy.begin_turn()

# Mark assigned tasks as running
for task in assigned_tasks:
if not task.is_running():
task.mark_running()
self.handle_event(
OrchestratorMessage(
content=f"Starting task {task.name} (ID {task.id}) "
f"with objective: {task.objective}"
)
)

while not self.turn_strategy.should_end_turn():
for task in assigned_tasks:
if task.max_llm_calls and task._llm_calls >= task.max_llm_calls:
Expand Down
28 changes: 14 additions & 14 deletions src/controlflow/orchestration/print_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,6 @@ def __init__(self):
self.paused_id: str = None
super().__init__()

def on_orchestrator_start(self, event: OrchestratorStart):
self.live: Live = Live(auto_refresh=False, console=cf_console)
self.events.clear()
try:
self.live.start()
except rich.errors.LiveError:
pass

def on_orchestrator_end(self, event: OrchestratorEnd):
self.live.stop()

def on_orchestrator_error(self, event: OrchestratorError):
self.live.stop()

def update_live(self, latest: BaseMessage = None):
events = sorted(self.events.items(), key=lambda e: (e[1].timestamp, e[0]))
content = []
Expand All @@ -72,6 +58,20 @@ def update_live(self, latest: BaseMessage = None):
elif latest:
cf_console.print(format_event(latest))

def on_orchestrator_start(self, event: OrchestratorStart):
self.live: Live = Live(auto_refresh=False, console=cf_console)
self.events.clear()
try:
self.live.start()
except rich.errors.LiveError:
pass

def on_orchestrator_end(self, event: OrchestratorEnd):
self.live.stop()

def on_orchestrator_error(self, event: OrchestratorError):
self.live.stop()

def on_agent_message_delta(self, event: AgentMessageDelta):
self.events[event.snapshot_message.id] = event
self.update_live()
Expand Down
Loading

0 comments on commit d2ebf6c

Please sign in to comment.