Skip to content

Commit

Permalink
add roleplay tool call orchestrator
Browse files Browse the repository at this point in the history
  • Loading branch information
LittleLittleCloud committed Nov 22, 2024
1 parent 1e0b254 commit 2f7a0f7
Show file tree
Hide file tree
Showing 3 changed files with 323 additions and 0 deletions.
1 change: 1 addition & 0 deletions dotnet/src/AutoGen.OpenAI/AutoGen.OpenAI.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

<ItemGroup>
<PackageReference Include="OpenAI" Version="$(OpenAISDKVersion)" />
<ProjectReference Include="..\AutoGen.SourceGenerator\AutoGen.SourceGenerator.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="false" />
</ItemGroup>

<ItemGroup>
Expand Down
133 changes: 133 additions & 0 deletions dotnet/src/AutoGen.OpenAI/Orchestrator/RolePlayToolCallOrchestrator.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// RolePlayToolCallOrchestrator.cs

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using AutoGen.OpenAI.Extension;
using OpenAI.Chat;

namespace AutoGen.OpenAI.Orchestrator;

/// <summary>
/// Orchestrating group chat using role play tool call
/// </summary>
public partial class RolePlayToolCallOrchestrator : IOrchestrator
{
public readonly ChatClient chatClient;
private readonly Graph? workflow;

public RolePlayToolCallOrchestrator(ChatClient chatClient, Graph? workflow = null)
{
this.chatClient = chatClient;
this.workflow = workflow;
}

public async Task<IAgent?> GetNextSpeakerAsync(
OrchestrationContext context,
CancellationToken cancellationToken = default)
{
var candidates = context.Candidates.ToList();

if (candidates.Count == 0)
{
return null;
}

if (candidates.Count == 1)
{
return candidates.First();
}

// if there's a workflow
// and the next available agent from the workflow is in the group chat
// then return the next agent from the workflow
if (this.workflow != null)
{
var lastMessage = context.ChatHistory.LastOrDefault();
if (lastMessage == null)
{
return null;
}
var currentSpeaker = candidates.First(candidates => candidates.Name == lastMessage.From);
var nextAgents = await this.workflow.TransitToNextAvailableAgentsAsync(currentSpeaker, context.ChatHistory, cancellationToken);
nextAgents = nextAgents.Where(nextAgent => candidates.Any(candidate => candidate.Name == nextAgent.Name));
candidates = nextAgents.ToList();
if (!candidates.Any())
{
return null;
}

if (candidates is { Count: 1 })
{
return candidates.First();
}
}

// In this case, since there are more than one available agents from the workflow for the next speaker
// We need to invoke LLM to select the next speaker via select next speaker function

var chatHistoryStringBuilder = new StringBuilder();
foreach (var message in context.ChatHistory)
{
var chatHistoryPrompt = $"{message.From}: {message.GetContent()}";

chatHistoryStringBuilder.AppendLine(chatHistoryPrompt);
}

var chatHistory = chatHistoryStringBuilder.ToString();

var prompt = $"""
# Task: Select the next speaker

You are in a role-play game. Carefully read the conversation history and select the next speaker from the available roles.

# Conversation
{chatHistory}

# Available roles
- {string.Join(",", candidates.Select(candidate => candidate.Name))}

Select the next speaker from the available roles and provide a reason for your selection.
""";

// enforce the next speaker to be selected by the LLM
var option = new ChatCompletionOptions
{
ToolChoice = ChatToolChoice.CreateFunctionChoice(this.SelectNextSpeakerFunctionContract.Name),
};

option.Tools.Add(this.SelectNextSpeakerFunctionContract.ToChatTool());
var toolCallMiddleware = new FunctionCallMiddleware(
functions: [this.SelectNextSpeakerFunctionContract],
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
[this.SelectNextSpeakerFunctionContract.Name] = this.SelectNextSpeakerWrapper,
});

var selectAgent = new OpenAIChatAgent(
chatClient,
"admin",
option)
.RegisterMessageConnector()
.RegisterMiddleware(toolCallMiddleware);

var reply = await selectAgent.SendAsync(prompt);

var nextSpeaker = candidates.FirstOrDefault(candidate => candidate.Name == reply.GetContent());

return nextSpeaker;
}

/// <summary>
/// Select the next speaker by name and reason
/// </summary>
[Function]
public async Task<string> SelectNextSpeaker(string name, string reason)
{
return name;
}
}
189 changes: 189 additions & 0 deletions dotnet/test/AutoGen.OpenAI.Tests/RolePlayToolCallOrchestratorTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// RolePlayToolCallOrchestratorTests.cs

using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using AutoGen.OpenAI.Orchestrator;
using AutoGen.Tests;
using Azure.AI.OpenAI;
using FluentAssertions;
using Moq;
using OpenAI;
using OpenAI.Chat;
using Xunit;

namespace AutoGen.OpenAI.Tests;

