Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add regex loading from tokenizer.json and code refinement #863

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cgmanifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
"component": {
"type": "git",
"git": {
"commitHash": "9406a60c7839052e4944ea4dbc8344762a89f9bd",
"commitHash": "e39786088138f2749d64e9e90e0f9902daa77c40",
"repositoryUrl": "https://github.com/google/googletest.git"
}
}
Expand Down
4 changes: 2 additions & 2 deletions cmake/externals/googletest.cmake
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
FetchContent_Declare(
googletest
URL https://github.com/google/googletest/archive/9406a60c7839052e4944ea4dbc8344762a89f9bd.zip
URL_HASH SHA1=06096d3900c356e468ba060a609642c635131106
URL https://github.com/google/googletest/archive/refs/tags/v1.15.0.zip
URL_HASH SHA1=9d2d0af8d77ac726ea55d44a8fa727ec98311349
EXCLUDE_FROM_ALL
)

Expand Down
4 changes: 3 additions & 1 deletion onnxruntime_extensions/pp_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,12 @@ def __init__(self, tokenizer_dir):
self.tokenizer = create_tokenizer(tokenizer_dir)

def tokenize(self, text):
if isinstance(text, (list, tuple)):
return batch_tokenize(self.tokenizer, text)
return batch_tokenize(self.tokenizer, [text])[0]

def detokenize(self, tokens):
return batch_detokenize(self.tokenizer, [tokens])[0]
return batch_detokenize(self.tokenizer, [tokens])

