Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CoreML] support coreml model cache #23065

Merged
merged 36 commits into from
Dec 31, 2024
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
1d1c874
support coreml model cache
Dec 10, 2024
b7888c4
improve
wejoncy Dec 11, 2024
f492fee
fix
wejoncy Dec 11, 2024
7b11848
better hash
Dec 16, 2024
4a5772f
refactor output -path
Dec 16, 2024
b57aa28
address comments
Dec 16, 2024
723b2dd
remove extra check
Dec 16, 2024
4f0ac2a
Apply suggestions from code review
wejoncy Dec 16, 2024
781e42e
improve doc
Dec 16, 2024
26775b4
typo
Dec 16, 2024
89317c6
check cache-key
Dec 16, 2024
773dce0
validate alpha-number
Dec 16, 2024
e82f3e4
fix
Dec 16, 2024
d3d25b9
format
Dec 16, 2024
d053dc0
fix bug
Dec 16, 2024
2779e3d
format
Dec 16, 2024
c7194ad
renaming
Dec 17, 2024
8204e64
max 64 chars
Dec 17, 2024
9c9374c
polish cache path
Dec 18, 2024
8faf178
fix
Dec 18, 2024
e4e3547
Update include/onnxruntime/core/providers/coreml/coreml_provider_fact…
wejoncy Dec 19, 2024
728fbee
Update include/onnxruntime/core/providers/coreml/coreml_provider_fact…
wejoncy Dec 19, 2024
e49112c
Update include/onnxruntime/core/providers/coreml/coreml_provider_fact…
wejoncy Dec 19, 2024
d7b867c
disable caching in runtime.
Dec 19, 2024
7c466a1
Apply suggestions from code review
wejoncy Dec 20, 2024
d1e7633
address comments
Dec 20, 2024
a5ffe03
fix
Dec 20, 2024
5518e38
format
Dec 20, 2024
70075e5
format
Dec 20, 2024
dc52361
Update onnxruntime/core/providers/coreml/coreml_execution_provider.cc
wejoncy Dec 24, 2024
31e6c68
add test for model cache
Dec 24, 2024
78b2a4b
ut
Dec 24, 2024
5f7bddc
ut
Dec 24, 2024
f9db65f
lint
Dec 24, 2024
391f914
Merge branch 'main' into jicwen/coreml_cache
wejoncy Dec 30, 2024
201368d
fix test
wejoncy Dec 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,22 @@ static const char* const kCoremlProviderOption_SpecializationStrategy = "Special
static const char* const kCoremlProviderOption_ProfileComputePlan = "ProfileComputePlan";
// please refer to https://developer.apple.com/documentation/coreml/mlmodelconfiguration/allowlowprecisionaccumulationongpu
static const char* const kCoremlProviderOption_AllowLowPrecisionAccumulationOnGPU = "AllowLowPrecisionAccumulationOnGPU";
// Specify the directory to cache any CoreML models created from the ONNX model in.
// CoreML EP will convert onnx subgraph to CoreML model and save to disk.
// If this path is not specified, the model will be saved to a temp directory and deleted after the session is closed.
// otherwise, the model will be saved to the specified path and User should manage to delete the model.

// we do NOT detect if the onnx model has changed and no longer matches the cached model.
// the user should carefully manage the cache if modifying/replacing a model.
// The cache key is generated by
// 1. User provided key in metadata_props if found (preferred)
// 2. Hash of the model url the inference session was created with
// 3. Hash of the input/output names of the model
// Please find out how to set metadata_props in the onnxruntime API documentation. https://onnxruntime.ai/docs/execution-providers/CoreML-ExecutionProvider.html#configuration-options
static const char* const kCoremlProviderOption_ModelCacheDirectory = "ModelCacheDirectory";

// User provided cache-key in metadata_props.
static const char* const kCOREML_CACHE_KEY = "COREML_CACHE_KEY";

