Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an IChatClient implementation to OnnxRuntimeGenAI #987

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
247 changes: 247 additions & 0 deletions src/csharp/ChatClient.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
using Microsoft.Extensions.AI;
using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.ML.OnnxRuntimeGenAI;

/// <summary>Provides an <see cref="IChatClient"/> implementation for interacting with a <see cref="Model"/>.</summary>
public sealed partial class ChatClient : IChatClient
{
/// <summary>The options used to configure the instance.</summary>
private readonly ChatClientConfiguration _config;
/// <summary>The wrapped <see cref="Model"/>.</summary>
private readonly Model _model;
/// <summary>The wrapped <see cref="Tokenizer"/>.</summary>
private readonly Tokenizer _tokenizer;
/// <summary>Whether to dispose of <see cref="_model"/> when this instance is disposed.</summary>
private readonly bool _ownsModel;

/// <summary>Initializes an instance of the <see cref="ChatClient"/> class.</summary>
/// <param name="configuration">Options used to configure the client instance.</param>
/// <param name="modelPath">The file path to the model to load.</param>
/// <exception cref="ArgumentNullException"><paramref name="modelPath"/> is null.</exception>
public ChatClient(ChatClientConfiguration configuration, string modelPath)
{
if (configuration is null)
{
throw new ArgumentNullException(nameof(configuration));
}

if (modelPath is null)
{
throw new ArgumentNullException(nameof(modelPath));
}

_config = configuration;

_ownsModel = true;
_model = new Model(modelPath);
_tokenizer = new Tokenizer(_model);

Metadata = new("onnxruntime-genai", new Uri($"file://{modelPath}"), modelPath);
}

/// <summary>Initializes an instance of the <see cref="ChatClient"/> class.</summary>
/// <param name="configuration">Options used to configure the client instance.</param>
/// <param name="model">The model to employ.</param>
/// <param name="ownsModel">
/// <see langword="true"/> if this <see cref="IChatClient"/> owns the <paramref name="model"/> and should
/// dispose of it when this <see cref="IChatClient"/> is disposed; otherwise, <see langword="false"/>.
/// The default is <see langword="true"/>.
/// </param>
/// <exception cref="ArgumentNullException"><paramref name="model"/> is null.</exception>
public ChatClient(ChatClientConfiguration configuration, Model model, bool ownsModel = true)
{
if (configuration is null)
{
throw new ArgumentNullException(nameof(configuration));
}

if (model is null)
{
throw new ArgumentNullException(nameof(model));
}

_config = configuration;

_ownsModel = ownsModel;
_model = model;
_tokenizer = new Tokenizer(_model);

Metadata = new("onnxruntime-genai");
}

/// <inheritdoc/>
public ChatClientMetadata Metadata { get; }

/// <inheritdoc/>
public void Dispose()
{
_tokenizer.Dispose();

if (_ownsModel)
{
_model.Dispose();
}
}

/// <inheritdoc/>
public async Task<ChatCompletion> CompleteAsync(IList<ChatMessage> chatMessages, ChatOptions options = null, CancellationToken cancellationToken = default)
{
if (chatMessages is null)
{
throw new ArgumentNullException(nameof(chatMessages));
}

StringBuilder text = new();
await Task.Run(() =>
{
using Sequences tokens = _tokenizer.Encode(_config.PromptFormatter(chatMessages));
using GeneratorParams generatorParams = new(_model);
UpdateGeneratorParamsFromOptions(tokens[0].Length, generatorParams, options);

using Generator generator = new(_model, generatorParams);
generator.AppendTokenSequences(tokens);

using var tokenizerStream = _tokenizer.CreateStream();

var completionId = Guid.NewGuid().ToString();
while (!generator.IsDone())
{
cancellationToken.ThrowIfCancellationRequested();

generator.GenerateNextToken();

ReadOnlySpan<int> outputSequence = generator.GetSequence(0);
string next = tokenizerStream.Decode(outputSequence[outputSequence.Length - 1]);

if (IsStop(next, options))
{
break;
}

text.Append(next);
}
}, cancellationToken);

return new ChatCompletion(new ChatMessage(ChatRole.Assistant, text.ToString()))
{
CompletionId = Guid.NewGuid().ToString(),
CreatedAt = DateTimeOffset.UtcNow,
ModelId = Metadata.ModelId,
};
}

/// <inheritdoc/>
public async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAsync(
IList<ChatMessage> chatMessages, ChatOptions options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
if (chatMessages is null)
{
throw new ArgumentNullException(nameof(chatMessages));
}

using Sequences tokens = _tokenizer.Encode(_config.PromptFormatter(chatMessages));
using GeneratorParams generatorParams = new(_model);
UpdateGeneratorParamsFromOptions(tokens[0].Length, generatorParams, options);

using Generator generator = new(_model, generatorParams);
generator.AppendTokenSequences(tokens);

using var tokenizerStream = _tokenizer.CreateStream();

var completionId = Guid.NewGuid().ToString();
while (!generator.IsDone())
{
string next = await Task.Run(() =>
{
generator.GenerateNextToken();

ReadOnlySpan<int> outputSequence = generator.GetSequence(0);
return tokenizerStream.Decode(outputSequence[outputSequence.Length - 1]);
}, cancellationToken);

if (IsStop(next, options))
{
break;
}

yield return new StreamingChatCompletionUpdate
{
CompletionId = completionId,
CreatedAt = DateTimeOffset.UtcNow,
Role = ChatRole.Assistant,
Text = next,
};
}
}

/// <inheritdoc/>
public object GetService(Type serviceType, object key = null) =>
key is not null ? null :
serviceType == typeof(Model) ? _model :
serviceType == typeof(Tokenizer) ? _tokenizer :
serviceType?.IsInstanceOfType(this) is true ? this :
null;

/// <summary>Gets whether the specified token is a stop sequence.</summary>
private bool IsStop(string token, ChatOptions options) =>
options?.StopSequences?.Contains(token) is true ||
Array.IndexOf(_config.StopSequences, token) >= 0;

/// <summary>Updates the <paramref name="generatorParams"/> based on the supplied <paramref name="options"/>.</summary>
private static void UpdateGeneratorParamsFromOptions(int numInputTokens, GeneratorParams generatorParams, ChatOptions options)
{
if (options is null)
{
return;
}

if (options.MaxOutputTokens.HasValue)
{
generatorParams.SetSearchOption("max_length", numInputTokens + options.MaxOutputTokens.Value);
}

if (options.Temperature.HasValue)
{
generatorParams.SetSearchOption("temperature", options.Temperature.Value);
}

if (options.TopP.HasValue || options.TopK.HasValue)
{
if (options.TopP.HasValue)
{
generatorParams.SetSearchOption("top_p", options.TopP.Value);
}

if (options.TopK.HasValue)
{
generatorParams.SetSearchOption("top_k", options.TopK.Value);
}
}

if (options.Seed.HasValue)
{
generatorParams.SetSearchOption("random_seed", options.Seed.Value);
}

if (options.AdditionalProperties is { } props)
{
foreach (var entry in props)
{
switch (entry.Value)
{
case int i: generatorParams.SetSearchOption(entry.Key, i); break;
case long l: generatorParams.SetSearchOption(entry.Key, l); break;
case float f: generatorParams.SetSearchOption(entry.Key, f); break;
case double d: generatorParams.SetSearchOption(entry.Key, d); break;
case bool b: generatorParams.SetSearchOption(entry.Key, b); break;
}
}
}
}
}
73 changes: 73 additions & 0 deletions src/csharp/ChatClientConfiguration.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
using Microsoft.Extensions.AI;
using System;
using System.Collections.Generic;

