Skip to content

Commit

Permalink
Optimize algorithm efficiency.
Browse files Browse the repository at this point in the history
  • Loading branch information
aiqinxuancai committed Apr 8, 2024
1 parent d6188fe commit 572ef55
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 21 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,11 @@ public int TiktokenSharp()

</details>


| Method | Job | Runtime | Mean | Error | StdDev | Gen0 | Allocated |
|-------------- |--------- |--------- |----------:|---------:|---------:|----------:|-----------:|
| SharpToken | .NET 8.0 | .NET 8.0 | 112.86 ms | 0.712 ms | 0.595 ms | 2600.0000 | 23202285 B |
| TiktokenSharp | .NET 8.0 | .NET 8.0 | 99.40 ms | 0.179 ms | 0.149 ms | 9800.0000 | 82321296 B |
| SharpToken | .NET 8.0 | .NET 8.0 | 116.38 ms | 1.026 ms | 0.909 ms | 2000.0000 | 23201696 B |
| TiktokenSharp | .NET 8.0 | .NET 8.0 | 98.34 ms | 0.198 ms | 0.176 ms | 9833.3333 | 82321080 B |

## Update

Expand Down
2 changes: 1 addition & 1 deletion TiktokenSharp.Test/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ static void GPT4()
Debug.Assert(i.IsEqualTo(new List<int>() { 15339, 1917 }));
Debug.Assert(tikToken.Decode(new List<int>() { 15339, 1917 }) == "hello world");

var c = tikToken.Encode("hello <|endoftext|>");
var c = tikToken.Encode("hello <|endoftext|>", new HashSet<string>() { "<|endoftext|>" });
Debug.Assert(c.IsEqualTo(new List<int>() { 15339, 220, 100257 }));

