Skip to content

Commit

Permalink
Use edit distance to find the best node namespace for TFLite model
Browse files Browse the repository at this point in the history
The previous method for associating TFLite model nodes with the most appropriate namespace relied on matching the final part of the candidate name with the node label. This approach can sometimes be unreliable, leading to potential confusion in the resulting graph representation.

We are introducing an edit distance approach for more robust namespace identification, even when candidate names don't precisely match node labels.

Key Design Points:
- Backward Iteration: Prioritizing candidate names discovered later in the search often yields better results. Fused nodes typically appear later and tend to more closely mirror the original server model's node hierarchy.
- Distance Threshold: To prevent irrelevant namespaces from being associated with node labels, a threshold is set at three times the length of the node label. If the edit distance exceeds this limit, the namespace is deemed irrelevant.
- Default Behavior: In cases where no suitable namespace is found (all candidates are irrelevant), the first candidate name is used as the default.

PiperOrigin-RevId: 671922889
  • Loading branch information
yijie-yang authored and copybara-github committed Sep 6, 2024
1 parent f8a66c6 commit 265b7cd
Show file tree
Hide file tree
Showing 6 changed files with 185 additions and 80 deletions.
2 changes: 2 additions & 0 deletions src/builtin-adapter/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ cc_library(
"//formats:schema_structs",
"//tools:attribute_printer",
"//tools:load_opdefs",
"//tools:namespace_heuristics",
"@org_tensorflow//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
# copybara:uncomment "@org_tensorflow//tensorflow_text:ops_lib",
],
Expand Down Expand Up @@ -181,6 +182,7 @@ cc_library(
"//tools:attribute_printer",
"//tools:convert_type",
"//tools:load_opdefs",
"//tools:namespace_heuristics",
"@org_tensorflow//tensorflow/compiler/mlir/lite/schema:schema_fbs",
"@org_tensorflow//tensorflow/compiler/mlir/lite/schema:schema_utils",
"@org_tensorflow//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
Expand Down
47 changes: 6 additions & 41 deletions src/builtin-adapter/direct_flatbuffer_to_json_graph_convert.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ limitations under the License.

#include <cstdint>
#include <cstring>
#include <iterator>
#include <memory>
#include <optional>
#include <string>
Expand All @@ -32,7 +31,6 @@ limitations under the License.
#include "absl/strings/ascii.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/str_replace.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "flatbuffers/flexbuffers.h"
Expand All @@ -57,6 +55,7 @@ limitations under the License.
#include "tools/attribute_printer.h"
#include "tools/convert_type.h"
#include "tools/load_opdefs.h"
#include "tools/namespace_heuristics.h"
#include "visualize_config.h"
#include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
Expand Down Expand Up @@ -199,15 +198,8 @@ std::string StringifyTensorShape(const TensorT& tensor) {
return absl::StrCat(TensorTypeToString(tensor.type), "[", shape_str, "]");
}

// Generates the node name based on the provided tensor indices.
//
// In TFLite, a single tensor name could still contain several hierarchical info
// concatenated together with semicolons. In this case, we will find the last
// candidate node name that contains this node label. If no match is found, we
// will return the first candidate node name by default. This method also echos
// the MLIR-based conversion for Flatbuffer.
std::string GenerateNodeName(absl::string_view node_id_str,
absl::string_view node_label,
// Obtains the node namespace based on the node label and related tensor names.
std::string GenerateNodeName(absl::string_view node_label,
const std::vector<int>& tensor_indices,
const Tensors& tensors) {
if (tensor_indices.empty()) return "";
Expand All @@ -224,33 +216,7 @@ std::string GenerateNodeName(absl::string_view node_id_str,
candidate_names.push_back(std::string(name));
}
}
if (candidate_names.empty()) return "";
if (candidate_names.size() == 1) {
return candidate_names[0];
}

// Removes any underscores in `node_label`.
const std::string node_label_substr =
absl::StrReplaceAll(node_label, {{"_", ""}});

// Iterates backwards to find if the last chunk of candidate_name contains the
// node label in the end hierarchy.
for (auto name_it = std::rbegin(candidate_names);
name_it != std::rend(candidate_names); ++name_it) {
const auto start_pos = name_it->find_last_of('/');
std::string last_substr;
if (start_pos != std::string::npos) {
last_substr = name_it->substr(start_pos, name_it->size());
} else {
last_substr = *name_it;
}
if (absl::AsciiStrToLower(last_substr).find(node_label_substr) !=
std::string::npos) {
return *name_it;
}
}

return candidate_names[0];
return TfliteNodeNamespaceHeuristic(node_label, candidate_names);
}

void AppendMetadata(
Expand Down Expand Up @@ -410,8 +376,7 @@ absl::Status AddAuxiliaryNode(
case NodeType::kConstNode: {
edge_type = EdgeType::kOutput;
node_label = kPseudoConst;
node_name =
GenerateNodeName(node_id_str, node_label, tensor_indices, tensors);
node_name = GenerateNodeName(node_label, tensor_indices, tensors);
break;
}
default: {
Expand Down Expand Up @@ -725,7 +690,7 @@ absl::Status AddNode(
const std::string node_id_str = node_ids[node_index];
absl::string_view node_label = op_names[op.opcode_index];
const std::string node_name =
GenerateNodeName(node_id_str, node_label, op.outputs, tensors);
GenerateNodeName(node_label, op.outputs, tensors);
GraphNodeBuilder builder;
builder.SetNodeInfo(node_id_str, node_label, node_name);
// Logs the error and continues to add the node to the graph.
Expand Down
11 changes: 11 additions & 0 deletions src/builtin-adapter/tools/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,14 @@ cc_library(
hdrs = ["load_opdefs.h"],
deps = ["@com_google_absl//absl/container:flat_hash_map"],
)

cc_library(
name = "namespace_heuristics",
srcs = ["namespace_heuristics.cc"],
hdrs = ["namespace_heuristics.h"],
deps = [
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
],
)
126 changes: 126 additions & 0 deletions src/builtin-adapter/tools/namespace_heuristics.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
#include "tools/namespace_heuristics.h"

#include <algorithm>
#include <limits>
#include <string>
#include <vector>

#include "absl/strings/ascii.h"
#include "absl/strings/str_replace.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"

namespace tooling {
namespace visualization_client {
namespace {

// Gets the edit distance between two strings with the optimized memory usage.
int EditDistance(absl::string_view shorter_str, absl::string_view longer_str) {
// Ensure that 'shorter_str' is indeed the shorter of the two strings
// This helps optimize space usage in the DP table
const int short_len = shorter_str.size();
const int long_len = longer_str.size();
if (short_len > long_len) {
return EditDistance(longer_str, shorter_str);
}

// 'prev_diag' stores the value from the previous diagonal in the DP table.
int prev_diag;

// 'curr_row' represents the current row in the DP table. Initialize it with
// increasing values from 0 to 'short_len', representing the edit distance
// when the longer string is empty.
std::vector<int> curr_row(short_len + 1, 0);
for (int j = 0; j <= short_len; j++) {
curr_row[j] = j;
}

for (int i = 1; i <= long_len; i++) {
prev_diag = curr_row[0];
// The first element in each row represents the edit distance when the
// shorter string is empty.
curr_row[0] = i;
for (int j = 1; j <= short_len; j++) {
int temp = curr_row[j];
// If the characters match, the edit distance is the same as the value on
// the previous diagonal.
if (longer_str[i - 1] == shorter_str[j - 1]) {
curr_row[j] = prev_diag;
} else {
// If the characters don't match, the edit distance is 1 plus the
// minimum of:
// 1. Insertion: 'curr_row[j]' (current cell)
// 2. Deletion: 'prev_diag' (top-left diagonal)
// 3. Substitution: 'curr_row[j - 1]' (left cell)
curr_row[j] = 1 + std::min({curr_row[j - 1], prev_diag, curr_row[j]});
}
prev_diag = temp;
}
}
// The final edit distance is stored in the last element of 'curr_row'.
return curr_row[short_len];
}

// Preprocesses the candidate name by obtaining the last chunk of the substring
// separated by '/', removing the non-alphabetic characters and converting to
// lower case.
std::string PreprocessCandidateName(absl::string_view name) {
const int start_pos = name.find_last_of('/');
std::string last_substr;
if (start_pos != std::string::npos) {
last_substr = name.substr(start_pos + 1, name.size());
} else {
last_substr = name;
}
// Removes the non-alphabetic characters and converts to lower case.
last_substr.erase(
std::remove_if(last_substr.begin(), last_substr.end(),
[](unsigned char c) { return !absl::ascii_isalpha(c); }),
last_substr.end());
last_substr = absl::AsciiStrToLower(last_substr);
return last_substr;
}

} // namespace

std::string TfliteNodeNamespaceHeuristic(
absl::string_view node_label,
absl::Span<const std::string> candidate_names) {
if (candidate_names.empty()) return "";
if (candidate_names.size() == 1) {
return candidate_names[0];
}

// Removes any underscores in `node_label`.
const std::string node_label_substr =
absl::StrReplaceAll(node_label, {{"_", ""}});

// Default the name to the first candidate name.
std::string result_name = candidate_names[0];
int min_distance = std::numeric_limits<int>::max();
// Sets the max distance threshold to be three times the length of the node
// label substring. If the distance is larger than the threshold, it's
// considered as irrelevant.
const int max_distance_threshold = 3 * node_label_substr.length();
// Iterates backwards is critical in finding a better match.
for (auto name_it = std::rbegin(candidate_names);
name_it != std::rend(candidate_names); ++name_it) {
const std::string last_substr = PreprocessCandidateName(*name_it);
// Skips the empty string to avoid false matching.
if (last_substr.empty()) {
continue;
}
int cur_distance = EditDistance(node_label_substr, last_substr);
if (cur_distance > max_distance_threshold) {
continue;
}
if (cur_distance < min_distance) {
min_distance = cur_distance;
result_name = *name_it;
}
}
return result_name;
}

} // namespace visualization_client
} // namespace tooling
27 changes: 27 additions & 0 deletions src/builtin-adapter/tools/namespace_heuristics.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_GOOGLE_TOOLING_TOOLS_NAMESPACE_HEURISTICS_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_GOOGLE_TOOLING_TOOLS_NAMESPACE_HEURISTICS_H_

#include <string>

#include "absl/strings/string_view.h"
#include "absl/types/span.h"
namespace tooling {
namespace visualization_client {

// Obtains the best matching namespace for the TFLite node based on the provided
// node label and candidate names.
//
// The candidate names are obtained from the tensor names. The node namespace is
// obtained by the following steps:
// 1. If there are no candidate names, returns an empty string.
// 2. If there is only one candidate name, returns the candidate name.
// 3. If there are multiple candidate names, iterates backwards and returns the
// candidate name with the minimum edit distance to the node label.
std::string TfliteNodeNamespaceHeuristic(
absl::string_view node_label,
absl::Span<const std::string> candidate_names);

} // namespace visualization_client
} // namespace tooling

#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_GOOGLE_TOOLING_TOOLS_NAMESPACE_HEURISTICS_H_
52 changes: 13 additions & 39 deletions src/builtin-adapter/translate_helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ limitations under the License.

#include "translate_helpers.h"

#include <cstddef>
#include <string>
#include <utility>
#include <vector>
Expand All @@ -25,7 +24,7 @@ limitations under the License.
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_replace.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
Expand All @@ -50,6 +49,7 @@ limitations under the License.
#include "status_macros.h"
#include "tools/attribute_printer.h"
#include "tools/load_opdefs.h"
#include "tools/namespace_heuristics.h"
#include "visualize_config.h"
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
Expand Down Expand Up @@ -266,13 +266,12 @@ llvm::StringRef GetTfNodeName(Operation& operation) {

// Generates the node name (the hierarchical information of the node) from a tfl
// dialect operation.
llvm::StringRef GenerateTfliteNodeName(llvm::StringRef node_id_str,
llvm::StringRef node_label,
Operation& operation) {
std::string GenerateTfliteNodeName(llvm::StringRef node_label,
Operation& operation) {
auto fusedLoc = operation.getLoc()->findInstanceOf<mlir::FusedLoc>();
auto nameLoc = operation.getLoc()->findInstanceOf<mlir::NameLoc>();
if (fusedLoc == nullptr && nameLoc == nullptr) {
return kEmptyString;
return "";
}
// In TFLite, we store op's output tensor names in location attribute. So it
// could be either a simple NameLoc of the original node_name; or a special
Expand All @@ -290,39 +289,15 @@ llvm::StringRef GenerateTfliteNodeName(llvm::StringRef node_id_str,
// concatenated together with semicolons. In this case, we will find the last
// single node name that contains this node label. If no matching found, we
// will return the first single node name by default.
llvm::SmallVector<llvm::StringRef, 4> candidate_names;
for (const llvm::StringRef tensor_name : tensor_names) {
llvm::SmallVector<llvm::StringRef, 4> tmp_names;
tensor_name.split(tmp_names, kSemicolonSeparator, /*MaxSplit=*/-1,
/*KeepEmpty=*/false);
for (const llvm::StringRef name : tmp_names) {
candidate_names.push_back(name);
std::vector<std::string> candidate_names;
for (absl::string_view tensor_name : tensor_names) {
std::vector<std::string> tmp_names =
absl::StrSplit(tensor_name, ';', absl::SkipEmpty());
for (absl::string_view name : tmp_names) {
candidate_names.push_back(std::string(name));
}
}
if (candidate_names.empty()) {
return kEmptyString;
}
if (candidate_names.size() == 1) {
return candidate_names.front();
}
// Removes any underscores in `node_label`.
const std::string node_label_substr =
absl::StrReplaceAll(node_label, {{"_", ""}});
// We iterate backwards to find if a single node name contains the node
// label in the end hierarchy.
for (auto it = candidate_names.rbegin(); it != candidate_names.rend(); ++it) {
llvm::StringRef name = *it;
llvm::StringRef last_substr = name;
const size_t start_pos = name.find_last_of('/');
if (start_pos != std::string::npos) {
last_substr = name.substr(start_pos);
}
if (last_substr.contains_insensitive(node_label_substr)) {
return name;
}
}

return candidate_names.front();
return TfliteNodeNamespaceHeuristic(node_label, candidate_names);
}

// Gets a list of output tensor name(s) of an TFLite operation. Returns empty
Expand Down Expand Up @@ -505,8 +480,7 @@ absl::StatusOr<Subgraph> TfliteFunctionToSubgraph(const VisualizeConfig& config,
node_label = kGraphOutputs.str();
}
seen_ops.insert({&operation, node_id});
llvm::StringRef node_name =
GenerateTfliteNodeName(node_id, node_label, operation);
std::string node_name = GenerateTfliteNodeName(node_label, operation);
GraphNodeBuilder builder;
builder.SetNodeInfo(node_id, node_label, node_name);
AppendNodeAttrs(config.const_element_count_limit, operation, builder);
Expand Down

0 comments on commit 265b7cd

Please sign in to comment.