-
Notifications
You must be signed in to change notification settings - Fork 76
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Use edit distance to find the best node namespace for TFLite model
The previous method for associating TFLite model nodes with the most appropriate namespace relied on matching the final part of the candidate name with the node label. This approach can sometimes be unreliable, leading to potential confusion in the resulting graph representation. We are introducing an edit distance approach for more robust namespace identification, even when candidate names don't precisely match node labels. Key Design Points: - Backward Iteration: Prioritizing candidate names discovered later in the search often yields better results. Fused nodes typically appear later and tend to more closely mirror the original server model's node hierarchy. - Distance Threshold: To prevent irrelevant namespaces from being associated with node labels, a threshold is set at three times the length of the node label. If the edit distance exceeds this limit, the namespace is deemed irrelevant. - Default Behavior: In cases where no suitable namespace is found (all candidates are irrelevant), the first candidate name is used as the default. PiperOrigin-RevId: 671922889
- Loading branch information
1 parent
f8a66c6
commit 265b7cd
Showing
6 changed files
with
185 additions
and
80 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
#include "tools/namespace_heuristics.h" | ||
|
||
#include <algorithm> | ||
#include <limits> | ||
#include <string> | ||
#include <vector> | ||
|
||
#include "absl/strings/ascii.h" | ||
#include "absl/strings/str_replace.h" | ||
#include "absl/strings/string_view.h" | ||
#include "absl/types/span.h" | ||
|
||
namespace tooling { | ||
namespace visualization_client { | ||
namespace { | ||
|
||
// Gets the edit distance between two strings with the optimized memory usage. | ||
int EditDistance(absl::string_view shorter_str, absl::string_view longer_str) { | ||
// Ensure that 'shorter_str' is indeed the shorter of the two strings | ||
// This helps optimize space usage in the DP table | ||
const int short_len = shorter_str.size(); | ||
const int long_len = longer_str.size(); | ||
if (short_len > long_len) { | ||
return EditDistance(longer_str, shorter_str); | ||
} | ||
|
||
// 'prev_diag' stores the value from the previous diagonal in the DP table. | ||
int prev_diag; | ||
|
||
// 'curr_row' represents the current row in the DP table. Initialize it with | ||
// increasing values from 0 to 'short_len', representing the edit distance | ||
// when the longer string is empty. | ||
std::vector<int> curr_row(short_len + 1, 0); | ||
for (int j = 0; j <= short_len; j++) { | ||
curr_row[j] = j; | ||
} | ||
|
||
for (int i = 1; i <= long_len; i++) { | ||
prev_diag = curr_row[0]; | ||
// The first element in each row represents the edit distance when the | ||
// shorter string is empty. | ||
curr_row[0] = i; | ||
for (int j = 1; j <= short_len; j++) { | ||
int temp = curr_row[j]; | ||
// If the characters match, the edit distance is the same as the value on | ||
// the previous diagonal. | ||
if (longer_str[i - 1] == shorter_str[j - 1]) { | ||
curr_row[j] = prev_diag; | ||
} else { | ||
// If the characters don't match, the edit distance is 1 plus the | ||
// minimum of: | ||
// 1. Insertion: 'curr_row[j]' (current cell) | ||
// 2. Deletion: 'prev_diag' (top-left diagonal) | ||
// 3. Substitution: 'curr_row[j - 1]' (left cell) | ||
curr_row[j] = 1 + std::min({curr_row[j - 1], prev_diag, curr_row[j]}); | ||
} | ||
prev_diag = temp; | ||
} | ||
} | ||
// The final edit distance is stored in the last element of 'curr_row'. | ||
return curr_row[short_len]; | ||
} | ||
|
||
// Preprocesses the candidate name by obtaining the last chunk of the substring | ||
// separated by '/', removing the non-alphabetic characters and converting to | ||
// lower case. | ||
std::string PreprocessCandidateName(absl::string_view name) { | ||
const int start_pos = name.find_last_of('/'); | ||
std::string last_substr; | ||
if (start_pos != std::string::npos) { | ||
last_substr = name.substr(start_pos + 1, name.size()); | ||
} else { | ||
last_substr = name; | ||
} | ||
// Removes the non-alphabetic characters and converts to lower case. | ||
last_substr.erase( | ||
std::remove_if(last_substr.begin(), last_substr.end(), | ||
[](unsigned char c) { return !absl::ascii_isalpha(c); }), | ||
last_substr.end()); | ||
last_substr = absl::AsciiStrToLower(last_substr); | ||
return last_substr; | ||
} | ||
|
||
} // namespace | ||
|
||
std::string TfliteNodeNamespaceHeuristic( | ||
absl::string_view node_label, | ||
absl::Span<const std::string> candidate_names) { | ||
if (candidate_names.empty()) return ""; | ||
if (candidate_names.size() == 1) { | ||
return candidate_names[0]; | ||
} | ||
|
||
// Removes any underscores in `node_label`. | ||
const std::string node_label_substr = | ||
absl::StrReplaceAll(node_label, {{"_", ""}}); | ||
|
||
// Default the name to the first candidate name. | ||
std::string result_name = candidate_names[0]; | ||
int min_distance = std::numeric_limits<int>::max(); | ||
// Sets the max distance threshold to be three times the length of the node | ||
// label substring. If the distance is larger than the threshold, it's | ||
// considered as irrelevant. | ||
const int max_distance_threshold = 3 * node_label_substr.length(); | ||
// Iterates backwards is critical in finding a better match. | ||
for (auto name_it = std::rbegin(candidate_names); | ||
name_it != std::rend(candidate_names); ++name_it) { | ||
const std::string last_substr = PreprocessCandidateName(*name_it); | ||
// Skips the empty string to avoid false matching. | ||
if (last_substr.empty()) { | ||
continue; | ||
} | ||
int cur_distance = EditDistance(node_label_substr, last_substr); | ||
if (cur_distance > max_distance_threshold) { | ||
continue; | ||
} | ||
if (cur_distance < min_distance) { | ||
min_distance = cur_distance; | ||
result_name = *name_it; | ||
} | ||
} | ||
return result_name; | ||
} | ||
|
||
} // namespace visualization_client | ||
} // namespace tooling |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_GOOGLE_TOOLING_TOOLS_NAMESPACE_HEURISTICS_H_ | ||
#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_GOOGLE_TOOLING_TOOLS_NAMESPACE_HEURISTICS_H_ | ||
|
||
#include <string> | ||
|
||
#include "absl/strings/string_view.h" | ||
#include "absl/types/span.h" | ||
namespace tooling { | ||
namespace visualization_client { | ||
|
||
// Obtains the best matching namespace for the TFLite node based on the provided | ||
// node label and candidate names. | ||
// | ||
// The candidate names are obtained from the tensor names. The node namespace is | ||
// obtained by the following steps: | ||
// 1. If there are no candidate names, returns an empty string. | ||
// 2. If there is only one candidate name, returns the candidate name. | ||
// 3. If there are multiple candidate names, iterates backwards and returns the | ||
// candidate name with the minimum edit distance to the node label. | ||
std::string TfliteNodeNamespaceHeuristic( | ||
absl::string_view node_label, | ||
absl::Span<const std::string> candidate_names); | ||
|
||
} // namespace visualization_client | ||
} // namespace tooling | ||
|
||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_GOOGLE_TOOLING_TOOLS_NAMESPACE_HEURISTICS_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters