Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Iterators] Test cases for folding a stream into a memref. #672

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,63 @@ def Iterators_PrintOp : Iterators_Base_Op<"print", [
// High-level iterators
//===----------------------------------------------------------------------===//

def Iterators_AccumulateOp : Iterators_Op<"accumulate", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let summary = "Accumulate the elements of a stream into one element";
let description = [{
Accumulate the elements of the input stream into a single element, i.e.,
compute their generalized sum. This is similar to
[`std::accumulate`](https://en.cppreference.com/w/cpp/algorithm/accumulate)
in C++ and
[`functools.reduce`](https://docs.python.org/3/library/functools.html#functools.reduce)
with *initializer* in Python. The accumulator is initialized with thevalue
provided value (which must be of the same type as the elements of the result
stream); the logic of the accumulation is given by the provided accumulate
function.

Pseudo-code:
```
accumulator = initVal
while (next = upstream->next()):
accumulator = @accumulateFuncRef(accumulator, next->value())
return accumulator

Example:
```mlir
%input = ...
%zero_tuple = ...
%0 = iterators.accumulate(%input, %zero_tuple) with @sum
: (!iterators.stream<i64>) -> !iterators.stream<tuple<i64>>
```
}];
let arguments = (ins
Iterators_Stream:$input,
AnyType:$initVal,
FlatSymbolRefAttr:$accumulateFuncRef
);
let results = (outs Iterators_Stream:$result);
let assemblyFormat = [{
`(` $input `,` $initVal `)` `with` $accumulateFuncRef attr-dict `:`
`(` qualified(type($input)) `)` `->` qualified(type($result))
custom<AccumulateInitValType>(type($initVal), ref(type($result)))
}];
let extraClassDeclaration = [{
/// Return the accumulate function op that the accumulateFuncRef refers to.
func::FuncOp getAccumulateFunc() {
return SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(
*this, getAccumulateFuncRefAttr());
}
}];
let extraClassDefinition = [{
/// Implement OpAsmOpInterface.
void $cppClass::getAsmResultNames(
llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
setNameFn(getResult(), "accumulated");
}
}];
}

def Iterators_ConstantStreamOp : Iterators_Op<"constantstream", [
PredOpTrait<"element type of return type must be tuple with matching types",
CPred<[{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,23 @@ class StateTypeComputer {
TypeConverter typeConverter;
};

/// The state of AccumulateOp consists of the state of its upstream iterator,
/// i.e., the state of the iterator that produces its input stream, the initial
/// value of the accumulator, and a Boolean indicating whether the iterator has
/// returned a result already (which is initialized to false and set to true in
/// the first call to next in order to ensure that only a single result is
/// returned).
template <>
StateType
StateTypeComputer::operator()(AccumulateOp op,
llvm::SmallVector<StateType> upstreamStateTypes) {
MLIRContext *context = op->getContext();
Type hasReturned = IntegerType::get(context, /*width=*/1);
Type initValType = op.getInitVal().getType();
return StateType::get(context,
{upstreamStateTypes[0], initValType, hasReturned});
}

/// The state of ConstantStreamOp consists of a single number that corresponds
/// to the index of the next struct returned by the iterator.
template <>
Expand Down Expand Up @@ -180,6 +197,7 @@ mlir::iterators::IteratorAnalysis::IteratorAnalysis(
// TODO: Verify that operands do not come from bbArgs.
.Case<
// clang-format off
AccumulateOp,
ConstantStreamOp,
FilterOp,
MapOp,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,244 @@ struct PrintOpLowering : public OpConversionPattern<PrintOp> {
}
};

//===----------------------------------------------------------------------===//
// AccumulateOp.
//===----------------------------------------------------------------------===//

/// Builds IR that opens the nested upstream iterator and sets `hasReturned` to
/// false. Possible output:
///
/// %0 = iterators.extractvalue %arg0[0] :
/// <!upstream_state, i1> -> !upstream_state
/// %1 = call @iterators.upstream.open.0(%0) :
/// (!upstream_state) -> !upstream_state
/// %2 = iterators.insertvalue %arg0[0] (%1 : !upstream_state) :
/// <!upstream_state, i1>
/// %false = arith.constant false
/// %3 = iterators.insertvalue %false into %2[1] :
/// !iterators.state<!upstream_state, i1>
static Value buildOpenBody(AccumulateOp op, OpBuilder &builder,
Value initialState,
ArrayRef<IteratorInfo> upstreamInfos) {
Location loc = op.getLoc();
ImplicitLocOpBuilder b(loc, builder);

Type upstreamStateType = upstreamInfos[0].stateType;

// Extract upstream state.
Value initialUpstreamState = b.create<iterators::ExtractValueOp>(
upstreamStateType, initialState, b.getIndexAttr(0));

// Call Open on upstream.
SymbolRefAttr openFunc = upstreamInfos[0].openFunc;
auto openCallOp =
b.create<func::CallOp>(openFunc, upstreamStateType, initialUpstreamState);

// Update upstream state.
Value updatedUpstreamState = openCallOp->getResult(0);
Value updatedState = b.create<iterators::InsertValueOp>(
initialState, b.getIndexAttr(0), updatedUpstreamState);

// Reset hasReturned to false.
Value constFalse = b.create<arith::ConstantIntOp>(/*value=*/0, /*width=*/1);
updatedState = b.create<iterators::InsertValueOp>(
updatedState, b.getIndexAttr(2), constFalse);

return updatedState;
}

/// Builds IR that consumes all elements of the upstream iterator and combines
/// them into a single one using the given accumulate function. Pseudo-code:
///
/// if hasReturned: return {}
/// hasReturned = True
/// accumulator = initVal
/// while (next = upstream->Next()):
/// accumulator = accumulate(accumulator, next)
/// return accumulator
///
/// Possible output:
///
/// %upstream_state = iterators.extractvalue %arg0[0] : !state_type
/// %init_val = iterators.extractvalue %arg0[1] : !state_type
/// %has_returned = iterators.extractvalue %arg0[2] : !state_type
/// %2:2 = scf.if %2 -> (!upstream_state, !element_type) {
/// scf.yield %upstream_state, %init_val : !upstream_state, !element_type
/// } else {
/// %5:3 = scf.while (%arg1 = %upsteram_state, %arg2 = %init_val) :
/// (!upstream_state, !element_type) ->
/// (!upstream_state, !element_type, !element_type) {
/// %6:3 = func.call @iterators.upstream.next.0(%arg1) :
/// (!upstream_state) -> (!upstream_state, i1, !element_type)
/// scf.condition(%6#1) %8#0, %arg2, %8#2 :
/// !upstream_state, !element_type, !element_type
//// } do {
/// ^bb0(%arg1: !upstream_state, %arg2: !element_type, %arg3: !element_type):
/// %8 = func.call @accumulate_func(%arg2, %arg3) :
/// (!element_type, !element_type) -> !element_type
/// scf.yield %arg1, %8 : !upstream_state, !element_type
/// }
/// scf.yield %7#0, %7#1 : !upstream_state, !element_type
/// }
/// %true = arith.constant true
/// %4 = arith.xori %true, %1 : i1
/// %state_0 = iterators.insertvalue %3#0 into %arg0[0] : !state_type
/// %state_1 = iterators.insertvalue %true into %state_0[1] : !state_type
static llvm::SmallVector<Value, 4>
buildNextBody(AccumulateOp op, OpBuilder &builder, Value initialState,
ArrayRef<IteratorInfo> upstreamInfos, Type elementType) {
Location loc = op.getLoc();
ImplicitLocOpBuilder b(loc, builder);
Type i1 = b.getI1Type();

// Extract input element type.
StreamType inputStreamType = op.getInput().getType().cast<StreamType>();
Type inputElementType = inputStreamType.getElementType();

// Extract upstream state and init value.
Type upstreamStateType = upstreamInfos[0].stateType;
Value initialUpstreamState = b.create<iterators::ExtractValueOp>(
upstreamStateType, initialState, b.getIndexAttr(0));
Value initValue = b.create<iterators::ExtractValueOp>(
elementType, initialState, b.getIndexAttr(1));

// Check if the iterator has returned an element already (since it should
// return one only in the first call to next).
Value hasReturned =
b.create<iterators::ExtractValueOp>(i1, initialState, b.getIndexAttr(2));
SmallVector<Type> ifReturnTypes{upstreamStateType, elementType};
auto ifOp = b.create<scf::IfOp>(
hasReturned,
/*thenBuilder=*/
[&](OpBuilder &builder, Location loc) {
ImplicitLocOpBuilder b(loc, builder);

// Don't modify state; return init value.
b.create<scf::YieldOp>(ValueRange{initialUpstreamState, initValue});
},
/*elseBuilder=*/
[&](OpBuilder &builder, Location loc) {
ImplicitLocOpBuilder b(loc, builder);

// Create while loop using init value as initial accumulator.
SmallVector<Value> whileInputs = {initialUpstreamState, initValue};
SmallVector<Type> whileResultTypes = {
upstreamStateType, // Updated upstream state.
elementType, // Accumulator.
inputElementType // Element from last next call.
};
scf::WhileOp whileOp = b.create<scf::WhileOp>(
whileResultTypes, whileInputs,
/*beforeBuilder=*/
[&](OpBuilder &builder, Location loc, ValueRange args) {
ImplicitLocOpBuilder b(loc, builder);

Value upstreamState = args[0];
Value accumulator = args[1];

// Call next function.
SmallVector<Type> nextResultTypes = {upstreamStateType, i1,
inputElementType};
SymbolRefAttr nextFunc = upstreamInfos[0].nextFunc;
auto nextCall = b.create<func::CallOp>(nextFunc, nextResultTypes,
upstreamState);

Value updatedUpstreamState = nextCall->getResult(0);
Value hasNext = nextCall->getResult(1);
Value maybeNextElement = nextCall->getResult(2);
b.create<scf::ConditionOp>(
hasNext, ValueRange{updatedUpstreamState, accumulator,
maybeNextElement});
},
/*afterBuilder=*/
[&](OpBuilder &builder, Location loc, ValueRange args) {
ImplicitLocOpBuilder b(loc, builder);

Value upstreamState = args[0];
Value accumulator = args[1];
Value nextElement = args[2];

// Call accumulate function.
auto accumulateCall =
b.create<func::CallOp>(elementType, op.getAccumulateFuncRef(),
ValueRange{accumulator, nextElement});
Value newAccumulator = accumulateCall->getResult(0);

b.create<scf::YieldOp>(ValueRange{upstreamState, newAccumulator});
});

Value updatedState = whileOp->getResult(0);
Value accumulator = whileOp->getResult(1);

b.create<scf::YieldOp>(ValueRange{updatedState, accumulator});
});

// Compute hasNext: we have an element iff we have not returned before, i.e.,
// iff "not hasReturend". We simulate "not" with "xor true".
Value constTrue = b.create<arith::ConstantIntOp>(/*value=*/1, /*width=*/1);
Value hasNext = b.create<arith::XOrIOp>(constTrue, hasReturned);

// Update state.
Value finalUpstreamState = ifOp->getResult(0);
Value finalState = b.create<iterators::InsertValueOp>(
initialState, b.getIndexAttr(0), finalUpstreamState); // upstreamState
finalState = b.create<iterators::InsertValueOp>(finalState, b.getIndexAttr(2),
constTrue); // hasReturned
Value nextElement = ifOp->getResult(1);

return {finalState, hasNext, nextElement};
}

/// Builds IR that closes the nested upstream iterator. Possible output:
///
/// %0 = iterators.extractvalue %arg0[0] :
/// !iterators.state<!upstream_state, i1> -> !upstream_state
/// %1 = call @iterators.upstream.close.0(%0) :
/// (!upstream_state) -> !upstream_state
/// %2 = iterators.insertvalue %arg0[0] (%1 : !upstream_state) :
/// !iterators.state<!upstream_state, i1>
static Value buildCloseBody(AccumulateOp op, OpBuilder &builder,
Value initialState,
ArrayRef<IteratorInfo> upstreamInfos) {
Location loc = op.getLoc();
ImplicitLocOpBuilder b(loc, builder);

Type upstreamStateType = upstreamInfos[0].stateType;

// Extract upstream state.
Value initialUpstreamState = b.create<iterators::ExtractValueOp>(
upstreamStateType, initialState, b.getIndexAttr(0));

// Call Close on upstream.
SymbolRefAttr closeFunc = upstreamInfos[0].closeFunc;
auto closeCallOp = b.create<func::CallOp>(closeFunc, upstreamStateType,
initialUpstreamState);

// Update upstream state.
Value updatedUpstreamState = closeCallOp->getResult(0);
return b
.create<iterators::InsertValueOp>(initialState, b.getIndexAttr(0),
updatedUpstreamState)
.getResult();
}

/// Builds IR that initializes the iterator state with the state of the upstream
/// iterator. Possible output:
///
/// %0 = ...
/// %1 = arith.constant false
/// %2 = iterators.createstate(%0, %1) : !iterators.state<!upstream_state, i1>
static Value buildStateCreation(AccumulateOp op, AccumulateOp::Adaptor adaptor,
OpBuilder &builder, StateType stateType) {
Location loc = op.getLoc();
ImplicitLocOpBuilder b(loc, builder);
Value upstreamState = adaptor.getInput();
Value initVal = adaptor.getInitVal();
Value constFalse = b.create<arith::ConstantIntOp>(/*value=*/0, /*width=*/1);
return b.create<iterators::CreateStateOp>(
stateType, ValueRange{upstreamState, initVal, constFalse});
}

//===----------------------------------------------------------------------===//
// ConstantStreamOp.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1543,6 +1781,7 @@ static Value buildOpenBody(Operation *op, OpBuilder &builder,
return llvm::TypeSwitch<Operation *, Value>(op)
.Case<
// clang-format off
AccumulateOp,
ConstantStreamOp,
FilterOp,
MapOp,
Expand All @@ -1563,6 +1802,7 @@ buildNextBody(Operation *op, OpBuilder &builder, Value initialState,
return llvm::TypeSwitch<Operation *, llvm::SmallVector<Value, 4>>(op)
.Case<
// clang-format off
AccumulateOp,
ConstantStreamOp,
FilterOp,
MapOp,
Expand All @@ -1584,6 +1824,7 @@ static Value buildCloseBody(Operation *op, OpBuilder &builder,
return llvm::TypeSwitch<Operation *, Value>(op)
.Case<
// clang-format off
AccumulateOp,
ConstantStreamOp,
FilterOp,
MapOp,
Expand All @@ -1603,6 +1844,7 @@ static Value buildStateCreation(IteratorOpInterface op, OpBuilder &builder,
return llvm::TypeSwitch<Operation *, Value>(op)
.Case<
// clang-format off
AccumulateOp,
ConstantStreamOp,
FilterOp,
MapOp,
Expand Down
Loading