Skip to content

Commit

Permalink
lower: fix a bug causing undefined variables when applying fuse
Browse files Browse the repository at this point in the history
Fixes tensor-compiler#355.

This commit fixes a bug where the fuse transformation would not generate
necessary locator variables when applied to iteration over two dense
variables.
  • Loading branch information
rohany committed Jan 13, 2021
1 parent cb4731d commit dcc52a7
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 22 deletions.
2 changes: 1 addition & 1 deletion src/lower/iterator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ std::ostream& operator<<(std::ostream& os, const Iterator& iterator) {
if (iterator.isDimensionIterator()) {
return os << "\u0394" << iterator.getIndexVar().getName();
}
return os << iterator.getTensor();
return os << iterator.getTensor() << " " << iterator.getIndexVar();
}


Expand Down
59 changes: 38 additions & 21 deletions src/lower/lowerer_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,8 @@ Stmt LowererImpl::lowerForallDimension(Forall forall,
{
Expr coordinate = getCoordinateVar(forall.getIndexVar());

// cout << "Lowering forall dimension? " << forall << endl;

if (forall.getParallelUnit() != ParallelUnit::NotParallel && forall.getOutputRaceStrategy() == OutputRaceStrategy::Atomics) {
markAssignsAtomicDepth++;
atomicParallelUnit = forall.getParallelUnit();
Expand Down Expand Up @@ -2250,35 +2252,50 @@ Stmt LowererImpl::declLocatePosVars(vector<Iterator> locators) {
for (Iterator& locator : locators) {
accessibleIterators.insert(locator);

bool doLocate = true;
// Pull out some logic for constructing the locators for a given iterator.
auto addLocator = [&](const Iterator& iter) {
ModeFunction locate = iter.locate(coordinates(iter));
taco_iassert(isValue(locate.getResults()[1], true));
Stmt declarePosVar = VarDecl::make(iter.getPosVar(), locate.getResults()[0]);
result.push_back(declarePosVar);
};

// Look through all of the parent iterators. If any of these iterators
// are not accessible, we need to construct their accessors before emitting
// locator's accessors. This is because locator may use the ancestor's
// variables in its accessors. We add these ancestors into a vector and reverse
// it so that the highest parent in the tree's accessors get declared first.
std::vector<Iterator> ancestors;
for (Iterator ancestorIterator = locator.getParent();
!ancestorIterator.isRoot() && ancestorIterator.hasLocate();
ancestorIterator = ancestorIterator.getParent()) {
if (!accessibleIterators.contains(ancestorIterator)) {
doLocate = false;
// Since we're going to emit the locators for this iterator, add it to
// accessibleIterators so that other locators with this as an ancestor
// don't do the same.
accessibleIterators.insert(ancestorIterator);
ancestors.push_back(ancestorIterator);
}
}
for (auto it = ancestors.rbegin(); it != ancestors.rend(); it++) addLocator(*it);

if (doLocate) {
Iterator locateIterator = locator;
if (locateIterator.hasPosIter()) {
taco_iassert(!provGraph.isUnderived(locateIterator.getIndexVar()));
continue; // these will be recovered with separate procedure
}
do {
ModeFunction locate = locateIterator.locate(coordinates(locateIterator));
taco_iassert(isValue(locate.getResults()[1], true));
Stmt declarePosVar = VarDecl::make(locateIterator.getPosVar(),
locate.getResults()[0]);
result.push_back(declarePosVar);

if (locateIterator.isLeaf()) {
break;
}

locateIterator = locateIterator.getChild();
} while (accessibleIterators.contains(locateIterator));
Iterator locateIterator = locator;
// Position iterators will be recovered with a separate procedure, so
// don't emit anything if locator is one.
if (locateIterator.hasPosIter()) {
taco_iassert(!provGraph.isUnderived(locateIterator.getIndexVar()));
continue;
}

// Once all parent locators have been declared, add the target and all
// children locators.
do {
addLocator(locateIterator);
if (locateIterator.isLeaf()) {
break;
}
locateIterator = locateIterator.getChild();
} while (accessibleIterators.contains(locateIterator));
}
return result.empty() ? Stmt() : Block::make(result);
}
Expand Down
50 changes: 50 additions & 0 deletions test/tests-scheduling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,56 @@ TEST(scheduling, splitIndexStmt) {
ASSERT_TRUE(equals(a(i) = b(i), i2Forall.getStmt()));
}

TEST(scheduling, fuseDenseLoops) {
auto dim = 4;
Tensor<int> A("A", {dim, dim, dim}, {Dense, Dense, Dense});
Tensor<int> B("B", {dim, dim, dim}, {Dense, Dense, Dense});
Tensor<int> expected("expected", {dim, dim, dim}, {Dense, Dense, Dense});
IndexVar f("f"), g("g");
for (int i = 0; i < dim; i++) {
for (int j = 0; j < dim; j++) {
for (int k = 0; k < dim; k++) {
A.insert({i, j, k}, i + j + k);
B.insert({i, j, k}, i + j + k);
expected.insert({i, j, k}, 2 * (i + j + k));
}
}
}
A.pack();
B.pack();
expected.pack();

// Helper function to evaluate the target statement and verify the results.
// It takes in a function that applies some scheduling transforms to the
// input IndexStmt, and applies to the point-wise tensor addition below.
// The test is structured this way as TACO does its best to avoid re-compilation
// whenever possible. I.e. changing the stmt that a tensor is compiled with
// doesn't cause compilation to occur again.
auto testFn = [&](IndexStmt modifier (IndexStmt)) {
Tensor<int> C("C", {dim, dim, dim}, {Dense, Dense, Dense});
C(i, j, k) = A(i, j, k) + B(i, j, k);
auto stmt = C.getAssignment().concretize();
C.compile(modifier(stmt));
C.evaluate();
ASSERT_TRUE(equals(C, expected)) << endl << C << endl << expected << endl;
};

// First, a sanity check with no transformations.
testFn([](IndexStmt stmt) { return stmt; });
// Next, fuse the outer two loops. This tests the original bug in #355.
testFn([](IndexStmt stmt) {
IndexVar f("f");
return stmt.fuse(i, j, f);
});
// Lastly, fuse all of the loops into a single loop. This ensures that
// locators with a chain of ancestors have all of their dependencies
// generated in a valid ordering.
testFn([](IndexStmt stmt) {
IndexVar f("f"), g("g");
return stmt.fuse(i, j, f).fuse(f, k, g);
});
}

TEST(scheduling, lowerDenseMatrixMul) {
Tensor<double> A("A", {4, 4}, {Dense, Dense});
Tensor<double> B("B", {4, 4}, {Dense, Dense});
Expand Down

0 comments on commit dcc52a7

Please sign in to comment.