Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Iterators] Add TensorToStreamOp. #649

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion experimental/iterators/benchmarks/inner_product/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,8 @@ def compile(self):
emit_benchmarking_function('main_bench', main_func)
pm = PassManager.parse( # (Comment for better formatting.)
'convert-iterators-to-llvm,'
'convert-states-to-llvm,'
'decompose-iterator-states,'
'canonicalize,'
'expand-strided-metadata,'
'finalize-memref-to-llvm,'
'convert-scf-to-cf,'
Expand Down
2 changes: 2 additions & 0 deletions experimental/iterators/include/iterators-c/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ extern "C" {

#include "iterators/Conversion/Passes.capi.h.inc" // IWYU pragma: export

#include "iterators/Dialect/Iterators/Transforms/Passes.capi.h.inc" // IWYU pragma: export

#ifdef __cplusplus
}
#endif
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
add_subdirectory(IR)
add_subdirectory(Transforms)
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,49 @@ def Iterators_TabularViewToStreamOp : Iterators_Op<"tabular_view_to_stream",
}];
}

def Iterators_TensorToStreamOp : Iterators_Op<"tensor_to_stream",
[AllMatch<["$input.getType().cast<TensorType>().getRank()",
"$result.getType().cast<StreamType>().getElementType()"
" .cast<TensorType>().getRank()"],
"the input tensor and those in the result stream "
"have the same rank">,
AllMatch<["$input.getType().cast<TensorType>().getElementType()",
"$result.getType().cast<StreamType>().getElementType()"
" .cast<TensorType>().getElementType()"],
"the input tensor and those in the result stream "
"have the same element type">,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
let summary = "Produces a stream of slices from a given tensor";
// TODO(ingomueller): allow for a padding value for non-dividing sizes?
// TODO(ingomueller): allow for strides? what would that even mean exactly?
// TODO(ingomueller): allow for defining the iteration order?
// TODO(ingomueller): expose iteration indices?
let description = [{
Produces a stream of all non-overlapping, statically shaped, unit-strided
slices from a given tensor in an implementation-defined order. The slices
are tensors of the same rank (and element type) as the input tensor (i.e.,
no rank reduction is done). In the dimensions where the resulting tensor
size divides the input tensor size, the resulting tensors fully cover the
input tensor; if the sizes do not divide, the remainders are dropped.

Example:
```mlir
%stream_of_tensors = iterators.tensor_to_stream %input :
tensor<?xi32> to !iterators.stream<tensor<4xi32>>
```
}];
let arguments = (ins AnyRankedTensor:$input);
let results = (outs Iterators_StreamOf<AnyStaticShapeTensor>:$result);
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($result)";
let extraClassDefinition = [{
/// Implement OpAsmOpInterface.
void $cppClass::getAsmResultNames(
llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
setNameFn(getResult(), "fromtensor");
}
}];
}

/// The sink op is a special op that only consumes a stream of values and
/// produces nothing.
/// It is not marked with Iterators_IteratorOpInterface.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Iterators)
mlir_tablegen(Passes.capi.h.inc -gen-pass-capi-header --prefix Iterators)
mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl --prefix Iterators)
add_public_tablegen_target(MLIRIteratorsPassIncGen)
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
//===- DecomposeIteratorStates.h - Pass Utilities ---------------*- C++ -*-===//
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef ITERATORS_DIALECT_ITERATORS_TRANSFORMS_DECOMPOSEITERATORSTATES_H
#define ITERATORS_DIALECT_ITERATORS_TRANSFORMS_DECOMPOSEITERATORSTATES_H

namespace mlir {
class RewritePatternSet;
class TypeConverter;
} // namespace mlir

namespace mlir {
namespace iterators {

void populateDecomposeIteratorStatesPatterns(TypeConverter &typeConverter,
RewritePatternSet &patterns);

} // namespace iterators
} // namespace mlir

#endif // ITERATORS_DIALECT_ITERATORS_TRANSFORMS_DECOMPOSEITERATORSTATES_H
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
//===- Passes.h - Transform Pass Construction and Registration --*- C++ -*-===//
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This header file defines prototypes that expose pass constructors.
//
//===----------------------------------------------------------------------===//

#ifndef ITERATORS_DIALECT_ITERATORS_TRANSFORMS_PASSES_H
#define ITERATORS_DIALECT_ITERATORS_TRANSFORMS_PASSES_H

#include "mlir/Pass/Pass.h"

namespace mlir {

//===----------------------------------------------------------------------===//
// Construction
//===----------------------------------------------------------------------===//

/// Generate pass declarations.
#define GEN_PASS_DECL
#include "iterators/Dialect/Iterators/Transforms/Passes.h.inc"

/// Creates a pass that decomposes iterator states into individual values.
std::unique_ptr<Pass> createDecomposeIteratorStatesPass();

//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//

/// Generate the code for registering conversion passes.
#define GEN_PASS_REGISTRATION
#include "iterators/Dialect/Iterators/Transforms/Passes.h.inc"

} // namespace mlir

