Skip to content

Commit

Permalink
support coreml model cache
Browse files Browse the repository at this point in the history
  • Loading branch information
wejoncy committed Dec 10, 2024
1 parent b14b4ec commit fc9db07
Show file tree
Hide file tree
Showing 11 changed files with 244 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,18 @@ static const char* const kCoremlProviderOption_SpecializationStrategy = "Special
static const char* const kCoremlProviderOption_ProfileComputePlan = "ProfileComputePlan";
// please refer to https://developer.apple.com/documentation/coreml/mlmodelconfiguration/allowlowprecisionaccumulationongpu
static const char* const kCoremlProviderOption_AllowLowPrecisionAccumulationOnGPU = "AllowLowPrecisionAccumulationOnGPU";
// Specify the path to cache the model.
// CoreML EP will convert onnx subgraph to CoreML model and save to disk.
// If this path is not specified, the model will be saved to a temp directory and deleted after the session is closed.
// otherwise, the model will be saved to the specified path and User should manage to delete the model.
// The basic logic is:
// if (ModelCachePath != nullptr && ModelCachePath/cache_coreml.exists()) {
// // load from cache_coreml
// } else {
// // save to ModelCachePath
// }
// we wound not detect if the cached model match the onnx subgraph, so User should carefully manage the cache for a new model.
static const char* const kCoremlProviderOption_ModelCachePath = "ModelCachePath";