var t1 = tikToken.Encode("我很抱歉,我不能提供任何非法或不道德的建议。快速赚钱是不容易的,需要耐心、刻苦努力和经验。如果您想增加收入,请考虑增加工作时间、寻找其他业务机会、学习新技能或提高自己的价值等方法。请记住,通过合法而道德的方式来获得收入,才是长期稳定的解决方案。");
Expand Down
41 changes: 29 additions & 12 deletions TiktokenSharp/CoreBPE.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.IO;
using System.Linq;
Expand All @@ -12,9 +13,12 @@ public class CoreBPE
{
private Dictionary<string, int> _specialTokensEncoder { get; set; }

// TODO private max_token_value ??
private Dictionary<ReadOnlyMemory<byte>, int> _encoder { get; set; }

// TODO Cache?
//private ConcurrentDictionary<char[], List<int>> _cache { get; set; }
//private MemoryCache _cache = MemoryCache.Default;

private Regex _specialRegex { get; set; }

private Regex _regex { get; set; }
Expand All @@ -37,6 +41,7 @@ public class CoreBPE
public CoreBPE(Dictionary<ReadOnlyMemory<byte>, int> encoder, Dictionary<string, int> specialTokensEncoder, string pattern)
{
_encoder = encoder;
//_cache = new ConcurrentDictionary<char[], List<int>>(new ReadOnlyMemoryComparer());
_regex = new Regex(pattern, RegexOptions.Compiled);
_specialRegex = new Regex(string.Join("|", specialTokensEncoder.Keys.Select(s => Regex.Escape(s))), RegexOptions.Compiled);
_specialTokensEncoder = specialTokensEncoder;
Expand All @@ -61,30 +66,27 @@ public CoreBPE(Dictionary<ReadOnlyMemory<byte>, int> encoder, Dictionary<string,
#if NET7_0_OR_GREATER
public (List<int>, int) EncodeNative(string text, HashSet<string> allowedSpecial, HashSet<string> disallowedSpecial)
{
Regex specialRegex = _specialRegex;
Regex regex = _regex;
var ret = new List<int>();

ReadOnlySpan<char> textSpan = text.AsSpan();
var textSpan = text.AsMemory();
int lastPieceTokenLen = 0;
int currentIndex = 0;

var enumerator = specialRegex.EnumerateMatches(textSpan);

var enumerator = _specialRegex.EnumerateMatches(textSpan.Span);

while (currentIndex < text.Length)
{
int nextMatchStart = text.Length;

if (enumerator.MoveNext())
{

var current = enumerator.Current;

var currentText = textSpan.Slice(current.Index, current.Length).ToString();

if (disallowedSpecial != null && disallowedSpecial.Contains(currentText))
{
throw new InvalidOperationException(currentText);
throw new InvalidOperationException(currentText.ToString());
}
if (allowedSpecial != null && allowedSpecial.Contains(currentText))
{
Expand All @@ -93,20 +95,35 @@ public CoreBPE(Dictionary<ReadOnlyMemory<byte>, int> encoder, Dictionary<string,

}

ReadOnlySpan<char> currentSpan = textSpan.Slice(currentIndex, nextMatchStart - currentIndex);
foreach (var match in regex.EnumerateMatches(currentSpan))
//read only

ReadOnlyMemory<char> currentSpan = textSpan.Slice(currentIndex, nextMatchStart - currentIndex);
foreach (var match in _regex.EnumerateMatches(currentSpan.Span))
{
var piece = Encoding.UTF8.GetBytes(currentSpan.Slice(match.Index, match.Length).ToString());
var charSpan = currentSpan.Slice(match.Index, match.Length);
//var byteSpan = ByteHelper.ConvertReadOnlyMemoryCharToByte(charSpan);

var piece = Encoding.UTF8.GetBytes(charSpan.ToString()); //TODO remove ToString
if (_encoder.TryGetValue(piece, out int token))
{
lastPieceTokenLen = 1;
ret.Add(token);
}
else
{
//TODO Cache?

//if (_cache.TryGetValue(piece, out List<int> cacheToken))
//{
// ret.AddRange(cacheToken);
// continue;
//}

var tokens = BytePairEncoding.BytePairEncode(piece, _encoder);
lastPieceTokenLen = tokens.Count;
ret.AddRange(tokens);

//_cache[piece] = tokens;
}
}

Expand All @@ -116,7 +133,7 @@ public CoreBPE(Dictionary<ReadOnlyMemory<byte>, int> encoder, Dictionary<string,
{
var match = enumerator.Current;
var pieceSpan = textSpan.Slice(currentIndex, match.Length);
if (_specialTokensEncoder.TryGetValue(pieceSpan.ToString(), out int token))
if (_specialTokensEncoder.TryGetValue(pieceSpan.ToString(), out int token)) //TODO remove ToString
{
ret.Add(token);
currentIndex += match.Length;
Expand Down
50 changes: 44 additions & 6 deletions TiktokenSharp/TikToken.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using Microsoft.VisualBasic;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
Expand Down Expand Up @@ -77,8 +78,15 @@ public TikToken(EncodingSettingModel setting)
{
if (setting.ExplicitNVocab != null)
{
Debug.Assert(setting.SpecialTokens.Count + setting.MergeableRanks.Count == setting.ExplicitNVocab);
Debug.Assert(setting.MaxTokenValue == setting.ExplicitNVocab - 1);
if (setting.SpecialTokens.Count + setting.MergeableRanks.Count != setting.ExplicitNVocab)
{
throw new ArgumentException("SpecialTokens + MergeableRanks counts must equal ExplicitNVocab.");
}

if (setting.MaxTokenValue != setting.ExplicitNVocab - 1)
{
throw new ArgumentException("MaxTokenValue must be equal to ExplicitNVocab - 1.");
}
}

_corePBE = new CoreBPE(setting.MergeableRanks, setting.SpecialTokens, setting.PatStr);
Expand All @@ -87,6 +95,39 @@ public TikToken(EncodingSettingModel setting)

public List<int> Encode(string text, HashSet<string> allowedSpecial = null, HashSet<string> disallowedSpecial = null)

Check warning on line 96 in TiktokenSharp/TikToken.cs

View workflow job for this annotation

GitHub Actions / build-and-publish

Cannot convert null literal to non-nullable reference type.

Check warning on line 96 in TiktokenSharp/TikToken.cs

View workflow job for this annotation

GitHub Actions / build-and-publish

Cannot convert null literal to non-nullable reference type.

Check warning on line 96 in TiktokenSharp/TikToken.cs

View workflow job for this annotation

GitHub Actions / build-and-publish

Cannot convert null literal to non-nullable reference type.

Check warning on line 96 in TiktokenSharp/TikToken.cs

View workflow job for this annotation

GitHub Actions / build-and-publish

Cannot convert null literal to non-nullable reference type.
{
//#if NET7_0_OR_GREATER
// HashSet<ReadOnlyMemory<char>>? allowedSpecialMemory = null;
// HashSet<ReadOnlyMemory<char>>? disallowedSpecialMemory = null;

//#else
// List<ReadOnlyMemory<char>>? allowedSpecialMemory = null;
// List<ReadOnlyMemory<char>>? disallowedSpecialMemory = null;
//#endif
// if (allowedSpecial != null)
// {
// allowedSpecialMemory = allowedSpecial
// .Select(str => (ReadOnlyMemory<char>)str.AsMemory())
//#if NET7_0_OR_GREATER
// .ToHashSet();
//#else
// .ToList();
//#endif
// }

// if (disallowedSpecial != null )
// {
// var disallowedSpecialMemcpy = disallowedSpecial
// .Select(str => (ReadOnlyMemory<char>)str.AsMemory())
//#if NET7_0_OR_GREATER
// .ToHashSet();
//#else
// .ToList();
//#endif
// }




return _corePBE.EncodeNative(text, allowedSpecial, disallowedSpecial).Item1;
}

Expand All @@ -98,8 +139,5 @@ public string Decode(List<int> tokens)
}





}
}
9 changes: 9 additions & 0 deletions TiktokenSharp/Utils/ByteHelper.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text;
using System.Threading.Tasks;

Expand All @@ -20,5 +21,13 @@ public static string ConvertByteListToString(List<ReadOnlyMemory<byte>> byteList
}
return Encoding.UTF8.GetString(allBytes);
}

public static ReadOnlySpan<byte> ConvertReadOnlyMemoryCharToByte(ReadOnlyMemory<char> charMemory)
{
var charSpan = charMemory.Span;
var bytes = MemoryMarshal.AsBytes(charSpan);
return bytes;
}

}
}

0 comments on commit 572ef55

Please sign in to comment.