Skip to content

Commit

Permalink
Add lowering for TensorToStreamOp to LLVM.
Browse files Browse the repository at this point in the history
  • Loading branch information
ingomueller-net committed Mar 31, 2023
1 parent 45255aa commit cffa426
Show file tree
Hide file tree
Showing 4 changed files with 280 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,21 @@ StateType StateTypeComputer::operator()(
return StateType::get(context, {indexType, viewType});
}

/// The state of TensorToStreamOp consists of a single number that corresponds
/// to the index of the next tensor slice returned by the iterator and the input
/// tensor. Pseudocode:
///
/// template <typename TensorType>
/// struct { size_t currentIndex; TensorType view; }
template <>
StateType StateTypeComputer::operator()(
TensorToStreamOp op, llvm::SmallVector<StateType> /*upstreamStateTypes*/) {
MLIRContext *context = op->getContext();
Type indexType = IndexType::get(context);
Type tensorType = op.getInput().getType();
return StateType::get(context, {indexType, tensorType});
}

/// The state of ValueToStreamOp consists a Boolean indicating whether it has
/// already returned its value (which is initialized to false and set to true in
/// the first call to next) and the value it converts to a stream.
Expand Down Expand Up @@ -173,6 +188,7 @@ mlir::iterators::IteratorAnalysis::IteratorAnalysis(
MapOp,
ReduceOp,
TabularViewToStreamOp,
TensorToStreamOp,
ValueToStreamOp
// clang-format on
>([&](auto op) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
Expand Down Expand Up @@ -1219,6 +1220,149 @@ static Value buildStateCreation(TabularViewToStreamOp op,
ValueRange{initialIndex, tabularView});
}

//===----------------------------------------------------------------------===//
// TensorToStreamOp.
//===----------------------------------------------------------------------===//

/// Builds IR that (re)sets the current index to zero. Possible output:
///
/// %0 = arith.constant 0 : index
/// %1 = iterators.insertvalue %0 into %arg0[0] :
/// !iterators.state<index, !tensor_type>
static Value buildOpenBody(TensorToStreamOp op, OpBuilder &builder,
Value initialState,
ArrayRef<IteratorInfo> upstreamInfos) {
Location loc = op.getLoc();
ImplicitLocOpBuilder b(loc, builder);

// Insert constant zero into state.
Value zeroValue = b.create<arith::ConstantIndexOp>(0);
return b.create<iterators::InsertValueOp>(initialState, b.getIndexAttr(0),
zeroValue);
}

/// Builds IR that extracts a slice of the desired size from the input tensor at
/// current index and increments that index by the output tensor size.
/// Pseudo-code:
///
/// if current_index + output_size > len(input):
/// return {}
/// slice = input[current_index..current_index+output_size-1]
/// current_index += output_size
/// return slice
///
/// Possible output:
///
/// %0 = iterators.extractvalue %arg0[0] : !iterators.state<index, !tensor_type>
/// %1 = iterators.extractvalue %arg0[1] : !iterators.state<index, !tensor_type>
/// %c0 = arith.constant 0 : index
/// %dim = tensor.dim %1, %c0 : !tensor_type
/// %2 = arith.cmpi slt, %0, %dim : index
/// %3:2 = scf.if %2 -> (!iterators.state<index, !tensor_type>, tensor<2xi32>) {
/// %c2 = arith.constant 2 : index
/// %4 = arith.addi %c2, %0 : index
/// %state = iterators.insertvalue %4 into %arg0[0] :
/// !iterators.state<index, !tensor_type>
/// %extracted_slice = tensor.extract_slice %1[%0] [2] [1] :
/// !tensor_type to tensor<2xi32>
/// scf.yield %state, %extracted_slice :
/// !iterators.state<index, !tensor_type>, tensor<2xi32>
/// } else {
/// %extracted_slice = tensor.extract_slice %1[0] [2] [1] :
/// !tensor_type to tensor<2xi32>
/// scf.yield %arg0, %extracted_slice :
/// !iterators.state<index, !tensor_type>, tensor<2xi32>
/// }
static llvm::SmallVector<Value, 4>
buildNextBody(TensorToStreamOp op, OpBuilder &builder, Value initialState,
ArrayRef<IteratorInfo> upstreamInfos, Type elementType) {
Location loc = op->getLoc();
ImplicitLocOpBuilder b(loc, builder);
Type indexType = b.getIndexType();

RankedTensorType elementTensorType = elementType.cast<RankedTensorType>();
int64_t elementTensorSize = elementTensorType.getDimSize(0);

// Extract current index.
Value currentIndex = b.create<iterators::ExtractValueOp>(
indexType, initialState, b.getIndexAttr(0));

// Extract input tensor.
auto stateType = initialState.getType().cast<StateType>();
Type inputTensorType = stateType.getFieldTypes()[1];
Value inputTensor = b.create<iterators::ExtractValueOp>(
inputTensorType, initialState, b.getIndexAttr(1));

// Test if we have reached the end of the range.
Value lastIndex = b.create<tensor::DimOp>(inputTensor, 0);

ArithBuilder ab(b, b.getLoc());
Value hasNext = ab.slt(currentIndex, lastIndex);
auto ifOp = b.create<scf::IfOp>(
/*condition=*/hasNext,
/*thenBuilder=*/
[&](OpBuilder &builder, Location loc) {
ImplicitLocOpBuilder b(loc, builder);

// Increment index and update state.
Value increment = b.create<arith::ConstantIndexOp>(elementTensorSize);
Value updatedCurrentIndex =
b.create<arith::AddIOp>(indexType, increment, currentIndex);
Value updatedState = b.create<iterators::InsertValueOp>(
initialState, b.getIndexAttr(0), updatedCurrentIndex);

// Extract current slice from input tensor.
Value nextElement = b.create<tensor::ExtractSliceOp>(
elementTensorType, inputTensor,
/*offsets=*/ArrayRef<OpFoldResult>{currentIndex},
/*sizes=*/ArrayRef<OpFoldResult>{b.getIndexAttr(elementTensorSize)},
/*strides=*/ArrayRef<OpFoldResult>{b.getIndexAttr(1)});

b.create<scf::YieldOp>(ValueRange{updatedState, nextElement});
},
/*elseBuilder=*/
[&](OpBuilder &builder, Location loc) {
// Don't modify state; return first slice.
// TODO(ingomueller): Is it always safe to extract that slice?
ImplicitLocOpBuilder b(loc, builder);
Value nextElement = b.create<tensor::ExtractSliceOp>(
elementTensorType, inputTensor,
/*offsets=*/ArrayRef<OpFoldResult>{b.getIndexAttr(0)},
/*sizes=*/ArrayRef<OpFoldResult>{b.getIndexAttr(elementTensorSize)},
/*strides=*/ArrayRef<OpFoldResult>{b.getIndexAttr(1)});
b.create<scf::YieldOp>(ValueRange{initialState, nextElement});
});

Value finalState = ifOp->getResult(0);
Value nextElement = ifOp.getResult(1);
return {finalState, hasNext, nextElement};
}

/// Builds IR that does nothing. The TensorToStreamOp does not need to do
/// anything on close.
static Value buildCloseBody(TensorToStreamOp /*op*/, OpBuilder & /*rewriter*/,
Value initialState,
ArrayRef<IteratorInfo> /*upstreamInfos*/) {
return initialState;
}

/// Builds IR that initializes the iterator state with the input tensor and an
/// initial current index (whose value doesn't matter). Possible output:
///
/// %0 = ...
/// %1 = arith.constant 0 : index
/// %2 = iterators.createstate(%1, %0) : !iterators.state<index, !tensor_type>
static Value buildStateCreation(TensorToStreamOp op,
TensorToStreamOp::Adaptor adaptor,
OpBuilder &builder, StateType stateType) {
Location loc = op.getLoc();
ImplicitLocOpBuilder b(loc, builder);
Value tensor = adaptor.getInput();
Value initialCurrentIndex = b.create<arith::ConstantIndexOp>(0);
return b.create<CreateStateOp>(stateType,
ValueRange{initialCurrentIndex, tensor});
}

//===----------------------------------------------------------------------===//
// ValueToStreamOp.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1363,6 +1507,7 @@ static Value buildOpenBody(Operation *op, OpBuilder &builder,
MapOp,
ReduceOp,
TabularViewToStreamOp,
TensorToStreamOp,
ValueToStreamOp
// clang-format on
>([&](auto op) {
Expand All @@ -1382,6 +1527,7 @@ buildNextBody(Operation *op, OpBuilder &builder, Value initialState,
MapOp,
ReduceOp,
TabularViewToStreamOp,
TensorToStreamOp,
ValueToStreamOp
// clang-format on
>([&](auto op) {
Expand All @@ -1402,6 +1548,7 @@ static Value buildCloseBody(Operation *op, OpBuilder &builder,
MapOp,
ReduceOp,
TabularViewToStreamOp,
TensorToStreamOp,
ValueToStreamOp
// clang-format on
>([&](auto op) {
Expand All @@ -1420,6 +1567,7 @@ static Value buildStateCreation(IteratorOpInterface op, OpBuilder &builder,
MapOp,
ReduceOp,
TabularViewToStreamOp,
TensorToStreamOp,
ValueToStreamOp
// clang-format on
>([&](auto op) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// RUN: iterators-opt %s \
// RUN: -convert-iterators-to-llvm \
// RUN: | FileCheck --enable-var-scope %s

// CHECK-LABEL: func private @iterators.tensor_to_stream.close.{{[0-9]+}}(
// CHECK-SAME: %[[ARG0:.*]]: !iterators.state<index, [[tensorType:.*]]>) ->
// CHECK-SAME: !iterators.state<index, [[tensorType]]> {
// CHECK-NEXT: return %[[arg0:.*]] : !iterators.state<index, [[tensorType]]>
// CHECK-NEXT: }

// CHECK-LABEL: func private @iterators.tensor_to_stream.next.{{[0-9]+}}(
// CHECK-SAME: %[[ARG0:.*]]: !iterators.state<index, [[tensorType:.*]]>) ->
// CHECK-SAME: (!iterators.state<index, [[tensorType]]>, i1, [[tensorSliceType:.*]]>) {
// CHECK-NEXT: %[[V0:.*]] = iterators.extractvalue %[[arg0:.*]][0] : !iterators.state<index, [[tensorType]]>
// CHECK-NEXT: %[[V1:.*]] = iterators.extractvalue %[[arg0]][1] : !iterators.state<index, [[tensorType]]>
// CHECK-NEXT: %[[Vx:.*]] = arith.constant 0 : index
// CHECK-NEXT: %[[V2:.*]] = tensor.dim %[[V1]], %[[Vx]] : [[tensorType]]
// CHECK-NEXT: %[[V3:.*]] = arith.cmpi slt, %[[V0]], %[[V2]] : index
// CHECK-NEXT: %[[V4:.*]]:2 = scf.if %[[V3]] -> (!iterators.state<index, [[tensorType]]>, [[tensorSliceType]]>) {
// CHECK-NEXT: %[[C1:.*]] = arith.constant 2 : index
// CHECK-NEXT: %[[V5:.*]] = arith.addi %[[C1]], %[[V0]] : index
// CHECK-NEXT: %[[V6:.*]] = iterators.insertvalue %[[V5]] into %[[arg0]][0] : !iterators.state<index, [[tensorType]]>
// CHECK-NEXT: %[[V7:.*]] = tensor.extract_slice %[[V1]][%[[V0]]] [2] [1] : [[tensorType]] to [[tensorSliceType]]>
// CHECK-NEXT: scf.yield %[[V6]], %[[V7]] : !iterators.state<index, [[tensorType]]>, [[tensorSliceType]]>
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[V5:.*]] = tensor.extract_slice %[[V1]][0] [2] [1] : [[tensorType]] to [[tensorSliceType]]>
// CHECK-NEXT: scf.yield %[[arg0]], %[[V5]] : !iterators.state<index, [[tensorType]]>, [[tensorSliceType]]>
// CHECK-NEXT: }
// CHECK-NEXT: return %[[V4]]#0, %[[V3]], %[[V4]]#1 : !iterators.state<index, [[tensorType]]>, i1, [[tensorSliceType]]>
// CHECK-NEXT: }

// CHECK-LABEL: func private @iterators.tensor_to_stream.open.{{[0-9]+}}(
// CHECK-SAME: %[[ARG0:.*]]: !iterators.state<index, [[tensorType:.*]]>) ->
// CHECK-SAME: !iterators.state<index, [[tensorType]]> {
// CHECK-NEXT: %[[V0:.*]] = arith.constant 0 : index
// CHECK-NEXT: %[[V1:.*]] = iterators.insertvalue %[[V0]] into %[[ARG0]][0] : !iterators.state<index, [[tensorType]]>
// CHECK-NEXT: return %[[V1]] : !iterators.state<index, [[tensorType]]>
// CHECK-NEXT: }

func.func @main(%tensor : tensor<?xi32>) {
// CHECK-LABEL: func.func @main(
// CHECK-SAME: %[[arg0:.*]]: [[tensorType:.*]]) {
%stream = iterators.tensor_to_stream %tensor :
tensor<?xi32> to !iterators.stream<tensor<2xi32>>
// CHECK-NEXT: %[[V1:.*]] = arith.constant 0 : index
// CHECK-NEXT: %[[V2:.*]] = iterators.createstate(%[[V1]], %[[arg0]]) : !iterators.state<index, [[tensorType]]>
return
// CHECK-NEXT: return
}
// CHECK-NEXT: }
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// RUN: iterators-opt %s \
// RUN: -convert-iterators-to-llvm \
// RUN: -inline -decompose-iterator-states -canonicalize \
// RUN: -one-shot-bufferize=bufferize-function-boundaries \
// RUN: -expand-strided-metadata -finalize-memref-to-llvm \
// RUN: -lower-affine -canonicalize \
// RUN: -convert-scf-to-cf \
// RUN: -convert-func-to-llvm \
// RUN: -canonicalize \
// RUN: -convert-cf-to-llvm \
// RUN: | mlir-cpu-runner -e main -entry-point-result=void \
// RUN: | FileCheck %s

!struct_i32i32 = !llvm.struct<(i32, i32)>

func.func private @tensor2xi32_to_struct(%input : tensor<2xi32>) -> !struct_i32i32 {
%zero = arith.constant 0 : index
%one = arith.constant 1 : index
%i0 = tensor.extract %input[%zero] : tensor<2xi32>
%i1 = tensor.extract %input[%one] : tensor<2xi32>
%undef = llvm.mlir.undef : !struct_i32i32
%inserted = llvm.insertvalue %i0, %undef[0] : !struct_i32i32
%result = llvm.insertvalue %i1, %inserted[1] : !struct_i32i32
return %result : !struct_i32i32
}

func.func @test_tensor_to_stream_simple_static() {
iterators.print("test_tensor_to_stream_simple_static")
%tensor = arith.constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>
%stream = iterators.tensor_to_stream %tensor :
tensor<6xi32> to !iterators.stream<tensor<2xi32>>
%mapped = "iterators.map"(%stream) {mapFuncRef = @tensor2xi32_to_struct}
: (!iterators.stream<tensor<2xi32>>) -> (!iterators.stream<!struct_i32i32>)
"iterators.sink"(%mapped) : (!iterators.stream<!struct_i32i32>) -> ()
// CHECK-LABEL: test_tensor_to_stream_simple_static
// CHECK-NEXT: (1, 2)
// CHECK-NEXT: (3, 4)
// CHECK-NEXT: (5, 6)
// CHECK-NEXT: -
return
}

func.func @test_tensor_to_stream_simple_dynamic(%tensor : tensor<?xi32>) {
iterators.print("test_tensor_to_stream_simple_dynamic")
%stream = iterators.tensor_to_stream %tensor :
tensor<?xi32> to !iterators.stream<tensor<2xi32>>
%mapped = "iterators.map"(%stream) {mapFuncRef = @tensor2xi32_to_struct}
: (!iterators.stream<tensor<2xi32>>) -> (!iterators.stream<!struct_i32i32>)
"iterators.sink"(%mapped) : (!iterators.stream<!struct_i32i32>) -> ()
// CHECK-LABEL: test_tensor_to_stream_simple_dynamic
// CHECK-NEXT: (11, 12)
// CHECK-NEXT: (13, 14)
// CHECK-NEXT: (15, 16)
// CHECK-NEXT: -
return
}

func.func @main() {
func.call @test_tensor_to_stream_simple_static() : () -> ()

%tensor = arith.constant dense<[11, 12, 13, 14, 15, 16]> : tensor<6xi32>
%dtensor = tensor.cast %tensor : tensor<6xi32> to tensor<?xi32>
func.call @test_tensor_to_stream_simple_dynamic(%dtensor) : (tensor<?xi32>) -> ()

return
}

0 comments on commit cffa426

Please sign in to comment.