#endif // ITERATORS_DIALECT_ITERATORS_TRANSFORMS_PASSES_H
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
//===-- Passes.td - Transform pass definition file ---------*- tablegen -*-===//
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef ITERATORS_TRANSFORMS_PASSES
#define ITERATORS_TRANSFORMS_PASSES

include "mlir/Pass/PassBase.td"

//===----------------------------------------------------------------------===//
// IteratorsToLLVM
//===----------------------------------------------------------------------===//

def DecomposeIteratorStates : Pass<"decompose-iterator-states", "ModuleOp"> {
let summary = "Decompose iterator states into their constituent values";
let description = [{
Iterator state "bundle" values that constitute the current state of
iterators, which often includes the state of nested iterators. This pass
decomposes these bundles into their constituent values such that the
`iterators.state` type is completely eliminated. In particular, the
creation, field access, and field updates now simply forward SSA values,
which are then carried as individual arguments through `scf` and `func` ops.
This decomposition allows further passes to run without knowing anything
about iterators, i.e., it makes iterators composable with other passes.

Example:

```mlir
func.func @example(%arg : !iterators.state<i1, i32>) -> (!iterators.state<i1, i32>) {
%i1 = iterators.extractvalue %arg[0] : !iterators.state<i1, i32>
%result = scf.if %i1 -> !iterators.state<i1, i32> {
scf.yield %arg : !iterators.state<i1, i32>
} else {
%true = arith.constant 1 : i1
%updated = iterators.insertvalue %true into %arg[0] : !iterators.state<i1, i32>
scf.yield %updated : !iterators.state<i1, i32>
}
return %result : !iterators.state<i1, i32>
}
```

gets decomposed into

```mlir
func.func @example(%arg0: i1, %arg1: i32) -> (i1, i32) {
%true = arith.constant true
%0:2 = scf.if %arg0 -> (i1, i32) {
scf.yield %arg0, %arg1 : i1, i32
} else {
scf.yield %true, %arg1 : i1, i32
}
return %0#0, %0#1 : i1, i32
}
```
}];
let constructor = "mlir::createDecomposeIteratorStatesPass()";
let dependentDialects = [
"scf::SCFDialect",
"func::FuncDialect",
];
}

#endif // ITERATORS_TRANSFORMS_PASSES
2 changes: 2 additions & 0 deletions experimental/iterators/lib/CAPI/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
add_mlir_public_c_api_library(IteratorsCAPI
Dialects.cpp
Passes.cpp
Transforms.cpp
LINK_LIBS PUBLIC
MLIRIterators
MLIRIteratorsToLLVM
MLIRIteratorsTransforms
MLIRTabular
MLIRTabularToLLVM
MLIRPass
Expand Down
27 changes: 27 additions & 0 deletions experimental/iterators/lib/CAPI/Transforms.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
//===- Transforms.cpp - C API for Transformations Passes ------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

// #include "iterators-c/Passes.h"
#include "iterators/Dialect/Iterators/Transforms/Passes.h"
#include "mlir/CAPI/Pass.h"
#include "mlir/Pass/Pass.h"

using namespace mlir;

// Must include the declarations as they carry important visibility attributes.
#include "iterators/Dialect/Iterators/Transforms/Passes.capi.h.inc"

#ifdef __cplusplus
extern "C" {
#endif

#include "iterators/Dialect/Iterators/Transforms/Passes.capi.cpp.inc"

#ifdef __cplusplus
}
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,21 @@ StateType StateTypeComputer::operator()(
return StateType::get(context, {indexType, viewType});
}

/// The state of TensorToStreamOp consists of a single number that corresponds
/// to the index of the next tensor slice returned by the iterator and the input
/// tensor. Pseudocode:
///
/// template <typename TensorType>
/// struct { size_t currentIndex; TensorType view; }
template <>
StateType StateTypeComputer::operator()(
TensorToStreamOp op, llvm::SmallVector<StateType> /*upstreamStateTypes*/) {
MLIRContext *context = op->getContext();
Type indexType = IndexType::get(context);
Type tensorType = op.getInput().getType();
return StateType::get(context, {indexType, tensorType});
}

/// The state of ValueToStreamOp consists a Boolean indicating whether it has
/// already returned its value (which is initialized to false and set to true in
/// the first call to next) and the value it converts to a stream.
Expand Down Expand Up @@ -173,6 +188,7 @@ mlir::iterators::IteratorAnalysis::IteratorAnalysis(
MapOp,
ReduceOp,
TabularViewToStreamOp,
TensorToStreamOp,
ValueToStreamOp
// clang-format on
>([&](auto op) {
Expand Down
Loading