Skip to content

Commit

Permalink
Skip constant folding of convert operations with dynamic shapes.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 629956742
  • Loading branch information
gnecula authored and tensorflower-gardener committed May 2, 2024
1 parent cce3f13 commit ccf07ce
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 0 deletions.
39 changes: 39 additions & 0 deletions third_party/stablehlo/temporary.patch
Original file line number Diff line number Diff line change
Expand Up @@ -2564,4 +2564,43 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c
+} // namespace experimental
+} // namespace stablehlo
+} // namespace mlir
diff --ruN a/stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir b/stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir
--- stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir
+++ stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir
@@ -356,6 +356,18 @@
%0 = stablehlo.constant dense<[true, false]> : tensor<2xi1>
%1 = stablehlo.convert %0 : (tensor<2xi1>) -> tensor<2xi64>
return %1 : tensor<2xi64>
+}
+
+// -----
+
+// CHECK-LABEL: func @eval_convert_dynamic_shape
+func.func @eval_convert_dynamic_shape() -> tensor<?xi32> {
+ // CHECK-NOT: stablehlo.convert
+ // CHECK: [[RESULT:%.*]] = stablehlo.constant dense<[3, 4]> : tensor<2xi32>
+ // CHECK: return [[RESULT]]
+ %0 = stablehlo.constant dense<[3, 4]> : tensor<2xi32>
+ %1 = stablehlo.convert %0 : (tensor<2xi32>) -> tensor<?xi32>
+ return %1 : tensor<?xi32>
}

// -----
diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
--- stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
+++ stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
@@ -482,9 +482,10 @@
LogicalResult matchAndRewrite(ConvertOp op,
PatternRewriter& rewriter) const override {
auto resultType = op.getType();
- if (!isa<IntegerType>(resultType.getElementType()))
- return rewriter.notifyMatchFailure(op,
- "expected integer result tensor type");
+ if (!isa<IntegerType>(resultType.getElementType()) ||
+ !resultType.hasStaticShape())
+ return rewriter.notifyMatchFailure(
+ op, "expected integer result tensor type with static shapes");
auto resultBitWidth = resultType.getElementType().getIntOrFloatBitWidth();
return evalElementwise(rewriter, op, [&](APSInt operand) {
return operand.extOrTrunc(resultBitWidth);

39 changes: 39 additions & 0 deletions third_party/xla/third_party/stablehlo/temporary.patch
Original file line number Diff line number Diff line change
Expand Up @@ -2564,4 +2564,43 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c
+} // namespace experimental
+} // namespace stablehlo
+} // namespace mlir
diff --ruN a/stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir b/stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir
--- stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir
+++ stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir
@@ -356,6 +356,18 @@
%0 = stablehlo.constant dense<[true, false]> : tensor<2xi1>
%1 = stablehlo.convert %0 : (tensor<2xi1>) -> tensor<2xi64>
return %1 : tensor<2xi64>
+}
+
+// -----
+
+// CHECK-LABEL: func @eval_convert_dynamic_shape
+func.func @eval_convert_dynamic_shape() -> tensor<?xi32> {
+ // CHECK-NOT: stablehlo.convert
+ // CHECK: [[RESULT:%.*]] = stablehlo.constant dense<[3, 4]> : tensor<2xi32>
+ // CHECK: return [[RESULT]]
+ %0 = stablehlo.constant dense<[3, 4]> : tensor<2xi32>
+ %1 = stablehlo.convert %0 : (tensor<2xi32>) -> tensor<?xi32>
+ return %1 : tensor<?xi32>
}

// -----
diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
--- stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
+++ stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
@@ -482,9 +482,10 @@
LogicalResult matchAndRewrite(ConvertOp op,
PatternRewriter& rewriter) const override {
auto resultType = op.getType();
- if (!isa<IntegerType>(resultType.getElementType()))
- return rewriter.notifyMatchFailure(op,
- "expected integer result tensor type");
+ if (!isa<IntegerType>(resultType.getElementType()) ||
+ !resultType.hasStaticShape())
+ return rewriter.notifyMatchFailure(
+ op, "expected integer result tensor type with static shapes");
auto resultBitWidth = resultType.getElementType().getIntOrFloatBitWidth();
return evalElementwise(rewriter, op, [&](APSInt operand) {
return operand.extOrTrunc(resultBitWidth);

0 comments on commit ccf07ce

Please sign in to comment.