From b57c24ac2b22a150cca299989b3030f956bf957e Mon Sep 17 00:00:00 2001 From: Eric Yang Date: Sun, 12 May 2024 18:39:58 -0700 Subject: [PATCH] Add `quant_params_count_limit` to VisualizeConfig PiperOrigin-RevId: 633054360 --- ...direct_flatbuffer_to_json_graph_convert.cc | 36 ++++++++++++------- src/builtin-adapter/models_to_json_main.cc | 10 +++++- src/builtin-adapter/visualize_config.h | 11 ++++-- 3 files changed, 41 insertions(+), 16 deletions(-) diff --git a/src/builtin-adapter/direct_flatbuffer_to_json_graph_convert.cc b/src/builtin-adapter/direct_flatbuffer_to_json_graph_convert.cc index a44264d5..bb7a10d5 100644 --- a/src/builtin-adapter/direct_flatbuffer_to_json_graph_convert.cc +++ b/src/builtin-adapter/direct_flatbuffer_to_json_graph_convert.cc @@ -15,6 +15,8 @@ limitations under the License. #include "direct_flatbuffer_to_json_graph_convert.h" +#include +#include #include #include #include @@ -649,6 +651,7 @@ absl::Status AddTensorTags(const OperatorT& op, absl::string_view op_label, } void AddQuantizationParameters(const std::unique_ptr& tensor, + const size_t size_limit, const EdgeType edge_type, const int rel_idx, GraphNodeBuilder& builder) { if (tensor->quantization == nullptr) return; @@ -662,9 +665,13 @@ void AddQuantizationParameters(const std::unique_ptr& tensor, } if (quant->scale.empty()) return; + const unsigned num_params = (size_limit < 0) + ? quant->scale.size() + : std::min(quant->scale.size(), size_limit); + if (num_params == 0) return; std::vector parameters; - parameters.reserve(quant->scale.size()); - for (int i = 0; i < quant->scale.size(); ++i) { + parameters.reserve(num_params); + for (int i = 0; i < num_params; ++i) { // Parameters will be shown as "[scale] * (q + [zero_point])" parameters.push_back( absl::StrCat(quant->scale[i], " * (q + ", quant->zero_point[i], ")")); @@ -680,7 +687,7 @@ absl::Status AddNode( const Buffers& buffers, const std::vector& func_names, const std::optional& signature_name_map, const OpdefsMap& op_defs, const std::unique_ptr& model_ptr, - const int const_element_count_limit, std::vector& node_ids, + const VisualizeConfig& config, std::vector& node_ids, EdgeMap& edge_map, mlir::Builder mlir_builder, Subgraph& subgraph) { if (op.opcode_index >= op_names.size()) { return absl::InvalidArgumentError( @@ -713,14 +720,16 @@ absl::Status AddNode( // when the input tensor is constant and not an output of a node. Thus we // create an auxiliary constant node to align with graph structure. if (EdgeInfoIncomplete(edge_map.at(tensor_index))) { - RETURN_IF_ERROR(AddAuxiliaryNode( - NodeType::kConstNode, std::vector{tensor_index}, tensors, - buffers, signature_name_map, model_ptr, const_element_count_limit, - node_ids, edge_map, mlir_builder, subgraph)); + RETURN_IF_ERROR( + AddAuxiliaryNode(NodeType::kConstNode, std::vector{tensor_index}, + tensors, buffers, signature_name_map, model_ptr, + config.const_element_count_limit, node_ids, edge_map, + mlir_builder, subgraph)); } AppendIncomingEdge(edge_map.at(tensor_index), builder); - AddQuantizationParameters(tensors[tensor_index], EdgeType::kInput, i, - builder); + AddQuantizationParameters(tensors[tensor_index], + config.quant_params_count_limit, EdgeType::kInput, + i, builder); } for (int i = 0; i < op.outputs.size(); ++i) { @@ -732,8 +741,9 @@ absl::Status AddNode( .source_node_output_id = absl::StrCat(i)}, edge_map); - AddQuantizationParameters(tensors[tensor_index], EdgeType::kOutput, i, - builder); + AddQuantizationParameters(tensors[tensor_index], + config.quant_params_count_limit, + EdgeType::kOutput, i, builder); } status = AddTensorTags(op, node_label, op_defs, builder); @@ -802,8 +812,8 @@ absl::Status AddSubgraph( const Tensors& tensors = subgraph_t.tensors; RETURN_IF_ERROR(AddNode(i, *op, op_codes, op_names, tensors, buffers, func_names, signature_name_map, op_defs, model_ptr, - config.const_element_count_limit, node_ids, - edge_map, mlir_builder, subgraph)); + config, node_ids, edge_map, mlir_builder, + subgraph)); } // Adds GraphOutputs node to the subgraph. diff --git a/src/builtin-adapter/models_to_json_main.cc b/src/builtin-adapter/models_to_json_main.cc index d6a31571..57d54b98 100644 --- a/src/builtin-adapter/models_to_json_main.cc +++ b/src/builtin-adapter/models_to_json_main.cc @@ -29,6 +29,7 @@ constexpr char kInputFileFlag[] = "i"; constexpr char kOutputFileFlag[] = "o"; constexpr char kConstElementCountLimitFlag[] = "const_element_count_limit"; constexpr char kDisableMlirFlag[] = "disable_mlir"; +constexpr char kQuantParamsCountLimitFlag[] = "quant_params_count_limit"; namespace { @@ -42,6 +43,7 @@ int main(int argc, char* argv[]) { // Creates and parses flags. std::string input_file, output_file; int const_element_count_limit = 16; + int quant_params_count_limit = 16; bool disable_mlir = false; std::vector flag_list = { @@ -61,6 +63,12 @@ int main(int argc, char* argv[]) { "Disable the MLIR-based conversion. If set to true, the conversion " "becomes from model directly to graph json", mlir::Flag::kOptional), + mlir::Flag::CreateFlag( + kQuantParamsCountLimitFlag, &quant_params_count_limit, + "The maximum number of quant parameters. If the number exceeds this " + "threshold, the rest of data will be elided. If the flag is not set, " + "the default threshold is 16 (use -1 to print all)", + mlir::Flag::kOptional), }; mlir::Flags::Parse(&argc, const_cast(argv), flag_list); @@ -75,7 +83,7 @@ int main(int argc, char* argv[]) { // Creates visualization config. tooling::visualization_client::VisualizeConfig config( - const_element_count_limit); + const_element_count_limit, quant_params_count_limit); const absl::StatusOr json_output = ConvertModelToJson(config, input_file, disable_mlir); diff --git a/src/builtin-adapter/visualize_config.h b/src/builtin-adapter/visualize_config.h index be4b1fbf..2dfa67c0 100644 --- a/src/builtin-adapter/visualize_config.h +++ b/src/builtin-adapter/visualize_config.h @@ -21,13 +21,20 @@ namespace visualization_client { struct VisualizeConfig { VisualizeConfig() = default; - explicit VisualizeConfig(const int const_element_count_limit) - : const_element_count_limit(const_element_count_limit) {} + explicit VisualizeConfig(const int const_element_count_limit, + const int quant_params_count_limit) + : const_element_count_limit(const_element_count_limit), + quant_params_count_limit(quant_params_count_limit) {} // The maximum number of constant elements to be displayed. If the number // exceeds this threshold, the rest of data will be elided. The default // threshold is set to 16 (use -1 to print all). int const_element_count_limit = 16; + + // The maximum number of quantization parameters to be displayed. If the + // number exceeds this threshold, the rest of data will be elided. The default + // threshold is set to 16 (use -1 to print all). + int quant_params_count_limit = 16; }; } // namespace visualization_client