public class RolePlayToolCallOrchestratorTests
{
[Fact]
public async Task ItReturnNullWhenNoCandidateIsAvailableAsync()
{
var chatClient = Mock.Of<ChatClient>();
var orchestrator = new RolePlayToolCallOrchestrator(chatClient);
var context = new OrchestrationContext
{
Candidates = [],
ChatHistory = [],
};

var speaker = await orchestrator.GetNextSpeakerAsync(context);
speaker.Should().BeNull();
}

[Fact]
public async Task ItReturnCandidateWhenOnlyOneCandidateIsAvailableAsync()
{
var chatClient = Mock.Of<ChatClient>();
var alice = new EchoAgent("Alice");
var orchestrator = new RolePlayToolCallOrchestrator(chatClient);
var context = new OrchestrationContext
{
Candidates = [alice],
ChatHistory = [],
};

var speaker = await orchestrator.GetNextSpeakerAsync(context);
speaker.Should().Be(alice);
}

[Fact]
public async Task ItSelectNextSpeakerFromWorkflowIfProvided()
{
var workflow = new Graph();
var alice = new EchoAgent("Alice");
var bob = new EchoAgent("Bob");
var charlie = new EchoAgent("Charlie");
workflow.AddTransition(Transition.Create(alice, bob));
workflow.AddTransition(Transition.Create(bob, charlie));
workflow.AddTransition(Transition.Create(charlie, alice));

var client = Mock.Of<ChatClient>();
var orchestrator = new RolePlayToolCallOrchestrator(client, workflow);
var context = new OrchestrationContext
{
Candidates = [alice, bob, charlie],
ChatHistory =
[
new TextMessage(Role.User, "Hello, Bob", from: "Alice"),
],
};

var speaker = await orchestrator.GetNextSpeakerAsync(context);
speaker.Should().Be(bob);
}

[Fact]
public async Task ItReturnNullIfNoAvailableAgentFromWorkflowAsync()
{
var workflow = new Graph();
var alice = new EchoAgent("Alice");
var bob = new EchoAgent("Bob");
workflow.AddTransition(Transition.Create(alice, bob));

var client = Mock.Of<ChatClient>();
var orchestrator = new RolePlayToolCallOrchestrator(client, workflow);
var context = new OrchestrationContext
{
Candidates = [alice, bob],
ChatHistory =
[
new TextMessage(Role.User, "Hello, Alice", from: "Bob"),
],
};

var speaker = await orchestrator.GetNextSpeakerAsync(context);
speaker.Should().BeNull();
}

[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task GPT_3_5_CoderReviewerRunnerTestAsync()
{
var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable.");
var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable.");
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable.");
var openaiClient = new AzureOpenAIClient(new Uri(endpoint), new System.ClientModel.ApiKeyCredential(key));
var chatClient = openaiClient.GetChatClient(deployName);

await CoderReviewerRunnerTestAsync(chatClient);
}

[ApiKeyFact("OPENAI_API_KEY")]
public async Task GPT_4o_CoderReviewerRunnerTestAsync()
{
var apiKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new InvalidOperationException("OPENAI_API_KEY is not set");
var model = "gpt-4o";
var openaiClient = new OpenAIClient(apiKey);
var chatClient = openaiClient.GetChatClient(model);

await CoderReviewerRunnerTestAsync(chatClient);
}

[ApiKeyFact("OPENAI_API_KEY")]
public async Task GPT_4o_mini_CoderReviewerRunnerTestAsync()
{
var apiKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new InvalidOperationException("OPENAI_API_KEY is not set");
var model = "gpt-4o-mini";
var openaiClient = new OpenAIClient(apiKey);
var chatClient = openaiClient.GetChatClient(model);

await CoderReviewerRunnerTestAsync(chatClient);
}

/// <summary>
/// This test is to mimic the conversation among coder, reviewer and runner.
/// The coder will write the code, the reviewer will review the code, and the runner will run the code.
/// </summary>
/// <param name="client"></param>
/// <returns></returns>
private async Task CoderReviewerRunnerTestAsync(ChatClient client)
{
var coder = new EchoAgent("Coder");
var reviewer = new EchoAgent("Reviewer");
var runner = new EchoAgent("Runner");
var user = new EchoAgent("User");
var initializeMessage = new List<IMessage>
{
new TextMessage(Role.User, "Hello, I am user, I will provide the coding task, please write the code first, then review and run it", from: "User"),
new TextMessage(Role.User, "Hello, I am coder, I will write the code", from: "Coder"),
new TextMessage(Role.User, "Hello, I am reviewer, I will review the code", from: "Reviewer"),
new TextMessage(Role.User, "Hello, I am runner, I will run the code", from: "Runner"),
new TextMessage(Role.User, "how to print 'hello world' using C#", from: user.Name),
};

var chatHistory = new List<IMessage>()
{
new TextMessage(Role.User, """
```csharp
Console.WriteLine("Hello World");
```
""", from: coder.Name),
new TextMessage(Role.User, "The code looks good", from: reviewer.Name),
new TextMessage(Role.User, "The code runs successfully, the output is 'Hello World'", from: runner.Name),
};

var orchestrator = new RolePlayToolCallOrchestrator(client);
foreach (var message in chatHistory)
{
var context = new OrchestrationContext
{
Candidates = [coder, reviewer, runner, user],
ChatHistory = initializeMessage,
};

var speaker = await orchestrator.GetNextSpeakerAsync(context);
speaker!.Name.Should().Be(message.From);
initializeMessage.Add(message);
}

// the last next speaker should be the user
var lastSpeaker = await orchestrator.GetNextSpeakerAsync(new OrchestrationContext
{
Candidates = [coder, reviewer, runner, user],
ChatHistory = initializeMessage,
});

lastSpeaker!.Name.Should().Be(user.Name);
}
}

0 comments on commit 2f7a0f7

Please sign in to comment.