-
Notifications
You must be signed in to change notification settings - Fork 5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
1e0b254
commit 2f7a0f7
Showing
3 changed files
with
323 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
133 changes: 133 additions & 0 deletions
133
dotnet/src/AutoGen.OpenAI/Orchestrator/RolePlayToolCallOrchestrator.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// RolePlayToolCallOrchestrator.cs | ||
|
||
using System; | ||
using System.Collections.Generic; | ||
using System.Linq; | ||
using System.Text; | ||
using System.Threading; | ||
using System.Threading.Tasks; | ||
using AutoGen.OpenAI.Extension; | ||
using OpenAI.Chat; | ||
|
||
namespace AutoGen.OpenAI.Orchestrator; | ||
|
||
/// <summary> | ||
/// Orchestrating group chat using role play tool call | ||
/// </summary> | ||
public partial class RolePlayToolCallOrchestrator : IOrchestrator | ||
{ | ||
public readonly ChatClient chatClient; | ||
private readonly Graph? workflow; | ||
|
||
public RolePlayToolCallOrchestrator(ChatClient chatClient, Graph? workflow = null) | ||
{ | ||
this.chatClient = chatClient; | ||
this.workflow = workflow; | ||
} | ||
|
||
public async Task<IAgent?> GetNextSpeakerAsync( | ||
OrchestrationContext context, | ||
CancellationToken cancellationToken = default) | ||
{ | ||
var candidates = context.Candidates.ToList(); | ||
|
||
if (candidates.Count == 0) | ||
{ | ||
return null; | ||
} | ||
|
||
if (candidates.Count == 1) | ||
{ | ||
return candidates.First(); | ||
} | ||
|
||
// if there's a workflow | ||
// and the next available agent from the workflow is in the group chat | ||
// then return the next agent from the workflow | ||
if (this.workflow != null) | ||
{ | ||
var lastMessage = context.ChatHistory.LastOrDefault(); | ||
if (lastMessage == null) | ||
{ | ||
return null; | ||
} | ||
var currentSpeaker = candidates.First(candidates => candidates.Name == lastMessage.From); | ||
var nextAgents = await this.workflow.TransitToNextAvailableAgentsAsync(currentSpeaker, context.ChatHistory, cancellationToken); | ||
nextAgents = nextAgents.Where(nextAgent => candidates.Any(candidate => candidate.Name == nextAgent.Name)); | ||
candidates = nextAgents.ToList(); | ||
if (!candidates.Any()) | ||
{ | ||
return null; | ||
} | ||
|
||
if (candidates is { Count: 1 }) | ||
{ | ||
return candidates.First(); | ||
} | ||
} | ||
|
||
// In this case, since there are more than one available agents from the workflow for the next speaker | ||
// We need to invoke LLM to select the next speaker via select next speaker function | ||
|
||
var chatHistoryStringBuilder = new StringBuilder(); | ||
foreach (var message in context.ChatHistory) | ||
{ | ||
var chatHistoryPrompt = $"{message.From}: {message.GetContent()}"; | ||
|
||
chatHistoryStringBuilder.AppendLine(chatHistoryPrompt); | ||
} | ||
|
||
var chatHistory = chatHistoryStringBuilder.ToString(); | ||
|
||
var prompt = $""" | ||
# Task: Select the next speaker | ||
|
||
You are in a role-play game. Carefully read the conversation history and select the next speaker from the available roles. | ||
|
||
# Conversation | ||
{chatHistory} | ||
|
||
# Available roles | ||
- {string.Join(",", candidates.Select(candidate => candidate.Name))} | ||
|
||
Select the next speaker from the available roles and provide a reason for your selection. | ||
"""; | ||
|
||
// enforce the next speaker to be selected by the LLM | ||
var option = new ChatCompletionOptions | ||
{ | ||
ToolChoice = ChatToolChoice.CreateFunctionChoice(this.SelectNextSpeakerFunctionContract.Name), | ||
}; | ||
|
||
option.Tools.Add(this.SelectNextSpeakerFunctionContract.ToChatTool()); | ||
var toolCallMiddleware = new FunctionCallMiddleware( | ||
functions: [this.SelectNextSpeakerFunctionContract], | ||
functionMap: new Dictionary<string, Func<string, Task<string>>> | ||
{ | ||
[this.SelectNextSpeakerFunctionContract.Name] = this.SelectNextSpeakerWrapper, | ||
}); | ||
|
||
var selectAgent = new OpenAIChatAgent( | ||
chatClient, | ||
"admin", | ||
option) | ||
.RegisterMessageConnector() | ||
.RegisterMiddleware(toolCallMiddleware); | ||
|
||
var reply = await selectAgent.SendAsync(prompt); | ||
|
||
var nextSpeaker = candidates.FirstOrDefault(candidate => candidate.Name == reply.GetContent()); | ||
|
||
return nextSpeaker; | ||
} | ||
|
||
/// <summary> | ||
/// Select the next speaker by name and reason | ||
/// </summary> | ||
[Function] | ||
public async Task<string> SelectNextSpeaker(string name, string reason) | ||
{ | ||
return name; | ||
} | ||
} |
189 changes: 189 additions & 0 deletions
189
dotnet/test/AutoGen.OpenAI.Tests/RolePlayToolCallOrchestratorTests.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// RolePlayToolCallOrchestratorTests.cs | ||
|
||
using System; | ||
using System.Collections.Generic; | ||
using System.Threading.Tasks; | ||
using AutoGen.OpenAI.Orchestrator; | ||
using AutoGen.Tests; | ||
using Azure.AI.OpenAI; | ||
using FluentAssertions; | ||
using Moq; | ||
using OpenAI; | ||
using OpenAI.Chat; | ||
using Xunit; | ||
|
||
namespace AutoGen.OpenAI.Tests; | ||
|
||
public class RolePlayToolCallOrchestratorTests | ||
{ | ||
[Fact] | ||
public async Task ItReturnNullWhenNoCandidateIsAvailableAsync() | ||
{ | ||
var chatClient = Mock.Of<ChatClient>(); | ||
var orchestrator = new RolePlayToolCallOrchestrator(chatClient); | ||
var context = new OrchestrationContext | ||
{ | ||
Candidates = [], | ||
ChatHistory = [], | ||
}; | ||
|
||
var speaker = await orchestrator.GetNextSpeakerAsync(context); | ||
speaker.Should().BeNull(); | ||
} | ||
|
||
[Fact] | ||
public async Task ItReturnCandidateWhenOnlyOneCandidateIsAvailableAsync() | ||
{ | ||
var chatClient = Mock.Of<ChatClient>(); | ||
var alice = new EchoAgent("Alice"); | ||
var orchestrator = new RolePlayToolCallOrchestrator(chatClient); | ||
var context = new OrchestrationContext | ||
{ | ||
Candidates = [alice], | ||
ChatHistory = [], | ||
}; | ||
|
||
var speaker = await orchestrator.GetNextSpeakerAsync(context); | ||
speaker.Should().Be(alice); | ||
} | ||
|
||
[Fact] | ||
public async Task ItSelectNextSpeakerFromWorkflowIfProvided() | ||
{ | ||
var workflow = new Graph(); | ||
var alice = new EchoAgent("Alice"); | ||
var bob = new EchoAgent("Bob"); | ||
var charlie = new EchoAgent("Charlie"); | ||
workflow.AddTransition(Transition.Create(alice, bob)); | ||
workflow.AddTransition(Transition.Create(bob, charlie)); | ||
workflow.AddTransition(Transition.Create(charlie, alice)); | ||
|
||
var client = Mock.Of<ChatClient>(); | ||
var orchestrator = new RolePlayToolCallOrchestrator(client, workflow); | ||
var context = new OrchestrationContext | ||
{ | ||
Candidates = [alice, bob, charlie], | ||
ChatHistory = | ||
[ | ||
new TextMessage(Role.User, "Hello, Bob", from: "Alice"), | ||
], | ||
}; | ||
|
||
var speaker = await orchestrator.GetNextSpeakerAsync(context); | ||
speaker.Should().Be(bob); | ||
} | ||
|
||
[Fact] | ||
public async Task ItReturnNullIfNoAvailableAgentFromWorkflowAsync() | ||
{ | ||
var workflow = new Graph(); | ||
var alice = new EchoAgent("Alice"); | ||
var bob = new EchoAgent("Bob"); | ||
workflow.AddTransition(Transition.Create(alice, bob)); | ||
|
||
var client = Mock.Of<ChatClient>(); | ||
var orchestrator = new RolePlayToolCallOrchestrator(client, workflow); | ||
var context = new OrchestrationContext | ||
{ | ||
Candidates = [alice, bob], | ||
ChatHistory = | ||
[ | ||
new TextMessage(Role.User, "Hello, Alice", from: "Bob"), | ||
], | ||
}; | ||
|
||
var speaker = await orchestrator.GetNextSpeakerAsync(context); | ||
speaker.Should().BeNull(); | ||
} | ||
|
||
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")] | ||
public async Task GPT_3_5_CoderReviewerRunnerTestAsync() | ||
{ | ||
var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable."); | ||
var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable."); | ||
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable."); | ||
var openaiClient = new AzureOpenAIClient(new Uri(endpoint), new System.ClientModel.ApiKeyCredential(key)); | ||
var chatClient = openaiClient.GetChatClient(deployName); | ||
|
||
await CoderReviewerRunnerTestAsync(chatClient); | ||
} | ||
|
||
[ApiKeyFact("OPENAI_API_KEY")] | ||
public async Task GPT_4o_CoderReviewerRunnerTestAsync() | ||
{ | ||
var apiKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new InvalidOperationException("OPENAI_API_KEY is not set"); | ||
var model = "gpt-4o"; | ||
var openaiClient = new OpenAIClient(apiKey); | ||
var chatClient = openaiClient.GetChatClient(model); | ||
|
||
await CoderReviewerRunnerTestAsync(chatClient); | ||
} | ||
|
||
[ApiKeyFact("OPENAI_API_KEY")] | ||
public async Task GPT_4o_mini_CoderReviewerRunnerTestAsync() | ||
{ | ||
var apiKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new InvalidOperationException("OPENAI_API_KEY is not set"); | ||
var model = "gpt-4o-mini"; | ||
var openaiClient = new OpenAIClient(apiKey); | ||
var chatClient = openaiClient.GetChatClient(model); | ||
|
||
await CoderReviewerRunnerTestAsync(chatClient); | ||
} | ||
|
||
/// <summary> | ||
/// This test is to mimic the conversation among coder, reviewer and runner. | ||
/// The coder will write the code, the reviewer will review the code, and the runner will run the code. | ||
/// </summary> | ||
/// <param name="client"></param> | ||
/// <returns></returns> | ||
private async Task CoderReviewerRunnerTestAsync(ChatClient client) | ||
{ | ||
var coder = new EchoAgent("Coder"); | ||
var reviewer = new EchoAgent("Reviewer"); | ||
var runner = new EchoAgent("Runner"); | ||
var user = new EchoAgent("User"); | ||
var initializeMessage = new List<IMessage> | ||
{ | ||
new TextMessage(Role.User, "Hello, I am user, I will provide the coding task, please write the code first, then review and run it", from: "User"), | ||
new TextMessage(Role.User, "Hello, I am coder, I will write the code", from: "Coder"), | ||
new TextMessage(Role.User, "Hello, I am reviewer, I will review the code", from: "Reviewer"), | ||
new TextMessage(Role.User, "Hello, I am runner, I will run the code", from: "Runner"), | ||
new TextMessage(Role.User, "how to print 'hello world' using C#", from: user.Name), | ||
}; | ||
|
||
var chatHistory = new List<IMessage>() | ||
{ | ||
new TextMessage(Role.User, """ | ||
```csharp | ||
Console.WriteLine("Hello World"); | ||
``` | ||
""", from: coder.Name), | ||
new TextMessage(Role.User, "The code looks good", from: reviewer.Name), | ||
new TextMessage(Role.User, "The code runs successfully, the output is 'Hello World'", from: runner.Name), | ||
}; | ||
|
||
var orchestrator = new RolePlayToolCallOrchestrator(client); | ||
foreach (var message in chatHistory) | ||
{ | ||
var context = new OrchestrationContext | ||
{ | ||
Candidates = [coder, reviewer, runner, user], | ||
ChatHistory = initializeMessage, | ||
}; | ||
|
||
var speaker = await orchestrator.GetNextSpeakerAsync(context); | ||
speaker!.Name.Should().Be(message.From); | ||
initializeMessage.Add(message); | ||
} | ||
|
||
// the last next speaker should be the user | ||
var lastSpeaker = await orchestrator.GetNextSpeakerAsync(new OrchestrationContext | ||
{ | ||
Candidates = [coder, reviewer, runner, user], | ||
ChatHistory = initializeMessage, | ||
}); | ||
|
||
lastSpeaker!.Name.Should().Be(user.Name); | ||
} | ||
} |