namespace Microsoft.ML.OnnxRuntimeGenAI;

/// <summary>Provides configuration options used when constructing a <see cref="ChatClient"/>.</summary>
/// <remarks>
/// Every model has different requirements for stop sequences and prompt formatting. For best results,
/// the configuration should be tailored to the exact nature of the model being used. For example,
/// when using a Phi3 model, a configuration like the following may be used:
/// <code>
/// static ChatClientConfiguration CreateForPhi3() =&gt;
/// new(["&lt;|system|&gt;", "&lt;|user|&gt;", "&lt;|assistant|&gt;", "&lt;|end|&gt;"],
/// (IEnumerable&lt;ChatMessage&gt; messages) =&gt;
/// {
/// StringBuilder prompt = new();
///
/// foreach (var message in messages)
/// foreach (var content in message.Contents.OfType&lt;TextContent&gt;())
/// prompt.Append("&lt;|").Append(message.Role.Value).Append("|&gt;\n").Append(tc.Text).Append("&lt;|end|&gt;\n");
///
/// return prompt.Append("&lt;|assistant|&gt;\n").ToString();
/// });
/// </code>
/// </remarks>
public sealed class ChatClientConfiguration
{
private string[] _stopSequences;
private Func<IEnumerable<ChatMessage>, string> _promptFormatter;

/// <summary>Initializes a new instance of the <see cref="ChatClientConfiguration"/> class.</summary>
/// <param name="stopSequences">The stop sequences used by the model.</param>
/// <param name="promptFormatter">The function to use to format a list of messages for input into the model.</param>
/// <exception cref="ArgumentNullException"><paramref name="stopSequences"/> is null.</exception>
/// <exception cref="ArgumentNullException"><paramref name="promptFormatter"/> is null.</exception>
public ChatClientConfiguration(
string[] stopSequences,
Func<IEnumerable<ChatMessage>, string> promptFormatter)
{
if (stopSequences is null)
{
throw new ArgumentNullException(nameof(stopSequences));
}

if (promptFormatter is null)
{
throw new ArgumentNullException(nameof(promptFormatter));
}

StopSequences = stopSequences;
PromptFormatter = promptFormatter;
}

/// <summary>
/// Gets or sets stop sequences to use during generation.
/// </summary>
/// <remarks>
/// These will apply in addition to any stop sequences that are a part of the <see cref="ChatOptions.StopSequences"/>.
/// </remarks>
public string[] StopSequences
{
get => _stopSequences;
set => _stopSequences = value ?? throw new ArgumentNullException(nameof(value));
}

/// <summary>Gets the function that creates a prompt string from the chat history.</summary>
public Func<IEnumerable<ChatMessage>, string> PromptFormatter
{
get => _promptFormatter;
set => _promptFormatter = value ?? throw new ArgumentNullException(nameof(value));
}
}
4 changes: 4 additions & 0 deletions src/csharp/Microsoft.ML.OnnxRuntimeGenAI.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -121,4 +121,8 @@
<PackageReference Include="System.Memory" Version="4.5.5" />
</ItemGroup>

