Skip to content

Commit

Permalink
[StablehloShapeRefinement] Skip constant folding of convert operation…
Browse files Browse the repository at this point in the history
…s with dynamic shapes.

PiperOrigin-RevId: 629956742
  • Loading branch information
gnecula authored and tensorflower-gardener committed May 2, 2024
1 parent 0aed1f1 commit 953bd2b
Show file tree
Hide file tree
Showing 17 changed files with 294 additions and 72 deletions.
39 changes: 39 additions & 0 deletions third_party/stablehlo/temporary.patch
Original file line number Diff line number Diff line change
Expand Up @@ -2564,4 +2564,43 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c
+} // namespace experimental
+} // namespace stablehlo
+} // namespace mlir
diff --ruN a/stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir b/stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir
--- stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir
+++ stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir
@@ -356,6 +356,18 @@
%0 = stablehlo.constant dense<[true, false]> : tensor<2xi1>
%1 = stablehlo.convert %0 : (tensor<2xi1>) -> tensor<2xi64>
return %1 : tensor<2xi64>
+}
+
+// -----
+
+// CHECK-LABEL: func @eval_convert_dynamic_shape
+func.func @eval_convert_dynamic_shape() -> tensor<?xi32> {
+ // CHECK-NOT: stablehlo.convert
+ // CHECK: [[RESULT:%.*]] = stablehlo.constant dense<[3, 4]> : tensor<2xi32>
+ // CHECK: return [[RESULT]]
+ %0 = stablehlo.constant dense<[3, 4]> : tensor<2xi32>
+ %1 = stablehlo.convert %0 : (tensor<2xi32>) -> tensor<?xi32>
+ return %1 : tensor<?xi32>
}

// -----
diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
--- stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
+++ stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
@@ -482,9 +482,10 @@
LogicalResult matchAndRewrite(ConvertOp op,
PatternRewriter& rewriter) const override {
auto resultType = op.getType();
- if (!isa<IntegerType>(resultType.getElementType()))
- return rewriter.notifyMatchFailure(op,
- "expected integer result tensor type");
+ if (!isa<IntegerType>(resultType.getElementType()) ||
+ !resultType.hasStaticShape())
+ return rewriter.notifyMatchFailure(
+ op, "expected integer result tensor type with static shapes");
auto resultBitWidth = resultType.getElementType().getIntOrFloatBitWidth();
return evalElementwise(rewriter, op, [&](APSInt operand) {
return operand.extOrTrunc(resultBitWidth);

39 changes: 39 additions & 0 deletions third_party/xla/third_party/stablehlo/temporary.patch
Original file line number Diff line number Diff line change
Expand Up @@ -2564,4 +2564,43 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c
+} // namespace experimental
+} // namespace stablehlo
+} // namespace mlir
diff --ruN a/stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir b/stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir
--- stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir
+++ stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir
@@ -356,6 +356,18 @@
%0 = stablehlo.constant dense<[true, false]> : tensor<2xi1>
%1 = stablehlo.convert %0 : (tensor<2xi1>) -> tensor<2xi64>
return %1 : tensor<2xi64>
+}
+
+// -----
+
+// CHECK-LABEL: func @eval_convert_dynamic_shape
+func.func @eval_convert_dynamic_shape() -> tensor<?xi32> {
+ // CHECK-NOT: stablehlo.convert
+ // CHECK: [[RESULT:%.*]] = stablehlo.constant dense<[3, 4]> : tensor<2xi32>
+ // CHECK: return [[RESULT]]
+ %0 = stablehlo.constant dense<[3, 4]> : tensor<2xi32>
+ %1 = stablehlo.convert %0 : (tensor<2xi32>) -> tensor<?xi32>
+ return %1 : tensor<?xi32>
}

