From 2f7a0f7e4cec0e52a47bac32b50435db1c77c010 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Fri, 22 Nov 2024 13:25:09 -0800 Subject: [PATCH] add roleplay tool call orchestrator --- .../src/AutoGen.OpenAI/AutoGen.OpenAI.csproj | 1 + .../RolePlayToolCallOrchestrator.cs | 133 ++++++++++++ .../RolePlayToolCallOrchestratorTests.cs | 189 ++++++++++++++++++ 3 files changed, 323 insertions(+) create mode 100644 dotnet/src/AutoGen.OpenAI/Orchestrator/RolePlayToolCallOrchestrator.cs create mode 100644 dotnet/test/AutoGen.OpenAI.Tests/RolePlayToolCallOrchestratorTests.cs diff --git a/dotnet/src/AutoGen.OpenAI/AutoGen.OpenAI.csproj b/dotnet/src/AutoGen.OpenAI/AutoGen.OpenAI.csproj index 7f00b63be86..70c0f2b0d1c 100644 --- a/dotnet/src/AutoGen.OpenAI/AutoGen.OpenAI.csproj +++ b/dotnet/src/AutoGen.OpenAI/AutoGen.OpenAI.csproj @@ -18,6 +18,7 @@ + diff --git a/dotnet/src/AutoGen.OpenAI/Orchestrator/RolePlayToolCallOrchestrator.cs b/dotnet/src/AutoGen.OpenAI/Orchestrator/RolePlayToolCallOrchestrator.cs new file mode 100644 index 00000000000..f088e1748e6 --- /dev/null +++ b/dotnet/src/AutoGen.OpenAI/Orchestrator/RolePlayToolCallOrchestrator.cs @@ -0,0 +1,133 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// RolePlayToolCallOrchestrator.cs + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using AutoGen.OpenAI.Extension; +using OpenAI.Chat; + +namespace AutoGen.OpenAI.Orchestrator; + +/// +/// Orchestrating group chat using role play tool call +/// +public partial class RolePlayToolCallOrchestrator : IOrchestrator +{ + public readonly ChatClient chatClient; + private readonly Graph? workflow; + + public RolePlayToolCallOrchestrator(ChatClient chatClient, Graph? workflow = null) + { + this.chatClient = chatClient; + this.workflow = workflow; + } + + public async Task GetNextSpeakerAsync( + OrchestrationContext context, + CancellationToken cancellationToken = default) + { + var candidates = context.Candidates.ToList(); + + if (candidates.Count == 0) + { + return null; + } + + if (candidates.Count == 1) + { + return candidates.First(); + } + + // if there's a workflow + // and the next available agent from the workflow is in the group chat + // then return the next agent from the workflow + if (this.workflow != null) + { + var lastMessage = context.ChatHistory.LastOrDefault(); + if (lastMessage == null) + { + return null; + } + var currentSpeaker = candidates.First(candidates => candidates.Name == lastMessage.From); + var nextAgents = await this.workflow.TransitToNextAvailableAgentsAsync(currentSpeaker, context.ChatHistory, cancellationToken); + nextAgents = nextAgents.Where(nextAgent => candidates.Any(candidate => candidate.Name == nextAgent.Name)); + candidates = nextAgents.ToList(); + if (!candidates.Any()) + { + return null; + } + + if (candidates is { Count: 1 }) + { + return candidates.First(); + } + } + + // In this case, since there are more than one available agents from the workflow for the next speaker + // We need to invoke LLM to select the next speaker via select next speaker function + + var chatHistoryStringBuilder = new StringBuilder(); + foreach (var message in context.ChatHistory) + { + var chatHistoryPrompt = $"{message.From}: {message.GetContent()}"; + + chatHistoryStringBuilder.AppendLine(chatHistoryPrompt); + } + + var chatHistory = chatHistoryStringBuilder.ToString(); + + var prompt = $""" + # Task: Select the next speaker + + You are in a role-play game. Carefully read the conversation history and select the next speaker from the available roles. + + # Conversation + {chatHistory} + + # Available roles + - {string.Join(",", candidates.Select(candidate => candidate.Name))} + + Select the next speaker from the available roles and provide a reason for your selection. + """; + + // enforce the next speaker to be selected by the LLM + var option = new ChatCompletionOptions + { + ToolChoice = ChatToolChoice.CreateFunctionChoice(this.SelectNextSpeakerFunctionContract.Name), + }; + + option.Tools.Add(this.SelectNextSpeakerFunctionContract.ToChatTool()); + var toolCallMiddleware = new FunctionCallMiddleware( + functions: [this.SelectNextSpeakerFunctionContract], + functionMap: new Dictionary>> + { + [this.SelectNextSpeakerFunctionContract.Name] = this.SelectNextSpeakerWrapper, + }); + + var selectAgent = new OpenAIChatAgent( + chatClient, + "admin", + option) + .RegisterMessageConnector() + .RegisterMiddleware(toolCallMiddleware); + + var reply = await selectAgent.SendAsync(prompt); + + var nextSpeaker = candidates.FirstOrDefault(candidate => candidate.Name == reply.GetContent()); + + return nextSpeaker; + } + + /// + /// Select the next speaker by name and reason + /// + [Function] + public async Task SelectNextSpeaker(string name, string reason) + { + return name; + } +} diff --git a/dotnet/test/AutoGen.OpenAI.Tests/RolePlayToolCallOrchestratorTests.cs b/dotnet/test/AutoGen.OpenAI.Tests/RolePlayToolCallOrchestratorTests.cs new file mode 100644 index 00000000000..d3a50170207 --- /dev/null +++ b/dotnet/test/AutoGen.OpenAI.Tests/RolePlayToolCallOrchestratorTests.cs @@ -0,0 +1,189 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// RolePlayToolCallOrchestratorTests.cs + +using System; +using System.Collections.Generic; +using System.Threading.Tasks; +using AutoGen.OpenAI.Orchestrator; +using AutoGen.Tests; +using Azure.AI.OpenAI; +using FluentAssertions; +using Moq; +using OpenAI; +using OpenAI.Chat; +using Xunit; + +namespace AutoGen.OpenAI.Tests; + +public class RolePlayToolCallOrchestratorTests +{ + [Fact] + public async Task ItReturnNullWhenNoCandidateIsAvailableAsync() + { + var chatClient = Mock.Of(); + var orchestrator = new RolePlayToolCallOrchestrator(chatClient); + var context = new OrchestrationContext + { + Candidates = [], + ChatHistory = [], + }; + + var speaker = await orchestrator.GetNextSpeakerAsync(context); + speaker.Should().BeNull(); + } + + [Fact] + public async Task ItReturnCandidateWhenOnlyOneCandidateIsAvailableAsync() + { + var chatClient = Mock.Of(); + var alice = new EchoAgent("Alice"); + var orchestrator = new RolePlayToolCallOrchestrator(chatClient); + var context = new OrchestrationContext + { + Candidates = [alice], + ChatHistory = [], + }; + + var speaker = await orchestrator.GetNextSpeakerAsync(context); + speaker.Should().Be(alice); + } + + [Fact] + public async Task ItSelectNextSpeakerFromWorkflowIfProvided() + { + var workflow = new Graph(); + var alice = new EchoAgent("Alice"); + var bob = new EchoAgent("Bob"); + var charlie = new EchoAgent("Charlie"); + workflow.AddTransition(Transition.Create(alice, bob)); + workflow.AddTransition(Transition.Create(bob, charlie)); + workflow.AddTransition(Transition.Create(charlie, alice)); + + var client = Mock.Of(); + var orchestrator = new RolePlayToolCallOrchestrator(client, workflow); + var context = new OrchestrationContext + { + Candidates = [alice, bob, charlie], + ChatHistory = + [ + new TextMessage(Role.User, "Hello, Bob", from: "Alice"), + ], + }; + + var speaker = await orchestrator.GetNextSpeakerAsync(context); + speaker.Should().Be(bob); + } + + [Fact] + public async Task ItReturnNullIfNoAvailableAgentFromWorkflowAsync() + { + var workflow = new Graph(); + var alice = new EchoAgent("Alice"); + var bob = new EchoAgent("Bob"); + workflow.AddTransition(Transition.Create(alice, bob)); + + var client = Mock.Of(); + var orchestrator = new RolePlayToolCallOrchestrator(client, workflow); + var context = new OrchestrationContext + { + Candidates = [alice, bob], + ChatHistory = + [ + new TextMessage(Role.User, "Hello, Alice", from: "Bob"), + ], + }; + + var speaker = await orchestrator.GetNextSpeakerAsync(context); + speaker.Should().BeNull(); + } + + [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")] + public async Task GPT_3_5_CoderReviewerRunnerTestAsync() + { + var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable."); + var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable."); + var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable."); + var openaiClient = new AzureOpenAIClient(new Uri(endpoint), new System.ClientModel.ApiKeyCredential(key)); + var chatClient = openaiClient.GetChatClient(deployName); + + await CoderReviewerRunnerTestAsync(chatClient); + } + + [ApiKeyFact("OPENAI_API_KEY")] + public async Task GPT_4o_CoderReviewerRunnerTestAsync() + { + var apiKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new InvalidOperationException("OPENAI_API_KEY is not set"); + var model = "gpt-4o"; + var openaiClient = new OpenAIClient(apiKey); + var chatClient = openaiClient.GetChatClient(model); + + await CoderReviewerRunnerTestAsync(chatClient); + } + + [ApiKeyFact("OPENAI_API_KEY")] + public async Task GPT_4o_mini_CoderReviewerRunnerTestAsync() + { + var apiKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new InvalidOperationException("OPENAI_API_KEY is not set"); + var model = "gpt-4o-mini"; + var openaiClient = new OpenAIClient(apiKey); + var chatClient = openaiClient.GetChatClient(model); + + await CoderReviewerRunnerTestAsync(chatClient); + } + + /// + /// This test is to mimic the conversation among coder, reviewer and runner. + /// The coder will write the code, the reviewer will review the code, and the runner will run the code. + /// + /// + /// + private async Task CoderReviewerRunnerTestAsync(ChatClient client) + { + var coder = new EchoAgent("Coder"); + var reviewer = new EchoAgent("Reviewer"); + var runner = new EchoAgent("Runner"); + var user = new EchoAgent("User"); + var initializeMessage = new List + { + new TextMessage(Role.User, "Hello, I am user, I will provide the coding task, please write the code first, then review and run it", from: "User"), + new TextMessage(Role.User, "Hello, I am coder, I will write the code", from: "Coder"), + new TextMessage(Role.User, "Hello, I am reviewer, I will review the code", from: "Reviewer"), + new TextMessage(Role.User, "Hello, I am runner, I will run the code", from: "Runner"), + new TextMessage(Role.User, "how to print 'hello world' using C#", from: user.Name), + }; + + var chatHistory = new List() + { + new TextMessage(Role.User, """ + ```csharp + Console.WriteLine("Hello World"); + ``` + """, from: coder.Name), + new TextMessage(Role.User, "The code looks good", from: reviewer.Name), + new TextMessage(Role.User, "The code runs successfully, the output is 'Hello World'", from: runner.Name), + }; + + var orchestrator = new RolePlayToolCallOrchestrator(client); + foreach (var message in chatHistory) + { + var context = new OrchestrationContext + { + Candidates = [coder, reviewer, runner, user], + ChatHistory = initializeMessage, + }; + + var speaker = await orchestrator.GetNextSpeakerAsync(context); + speaker!.Name.Should().Be(message.From); + initializeMessage.Add(message); + } + + // the last next speaker should be the user + var lastSpeaker = await orchestrator.GetNextSpeakerAsync(new OrchestrationContext + { + Candidates = [coder, reviewer, runner, user], + ChatHistory = initializeMessage, + }); + + lastSpeaker!.Name.Should().Be(user.Name); + } +}