<ItemGroup>
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" Version="9.0.1-preview.1.24570.5" />
</ItemGroup>

</Project>
32 changes: 31 additions & 1 deletion test/csharp/TestOnnxRuntimeGenAIAPI.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
// Licensed under the MIT License.

using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Runtime.InteropServices;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading.Tasks;
using Xunit;
using Xunit.Abstractions;
using Microsoft.Extensions.AI;

namespace Microsoft.ML.OnnxRuntimeGenAI.Tests
{
Expand Down Expand Up @@ -349,6 +352,33 @@ public void TestTopKTopPSearch()
}
}

[IgnoreOnModelAbsenceFact(DisplayName = "TestChatClient")]
public async Task TestChatClient()
{
ChatClientConfiguration config = new(
["<|system|>", "<|user|>", "<|assistant|>", "<|end|>"],
(IEnumerable<ChatMessage> messages) =>
{
StringBuilder prompt = new();

foreach (var message in messages)
foreach (var content in message.Contents.OfType<TextContent>())
prompt.Append("<|").Append(message.Role.Value).Append("|>\n").Append(content.Text).Append("<|end|>\n");

return prompt.Append("<|assistant|>\n").ToString();
});

using var client = new ChatClient(config, _phi2Path);

var completion = await client.CompleteAsync("What is 2 + 3?", new()
{
MaxOutputTokens = 20,
Temperature = 0f,
});

Assert.Contains("5", completion.ToString());
}

[IgnoreOnModelAbsenceFact(DisplayName = "TestTokenizerBatchEncodeDecode")]
public void TestTokenizerBatchEncodeDecode()
{
Expand Down
Loading