diff --git a/src/builtin-adapter/BUILD b/src/builtin-adapter/BUILD index e50c1c72..9b06b58f 100644 --- a/src/builtin-adapter/BUILD +++ b/src/builtin-adapter/BUILD @@ -26,6 +26,7 @@ cc_library( "//formats:schema_structs", "//tools:attribute_printer", "//tools:load_opdefs", + "//tools:namespace_heuristics", "@org_tensorflow//tensorflow/compiler/mlir/tensorflow:tensorflow_types", # copybara:uncomment "@org_tensorflow//tensorflow_text:ops_lib", ], @@ -181,6 +182,7 @@ cc_library( "//tools:attribute_printer", "//tools:convert_type", "//tools:load_opdefs", + "//tools:namespace_heuristics", "@org_tensorflow//tensorflow/compiler/mlir/lite/schema:schema_fbs", "@org_tensorflow//tensorflow/compiler/mlir/lite/schema:schema_utils", "@org_tensorflow//tensorflow/compiler/mlir/tensorflow:tensorflow_types", diff --git a/src/builtin-adapter/direct_flatbuffer_to_json_graph_convert.cc b/src/builtin-adapter/direct_flatbuffer_to_json_graph_convert.cc index f80b1e63..a44264d5 100644 --- a/src/builtin-adapter/direct_flatbuffer_to_json_graph_convert.cc +++ b/src/builtin-adapter/direct_flatbuffer_to_json_graph_convert.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include -#include #include #include #include @@ -32,7 +31,6 @@ limitations under the License. #include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" -#include "absl/strings/str_replace.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "flatbuffers/flexbuffers.h" @@ -57,6 +55,7 @@ limitations under the License. #include "tools/attribute_printer.h" #include "tools/convert_type.h" #include "tools/load_opdefs.h" +#include "tools/namespace_heuristics.h" #include "visualize_config.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" @@ -199,15 +198,8 @@ std::string StringifyTensorShape(const TensorT& tensor) { return absl::StrCat(TensorTypeToString(tensor.type), "[", shape_str, "]"); } -// Generates the node name based on the provided tensor indices. -// -// In TFLite, a single tensor name could still contain several hierarchical info -// concatenated together with semicolons. In this case, we will find the last -// candidate node name that contains this node label. If no match is found, we -// will return the first candidate node name by default. This method also echos -// the MLIR-based conversion for Flatbuffer. -std::string GenerateNodeName(absl::string_view node_id_str, - absl::string_view node_label, +// Obtains the node namespace based on the node label and related tensor names. +std::string GenerateNodeName(absl::string_view node_label, const std::vector& tensor_indices, const Tensors& tensors) { if (tensor_indices.empty()) return ""; @@ -224,33 +216,7 @@ std::string GenerateNodeName(absl::string_view node_id_str, candidate_names.push_back(std::string(name)); } } - if (candidate_names.empty()) return ""; - if (candidate_names.size() == 1) { - return candidate_names[0]; - } - - // Removes any underscores in `node_label`. - const std::string node_label_substr = - absl::StrReplaceAll(node_label, {{"_", ""}}); - - // Iterates backwards to find if the last chunk of candidate_name contains the - // node label in the end hierarchy. - for (auto name_it = std::rbegin(candidate_names); - name_it != std::rend(candidate_names); ++name_it) { - const auto start_pos = name_it->find_last_of('/'); - std::string last_substr; - if (start_pos != std::string::npos) { - last_substr = name_it->substr(start_pos, name_it->size()); - } else { - last_substr = *name_it; - } - if (absl::AsciiStrToLower(last_substr).find(node_label_substr) != - std::string::npos) { - return *name_it; - } - } - - return candidate_names[0]; + return TfliteNodeNamespaceHeuristic(node_label, candidate_names); } void AppendMetadata( @@ -410,8 +376,7 @@ absl::Status AddAuxiliaryNode( case NodeType::kConstNode: { edge_type = EdgeType::kOutput; node_label = kPseudoConst; - node_name = - GenerateNodeName(node_id_str, node_label, tensor_indices, tensors); + node_name = GenerateNodeName(node_label, tensor_indices, tensors); break; } default: { @@ -725,7 +690,7 @@ absl::Status AddNode( const std::string node_id_str = node_ids[node_index]; absl::string_view node_label = op_names[op.opcode_index]; const std::string node_name = - GenerateNodeName(node_id_str, node_label, op.outputs, tensors); + GenerateNodeName(node_label, op.outputs, tensors); GraphNodeBuilder builder; builder.SetNodeInfo(node_id_str, node_label, node_name); // Logs the error and continues to add the node to the graph. diff --git a/src/builtin-adapter/tools/BUILD b/src/builtin-adapter/tools/BUILD index f206ea81..0672a35d 100644 --- a/src/builtin-adapter/tools/BUILD +++ b/src/builtin-adapter/tools/BUILD @@ -31,3 +31,14 @@ cc_library( hdrs = ["load_opdefs.h"], deps = ["@com_google_absl//absl/container:flat_hash_map"], ) + +cc_library( + name = "namespace_heuristics", + srcs = ["namespace_heuristics.cc"], + hdrs = ["namespace_heuristics.h"], + deps = [ + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) diff --git a/src/builtin-adapter/tools/namespace_heuristics.cc b/src/builtin-adapter/tools/namespace_heuristics.cc new file mode 100644 index 00000000..906d2c8b --- /dev/null +++ b/src/builtin-adapter/tools/namespace_heuristics.cc @@ -0,0 +1,126 @@ +#include "tools/namespace_heuristics.h" + +#include +#include +#include +#include + +#include "absl/strings/ascii.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" + +namespace tooling { +namespace visualization_client { +namespace { + +// Gets the edit distance between two strings with the optimized memory usage. +int EditDistance(absl::string_view shorter_str, absl::string_view longer_str) { + // Ensure that 'shorter_str' is indeed the shorter of the two strings + // This helps optimize space usage in the DP table + const int short_len = shorter_str.size(); + const int long_len = longer_str.size(); + if (short_len > long_len) { + return EditDistance(longer_str, shorter_str); + } + + // 'prev_diag' stores the value from the previous diagonal in the DP table. + int prev_diag; + + // 'curr_row' represents the current row in the DP table. Initialize it with + // increasing values from 0 to 'short_len', representing the edit distance + // when the longer string is empty. + std::vector curr_row(short_len + 1, 0); + for (int j = 0; j <= short_len; j++) { + curr_row[j] = j; + } + + for (int i = 1; i <= long_len; i++) { + prev_diag = curr_row[0]; + // The first element in each row represents the edit distance when the + // shorter string is empty. + curr_row[0] = i; + for (int j = 1; j <= short_len; j++) { + int temp = curr_row[j]; + // If the characters match, the edit distance is the same as the value on + // the previous diagonal. + if (longer_str[i - 1] == shorter_str[j - 1]) { + curr_row[j] = prev_diag; + } else { + // If the characters don't match, the edit distance is 1 plus the + // minimum of: + // 1. Insertion: 'curr_row[j]' (current cell) + // 2. Deletion: 'prev_diag' (top-left diagonal) + // 3. Substitution: 'curr_row[j - 1]' (left cell) + curr_row[j] = 1 + std::min({curr_row[j - 1], prev_diag, curr_row[j]}); + } + prev_diag = temp; + } + } + // The final edit distance is stored in the last element of 'curr_row'. + return curr_row[short_len]; +} + +// Preprocesses the candidate name by obtaining the last chunk of the substring +// separated by '/', removing the non-alphabetic characters and converting to +// lower case. +std::string PreprocessCandidateName(absl::string_view name) { + const int start_pos = name.find_last_of('/'); + std::string last_substr; + if (start_pos != std::string::npos) { + last_substr = name.substr(start_pos + 1, name.size()); + } else { + last_substr = name; + } + // Removes the non-alphabetic characters and converts to lower case. + last_substr.erase( + std::remove_if(last_substr.begin(), last_substr.end(), + [](unsigned char c) { return !absl::ascii_isalpha(c); }), + last_substr.end()); + last_substr = absl::AsciiStrToLower(last_substr); + return last_substr; +} + +} // namespace + +std::string TfliteNodeNamespaceHeuristic( + absl::string_view node_label, + absl::Span candidate_names) { + if (candidate_names.empty()) return ""; + if (candidate_names.size() == 1) { + return candidate_names[0]; + } + + // Removes any underscores in `node_label`. + const std::string node_label_substr = + absl::StrReplaceAll(node_label, {{"_", ""}}); + + // Default the name to the first candidate name. + std::string result_name = candidate_names[0]; + int min_distance = std::numeric_limits::max(); + // Sets the max distance threshold to be three times the length of the node + // label substring. If the distance is larger than the threshold, it's + // considered as irrelevant. + const int max_distance_threshold = 3 * node_label_substr.length(); + // Iterates backwards is critical in finding a better match. + for (auto name_it = std::rbegin(candidate_names); + name_it != std::rend(candidate_names); ++name_it) { + const std::string last_substr = PreprocessCandidateName(*name_it); + // Skips the empty string to avoid false matching. + if (last_substr.empty()) { + continue; + } + int cur_distance = EditDistance(node_label_substr, last_substr); + if (cur_distance > max_distance_threshold) { + continue; + } + if (cur_distance < min_distance) { + min_distance = cur_distance; + result_name = *name_it; + } + } + return result_name; +} + +} // namespace visualization_client +} // namespace tooling diff --git a/src/builtin-adapter/tools/namespace_heuristics.h b/src/builtin-adapter/tools/namespace_heuristics.h new file mode 100644 index 00000000..dd31ca5c --- /dev/null +++ b/src/builtin-adapter/tools/namespace_heuristics.h @@ -0,0 +1,27 @@ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_GOOGLE_TOOLING_TOOLS_NAMESPACE_HEURISTICS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_GOOGLE_TOOLING_TOOLS_NAMESPACE_HEURISTICS_H_ + +#include + +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +namespace tooling { +namespace visualization_client { + +// Obtains the best matching namespace for the TFLite node based on the provided +// node label and candidate names. +// +// The candidate names are obtained from the tensor names. The node namespace is +// obtained by the following steps: +// 1. If there are no candidate names, returns an empty string. +// 2. If there is only one candidate name, returns the candidate name. +// 3. If there are multiple candidate names, iterates backwards and returns the +// candidate name with the minimum edit distance to the node label. +std::string TfliteNodeNamespaceHeuristic( + absl::string_view node_label, + absl::Span candidate_names); + +} // namespace visualization_client +} // namespace tooling + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_GOOGLE_TOOLING_TOOLS_NAMESPACE_HEURISTICS_H_ diff --git a/src/builtin-adapter/translate_helpers.cc b/src/builtin-adapter/translate_helpers.cc index 8de0c277..7e1dc27c 100644 --- a/src/builtin-adapter/translate_helpers.cc +++ b/src/builtin-adapter/translate_helpers.cc @@ -15,7 +15,6 @@ limitations under the License. #include "translate_helpers.h" -#include #include #include #include @@ -25,7 +24,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" -#include "absl/strings/str_replace.h" +#include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" @@ -50,6 +49,7 @@ limitations under the License. #include "status_macros.h" #include "tools/attribute_printer.h" #include "tools/load_opdefs.h" +#include "tools/namespace_heuristics.h" #include "visualize_config.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" @@ -266,13 +266,12 @@ llvm::StringRef GetTfNodeName(Operation& operation) { // Generates the node name (the hierarchical information of the node) from a tfl // dialect operation. -llvm::StringRef GenerateTfliteNodeName(llvm::StringRef node_id_str, - llvm::StringRef node_label, - Operation& operation) { +std::string GenerateTfliteNodeName(llvm::StringRef node_label, + Operation& operation) { auto fusedLoc = operation.getLoc()->findInstanceOf(); auto nameLoc = operation.getLoc()->findInstanceOf(); if (fusedLoc == nullptr && nameLoc == nullptr) { - return kEmptyString; + return ""; } // In TFLite, we store op's output tensor names in location attribute. So it // could be either a simple NameLoc of the original node_name; or a special @@ -290,39 +289,15 @@ llvm::StringRef GenerateTfliteNodeName(llvm::StringRef node_id_str, // concatenated together with semicolons. In this case, we will find the last // single node name that contains this node label. If no matching found, we // will return the first single node name by default. - llvm::SmallVector candidate_names; - for (const llvm::StringRef tensor_name : tensor_names) { - llvm::SmallVector tmp_names; - tensor_name.split(tmp_names, kSemicolonSeparator, /*MaxSplit=*/-1, - /*KeepEmpty=*/false); - for (const llvm::StringRef name : tmp_names) { - candidate_names.push_back(name); + std::vector candidate_names; + for (absl::string_view tensor_name : tensor_names) { + std::vector tmp_names = + absl::StrSplit(tensor_name, ';', absl::SkipEmpty()); + for (absl::string_view name : tmp_names) { + candidate_names.push_back(std::string(name)); } } - if (candidate_names.empty()) { - return kEmptyString; - } - if (candidate_names.size() == 1) { - return candidate_names.front(); - } - // Removes any underscores in `node_label`. - const std::string node_label_substr = - absl::StrReplaceAll(node_label, {{"_", ""}}); - // We iterate backwards to find if a single node name contains the node - // label in the end hierarchy. - for (auto it = candidate_names.rbegin(); it != candidate_names.rend(); ++it) { - llvm::StringRef name = *it; - llvm::StringRef last_substr = name; - const size_t start_pos = name.find_last_of('/'); - if (start_pos != std::string::npos) { - last_substr = name.substr(start_pos); - } - if (last_substr.contains_insensitive(node_label_substr)) { - return name; - } - } - - return candidate_names.front(); + return TfliteNodeNamespaceHeuristic(node_label, candidate_names); } // Gets a list of output tensor name(s) of an TFLite operation. Returns empty @@ -505,8 +480,7 @@ absl::StatusOr TfliteFunctionToSubgraph(const VisualizeConfig& config, node_label = kGraphOutputs.str(); } seen_ops.insert({&operation, node_id}); - llvm::StringRef node_name = - GenerateTfliteNodeName(node_id, node_label, operation); + std::string node_name = GenerateTfliteNodeName(node_label, operation); GraphNodeBuilder builder; builder.SetNodeInfo(node_id, node_label, node_name); AppendNodeAttrs(config.const_element_count_limit, operation, builder);