Skip to content

Commit

Permalink
Fixed precompute transformation and attempt at fixing tensor-compiler…
Browse files Browse the repository at this point in the history
…#355. Also generate more optimized attribute query code for parallel sparse tensor addition
  • Loading branch information
stephenchouca committed Apr 30, 2021
1 parent 45ca20e commit 27e898c
Show file tree
Hide file tree
Showing 6 changed files with 585 additions and 110 deletions.
70 changes: 44 additions & 26 deletions include/taco/index_notation/index_notation.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class WindowedIndexVar;
class IndexSetVar;
class TensorVar;

class IndexStmt;
class IndexExpr;
class Assignment;
class Access;
Expand Down Expand Up @@ -63,6 +64,14 @@ struct SuchThatNode;
class IndexExprVisitorStrict;
class IndexStmtVisitorStrict;

/// Return true if the index statement is of the given subtype. The subtypes
/// are Assignment, Forall, Where, Sequence, and Multi.
template <typename SubType> bool isa(IndexExpr);

/// Casts the index statement to the given subtype. Assumes S is a subtype and
/// the subtypes are Assignment, Forall, Where, Sequence, and Multi.
template <typename SubType> SubType to(IndexExpr);

/// A tensor index expression describes a tensor computation as a scalar
/// expression where tensors are indexed by index variables (`IndexVar`). The
/// index variables range over the tensor dimensions they index, and the scalar
Expand Down Expand Up @@ -161,6 +170,12 @@ class IndexExpr : public util::IntrusivePtr<const IndexExprNode> {
/// Returns the schedule of the index expression.
const Schedule& getSchedule() const;

/// Casts index expression to specified subtype.
template <typename SubType>
SubType as() {
return to<SubType>(*this);
}

/// Visit the index expression's sub-expressions.
void accept(IndexExprVisitorStrict *) const;

Expand Down Expand Up @@ -204,14 +219,6 @@ IndexExpr operator*(const IndexExpr&, const IndexExpr&);
/// ```
IndexExpr operator/(const IndexExpr&, const IndexExpr&);

/// Return true if the index statement is of the given subtype. The subtypes
/// are Assignment, Forall, Where, Sequence, and Multi.
template <typename SubType> bool isa(IndexExpr);

/// Casts the index statement to the given subtype. Assumes S is a subtype and
/// the subtypes are Assignment, Forall, Where, Sequence, and Multi.
template <typename SubType> SubType to(IndexExpr);


/// An index expression that represents a tensor access, such as `A(i,j))`.
/// Access expressions are returned when calling the overloaded operator() on
Expand Down Expand Up @@ -514,6 +521,14 @@ class Reduction : public IndexExpr {
/// Create a summation index expression.
Reduction sum(IndexVar i, IndexExpr expr);

/// Return true if the index statement is of the given subtype. The subtypes
/// are Assignment, Forall, Where, Multi, and Sequence.
template <typename SubType> bool isa(IndexStmt);

/// Casts the index statement to the given subtype. Assumes S is a subtype and
/// the subtypes are Assignment, Forall, Where, Multi, and Sequence.
template <typename SubType> SubType to(IndexStmt);

/// A an index statement computes a tensor. The index statements are
/// assignment, forall, where, multi, and sequence.
class IndexStmt : public util::IntrusivePtr<const IndexStmtNode> {
Expand Down Expand Up @@ -633,9 +648,9 @@ class IndexStmt : public util::IntrusivePtr<const IndexStmtNode> {
///
/// Preconditions:
/// The index variable supplied to the coord transformation must be in
/// position space. The index variable supplied to the pos transformation
/// must be in coordinate space. The pos transformation also takes an
/// input to indicate which position space to use. This input must appear in the computation
/// position space. The index variable supplied to the pos transformation must
/// be in coordinate space. The pos transformation also takes an input to
/// indicate which position space to use. This input must appear in the computation
/// expression and also be indexed by this index variable. In the case that this
/// index variable is derived from multiple index variables, these variables must appear
/// directly nested in the mode ordering of this datastructure. This allows for
Expand All @@ -661,28 +676,38 @@ class IndexStmt : public util::IntrusivePtr<const IndexStmtNode> {
/// to the pos transformation.
IndexStmt fuse(IndexVar i, IndexVar j, IndexVar f) const;

/// The precompute transformation is described in kjolstad2019
/// allows us to leverage scratchpad memories and
/// reorder computations to increase locality
/// The precompute transformation is described in kjolstad2019
/// allows us to leverage scratchpad memories and
/// reorder computations to increase locality
IndexStmt precompute(IndexExpr expr, IndexVar i, IndexVar iw, TensorVar workspace) const;

/// bound specifies a compile-time constraint on an index variable's
/// iteration space that allows knowledge of the
/// size or structured sparsity pattern of the inputs to be
/// incorporated during bounds propagatio
/// incorporated during bounds propagation
///
/// Preconditions:
/// The precondition for bound is that the computation bounds supplied are correct
/// given the inputs that this code will be run on.
/// The precondition for bound is that the computation bounds supplied are
/// correct given the inputs that this code will be run on.
IndexStmt bound(IndexVar i, IndexVar i1, size_t bound, BoundType bound_type) const;

/// The unroll
/// primitive unrolls the corresponding loop by a statically-known
/// The unroll primitive unrolls the corresponding loop by a statically-known
/// integer number of iterations
/// Preconditions: unrollFactor is a positive nonzero integer
IndexStmt unroll(IndexVar i, size_t unrollFactor) const;

/// The assemble primitive specifies whether a result tensor should be
/// assembled by appending or inserting nonzeros into the result tensor.
/// In the latter case, the transformation inserts additional loops to
/// precompute statistics about the result tensor that are required for
/// preallocating memory and coordinating insertions of nonzeros.
IndexStmt assemble(TensorVar result, AssembleStrategy strategy) const;

/// Casts index statement to specified subtype.
template <typename SubType>
SubType as() {
return to<SubType>(*this);
}
};

/// Check if two index statements are isomorphic.
Expand All @@ -694,13 +719,6 @@ bool equals(IndexStmt, IndexStmt);
/// Print the index statement.
std::ostream& operator<<(std::ostream&, const IndexStmt&);

/// Return true if the index statement is of the given subtype. The subtypes
/// are Assignment, Forall, Where, Multi, and Sequence.
template <typename SubType> bool isa(IndexStmt);

/// Casts the index statement to the given subtype. Assumes S is a subtype and
/// the subtypes are Assignment, Forall, Where, Multi, and Sequence.
template <typename SubType> SubType to(IndexStmt);

/// An assignment statement assigns an index expression to the locations in a
/// tensor given by an lhs access expression.
Expand Down
159 changes: 116 additions & 43 deletions src/index_notation/transformations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
#include <iostream>
#include <algorithm>
#include <limits>
#include <set>
#include <map>
#include <vector>

using namespace std;

Expand Down Expand Up @@ -171,6 +174,12 @@ static bool containsExpr(Assignment assignment, IndexExpr expr) {
IndexExpr expr;
bool contains = false;

void visit(const AccessNode* node) {
if (equals(IndexExpr(node), expr)) {
contains = true;
}
}

void visit(const UnaryExprNode* node) {
if (equals(IndexExpr(node), expr)) {
contains = true;
Expand Down Expand Up @@ -213,6 +222,60 @@ static Assignment getAssignmentContainingExpr(IndexStmt stmt, IndexExpr expr) {
return assignment;
}

static IndexStmt eliminateRedundantReductions(IndexStmt stmt,
const std::set<TensorVar>* const candidates = nullptr) {

struct ReduceToAssign : public IndexNotationRewriter {
using IndexNotationRewriter::visit;

const std::set<TensorVar>* const candidates;
std::map<TensorVar,std::set<IndexVar>> availableVars;

ReduceToAssign(const std::set<TensorVar>* const candidates) :
candidates(candidates) {}

IndexStmt rewrite(IndexStmt stmt) {
for (const auto& result : getResults(stmt)) {
availableVars[result] = {};
}
return IndexNotationRewriter::rewrite(stmt);
}

void visit(const ForallNode* op) {
for (auto& it : availableVars) {
it.second.insert(op->indexVar);
}
IndexNotationRewriter::visit(op);
for (auto& it : availableVars) {
it.second.erase(op->indexVar);
}
}

void visit(const WhereNode* op) {
const auto workspaces = getResults(op->producer);
for (const auto& workspace : workspaces) {
availableVars[workspace] = {};
}
IndexNotationRewriter::visit(op);
for (const auto& workspace : workspaces) {
availableVars.erase(workspace);
}
}

void visit(const AssignmentNode* op) {
const auto result = op->lhs.getTensorVar();
if (op->op.defined() &&
util::toSet(op->lhs.getIndexVars()) == availableVars[result] &&
(!candidates || util::contains(*candidates, result))) {
stmt = Assignment(op->lhs, op->rhs);
return;
}
stmt = op;
}
};
return ReduceToAssign(candidates).rewrite(stmt);
}

IndexStmt Precompute::apply(IndexStmt stmt, std::string* reason) const {
INIT_REASON(reason);

Expand All @@ -229,30 +292,68 @@ IndexStmt Precompute::apply(IndexStmt stmt, std::string* reason) const {

Precompute precompute;

void visit(const ForallNode* node) {
Forall foralli(node);
void visit(const ForallNode* op) {
Forall foralli(op);
IndexVar i = precompute.geti();
IndexVar j = foralli.getIndexVar();

if (foralli.getIndexVar() == i) {
Assignment assign = getAssignmentContainingExpr(foralli,
precompute.getExpr());
if (j == i && assign.defined()) {
IndexStmt s = foralli.getStmt();
TensorVar ws = precompute.getWorkspace();
IndexExpr e = precompute.getExpr();
IndexVar iw = precompute.getiw();

IndexStmt consumer = forall(i, replace(s, {{e, ws(i)}}));
IndexStmt producer = forall(iw, ws(iw) = replace(e, {{i,iw}}));
IndexStmt producer = forall(iw, Assignment(ws(iw), replace(e, {{i,iw}}),
assign.getOperator()));
Where where(consumer, producer);

stmt = where;
return;
}
IndexNotationRewriter::visit(node);
}

IndexStmt s = rewrite(op->stmt);
if (s == op->stmt) {
stmt = op;
return;
} else if (isa<Where>(s)) {
Where body = to<Where>(s);
const auto consumerHasJ =
util::contains(body.getConsumer().getIndexVars(), j);
const auto producerHasJ =
util::contains(body.getProducer().getIndexVars(), j);
if (consumerHasJ && !producerHasJ) {
const auto producer = body.getProducer();
const auto consumer = Forall(op->indexVar, body.getConsumer(),
op->parallel_unit,
op->output_race_strategy,
op->unrollFactor);
stmt = Where(consumer, producer);
return;
} else if (producerHasJ && !consumerHasJ) {
const auto producer = Forall(op->indexVar, body.getProducer(),
op->parallel_unit,
op->output_race_strategy,
op->unrollFactor);
const auto consumer = body.getConsumer();
stmt = Where(consumer, producer);
return;
}
}
stmt = Forall(op->indexVar, s, op->parallel_unit,
op->output_race_strategy, op->unrollFactor);
}
};
PrecomputeRewriter rewriter;
rewriter.precompute = *this;
return rewriter.rewrite(stmt);
stmt = rewriter.rewrite(stmt);

// Convert redundant reductions to assignments
stmt = eliminateRedundantReductions(stmt);

return stmt;
}

void Precompute::print(std::ostream& os) const {
Expand Down Expand Up @@ -506,23 +607,24 @@ IndexStmt Parallelize::apply(IndexStmt stmt, std::string* reason) const {
Iterators iterators(foralli, tensorVars);
definedIndexVars.insert(foralli.getIndexVar());
MergeLattice lattice = MergeLattice::make(foralli, iterators, provGraph, definedIndexVars);
// Precondition 3: No parallelization of variables under a reduction
// Precondition 1: No parallelization of variables under a reduction
// variable (ie MergePoint has at least 1 result iterators)
if (parallelize.getOutputRaceStrategy() == OutputRaceStrategy::NoRaces && lattice.results().empty()
&& lattice != MergeLattice({MergePoint({iterators.modeIterator(foralli.getIndexVar())}, {}, {})})) {
if (parallelize.getOutputRaceStrategy() == OutputRaceStrategy::NoRaces &&
(lattice.results().empty() || lattice.results()[0].getIndexVar() != foralli.getIndexVar()) &&
lattice != MergeLattice({MergePoint({iterators.modeIterator(foralli.getIndexVar())}, {}, {})})) {
reason = "Precondition failed: Free variables cannot be dominated by reduction variables in the iteration graph, "
"as this causes scatter behavior and we do not yet emit parallel synchronization constructs";
return;
}

if (foralli.getIndexVar() == i) {
// Precondition 1: No coiteration of node (ie Merge Lattice has only 1 iterator)
// Precondition 2: No coiteration of mode (ie Merge Lattice has only 1 iterator)
if (lattice.iterators().size() != 1) {
reason = "Precondition failed: The loop must not merge tensor dimensions, that is, it must be a for loop;";
return;
}

// Precondition 2: Every result iterator must have insert capability
// Precondition 3: Every result iterator must have insert capability
for (Iterator iterator : lattice.results()) {
if (util::contains(assembledByUngroupedInsert, iterator.getTensor())) {
for (Iterator it = iterator; !it.isRoot(); it = it.getParent()) {
Expand Down Expand Up @@ -923,37 +1025,8 @@ IndexStmt SetAssembleStrategy::apply(IndexStmt stmt, string* reason) const {
}

// Convert redundant reductions to assignments
struct ReduceToAssign : public IndexNotationRewriter {
using IndexNotationRewriter::visit;

const std::set<TensorVar>& insertedResults;
std::set<IndexVar> availableVars;

ReduceToAssign(const std::set<TensorVar>& insertedResults) :
insertedResults(insertedResults) {}

void visit(const ForallNode* op) {
availableVars.insert(op->indexVar);
IndexNotationRewriter::visit(op);
availableVars.erase(op->indexVar);
}

void visit(const AssignmentNode* op) {
std::set<IndexVar> accessVars;
for (const auto& index : op->lhs.getIndexVars()) {
accessVars.insert(index);
}

if (op->op.defined() && accessVars == availableVars &&
util::contains(insertedResults, op->lhs.getTensorVar())) {
stmt = new AssignmentNode(op->lhs, op->rhs, IndexExpr());
return;
}

stmt = op;
}
};
loweredQueries = ReduceToAssign(insertedResults).rewrite(loweredQueries);
loweredQueries = eliminateRedundantReductions(loweredQueries,
&insertedResults);

// Inline definitions of temporaries into their corresponding uses, as long
// as the temporaries are not the results of reductions
Expand Down

0 comments on commit 27e898c

Please sign in to comment.