diff --git a/.github/workflows/android-build.yml b/.github/workflows/android-build.yml index 6fa3bd530..34b844359 100644 --- a/.github/workflows/android-build.yml +++ b/.github/workflows/android-build.yml @@ -72,16 +72,23 @@ jobs: with: submodules: true + - name: Install Rust Toolchain + uses: dtolnay/rust-toolchain@1.82.0 + + - name: Install Rust Android Toolchain + run: | + rustup target add --toolchain 1.82.0-x86_64-unknown-linux-gnu x86_64-linux-android + - name: Create Android build run: | set -e -x rm -rf build - ./build.sh --android --android_api=27 --android_ndk_path=${ANDROID_NDK_LATEST_HOME} --config=RelWithDebInfo --android_abi=${{ env.ANDROID_ABI }} --parallel --build_java --update + ./build.sh --android --android_api=27 --android_ndk_path=${ANDROID_NDK_LATEST_HOME} --config=RelWithDebInfo --android_abi=${{ env.ANDROID_ABI }} --parallel --build_java --update --use_guidance - name: Run Android build run: | set -e -x - ./build.sh --android --android_api=27 --android_ndk_path=${ANDROID_NDK_LATEST_HOME} --config=RelWithDebInfo --android_abi=${{ env.ANDROID_ABI }} --parallel --build_java --build + ./build.sh --android --android_api=27 --android_ndk_path=${ANDROID_NDK_LATEST_HOME} --config=RelWithDebInfo --android_abi=${{ env.ANDROID_ABI }} --parallel --build_java --build --use_guidance - name: Enable KVM group perms so Android emulator can run run: | diff --git a/.github/workflows/linux-cpu-x64-build.yml b/.github/workflows/linux-cpu-x64-build.yml index 5fc97369d..3e41ba070 100644 --- a/.github/workflows/linux-cpu-x64-build.yml +++ b/.github/workflows/linux-cpu-x64-build.yml @@ -39,6 +39,9 @@ jobs: with: gradle-version: '8.6' + - name: Install Rust Toolchain + uses: dtolnay/rust-toolchain@1.82.0 + - name: Get the Latest OnnxRuntime Nightly Version shell: pwsh run: | @@ -74,8 +77,8 @@ jobs: run: | set -e -x rm -rf build - cmake --preset linux_gcc_cpu_release - cmake --build --preset linux_gcc_cpu_release + cmake --preset linux_gcc_cpu_release -DUSE_GUIDANCE=ON + cmake --build --preset linux_gcc_cpu_release -DUSE_GUIDANCE=ON - name: Install the python wheel and test dependencies run: | diff --git a/.github/workflows/linux-cpu-x64-nightly-build.yml b/.github/workflows/linux-cpu-x64-nightly-build.yml index 61be5eb6f..c78e8bc8d 100644 --- a/.github/workflows/linux-cpu-x64-nightly-build.yml +++ b/.github/workflows/linux-cpu-x64-nightly-build.yml @@ -22,7 +22,8 @@ jobs: - name: Checkout OnnxRuntime GenAI repo uses: actions/checkout@v2 - + - name: Install Rust Toolchain + uses: dtolnay/rust-toolchain@1.82.0 - name: Download OnnxRuntime run: | @@ -45,8 +46,8 @@ jobs: run: | set -e -x rm -rf build - cmake --preset linux_gcc_cpu_release - cmake --build --preset linux_gcc_cpu_release + cmake --preset linux_gcc_cpu_release -DUSE_GUIDANCE=ON + cmake --build --preset linux_gcc_cpu_release -DUSE_GUIDANCE=ON - name: Install the python wheel and test dependencies run: | diff --git a/.github/workflows/mac-cpu-arm64-build.yml b/.github/workflows/mac-cpu-arm64-build.yml index 658f7a660..86fdde374 100644 --- a/.github/workflows/mac-cpu-arm64-build.yml +++ b/.github/workflows/mac-cpu-arm64-build.yml @@ -52,13 +52,16 @@ jobs: mv ${{ env.ORT_PACKAGE_NAME }}/build/native/include ort/ mv ${{ env.ORT_PACKAGE_NAME }}/runtimes/osx-arm64/native/* ort/lib/ + - name: Install Rust Toolchain + uses: dtolnay/rust-toolchain@1.82.0 + - name: Configure CMake run: | cmake --preset macos_arm64_cpu_release - name: Build with CMake run: | - cmake --build --preset macos_arm64_cpu_release --parallel + cmake --build --preset macos_arm64_cpu_release --parallel -DUSE_GUIDANCE=ON continue-on-error: false - name: Install the python wheel and test dependencies diff --git a/.github/workflows/win-cpu-arm64-build.yml b/.github/workflows/win-cpu-arm64-build.yml index b6b20cfc6..fca5599cd 100644 --- a/.github/workflows/win-cpu-arm64-build.yml +++ b/.github/workflows/win-cpu-arm64-build.yml @@ -64,15 +64,22 @@ jobs: move ${{ env.ORT_PACKAGE_NAME }}/build/native/include ort/ move ${{ env.ORT_PACKAGE_NAME }}/runtimes/win-arm64/native/* ort/lib/ + - name: Install Rust Toolchain + run: | + $exePath = "$env:TEMP\rustup-init.exe" + (New-Object Net.WebClient).DownloadFile('https://static.rust-lang.org/rustup/dist/x86_64-pc-windows-msvc/rustup-init.exe', $exePath) + & $exePath -y --default-toolchain=1.82.0 + Add-Content $env:GITHUB_PATH "$env:USERPROFILE\.cargo\bin" + - name: Configure CMake run: | python -m pip install wheel requests - cmake --preset windows_arm64_cpu_release + cmake --preset windows_arm64_cpu_release -DUSE_GUIDANCE=ON - name: Build with CMake run: | - cmake --build --preset windows_arm64_cpu_release --parallel + cmake --build --preset windows_arm64_cpu_release --parallel -DUSE_GUIDANCE=ON - name: Install the Python Wheel and Test Dependencies run: | diff --git a/.github/workflows/win-cpu-x64-build.yml b/.github/workflows/win-cpu-x64-build.yml index 3374a3b6d..0a801237c 100644 --- a/.github/workflows/win-cpu-x64-build.yml +++ b/.github/workflows/win-cpu-x64-build.yml @@ -53,6 +53,13 @@ jobs: with: gradle-version: '8.6' + - name: Install Rust Toolchain + run: | + $exePath = "$env:TEMP\rustup-init.exe" + (New-Object Net.WebClient).DownloadFile('https://static.rust-lang.org/rustup/dist/x86_64-pc-windows-msvc/rustup-init.exe', $exePath) + & $exePath -y --default-toolchain=1.82.0 + Add-Content $env:GITHUB_PATH "$env:USERPROFILE\.cargo\bin" + - name: Download OnnxRuntime Nightly shell: pwsh run: | @@ -78,11 +85,11 @@ jobs: - name: Configure CMake run: | - cmake --preset windows_x64_cpu_release + cmake --preset windows_x64_cpu_release -DUSE_GUIDANCE=ON - name: Build with CMake run: | - cmake --build --preset windows_x64_cpu_release --parallel + cmake --build --preset windows_x64_cpu_release --parallel -DUSE_GUIDANCE=ON - name: Install the python wheel and test dependencies run: | diff --git a/.github/workflows/win-cuda-x64-build.yml b/.github/workflows/win-cuda-x64-build.yml index 5cda163d1..acc2ec95a 100644 --- a/.github/workflows/win-cuda-x64-build.yml +++ b/.github/workflows/win-cuda-x64-build.yml @@ -59,15 +59,22 @@ jobs: run: | mkdir ort/lib move ${{ env.ORT_PACKAGE_NAME }}/buildTransitive/native/include ort/ - move ${{ env.ORT_PACKAGE_NAME }}/runtimes/win-x64/native/* ort/lib/ + move ${{ env.ORT_PACKAGE_NAME }}/runtimes/win-x64/native/* ort/lib/ + + - name: Install Rust Toolchain + run: | + $exePath = "$env:TEMP\rustup-init.exe" + (New-Object Net.WebClient).DownloadFile('https://static.rust-lang.org/rustup/dist/x86_64-pc-windows-msvc/rustup-init.exe', $exePath) + & $exePath -y --default-toolchain=1.82.0 + Add-Content $env:GITHUB_PATH "$env:USERPROFILE\.cargo\bin" - name: Configure CMake run: | - cmake --preset windows_x64_cuda_release -T cuda=${{ env.cuda_dir }}\\v${{ env.cuda_version }} + cmake --preset windows_x64_cuda_release -T cuda=${{ env.cuda_dir }}\\v${{ env.cuda_version }} -DUSE_GUIDANCE=ON - name: Build with CMake run: | - cmake --build --preset windows_x64_cuda_release --parallel + cmake --build --preset windows_x64_cuda_release --parallel -DUSE_GUIDANCE=ON - name: Add CUDA to PATH run: | diff --git a/.github/workflows/win-directml-x64-build.yml b/.github/workflows/win-directml-x64-build.yml index 678573606..37b213e8b 100644 --- a/.github/workflows/win-directml-x64-build.yml +++ b/.github/workflows/win-directml-x64-build.yml @@ -78,13 +78,20 @@ jobs: mv $env:d3d12_dir\build\native\bin\x64\D3D12Core.dll ort\lib mv $env:dml_dir\include\DirectML.h ort\include + - name: Install Rust Toolchain + run: | + $exePath = "$env:TEMP\rustup-init.exe" + (New-Object Net.WebClient).DownloadFile('https://static.rust-lang.org/rustup/dist/x86_64-pc-windows-msvc/rustup-init.exe', $exePath) + & $exePath -y --default-toolchain=1.82.0 + Add-Content $env:GITHUB_PATH "$env:USERPROFILE\.cargo\bin" + - name: Configure CMake run: | - cmake --preset windows_x64_directml_release -DTEST_PHI2=False + cmake --preset windows_x64_directml_release -DTEST_PHI2=False -DUSE_GUIDANCE=ON - name: Build with CMake run: | - cmake --build --preset windows_x64_directml_release --parallel + cmake --build --preset windows_x64_directml_release --parallel -DUSE_GUIDANCE=ON - name: Install the Python Wheel and Test Dependencies run: | diff --git a/.gitignore b/.gitignore index 3b68e21f3..8a756f884 100644 --- a/.gitignore +++ b/.gitignore @@ -30,6 +30,7 @@ examples/csharp/HelloPhi/models !test/test_models/hf-internal-testing/ !test/test_models/hf-internal-testing/tiny-random-gpt2*/*.onnx +!test/test_models/grammars/ .ipynb_checkpoints/ /src/java/.gradle diff --git a/CMakeLists.txt b/CMakeLists.txt index 4b9f83ec1..b70214e0d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -47,6 +47,13 @@ include(cmake/check_webgpu.cmake) include(cmake/cxx_standard.cmake) add_compile_definitions(BUILDING_ORT_GENAI_C) + +if(USE_GUIDANCE) + add_compile_definitions(USE_GUIDANCE=1) +else() + add_compile_definitions(USE_GUIDANCE=0) +endif() + if(MSVC) # set updated value for __cplusplus macro instead of 199711L add_compile_options($<$:/Zc:__cplusplus>) @@ -142,6 +149,19 @@ if(USE_CUDA AND CMAKE_CUDA_COMPILER) endif() endif() + +if(USE_GUIDANCE) + target_include_directories(onnxruntime-genai PUBLIC ${llguidance_SOURCE_DIR}/parser/) + target_include_directories(onnxruntime-genai-static PUBLIC ${llguidance_SOURCE_DIR}/parser/) + target_link_libraries(onnxruntime-genai PRIVATE llguidance_parser) + target_link_libraries(onnxruntime-genai-static PUBLIC llguidance_parser) + if (WIN32) + # bcrypt is needed for the rust std lib + target_link_libraries(onnxruntime-genai PRIVATE bcrypt) + target_link_libraries(onnxruntime-genai-static PRIVATE bcrypt) + endif() +endif() + if(CMAKE_GENERATOR_TOOLSET MATCHES "Visual Studio") target_link_options(onnxruntime-genai PRIVATE "/CETCOMPAT") target_compile_options(onnxruntime-genai PRIVATE "/sdl") diff --git a/build.py b/build.py index b83c18e11..87c2c1a39 100644 --- a/build.py +++ b/build.py @@ -130,6 +130,8 @@ class HelpFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescript parser.add_argument("--use_dml", action="store_true", help="Whether to use DML. Default is to not use DML.") + parser.add_argument("--use_guidance", action="store_true", help="Whether to add guidance support. Default is False.") + # The following options are mutually exclusive (cross compiling options such as android, ios, etc.) platform_group = parser.add_mutually_exclusive_group() platform_group.add_argument("--android", action="store_true", help="Build for Android") @@ -477,6 +479,7 @@ def update(args: argparse.Namespace, env: dict[str, str]): f"-DUSE_DML={'ON' if args.use_dml else 'OFF'}", f"-DENABLE_JAVA={'ON' if args.build_java else 'OFF'}", f"-DBUILD_WHEEL={build_wheel}", + f"-DUSE_GUIDANCE={'ON' if args.use_guidance else 'OFF'}", ] if args.ort_home: @@ -535,6 +538,8 @@ def _get_opencv_toolchain_file(): "-DENABLE_TESTS=OFF", "-DENABLE_MODEL_BENCHMARK=OFF", ] + if args.use_guidance: + command += ["-DRust_CARGO_TARGET=aarch64-apple-ios-sim"] if args.macos == "Catalyst": if args.cmake_generator == "Xcode": diff --git a/cmake/deps.txt b/cmake/deps.txt index 6e504c9a7..49058600a 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -15,3 +15,5 @@ googletest;https://github.com/google/googletest/archive/530d5c8c84abd2a46f38583e microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5 directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e onnxruntime_extensions;https://github.com/microsoft/onnxruntime-extensions.git;2c3e936cfc3401ba7ebb79d02b9e52a50439ffc3 +llguidance;https://github.com/microsoft/llguidance.git;4dc358feef3cdf0542a5f95b5f4e92761887a25d +corrosion;https://github.com/corrosion-rs/corrosion.git;64289b1d79d6d19cd2e241db515381a086bb8407 \ No newline at end of file diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 8dde244c7..aeb561c1e 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -93,3 +93,19 @@ list(APPEND EXTERNAL_LIBRARIES ocos_operators noexcep_operators ) + +if(USE_GUIDANCE) + FetchContent_Declare( + Corrosion + GIT_REPOSITORY ${DEP_URL_corrosion} + GIT_TAG ${DEP_SHA1_corrosion} + ) + onnxruntime_fetchcontent_makeavailable(Corrosion) + FetchContent_Declare( + llguidance + GIT_REPOSITORY ${DEP_URL_llguidance} + GIT_TAG ${DEP_SHA1_llguidance} + ) + onnxruntime_fetchcontent_makeavailable(llguidance) + corrosion_import_crate(MANIFEST_PATH ${llguidance_SOURCE_DIR}/parser/Cargo.toml) +endif() \ No newline at end of file diff --git a/cmake/options.cmake b/cmake/options.cmake index fcb8454bb..d57ad6ac7 100644 --- a/cmake/options.cmake +++ b/cmake/options.cmake @@ -5,6 +5,7 @@ option(USE_CUDA "Build with CUDA support" ON) option(USE_ROCM "Build with ROCm support" ON) option(USE_DML "Build with DML support" OFF) option(USE_WEBGPU "Build with WEBGPU support" ON) +option(USE_GUIDANCE "Build with guidance support" ON) # bindings option(ENABLE_JAVA "Build the Java API." OFF) diff --git a/src/cuda/interface.cpp b/src/cuda/interface.cpp index b225604f1..958e4a888 100644 --- a/src/cuda/interface.cpp +++ b/src/cuda/interface.cpp @@ -160,6 +160,10 @@ struct CudaInterfaceImpl : CudaInterface { cuda::LaunchHandleEOSArray(batch_logits, batch_beam_size, vocab_size, eos_token_ids, eos_token_ids_count, stream); } + void LaunchAddLogitsMask(float* batch_logits, int batch_beam_size, int vocab_size, const uint32_t* logits_mask, cudaStream_t stream) override { + cuda::LaunchAddLogitsMask(batch_logits, batch_beam_size, vocab_size, logits_mask, stream); + } + void UpdateCacheIndirectionKernelLauncher(int32_t* tgt_indir_cache, const int32_t* src_indir_cache, const int32_t* beam_ids, int batch_size, int beam_width, int input_seq_length, int max_seq_length, int current_length, cudaStream_t stream) override { cuda::UpdateCacheIndirectionKernelLauncher(tgt_indir_cache, src_indir_cache, beam_ids, batch_size, beam_width, input_seq_length, max_seq_length, current_length, stream); } diff --git a/src/cuda/interface.h b/src/cuda/interface.h index d664277cc..00cbd8fa6 100644 --- a/src/cuda/interface.h +++ b/src/cuda/interface.h @@ -36,6 +36,7 @@ struct CudaInterface : DeviceInterface { virtual void Launch_UpdateAttentionMask(int32_t* mask_data, const int32_t* old_data, int batch_beam_size, int new_kv_length, int total_length, int max_length, bool update_only, cudaStream_t stream) = 0; virtual void Launch_UpdateAttentionMask(int64_t* mask_data, const int64_t* old_data, int batch_beam_size, int new_kv_length, int total_length, int max_length, bool update_only, cudaStream_t stream) = 0; virtual void LaunchHandleEOSArray(float* batch_logits, int batch_beam_size, int vocab_size, const int32_t* eos_token_ids, int eos_token_ids_count, cudaStream_t stream) = 0; + virtual void LaunchAddLogitsMask(float* batch_logits, int batch_beam_size, int vocab_size, const uint32_t* logits_mask, cudaStream_t stream) = 0; virtual void UpdateCacheIndirectionKernelLauncher(int32_t* tgt_indir_cache, const int32_t* src_indir_cache, const int32_t* beam_ids, int batch_size, int beam_width, int input_seq_length, int max_seq_length, int current_length, cudaStream_t stream) = 0; virtual void ReorderPastStatesKernelLauncher(void* out_buffer, const void* in_buffer, int batch_size, int num_heads, int max_length, int head_size, int chunk_size, cudaStream_t stream) = 0; virtual void LaunchCopyCrossQKSingleDecodeStep(cudaStream_t stream, float* cross_qk_buffer_data, float** qk_layer_pointers, int token_index, int batch_beam_size, int num_layers, int num_heads, int num_alignment_heads, const int* alignment_heads, int frames, int max_length) = 0; diff --git a/src/cuda/model_kernels.cu b/src/cuda/model_kernels.cu index 0eb316383..5369d44a1 100644 --- a/src/cuda/model_kernels.cu +++ b/src/cuda/model_kernels.cu @@ -100,6 +100,22 @@ void LaunchHandleEOSArray(float* batch_logits, int batch_beam_size, int vocab_si HandleEOSArray<<<(batch_beam_size + 255) / 256, 256, 0, stream>>>(batch_logits, batch_beam_size, vocab_size, eos_token_ids, eos_token_ids_count); } +__global__ void AddLogitsMask(float* batch_logits, int batch_beam_size, int vocab_size, const uint32_t* logits_mask) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index >= batch_beam_size * vocab_size) + return; + int batch_index = index / vocab_size; + int vocab_index = index % vocab_size; + if (!(logits_mask[(batch_index * vocab_size + vocab_index) / 32] & (1 << (vocab_index % 32)))) + batch_logits[index] = std::numeric_limits::lowest(); +} + +void LaunchAddLogitsMask(float* batch_logits, int batch_beam_size, int vocab_size, const uint32_t* logits_mask, cudaStream_t stream) { + int block_size = 256; + int num_blocks = (batch_beam_size * vocab_size + block_size - 1) / block_size; + AddLogitsMask<<>>(batch_logits, batch_beam_size, vocab_size, logits_mask); +} + __global__ void ConvertFp16ToFp32(const half* src, float* dst, int count) { int idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx < count) diff --git a/src/generators.cpp b/src/generators.cpp index 9415cb1db..8a8b08a7a 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -6,6 +6,7 @@ #include "models/env_utils.h" #include "models/model.h" #include "models/decoder_only.h" +#include "logits_processor.h" #include "search.h" #include "cpu/interface.h" #include "cuda/interface.h" @@ -158,6 +159,7 @@ void Launch_UpdateAttentionMask(int32_t* mask_data, const int32_t* old_ template <> void Launch_UpdateAttentionMask(int64_t* mask_data, const int64_t* old_data, int batch_beam_size, int new_kv_length, int total_length, int max_length, bool update_only, cudaStream_t stream) { GetCudaInterface()->Launch_UpdateAttentionMask(mask_data, old_data, batch_beam_size, new_kv_length, total_length, max_length, update_only, stream); } void LaunchHandleEOSArray(float* batch_logits, int batch_beam_size, int vocab_size, const int32_t* eos_token_ids, int eos_token_ids_count, cudaStream_t stream) { GetCudaInterface()->LaunchHandleEOSArray(batch_logits, batch_beam_size, vocab_size, eos_token_ids, eos_token_ids_count, stream); } +void LaunchAddLogitsMask(float* batch_logits, int batch_beam_size, int vocab_size, const uint32_t* logits_mask, cudaStream_t stream) { GetCudaInterface()->LaunchAddLogitsMask(batch_logits, batch_beam_size, vocab_size, logits_mask, stream); } void UpdateCacheIndirectionKernelLauncher(int32_t* tgt_indir_cache, const int32_t* src_indir_cache, const int32_t* beam_ids, int batch_size, int beam_width, int input_seq_length, int max_seq_length, int current_length, cudaStream_t stream) { GetCudaInterface()->UpdateCacheIndirectionKernelLauncher(tgt_indir_cache, src_indir_cache, beam_ids, batch_size, beam_width, input_seq_length, max_seq_length, current_length, stream); } void ReorderPastStatesKernelLauncher(void* out_buffer, const void* in_buffer, int batch_size, int num_heads, int max_length, int head_size, int chunk_size, cudaStream_t stream) { GetCudaInterface()->ReorderPastStatesKernelLauncher(out_buffer, in_buffer, batch_size, num_heads, max_length, head_size, chunk_size, stream); } template <> @@ -246,6 +248,11 @@ void GeneratorParams::SetInputs(const NamedTensors& named_tensors) { } } +void GeneratorParams::SetGuidance(std::string_view type, std::string_view data) { + guidance_type = type; + guidance_data = data; +} + std::unique_ptr CreateGenerator(const Model& model, const GeneratorParams& params) { return std::make_unique(model, params); } @@ -269,6 +276,7 @@ Generator::Generator(const Model& model, const GeneratorParams& params) : model_ search_ = CreateSearch(params); state_ = model.CreateState(search_->GetSequenceLengths(), params); // Search sequence lengths set when creating state + logits_processor_ = CreateLogitsProcessor(*state_); // Could be nullptr if no logits processor is used // Temporary solution for multimodal and whisper models if (!params.aux_input_ids.empty() && params.aux_input_ids.data() != nullptr) { AuxAppendTokens(params.aux_input_ids); @@ -333,7 +341,10 @@ void Generator::AppendTokens(cpu_span input_ids) { void Generator::ComputeLogits(DeviceSpan next_tokens) { if (computed_logits_) throw std::runtime_error("ComputeLogits called again without calling AppendTokens or GenerateNextToken first"); - + if (last_action_ == Action::generated && logits_processor_) { + auto next_tokens_span = next_tokens.CopyDeviceToCpu(); + logits_processor_->CommitTokens(next_tokens_span); + } auto logits = state_->Run(search_->GetSequenceLength(), next_tokens, search_->GetNextIndices()); if (g_log.enabled && g_log.model_logits) { auto& stream = Log("model_logits"); @@ -395,6 +406,10 @@ void Generator::GenerateNextToken() { search_->AppendTokens(next_tokens); ComputeLogits(next_tokens); } + if (logits_processor_) { + auto logits = GetLogits(); + logits_processor_->ProcessLogits(logits); + } computed_logits_ = false; auto& search = search_->params_->search; search_->ApplyMinLength(search.min_length); @@ -448,6 +463,9 @@ void Generator::RewindToLength(size_t new_length) { throw std::runtime_error("RewindToLength must be called with new_length=0 when batch_size > 1"); search_->RewindTo(new_length); state_->RewindTo(new_length); + if (logits_processor_) { + logits_processor_->Reset(); + } computed_logits_ = false; last_action_ = Action::rewound; } diff --git a/src/generators.h b/src/generators.h index 28a71e580..723fbb2e8 100644 --- a/src/generators.h +++ b/src/generators.h @@ -46,6 +46,7 @@ struct Model; struct State; struct Search; struct Tokenizer; +struct LogitsProcessor; template DeviceSpan WrapTensor(DeviceInterface& device, OrtValue& value) { @@ -104,6 +105,10 @@ struct GeneratorParams : std::enable_shared_from_this, LeakChec void SetInputs(const NamedTensors& inputs); + std::string guidance_type; // e.g. json_schema or regex + std::string guidance_data; // e.g. rules data in json_schema or regex + void SetGuidance(std::string_view type, std::string_view data); + private: bool is_cuda_graph_enabled_{}; }; @@ -125,6 +130,7 @@ struct Generator : LeakChecked { std::shared_ptr model_; std::unique_ptr state_; std::unique_ptr search_; + std::unique_ptr logits_processor_; bool computed_logits_{}; // Set to true in ComputeLogits() and false after appending a token to ensure a 1 to 1 call ratio private: diff --git a/src/logits_processor.cpp b/src/logits_processor.cpp new file mode 100644 index 000000000..3768aa93d --- /dev/null +++ b/src/logits_processor.cpp @@ -0,0 +1,242 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include +#include +#include +#include +#include +#include +#include + +#include "generators.h" +#if USE_GUIDANCE +#include "llguidance.h" +#endif + +#if USE_CUDA +#include "cuda/cuda_common.h" +#include "models/kernels.h" +#endif + +#include "logits_processor.h" + +namespace Generators { + +#if USE_GUIDANCE +GuidanceLogitsProcessor::GuidanceLogitsProcessor(const State& state) + : vocab_size_(state.params_->config.model.vocab_size), + eos_token_(state.params_->config.model.eos_token_id), + batch_size_(state.params_->search.batch_size), + device_type_(state.params_->device_type), + guidance_type_(state.params_->guidance_type), + guidance_data_(state.params_->guidance_data) { + if (guidance_type_.empty() || guidance_data_.empty()) { + throw std::runtime_error("Guidance type and data must be provided together"); + } + + if (guidance_type_ != "json_schema" && guidance_type_ != "regex") { + throw std::runtime_error("Unsupported guidance type: " + std::string(guidance_type_) + " (only json_schema and regex are supported)"); + } + + auto tokenize_fn = (LlgTokenizeFn) + [](const void* user_data, const uint8_t* bytes, + size_t bytes_len, uint32_t* output_tokens, size_t output_tokens_len) + -> unsigned long { + const TokenizeData* tokenize_data = reinterpret_cast(user_data); + auto output_ids = tokenize_partial(reinterpret_cast(tokenize_data->tokenizer), tokenize_data->prefix_len, bytes, bytes_len); + size_t output_size = std::min(output_tokens_len, output_ids.size()); + for (size_t i = 0; i < output_size; i++) { + output_tokens[i] = output_ids[i]; + } + return static_cast(output_ids.size()); + }; + + auto tokenizer_path = state.params_->config.config_path.string(); + fs::path tokenizer_path_fs(tokenizer_path); + fs::path json_path(tokenizer_path_fs / kDefaultVocabFile); + std::ifstream json_file(json_path.string()); + std::stringstream json_buffer; + json_buffer << json_file.rdbuf(); + std::string json_data = json_buffer.str(); + tokenizer_ = state.model_.CreateTokenizer(); + auto prefix_len = tokenizer_->Encode(kTokenizePrefixStr).size(); + tokenize_data_ = {tokenizer_.get(), prefix_len}; + LlgTokenizerInit tokenizer_init = { + static_cast(vocab_size_), // vocab_size + eos_token_, // eos_token + nullptr, // token_lens + nullptr, // token_bytes + json_data.c_str(), // tokenizer_json config data + false, // tokenize_assumes_string + tokenize_fn, // tokenize_fn + false, // use_approximate_greedy_tokenize_fn + &tokenize_data_, // user_data + }; + + char error_buf[128]; + llg_tokenizer_ = std::unique_ptr(llg_new_tokenizer(&tokenizer_init, error_buf, sizeof(error_buf))); + if (!llg_tokenizer_) { + throw std::runtime_error("Error creating llg_tokenizer: " + std::string(error_buf)); + } + + llg_constraints_.resize(batch_size_); + for (int i = 0; i < batch_size_; i++) { + LlgConstraintInit constraint_init; + llg_constraint_init_set_defaults(&constraint_init, llg_tokenizer_.get()); + LlgConstraint* constraint_ptr; + if (guidance_type_ == "json_schema") { + constraint_ptr = llg_new_constraint_json(&constraint_init, guidance_data_.data()); + } else if (guidance_type_ == "regex") { + constraint_ptr = llg_new_constraint_regex(&constraint_init, guidance_data_.data()); + } else { + throw std::runtime_error("Unsupported guidance type: " + std::string(guidance_type_) + " (only json_schema and regex are supported)"); + } + if (llg_get_error(constraint_ptr) != nullptr) { + std::string error_message = llg_get_error(constraint_ptr); + llg_free_constraint(constraint_ptr); + throw std::runtime_error("Error creating grammar: " + error_message); + } + llg_constraints_[i] = std::unique_ptr(constraint_ptr); + } + + // Compute the mask asynchronously to avoid blocking the model inference on device + mask_future_ = std::async(std::launch::async, [&]() { + return ComputeMask(); + }); + +#if USE_CUDA + if (state.params_->device_type == DeviceType::CUDA) { + cuda_logits_mask_ptr_ = state.params_->p_device->Allocate(batch_size_ * vocab_size_ / 32); + } + cuda_stream_ = state.params_->cuda_stream; +#endif +} + +std::vector> GuidanceLogitsProcessor::ComputeMask() { + std::vector> masks; + for (int batch_idx = 0; batch_idx < batch_size_; batch_idx++) { // renamed 'i' to 'batch_idx' + LlgMaskResult mask_result; + auto error = llg_compute_mask(llg_constraints_[batch_idx].get(), &mask_result); + if (error != 0) { + std::string error_message = llg_get_error(llg_constraints_[batch_idx].get()); + throw std::runtime_error("Error computing mask: " + error_message); + } + + std::vector mask; + if (mask_result.is_stop) { + // when logits processor decides to stop, we mask all tokens except the EOS token + mask = std::vector((vocab_size_ - 1) / 32 + 1, 0); + uint32_t eos_mask32 = 1 << (eos_token_ % 32); + mask[eos_token_ / 32] = eos_mask32; + } else { + mask.reserve((vocab_size_ - 1) / 32 + 1); + for (int i = 0; i < (vocab_size_ - 1) / 32 + 1; i++) { + mask.push_back(mask_result.sample_mask[i]); + } + } + masks.push_back(mask); + } + return masks; +} + +void GuidanceLogitsProcessor::CommitTokens(std::span tokens) { + for (int i = 0; i < batch_size_; i++) { + LlgCommitResult commit_result; + auto error = llg_commit_token(llg_constraints_[i].get(), static_cast(tokens[i]), &commit_result); + if (error != 0) { + std::string error_message = llg_get_error(llg_constraints_[i].get()); + throw std::runtime_error("Error committing tokens: " + error_message); + } + } + mask_future_ = std::async(std::launch::async, [&]() { + return ComputeMask(); + }); + masks_.clear(); +} + +std::vector> GuidanceLogitsProcessor::GetMask() { + if (masks_.empty()) { + masks_ = mask_future_.get(); + } + return masks_; +} + +void GuidanceLogitsProcessor::ProcessLogits(DeviceSpan logits) { + auto masks = GetMask(); + +#if USE_CUDA + if (device_type_ == DeviceType::CUDA) { + for (int i = 0; i < static_cast(masks.size()); i++) { + cudaMemcpyAsync(cuda_logits_mask_ptr_.Span().data() + (i * vocab_size_ / 32), masks.at(i).data(), + static_cast(masks.at(i).size() * sizeof(uint32_t)), ::cudaMemcpyHostToDevice, cuda_stream_); + } + cuda::LaunchAddLogitsMask(logits.Span().data(), batch_size_, vocab_size_, cuda_logits_mask_ptr_.Span().data(), cuda_stream_); + return; + } +#endif + size_t vocab_index = 0; + + auto logits_span = logits.CpuSpan(); + for (int index = 0; index < batch_size_; index++) { + auto subspan = logits_span.subspan(vocab_index, vocab_size_); + auto& mask = masks[index]; + for (size_t i = 0; i < vocab_size_; i++) { + // mask is a 32-bit integer, where each bit corresponds to a token in the vocabulary. + // If the bit is set, the corresponding token is masked (i.e., its logit is set to the lowest possible value). + subspan[i] = mask[i / 32] & (1 << (i % 32)) ? subspan[i] : std::numeric_limits::lowest(); + } + vocab_index += vocab_size_; + } +} + +void GuidanceLogitsProcessor::Reset() { + masks_.clear(); + llg_constraints_.clear(); + llg_constraints_.resize(batch_size_); + for (int i = 0; i < batch_size_; i++) { + LlgConstraintInit constraint_init; + llg_constraint_init_set_defaults(&constraint_init, llg_tokenizer_.get()); + LlgConstraint* constraint_ptr; + if (guidance_type_ == "json_schema") { + constraint_ptr = llg_new_constraint_json(&constraint_init, guidance_data_.data()); + } else if (guidance_type_ == "regex") { + constraint_ptr = llg_new_constraint_regex(&constraint_init, guidance_data_.data()); + } else { + throw std::runtime_error("Unsupported guidance type: " + std::string(guidance_type_) + " (only json_schema and regex are supported)"); + } + if (llg_get_error(constraint_ptr) != nullptr) { + std::string error_message = llg_get_error(constraint_ptr); + llg_free_constraint(constraint_ptr); + throw std::runtime_error("Error creating grammar: " + error_message); + } + llg_constraints_[i] = std::unique_ptr(constraint_ptr); + } + + mask_future_ = std::async(std::launch::async, [&]() { + return ComputeMask(); + }); +} + +std::vector GuidanceLogitsProcessor::tokenize_partial(const Tokenizer* tokenizer, const size_t prefix_len, + const uint8_t* bytes, size_t bytes_len) { + // add prefix to tokenize for partial tokenization, it will produce ids more stable + std::string input_string = kTokenizePrefixStr; + input_string.reserve(bytes_len + 2); + for (size_t i = 0; i < bytes_len; i++) { + input_string.push_back(bytes[i]); + } + std::vector output_ids = tokenizer->Encode(input_string.c_str()); + return std::vector(output_ids.begin() + prefix_len, output_ids.end()); +} + +#endif + +std::unique_ptr CreateLogitsProcessor(const State& state) { + if (!state.params_->guidance_type.empty() && !state.params_->guidance_data.empty()) { +#if USE_GUIDANCE + return std::make_unique(state); +#endif + Log("warning", "No supported LogitsProcessor found. e.g. to use guidance, build with use_guidance=true"); + } + return nullptr; +} +} // namespace Generators diff --git a/src/logits_processor.h b/src/logits_processor.h new file mode 100644 index 000000000..66f08e779 --- /dev/null +++ b/src/logits_processor.h @@ -0,0 +1,93 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include +#include +#include + +#if USE_GUIDANCE +#include +#endif + +#include "models/model.h" + +namespace Generators { + +struct LogitsProcessor { + LogitsProcessor() = default; + virtual ~LogitsProcessor() = default; + // CommitTokens is used to commit the generated tokens to the logits processor + virtual void CommitTokens(std::span tokens) = 0; + // ProcessLogits is used to add logits mask to the logits + virtual void ProcessLogits(DeviceSpan logits) = 0; + // Reset is used to reset the logits processor after rewinding + virtual void Reset() = 0; +}; + +#if USE_GUIDANCE +struct LlgConstraintDeleter { + void operator()(LlgConstraint* lc) const { + llg_free_constraint(lc); + } +}; + +struct LlgTokenizerDeleter { + void operator()(LlgTokenizer* lt) const { + llg_free_tokenizer(lt); + } +}; + +struct GuidanceLogitsProcessor : public LogitsProcessor { + // llguidance need to use tokenizer.json to add special tokens + static constexpr const char* kDefaultVocabFile = "tokenizer.json"; + // tokenizer need to tokenize token with special prefix + static constexpr const char* kTokenizePrefixStr = "\x02"; + + GuidanceLogitsProcessor(const State& state); + void ProcessLogits(DeviceSpan logits) override; + void CommitTokens(std::span tokens) override; + void Reset() override; + // GetMask is used to get the logits mask + std::vector> GetMask(); + // tokenize_partial is used to tokenize the input tokens with special prefix, this will get stable + // token ids. + static std::vector tokenize_partial(const Tokenizer* tokenizer, const size_t prefix_len, + const uint8_t* bytes, size_t bytes_len); + + private: + std::vector> ComputeMask(); + + int vocab_size_; + uint32_t eos_token_; + int batch_size_; + DeviceType device_type_; + std::string_view guidance_type_; + std::string_view guidance_data_; + std::vector> masks_; + std::vector> llg_constraints_; + std::unique_ptr llg_tokenizer_; + std::shared_ptr tokenizer_; + + std::future>> mask_future_; + std::vector> logits_masks_; + +#if USE_CUDA + DeviceSpan cuda_logits_mask_ptr_; + cudaStream_t cuda_stream_; +#endif + + struct TokenizeData { + Tokenizer* tokenizer; + size_t prefix_len; + }; + TokenizeData tokenize_data_; +}; +#endif + +std::unique_ptr CreateLogitsProcessor(const State& state); + +} // namespace Generators diff --git a/src/models/kernels.h b/src/models/kernels.h index fece7dadf..d6c2ebcc7 100644 --- a/src/models/kernels.h +++ b/src/models/kernels.h @@ -11,6 +11,7 @@ template void Launch_UpdateAttentionMask(T* mask_data, const T* old_data, int batch_beam_size, int new_kv_length, int total_length, int max_length, bool update_only, cudaStream_t stream); void LaunchHandleEOSArray(float* batch_logits, int batch_beam_size, int vocab_size, const int32_t* eos_token_ids, int eos_token_ids_count, cudaStream_t stream); +void LaunchAddLogitsMask(float* batch_logits, int batch_beam_size, int vocab_size, const uint32_t* logits_mask, cudaStream_t stream); void LaunchFp16ToFp32(const uint16_t* fp16, float* fp32, int count, cudaStream_t stream); void LaunchFp32ToFp16(const float* fp32, uint16_t* fp16, int count, cudaStream_t stream); diff --git a/src/ort_genai.h b/src/ort_genai.h index 8b4a026d5..93366ec39 100644 --- a/src/ort_genai.h +++ b/src/ort_genai.h @@ -262,6 +262,10 @@ struct OgaGeneratorParams : OgaAbstract { OgaCheckResult(OgaGeneratorParamsTryGraphCaptureWithMaxBatchSize(this, max_batch_size)); } + void SetGuidance(const char* type, const char* data) { + OgaCheckResult(OgaGeneratorParamsSetGuidance(this, type, data)); + } + static void operator delete(void* p) { OgaDestroyGeneratorParams(reinterpret_cast(p)); } }; diff --git a/src/ort_genai_c.cpp b/src/ort_genai_c.cpp index 3f2c9d750..6d8bfbca5 100644 --- a/src/ort_genai_c.cpp +++ b/src/ort_genai_c.cpp @@ -7,6 +7,7 @@ #include "span.h" #include "ort_genai_c.h" #include "generators.h" +#include "logits_processor.h" #include "models/model.h" #include "runtime_settings.h" #include "search.h" @@ -261,6 +262,14 @@ OgaResult* OGA_API_CALL OgaGeneratorParamsSetWhisperInputFeatures(OgaGeneratorPa OGA_CATCH } +OgaResult* OGA_API_CALL OgaGeneratorParamsSetGuidance(OgaGeneratorParams* oga_params, const char* type, const char* data) { + OGA_TRY + auto& params = *reinterpret_cast(oga_params); + params.SetGuidance(type, data); + return nullptr; + OGA_CATCH +} + OgaResult* OgaCreateGenerator(const OgaModel* model, const OgaGeneratorParams* generator_params, OgaGenerator** out) { OGA_TRY *out = reinterpret_cast(CreateGenerator(*reinterpret_cast(model), *reinterpret_cast(generator_params)).release()); diff --git a/src/ort_genai_c.h b/src/ort_genai_c.h index b970ecbaf..31874b8c8 100644 --- a/src/ort_genai_c.h +++ b/src/ort_genai_c.h @@ -287,6 +287,8 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetModelInput(OgaGeneratorP OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetWhisperInputFeatures(OgaGeneratorParams*, OgaTensor* tensor); +OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetGuidance(OgaGeneratorParams*, const char* type, const char* data); + /** * \brief Creates a generator from the given model and generator params. * \param[in] model The model to use for generation. diff --git a/src/python/python.cpp b/src/python/python.cpp index 4ee8f87ed..2ac7ce877 100644 --- a/src/python/python.cpp +++ b/src/python/python.cpp @@ -8,6 +8,7 @@ #include "../ort_genai.h" #include "../json.h" #include "../search.h" +#include "../logits_processor.h" #include "../models/model.h" #include "../logging.h" #include "../smartptrs.h" @@ -265,6 +266,9 @@ struct PyGeneratorParams { params_->TryGraphCapture(max_batch_size.cast()); } + void SetGuidance(const std::string& type, const std::string& data) { + params_->SetGuidance(type, data); + } pybind11::array py_whisper_input_features_; pybind11::array py_alignment_heads_; @@ -396,7 +400,8 @@ PYBIND11_MODULE(onnxruntime_genai, m) { .def("set_model_input", &PyGeneratorParams::SetModelInput) .def("set_search_options", &PyGeneratorParams::SetSearchOptions) // See config.h 'struct Search' for the options .def("try_use_cuda_graph_with_max_batch_size", &PyGeneratorParams::TryUseCudaGraphWithMaxBatchSize) // will be deprecated - .def("try_graph_capture_with_max_batch_size", &PyGeneratorParams::TryGraphCaptureWithMaxBatchSize); + .def("try_graph_capture_with_max_batch_size", &PyGeneratorParams::TryGraphCaptureWithMaxBatchSize) + .def("set_guidance", &PyGeneratorParams::SetGuidance); pybind11::class_(m, "TokenizerStream") .def("decode", [](TokenizerStream& t, int32_t token) { return t.Decode(token); }); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 75070cbd7..40ef5e15e 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -13,6 +13,7 @@ add_executable(unit_tests model_tests.cpp sampling_tests.cpp sampling_benchmark.cpp + logits_processor_tests.cpp ) target_include_directories(unit_tests PRIVATE @@ -21,6 +22,7 @@ target_include_directories(unit_tests PRIVATE ) target_link_directories(unit_tests PRIVATE ${ORT_LIB_DIR}) + target_link_libraries(unit_tests PRIVATE onnxruntime-genai-static GTest::gtest_main @@ -56,3 +58,4 @@ if (NOT MSVC) endif() include(GoogleTest) + diff --git a/test/c_api_tests.cpp b/test/c_api_tests.cpp index e84d25855..485ea5656 100644 --- a/test/c_api_tests.cpp +++ b/test/c_api_tests.cpp @@ -7,6 +7,7 @@ #include "../src/span.h" #include #include +#include #ifndef MODEL_PATH #define MODEL_PATH "../../test/test_models/" @@ -743,3 +744,31 @@ TEST(CAPITests, RewindGptFp32CAPI) { expected_output_start = &expected_output[0]; EXPECT_TRUE(0 == std::memcmp(expected_output_start, sequence_data, sequence_length * sizeof(int32_t))); } + +#if USE_GUIDANCE +TEST(CAPITests, SetGuidance) { +#if TEST_PHI2 + + auto model = OgaModel::Create(PHI2_PATH); + auto tokenizer = OgaTokenizer::Create(*model); + auto tokenizer_stream = OgaTokenizerStream::Create(*tokenizer); + + const char* input_string = "who are you?"; + auto input_sequences = OgaSequences::Create(); + tokenizer->Encode(input_string, *input_sequences); + auto params = OgaGeneratorParams::Create(*model); + params->SetSearchOption("max_length", 32); + params->SetGuidance("regex", "answer: .*"); + + auto generator = OgaGenerator::Create(*model, *params); + generator->AppendTokenSequences(*input_sequences); + while (!generator->IsDone()) { + generator->GenerateNextToken(); + } + auto out_string = tokenizer->Decode(generator->GetSequenceData(0), generator->GetSequenceCount(0)); + auto output = std::string(out_string).substr(std::string(input_string).size()); + EXPECT_TRUE(std::regex_match(output, std::regex("answer: .*"))); + +#endif +} +#endif \ No newline at end of file diff --git a/test/logits_processor_tests.cpp b/test/logits_processor_tests.cpp new file mode 100644 index 000000000..b2aef4826 --- /dev/null +++ b/test/logits_processor_tests.cpp @@ -0,0 +1,73 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef MODEL_PATH +#define MODEL_PATH "../../test/test_models/" +#endif +#ifndef PHI2_PATH +#if USE_CUDA +#define PHI2_PATH MODEL_PATH "phi-2/int4/cuda" +#else +#define PHI2_PATH MODEL_PATH "phi-2/int4/cpu" +#endif +#endif +#ifndef SCHEMA_PATH +#define SCHEMA_PATH MODEL_PATH "grammars/blog.schema.json" +#endif + +std::string read_file(const char* filePath) { + std::ifstream file(filePath); + std::stringstream buffer; + buffer << file.rdbuf(); + return buffer.str(); +} +#if USE_GUIDANCE +TEST(LogitsProcessorTests, TestRegex) { + std::string regex = "answer: .*"; + std::string text = "answer: I am a robot"; + auto model = Generators::CreateModel(Generators::GetOrtEnv(), MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32"); + auto tokenizer = model->CreateTokenizer(); + auto params = Generators::CreateGeneratorParams(*model); + params->SetGuidance("regex", regex); + auto generator = Generators::CreateGenerator(*model, *params); + auto processor = std::make_unique(*generator->state_); + auto target_ids = Generators::GuidanceLogitsProcessor::tokenize_partial(tokenizer.get(), tokenizer->Encode(Generators::GuidanceLogitsProcessor::kTokenizePrefixStr).size(), + reinterpret_cast(text.c_str()), text.size()); + for (auto id : target_ids) { + auto mask = processor->GetMask(); + auto tokens = std::vector{static_cast(id)}; + processor->CommitTokens(std::span(tokens)); + } +} + +TEST(LogitsProcessorTests, TestJsonSchema) { + std::string json_schema = read_file(MODEL_PATH "grammars/blog.schema.json"); + std::string text = read_file(MODEL_PATH "grammars/blog.sample.json"); + auto model = Generators::CreateModel(Generators::GetOrtEnv(), MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32"); + + auto tokenizer = model->CreateTokenizer(); + auto params = Generators::CreateGeneratorParams(*model); + params->SetGuidance("json_schema", json_schema); + auto generator = Generators::CreateGenerator(*model, *params); + auto processor = std::make_unique(*generator->state_); + auto target_ids = Generators::GuidanceLogitsProcessor::tokenize_partial(tokenizer.get(), tokenizer->Encode(Generators::GuidanceLogitsProcessor::kTokenizePrefixStr).size(), + reinterpret_cast(text.c_str()), text.size()); + for (auto id : target_ids) { + auto mask = processor->GetMask(); + auto tokens = std::vector{static_cast(id)}; + processor->CommitTokens(std::span(tokens)); + } +} + +#endif \ No newline at end of file diff --git a/test/model_tests.cpp b/test/model_tests.cpp index 321d1ac46..35e4b9b53 100644 --- a/test/model_tests.cpp +++ b/test/model_tests.cpp @@ -3,6 +3,7 @@ #include #include +#include #include #include #include diff --git a/test/sampling_benchmark.cpp b/test/sampling_benchmark.cpp index eb7f04cd9..cc1b151e6 100644 --- a/test/sampling_benchmark.cpp +++ b/test/sampling_benchmark.cpp @@ -3,6 +3,7 @@ #include #include +#include #include #include #include diff --git a/test/sampling_tests.cpp b/test/sampling_tests.cpp index a71910b15..6d95875f1 100644 --- a/test/sampling_tests.cpp +++ b/test/sampling_tests.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include diff --git a/test/test_models/grammars/blog.sample.json b/test/test_models/grammars/blog.sample.json new file mode 100644 index 000000000..592499a6e --- /dev/null +++ b/test/test_models/grammars/blog.sample.json @@ -0,0 +1,20 @@ +{ + "title": "New Blog Post", + "content": "This is the content of the blog post...", + "publishedDate": "2023-08-25T15:00:00Z", + "author": { + "username": "authoruser", + "email": "author@example.com", + "fullName": "Author User", + "age": 30, + "location": "Earth", + "interests": [ + "Technology", + "Foo" + ] + }, + "tags": [ + "Technology", + "Programming" + ] +} \ No newline at end of file diff --git a/test/test_models/grammars/blog.schema.json b/test/test_models/grammars/blog.schema.json new file mode 100644 index 000000000..11e042c29 --- /dev/null +++ b/test/test_models/grammars/blog.schema.json @@ -0,0 +1,54 @@ +{ + "description": "A representation of a blog post", + "type": "object", + "required": [ + "title", + "content", + "author" + ], + "additionalProperties": false, + "properties": { + "title": { + "type": "string" + }, + "content": { + "type": "string" + }, + "publishedDate": { + "type": "string" + }, + "author": { + "type": "object", + "properties": { + "username": { + "type": "string" + }, + "email": { + "type": "string" + }, + "fullName": { + "type": "string" + }, + "age": { + "type": "integer" + }, + "location": { + "type": "string" + }, + "interests": { + "type": "array", + "items": { + "type": "string" + } + } + }, + "additionalProperties": false + }, + "tags": { + "type": "array", + "items": { + "type": "string" + } + } + } +} \ No newline at end of file diff --git a/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/Dockerfile b/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/Dockerfile index efb417e51..a1405b097 100644 --- a/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/Dockerfile @@ -6,6 +6,9 @@ FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_aar ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/install_centos.sh && /tmp/scripts/install_deps.sh && rm -rf /tmp/scripts +ENV PATH="/usr/.cargo/bin:$PATH" +ENV RUSTUP_HOME="/usr/.rustup" +ENV CARGO_HOME="/usr/.cargo" ARG BUILD_UID=1001 ARG BUILD_USER=onnxruntimedev RUN adduser --uid $BUILD_UID $BUILD_USER diff --git a/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_deps.sh b/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_deps.sh index efbe3ef40..9718081d7 100755 --- a/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_deps.sh +++ b/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_deps.sh @@ -65,5 +65,12 @@ fi GetFile https://nodejs.org/dist/v18.17.1/node-v18.17.1-linux-${NODEJS_ARCH}.tar.gz /tmp/src/node-v18.17.1-linux-${NODEJS_ARCH}.tar.gz tar --strip 1 -xf /tmp/src/node-v18.17.1-linux-${NODEJS_ARCH}.tar.gz -C /usr +# Install Rust +export RUSTUP_HOME=/usr/.rustup +export CARGO_HOME=/usr/.cargo +curl --proto '=https' --tlsv1.2 https://sh.rustup.rs -sSf | sh -s -- -y --default-toolchain=1.82.0 +chmod -R 777 /usr/.rustup +chmod -R 777 /usr/.cargo + cd / rm -rf /tmp/src diff --git a/tools/ci_build/github/linux/docker/manylinux/Dockerfile.manylinux2_28_cpu b/tools/ci_build/github/linux/docker/manylinux/Dockerfile.manylinux2_28_cpu index 6af5a8a21..af71099df 100644 --- a/tools/ci_build/github/linux/docker/manylinux/Dockerfile.manylinux2_28_cpu +++ b/tools/ci_build/github/linux/docker/manylinux/Dockerfile.manylinux2_28_cpu @@ -2,7 +2,9 @@ FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64 ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/install_centos_gcc12.sh && /tmp/scripts/install_deps.sh && rm -rf /tmp/scripts - +ENV PATH="/usr/.cargo/bin:$PATH" +ENV RUSTUP_HOME="/usr/.rustup" +ENV CARGO_HOME="/usr/.cargo" ARG BUILD_UID=1001 ARG BUILD_USER=onnxruntimedev RUN adduser --uid $BUILD_UID $BUILD_USER diff --git a/tools/ci_build/github/linux/docker/manylinux/Dockerfile.manylinux2_28_cuda_11.8 b/tools/ci_build/github/linux/docker/manylinux/Dockerfile.manylinux2_28_cuda_11.8 index 6df955c02..673ba0300 100644 --- a/tools/ci_build/github/linux/docker/manylinux/Dockerfile.manylinux2_28_cuda_11.8 +++ b/tools/ci_build/github/linux/docker/manylinux/Dockerfile.manylinux2_28_cuda_11.8 @@ -10,7 +10,9 @@ else \ echo "Using default gcc because CUDA version is less than 12"; \ cd /tmp/scripts && /tmp/scripts/install_centos.sh && /tmp/scripts/install_deps.sh && rm -rf /tmp/scripts; \ fi - +ENV PATH="/usr/.cargo/bin:$PATH" +ENV RUSTUP_HOME="/usr/.rustup" +ENV CARGO_HOME="/usr/.cargo" ARG BUILD_UID=1001 ARG BUILD_USER=onnxruntimedev RUN adduser --uid $BUILD_UID $BUILD_USER diff --git a/tools/ci_build/github/linux/docker/manylinux/Dockerfile.manylinux2_28_cuda_12.2 b/tools/ci_build/github/linux/docker/manylinux/Dockerfile.manylinux2_28_cuda_12.2 index c865579bd..add1e4ef2 100644 --- a/tools/ci_build/github/linux/docker/manylinux/Dockerfile.manylinux2_28_cuda_12.2 +++ b/tools/ci_build/github/linux/docker/manylinux/Dockerfile.manylinux2_28_cuda_12.2 @@ -10,6 +10,9 @@ else \ echo "Using default gcc because CUDA version is less than 12"; \ cd /tmp/scripts && /tmp/scripts/install_centos.sh && /tmp/scripts/install_deps.sh && rm -rf /tmp/scripts; \ fi +ENV PATH="/usr/.cargo/bin:$PATH" +ENV RUSTUP_HOME="/usr/.rustup" +ENV CARGO_HOME="/usr/.cargo" ARG BUILD_UID=1001 ARG BUILD_USER=onnxruntimedev RUN adduser --uid $BUILD_UID $BUILD_USER diff --git a/tools/ci_build/github/linux/docker/manylinux/Dockerfile.manylinux2_28_rocm b/tools/ci_build/github/linux/docker/manylinux/Dockerfile.manylinux2_28_rocm index 6af5a8a21..af71099df 100644 --- a/tools/ci_build/github/linux/docker/manylinux/Dockerfile.manylinux2_28_rocm +++ b/tools/ci_build/github/linux/docker/manylinux/Dockerfile.manylinux2_28_rocm @@ -2,7 +2,9 @@ FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64 ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/install_centos_gcc12.sh && /tmp/scripts/install_deps.sh && rm -rf /tmp/scripts - +ENV PATH="/usr/.cargo/bin:$PATH" +ENV RUSTUP_HOME="/usr/.rustup" +ENV CARGO_HOME="/usr/.cargo" ARG BUILD_UID=1001 ARG BUILD_USER=onnxruntimedev RUN adduser --uid $BUILD_UID $BUILD_USER diff --git a/tools/ci_build/github/linux/docker/manylinux/scripts/install_deps.sh b/tools/ci_build/github/linux/docker/manylinux/scripts/install_deps.sh index 2a4fd31f2..fa488debe 100755 --- a/tools/ci_build/github/linux/docker/manylinux/scripts/install_deps.sh +++ b/tools/ci_build/github/linux/docker/manylinux/scripts/install_deps.sh @@ -80,5 +80,12 @@ cmake --build build-cmake mv ./build-cmake/ninja /usr/bin popd +# Install Rust +export RUSTUP_HOME=/usr/.rustup +export CARGO_HOME=/usr/.cargo +curl --proto '=https' --tlsv1.2 https://sh.rustup.rs -sSf | sh -s -- -y --default-toolchain=1.82.0 +chmod -R 777 /usr/.rustup +chmod -R 777 /usr/.cargo + cd / rm -rf /tmp/src