#ifdef __cplusplus
extern "C" {
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/platform/env.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ class Env {
#ifdef _WIN32
/// \brief Returns true if the directory exists.
virtual bool FolderExists(const std::wstring& path) const = 0;
virtual bool FileExists(const std::wstring& path) const = 0;
/// \brief Recursively creates the directory, if it doesn't exist.
virtual common::Status CreateFolder(const std::wstring& path) const = 0;
// Mainly for use with protobuf library
Expand All @@ -206,6 +207,7 @@ class Env {
#endif
/// \brief Returns true if the directory exists.
virtual bool FolderExists(const std::string& path) const = 0;
virtual bool FileExists(const std::string& path) const = 0;
/// \brief Recursively creates the directory, if it doesn't exist.
virtual common::Status CreateFolder(const std::string& path) const = 0;
// Recursively deletes the directory and its contents.
Expand Down
8 changes: 8 additions & 0 deletions onnxruntime/core/platform/posix/env.cc
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,14 @@ class PosixEnv : public Env {
return S_ISDIR(sb.st_mode);
}

bool FileExists(const std::string& path) const override {
struct stat sb;
if (stat(path.c_str(), &sb)) {
return false;
}
return S_ISREG(sb.st_mode);
}

common::Status CreateFolder(const std::string& path) const override {
size_t pos = 0;
do {
Expand Down
10 changes: 10 additions & 0 deletions onnxruntime/core/platform/windows/env.cc
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,16 @@ bool WindowsEnv::FolderExists(const std::string& path) const {
return (attributes != INVALID_FILE_ATTRIBUTES) && (attributes & FILE_ATTRIBUTE_DIRECTORY);
}

bool WindowsEnv::FileExists(const std::wstring& path) const {
DWORD attributes = GetFileAttributesW(path.c_str());
return (attributes != INVALID_FILE_ATTRIBUTES) && (attributes & FILE_ATTRIBUTE_NORMAL);
}

bool WindowsEnv::FileExists(const std::string& path) const {
DWORD attributes = GetFileAttributesA(path.c_str());
return (attributes != INVALID_FILE_ATTRIBUTES) && (attributes & FILE_ATTRIBUTE_NORMAL);
}

common::Status WindowsEnv::CreateFolder(const std::wstring& path) const {
size_t pos = 0;
do {
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/platform/windows/env.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ class WindowsEnv : public Env {
MappedMemoryPtr& mapped_memory) const override;
bool FolderExists(const std::wstring& path) const override;
bool FolderExists(const std::string& path) const override;
bool FileExists(const std::wstring& path) const override;
bool FileExists(const std::string& path) const override;
common::Status CreateFolder(const std::wstring& path) const override;
common::Status CreateFolder(const std::string& path) const override;
common::Status DeleteFolder(const PathString& path) const override;
Expand Down
42 changes: 39 additions & 3 deletions onnxruntime/core/providers/coreml/builders/model_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -410,10 +410,37 @@ ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logge
coreml_version_(coreml_version),
coreml_options_(coreml_options),
create_ml_program_(coreml_options.CreateMLProgram()),
model_output_path_(GetModelOutputPath(create_ml_program_)),
onnx_input_names_(std::move(onnx_input_names)),
onnx_output_names_(std::move(onnx_output_names)),
coreml_model_(std::make_unique<CoreML::Specification::Model>()) {
if (coreml_options.ModelCachePath().empty()) {
model_output_path_ = GetModelOutputPath(create_ml_program_);
} else {
// input names in onnx are unique. so we can use them as the key in the cache.
std::string inputs_collections = std::accumulate(
onnx_input_names_.begin(), onnx_input_names_.end(), std::string(),
[](const std::string& a, const std::string& b) { return a + "," + b; });
std::hash<std::string> hasher;
// different subgraph has different folders. so we need to hash the inputs.
model_output_path_ = std::string(coreml_options.ModelCachePath()) +
"/" + std::to_string(hasher(inputs_collections));
if (!coreml_options_.CreateMLProgram()) {
ORT_THROW_IF_ERROR(Env::Default().CreateFolder(model_output_path_));
model_output_path_ += "/mlmodel";
}
}

// GetModelOutputPath(create_ml_program_) always produce a unique path for the model and this is not existed
// Mlprogram will create a folder while NN create a file
if (Env::Default().FolderExists(ToPathString(model_output_path_)) ||
Env::Default().FileExists(ToPathString(model_output_path_))) {
is_model_cached_ = true;
LOGS(logger, WARNING) << "Model is already cached in " << model_output_path_
<< " and will be reused. If you want to update the model or hit other issues, "
<< "please consider to clear the cache and retry.";
return;
}

if (create_ml_program_) {
#if defined(COREML_ENABLE_MLPROGRAM)
coreml_model_->set_specificationversion(CoreMLSpecVersion());
Expand Down Expand Up @@ -847,6 +874,10 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i

input_output_info_.emplace(name, OnnxTensorInfo{data_type, shape});

if (is_model_cached_) {
return Status::OK();
}

#if defined(COREML_ENABLE_MLPROGRAM)
if (create_ml_program_) {
if (is_input) {
Expand Down Expand Up @@ -1056,8 +1087,13 @@ Status ModelBuilder::Build(const GraphViewer& graph_viewer, const logging::Logge
ModelBuilder builder(graph_viewer, logger, coreml_version, coreml_options,
std::move(onnx_input_names), std::move(onnx_output_names));

ORT_RETURN_IF_ERROR(builder.CreateModel());
ORT_RETURN_IF_ERROR(builder.SaveModel());
if (!builder.IsModelCached()) {
ORT_RETURN_IF_ERROR(builder.CreateModel());
ORT_RETURN_IF_ERROR(builder.SaveModel());
} else {
ORT_RETURN_IF_ERROR(builder.RegisterModelInputs());
ORT_RETURN_IF_ERROR(builder.RegisterModelOutputs());
}

return builder.LoadModel(model);
}
Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/core/providers/coreml/builders/model_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class ModelBuilder {
// We only support CoreML 3 and later so the spec version is always version + 1.
int32_t CoreMLVersion() const { return coreml_version_; }
int32_t CoreMLSpecVersion() const { return coreml_version_ + 1; }
bool IsModelCached() const { return is_model_cached_; }

// Returns true if we are creating an ML Program
bool CreateMLProgram() const {
Expand Down Expand Up @@ -218,8 +219,9 @@ class ModelBuilder {
const logging::Logger& logger_;
const int32_t coreml_version_;
CoreMLOptions coreml_options_;
const bool create_ml_program_; // ML Program (CoreML5, iOS 15+, macOS 12+) or NeuralNetwork (old)
const std::string model_output_path_; // create_ml_program_ ? dir for mlpackage : filename for mlmodel
const bool create_ml_program_; // ML Program (CoreML5, iOS 15+, macOS 12+) or NeuralNetwork (old)
std::string model_output_path_; // create_ml_program_ ? dir for mlpackage : filename for mlmodel
bool is_model_cached_{false};

std::vector<std::string> onnx_input_names_;
std::vector<std::string> onnx_output_names_;
Expand Down
20 changes: 20 additions & 0 deletions onnxruntime/core/providers/coreml/coreml_options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "core/providers/coreml/coreml_provider_factory.h" // defines flags
#include "core/providers/coreml/model/host_utils.h"
#include "core/providers/coreml/builders/helper.h"
#include "core/platform/env.h"

namespace onnxruntime {

Expand Down Expand Up @@ -71,6 +72,7 @@ void CoreMLOptions::ValidateAndParseProviderOption(const ProviderOptions& option
kCoremlProviderOption_SpecializationStrategy,
kCoremlProviderOption_ProfileComputePlan,
kCoremlProviderOption_AllowLowPrecisionAccumulationOnGPU,
kCoremlProviderOption_ModelCachePath,
};
// Validate the options
for (const auto& option : options) {
Expand Down Expand Up @@ -103,7 +105,25 @@ void CoreMLOptions::ValidateAndParseProviderOption(const ProviderOptions& option
profile_compute_plan_ = option.second == "1";
} else if (kCoremlProviderOption_AllowLowPrecisionAccumulationOnGPU == option.first) {
allow_low_precision_accumulation_on_gpu_ = option.second == "1";
} else if (kCoremlProviderOption_ModelCachePath == option.first) {
model_cache_path_ = option.second;
}
}

// Set the model cache path with equireStaticShape and ModelFormat
if (model_cache_path_.size()) {
if (require_static_shape_) {
model_cache_path_ += "/static_shape";
} else {
model_cache_path_ += "/dynamic_shape";
}

if (create_mlprogram_) {
model_cache_path_ += "/mlpackage";
} else {
model_cache_path_ += "/mlnnmodel";
}
ORT_THROW_IF_ERROR(Env::Default().CreateFolder(model_cache_path_));
}
}
} // namespace onnxruntime
4 changes: 4 additions & 0 deletions onnxruntime/core/providers/coreml/coreml_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ class CoreMLOptions {
std::string strategy_;
bool profile_compute_plan_{false};
bool allow_low_precision_accumulation_on_gpu_{false};
// path to store the converted coreml model
std::string model_cache_path_;

public:
explicit CoreMLOptions(uint32_t coreml_flags);
Expand All @@ -32,6 +34,8 @@ class CoreMLOptions {
bool UseStrategy(std::string_view strategy) const { return strategy_ == strategy; }
bool ProfileComputePlan() const { return profile_compute_plan_ && create_mlprogram_; }

std::string_view ModelCachePath() const { return model_cache_path_; }

private:
void ValidateAndParseProviderOption(const ProviderOptions& options);
};
Expand Down
Loading

0 comments on commit fc9db07

Please sign in to comment.