Skip to content

Commit

Permalink
Return stream of tuples instead of LLVMStructs in ConstantStreamOp.
Browse files Browse the repository at this point in the history
  • Loading branch information
ingomueller-net committed Mar 31, 2023
1 parent 405b0f8 commit b6a3dc3
Show file tree
Hide file tree
Showing 19 changed files with 296 additions and 292 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,46 +79,38 @@ def Iterators_PrintOp : Iterators_Base_Op<"print", [
// High-level iterators
//===----------------------------------------------------------------------===//

/// Verifies that the element types of nested arrays in the $value array
/// correspond to the types of the LLVM-struct element type of the $result
/// Stream.
def Iterators_ValueMatchesElementTypePred
: CPred<[{$value.dyn_cast<ArrayAttr>().size() == 0 ||
$result.getType().dyn_cast<StreamType>().getElementType() ==
::mlir::LLVM::LLVMStructType::getLiteral(
$result.getType().getContext(),
::llvm::SmallVector<Type>(
::llvm::map_range(
$value.dyn_cast<::mlir::ArrayAttr>().begin()->dyn_cast<::mlir::ArrayAttr>(),
[](Attribute attr) { return attr.cast<TypedAttr>().getType(); }
)
)
)}]>;
def Iterators_ValueMatchesElementType
: PredOpTrait<"value type matches return type",
Iterators_ValueMatchesElementTypePred>;

def Iterators_ConstantStreamOp : Iterators_Op<"constantstream",
[Iterators_ValueMatchesElementType,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
def Iterators_ConstantStreamOp : Iterators_Op<"constantstream", [
PredOpTrait<"element type of return type must be tuple with matching types",
CPred<[{
$value.cast<::mlir::ArrayAttr>().size () == 0 ||
TupleType::get(
$value.getContext(),
::llvm::SmallVector<Type>(
::llvm::map_range(
$value.cast<::mlir::ArrayAttr>().begin()->cast<::mlir::ArrayAttr>(),
[](Attribute attr) { return attr.cast<TypedAttr>().getType(); }
))) ==
$result.getType().cast<StreamType>().getElementType()}]>>,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
let summary = "Produce a statically defined stream of elements";
let description = [{
Produces a stream of LLVM structs given in the array of arrays attribute
(each inner array being returned as a literal LLVM struct with the values
and types of the elements of that array). The inner arrays all have to have
matching types, i.e., the element at position i has to be the same for all
inner arrays, and the element type of the return Stream has to be the
corresponding literal LLVM struct. An empty array is allowed (in which case
the return Stream does not need to match anything).
Produces a stream of tuples given in the array of arrays attribute (each
inner array being returned as a built-in tuple with the values and types of
the elements of that array). The inner arrays all have to have matching
types, i.e., the element at position i has to be the same for all inner
arrays, and the element type of the return Stream has to be the
corresponding tuple tpye. An empty array is allowed (in which case the
return Stream does not need to match anything).

Example:
```mlir
%constantstream = "iterators.constantstream"() { value = [[42 : i32]] } :
() -> (!iterators.stream<!llvm.struct<(i32)>>)
() -> (!iterators.stream<tuple<i32>>)
```
}];
// TODO(ingomueller): Devise a lowering that allows to return non-LLVM types.
let arguments = (ins Iterators_HomogeneouslyTypedLLVMNumericArrayArrayAttr:$value);
let results = (outs Iterators_StreamOfLLVMStructOfNumerics:$result);
let results = (outs Iterators_StreamOfPrintableTuples:$result);
let extraClassDefinition = [{
/// Implement OpAsmOpInterface.
void $cppClass::getAsmResultNames(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,10 @@ class Iterators_StreamOf<Type elementType>
def Iterators_StreamOfLLVMStructOfNumerics
: Iterators_StreamOf<Iterators_LLVMStructOfNumerics>;

/// An Iterators stream of tuples of printable types.
def Iterators_StreamOfPrintableTuples
: Iterators_StreamOf<Iterators_TupleOfPrintableTypes>;

/// An Iterators stream of printable elements.
def Iterators_StreamOfPrintableElements
: Iterators_StreamOf<Iterators_PrintableType>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -408,20 +408,22 @@ static GlobalOp buildGlobalData(ConstantStreamOp op, OpBuilder &builder,
/// %0 = iterators.extractvalue %arg0[0] : !iterators.state<i32>
/// %c4_i32 = arith.constant 4 : i32
/// %1 = arith.cmpi slt, %0, %c4_i32 : i32
/// %2:2 = scf.if %1 -> (!iterators.state<i32>, !element_type) {
/// %2:2 = scf.if %1 -> (!iterators.state<i32>, !struct_tpe) {
/// %c1_i32 = arith.constant 1 : i32
/// %3 = arith.addi %0, %c1_i32 : i32
/// %4 = iterators.insertvalue %3 into %arg0[0] : !iterators.state<i32>
/// %5 = llvm.mlir.addressof @iterators.constant_stream_data.0 : !llvm.ptr
/// %6 = llvm.getelementptr %5[%0, 0] :
/// (!llvm.ptr<array<4 x !element_type>>, i32, i32)
/// -> !llvm.ptr, !element_type
/// %7 = llvm.load %6 : !llvm.ptr -> !element_type
/// scf.yield %4, %7 : !iterators.state<i32>, !element_type
/// (!llvm.ptr<array<4 x !struct_type>>, i32, i32)
/// -> !llvm.ptr, !struct_type
/// %7 = llvm.load %6 : !llvm.ptr -> !struct_type
/// scf.yield %4, %7 : !iterators.state<i32>, !struct_type
/// } else {
/// %3 = llvm.mlir.undef : !element_type
/// scf.yield %arg0, %3 : !iterators.state<i32>, !element_type
/// %4 = llvm.mlir.undef : !struct_tpe
/// scf.yield %arg0, %3 : !iterators.state<i32>, !struct_tpe
/// }
/// %3 = llvm.extractvalue %2#1[0] : !llvm.struct<(i32)>
/// %tuple = tuple.from_elements %3 : tuple<i32>
static llvm::SmallVector<Value, 4>
buildNextBody(ConstantStreamOp op, OpBuilder &builder, Value initialState,
ArrayRef<IteratorInfo> upstreamInfos, Type elementType) {
Expand All @@ -430,6 +432,8 @@ buildNextBody(ConstantStreamOp op, OpBuilder &builder, Value initialState,
MLIRContext *context = builder.getContext();
Type i32 = b.getI32Type();
Type opaquePtrType = LLVMPointerType::get(context);
auto tupleType = elementType.cast<TupleType>();
auto structType = LLVMStructType::getLiteral(context, tupleType.getTypes());

// Extract current index.
Value currentIndex = b.create<iterators::ExtractValueOp>(
Expand All @@ -456,26 +460,34 @@ buildNextBody(ConstantStreamOp op, OpBuilder &builder, Value initialState,
initialState, b.getIndexAttr(0), updatedCurrentIndex);

// Load element from global data at current index.
GlobalOp globalArray = buildGlobalData(op, b, elementType);
GlobalOp globalArray = buildGlobalData(op, b, structType);
Value globalPtr =
b.create<AddressOfOp>(opaquePtrType, globalArray.getName());
Value gep = b.create<GEPOp>(opaquePtrType, elementType, globalPtr,
Value gep = b.create<GEPOp>(opaquePtrType, structType, globalPtr,
ArrayRef<GEPArg>{currentIndex, 0});
Value nextElement = b.create<LoadOp>(elementType, gep);
Value nextStruct = b.create<LoadOp>(structType, gep);

b.create<scf::YieldOp>(ValueRange{updatedState, nextElement});
b.create<scf::YieldOp>(ValueRange{updatedState, nextStruct});
},
/*elseBuilder=*/
[&](OpBuilder &builder, Location loc) {
ImplicitLocOpBuilder b(loc, builder);

// Don't modify state; return undef element.
Value nextElement = b.create<UndefOp>(elementType);
b.create<scf::YieldOp>(ValueRange{initialState, nextElement});
Value nextStruct = b.create<UndefOp>(structType);
b.create<scf::YieldOp>(ValueRange{initialState, nextStruct});
});

// Convert LLVM struct to tuple.
Value nextStruct = ifOp.getResult(1);
SmallVector<Value> elements;
for (auto [i, fieldType] : llvm::enumerate(structType.getBody())) {
auto element = b.create<LLVM::ExtractValueOp>(fieldType, nextStruct, i);
elements.push_back(element);
}
Value nextElement = b.create<tuple::FromElementsOp>(elementType, elements);

Value finalState = ifOp->getResult(0);
Value nextElement = ifOp.getResult(1);
return {finalState, hasNext, nextElement};
}

Expand Down Expand Up @@ -781,8 +793,22 @@ buildNextBody(MapOp op, OpBuilder &builder, Value initialState,
[&](OpBuilder &builder, Location loc) {
// Return undefined value.
ImplicitLocOpBuilder b(loc, builder);
Value undef = b.create<LLVM::UndefOp>(elementType);
b.create<scf::YieldOp>(undef);
// TODO(ingomueller): Find a more extensible design.
Value defaultElement;
if (auto tupleType = elementType.dyn_cast<TupleType>()) {
// Special case for tuples: hope that field types are undef'able.
SmallVector<Value> fieldValues;
for (Type fieldType : tupleType.getTypes()) {
auto fieldValue = b.create<LLVM::UndefOp>(fieldType);
fieldValues.push_back(fieldValue);
}
defaultElement =
b.create<tuple::FromElementsOp>(tupleType, fieldValues);
} else {
// Default case: hope that type is undef'able.
defaultElement = b.create<LLVM::UndefOp>(elementType);
}
b.create<scf::YieldOp>(defaultElement);
});
Value mappedElement = ifOp.getResult(0);

Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
// RUN: iterators-opt %s -convert-iterators-to-llvm \
// RUN: | FileCheck --enable-var-scope %s

!element_type = !llvm.struct<(i32)>

// CHECK-LABEL: func private @iterators.constantstream.close.{{[0-9]+}}(%{{.*}}: !iterators.state<i32>) -> !iterators.state<i32>
// CHECK-NEXT: return %[[arg0:.*]] : !iterators.state<i32>
// CHECK-NEXT: }
Expand All @@ -28,7 +26,7 @@
// CHECK-NEXT: llvm.return %[[V16]] : !llvm.array<4 x struct<(i32)>>
// CHECK-NEXT: }

// CHECK-LABEL: func private @iterators.constantstream.next.{{[0-9]+}}(%{{.*}}: !iterators.state<i32>) -> (!iterators.state<i32>, i1, !llvm.struct<(i32)>)
// CHECK-LABEL: func private @iterators.constantstream.next.{{[0-9]+}}(%{{.*}}: !iterators.state<i32>) -> (!iterators.state<i32>, i1, tuple<i32>)
// CHECK-NEXT: %[[V0:.*]] = iterators.extractvalue %[[arg0:.*]][0] : !iterators.state<i32>
// CHECK-NEXT: %[[V1:.*]] = arith.constant 4 : i32
// CHECK-NEXT: %[[V2:.*]] = arith.cmpi slt, %[[V0]], %[[V1]] : i32
Expand All @@ -44,7 +42,9 @@
// CHECK-NEXT: %[[Vb:.*]] = llvm.mlir.undef : !llvm.struct<(i32)>
// CHECK-NEXT: scf.yield %[[arg0]], %[[Vb]] : !iterators.state<i32>, !llvm.struct<(i32)>
// CHECK-NEXT: }
// CHECK-NEXT: return %[[V3]]#0, %[[V2]], %[[V3]]#1 : !iterators.state<i32>, i1, !llvm.struct<(i32)>
// CHECK-NEXT: %[[Vc:.*]] = llvm.extractvalue %[[V3]]#1[0] : !llvm.struct<(i32)>
// CHECK-NEXT: %[[Vd:.*]] = tuple.from_elements %[[Vc]] : tuple<i32>
// CHECK-NEXT: return %[[V3]]#0, %[[V2]], %[[Vd]] : !iterators.state<i32>, i1, tuple<i32>
// CHECK-NEXT: }

// CHECK-LABEL: func private @iterators.constantstream.open.{{[0-9]+}}(%{{.*}}: !iterators.state<i32>) -> !iterators.state<i32>
Expand All @@ -57,7 +57,7 @@ func.func @main() {
// CHECK-LABEL: func.func @main()
%input = "iterators.constantstream"()
{ value = [[0 : i32], [1 : i32], [2 : i32], [3 : i32]] }
: () -> (!iterators.stream<!element_type>)
: () -> (!iterators.stream<tuple<i32>>)
// CHECK-NEXT: %[[V0:.*]] = arith.constant 0 : i32
// CHECK-NEXT: %[[V1:.*]] = iterators.createstate(%[[V0]]) : !iterators.state<i32>
return
Expand Down
38 changes: 18 additions & 20 deletions experimental/iterators/test/Conversion/IteratorsToLLVM/filter.mlir
Original file line number Diff line number Diff line change
@@ -1,35 +1,33 @@
// RUN: iterators-opt %s -convert-iterators-to-llvm \
// RUN: | FileCheck --enable-var-scope %s

!element_type = !llvm.struct<(i32)>

// CHECK-LABEL: func.func private @iterators.filter.close.{{[0-9]+}}(%{{.*}}: !iterators.state<!iterators.state<i32>>) -> !iterators.state<!iterators.state<i32>> {
// CHECK-NEXT: %[[V0:.*]] = iterators.extractvalue %[[arg0:.*]][0] : !iterators.state<!iterators.state<i32>>
// CHECK-NEXT: %[[V1:.*]] = call @iterators.{{[a-zA-Z]+}}.close.{{[0-9]+}}(%[[V0]]) : ([[upstreamStateType:.*]]) -> [[upstreamStateType]]
// CHECK-NEXT: %[[V2:.*]] = iterators.insertvalue %[[V1]] into %[[arg0]][0] : !iterators.state<!iterators.state<i32>>
// CHECK-NEXT: return %[[V2]] : !iterators.state<!iterators.state<i32>>
// CHECK-NEXT: }

// CHECK-LABEL: func.func private @iterators.filter.next.{{[0-9]+}}(%{{.*}}: !iterators.state<!iterators.state<i32>>) -> (!iterators.state<!iterators.state<i32>>, i1, !llvm.struct<(i32)>)
// CHECK-LABEL: func.func private @iterators.filter.next.{{[0-9]+}}(%{{.*}}: !iterators.state<!iterators.state<i32>>) -> (!iterators.state<!iterators.state<i32>>, i1, tuple<i32>)
// CHECK-NEXT: %[[V0:.*]] = iterators.extractvalue %[[arg0:.*]][0] : !iterators.state<!iterators.state<i32>>
// CHECK-NEXT: %[[V1:.*]]:3 = scf.while (%[[arg1:.*]] = %[[V0]]) : ([[upstreamStateType:.*]]) -> ([[upstreamStateType]], i1, !llvm.struct<(i32)>) {
// CHECK-NEXT: %[[V3:.*]]:3 = func.call @iterators.{{[a-zA-Z]+}}.next.0(%[[arg1]]) : ([[upstreamStateType]]) -> ([[upstreamStateType]], i1, !llvm.struct<(i32)>)
// CHECK-NEXT: %[[V1:.*]]:3 = scf.while (%[[arg1:.*]] = %[[V0]]) : ([[upstreamStateType:.*]]) -> ([[upstreamStateType]], i1, tuple<i32>) {
// CHECK-NEXT: %[[V3:.*]]:3 = func.call @iterators.{{[a-zA-Z]+}}.next.0(%[[arg1]]) : ([[upstreamStateType]]) -> ([[upstreamStateType]], i1, tuple<i32>)
// CHECK-NEXT: %[[V4:.*]] = scf.if %[[V3]]#1 -> (i1) {
// CHECK-NEXT: %[[V7:.*]] = func.call @is_positive_struct(%[[V3]]#2) : (!llvm.struct<(i32)>) -> i1
// CHECK-NEXT: %[[V7:.*]] = func.call @is_positive_tuple(%[[V3]]#2) : (tuple<i32>) -> i1
// CHECK-NEXT: scf.yield %[[V7]] : i1
// CHECK-NEXT: } else {
// CHECK-NEXT: scf.yield %[[V3]]#1 : i1
// CHECK-NEXT: }
// CHECK-NEXT: %[[Vtrue:.*]] = arith.constant true
// CHECK-NEXT: %[[V5:.*]] = arith.xori %[[V4]], %[[Vtrue]] : i1
// CHECK-NEXT: %[[V6:.*]] = arith.andi %[[V3]]#1, %[[V5]] : i1
// CHECK-NEXT: scf.condition(%[[V6]]) %[[V3]]#0, %[[V3]]#1, %[[V3]]#2 : [[upstreamStateType]], i1, !llvm.struct<(i32)>
// CHECK-NEXT: scf.condition(%[[V6]]) %[[V3]]#0, %[[V3]]#1, %[[V3]]#2 : [[upstreamStateType]], i1, tuple<i32>
// CHECK-NEXT: } do {
// CHECK-NEXT: ^bb0(%[[arg2:.*]]: [[upstreamStateType]], %arg2: i1, %arg3: !llvm.struct<(i32)>):
// CHECK-NEXT: ^bb0(%[[arg2:.*]]: [[upstreamStateType]], %arg2: i1, %arg3: tuple<i32>):
// CHECK-NEXT: scf.yield %[[arg2]] : [[upstreamStateType]]
// CHECK-NEXT: }
// CHECK-NEXT: %[[V2:.*]] = iterators.insertvalue %[[V1]]#0 into %[[arg0]][0] : !iterators.state<!iterators.state<i32>>
// CHECK-NEXT: return %[[V2]], %[[V1]]#1, %[[V1]]#2 : !iterators.state<!iterators.state<i32>>, i1, !llvm.struct<(i32)>
// CHECK-NEXT: return %[[V2]], %[[V1]]#1, %[[V1]]#2 : !iterators.state<!iterators.state<i32>>, i1, tuple<i32>
// CHECK-NEXT: }

// CHECK-LABEL: func.func private @iterators.filter.open.{{[0-9]+}}(%{{.*}}: !iterators.state<!iterators.state<i32>>) -> !iterators.state<!iterators.state<i32>>
Expand All @@ -39,25 +37,25 @@
// CHECK-NEXT: return %[[V2]] : !iterators.state<!iterators.state<i32>>
// CHECK-NEXT: }

func.func private @is_positive_struct(%struct : !element_type) -> i1 {
// CHECK-LABEL: func.func private @is_positive_struct(%{{.*}}: !llvm.struct<(i32)>) -> i1 {
%i = llvm.extractvalue %struct[0] : !element_type
// CHECK-NEXT: %[[i:.*]] = llvm.extractvalue %[[struct:.*]][0] : !llvm.struct<(i32)>
// CHECK-LABEL: func.func private @is_positive_tuple(
// CHECK-SAME: %[[ARG0:.*]]: tuple<i32>) -> i1 {
// CHECK-DAG: %[[V0:.*]] = tuple.to_elements %[[ARG0]] : tuple<i32>
// CHECK-DAG: %[[V1:.*]] = arith.constant 0 : i32
// CHECK-NEXT: %[[V2:.*]] = arith.cmpi sgt, %[[V0]], %[[V1]] : i32
// CHECK-NEXT: return %[[V2]] : i1
func.func private @is_positive_tuple(%tuple : tuple<i32>) -> i1 {
%i = tuple.to_elements %tuple : tuple<i32>
%zero = arith.constant 0 : i32
// CHECK-NEXT: %[[zero:.*]] = arith.constant 0 : i32
%cmp = arith.cmpi "sgt", %i, %zero : i32
// CHECK-NEXT: %[[cmp:.*]] = arith.cmpi sgt, %[[i]], %[[zero]] : i32
return %cmp : i1
// CHECK-NEXT: return %[[cmp]] : i1
}
// CHECK-NEXT: }

func.func @main() {
// CHECK-LABEL: func.func @main()
%input = "iterators.constantstream"() { value = [] } : () -> (!iterators.stream<!element_type>)
%input = "iterators.constantstream"() { value = [] } : () -> (!iterators.stream<tuple<i32>>)
// CHECK: %[[V0:.*]] = iterators.createstate({{.*}}) : [[upstreamStateType:.*]]
%filter = "iterators.filter"(%input) {predicateRef = @is_positive_struct}
: (!iterators.stream<!element_type>) -> (!iterators.stream<!element_type>)
%filter = "iterators.filter"(%input) {predicateRef = @is_positive_tuple}
: (!iterators.stream<tuple<i32>>) -> (!iterators.stream<tuple<i32>>)
// CHECK-NEXT: %[[V1:.*]] = iterators.createstate(%[[V0]]) : !iterators.state<[[upstreamStateType]]>
return
// CHECK-NEXT: return
Expand Down
Loading

0 comments on commit b6a3dc3

Please sign in to comment.