Skip to content

Commit

Permalink
[StablehloShapeRefinement] Skip constant folding of convert operation…
Browse files Browse the repository at this point in the history
…s with dynamic shapes.

PiperOrigin-RevId: 629956742
  • Loading branch information
gnecula authored and tensorflower-gardener committed May 2, 2024
1 parent bd1c3bf commit 05bc7e5
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 8 deletions.
2 changes: 1 addition & 1 deletion tensorflow/core/ops/ops.pbtxt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
go/debugproto
go/debugproto
op {
name: "Abort"
attr {
Expand Down
39 changes: 39 additions & 0 deletions third_party/stablehlo/temporary.patch
Original file line number Diff line number Diff line change
Expand Up @@ -2564,4 +2564,43 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c
+} // namespace experimental
+} // namespace stablehlo
+} // namespace mlir
diff --ruN a/stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir b/stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir
--- stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir
+++ stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir
@@ -356,6 +356,18 @@
%0 = stablehlo.constant dense<[true, false]> : tensor<2xi1>
%1 = stablehlo.convert %0 : (tensor<2xi1>) -> tensor<2xi64>
return %1 : tensor<2xi64>
+}
+
+// -----
+
+// CHECK-LABEL: func @eval_convert_dynamic_shape
+func.func @eval_convert_dynamic_shape() -> tensor<?xi32> {
+ // CHECK-NOT: stablehlo.convert
+ // CHECK: [[RESULT:%.*]] = stablehlo.constant dense<[3, 4]> : tensor<2xi32>
+ // CHECK: return [[RESULT]]
+ %0 = stablehlo.constant dense<[3, 4]> : tensor<2xi32>
+ %1 = stablehlo.convert %0 : (tensor<2xi32>) -> tensor<?xi32>
+ return %1 : tensor<?xi32>
}

// -----
diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
--- stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
+++ stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
@@ -482,9 +482,10 @@
LogicalResult matchAndRewrite(ConvertOp op,
PatternRewriter& rewriter) const override {
auto resultType = op.getType();
- if (!isa<IntegerType>(resultType.getElementType()))
- return rewriter.notifyMatchFailure(op,
- "expected integer result tensor type");
+ if (!isa<IntegerType>(resultType.getElementType()) ||
+ !resultType.hasStaticShape())
+ return rewriter.notifyMatchFailure(
+ op, "expected integer result tensor type with static shapes");
auto resultBitWidth = resultType.getElementType().getIntOrFloatBitWidth();
return evalElementwise(rewriter, op, [&](APSInt operand) {
return operand.extOrTrunc(resultBitWidth);

39 changes: 39 additions & 0 deletions third_party/xla/third_party/stablehlo/temporary.patch
Original file line number Diff line number Diff line change
Expand Up @@ -2564,4 +2564,43 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c
+} // namespace experimental
+} // namespace stablehlo
+} // namespace mlir
diff --ruN a/stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir b/stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir
--- stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir
+++ stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir
@@ -356,6 +356,18 @@
%0 = stablehlo.constant dense<[true, false]> : tensor<2xi1>
%1 = stablehlo.convert %0 : (tensor<2xi1>) -> tensor<2xi64>
return %1 : tensor<2xi64>
+}
+
+// -----
+
+// CHECK-LABEL: func @eval_convert_dynamic_shape
+func.func @eval_convert_dynamic_shape() -> tensor<?xi32> {
+ // CHECK-NOT: stablehlo.convert
+ // CHECK: [[RESULT:%.*]] = stablehlo.constant dense<[3, 4]> : tensor<2xi32>
+ // CHECK: return [[RESULT]]
+ %0 = stablehlo.constant dense<[3, 4]> : tensor<2xi32>
+ %1 = stablehlo.convert %0 : (tensor<2xi32>) -> tensor<?xi32>
+ return %1 : tensor<?xi32>
}

