Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CoreML] support coreml model cache #23065

Merged
merged 36 commits into from
Dec 31, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
1d1c874
support coreml model cache
Dec 10, 2024
b7888c4
improve
wejoncy Dec 11, 2024
f492fee
fix
wejoncy Dec 11, 2024
7b11848
better hash
Dec 16, 2024
4a5772f
refactor output -path
Dec 16, 2024
b57aa28
address comments
Dec 16, 2024
723b2dd
remove extra check
Dec 16, 2024
4f0ac2a
Apply suggestions from code review
wejoncy Dec 16, 2024
781e42e
improve doc
Dec 16, 2024
26775b4
typo
Dec 16, 2024
89317c6
check cache-key
Dec 16, 2024
773dce0
validate alpha-number
Dec 16, 2024
e82f3e4
fix
Dec 16, 2024
d3d25b9
format
Dec 16, 2024
d053dc0
fix bug
Dec 16, 2024
2779e3d
format
Dec 16, 2024
c7194ad
renaming
Dec 17, 2024
8204e64
max 64 chars
Dec 17, 2024
9c9374c
polish cache path
Dec 18, 2024
8faf178
fix
Dec 18, 2024
e4e3547
Update include/onnxruntime/core/providers/coreml/coreml_provider_fact…
wejoncy Dec 19, 2024
728fbee
Update include/onnxruntime/core/providers/coreml/coreml_provider_fact…
wejoncy Dec 19, 2024
e49112c
Update include/onnxruntime/core/providers/coreml/coreml_provider_fact…
wejoncy Dec 19, 2024
d7b867c
disable caching in runtime.
Dec 19, 2024
7c466a1
Apply suggestions from code review
wejoncy Dec 20, 2024
d1e7633
address comments
Dec 20, 2024
a5ffe03
fix
Dec 20, 2024
5518e38
format
Dec 20, 2024
70075e5
format
Dec 20, 2024
dc52361
Update onnxruntime/core/providers/coreml/coreml_execution_provider.cc
wejoncy Dec 24, 2024
31e6c68
add test for model cache
Dec 24, 2024
78b2a4b
ut
Dec 24, 2024
5f7bddc
ut
Dec 24, 2024
f9db65f
lint
Dec 24, 2024
391f914
Merge branch 'main' into jicwen/coreml_cache
wejoncy Dec 30, 2024
201368d
fix test
wejoncy Dec 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,18 @@ static const char* const kCoremlProviderOption_SpecializationStrategy = "Special
static const char* const kCoremlProviderOption_ProfileComputePlan = "ProfileComputePlan";
// please refer to https://developer.apple.com/documentation/coreml/mlmodelconfiguration/allowlowprecisionaccumulationongpu
static const char* const kCoremlProviderOption_AllowLowPrecisionAccumulationOnGPU = "AllowLowPrecisionAccumulationOnGPU";
// Specify the path to cache the model.
wejoncy marked this conversation as resolved.
Show resolved Hide resolved
// CoreML EP will convert onnx subgraph to CoreML model and save to disk.
// If this path is not specified, the model will be saved to a temp directory and deleted after the session is closed.
// otherwise, the model will be saved to the specified path and User should manage to delete the model.
// The basic logic is:
// if (ModelCachePath != nullptr && ModelCachePath/cache_coreml.exists()) {
// // load from cache_coreml
// } else {
// // save to ModelCachePath
// }
// we wound not detect if the cached model match the onnx subgraph, so User should carefully manage the cache for a new model.
wejoncy marked this conversation as resolved.
Show resolved Hide resolved
static const char* const kCoremlProviderOption_ModelCachePath = "ModelCachePath";
wejoncy marked this conversation as resolved.
Show resolved Hide resolved

#ifdef __cplusplus
extern "C" {
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/platform/env.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ class Env {
#ifdef _WIN32
/// \brief Returns true if the directory exists.
virtual bool FolderExists(const std::wstring& path) const = 0;
virtual bool FileExists(const std::wstring& path) const = 0;
/// \brief Recursively creates the directory, if it doesn't exist.
virtual common::Status CreateFolder(const std::wstring& path) const = 0;
// Mainly for use with protobuf library
Expand All @@ -206,6 +207,7 @@ class Env {
#endif
/// \brief Returns true if the directory exists.
virtual bool FolderExists(const std::string& path) const = 0;
virtual bool FileExists(const std::string& path) const = 0;
/// \brief Recursively creates the directory, if it doesn't exist.
virtual common::Status CreateFolder(const std::string& path) const = 0;
// Recursively deletes the directory and its contents.
Expand Down
8 changes: 8 additions & 0 deletions onnxruntime/core/platform/posix/env.cc
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,14 @@ class PosixEnv : public Env {
return S_ISDIR(sb.st_mode);
}

bool FileExists(const std::string& path) const override {
struct stat sb;
if (stat(path.c_str(), &sb)) {
return false;
}
return S_ISREG(sb.st_mode);
}

common::Status CreateFolder(const std::string& path) const override {
size_t pos = 0;
do {
Expand Down
10 changes: 10 additions & 0 deletions onnxruntime/core/platform/windows/env.cc
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,16 @@ bool WindowsEnv::FolderExists(const std::string& path) const {
return (attributes != INVALID_FILE_ATTRIBUTES) && (attributes & FILE_ATTRIBUTE_DIRECTORY);
}

bool WindowsEnv::FileExists(const std::wstring& path) const {
DWORD attributes = GetFileAttributesW(path.c_str());
return (attributes != INVALID_FILE_ATTRIBUTES) && (attributes & FILE_ATTRIBUTE_NORMAL);
}

bool WindowsEnv::FileExists(const std::string& path) const {
DWORD attributes = GetFileAttributesA(path.c_str());
return (attributes != INVALID_FILE_ATTRIBUTES) && (attributes & FILE_ATTRIBUTE_NORMAL);
}

common::Status WindowsEnv::CreateFolder(const std::wstring& path) const {
size_t pos = 0;
do {
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/platform/windows/env.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ class WindowsEnv : public Env {
MappedMemoryPtr& mapped_memory) const override;
bool FolderExists(const std::wstring& path) const override;
bool FolderExists(const std::string& path) const override;
bool FileExists(const std::wstring& path) const override;
bool FileExists(const std::string& path) const override;
common::Status CreateFolder(const std::wstring& path) const override;
common::Status CreateFolder(const std::string& path) const override;
common::Status DeleteFolder(const PathString& path) const override;
Expand Down
79 changes: 70 additions & 9 deletions onnxruntime/core/providers/coreml/builders/model_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -390,13 +390,53 @@ void CreateEmptyFile(const std::string& filename) {

#endif // defined(COREML_ENABLE_MLPROGRAM)

std::string GetModelOutputPath(bool create_ml_program) {
// path is used to create the ML Package directory for ML Program, and for the model directly otherwise.
auto path = util::GetTemporaryFilePath();
if (!create_ml_program) {
path += ".model.mlmodel";
}
std::string GetModelOutputPath(const CoreMLOptions& coreml_options,
const GraphViewer& graph_viewer) {
const std::string& subgraph_name = graph_viewer.Name();
std::string path;
if (coreml_options.ModelCachePath().empty()) {
// path is used to create the ML Package directory for ML Program, and for the model directly otherwise.
path = util::GetTemporaryFilePath();
if (!coreml_options.CreateMLProgram()) {
path += ".model.mlmodel";
}
} else {
// subgraph_name is uniquely generated by
// onnxruntime/core/providers/coreml/coreml_execution_provider.cc::gen_metadef_name
// int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash);
// MakeString(user_provide_key, "_", COREML, "_", model_hash, "_", metadef_id);
ORT_ENFORCE(std::count(subgraph_name.begin(), subgraph_name.end(), '_') == 3,
"Unexpected graph name format: ", subgraph_name);
wejoncy marked this conversation as resolved.
Show resolved Hide resolved
std::string_view cache_key = std::string_view(subgraph_name).substr(0, subgraph_name.find_first_of("_"));
path = MakeString(std::string(coreml_options.ModelCachePath()), "/", cache_key);
ORT_THROW_IF_ERROR(Env::Default().CreateFolder(path));
wejoncy marked this conversation as resolved.
Show resolved Hide resolved
if (!Env::Default().FileExists(ToPathString(path + "/model.txt"))) {
wejoncy marked this conversation as resolved.
Show resolved Hide resolved
const Graph* main_graph = &graph_viewer.GetGraph();
while (main_graph->IsSubgraph()) {
main_graph = main_graph->ParentGraph();
}
std::ofstream file(path + "/model.txt");
ORT_ENFORCE(file.is_open(), "Failed to open file ", path + "/model.txt");
wejoncy marked this conversation as resolved.
Show resolved Hide resolved
file << main_graph->ModelPath().string();
file.close();
}

path = MakeString(path, "/", subgraph_name);
// Set the model cache path with equireStaticShape and ModelFormat
wejoncy marked this conversation as resolved.
Show resolved Hide resolved
if (coreml_options.RequireStaticShape()) {
path += "/static_shape";
} else {
path += "/dynamic_shape";
}

if (coreml_options.CreateMLProgram()) {
path += ".mlpackage";
} else {
path += ".mlnnmodel";
}
ORT_THROW_IF_ERROR(Env::Default().CreateFolder(path));
path += "/mlmodel";
}
return path;
}
} // namespace
Expand All @@ -410,10 +450,21 @@ ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logge
coreml_version_(coreml_version),
coreml_options_(coreml_options),
create_ml_program_(coreml_options.CreateMLProgram()),
model_output_path_(GetModelOutputPath(create_ml_program_)),
model_output_path_(GetModelOutputPath(coreml_options, graph_viewer)),
onnx_input_names_(std::move(onnx_input_names)),
onnx_output_names_(std::move(onnx_output_names)),
coreml_model_(std::make_unique<CoreML::Specification::Model>()) {
// GetTemporaryFilePath() always produce a unique path for the model and this is not existed
// Mlprogram will create a folder while NN create a file
if (Env::Default().FolderExists(ToPathString(model_output_path_)) ||
Env::Default().FileExists(ToPathString(model_output_path_))) {
is_model_cached_ = true;
LOGS(logger, WARNING) << "Model is already cached in " << model_output_path_
wejoncy marked this conversation as resolved.
Show resolved Hide resolved
<< " and will be reused. If you want to update the model or hit other issues, "
<< "please consider to clear the cache and retry.";
return;
}

if (create_ml_program_) {
#if defined(COREML_ENABLE_MLPROGRAM)
coreml_model_->set_specificationversion(CoreMLSpecVersion());
Expand Down Expand Up @@ -847,6 +898,10 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i

input_output_info_.emplace(name, OnnxTensorInfo{data_type, shape});

if (IsModelCached()) {
return Status::OK();
}

#if defined(COREML_ENABLE_MLPROGRAM)
if (create_ml_program_) {
if (is_input) {
Expand Down Expand Up @@ -1056,8 +1111,14 @@ Status ModelBuilder::Build(const GraphViewer& graph_viewer, const logging::Logge
ModelBuilder builder(graph_viewer, logger, coreml_version, coreml_options,
std::move(onnx_input_names), std::move(onnx_output_names));

ORT_RETURN_IF_ERROR(builder.CreateModel());
ORT_RETURN_IF_ERROR(builder.SaveModel());
if (!builder.IsModelCached()) {
ORT_RETURN_IF_ERROR(builder.CreateModel());
ORT_RETURN_IF_ERROR(builder.SaveModel());
} else {
// runtime requires the input/output names to be passed
ORT_RETURN_IF_ERROR(builder.RegisterModelInputs());
ORT_RETURN_IF_ERROR(builder.RegisterModelOutputs());
}

return builder.LoadModel(model);
}
Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/core/providers/coreml/builders/model_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class ModelBuilder {
// We only support CoreML 3 and later so the spec version is always version + 1.
int32_t CoreMLVersion() const { return coreml_version_; }
int32_t CoreMLSpecVersion() const { return coreml_version_ + 1; }
bool IsModelCached() const { return is_model_cached_; }

// Returns true if we are creating an ML Program
bool CreateMLProgram() const {
Expand Down Expand Up @@ -218,8 +219,9 @@ class ModelBuilder {
const logging::Logger& logger_;
const int32_t coreml_version_;
CoreMLOptions coreml_options_;
const bool create_ml_program_; // ML Program (CoreML5, iOS 15+, macOS 12+) or NeuralNetwork (old)
const std::string model_output_path_; // create_ml_program_ ? dir for mlpackage : filename for mlmodel
const bool create_ml_program_; // ML Program (CoreML5, iOS 15+, macOS 12+) or NeuralNetwork (old)
std::string model_output_path_; // create_ml_program_ ? dir for mlpackage : filename for mlmodel
bool is_model_cached_{false};

std::vector<std::string> onnx_input_names_;
std::vector<std::string> onnx_output_names_;
Expand Down
14 changes: 13 additions & 1 deletion onnxruntime/core/providers/coreml/coreml_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "core/providers/coreml/model/host_utils.h"
#include "core/providers/coreml/model/model.h"
#include "core/providers/coreml/shape_utils.h"
#include "core/graph/model.h"

namespace onnxruntime {

Expand Down Expand Up @@ -57,7 +58,18 @@ CoreMLExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie
[&]() {
HashValue model_hash;
int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash);
return MakeString(COREML, "_", model_hash, "_", metadef_id);
std::string user_provide_key;
wejoncy marked this conversation as resolved.
Show resolved Hide resolved
const Graph* main_graph = &graph_viewer.GetGraph();
while (main_graph->IsSubgraph()) {
main_graph = main_graph->ParentGraph();
}
if (main_graph->GetModel().MetaData().count("CACHE_KEY") > 0) {
user_provide_key = graph_viewer.GetGraph().GetModel().MetaData().at("CACHE_KEY");
wejoncy marked this conversation as resolved.
Show resolved Hide resolved
} else {
// model_hash is a 64-bit hash value of model_path
user_provide_key = std::to_string(model_hash);
wejoncy marked this conversation as resolved.
Show resolved Hide resolved
}
return MakeString(user_provide_key, "_", COREML, "_", model_hash, "_", metadef_id);
};

result = utils::CreateSupportedPartitions(graph_viewer, supported_nodes, {},
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/providers/coreml/coreml_options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "core/providers/coreml/coreml_provider_factory.h" // defines flags
#include "core/providers/coreml/model/host_utils.h"
#include "core/providers/coreml/builders/helper.h"
#include "core/platform/env.h"

namespace onnxruntime {

Expand Down Expand Up @@ -71,6 +72,7 @@ void CoreMLOptions::ValidateAndParseProviderOption(const ProviderOptions& option
kCoremlProviderOption_SpecializationStrategy,
kCoremlProviderOption_ProfileComputePlan,
kCoremlProviderOption_AllowLowPrecisionAccumulationOnGPU,
kCoremlProviderOption_ModelCachePath,
wejoncy marked this conversation as resolved.
Show resolved Hide resolved
};
// Validate the options
for (const auto& option : options) {
Expand Down Expand Up @@ -103,6 +105,8 @@ void CoreMLOptions::ValidateAndParseProviderOption(const ProviderOptions& option
profile_compute_plan_ = option.second == "1";
} else if (kCoremlProviderOption_AllowLowPrecisionAccumulationOnGPU == option.first) {
allow_low_precision_accumulation_on_gpu_ = option.second == "1";
} else if (kCoremlProviderOption_ModelCachePath == option.first) {
model_cache_path_ = option.second;
}
}
}
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/providers/coreml/coreml_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
std::string strategy_;
bool profile_compute_plan_{false};
bool allow_low_precision_accumulation_on_gpu_{false};
// path to store the converted coreml model
std::string model_cache_path_;

Check warning on line 21 in onnxruntime/core/providers/coreml/coreml_options.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/coreml/coreml_options.h:21: Add #include <string> for string [build/include_what_you_use] [4]

public:
explicit CoreMLOptions(uint32_t coreml_flags);
Expand All @@ -32,6 +34,8 @@
bool UseStrategy(std::string_view strategy) const { return strategy_ == strategy; }
bool ProfileComputePlan() const { return profile_compute_plan_ && create_mlprogram_; }

std::string_view ModelCachePath() const { return model_cache_path_; }

private:
void ValidateAndParseProviderOption(const ProviderOptions& options);
};
Expand Down
Loading
Loading