def __del__(self):
if delete_object and self.tokenizer:
Expand Down
20 changes: 7 additions & 13 deletions operators/tokenizer/bpe_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ std::vector<int64_t> KernelBpeTokenizer::Tokenize(ustring& input,

// Parse input
auto special_token_split_res = bbpe_tokenizer_->SplitByAddedAndSpecial(input);
bpe::TokenWithRegularExp regcmp;
bpe::PreTokenizerWithRegEx reg_splitter;

for (auto& seg_id : special_token_split_res) {
if (static_cast<int64_t>(res.size()) >= max_length) break;
Expand All @@ -274,7 +274,7 @@ std::vector<int64_t> KernelBpeTokenizer::Tokenize(ustring& input,

// Note: keep ptr to make sure the string_view is valid in the following process
std::u32string str(seg_id.first);
regcmp.Set(str.c_str());
reg_splitter.Set(str.c_str());

size_t offset = 0;
OffsetMappingType offset_mapping;
Expand All @@ -287,14 +287,8 @@ std::vector<int64_t> KernelBpeTokenizer::Tokenize(ustring& input,
}

while (static_cast<int64_t>(res.size()) < max_length) {
std::string regex_expr = "";
if (ModelName() == kModel_Llama){
regex_expr = regcmp.LLAMA_REGEX_PATTERN;
} else {
// default to GPT2 regex
regex_expr = regcmp.GPT2_REGEX_PATTERN;
}
auto [b, tok] = regcmp.GetNextToken(regex_expr);
std::string regex_expr = bbpe_tokenizer_->GetPreTokenizerRegex(ModelName());
auto [b, tok] = reg_splitter.GetNextToken(regex_expr);

if (!b) break;

Expand Down Expand Up @@ -742,9 +736,9 @@ OrtxStatus JsonFastTokenizer::Load(const ort_extensions::TokenJsonConfig& config
}

bbpe_tokenizer_ = std::make_unique<BpeModel>();
status = bbpe_tokenizer_->Load(*model_node,
bpe_conf_.get().GetSpecialTokens().c_str(),
bpe_conf_.get().spm_model_);
status = bbpe_tokenizer_->Load(*model_node, tok_json,
bpe_conf_.get().GetSpecialTokens().c_str(),
bpe_conf_.get().spm_model_);
if (status.IsOk()) {
UpdateTokenizer(config, tok_json);
}
Expand Down
57 changes: 55 additions & 2 deletions operators/tokenizer/bpe_tokenizer_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,12 @@
#include "trietree.hpp"
#include "tokenizer_common.h"

namespace ort_extensions {
#define ORTX_JSON_RETURN_IF_NULL(node_iter, name, var) \
auto var = (node_iter)->find(name); \
if (var == (node_iter)->end() || var->is_null()) { return {}; }


namespace ort_extensions {
class BpeModel {
using json = nlohmann::json;

Expand All @@ -44,6 +48,39 @@ class BpeModel {
}
}

OrtxStatus LoadPreTokenizer(const json& bpe_model) {
auto root_node = &bpe_model;
ORTX_JSON_RETURN_IF_NULL(root_node, "pre_tokenizer", node_pre_tokenizer);
ORTX_JSON_RETURN_IF_NULL(node_pre_tokenizer, "type", iter_type);

auto pre_token_type = iter_type->get<std::string>();
if (pre_token_type == "ByteLevel") {
// need to add more flag support here in the future
return {};
} else if (pre_token_type != "Sequence") {
return {kOrtxErrorNotImplemented, std::string("Unsupported pretokenizer type!") + pre_token_type};
}

ORTX_JSON_RETURN_IF_NULL(node_pre_tokenizer, "pretokenizers", iter_node_list);

for (const auto& node : *iter_node_list) {
ORTX_JSON_RETURN_IF_NULL(&node, "type", iter_type);
auto pre_type = iter_type->get<std::string>();
if (pre_type == "Split") {
ORTX_JSON_RETURN_IF_NULL(&node, "pattern", iter_pattern);
ORTX_JSON_RETURN_IF_NULL(iter_pattern, "Regex", regex_str);
pre_tokenizer_regex_ = regex_str->get<std::string>();
} else if (pre_type == "ByteLevel") {
; // need to add more flag support here in the future
}
else {
return {kOrtxErrorNotImplemented, "Unsupported pretokenizer type!"};
}
}

return {};
}

OrtxStatus Load(std::istream& vocab_stream, std::istream& merges_stream, const char* unk_token,
const char* special_tokens, bool spm_converted) {
nlohmann::json tok_json;
Expand Down Expand Up @@ -120,7 +157,9 @@ class BpeModel {
return {};
}

OrtxStatus Load(const json& bpe_model, const char* /* special_tokens */, bool spm_converted) {
OrtxStatus Load(const json& bpe_model, const json& tok_json, const char* /* special_tokens */, bool spm_converted) {
ORTX_RETURN_IF_ERROR(LoadPreTokenizer(tok_json));

const json& vocab_json = bpe_model["vocab"];
const json& merges_json = bpe_model["merges"];
vocab_json.get_to(vocab_map_);
Expand Down Expand Up @@ -358,6 +397,19 @@ class BpeModel {

const std::string& GetEndOfWordSuffix() const { return end_of_word_suffix_; }

std::string GetPreTokenizerRegex(const std::string& model_name) const {
if (!pre_tokenizer_regex_.empty()) {
return pre_tokenizer_regex_;
}

if (model_name == "Llama") {
return bpe::PreTokenizerWithRegEx::LLAMA_REGEX_PATTERN;
}

// by default, use the GPT2 pretokenizer regex
return bpe::PreTokenizerWithRegEx::GPT2_REGEX_PATTERN;
}

private:
struct BpeNode {
uint32_t id;
Expand All @@ -379,6 +431,7 @@ class BpeModel {
uint32_t unk_id_ = (std::numeric_limits<uint32_t>::max)();
bpe::SpecialTokenMap special_tokens_;
TrieTree<char32_t> added_tokens_;
std::string pre_tokenizer_regex_;
};

} // namespace ort_extensions
10 changes: 5 additions & 5 deletions operators/tokenizer/bpe_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,12 @@ class SpecialTokenMap {
std::unordered_map<ustring, int> token_map_;
};

class TokenWithRegularExp {
class PreTokenizerWithRegEx {
public:
static constexpr const char* GPT2_REGEX_PATTERN = "'s|'t|'re|'ve|'m|'ll|'d|?\\p{L}+|?\\p{N}+|?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+";
static constexpr const char* LLAMA_REGEX_PATTERN = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}|?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+";
static constexpr const char* LLAMA_REGEX_PATTERN_2 = "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+";

void Set(std::u32string_view val) {
m_text = val;
}
Expand All @@ -115,10 +119,6 @@ class TokenWithRegularExp {
return {false, {}};
}

const std::string GPT2_REGEX_PATTERN = "'s|'t|'re|'ve|'m|'ll|'d|?\\p{L}+|?\\p{N}+|?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+";
const std::string LLAMA_REGEX_PATTERN = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}|?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+";
const std::string LLAMA_REGEX_PATTERN_2 = "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+";

public:

// Although we have RegexMatchGeneral which performs regex matching given any general regex string
Expand Down
1 change: 1 addition & 0 deletions operators/tokenizer/tokenizer_jsconfig.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ constexpr std::pair<const char*, TokenType> kTokenizerDict[] = {
{"GPT2Tokenizer", TokenType::kBPE},
{"Qwen2Tokenizer", TokenType::kBPE},
{"BaichuanTokenizer", TokenType::kBPE},
{"GPTNeoXTokenizer", TokenType::kBPE},

{"", TokenType::kUnigram},
{"T5Tokenizer", TokenType::kUnigram},
Expand Down
1 change: 1 addition & 0 deletions operators/tokenizer/trietree.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ class TrieTree {
tok_idx += 1;
if (tok_id == invalid_id) {
if (tok_idx < input.length()) {
tok_idx -= tok_len; // backtrack to the last token
continue;
} else {
tok_idx += 1; // Assign tok_idx to input.length()
Expand Down
2 changes: 1 addition & 1 deletion pyop/py_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ void AddGlobalMethodsCApi(pybind11::module& m) {
OrtxTokenizer* tokenizer = nullptr;
auto err = OrtxCreateTokenizer(&tokenizer, tokenizer_def_json.c_str());
if (err != kOrtxOK) {
throw std::runtime_error(std::string("Failed to create tokenizer") + OrtxGetLastErrorMessage());
throw std::runtime_error(std::string("Failed to create tokenizer\n") + OrtxGetLastErrorMessage());
}
return reinterpret_cast<std::uintptr_t>(tokenizer);
},
Expand Down
7 changes: 2 additions & 5 deletions pyop/pyfunc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -258,10 +258,9 @@ void PyCustomOpKernel::Compute(OrtKernelContext* context) {
inputs.push_back(InputInformation{input_X, i_dtype, i_dimensions});
}

/* Acquire GIL before calling Python code, due to it was released in sess.run */
py::gil_scoped_acquire acquire;

{
/* Acquire GIL before calling Python C API, due to it was released in sess.run */
py::gil_scoped_acquire acquire;
py::list pyinputs;
for (auto it = inputs.begin(); it != inputs.end(); ++it) {
py::object input0 = PyCustomOpDefImpl::BuildPyArrayFromTensor(
Expand Down Expand Up @@ -349,8 +348,6 @@ void PyCustomOpKernel::Compute(OrtKernelContext* context) {
memcpy(out, retval.data(), size * retval.size());
}
}

py::gil_scoped_release release;
}
}

Expand Down
77 changes: 28 additions & 49 deletions shared/api/tokenizer_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,33 +11,15 @@

namespace ort_extensions {

std::set<std::string> TokenizerImpl::supported_bpe_models_ = {
"PreTrainedTokenizerFast",
"CLIPTokenizer",
"WhisperTokenizer",
"GemmaTokenizer",
"LlamaTokenizer",
"Phi3Tokenizer",
"CodeLlamaTokenizer",
"CodeGenTokenizer",
"GPT2Tokenizer",
"Qwen2Tokenizer",
"BaichuanTokenizer"
};

std::set<std::string> TokenizerImpl::supported_ugm_models_ = {
"XLMRobertaTokenizer",
"T5Tokenizer",
"ChatGLMTokenizer"
};

TokenizerImpl::TokenizerImpl()
: OrtxObjectImpl(extObjectKind_t::kOrtxKindTokenizer) {};
TokenizerImpl::~TokenizerImpl() {};

OrtxStatus TokenizerImpl::LoadTokenizer(const OrtxTokenizerBlob* blob) {
if (tok_config_->tokenizer_class_.empty() ||
supported_ugm_models_.count(tok_config_->tokenizer_class_)) {

auto type = TokenJsonConfig::GetTokenType(tok_config_->tokenizer_class_);
if (type == TokenType::kUnigram) {
auto tokenizer = std::make_unique<SpmUgmTokenizer>();
auto status = tokenizer->Load(*tok_config_);
if (!status.IsOk()) {
Expand All @@ -53,42 +35,39 @@ OrtxStatus TokenizerImpl::LoadTokenizer(const OrtxTokenizerBlob* blob) {
tokenizer_ = std::move(tokenizer);
detokenizer_ = std::move(detok);
}

return status;
}

if (!supported_bpe_models_.count(tok_config_->tokenizer_class_)) {
return OrtxStatus(kOrtxErrorNotImplemented, "Unsupported tokenizer class");
}

auto tokenizer = std::make_unique<JsonFastTokenizer>();
auto fx_load = &JsonFastTokenizer::Load;
if (blob == nullptr) {
auto vocab_file_path = ortx::path(tok_config_->GetVocabDataFile());
// vocab file is checked in TokenJsonConfig::Load
if (vocab_file_path.extension() != ".json") {
fx_load = &JsonFastTokenizer::LoadTikTokenBase64;
} else if (type == TokenType::kBPE) {
auto tokenizer = std::make_unique<JsonFastTokenizer>();
auto fx_load = &JsonFastTokenizer::Load;
if (blob == nullptr) {
auto vocab_file_path = ortx::path(tok_config_->GetVocabDataFile());
// vocab file is checked in TokenJsonConfig::Load
if (vocab_file_path.extension() != ".json") {
fx_load = &JsonFastTokenizer::LoadTikTokenBase64;
}
} else {
if (blob->raw_model_blob_len > 0) {
fx_load = &JsonFastTokenizer::LoadTikTokenBase64;
}
}
} else {
if (blob->raw_model_blob_len > 0) {
fx_load = &JsonFastTokenizer::LoadTikTokenBase64;

auto status = (tokenizer.get()->*fx_load)(*tok_config_);
if (!status.IsOk()) {
return status;
}
}

auto status = (tokenizer.get()->*fx_load)(*tok_config_);
if (!status.IsOk()) {
return status;
}
auto detok = std::make_unique<BpeStreamingDecoder>();
status = detok->Load(tok_config_, *tokenizer);

auto detok = std::make_unique<BpeStreamingDecoder>();
status = detok->Load(tok_config_, *tokenizer);
if (status.IsOk()) {
tokenizer_ = std::move(tokenizer);
detokenizer_ = std::move(detok);
}

if (status.IsOk()) {
tokenizer_ = std::move(tokenizer);
detokenizer_ = std::move(detok);
return status;
}

return status;
return OrtxStatus(kOrtxErrorNotImplemented, "Unsupported tokenizer class");
}

OrtxStatus TokenizerImpl::Load(const OrtxTokenizerBlob& blob) {
Expand Down
3 changes: 0 additions & 3 deletions shared/api/tokenizer_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ namespace ort_extensions {

class TokenizerImpl : public OrtxObjectImpl {
public:
static std::set<std::string> supported_bpe_models_;
static std::set<std::string> supported_ugm_models_;

TokenizerImpl();
virtual ~TokenizerImpl();

Expand Down
8 changes: 4 additions & 4 deletions test/pp_api_test/test_tokenizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ TEST(CApiTest, StreamApiTest) {

TEST(OrtxTokenizerTest, RegexTest) {
std::u32string str = U"CAN'T \r\n 2413m";
auto regcmp = std::make_unique<ort_extensions::bpe::TokenWithRegularExp>();
auto regcmp = std::make_unique<ort_extensions::bpe::PreTokenizerWithRegEx>();

std::vector<std::u32string> res;
std::vector<std::u32string> out_tokens = {U"CAN", U"'T", U" \r\n", U" ", U"241", U"3", U"m"};
Expand All @@ -91,7 +91,7 @@ TEST(OrtxTokenizerTest, RegexMatchSTDTest) {
std::vector<std::u32string> input_strings = {U"not its, or IT'S, but it's",
U" ",
U"AbCd"};
auto regcmp = std::make_unique<ort_extensions::bpe::TokenWithRegularExp>();
auto regcmp = std::make_unique<ort_extensions::bpe::PreTokenizerWithRegEx>();

std::vector<std::vector<std::u32string>> res_vector;
std::vector<std::vector<std::u32string>> out_tokens = {{U"'s"},
Expand All @@ -118,7 +118,7 @@ TEST(OrtxTokenizerTest, WrapStandaloneCategoriesTest) {
"\\p{rn}\\p{L}\\p{N}\\p{L}",
"\\p{Z}*[\\p{rn}]+",
"\\p{Z}+"};
auto regcmp = std::make_unique<ort_extensions::bpe::TokenWithRegularExp>();
auto regcmp = std::make_unique<ort_extensions::bpe::PreTokenizerWithRegEx>();

std::vector<std::string> res;
std::vector<std::string> out_regex = {"[^\\p{rn}\\p{L}\\p{N}]?[\\p{L}]+",
Expand Down Expand Up @@ -152,7 +152,7 @@ TEST(OrtxTokenizerTest, RegexMatchGeneralTest) {
U"241356m",
U"Ich liebe München <3 \r\n ",
U"生活的真谛是"};
auto regcmp = std::make_unique<ort_extensions::bpe::TokenWithRegularExp>();
auto regcmp = std::make_unique<ort_extensions::bpe::PreTokenizerWithRegEx>();

std::vector<std::vector<std::u32string>> res_vector;
std::vector<std::vector<std::u32string>> out_tokens = {{U"CAN", U"'T", U"", U""},
Expand Down
Loading
Loading