Skip to content

Commit

Permalink
Add quant_params_count_limit to VisualizeConfig
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 633054360
  • Loading branch information
yijie-yang authored and copybara-github committed Sep 10, 2024
1 parent 265b7cd commit 1eb8f45
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 16 deletions.
36 changes: 23 additions & 13 deletions src/builtin-adapter/direct_flatbuffer_to_json_graph_convert.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ limitations under the License.

#include "direct_flatbuffer_to_json_graph_convert.h"

#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <cstring>
#include <memory>
Expand Down Expand Up @@ -649,6 +651,7 @@ absl::Status AddTensorTags(const OperatorT& op, absl::string_view op_label,
}

void AddQuantizationParameters(const std::unique_ptr<TensorT>& tensor,
const size_t size_limit,
const EdgeType edge_type, const int rel_idx,
GraphNodeBuilder& builder) {
if (tensor->quantization == nullptr) return;
Expand All @@ -662,9 +665,13 @@ void AddQuantizationParameters(const std::unique_ptr<TensorT>& tensor,
}
if (quant->scale.empty()) return;

const unsigned num_params = (size_limit < 0)
? quant->scale.size()
: std::min(quant->scale.size(), size_limit);
if (num_params == 0) return;
std::vector<std::string> parameters;
parameters.reserve(quant->scale.size());
for (int i = 0; i < quant->scale.size(); ++i) {
parameters.reserve(num_params);
for (int i = 0; i < num_params; ++i) {
// Parameters will be shown as "[scale] * (q + [zero_point])"
parameters.push_back(
absl::StrCat(quant->scale[i], " * (q + ", quant->zero_point[i], ")"));
Expand All @@ -680,7 +687,7 @@ absl::Status AddNode(
const Buffers& buffers, const std::vector<std::string>& func_names,
const std::optional<const SignatureNameMap>& signature_name_map,
const OpdefsMap& op_defs, const std::unique_ptr<FlatBufferModel>& model_ptr,
const int const_element_count_limit, std::vector<std::string>& node_ids,
const VisualizeConfig& config, std::vector<std::string>& node_ids,
EdgeMap& edge_map, mlir::Builder mlir_builder, Subgraph& subgraph) {
if (op.opcode_index >= op_names.size()) {
return absl::InvalidArgumentError(
Expand Down Expand Up @@ -713,14 +720,16 @@ absl::Status AddNode(
// when the input tensor is constant and not an output of a node. Thus we
// create an auxiliary constant node to align with graph structure.
if (EdgeInfoIncomplete(edge_map.at(tensor_index))) {
RETURN_IF_ERROR(AddAuxiliaryNode(
NodeType::kConstNode, std::vector<int>{tensor_index}, tensors,
buffers, signature_name_map, model_ptr, const_element_count_limit,
node_ids, edge_map, mlir_builder, subgraph));
RETURN_IF_ERROR(
AddAuxiliaryNode(NodeType::kConstNode, std::vector<int>{tensor_index},
tensors, buffers, signature_name_map, model_ptr,
config.const_element_count_limit, node_ids, edge_map,
mlir_builder, subgraph));
}
AppendIncomingEdge(edge_map.at(tensor_index), builder);
AddQuantizationParameters(tensors[tensor_index], EdgeType::kInput, i,
builder);
AddQuantizationParameters(tensors[tensor_index],
config.quant_params_count_limit, EdgeType::kInput,
i, builder);
}

for (int i = 0; i < op.outputs.size(); ++i) {
Expand All @@ -732,8 +741,9 @@ absl::Status AddNode(
.source_node_output_id = absl::StrCat(i)},
edge_map);

AddQuantizationParameters(tensors[tensor_index], EdgeType::kOutput, i,
builder);
AddQuantizationParameters(tensors[tensor_index],
config.quant_params_count_limit,
EdgeType::kOutput, i, builder);
}

status = AddTensorTags(op, node_label, op_defs, builder);
Expand Down Expand Up @@ -802,8 +812,8 @@ absl::Status AddSubgraph(
const Tensors& tensors = subgraph_t.tensors;
RETURN_IF_ERROR(AddNode(i, *op, op_codes, op_names, tensors, buffers,
func_names, signature_name_map, op_defs, model_ptr,
config.const_element_count_limit, node_ids,
edge_map, mlir_builder, subgraph));
config, node_ids, edge_map, mlir_builder,
subgraph));
}

// Adds GraphOutputs node to the subgraph.
Expand Down
10 changes: 9 additions & 1 deletion src/builtin-adapter/models_to_json_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ constexpr char kInputFileFlag[] = "i";
constexpr char kOutputFileFlag[] = "o";
constexpr char kConstElementCountLimitFlag[] = "const_element_count_limit";
constexpr char kDisableMlirFlag[] = "disable_mlir";
constexpr char kQuantParamsCountLimitFlag[] = "quant_params_count_limit";

namespace {

Expand All @@ -42,6 +43,7 @@ int main(int argc, char* argv[]) {
// Creates and parses flags.
std::string input_file, output_file;
int const_element_count_limit = 16;
int quant_params_count_limit = 16;
bool disable_mlir = false;

std::vector<mlir::Flag> flag_list = {
Expand All @@ -61,6 +63,12 @@ int main(int argc, char* argv[]) {
"Disable the MLIR-based conversion. If set to true, the conversion "
"becomes from model directly to graph json",
mlir::Flag::kOptional),
mlir::Flag::CreateFlag(
kQuantParamsCountLimitFlag, &quant_params_count_limit,
"The maximum number of quant parameters. If the number exceeds this "
"threshold, the rest of data will be elided. If the flag is not set, "
"the default threshold is 16 (use -1 to print all)",
mlir::Flag::kOptional),
};
mlir::Flags::Parse(&argc, const_cast<const char**>(argv), flag_list);

Expand All @@ -75,7 +83,7 @@ int main(int argc, char* argv[]) {

// Creates visualization config.
tooling::visualization_client::VisualizeConfig config(
const_element_count_limit);
const_element_count_limit, quant_params_count_limit);

const absl::StatusOr<std::string> json_output =
ConvertModelToJson(config, input_file, disable_mlir);
Expand Down
11 changes: 9 additions & 2 deletions src/builtin-adapter/visualize_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,20 @@ namespace visualization_client {

struct VisualizeConfig {
VisualizeConfig() = default;
explicit VisualizeConfig(const int const_element_count_limit)
: const_element_count_limit(const_element_count_limit) {}
explicit VisualizeConfig(const int const_element_count_limit = 16,
const int quant_params_count_limit = 16)
: const_element_count_limit(const_element_count_limit),
quant_params_count_limit(quant_params_count_limit) {}

// The maximum number of constant elements to be displayed. If the number
// exceeds this threshold, the rest of data will be elided. The default
// threshold is set to 16 (use -1 to print all).
int const_element_count_limit = 16;

// The maximum number of quantization parameters to be displayed. If the
// number exceeds this threshold, the rest of data will be elided. The default
// threshold is set to 16 (use -1 to print all).
int quant_params_count_limit = 16;
};

} // namespace visualization_client
Expand Down

0 comments on commit 1eb8f45

Please sign in to comment.