Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DRAFT for Feedback - Support for token streaming for more dynamic UX #4443

Draft
wants to merge 33 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
963e409
AssistantAgent support for streaming tokens
jspv Nov 22, 2024
8b9295e
Merge branch 'main' into stream_token_0.4
jspv Dec 1, 2024
c20298d
Merge branch 'microsoft:main' into stream_token_0.4
jspv Dec 4, 2024
cf9a1eb
Merge branch 'microsoft:main' into stream_token_0.4
jspv Dec 9, 2024
b854dae
Updates for token handling
jspv Dec 9, 2024
6054cd9
Merge branch 'stream_token_0.4' of https://github.com/jspv/autogen in…
jspv Dec 9, 2024
f43f831
Fixed formatting
jspv Dec 9, 2024
9a763f4
Set default to fix Azure clients
jspv Dec 9, 2024
9d4d58c
refresh from upstream
jspv Dec 18, 2024
4be98d0
Adding ToolCallResultSummaryMessage
jspv Dec 18, 2024
c36c4f2
Merge remote-tracking branch 'origin/main' into stream_token_0.4
jspv Dec 18, 2024
9ff27cf
revert typo
jspv Dec 18, 2024
36fd8a5
Fix merge error
jspv Dec 18, 2024
16b6f5a
Merge branch 'ToolCallResultSummaryMessage' into stream_token_0.4
jspv Dec 18, 2024
c740229
ToolCallResultSummaryMessage support
jspv Dec 18, 2024
bd3ee60
Support for ToolCallResultSummaryMessage
jspv Dec 18, 2024
d17dff4
Merge branch 'main' into ToolCallResultSummaryMessage
jspv Dec 19, 2024
b26bcb0
Added ToolCallSummaryMessage
jspv Dec 19, 2024
d9c1d32
Merge branch 'ToolCallResultSummaryMessage' into stream_token_0.4
jspv Dec 19, 2024
f0e9be2
Merge branch 'main' into ToolCallResultSummaryMessage
jspv Dec 19, 2024
aef903e
ruff format
jspv Dec 19, 2024
b4eb3ab
Add ToolCallSummaryMessage to ChatMessage
jspv Dec 19, 2024
8a9561c
Merge branch 'ToolCallResultSummaryMessage' into stream_token_0.4
jspv Dec 19, 2024
4b72194
Merge branch 'main' into ToolCallResultSummaryMessage
ekzhu Dec 19, 2024
285228a
typing and tests for ToolCallSummaryMessage
jspv Dec 19, 2024
22ff506
Merge branch 'ToolCallResultSummaryMessage' of https://github.com/jsp…
jspv Dec 19, 2024
98a2feb
Merge branch 'ToolCallResultSummaryMessage' into stream_token_0.4
jspv Dec 19, 2024
386db37
Merge branch 'main' into ToolCallResultSummaryMessage
jspv Dec 19, 2024
a6cab09
PR Feedback
jspv Dec 19, 2024
2622fdc
Merge branch 'ToolCallResultSummaryMessage' into stream_token_0.4
jspv Dec 19, 2024
26ec534
Merge branch 'main' into stream_token_0.4
jspv Dec 21, 2024
58a322a
Merge branch 'main' into stream_token_0.4
jspv Dec 25, 2024
f5a14c6
Merge branch 'main' into stream_token_0.4
jspv Dec 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def __init__(
system_message: (
str | None
) = "You are a helpful AI assistant. Solve tasks using your tools. Reply with TERMINATE when the task has been completed.",
token_callback: Callable | None = None,
reflect_on_tool_use: bool = False,
tool_call_summary_format: str = "{result}",
):
Expand All @@ -255,6 +256,7 @@ def __init__(
else:
self._system_messages = [SystemMessage(content=system_message)]
self._tools: List[Tool] = []
self._token_callback = token_callback
if tools is not None:
if model_client.capabilities["function_calling"] is False:
raise ValueError("The model does not support function calling.")
Expand Down Expand Up @@ -334,9 +336,26 @@ async def on_messages_stream(

# Generate an inference result based on the current model context.
llm_messages = self._system_messages + await self._model_context.get_messages()
result = await self._model_client.create(
llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token
)

# if token_callback is set, use create_stream to get the tokens as they are
# generated and call the token_callback with the tokens
if self._token_callback is not None:
async for result in self._model_client.create_stream(
llm_messages,
tools=self._tools + self._handoff_tools,
cancellation_token=cancellation_token,
):
# if the result is a string, it is a token to be streamed back
if isinstance(result, str):
await self._token_callback(result)
else:
break
else:
result = await self._model_client.create(
llm_messages,
tools=self._tools + self._handoff_tools,
cancellation_token=cancellation_token,
)

# Add the response to the model context.
await self._model_context.add_message(AssistantMessage(content=result.content, source=self.name))
Expand Down Expand Up @@ -387,7 +406,24 @@ async def on_messages_stream(
if self._reflect_on_tool_use:
# Generate another inference result based on the tool call and result.
llm_messages = self._system_messages + await self._model_context.get_messages()
result = await self._model_client.create(llm_messages, cancellation_token=cancellation_token)

# if token_callback is set, use create_stream to get the tokens as they are
# generated and call the token_callback with the tokens
if self._token_callback is not None:
async for result in self._model_client.create_stream(
llm_messages,
cancellation_token=cancellation_token,
):
# if the result is a string, it is a token to be streamed back
if isinstance(result, str):
await self._token_callback(result)
else:
break
else:
result = await self._model_client.create(
llm_messages,
cancellation_token=cancellation_token,
)
assert isinstance(result.content, str)
# Add the response to the model context.
await self._model_context.add_message(AssistantMessage(content=result.content, source=self.name))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ async def create_stream(
json_output: Optional[bool] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
max_consecutive_empty_chunk_tolerance: int = 0,
max_consecutive_empty_chunk_tolerance: int = 10,
) -> AsyncGenerator[Union[str, CreateResult], None]:
"""
Creates an AsyncGenerator that will yield a stream of chat completions based on the provided messages and tools.
Expand Down