Skip to content

Commit

Permalink
Add an IChatClient implementation to OnnxRuntimeGenAI
Browse files Browse the repository at this point in the history
  • Loading branch information
stephentoub committed Nov 4, 2024
1 parent c9ffcb9 commit e3beb9c
Show file tree
Hide file tree
Showing 2 changed files with 282 additions and 0 deletions.
278 changes: 278 additions & 0 deletions src/csharp/ChatClient.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,278 @@
using Microsoft.Extensions.AI;
using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.ML.OnnxRuntimeGenAI;

/// <summary>An <see cref="IChatClient"/> implementation based on ONNX Runtime GenAI.</summary>
public sealed class ChatClient : IChatClient, IDisposable
{
/// <summary>The wrapped <see cref="Model"/>.</summary>
private readonly Model _model;
/// <summary>The wrapped <see cref="Tokenizer"/>.</summary>
private readonly Tokenizer _tokenizer;
/// <summary>Whether to dispose of <see cref="_model"/> when this instance is disposed.</summary>
private readonly bool _ownsModel;

/// <summary>Initializes an instance of the <see cref="ChatClient"/> class.</summary>
/// <param name="modelPath">The file path to the model to load.</param>
/// <exception cref="ArgumentNullException"><paramref name="modelPath"/> is null.</exception>
public ChatClient(string modelPath)
{
if (modelPath is null)
{
throw new ArgumentNullException(nameof(modelPath));
}

_ownsModel = true;
_model = new Model(modelPath);
_tokenizer = new Tokenizer(_model);

Metadata = new(typeof(ChatClient).Namespace, new Uri($"file://{modelPath}"), modelPath);
}

/// <summary>Initializes an instance of the <see cref="ChatClient"/> class.</summary>
/// <param name="model">The model to employ.</param>
/// <param name="ownsModel">
/// <see langword="true"/> if this <see cref="IChatClient"/> owns the <paramref name="model"/> and should
/// dispose of it when this <see cref="IChatClient"/> is disposed; otherwise, <see langword="false"/>.
/// The default is <see langword="true"/>.
/// </param>
/// <exception cref="ArgumentNullException"><paramref name="model"/> is null.</exception>
public ChatClient(Model model, bool ownsModel = true)
{
if (model is null)
{
throw new ArgumentNullException(nameof(model));
}

_ownsModel = ownsModel;
_model = model;
_tokenizer = new Tokenizer(_model);

Metadata = new("Microsoft.ML.OnnxRuntimeGenAI");
}

/// <inheritdoc/>
public ChatClientMetadata Metadata { get; }

/// <summary>
/// Gets or sets stop sequences to use during generation.
/// </summary>
/// <remarks>
/// These will apply in addition to any stop sequences that are a part of the <see cref="ChatOptions.StopSequences"/>.
/// </remarks>
public IList<string> StopSequences { get; set; } =
[
// Default stop sequences based on Phi3
"<|system|>",
"<|user|>",
"<|assistant|>",
"<|end|>"
];

/// <summary>
/// Gets or sets a function that creates a prompt string from the chat history.
/// </summary>
public Func<IEnumerable<ChatMessage>, string> PromptFormatter { get; set; }

/// <inheritdoc/>
public void Dispose()
{
_tokenizer.Dispose();

if (_ownsModel)
{
_model.Dispose();
}
}

/// <inheritdoc/>
public async Task<ChatCompletion> CompleteAsync(IList<ChatMessage> chatMessages, ChatOptions options = null, CancellationToken cancellationToken = default)
{
if (chatMessages is null)
{
throw new ArgumentNullException(nameof(chatMessages));
}

StringBuilder text = new();
await Task.Run(() =>
{
using Sequences tokens = _tokenizer.Encode(CreatePrompt(chatMessages));
using GeneratorParams generatorParams = new(_model);
UpdateGeneratorParamsFromOptions(tokens[0].Length, generatorParams, options);
generatorParams.SetInputSequences(tokens);

using Generator generator = new(_model, generatorParams);
using var tokenizerStream = _tokenizer.CreateStream();

var completionId = Guid.NewGuid().ToString();
while (!generator.IsDone())
{
cancellationToken.ThrowIfCancellationRequested();

generator.ComputeLogits();
generator.GenerateNextToken();

ReadOnlySpan<int> outputSequence = generator.GetSequence(0);
string next = tokenizerStream.Decode(outputSequence[outputSequence.Length - 1]);

if (IsStop(next, options))
{
break;
}

text.Append(next);
}
}, cancellationToken);

return new ChatCompletion(new ChatMessage(ChatRole.Assistant, text.ToString()))
{
CompletionId = Guid.NewGuid().ToString(),
CreatedAt = DateTimeOffset.UtcNow,
ModelId = Metadata.ModelId,
};
}

/// <inheritdoc/>
public async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAsync(
IList<ChatMessage> chatMessages, ChatOptions options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
if (chatMessages is null)
{
throw new ArgumentNullException(nameof(chatMessages));
}

using Sequences tokens = _tokenizer.Encode(CreatePrompt(chatMessages));
using GeneratorParams generatorParams = new(_model);
UpdateGeneratorParamsFromOptions(tokens[0].Length, generatorParams, options);
generatorParams.SetInputSequences(tokens);

using Generator generator = new(_model, generatorParams);
using var tokenizerStream = _tokenizer.CreateStream();

var completionId = Guid.NewGuid().ToString();
while (!generator.IsDone())
{
string next = await Task.Run(() =>
{
generator.ComputeLogits();
generator.GenerateNextToken();

ReadOnlySpan<int> outputSequence = generator.GetSequence(0);
return tokenizerStream.Decode(outputSequence[outputSequence.Length - 1]);
}, cancellationToken);

if (IsStop(next, options))
{
break;
}

yield return new StreamingChatCompletionUpdate
{
CompletionId = completionId,
CreatedAt = DateTimeOffset.UtcNow,
Role = ChatRole.Assistant,
Text = next,
};
}
}

/// <inheritdoc/>
public TService GetService<TService>(object key = null) where TService : class =>
typeof(TService) == typeof(Model) ? (TService)(object)_model :
typeof(TService) == typeof(Tokenizer) ? (TService)(object)_tokenizer :
this as TService;

/// <summary>Gets whether the specified token is a stop sequence.</summary>
private bool IsStop(string token, ChatOptions options) =>
options?.StopSequences?.Contains(token) is true ||
StopSequences?.Contains(token) is true;

/// <summary>Creates a prompt string from the supplied chat history.</summary>
private string CreatePrompt(IEnumerable<ChatMessage> messages)
{
if (messages is null)
{
throw new ArgumentNullException(nameof(messages));
}

if (PromptFormatter is not null)
{
return PromptFormatter(messages) ?? string.Empty;
}

// Default formatting based on Phi3.
StringBuilder prompt = new();

foreach (var message in messages)
{
foreach (var content in message.Contents)
{
switch (content)
{
case TextContent tc when !string.IsNullOrWhiteSpace(tc.Text):
prompt.Append("<|").Append(message.Role.Value).Append("|>\n")
.Append(tc.Text.Replace("<|end|>\n", ""))
.Append("<|end|>\n");
break;
}
}
}

prompt.Append("<|assistant|>");

return prompt.ToString();
}

/// <summary>Updates the <paramref name="generatorParams"/> based on the supplied <paramref name="options"/>.</summary>
private static void UpdateGeneratorParamsFromOptions(int numInputTokens, GeneratorParams generatorParams, ChatOptions options)
{
if (options is null)
{
return;
}

if (options.MaxOutputTokens.HasValue)
{
generatorParams.SetSearchOption("max_length", numInputTokens + options.MaxOutputTokens.Value);
}

if (options.Temperature.HasValue)
{
generatorParams.SetSearchOption("temperature", options.Temperature.Value);
}

if (options.TopP.HasValue || options.TopK.HasValue)
{
if (options.TopP.HasValue)
{
generatorParams.SetSearchOption("top_p", options.TopP.Value);
}

if (options.TopK.HasValue)
{
generatorParams.SetSearchOption("top_k", options.TopK.Value);
}
}

if (options.AdditionalProperties is { } props)
{
foreach (var entry in props)
{
switch (entry.Value)
{
case int i: generatorParams.SetSearchOption(entry.Key, i); break;
case long l: generatorParams.SetSearchOption(entry.Key, l); break;
case float f: generatorParams.SetSearchOption(entry.Key, f); break;
case double d: generatorParams.SetSearchOption(entry.Key, d); break;
case bool b: generatorParams.SetSearchOption(entry.Key, b); break;
}
}
}
}
}
4 changes: 4 additions & 0 deletions src/csharp/Microsoft.ML.OnnxRuntimeGenAI.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -121,4 +121,8 @@
<PackageReference Include="System.Memory" Version="4.5.5" />
</ItemGroup>

<ItemGroup>
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" Version="9.0.0-preview.9.24525.1" />
</ItemGroup>

</Project>

0 comments on commit e3beb9c

Please sign in to comment.