Skip to content

Commit

Permalink
improve AsyncCallbackManager (langchain-ai#2410)
Browse files Browse the repository at this point in the history
  • Loading branch information
agola11 authored Apr 5, 2023
1 parent af7f20f commit 4d730a9
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 143 deletions.
214 changes: 73 additions & 141 deletions langchain/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import asyncio
import functools
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Optional, Union

from langchain.schema import AgentAction, AgentFinish, LLMResult

Expand Down Expand Up @@ -328,6 +328,25 @@ async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run on agent end."""


async def _handle_event_for_handler(
handler: BaseCallbackHandler,
event_name: str,
ignore_condition_name: Optional[str],
verbose: bool,
*args: Any,
**kwargs: Any
) -> None:
if ignore_condition_name is None or not getattr(handler, ignore_condition_name):
if verbose or handler.always_verbose:
event = getattr(handler, event_name)
if asyncio.iscoroutinefunction(event):
await event(*args, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None, functools.partial(event, *args, **kwargs)
)


class AsyncCallbackManager(BaseCallbackManager):
"""Async callback manager that can be used to handle callbacks from LangChain."""

Expand All @@ -340,6 +359,24 @@ def __init__(self, handlers: List[BaseCallbackHandler]) -> None:
"""Initialize callback manager."""
self.handlers: List[BaseCallbackHandler] = handlers

async def _handle_event(
self,
event_name: str,
ignore_condition_name: Optional[str],
verbose: bool,
*args: Any,
**kwargs: Any
) -> None:
"""Generic event handler for AsyncCallbackManager."""
await asyncio.gather(
*(
_handle_event_for_handler(
handler, event_name, ignore_condition_name, verbose, *args, **kwargs
)
for handler in self.handlers
)
)

async def on_llm_start(
self,
serialized: Dict[str, Any],
Expand All @@ -348,50 +385,25 @@ async def on_llm_start(
**kwargs: Any
) -> None:
"""Run when LLM starts running."""
for handler in self.handlers:
if not handler.ignore_llm:
if verbose or handler.always_verbose:
if asyncio.iscoroutinefunction(handler.on_llm_start):
await handler.on_llm_start(serialized, prompts, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None,
functools.partial(
handler.on_llm_start, serialized, prompts, **kwargs
),
)
await self._handle_event(
"on_llm_start", "ignore_llm", verbose, serialized, prompts, **kwargs
)

async def on_llm_new_token(
self, token: str, verbose: bool = False, **kwargs: Any
) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
for handler in self.handlers:
if not handler.ignore_llm:
if verbose or handler.always_verbose:
if asyncio.iscoroutinefunction(handler.on_llm_new_token):
await handler.on_llm_new_token(token, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None,
functools.partial(
handler.on_llm_new_token, token, **kwargs
),
)
await self._handle_event(
"on_llm_new_token", "ignore_llm", verbose, token, **kwargs
)

async def on_llm_end(
self, response: LLMResult, verbose: bool = False, **kwargs: Any
) -> None:
"""Run when LLM ends running."""
for handler in self.handlers:
if not handler.ignore_llm:
if verbose or handler.always_verbose:
if asyncio.iscoroutinefunction(handler.on_llm_end):
await handler.on_llm_end(response, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None,
functools.partial(handler.on_llm_end, response, **kwargs),
)
await self._handle_event(
"on_llm_end", "ignore_llm", verbose, response, **kwargs
)

async def on_llm_error(
self,
Expand All @@ -400,16 +412,7 @@ async def on_llm_error(
**kwargs: Any
) -> None:
"""Run when LLM errors."""
for handler in self.handlers:
if not handler.ignore_llm:
if verbose or handler.always_verbose:
if asyncio.iscoroutinefunction(handler.on_llm_error):
await handler.on_llm_error(error, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None,
functools.partial(handler.on_llm_error, error, **kwargs),
)
await self._handle_event("on_llm_error", "ignore_llm", verbose, error, **kwargs)

async def on_chain_start(
self,
Expand All @@ -419,33 +422,17 @@ async def on_chain_start(
**kwargs: Any
) -> None:
"""Run when chain starts running."""
for handler in self.handlers:
if not handler.ignore_chain:
if verbose or handler.always_verbose:
if asyncio.iscoroutinefunction(handler.on_chain_start):
await handler.on_chain_start(serialized, inputs, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None,
functools.partial(
handler.on_chain_start, serialized, inputs, **kwargs
),
)
await self._handle_event(
"on_chain_start", "ignore_chain", verbose, serialized, inputs, **kwargs
)

async def on_chain_end(
self, outputs: Dict[str, Any], verbose: bool = False, **kwargs: Any
) -> None:
"""Run when chain ends running."""
for handler in self.handlers:
if not handler.ignore_chain:
if verbose or handler.always_verbose:
if asyncio.iscoroutinefunction(handler.on_chain_end):
await handler.on_chain_end(outputs, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None,
functools.partial(handler.on_chain_end, outputs, **kwargs),
)
await self._handle_event(
"on_chain_end", "ignore_chain", verbose, outputs, **kwargs
)

async def on_chain_error(
self,
Expand All @@ -454,16 +441,9 @@ async def on_chain_error(
**kwargs: Any
) -> None:
"""Run when chain errors."""
for handler in self.handlers:
if not handler.ignore_chain:
if verbose or handler.always_verbose:
if asyncio.iscoroutinefunction(handler.on_chain_error):
await handler.on_chain_error(error, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None,
functools.partial(handler.on_chain_error, error, **kwargs),
)
await self._handle_event(
"on_chain_error", "ignore_chain", verbose, error, **kwargs
)

async def on_tool_start(
self,
Expand All @@ -473,33 +453,17 @@ async def on_tool_start(
**kwargs: Any
) -> None:
"""Run when tool starts running."""
for handler in self.handlers:
if not handler.ignore_agent:
if verbose or handler.always_verbose:
if asyncio.iscoroutinefunction(handler.on_tool_start):
await handler.on_tool_start(serialized, input_str, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None,
functools.partial(
handler.on_tool_start, serialized, input_str, **kwargs
),
)
await self._handle_event(
"on_tool_start", "ignore_agent", verbose, serialized, input_str, **kwargs
)

async def on_tool_end(
self, output: str, verbose: bool = False, **kwargs: Any
) -> None:
"""Run when tool ends running."""
for handler in self.handlers:
if not handler.ignore_agent:
if verbose or handler.always_verbose:
if asyncio.iscoroutinefunction(handler.on_tool_end):
await handler.on_tool_end(output, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None,
functools.partial(handler.on_tool_end, output, **kwargs),
)
await self._handle_event(
"on_tool_end", "ignore_agent", verbose, output, **kwargs
)

async def on_tool_error(
self,
Expand All @@ -508,61 +472,29 @@ async def on_tool_error(
**kwargs: Any
) -> None:
"""Run when tool errors."""
for handler in self.handlers:
if not handler.ignore_agent:
if verbose or handler.always_verbose:
if asyncio.iscoroutinefunction(handler.on_tool_error):
await handler.on_tool_error(error, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None,
functools.partial(handler.on_tool_error, error, **kwargs),
)
await self._handle_event(
"on_tool_error", "ignore_agent", verbose, error, **kwargs
)

async def on_text(self, text: str, verbose: bool = False, **kwargs: Any) -> None:
"""Run when text is printed."""
for handler in self.handlers:
if verbose or handler.always_verbose:
if asyncio.iscoroutinefunction(handler.on_text):
await handler.on_text(text, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None, functools.partial(handler.on_text, text, **kwargs)
)
await self._handle_event("on_text", None, verbose, text, **kwargs)

async def on_agent_action(
self, action: AgentAction, verbose: bool = False, **kwargs: Any
) -> None:
"""Run on agent action."""
for handler in self.handlers:
if not handler.ignore_agent:
if verbose or handler.always_verbose:
if asyncio.iscoroutinefunction(handler.on_agent_action):
await handler.on_agent_action(action, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None,
functools.partial(
handler.on_agent_action, action, **kwargs
),
)
await self._handle_event(
"on_agent_action", "ignore_agent", verbose, action, **kwargs
)

async def on_agent_finish(
self, finish: AgentFinish, verbose: bool = False, **kwargs: Any
) -> None:
"""Run when agent finishes."""
for handler in self.handlers:
if not handler.ignore_agent:
if verbose or handler.always_verbose:
if asyncio.iscoroutinefunction(handler.on_agent_finish):
await handler.on_agent_finish(finish, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None,
functools.partial(
handler.on_agent_finish, finish, **kwargs
),
)
await self._handle_event(
"on_agent_finish", "ignore_agent", verbose, finish, **kwargs
)

def add_handler(self, handler: BaseCallbackHandler) -> None:
"""Add a handler to the callback manager."""
Expand Down
5 changes: 3 additions & 2 deletions tests/unit_tests/callbacks/test_callback_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,5 +176,6 @@ async def test_async_callback_manager_sync_handler() -> None:
"""Test the AsyncCallbackManager."""
handler1 = FakeCallbackHandler(always_verbose_=True)
handler2 = FakeAsyncCallbackHandler()
manager = AsyncCallbackManager([handler1, handler2])
await _test_callback_manager_async(manager, handler1, handler2)
handler3 = FakeAsyncCallbackHandler(always_verbose_=True)
manager = AsyncCallbackManager([handler1, handler2, handler3])
await _test_callback_manager_async(manager, handler1, handler2, handler3)

0 comments on commit 4d730a9

Please sign in to comment.