// -----
diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
--- stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
+++ stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
@@ -482,9 +482,10 @@
LogicalResult matchAndRewrite(ConvertOp op,
PatternRewriter& rewriter) const override {
auto resultType = op.getType();
- if (!isa<IntegerType>(resultType.getElementType()))
- return rewriter.notifyMatchFailure(op,
- "expected integer result tensor type");
+ if (!isa<IntegerType>(resultType.getElementType()) ||
+ !resultType.hasStaticShape())
+ return rewriter.notifyMatchFailure(
+ op, "expected integer result tensor type with static shapes");
auto resultBitWidth = resultType.getElementType().getIntOrFloatBitWidth();
return evalElementwise(rewriter, op, [&](APSInt operand) {
return operand.extOrTrunc(resultBitWidth);

5 changes: 3 additions & 2 deletions third_party/xla/xla/service/gpu/fusions/concatenate_mlir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ absl::Status MlirConcatenateFusion::EmitEntryFunction(
const HloFusionInstruction& fusion) const {
const auto& root_computation = computations.FindPartitionedComputation(
fusion.fused_instructions_computation());
const auto* concat = &analysis_.fusion_hero(0).instruction();
mlir::ImplicitLocOpBuilder builder(entry_function.getLoc(), entry_function);
builder.setInsertionPointToStart(entry_function.addEntryBlock());
auto* ctx = entry_function.getContext();
Expand All @@ -101,8 +100,10 @@ absl::Status MlirConcatenateFusion::EmitEntryFunction(
ComputeThreadIdToInputIndexing(
/*root_index=*/0, /*hero_operand_index=*/0, ctx)
.value();
auto epilogue_indexing = ComputeEpilogueInputToOutputIndexing(concat, ctx);
auto epilogue_indexing = ComputeEpilogueInputToOutputIndexing(
analysis_.fusion_hero(0), analysis_.fusion_root(0), ctx);

const auto* concat = &analysis_.fusion_hero(0).instruction();
for (auto [operand_index, operand] : llvm::enumerate(concat->operands())) {
auto input_to_output_map =
*ComputeInputToOutputIndexing(concat, /*input_id=*/operand_index, ctx)
Expand Down
21 changes: 21 additions & 0 deletions third_party/xla/xla/service/gpu/fusions/loop_mlir_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,27 @@ TEST_F(MlirLoopFusionTest, TupleBitcast) {
EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3}));
}

TEST_F(MlirLoopFusionTest, DynamicSliceWith64BitInput) {
// Lowering this kernel with 32 bit indices causes an underflow of `c`,
// resulting in slicing the last four elements instead of the first four.
constexpr auto kHloString = R"(
%fused_computation {
%p0 = s64[] parameter(0)
%p1 = f64[5] parameter(1)
ROOT slice = f64[4] dynamic-slice(%p1, %p0), dynamic_slice_sizes={4}
}
ENTRY main {
%c = s64[] constant(-1000000000000)
%p0 = f64[5] parameter(0)
ROOT %fusion = f64[4]{0} fusion(%c, %p0), kind=kInput, calls=%fused_computation
})";
TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"(
// CHECK: dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<index, 64 : i32>>
)"));
EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3}));
}

} // namespace
} // namespace gpu
} // namespace xla
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ absl::flat_hash_map<const HloInstruction*, int> PartitionGraphByIndexing(
for (auto* user : instr->users()) {
auto user_indexing = indexing_for_instr(user);
if (user->opcode() == HloOpcode::kConcatenate ||
user->opcode() == HloOpcode::kSelect ||
user->opcode() == HloOpcode::kTuple ||
(instr_indexing && user_indexing != *instr_indexing)) {
instr_indexing = std::nullopt;
Expand Down Expand Up @@ -151,10 +152,9 @@ EpilogueSpecification EpilogueSpecification::FromOutputIndexing(
result.index_ranges.push_back(sym.upper + 1);
}
}

auto* hero = root_to_hero[root];
auto epilogue_indexing =
ComputeEpilogueInputToOutputIndexing(hero, mlir_context);
auto epilogue_indexing = ComputeEpilogueInputToOutputIndexing(
{*hero, &analysis.fusion()}, {*root, &analysis.fusion()}, mlir_context);
auto root_indexing = ComposeIndexingMaps(*indexing, epilogue_indexing);

result.root_indexing.push_back(root_indexing.GetAffineMap());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ class PartitionedComputation {
// The roots (return values of the function).
std::vector<const HloInstruction*> roots;

// The ranges of the indices that the subgraph is called with.
// The ranges of the indices that the subgraph is called with (dimensions
// and symbols).
std::vector<int64_t> index_ranges;

// Maps from raw indices to root indices.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1174,27 +1174,29 @@ absl::StatusOr<SmallVector<Value>> SubgraphToMlir(

emit_instr = [&](const HloInstruction* instr,
ValueRange indices) -> absl::StatusOr<SmallVector<Value>> {
// TODO(jreiffers): Check dominance, e.g.:
//
// padding_value = log(param)
// pad = pad(bar, padding_value)
// broadcast = broadcast(padding_value)
// pad + broadcast
//
// If padding_value was first emitted in the context of pad, it'll be
// inside an scf.if. For now this doesn't matter, because the indexing
// is considered to be different, but once the partitioner is smarter,
// it will matter.
//
// Also, this caching should be combined with parameter caching.
std::vector<void*> indices_ptrs;
indices_ptrs.reserve(indices.size());
for (auto index : indices) {
indices_ptrs.push_back(index.getAsOpaquePointer());
}
auto& entry = cached_instructions[std::make_pair(instr, indices_ptrs)];
// Only use the entry if its parent block is still in scope. Note that this
// should always be the case normally - if not, we risk exponential code
// size.
if (!entry.empty()) {
return entry;
auto* entry_block = entry.front().getParentBlock();
auto* insertion_block = builder.getInsertionBlock();
while (insertion_block != nullptr) {
if (insertion_block == entry_block) return entry;
if (insertion_block->getParentOp()) {
insertion_block = insertion_block->getParentOp()->getBlock();
} else {
insertion_block = nullptr;
VLOG(2) << "Failed dominance check while looking up cache for "
<< instr->ToShortString()
<< ". This is a bug in the computation partitioner.";
}
}
}

TF_ASSIGN_OR_RETURN(auto lowered_instr,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,31 @@ bool Needs64Bits(const Shape& shape) {
: absl::c_any_of(shape.tuple_shapes(), Needs64Bits);
}

bool Is64BitIndex(const HloInstruction* instr, int operand) {
const auto& shape = instr->operand(operand)->shape();
return shape.element_type() == PrimitiveType::S64 ||
shape.element_type() == PrimitiveType::U64;
}

bool Needs64BitIndices(const HloComputation* computation) {
for (auto* instr : computation->instructions()) {
// Check if any HLO instructions directly take 64 bit indices as operands.
switch (instr->opcode()) {
case HloOpcode::kDynamicSlice:
case HloOpcode::kDynamicUpdateSlice:
for (int i = 1; i < instr->operand_count(); ++i) {
if (Is64BitIndex(instr, i)) return true;
}
break;
case HloOpcode::kGather:
case HloOpcode::kScatter:
CHECK(instr->shape().IsArray()) << "Variadic scatter is unsupported.";
if (Is64BitIndex(instr, 1)) return true;
break;
default:
break;
}

if (Needs64Bits(instr->shape()) ||
absl::c_any_of(instr->called_computations(), Needs64BitIndices)) {
return true;
Expand Down
54 changes: 37 additions & 17 deletions third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ struct MlirReductionFusion::EmitterState {
MlirReductionFusion::MlirReductionFusion(const HloFusionAnalysis& analysis)
: ReductionFusionBase(analysis) {
absl::flat_hash_set<const HloInstruction*> seen_heroes;
const auto& is_reduction_root =
reduction_info().GetGroups().is_reduction_root;
first_reduction_root_index_ = std::distance(
is_reduction_root.begin(), absl::c_find(is_reduction_root, true));
for (auto [root, hero, is_reduction] :
llvm::zip(analysis.fusion_roots(), analysis.fusion_heroes(),
reduction_info().GetGroups().is_reduction_root)) {
Expand All @@ -116,8 +120,7 @@ MlirReductionFusion::MlirReductionFusion(const HloFusionAnalysis& analysis)

bool MlirReductionFusion::IsSupported(const HloFusionAnalysis& analysis) {
auto info = ReductionInfo::Create(analysis);
return info.GetGroups().grouped_roots.size() == 1 && info.IsRaceFree() &&
!absl::c_linear_search(info.GetGroups().is_reduction_root, false);
return info.GetGroups().grouped_roots.size() == 1 && info.IsRaceFree();
}

std::vector<mlir_converter::EpilogueSpecification>
Expand Down Expand Up @@ -201,25 +204,24 @@ absl::Status MlirReductionFusion::EmitReduction(EmitterState& state) const {
}
bool use_shared = !shared_tile_size.empty();

auto thread_has_output =
mlir_converter::CheckConstraints(*ComputeThreadIdToOutputIndexing(0, ctx),
thread_and_block_indices, {}, builder);
auto thread_has_output = mlir_converter::CheckConstraints(
*ComputeThreadIdToOutputIndexing(first_reduction_root_index_, ctx),
thread_and_block_indices, {}, builder);

HloValueMap inits;
llvm::SmallVector<Value> outputs =
mlir::ValueRange(state.entry_function.getArguments().drop_front(
state.fusion.fused_parameters().size()));
HloValueMap root_output_indices;
llvm::SmallVector<Value> epilogue_input_indices;
llvm::SmallVector<Value> epilogue_input_dims;
const auto& epilogue = state.computations.epilogues().front();
epilogue_input_indices = EmitThreadAndBlockIds(builder);
int num_symbols = epilogue.root_indexing.front().getNumSymbols();
for (int i = 0; i < num_symbols; ++i) {
epilogue_input_indices.push_back(zero);
}
epilogue_input_dims = EmitThreadAndBlockIds(builder);
llvm::SmallVector<Value> epilogue_input_symbols(
epilogue.root_indexing.front().getNumSymbols(), zero);
for (auto [index, root] : llvm::enumerate(epilogue.roots)) {
root_output_indices[root] = mlir_converter::ApplyAffineMap(
epilogue.root_indexing[index], epilogue_input_indices, {}, builder);
epilogue.root_indexing[index], epilogue_input_dims,
epilogue_input_symbols, builder);
}

for (auto [index, hero] : llvm::enumerate(reduction_heroes_)) {
Expand All @@ -233,9 +235,11 @@ absl::Status MlirReductionFusion::EmitReduction(EmitterState& state) const {

auto evaluate_epilogue = [&](const HloValueMap& results,
llvm::SmallVector<Value> outputs) {
auto values = EmitEpilogue(/*epilogue_index=*/0, state.computations,
state.entry_function, results,
epilogue_input_indices, builder);
auto epilogue_indices = epilogue_input_dims;
epilogue_indices.append(epilogue_input_symbols);
auto values =
EmitEpilogue(/*epilogue_index=*/0, state.computations,
state.entry_function, results, epilogue_indices, builder);
const auto& epilogue = state.computations.epilogues().front();
for (auto root : epilogue.roots) {
for (auto [result_index, result] : llvm::enumerate(values.at(root))) {
Expand Down Expand Up @@ -348,6 +352,13 @@ HloValueMap MlirReductionFusion::EmitterState::EmitPerThreadReducedElements(
tile_indexing.GetAffineMap(), dim_values, symbol_values, builder);

llvm::SmallVector<Value> results;
struct SideOutput {
Value tensor;
llvm::SmallVector<Value> indices;
Value scalar;
int result_index;
};
llvm::SmallVector<SideOutput> side_outputs;
int start = 0;
for (auto [is_reduction, hero] :
llvm::zip(owner.reduction_info().GetGroups().is_reduction_root,
Expand All @@ -374,11 +385,20 @@ HloValueMap MlirReductionFusion::EmitterState::EmitPerThreadReducedElements(
Value value = mlir_converter::ProvideParameter(
computation, root_tuple, root_tuple->operand_index(hero),
input_indices, call_target, entry_function, builder);
results.push_back(builder.create<mlir::tensor::InsertOp>(
value, iter_args[start], input_indices));
// Tensor insertions turn into writes, so they have to happen in the
// end. This could be considered a bug in the lowering, but since we
// don't have bufferization, we need to handle it here.
side_outputs.push_back(
{iter_args[start], std::move(input_indices), value, start});
results.push_back(nullptr);
++start;
}
}
for (auto& side_output : side_outputs) {
results[side_output.result_index] =
builder.create<mlir::tensor::InsertOp>(
side_output.scalar, side_output.tensor, side_output.indices);
}
return results;
};

Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/service/gpu/fusions/reduction_mlir.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class MlirReductionFusion : public ReductionFusionBase<MlirFusionEmitterBase> {
// The roots that have reduction heroes.
std::vector<const HloInstruction*> reduction_roots_;
std::vector<const HloInstruction*> side_output_roots_;
int first_reduction_root_index_;
};

} // namespace gpu
Expand Down
25 changes: 25 additions & 0 deletions third_party/xla/xla/service/gpu/fusions/reduction_mlir_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,31 @@ TEST_F(ReductionTest, SideOutput) {
EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3}));
}

TEST_F(ReductionTest, BroadcastSideOutput) {
constexpr auto kHloString = R"(
%add {
p0 = f32[] parameter(0)
p1 = f32[] parameter(1)
ROOT add = f32[] add(p0, p1)
}
%fusion {
%p0 = f32[6,6] parameter(0)
%c0 = f32[] constant(0)
%reduce = f32[] reduce(%p0, %c0), dimensions={0,1}, to_apply=%add
%broadcast = f32[6,6] broadcast(%reduce), dimensions={}
ROOT %tuple = (f32[6,6], f32[]) tuple(%broadcast, %reduce)
}
ENTRY main {
%p0 = f32[6,6] parameter(0)
ROOT %fusion = (f32[6,6], f32[]) fusion(%p0), kind=kInput, calls=%fusion
})";

TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"(
// CHECK: @fused_computation
)"));
EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3}));
}

} // namespace
} // namespace gpu
} // namespace xla

0 comments on commit 953bd2b

Please sign in to comment.