// -----
diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
--- stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
+++ stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
@@ -482,9 +482,10 @@
LogicalResult matchAndRewrite(ConvertOp op,
PatternRewriter& rewriter) const override {
auto resultType = op.getType();
- if (!isa<IntegerType>(resultType.getElementType()))
- return rewriter.notifyMatchFailure(op,
- "expected integer result tensor type");
+ if (!isa<IntegerType>(resultType.getElementType()) ||
+ !resultType.hasStaticShape())
+ return rewriter.notifyMatchFailure(
+ op, "expected integer result tensor type with static shapes");
auto resultBitWidth = resultType.getElementType().getIntOrFloatBitWidth();
return evalElementwise(rewriter, op, [&](APSInt operand) {
return operand.extOrTrunc(resultBitWidth);

11 changes: 6 additions & 5 deletions third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h"

#include <cstdint>
#include <utility>
#include <vector>

#include "llvm/ADT/STLExtras.h"
Expand Down Expand Up @@ -311,13 +312,13 @@ IndexingMap ApplyIndexingOp::getIndexingMap() {
unsigned num_dimensions = affine_map.getNumDims();
std::vector<DimVar> dim_vars;
dim_vars.reserve(num_dimensions);
for (int id = 0; id < num_dimensions; ++id) {
for (unsigned id = 0; id < num_dimensions; ++id) {
dim_vars.push_back(DimVar{Interval{lower_bounds[id], upper_bounds[id]}});
}
unsigned num_symbols = affine_map.getNumSymbols();
std::vector<RangeVar> range_vars;
range_vars.reserve(num_symbols);
for (int id = 0; id < num_symbols; ++id) {
for (unsigned id = num_dimensions; id < num_symbols + num_dimensions; ++id) {
range_vars.push_back(
RangeVar{Interval{lower_bounds[id], upper_bounds[id]}});
}
Expand All @@ -338,16 +339,16 @@ struct SimplifyIndexingMap : public mlir::OpRewritePattern<ApplyIndexingOp> {
bool is_simplified = indexing_map.Simplify(GetIndexingMapForInstruction);

// Remove unused symbols.
auto unused_symbols_bit_vector = indexing_map.RemoveUnusedSymbols();
auto unused_symbols_bit_vector = indexing_map.RemoveUnusedVars();
bool symbols_removed = !unused_symbols_bit_vector.empty();

if (!is_simplified || !symbols_removed) {
return rewriter.notifyMatchFailure(indexing_op,
"IndexingMap stayed unchanged");
}
if (!unused_symbols_bit_vector.empty()) {
SmallVector<Value, 4> operands = indexing_op.getOperands().take_front(
indexing_map.GetDimensionCount());
SmallVector<Value, 4> operands;
operands.reserve(unused_symbols_bit_vector.count());
for (int i = 0; i < unused_symbols_bit_vector.size(); ++i) {
if (!unused_symbols_bit_vector[i]) {
operands.push_back(indexing_op.getOperand(i));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,40 @@
// RUN: mlir_fusions_opt %s --split-input-file -canonicalize | FileCheck %s

#map0 = affine_map<()[s0, s1] -> (1 + s0 + s1 mod 3 - s1, s0 mod 2)>
func.func @apply_indexing_no_dims(%s0: index, %s1: index) -> (index, index) {
func.func @simplify_apply_indexing(%s0: index, %s1: index) -> (index, index) {
%0:2 = xla_gpu.apply_indexing #map0 [%s0 in [-10, 10], %s1 in [0, 2]]
func.return %0#0, %0#1 : index, index
}
// CHECK: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 + 1, s0 mod 2)>

// CHECK-LABEL: func.func @apply_indexing_no_dims
// CHECK-LABEL: func.func @simplify_apply_indexing
// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index)
// CHECK: xla_gpu.apply_indexing #[[$MAP]][%[[ARG_0]] in [-10, 10]]

// -----

#map0 = affine_map<(d0, d1, d2)[s0, s1] -> (1 + s0 + s1 mod 4 - s1, s0 mod 2, d0 + d2)>
func.func @simplify_apply_indexing_remove_dims(%d0: index, %d1: index,
%d2: index, %s0: index, %s1: index) -> (index, index, index) {
%0:3 = xla_gpu.apply_indexing #map0
(%d0 in [0, 1], %d1 in [0, 2], %d2 in [0, 3])
[%s0 in [-11, 11], %s1 in [0, 3]]
func.return %0#0, %0#1, %0#2 : index, index, index
}
// CHECK: #[[$MAP:.*]] = affine_map<(d0, d1)[s0] -> (s0 + 1, s0 mod 2, d0 + d1)>

// CHECK-LABEL: func.func @simplify_apply_indexing_remove_dims
// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: index,
// CHECK-SAME: %[[ARG_1:[a-zA-Z0-9_]+]]: index,
// CHECK-SAME: %[[ARG_2:[a-zA-Z0-9_]+]]: index,
// CHECK-SAME: %[[ARG_3:[a-zA-Z0-9_]+]]: index,
// CHECK-SAME: %[[ARG_4:[a-zA-Z0-9_]+]]: index)
// CHECK: xla_gpu.apply_indexing #[[$MAP]]
// CHECK-SAME: (%[[ARG_0]] in [0, 1], %[[ARG_2]] in [0, 3])
// CHECK-SAME: [%[[ARG_3]] in [-11, 11]]

// -----

#map0 = affine_map<(d0, d1)[s0] -> (d0 + s0, 4, d1, 1, s0)>
func.func @fold_indexing_map_results(%d0: index, %d1: index, %s0: index)
-> (index, index, index, index, index) {
Expand Down

0 comments on commit 05bc7e5

Please sign in to comment.