diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py index 3b92a4a51b0..ea4d3d2b2a5 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py @@ -245,6 +245,7 @@ def __init__( system_message: ( str | None ) = "You are a helpful AI assistant. Solve tasks using your tools. Reply with TERMINATE when the task has been completed.", + token_callback: Callable | None = None, reflect_on_tool_use: bool = False, tool_call_summary_format: str = "{result}", ): @@ -255,6 +256,7 @@ def __init__( else: self._system_messages = [SystemMessage(content=system_message)] self._tools: List[Tool] = [] + self._token_callback = token_callback if tools is not None: if model_client.capabilities["function_calling"] is False: raise ValueError("The model does not support function calling.") @@ -334,9 +336,26 @@ async def on_messages_stream( # Generate an inference result based on the current model context. llm_messages = self._system_messages + await self._model_context.get_messages() - result = await self._model_client.create( - llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token - ) + + # if token_callback is set, use create_stream to get the tokens as they are + # generated and call the token_callback with the tokens + if self._token_callback is not None: + async for result in self._model_client.create_stream( + llm_messages, + tools=self._tools + self._handoff_tools, + cancellation_token=cancellation_token, + ): + # if the result is a string, it is a token to be streamed back + if isinstance(result, str): + await self._token_callback(result) + else: + break + else: + result = await self._model_client.create( + llm_messages, + tools=self._tools + self._handoff_tools, + cancellation_token=cancellation_token, + ) # Add the response to the model context. await self._model_context.add_message(AssistantMessage(content=result.content, source=self.name)) @@ -387,7 +406,24 @@ async def on_messages_stream( if self._reflect_on_tool_use: # Generate another inference result based on the tool call and result. llm_messages = self._system_messages + await self._model_context.get_messages() - result = await self._model_client.create(llm_messages, cancellation_token=cancellation_token) + + # if token_callback is set, use create_stream to get the tokens as they are + # generated and call the token_callback with the tokens + if self._token_callback is not None: + async for result in self._model_client.create_stream( + llm_messages, + cancellation_token=cancellation_token, + ): + # if the result is a string, it is a token to be streamed back + if isinstance(result, str): + await self._token_callback(result) + else: + break + else: + result = await self._model_client.create( + llm_messages, + cancellation_token=cancellation_token, + ) assert isinstance(result.content, str) # Add the response to the model context. await self._model_context.add_message(AssistantMessage(content=result.content, source=self.name)) diff --git a/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py b/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py index a804250b3aa..685a86d0a0c 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py @@ -588,7 +588,7 @@ async def create_stream( json_output: Optional[bool] = None, extra_create_args: Mapping[str, Any] = {}, cancellation_token: Optional[CancellationToken] = None, - max_consecutive_empty_chunk_tolerance: int = 0, + max_consecutive_empty_chunk_tolerance: int = 10, ) -> AsyncGenerator[Union[str, CreateResult], None]: """ Creates an AsyncGenerator that will yield a stream of chat completions based on the provided messages and tools.