#ifdef __cplusplus
extern "C" {
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/platform/env.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ class Env {
#ifdef _WIN32
/// \brief Returns true if the directory exists.
virtual bool FolderExists(const std::wstring& path) const = 0;
virtual bool FileExists(const std::wstring& path) const = 0;
/// \brief Recursively creates the directory, if it doesn't exist.
virtual common::Status CreateFolder(const std::wstring& path) const = 0;
// Mainly for use with protobuf library
Expand All @@ -206,6 +207,7 @@ class Env {
#endif
/// \brief Returns true if the directory exists.
virtual bool FolderExists(const std::string& path) const = 0;
virtual bool FileExists(const std::string& path) const = 0;
/// \brief Recursively creates the directory, if it doesn't exist.
virtual common::Status CreateFolder(const std::string& path) const = 0;
// Recursively deletes the directory and its contents.
Expand Down
8 changes: 8 additions & 0 deletions onnxruntime/core/platform/posix/env.cc
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,14 @@ class PosixEnv : public Env {
return S_ISDIR(sb.st_mode);
}

bool FileExists(const std::string& path) const override {
struct stat sb;
if (stat(path.c_str(), &sb)) {
return false;
}
return S_ISREG(sb.st_mode);
}

common::Status CreateFolder(const std::string& path) const override {
size_t pos = 0;
do {
Expand Down
10 changes: 10 additions & 0 deletions onnxruntime/core/platform/windows/env.cc
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,16 @@ bool WindowsEnv::FolderExists(const std::string& path) const {
return (attributes != INVALID_FILE_ATTRIBUTES) && (attributes & FILE_ATTRIBUTE_DIRECTORY);
}

bool WindowsEnv::FileExists(const std::wstring& path) const {
DWORD attributes = GetFileAttributesW(path.c_str());
return (attributes != INVALID_FILE_ATTRIBUTES) && (attributes & FILE_ATTRIBUTE_NORMAL);
}

bool WindowsEnv::FileExists(const std::string& path) const {
DWORD attributes = GetFileAttributesA(path.c_str());
return (attributes != INVALID_FILE_ATTRIBUTES) && (attributes & FILE_ATTRIBUTE_NORMAL);
}

common::Status WindowsEnv::CreateFolder(const std::wstring& path) const {
size_t pos = 0;
do {
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/platform/windows/env.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ class WindowsEnv : public Env {
MappedMemoryPtr& mapped_memory) const override;
bool FolderExists(const std::wstring& path) const override;
bool FolderExists(const std::string& path) const override;
bool FileExists(const std::wstring& path) const override;
bool FileExists(const std::string& path) const override;
common::Status CreateFolder(const std::wstring& path) const override;
common::Status CreateFolder(const std::string& path) const override;
common::Status DeleteFolder(const PathString& path) const override;
Expand Down
95 changes: 86 additions & 9 deletions onnxruntime/core/providers/coreml/builders/model_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -390,13 +390,69 @@ void CreateEmptyFile(const std::string& filename) {

#endif // defined(COREML_ENABLE_MLPROGRAM)

std::string GetModelOutputPath(bool create_ml_program) {
// path is used to create the ML Package directory for ML Program, and for the model directly otherwise.
auto path = util::GetTemporaryFilePath();
if (!create_ml_program) {
path += ".model.mlmodel";
}
std::string GetModelOutputPath(const CoreMLOptions& coreml_options,
const GraphViewer& graph_viewer,
const logging::Logger& logger) {
const std::string& subgraph_name = graph_viewer.Name();
std::string path;
if (coreml_options.ModelCacheDirectory().empty()) {
// path is used to create the ML Package directory for ML Program, and for the model directly otherwise.
path = util::GetTemporaryFilePath();
if (!coreml_options.CreateMLProgram()) {
path += ".model.mlmodel";
}
} else {
// subgraph_name is uniquely generated by
// onnxruntime/core/providers/coreml/coreml_execution_provider.cc::gen_metadef_name
// int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash);
// MakeString(user_provide_key, "_", COREML, "_", model_hash, "_", metadef_id);
std::string_view cache_key = std::string_view(subgraph_name)
.substr(0, subgraph_name.find_first_of("_"));
skottmckay marked this conversation as resolved.
Show resolved Hide resolved
// subgraph_short_name is metadef_id
std::string_view subgraph_short_name = std::string_view(subgraph_name)
.substr(subgraph_name.find_last_of("_") + 1);
path = MakeString(std::string(coreml_options.ModelCacheDirectory()), "/", cache_key);
if (!Env::Default().CreateFolder(path).IsOK()) {
LOGS(logger, ERROR) << "Failed to create cache directory " << path << ". Model caching is disabled.";
coreml_options.DisableModelCache();
return GetModelOutputPath(coreml_options, graph_viewer, logger);
}
// Write the model path to a file in the cache directory.
// This is for developers to know what the cached model is as we used a hash for the directory name.
if (!Env::Default().FileExists(ToPathString(path + "/model.txt"))) {
wejoncy marked this conversation as resolved.
Show resolved Hide resolved
const Graph* main_graph = &graph_viewer.GetGraph();
while (main_graph->IsSubgraph()) {
main_graph = main_graph->ParentGraph();
}
std::ofstream file(path + "/model.txt");
if (!file.is_open()) {
LOGS(logger, ERROR) << "Failed to open file " << path + "/model.txt";
} else {
file << main_graph->ModelPath().string();
file.close();
}
}

path = MakeString(path, "/", subgraph_short_name);
// Set the model cache path with setting of RequireStaticShape and ModelFormat
if (coreml_options.RequireStaticShape()) {
path += "_static";
} else {
path += "_dynamic";
}

if (coreml_options.CreateMLProgram()) {
path += "_mlprogram";
} else {
path += "_nn";
}
if (!Env::Default().CreateFolder(path).IsOK()) {
LOGS(logger, ERROR) << "Failed to create cache directory " << path << ". Model caching is disabled.";
coreml_options.DisableModelCache();
return GetModelOutputPath(coreml_options, graph_viewer, logger);
}
path += "/model";
}
return path;
}
} // namespace
Expand All @@ -410,10 +466,21 @@ ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logge
coreml_version_(coreml_version),
coreml_options_(coreml_options),
create_ml_program_(coreml_options.CreateMLProgram()),
model_output_path_(GetModelOutputPath(create_ml_program_)),
model_output_path_(GetModelOutputPath(coreml_options_, graph_viewer, logger)), // coreml_options_ must be set before this
onnx_input_names_(std::move(onnx_input_names)),
onnx_output_names_(std::move(onnx_output_names)),
coreml_model_(std::make_unique<CoreML::Specification::Model>()) {
// GetTemporaryFilePath() always produce a unique path for the model and this is not existed
// Mlprogram will create a folder while NN create a file
if (Env::Default().FolderExists(ToPathString(model_output_path_)) ||
Env::Default().FileExists(ToPathString(model_output_path_))) {
is_model_cached_ = true;
LOGS(logger, INFO) << "Model is already cached in " << model_output_path_
<< " and will be reused. If you want to update the model or hit other issues, "
<< "please consider to clear the cache and retry.";
return;
}

if (create_ml_program_) {
#if defined(COREML_ENABLE_MLPROGRAM)
coreml_model_->set_specificationversion(CoreMLSpecVersion());
Expand Down Expand Up @@ -847,6 +914,10 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i

input_output_info_.emplace(name, OnnxTensorInfo{data_type, shape});

if (IsModelCached()) {
return Status::OK();
}

#if defined(COREML_ENABLE_MLPROGRAM)
if (create_ml_program_) {
if (is_input) {
Expand Down Expand Up @@ -1056,8 +1127,14 @@ Status ModelBuilder::Build(const GraphViewer& graph_viewer, const logging::Logge
ModelBuilder builder(graph_viewer, logger, coreml_version, coreml_options,
std::move(onnx_input_names), std::move(onnx_output_names));

ORT_RETURN_IF_ERROR(builder.CreateModel());
ORT_RETURN_IF_ERROR(builder.SaveModel());
if (!builder.IsModelCached()) {
ORT_RETURN_IF_ERROR(builder.CreateModel());
ORT_RETURN_IF_ERROR(builder.SaveModel());
} else {
// runtime requires the input/output names to be passed
ORT_RETURN_IF_ERROR(builder.RegisterModelInputs());
ORT_RETURN_IF_ERROR(builder.RegisterModelOutputs());
}

return builder.LoadModel(model);
}
Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/core/providers/coreml/builders/model_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class ModelBuilder {
// We only support CoreML 3 and later so the spec version is always version + 1.
int32_t CoreMLVersion() const { return coreml_version_; }
int32_t CoreMLSpecVersion() const { return coreml_version_ + 1; }
bool IsModelCached() const { return is_model_cached_; }

// Returns true if we are creating an ML Program
bool CreateMLProgram() const {
Expand Down Expand Up @@ -218,8 +219,9 @@ class ModelBuilder {
const logging::Logger& logger_;
const int32_t coreml_version_;
CoreMLOptions coreml_options_;
const bool create_ml_program_; // ML Program (CoreML5, iOS 15+, macOS 12+) or NeuralNetwork (old)
const std::string model_output_path_; // create_ml_program_ ? dir for mlpackage : filename for mlmodel
const bool create_ml_program_; // ML Program (CoreML5, iOS 15+, macOS 12+) or NeuralNetwork (old)
std::string model_output_path_; // create_ml_program_ ? dir for mlpackage : filename for mlmodel
bool is_model_cached_{false};

std::vector<std::string> onnx_input_names_;
std::vector<std::string> onnx_output_names_;
Expand Down
29 changes: 27 additions & 2 deletions onnxruntime/core/providers/coreml/coreml_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "core/providers/coreml/model/host_utils.h"
#include "core/providers/coreml/model/model.h"
#include "core/providers/coreml/shape_utils.h"
#include "core/graph/model.h"

namespace onnxruntime {

Expand Down Expand Up @@ -52,12 +53,36 @@ CoreMLExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie
const auto builder_params = coreml::MakeOpBuilderParams(graph_viewer, coreml_version_,
coreml_options_.RequireStaticShape(), coreml_options_.CreateMLProgram());
const auto supported_nodes = coreml::GetSupportedNodes(graph_viewer, builder_params, logger);

const Graph* main_graph = &graph_viewer.GetGraph();
while (main_graph->IsSubgraph()) {
main_graph = main_graph->ParentGraph();
}
const auto& metadata = main_graph->GetModel().MetaData();

std::string user_provided_key = metadata.count(kCOREML_CACHE_KEY) > 0
? metadata.at(kCOREML_CACHE_KEY)
: "";
if (user_provided_key.size() > 64 ||
std::any_of(user_provided_key.begin(), user_provided_key.end(),
[](unsigned char c) { return !std::isalnum(c); })) {
LOGS(logger, ERROR) << "[" << kCOREML_CACHE_KEY << ":" << user_provided_key << "] is not a valid cache key."
<< " It should be alphanumeric and less than 64 characters.";
wejoncy marked this conversation as resolved.
Show resolved Hide resolved
}
const auto gen_metadef_name =
[&]() {
HashValue model_hash;
int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash);
return MakeString(COREML, "_", model_hash, "_", metadef_id);
// use model_hash as the key if user doesn't provide one
if (user_provided_key.empty()) {
// user passed a empty string
// model_hash is a 64-bit hash value of model_path if model_path is not empty,
// otherwise it hashes the graph input names and all the node output names.
// it can't guarantee the uniqueness of the key, so user should manager the key for the best.
user_provided_key = std::to_string(model_hash);
}
// The string format is used by onnxruntime/core/providers/coreml/builders/model_builder.cc::GetModelOutputPath
// If the format changes, the function should be updated accordingly.
return MakeString(user_provided_key, "_", COREML, "_", model_hash, "_", metadef_id);
};

result = utils::CreateSupportedPartitions(graph_viewer, supported_nodes, {},
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/providers/coreml/coreml_options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "core/providers/coreml/coreml_provider_factory.h" // defines flags
#include "core/providers/coreml/model/host_utils.h"
#include "core/providers/coreml/builders/helper.h"
#include "core/platform/env.h"

namespace onnxruntime {

Expand Down Expand Up @@ -71,6 +72,7 @@ void CoreMLOptions::ValidateAndParseProviderOption(const ProviderOptions& option
kCoremlProviderOption_SpecializationStrategy,
kCoremlProviderOption_ProfileComputePlan,
kCoremlProviderOption_AllowLowPrecisionAccumulationOnGPU,
kCoremlProviderOption_ModelCacheDirectory,
};
// Validate the options
for (const auto& option : options) {
Expand Down Expand Up @@ -103,6 +105,8 @@ void CoreMLOptions::ValidateAndParseProviderOption(const ProviderOptions& option
profile_compute_plan_ = option.second == "1";
} else if (kCoremlProviderOption_AllowLowPrecisionAccumulationOnGPU == option.first) {
allow_low_precision_accumulation_on_gpu_ = option.second == "1";
} else if (kCoremlProviderOption_ModelCacheDirectory == option.first) {
model_cache_directory_ = option.second;
}
}
}
Expand Down
8 changes: 8 additions & 0 deletions onnxruntime/core/providers/coreml/coreml_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
std::string strategy_;
bool profile_compute_plan_{false};
bool allow_low_precision_accumulation_on_gpu_{false};
// path to store the converted coreml model
// we may run DisableModelCache() to disable model caching
mutable std::string model_cache_directory_;

Check warning on line 22 in onnxruntime/core/providers/coreml/coreml_options.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/coreml/coreml_options.h:22: Add #include <string> for string [build/include_what_you_use] [4]

public:
explicit CoreMLOptions(uint32_t coreml_flags);
Expand All @@ -32,6 +35,11 @@
bool UseStrategy(std::string_view strategy) const { return strategy_ == strategy; }
bool ProfileComputePlan() const { return profile_compute_plan_ && create_mlprogram_; }

std::string_view ModelCacheDirectory() const { return model_cache_directory_; }
// The options specified by the user are const, but if there's an error setting up caching we disable it
// so that the EP can still be used. The error is logged for the user to investigate.
void DisableModelCache() const { model_cache_directory_.clear(); }

private:
void ValidateAndParseProviderOption(const ProviderOptions& options);
};
Expand Down
Loading
Loading