[Mlir-commits] [mlir] 384049a - [mlir][sparse] completed codegen environment privatization
Aart Bik
llvmlistbot at llvm.org
Thu Dec 22 10:34:51 PST 2022
Author: Aart Bik
Date: 2022-12-22T10:34:43-08:00
New Revision: 384049a755fa9f3a58655599f86e4fdf4633c390
URL: https://github.com/llvm/llvm-project/commit/384049a755fa9f3a58655599f86e4fdf4633c390
DIFF: https://github.com/llvm/llvm-project/commit/384049a755fa9f3a58655599f86e4fdf4633c390.diff
LOG: [mlir][sparse] completed codegen environment privatization
All members are now private and access is through delegate
or convenience methods only (except the loop emitter, which
is still under refactoring).
Reviewed By: Peiming
Differential Revision: https://reviews.llvm.org/D140519
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
index 52d6d5822e577..0be15d6995366 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
@@ -12,27 +12,84 @@ using namespace mlir;
using namespace mlir::sparse_tensor;
//===----------------------------------------------------------------------===//
-// Code generation environment constructor and setup
+// Code generation environment constructor and general methods
//===----------------------------------------------------------------------===//
CodegenEnv::CodegenEnv(linalg::GenericOp linop, SparsificationOptions opts,
unsigned numTensors, unsigned numLoops,
unsigned numFilterLoops)
- : linalgOp(linop), options(opts), topSort(),
- merger(numTensors, numLoops, numFilterLoops), loopEmitter(nullptr),
- sparseOut(nullptr), redVal(nullptr), redExp(-1u), redCustom(-1u) {}
+ : linalgOp(linop), sparseOptions(opts),
+ latticeMerger(numTensors, numLoops, numFilterLoops), loopEmitter(nullptr),
+ topSort(), sparseOut(nullptr), outerParNest(-1u), insChain(), expValues(),
+ expFilled(), expAdded(), expCount(), redVal(), redExp(-1u),
+ redCustom(-1u) {}
-void CodegenEnv::startEmit(SparseTensorLoopEmitter *le) {
- assert(!loopEmitter && "must only start emitting once");
+void CodegenEnv::startEmit(OpOperand *so, unsigned lv,
+ SparseTensorLoopEmitter *le) {
+ assert(sparseOut == nullptr && loopEmitter == nullptr &&
+ insChain == nullptr && "must only start emitting once");
+ sparseOut = so;
+ outerParNest = lv;
loopEmitter = le;
if (sparseOut) {
insChain = sparseOut->get();
- merger.setHasSparseOut(true);
+ latticeMerger.setHasSparseOut(true);
}
}
//===----------------------------------------------------------------------===//
-// Code generation environment methods
+// Code generation environment topological sort methods
+//===----------------------------------------------------------------------===//
+
+ArrayRef<unsigned> CodegenEnv::getTopSortSlice(size_t n, size_t m) const {
+ return ArrayRef<unsigned>(topSort).slice(n, m);
+}
+
+ArrayRef<unsigned> CodegenEnv::getLoopCurStack() const {
+ return getTopSortSlice(0, loopEmitter->getCurrentDepth());
+}
+
+Value CodegenEnv::getLoopIdxValue(size_t loopIdx) const {
+ for (unsigned lv = 0, lve = topSort.size(); lv < lve; lv++)
+ if (topSort[lv] == loopIdx)
+ return loopEmitter->getLoopIV(lv);
+ llvm_unreachable("invalid loop index");
+}
+
+//===----------------------------------------------------------------------===//
+// Code generation environment sparse tensor output and expansion methods
+//===----------------------------------------------------------------------===//
+
+void CodegenEnv::updateInsertionChain(Value chain) {
+ assert(sparseOut != nullptr && insChain != nullptr);
+ insChain = chain;
+}
+
+bool CodegenEnv::atExpandLevel(OpOperand *o, unsigned rank, unsigned lv) const {
+ return sparseOut == o && outerParNest == rank - 1 && outerParNest == lv;
+}
+
+void CodegenEnv::startExpand(Value values, Value filled, Value added,
+ Value count) {
+ assert(sparseOut != nullptr && expValues == nullptr);
+ expValues = values;
+ expFilled = filled;
+ expAdded = added;
+ expCount = count;
+}
+
+void CodegenEnv::updateExpandCount(Value count) {
+ assert(sparseOut != nullptr && expValues != nullptr);
+ expCount = count;
+}
+
+void CodegenEnv::endExpand() {
+ assert(sparseOut != nullptr && expValues != nullptr);
+ expValues = expFilled = expAdded = expCount = Value();
+}
+
+//===----------------------------------------------------------------------===//
+// Code generation environment reduction methods
//===----------------------------------------------------------------------===//
void CodegenEnv::startReduc(unsigned exp, Value val) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
index 1b0362203f2fd..cb5ba99a93065 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
@@ -28,59 +28,87 @@ namespace sparse_tensor {
/// sparsification. This environment simplifies passing around such
/// data during sparsification (rather than passing around all the
/// individual compoments where needed). Furthermore, it provides
-/// a number of delegate and convience methods that keep some of the
-/// implementation details transparent to sparsification.
+/// convience methods that keep implementation details transparent
+/// to sparsification while asserting on internal consistency.
class CodegenEnv {
public:
+ /// Constructs a code generation environment which can be
+ /// passed around during sparsification for bookkeeping
+ /// together with some consistency asserts.
CodegenEnv(linalg::GenericOp linop, SparsificationOptions opts,
unsigned numTensors, unsigned numLoops, unsigned numFilterLoops);
- // Start emitting.
- void startEmit(SparseTensorLoopEmitter *le);
+ //
+ // General methods.
+ //
+
+ linalg::GenericOp op() const { return linalgOp; }
+ const SparsificationOptions &options() const { return sparseOptions; }
+ Merger &merger() { return latticeMerger; }
+ SparseTensorLoopEmitter *emitter() { return loopEmitter; }
+
+ void startEmit(OpOperand *so, unsigned lv, SparseTensorLoopEmitter *le);
- // Delegate methods to merger.
- TensorExp &exp(unsigned e) { return merger.exp(e); }
- LatPoint &lat(unsigned l) { return merger.lat(l); }
- SmallVector<unsigned> &set(unsigned s) { return merger.set(s); }
- DimLevelType dimLevelType(unsigned t, unsigned i) const {
- return merger.getDimLevelType(t, i);
+ //
+ // Merger delegates.
+ //
+
+ TensorExp &exp(unsigned e) { return latticeMerger.exp(e); }
+ LatPoint &lat(unsigned l) { return latticeMerger.lat(l); }
+ SmallVector<unsigned> &set(unsigned s) { return latticeMerger.set(s); }
+ DimLevelType dlt(unsigned t, unsigned i) const {
+ return latticeMerger.getDimLevelType(t, i);
}
- DimLevelType dimLevelType(unsigned b) const {
- return merger.getDimLevelType(b);
+ DimLevelType dlt(unsigned b) const {
+ return latticeMerger.getDimLevelType(b);
}
- bool isFilterLoop(unsigned i) const { return merger.isFilterLoop(i); }
- // Delegate methods to loop emitter.
- Value getLoopIV(unsigned i) const { return loopEmitter->getLoopIV(i); }
- const std::vector<Value> &getValBuffer() const {
- return loopEmitter->getValBuffer();
- }
+ //
+ // Topological delegate and sort methods.
+ //
- // Convenience method to slice topsort.
- ArrayRef<unsigned> getTopSortSlice(size_t n, size_t m) const {
- return ArrayRef<unsigned>(topSort).slice(n, m);
- }
+ // TODO: get rid of this one!
+ std::vector<unsigned> &topSortRef() { return topSort; }
- // Convenience method to get current loop stack.
- ArrayRef<unsigned> getLoopCurStack() const {
- return getTopSortSlice(0, loopEmitter->getCurrentDepth());
+ size_t topSortSize() const { return topSort.size(); }
+ unsigned topSortAt(unsigned i) const { return topSort.at(i); }
+ void topSortPushBack(unsigned i) { topSort.push_back(i); }
+ void topSortClear(unsigned capacity = 0) {
+ topSort.clear();
+ topSort.reserve(capacity);
}
- // Convenience method to get the IV of the given loop index.
- Value getLoopIdxValue(size_t loopIdx) const {
- for (unsigned lv = 0, lve = topSort.size(); lv < lve; lv++)
- if (topSort[lv] == loopIdx)
- return getLoopIV(lv);
- llvm_unreachable("invalid loop index");
- }
+ ArrayRef<unsigned> getTopSortSlice(size_t n, size_t m) const;
+ ArrayRef<unsigned> getLoopCurStack() const;
+ Value getLoopIdxValue(size_t loopIdx) const;
+
+ //
+ // Sparse tensor output and expansion methods.
+ //
+
+ bool hasSparseOutput() const { return sparseOut != nullptr; }
+ bool isSparseOutput(OpOperand *o) const { return sparseOut == o; }
+
+ Value getInsertionChain() const { return insChain; }
+ void updateInsertionChain(Value chain);
+
+ bool atExpandLevel(OpOperand *o, unsigned rank, unsigned lv) const;
+ void startExpand(Value values, Value filled, Value added, Value count);
+ bool isExpand() const { return expValues != nullptr; }
+ void updateExpandCount(Value count);
+ Value getExpandValues() const { return expValues; }
+ Value getExpandFilled() const { return expFilled; }
+ Value getExpandAdded() const { return expAdded; }
+ Value getExpandCount() const { return expCount; }
+ void endExpand();
//
- // Reductions.
+ // Reduction methods.
//
void startReduc(unsigned exp, Value val);
- void updateReduc(Value val);
bool isReduc() const { return redExp != -1u; }
+ void updateReduc(Value val);
Value getReduc() const { return redVal; }
Value endReduc();
@@ -89,39 +117,34 @@ class CodegenEnv {
Value getCustomRedId();
void endCustomReduc();
-public:
- //
- // TODO make this section private too, using similar refactoring as for reduc
- //
-
+private:
// Linalg operation.
linalg::GenericOp linalgOp;
// Sparsification options.
- SparsificationOptions options;
-
- // Topological sort.
- std::vector<unsigned> topSort;
+ SparsificationOptions sparseOptions;
// Merger helper class.
- Merger merger;
+ Merger latticeMerger;
// Loop emitter helper class (keep reference in scope!).
// TODO: move emitter constructor up in time?
SparseTensorLoopEmitter *loopEmitter;
+ // Topological sort.
+ std::vector<unsigned> topSort;
+
// Sparse tensor as output. Implemented either through direct injective
- // insertion in lexicographic index order or through access pattern expansion
- // in the innermost loop nest (`expValues` through `expCount`).
+ // insertion in lexicographic index order or through access pattern
+ // expansion in the innermost loop nest (`expValues` through `expCount`).
OpOperand *sparseOut;
unsigned outerParNest;
- Value insChain; // bookkeeping for insertion chain
+ Value insChain;
Value expValues;
Value expFilled;
Value expAdded;
Value expCount;
-private:
// Bookkeeping for reductions (up-to-date value of the reduction, and indices
// into the merger's expression tree. When the indices of a tensor reduction
// expression are exhausted, all inner loops can use a scalarized reduction.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index a6c031486c7e3..eb71c4cb8ff0e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -127,8 +127,8 @@ static bool isInvariantAffine(CodegenEnv &env, AffineExpr a, unsigned ldx,
/// Helper method to construct a permuted dimension ordering
/// that adheres to the given topological sort.
static AffineMap permute(CodegenEnv &env, AffineMap m) {
- assert(m.getNumDims() + env.merger.getNumFilterLoops() ==
- env.topSort.size() &&
+ assert(m.getNumDims() + env.merger().getNumFilterLoops() ==
+ env.topSortSize() &&
"size mismatch");
// Construct the inverse of `m`; to avoid the asymptotic complexity
// of calling `m.getPermutedPosition` repeatedly.
@@ -138,14 +138,14 @@ static AffineMap permute(CodegenEnv &env, AffineMap m) {
unsigned loopDepth = 1;
// Construct the permutation.
- while (worklist.any() && loopDepth <= env.topSort.size()) {
+ while (worklist.any() && loopDepth <= env.topSortSize()) {
unsigned preSize = perm.size();
for (auto dim : worklist.set_bits()) {
bool atLevel = false;
if (m.getResult(dim).isa<AffineConstantExpr>() ||
(isInvariantAffine(m.getResult(dim),
env.getTopSortSlice(0, loopDepth),
- env.topSort[loopDepth - 1], atLevel) &&
+ env.topSortAt(loopDepth - 1), atLevel) &&
atLevel)) {
// If the matching affine is constant expression or just become
// invariant. We can visit the dimension now without breaking the
@@ -163,7 +163,7 @@ static AffineMap permute(CodegenEnv &env, AffineMap m) {
}
assert(perm.size() == numResults);
- return AffineMap::getPermutationMap(perm, env.linalgOp.getContext());
+ return AffineMap::getPermutationMap(perm, env.op().getContext());
}
/// Helper method to inspect affine expressions. Rejects cases where the
@@ -255,22 +255,22 @@ static unsigned getNumCompoundAffineOnSparseDims(linalg::GenericOp op) {
/// no annotations are found or inadmissible constructs occur.
static bool findSparseAnnotations(CodegenEnv &env) {
bool annotated = false;
- unsigned filterLdx = env.merger.getFilterLoopStartingIdx();
- for (OpOperand &t : env.linalgOp->getOpOperands()) {
- auto map = env.linalgOp.getMatchingIndexingMap(&t);
+ unsigned filterLdx = env.merger().getFilterLoopStartingIdx();
+ for (OpOperand &t : env.op()->getOpOperands()) {
+ auto map = env.op().getMatchingIndexingMap(&t);
auto enc = getSparseTensorEncoding(t.get().getType());
if (enc)
annotated = true;
- assert(map.getNumResults() == env.linalgOp.getRank(&t));
+ assert(map.getNumResults() == env.op().getRank(&t));
for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
unsigned tensor = t.getOperandNumber();
AffineExpr a = map.getResult(toOrigDim(enc, d));
- if (!findAffine(env.merger, tensor, d, a, getDimLevelType(enc, d),
+ if (!findAffine(env.merger(), tensor, d, a, getDimLevelType(enc, d),
filterLdx))
return false; // inadmissible affine expression
}
}
- assert(filterLdx == env.merger.getNumLoops());
+ assert(filterLdx == env.merger().getNumLoops());
return annotated;
}
@@ -287,7 +287,7 @@ static bool topSortOptimal(CodegenEnv &env, unsigned n,
std::vector<unsigned> filterIt; // filter loop with 0 degree
for (unsigned i = 0; i < n; i++) {
if (inDegree[i] == 0) {
- if (env.isFilterLoop(i))
+ if (env.merger().isFilterLoop(i))
filterIt.push_back(i);
else if (linalg::isReductionIterator(iteratorTypes[i]))
redIt.push_back(i);
@@ -318,12 +318,12 @@ static bool topSortOptimal(CodegenEnv &env, unsigned n,
// O(X) computation => O(NK+NMX) time complexity
auto &it = !filterIt.empty() ? filterIt : (!parIt.empty() ? parIt : redIt);
auto src = it.back();
- env.topSort.push_back(src);
+ env.topSortPushBack(src);
it.pop_back();
// Update in-degree, and push 0-degree node into worklist.
for (unsigned dst = 0; dst < n; dst++) {
if (adjM[src][dst] && --inDegree[dst] == 0) {
- if (env.isFilterLoop(dst))
+ if (env.merger().isFilterLoop(dst))
filterIt.push_back(dst);
else if (linalg::isReductionIterator(iteratorTypes[dst]))
redIt.push_back(dst);
@@ -332,7 +332,7 @@ static bool topSortOptimal(CodegenEnv &env, unsigned n,
}
}
}
- return env.topSort.size() == n;
+ return env.topSortSize() == n;
}
/// Helper method to add all constraints from the indices in one affine
@@ -428,17 +428,16 @@ static bool computeIterationGraph(CodegenEnv &env, unsigned mask,
OpOperand *skip = nullptr) {
// Set up an n x n from/to adjacency matrix of the iteration graph
// for the implicit loop indices i_0 .. i_n-1.
- unsigned n = env.merger.getNumLoops();
+ unsigned n = env.merger().getNumLoops();
std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false));
std::vector<unsigned> inDegree(n, 0); // in-degree of each node.
- auto iteratorTypes = env.linalgOp.getIteratorTypesArray();
+ auto iteratorTypes = env.op().getIteratorTypesArray();
// Iterate over the indexing maps of every tensor in the tensor expression.
- for (OpOperand &t : env.linalgOp->getOpOperands()) {
+ for (OpOperand &t : env.op()->getOpOperands()) {
// Get map and encoding.
- auto map = env.linalgOp.getMatchingIndexingMap(&t);
+ auto map = env.op().getMatchingIndexingMap(&t);
auto enc = getSparseTensorEncoding(t.get().getType());
- assert(map.getNumDims() + getNumCompoundAffineOnSparseDims(env.linalgOp) ==
- n);
+ assert(map.getNumDims() + getNumCompoundAffineOnSparseDims(env.op()) == n);
// Skip dense tensor constraints when not requested.
if (!(mask & SortMask::kIncludeDense) && !enc)
continue;
@@ -448,11 +447,12 @@ static bool computeIterationGraph(CodegenEnv &env, unsigned mask,
// on the loop indices if no explicit dimension ordering is given.
for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
AffineExpr ta = map.getResult(toOrigDim(enc, d));
- Optional<unsigned> tldx = env.merger.getLoopIdx(t.getOperandNumber(), d);
+ Optional<unsigned> tldx =
+ env.merger().getLoopIdx(t.getOperandNumber(), d);
// Filter loops should be constructed after all the dependent loops,
// i.e., d0 + d1 < filter_loop(d0 + d1)
- if (tldx && env.isFilterLoop(*tldx)) {
+ if (tldx && env.merger().isFilterLoop(*tldx)) {
assert(!ta.isa<AffineDimExpr>() &&
!isDenseDLT(getDimLevelType(enc, d)));
addAffineOrderings(adjM, inDegree, ta, AffineExpr(), std::nullopt,
@@ -472,7 +472,7 @@ static bool computeIterationGraph(CodegenEnv &env, unsigned mask,
if (d > 0) {
AffineExpr fa = map.getResult(toOrigDim(enc, d - 1));
Optional<unsigned> fldx =
- env.merger.getLoopIdx(t.getOperandNumber(), d - 1);
+ env.merger().getLoopIdx(t.getOperandNumber(), d - 1);
// Applying order constraints on every pair of dimExpr between two
// compound affine expressions can sometime too strict:
@@ -480,7 +480,7 @@ static bool computeIterationGraph(CodegenEnv &env, unsigned mask,
// It is totally fine to have loop sequence d0->d2->d1->d3 instead of
// requiring d0 < d2, d1 < d2, d0 < d3, d1 < d3.
if (!(mask & SortMask::kIncludeDense))
- tryLoosenAffineDenseConstraints(env.linalgOp, fldx, fa, tldx, ta);
+ tryLoosenAffineDenseConstraints(env.op(), fldx, fa, tldx, ta);
// (d0 + d1) < (d2 + d3), or
// filter_loop_d-1 < (d2 + d3), or
@@ -495,23 +495,22 @@ static bool computeIterationGraph(CodegenEnv &env, unsigned mask,
if (mask & SortMask::kIncludeUndef) {
unsigned tensor = t.getOperandNumber();
for (unsigned i = 0; i < n; i++)
- if (isCompressedDLT(env.dimLevelType(tensor, i)) ||
- isSingletonDLT(env.dimLevelType(tensor, i))) {
+ if (isCompressedDLT(env.dlt(tensor, i)) ||
+ isSingletonDLT(env.dlt(tensor, i))) {
for (unsigned j = 0; j < n; j++)
- if (isUndefDLT(env.dimLevelType(tensor, j))) {
+ if (isUndefDLT(env.dlt(tensor, j))) {
adjM[i][j] = true;
inDegree[j]++;
}
} else {
- assert(isDenseDLT(env.dimLevelType(tensor, i)) ||
- isUndefDLT(env.dimLevelType(tensor, i)));
+ assert(isDenseDLT(env.dlt(tensor, i)) ||
+ isUndefDLT(env.dlt(tensor, i)));
}
}
}
// Topologically sort the iteration graph to determine loop order.
// Report failure for a cyclic iteration graph.
- env.topSort.clear();
- env.topSort.reserve(n);
+ env.topSortClear(n);
return topSortOptimal(env, n, iteratorTypes, inDegree, adjM);
}
@@ -526,20 +525,22 @@ static bool isMaterializing(Value val) {
/// whether the out tensor in the tensor expression codegen is admissible.
/// Sets `sparseOut` to the tensor and `outerParNest` to the outer injective
/// nesting depth when a "truly dynamic" sparse tensor output occurs.
-static bool isAdmissibleTensorExp(CodegenEnv &env, unsigned exp) {
+static bool isAdmissibleTensorExp(CodegenEnv &env, unsigned exp,
+ OpOperand **sparseOut,
+ unsigned *outerParNest) {
// We reject any expression that makes a reduction from `-outTensor`, as those
// expressions create a dependency between the current iteration (i) and the
// previous iteration (i-1). It would require iterating over the whole
// coordinate space, which prevent exploiting sparsity for faster code.
- for (utils::IteratorType it : env.linalgOp.getIteratorTypesArray()) {
+ for (utils::IteratorType it : env.op().getIteratorTypesArray()) {
if (it == utils::IteratorType::reduction) {
- if (env.merger.hasNegateOnOut(exp))
+ if (env.merger().hasNegateOnOut(exp))
return false;
break;
}
}
- OpOperand *lhs = env.linalgOp.getDpsInitOperand(0);
+ OpOperand *lhs = env.op().getDpsInitOperand(0);
unsigned tensor = lhs->getOperandNumber();
auto enc = getSparseTensorEncoding(lhs->get().getType());
// An non-annotated output tensor is assumed dense, and becomes a random
@@ -550,40 +551,39 @@ static bool isAdmissibleTensorExp(CodegenEnv &env, unsigned exp) {
// access 1-dim memref. Also admissible since insertions cannot occur.
bool allDense = true;
unsigned numLoops =
- env.merger.getNumLoops(); // numNativeLoops + numFilterLoops
- for (unsigned i = 0; i < env.merger.getNumLoops(); i++)
- if (isCompressedDLT(env.dimLevelType(tensor, i)) ||
- isSingletonDLT(env.dimLevelType(tensor, i))) {
+ env.merger().getNumLoops(); // numNativeLoops + numFilterLoops
+ for (unsigned i = 0; i < env.merger().getNumLoops(); i++)
+ if (isCompressedDLT(env.dlt(tensor, i)) ||
+ isSingletonDLT(env.dlt(tensor, i))) {
allDense = false;
break;
} else {
- assert(isDenseDLT(env.dimLevelType(tensor, i)) ||
- isUndefDLT(env.dimLevelType(tensor, i)));
+ assert(isDenseDLT(env.dlt(tensor, i)) || isUndefDLT(env.dlt(tensor, i)));
}
if (allDense)
return true;
// TODO: support compound affine expression on sparse output.
- if (getNumCompoundAffineOnSparseDims(env.linalgOp.getMatchingIndexingMap(lhs),
+ if (getNumCompoundAffineOnSparseDims(env.op().getMatchingIndexingMap(lhs),
lhs->get()) != 0)
return false;
// A tensor expression with a sparse output tensor that changes its values
// but not its nonzero structure, an operation called "simply dynamic" in
// [Bik96,Ch9], is also admissible without special env.
- if (env.merger.isSingleCondition(tensor, exp))
+ if (env.merger().isSingleCondition(tensor, exp))
return true;
// Accept "truly dynamic" if the output tensor materializes uninitialized
// into the computation and insertions occur in lexicographic index order.
if (isMaterializing(lhs->get())) {
- auto iteratorTypes = env.linalgOp.getIteratorTypesArray();
+ auto iteratorTypes = env.op().getIteratorTypesArray();
unsigned nest = 0;
for (unsigned i = 0; i < numLoops; i++) {
- if (!env.isFilterLoop(env.topSort[i])) {
+ if (!env.merger().isFilterLoop(env.topSortAt(i))) {
// We only count non-filter loops as filter loops should be considered
// as a special type of parallel loops.
- if (linalg::isReductionIterator(iteratorTypes[env.topSort[i]]))
+ if (linalg::isReductionIterator(iteratorTypes[env.topSortAt(i)]))
break; // terminate at first reduction
nest++;
}
@@ -591,9 +591,9 @@ static bool isAdmissibleTensorExp(CodegenEnv &env, unsigned exp) {
// Determine admissible dynamic insertion situations:
// (1) fully injective, since there are no reductions,
// (2) admissible 1-d expansion in innermost dimension.
- if (nest >= env.linalgOp.getRank(lhs) - 1) {
- env.sparseOut = lhs;
- env.outerParNest = nest;
+ if (nest >= env.op().getRank(lhs) - 1) {
+ *sparseOut = lhs;
+ *outerParNest = nest;
return true;
}
}
@@ -613,10 +613,10 @@ static Optional<Operation *> genLoopBoundary(
SmallVector<Value> reduc;
if (env.isReduc())
reduc.push_back(env.getReduc());
- if (env.expValues)
- reduc.push_back(env.expCount);
- if (env.insChain)
- reduc.push_back(env.insChain);
+ if (env.isExpand())
+ reduc.push_back(env.getExpandCount());
+ if (env.getInsertionChain())
+ reduc.push_back(env.getInsertionChain());
auto r = callback(reduc);
@@ -624,21 +624,21 @@ static Optional<Operation *> genLoopBoundary(
unsigned i = 0;
if (env.isReduc())
env.updateReduc(reduc[i++]);
- if (env.expValues)
- env.expCount = reduc[i++];
- if (env.insChain)
- env.insChain = reduc[i];
+ if (env.isExpand())
+ env.updateExpandCount(reduc[i++]);
+ if (env.getInsertionChain())
+ env.updateInsertionChain(reduc[i]);
return r;
}
/// Local bufferization of all dense and sparse data structures.
static void genBuffers(CodegenEnv &env, OpBuilder &builder) {
- linalg::GenericOp op = env.linalgOp;
+ linalg::GenericOp op = env.op();
Location loc = op.getLoc();
assert(op.getNumOperands() == op.getNumDpsInputs() + 1);
- env.loopEmitter->initializeLoopEmit(
+ env.emitter()->initializeLoopEmit(
builder, loc,
/// Generates buffer for the output tensor.
/// Note that all sparse kernels assume that when all elements are written
@@ -676,7 +676,7 @@ static void genBuffers(CodegenEnv &env, OpBuilder &builder) {
/// Generates index for load/store on sparse tensor.
static Value genIndex(CodegenEnv &env, OpOperand *t) {
- auto map = env.linalgOp.getMatchingIndexingMap(t);
+ auto map = env.op().getMatchingIndexingMap(t);
auto enc = getSparseTensorEncoding(t->get().getType());
AffineExpr a = map.getResult(toOrigDim(enc, map.getNumResults() - 1));
assert(a.getKind() == AffineExprKind::DimId);
@@ -687,69 +687,73 @@ static Value genIndex(CodegenEnv &env, OpOperand *t) {
/// Generates subscript for load/store on a dense or sparse tensor.
static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
SmallVectorImpl<Value> &args) {
- linalg::GenericOp op = env.linalgOp;
+ linalg::GenericOp op = env.op();
unsigned tensor = t->getOperandNumber();
auto map = op.getMatchingIndexingMap(t);
auto enc = getSparseTensorEncoding(t->get().getType());
unsigned rank = map.getNumResults();
if (enc) {
- Value pidx = env.loopEmitter->getPidxs()[tensor].back();
+ Value pidx = env.emitter()->getPidxs()[tensor].back();
assert(pidx);
args.push_back(pidx); // position index
} else {
for (unsigned d = 0; d < rank; d++) {
AffineExpr a = map.getResult(d);
- args.push_back(env.loopEmitter->genAffine(builder, a, op.getLoc()));
+ args.push_back(env.emitter()->genAffine(builder, a, op.getLoc()));
}
}
- return env.getValBuffer()[tensor];
+ return env.emitter()->getValBuffer()[tensor];
}
/// Generates insertion code to implement dynamic tensor load.
static Value genInsertionLoad(CodegenEnv &env, OpBuilder &builder,
OpOperand *t) {
- linalg::GenericOp op = env.linalgOp;
+ linalg::GenericOp op = env.op();
Location loc = op.getLoc();
// Direct lexicographic index order, tensor loads as zero.
- if (!env.expValues) {
+ if (!env.isExpand()) {
Type tp = getElementTypeOrSelf(t->get().getType());
return constantZero(builder, loc, tp);
}
// Load from expanded access pattern.
Value index = genIndex(env, t);
- return builder.create<memref::LoadOp>(loc, env.expValues, index);
+ return builder.create<memref::LoadOp>(loc, env.getExpandValues(), index);
}
/// Generates insertion code to implement dynamic tensor load for reduction.
static Value genInsertionLoadReduce(CodegenEnv &env, OpBuilder &builder,
OpOperand *t) {
- linalg::GenericOp op = env.linalgOp;
+ linalg::GenericOp op = env.op();
Location loc = op.getLoc();
Value identity = env.getCustomRedId();
// Direct lexicographic index order, tensor loads as identity.
- if (!env.expValues)
+ if (!env.isExpand())
return identity;
// Load from expanded access pattern if filled, identity otherwise.
+ Value values = env.getExpandValues();
+ Value filled = env.getExpandFilled();
Value index = genIndex(env, t);
- Value isFilled = builder.create<memref::LoadOp>(loc, env.expFilled, index);
- Value valAtIndex = builder.create<memref::LoadOp>(loc, env.expValues, index);
+ Value isFilled = builder.create<memref::LoadOp>(loc, filled, index);
+ Value valAtIndex = builder.create<memref::LoadOp>(loc, values, index);
return builder.create<arith::SelectOp>(loc, isFilled, valAtIndex, identity);
}
/// Generates insertion code to implement dynamic tensor store.
static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
Value rhs) {
- linalg::GenericOp op = env.linalgOp;
+ linalg::GenericOp op = env.op();
Location loc = op.getLoc();
// Direct insertion in lexicographic index order.
- if (!env.expValues) {
+ if (!env.isExpand()) {
unsigned rank = op.getRank(t);
SmallVector<Value> indices;
for (unsigned i = 0; i < rank; i++) {
- assert(env.getLoopIV(i));
- indices.push_back(env.getLoopIV(i));
+ assert(env.emitter()->getLoopIV(i));
+ indices.push_back(env.emitter()->getLoopIV(i));
}
- env.insChain = builder.create<InsertOp>(loc, rhs, env.insChain, indices);
+ Value chain = env.getInsertionChain();
+ env.updateInsertionChain(
+ builder.create<InsertOp>(loc, rhs, chain, indices));
return;
}
// Generates insertion code along expanded access pattern.
@@ -758,29 +762,33 @@ static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
// expAdded[inserts++] = i
// endif
// values[i] = rhs
+ Value values = env.getExpandValues();
+ Value filled = env.getExpandFilled();
+ Value added = env.getExpandAdded();
+ Value count = env.getExpandCount();
Value index = genIndex(env, t);
Value fval = constantI1(builder, loc, false);
Value tval = constantI1(builder, loc, true);
// If statement.
- Value filled = builder.create<memref::LoadOp>(loc, env.expFilled, index);
+ Value isFilled = builder.create<memref::LoadOp>(loc, filled, index);
Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
- filled, fval);
+ isFilled, fval);
scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIndexType(), cond,
/*else=*/true);
// True branch.
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
- builder.create<memref::StoreOp>(loc, tval, env.expFilled, index);
- builder.create<memref::StoreOp>(loc, index, env.expAdded, env.expCount);
+ builder.create<memref::StoreOp>(loc, tval, filled, index);
+ builder.create<memref::StoreOp>(loc, index, added, count);
Value one = constantIndex(builder, loc, 1);
- Value add = builder.create<arith::AddIOp>(loc, env.expCount, one);
+ Value add = builder.create<arith::AddIOp>(loc, count, one);
builder.create<scf::YieldOp>(loc, add);
// False branch.
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
- builder.create<scf::YieldOp>(loc, env.expCount);
+ builder.create<scf::YieldOp>(loc, count);
builder.setInsertionPointAfter(ifOp);
// Value assignment.
- env.expCount = ifOp.getResult(0);
- builder.create<memref::StoreOp>(loc, rhs, env.expValues, index);
+ env.updateExpandCount(ifOp.getResult(0));
+ builder.create<memref::StoreOp>(loc, rhs, values, index);
}
/// Generates a load on a dense or sparse tensor.
@@ -791,23 +799,23 @@ static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, unsigned exp) {
return val;
// Load during insertion.
- linalg::GenericOp op = env.linalgOp;
- OpOperand &t = op->getOpOperand(env.exp(exp).tensor);
- if (&t == env.sparseOut) {
+ linalg::GenericOp op = env.op();
+ OpOperand *t = &op->getOpOperand(env.exp(exp).tensor);
+ if (env.isSparseOutput(t)) {
if (env.isCustomReduc())
- return genInsertionLoadReduce(env, builder, &t);
- return genInsertionLoad(env, builder, &t);
+ return genInsertionLoadReduce(env, builder, t);
+ return genInsertionLoad(env, builder, t);
}
// Actual load.
SmallVector<Value> args;
- Value ptr = genSubscript(env, builder, &t, args);
+ Value ptr = genSubscript(env, builder, t, args);
return builder.create<memref::LoadOp>(op.getLoc(), ptr, args);
}
/// Generates a store on a dense or sparse tensor.
static void genTensorStore(CodegenEnv &env, OpBuilder &builder, unsigned exp,
Value rhs) {
- linalg::GenericOp op = env.linalgOp;
+ linalg::GenericOp op = env.op();
Location loc = op.getLoc();
// Test if this is a scalarized reduction.
if (env.isReduc()) {
@@ -816,17 +824,16 @@ static void genTensorStore(CodegenEnv &env, OpBuilder &builder, unsigned exp,
}
// Store during insertion.
OpOperand *t = op.getDpsInitOperand(0);
- if (t == env.sparseOut) {
+ if (env.isSparseOutput(t)) {
if (!rhs) {
// Only unary and binary are allowed to return uninitialized rhs
// to indicate missing output.
assert(env.exp(exp).kind == kUnary || env.exp(exp).kind == kBinary);
} else if (env.exp(exp).kind == kSelect) {
// Select operation insertion.
- Value insChain = env.insChain;
- assert(insChain);
- scf::IfOp ifOp = builder.create<scf::IfOp>(loc, insChain.getType(), rhs,
- /*else=*/true);
+ Value chain = env.getInsertionChain();
+ scf::IfOp ifOp =
+ builder.create<scf::IfOp>(loc, chain.getType(), rhs, /*else=*/true);
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
// Existing value was preserved to be used here.
assert(env.exp(exp).val);
@@ -834,12 +841,13 @@ static void genTensorStore(CodegenEnv &env, OpBuilder &builder, unsigned exp,
genInsertionStore(env, builder, t, v0);
env.exp(exp).val = Value();
// Yield modified insertion chain along true branch.
- builder.create<scf::YieldOp>(op.getLoc(), env.insChain);
+ Value mchain = env.getInsertionChain();
+ builder.create<scf::YieldOp>(op.getLoc(), mchain);
// Yield original insertion chain along false branch.
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
- builder.create<scf::YieldOp>(loc, insChain);
+ builder.create<scf::YieldOp>(loc, chain);
// Done with if statement.
- env.insChain = ifOp->getResult(0);
+ env.updateInsertionChain(ifOp->getResult(0));
builder.setInsertionPointAfter(ifOp);
} else {
genInsertionStore(env, builder, t, rhs);
@@ -884,7 +892,7 @@ static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
/// Recursively generates tensor expression.
static Value genExp(CodegenEnv &env, RewriterBase &rewriter, unsigned exp,
unsigned ldx) {
- linalg::GenericOp op = env.linalgOp;
+ linalg::GenericOp op = env.op();
Location loc = op.getLoc();
if (exp == -1u)
@@ -901,7 +909,7 @@ static Value genExp(CodegenEnv &env, RewriterBase &rewriter, unsigned exp,
Value v0 = genExp(env, rewriter, env.exp(exp).children.e0, ldx);
Value v1 = genExp(env, rewriter, env.exp(exp).children.e1, ldx);
- Value ee = env.merger.buildExp(rewriter, loc, exp, v0, v1);
+ Value ee = env.merger().buildExp(rewriter, loc, exp, v0, v1);
if (ee && (env.exp(exp).kind == Kind::kUnary ||
env.exp(exp).kind == Kind::kBinary ||
env.exp(exp).kind == Kind::kBinaryBranch ||
@@ -928,14 +936,15 @@ static void genInvariants(CodegenEnv &env, OpBuilder &builder, unsigned exp,
if (env.exp(exp).kind == Kind::kTensor) {
// Inspect tensor indices.
bool atLevel = ldx == -1u;
- linalg::GenericOp op = env.linalgOp;
+ linalg::GenericOp op = env.op();
OpOperand &t = op->getOpOperand(env.exp(exp).tensor);
auto map = op.getMatchingIndexingMap(&t);
auto enc = getSparseTensorEncoding(t.get().getType());
for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
AffineExpr a = map.getResult(toOrigDim(enc, d));
- Optional<unsigned> sldx = env.merger.getLoopIdx(t.getOperandNumber(), d);
- if (sldx && env.isFilterLoop(*sldx)) {
+ Optional<unsigned> sldx =
+ env.merger().getLoopIdx(t.getOperandNumber(), d);
+ if (sldx && env.merger().isFilterLoop(*sldx)) {
if (!env.getLoopIdxValue(*sldx))
// The filter loops has not been constructed.
return;
@@ -978,11 +987,11 @@ static void genInvariants(CodegenEnv &env, OpBuilder &builder, unsigned exp,
}
/// Generates an expanded access pattern in innermost dimension.
-static void genExpansion(CodegenEnv &env, OpBuilder &builder, unsigned at,
- bool atStart) {
- linalg::GenericOp op = env.linalgOp;
- OpOperand *lhs = env.sparseOut;
- if (!lhs || env.outerParNest != op.getRank(lhs) - 1 || at != env.outerParNest)
+static void genExpand(CodegenEnv &env, OpBuilder &builder, unsigned at,
+ bool atStart) {
+ linalg::GenericOp op = env.op();
+ OpOperand *lhs = op.getDpsInitOperand(0);
+ if (!env.atExpandLevel(lhs, op.getRank(lhs), at))
return; // not needed at this level
assert(!env.isReduc());
// Generate start or end of an expanded access pattern. Note that because
@@ -999,25 +1008,23 @@ static void genExpansion(CodegenEnv &env, OpBuilder &builder, unsigned at,
Type t2 = MemRefType::get(dynShape, builder.getI1Type());
Type t3 = MemRefType::get(dynShape, builder.getIndexType());
Type t4 = builder.getIndexType();
- auto res =
- builder.create<ExpandOp>(loc, TypeRange({t1, t2, t3, t4}), tensor);
- assert(res.getNumResults() == 4);
- assert(!env.expValues);
- env.expValues = res.getResult(0);
- env.expFilled = res.getResult(1);
- env.expAdded = res.getResult(2);
- env.expCount = res.getResult(3);
+ auto r = builder.create<ExpandOp>(loc, TypeRange({t1, t2, t3, t4}), tensor);
+ assert(r.getNumResults() == 4);
+ env.startExpand(r.getResult(0), r.getResult(1), r.getResult(2),
+ r.getResult(3));
} else {
- assert(env.expValues);
SmallVector<Value> indices;
- for (unsigned i = 0; i < at; i++) {
- assert(env.getLoopIV(i));
- indices.push_back(env.getLoopIV(i));
- }
- env.insChain = builder.create<CompressOp>(loc, env.expValues, env.expFilled,
- env.expAdded, env.expCount,
- env.insChain, indices);
- env.expValues = env.expFilled = env.expAdded = env.expCount = Value();
+ for (unsigned i = 0; i < at; i++)
+ indices.push_back(env.emitter()->getLoopIV(i));
+ Value values = env.getExpandValues();
+ Value filled = env.getExpandFilled();
+ Value added = env.getExpandAdded();
+ Value count = env.getExpandCount();
+ Value chain = env.getInsertionChain();
+ Value compress = builder.create<CompressOp>(loc, values, filled, added,
+ count, chain, indices);
+ env.updateInsertionChain(compress);
+ env.endExpand();
}
}
@@ -1026,13 +1033,13 @@ static void genExpansion(CodegenEnv &env, OpBuilder &builder, unsigned at,
/// converted to a parallel operation depends on the requested strategy.
static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse) {
// Reject parallelization of sparse output.
- if (env.sparseOut)
+ if (env.hasSparseOutput())
return false;
// Parallel loops on tensor expansion can cause data races.
- if (env.expCount)
+ if (env.isExpand())
return false;
// Inspect strategy.
- switch (env.options.parallelizationStrategy) {
+ switch (env.options().parallelizationStrategy) {
case SparseParallelizationStrategy::kNone:
return false;
case SparseParallelizationStrategy::kDenseOuterLoop:
@@ -1052,15 +1059,15 @@ static Operation *genFor(CodegenEnv &env, OpBuilder &builder, bool isOuter,
bool isInner, unsigned idx, size_t tid, size_t dim,
ArrayRef<size_t> extraTids,
ArrayRef<size_t> extraDims) {
- linalg::GenericOp op = env.linalgOp;
+ linalg::GenericOp op = env.op();
Location loc = op.getLoc();
auto iteratorTypes = op.getIteratorTypesArray();
- bool isSparse = isCompressedDLT(env.dimLevelType(tid, idx)) ||
- isSingletonDLT(env.dimLevelType(tid, idx));
+ bool isSparse =
+ isCompressedDLT(env.dlt(tid, idx)) || isSingletonDLT(env.dlt(tid, idx));
bool isParallel = isParallelFor(env, isOuter, isSparse);
Operation *loop = *genLoopBoundary(env, [&](MutableArrayRef<Value> reduc) {
- if (env.isFilterLoop(idx)) {
+ if (env.merger().isFilterLoop(idx)) {
// extraTids/extraDims must be empty because filter loops only
// corresponding to the one and only sparse tensor level.
assert(isSparse && extraTids.empty() && extraDims.empty());
@@ -1069,10 +1076,10 @@ static Operation *genFor(CodegenEnv &env, OpBuilder &builder, bool isOuter,
// Retrieves the affine expression for the filter loop.
AffineExpr a =
op.getMatchingIndexingMap(t).getResult(toOrigDim(enc, dim));
- return env.loopEmitter->enterFilterLoopOverTensorAtDim(builder, loc, tid,
- dim, a, reduc);
+ return env.emitter()->enterFilterLoopOverTensorAtDim(builder, loc, tid,
+ dim, a, reduc);
}
- return env.loopEmitter->enterLoopOverTensorAtDim(
+ return env.emitter()->enterLoopOverTensorAtDim(
builder, loc, tid, dim, reduc, isParallel, extraTids, extraDims);
});
assert(loop);
@@ -1088,8 +1095,8 @@ static Operation *genWhile(CodegenEnv &env, OpBuilder &builder, unsigned idx,
Operation *loop = *genLoopBoundary(env, [&](MutableArrayRef<Value> reduc) {
// Construct the while-loop with a parameter for each
// index.
- return env.loopEmitter->enterCoIterationOverTensorsAtDims(
- builder, env.linalgOp.getLoc(), condTids, condDims, needsUniv, reduc,
+ return env.emitter()->enterCoIterationOverTensorsAtDims(
+ builder, env.op().getLoc(), condTids, condDims, needsUniv, reduc,
extraTids, extraDims);
});
assert(loop);
@@ -1104,10 +1111,10 @@ static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, unsigned at,
ArrayRef<size_t> extraDims) {
assert(condTids.size() == condDims.size());
assert(extraTids.size() == extraDims.size());
- unsigned idx = env.topSort[at];
+ unsigned idx = env.topSortAt(at);
if (condTids.size() == 1) {
bool isOuter = at == 0;
- bool isInner = at == env.topSort.size() - 1;
+ bool isInner = at == env.topSortSize() - 1;
return genFor(env, builder, isOuter, isInner, idx, condTids.front(),
condDims.front(), extraTids, extraDims);
}
@@ -1119,9 +1126,9 @@ static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, unsigned at,
static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, unsigned idx,
bool needsUniv, BitVector &induction,
scf::WhileOp whileOp) {
- Location loc = env.linalgOp.getLoc();
+ Location loc = env.op().getLoc();
// Finalize each else branch of all if statements.
- if (env.isReduc() || env.expValues || env.insChain) {
+ if (env.isReduc() || env.isExpand() || env.getInsertionChain()) {
while (auto ifOp = dyn_cast_or_null<scf::IfOp>(
builder.getInsertionBlock()->getParentOp())) {
unsigned y = 0;
@@ -1130,13 +1137,13 @@ static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, unsigned idx,
yields.push_back(env.getReduc());
env.updateReduc(ifOp.getResult(y++));
}
- if (env.expValues) {
- yields.push_back(env.expCount);
- env.expCount = ifOp->getResult(y++);
+ if (env.isExpand()) {
+ yields.push_back(env.getExpandCount());
+ env.updateExpandCount(ifOp->getResult(y++));
}
- if (env.insChain) {
- yields.push_back(env.insChain);
- env.insChain = ifOp->getResult(y++);
+ if (env.getInsertionChain()) {
+ yields.push_back(env.getInsertionChain());
+ env.updateInsertionChain(ifOp->getResult(y++));
}
assert(y == yields.size());
builder.create<scf::YieldOp>(loc, yields);
@@ -1149,35 +1156,34 @@ static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, unsigned idx,
/// Generates a single if-statement within a while-loop.
static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, unsigned idx,
BitVector &conditions) {
- Location loc = env.linalgOp.getLoc();
+ Location loc = env.op().getLoc();
SmallVector<Type> types;
Value cond;
for (unsigned b = 0, be = conditions.size(); b < be; b++) {
if (!conditions[b])
continue;
- unsigned tensor = env.merger.tensor(b);
- assert(idx == env.merger.index(b));
+ unsigned tensor = env.merger().tensor(b);
+ assert(idx == env.merger().index(b));
Value clause;
- if (isCompressedDLT(env.dimLevelType(b)) ||
- isSingletonDLT(env.dimLevelType(b))) {
- auto dim = *env.merger.getDimNum(tensor, idx);
- Value op1 = env.loopEmitter->getCoord()[tensor][dim];
+ if (isCompressedDLT(env.dlt(b)) || isSingletonDLT(env.dlt(b))) {
+ auto dim = *env.merger().getDimNum(tensor, idx);
+ Value op1 = env.emitter()->getCoord()[tensor][dim];
Value op2 = env.getLoopIdxValue(idx);
clause = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, op1,
op2);
} else {
- assert(isDenseDLT(env.dimLevelType(b)) ||
- isUndefDLT(env.dimLevelType(b)));
+ assert(isDenseDLT(env.merger().getDimLevelType(b)) ||
+ isUndefDLT(env.merger().getDimLevelType(b)));
clause = constantI1(builder, loc, true);
}
cond = cond ? builder.create<arith::AndIOp>(loc, cond, clause) : clause;
}
if (env.isReduc())
types.push_back(env.getReduc().getType());
- if (env.expValues)
+ if (env.isExpand())
types.push_back(builder.getIndexType());
- if (env.insChain)
- types.push_back(env.insChain.getType());
+ if (env.getInsertionChain())
+ types.push_back(env.getInsertionChain().getType());
scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true);
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
return ifOp;
@@ -1192,16 +1198,16 @@ static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
operands.push_back(env.getReduc());
env.updateReduc(redInput);
}
- if (env.expValues) {
- operands.push_back(env.expCount);
- env.expCount = cntInput;
+ if (env.isExpand()) {
+ operands.push_back(env.getExpandCount());
+ env.updateExpandCount(cntInput);
}
- if (env.insChain) {
- operands.push_back(env.insChain);
- env.insChain = insInput;
+ if (env.getInsertionChain()) {
+ operands.push_back(env.getInsertionChain());
+ env.updateInsertionChain(insInput);
}
if (!operands.empty())
- builder.create<scf::YieldOp>(env.linalgOp.getLoc(), operands);
+ builder.create<scf::YieldOp>(env.op().getLoc(), operands);
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
}
@@ -1218,17 +1224,17 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp,
// Emit invariants at this loop sequence level.
genInvariants(env, builder, exp, ldx, /*atStart=*/true);
// Emit access pattern expansion for sparse tensor output.
- genExpansion(env, builder, at, /*atStart=*/true);
+ genExpand(env, builder, at, /*atStart=*/true);
// Emit further intitialization at this loop sequence level.
unsigned l0 = env.set(lts)[0];
bool needsUniv = false;
SmallVector<size_t> tids;
SmallVector<size_t> dims;
- env.merger.foreachTidDimPairInBits(
+ env.merger().foreachTidDimPairInBits(
env.lat(l0).bits,
[&](unsigned b, unsigned tid, Optional<unsigned> dim, DimLevelType dlt) {
- assert(env.merger.index(b) == idx);
+ assert(env.merger().index(b) == idx);
if (isDenseDLT(dlt) || isUndefDLT(dlt)) {
needsUniv = true;
} else {
@@ -1238,7 +1244,7 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp,
}
});
- env.loopEmitter->enterNewLoopSeq(builder, env.linalgOp.getLoc(), tids, dims);
+ env.emitter()->enterNewLoopSeq(builder, env.op().getLoc(), tids, dims);
// Maintain the universal index only if it is actually
// consumed by a subsequent lattice point.
@@ -1246,7 +1252,7 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp,
unsigned lsize = env.set(lts).size();
for (unsigned i = 1; i < lsize; i++) {
unsigned li = env.set(lts)[i];
- if (!env.merger.hasAnySparse(env.lat(li).simple))
+ if (!env.merger().hasAnySparse(env.lat(li).simple))
return true;
}
}
@@ -1257,7 +1263,7 @@ static void genConstantDenseAddressFromLevel(CodegenEnv &env,
OpBuilder &builder, unsigned tid,
unsigned lvl) {
// TODO: Handle affine expression on output tensor.
- linalg::GenericOp op = env.linalgOp;
+ linalg::GenericOp op = env.op();
assert(tid < op.getNumDpsInputs());
OpOperand *input = op.getDpsInputOperands()[tid];
ArrayRef<AffineExpr> affines = op.getMatchingIndexingMap(input).getResults();
@@ -1267,7 +1273,7 @@ static void genConstantDenseAddressFromLevel(CodegenEnv &env,
AffineExpr affine = affines[toOrigDim(enc, i)];
if (isDenseDLT(getDimLevelType(enc, i)) &&
affine.isa<AffineConstantExpr>())
- env.loopEmitter->genDenseAffineAddressAtCurLevel(
+ env.emitter()->genDenseAffineAddressAtCurLevel(
builder, op.getLoc(), input->getOperandNumber(), i, affine);
else
return; // break on first non-dense non-constant level
@@ -1281,7 +1287,7 @@ static void genInitConstantDenseAddress(CodegenEnv &env,
// starting from the first level as they do not depend on any thing.
// E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
// levels can be determined before loops.
- for (unsigned tid = 0, e = env.linalgOp.getNumDpsInputs(); tid < e; tid++)
+ for (unsigned tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++)
genConstantDenseAddressFromLevel(env, rewriter, tid, 0);
}
@@ -1295,9 +1301,9 @@ static void translateBitsToTidDimPairs(
const BitVector &simple = env.lat(li).simple;
// Converts bits to array + dim pair
- env.merger.foreachTidDimPairInBits(all, [&, idx](unsigned b, unsigned tid,
- Optional<unsigned> dim,
- DimLevelType dlt) {
+ env.merger().foreachTidDimPairInBits(all, [&, idx](unsigned b, unsigned tid,
+ Optional<unsigned> dim,
+ DimLevelType dlt) {
if (simple.test(b)) {
if (isUndefDLT(dlt)) {
// An undefined dlt in the lattices, we probably mean to iterate based
@@ -1306,8 +1312,8 @@ static void translateBitsToTidDimPairs(
// output tensor).
// out[i][j] = invariant; or a broadcast
// out[i][j] = in[i] (j is undef for input)
- tid = env.merger.getOutTensorID();
- dim = env.merger.getDimNum(tid, idx);
+ tid = env.merger().getOutTensorID();
+ dim = env.merger().getDimNum(tid, idx);
// Skips invalid dim (e.g., when this is a zero ranked tensor).
if (!dim)
return;
@@ -1320,7 +1326,7 @@ static void translateBitsToTidDimPairs(
extraDims.push_back(*dim);
} else {
assert(isUndefDLT(dlt));
- linalg::GenericOp op = env.linalgOp;
+ linalg::GenericOp op = env.op();
if (tid >= op.getNumDpsInputs())
// We only handle affine expression on input tensors (for now).
return;
@@ -1361,12 +1367,12 @@ static void translateBitsToTidDimPairs(
}
});
- if (isDenseDLT(env.dimLevelType(env.merger.getOutTensorID(), idx))) {
+ if (isDenseDLT(env.dlt(env.merger().getOutTensorID(), idx))) {
// Note that we generate dense indices of the output tensor
// unconditionally, since they may not appear in the lattice, but may be
- // needed for linearized codegen.
- auto dim = *env.merger.getDimNum(env.merger.getOutTensorID(), idx);
- extraTids.push_back(env.merger.getOutTensorID());
+ // needed for linearized env.
+ auto dim = *env.merger().getDimNum(env.merger().getOutTensorID(), idx);
+ extraTids.push_back(env.merger().getOutTensorID());
extraDims.push_back(dim);
}
}
@@ -1384,7 +1390,7 @@ static Operation *startLoop(CodegenEnv &env, OpBuilder &builder, unsigned at,
// level.
SmallVector<size_t> affineTids, affineDims;
SmallVector<AffineExpr> affines;
- translateBitsToTidDimPairs(env, li, env.topSort[at], condTids, condDims,
+ translateBitsToTidDimPairs(env, li, env.topSortAt(at), condTids, condDims,
extraTids, extraDims, affineTids, affineDims,
affines);
@@ -1392,8 +1398,8 @@ static Operation *startLoop(CodegenEnv &env, OpBuilder &builder, unsigned at,
Operation *loop = genLoop(env, builder, at, needsUniv, condTids, condDims,
extraTids, extraDims);
for (auto [tid, dim, exp] : llvm::zip(affineTids, affineDims, affines)) {
- env.loopEmitter->genDenseAffineAddressAtCurLevel(
- builder, env.linalgOp.getLoc(), tid, dim, exp);
+ env.emitter()->genDenseAffineAddressAtCurLevel(builder, env.op().getLoc(),
+ tid, dim, exp);
}
// Until now, we have entered every <tid, dim> pair in {cond, extra,
@@ -1402,7 +1408,7 @@ static Operation *startLoop(CodegenEnv &env, OpBuilder &builder, unsigned at,
auto allTids = llvm::concat<size_t>(condTids, extraTids, affineTids);
auto allDims = llvm::concat<size_t>(condDims, extraDims, affineDims);
for (auto [tid, dim] : llvm::zip(allTids, allDims)) {
- if (tid != env.merger.getOutTensorID())
+ if (tid != env.merger().getOutTensorID())
genConstantDenseAddressFromLevel(env, builder, tid, dim + 1);
}
@@ -1420,7 +1426,7 @@ static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop,
}
genLoopBoundary(env, [&](MutableArrayRef<Value> reduc) {
- env.loopEmitter->exitCurrentLoop(rewriter, env.linalgOp.getLoc(), reduc);
+ env.emitter()->exitCurrentLoop(rewriter, env.op().getLoc(), reduc);
return std::nullopt;
});
@@ -1431,11 +1437,11 @@ static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop,
static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp,
unsigned at, unsigned idx, unsigned ldx) {
assert(env.getLoopIdxValue(idx) == nullptr);
- env.loopEmitter->exitCurrentLoopSeq();
+ env.emitter()->exitCurrentLoopSeq();
// Unmark bookkeeping of invariants and loop index.
genInvariants(env, builder, exp, ldx, /*atStart=*/false);
// Finalize access pattern expansion for sparse tensor output.
- genExpansion(env, builder, at, /*atStart=*/false);
+ genExpand(env, builder, at, /*atStart=*/false);
}
/// Recursively generates code while computing iteration lattices in order
@@ -1444,17 +1450,17 @@ static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp,
static void genStmt(CodegenEnv &env, RewriterBase &rewriter, unsigned exp,
unsigned at) {
// At each leaf, assign remaining tensor (sub)expression to output tensor.
- if (at == env.topSort.size()) {
- unsigned ldx = env.topSort[at - 1];
+ if (at == env.topSortSize()) {
+ unsigned ldx = env.topSortAt(at - 1);
Value rhs = genExp(env, rewriter, exp, ldx);
genTensorStore(env, rewriter, exp, rhs);
return;
}
// Construct iteration lattices for current loop index, with L0 at top.
- unsigned idx = env.topSort[at];
- unsigned ldx = at == 0 ? -1u : env.topSort[at - 1];
- unsigned lts = env.merger.optimizeSet(env.merger.buildLattices(exp, idx));
+ unsigned idx = env.topSortAt(at);
+ unsigned ldx = at == 0 ? -1u : env.topSortAt(at - 1);
+ unsigned lts = env.merger().optimizeSet(env.merger().buildLattices(exp, idx));
// TODO: sort
// TODO: dedup
@@ -1472,13 +1478,13 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, unsigned exp,
// Visit all lattices points with Li >= Lj to generate the
// loop-body, possibly with if statements for coiteration.
Value redInput = env.getReduc();
- Value cntInput = env.expCount;
- Value insInput = env.insChain;
+ Value cntInput = env.getExpandCount();
+ Value insInput = env.getInsertionChain();
bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr;
for (unsigned j = 0; j < lsize; j++) {
unsigned lj = env.set(lts)[j];
unsigned ej = env.lat(lj).exp;
- if (li == lj || env.merger.latGT(li, lj)) {
+ if (li == lj || env.merger().latGT(li, lj)) {
// Recurse into body of each branch.
if (isWhile) {
scf::IfOp ifOp = genIf(env, rewriter, idx, env.lat(lj).simple);
@@ -1500,7 +1506,7 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, unsigned exp,
/// Converts the result computed by the sparse kernel into the required form.
static void genResult(CodegenEnv &env, RewriterBase &rewriter) {
- linalg::GenericOp op = env.linalgOp;
+ linalg::GenericOp op = env.op();
OpOperand *lhs = op.getDpsInitOperand(0);
Value tensor = lhs->get();
Type resType = tensor.getType();
@@ -1508,14 +1514,16 @@ static void genResult(CodegenEnv &env, RewriterBase &rewriter) {
// The sparse tensor rematerializes from the original sparse tensor's
// underlying sparse storage format. For an insertion chain, the
// tensor materializes from the chain with 'hasInserts' enabled.
- bool hasInserts = env.sparseOut == lhs;
- if (hasInserts)
- tensor = env.insChain;
+ bool hasInserts = false;
+ if (Value chain = env.getInsertionChain()) {
+ hasInserts = true;
+ tensor = chain;
+ }
rewriter.replaceOpWithNewOp<LoadOp>(op, resType, tensor, hasInserts);
} else {
// To rematerialize an non-annotated tensor, simply load it
// from the bufferized value.
- Value val = env.getValBuffer().back(); // value array
+ Value val = env.emitter()->getValBuffer().back(); // value array
rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, val);
}
}
@@ -1550,8 +1558,8 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
return failure();
// Builds the tensor expression for the Linalg operation in SSA form.
- Optional<unsigned> optExp = env.merger.buildTensorExpFromLinalg(op);
- if (!optExp.has_value())
+ Optional<unsigned> optExp = env.merger().buildTensorExpFromLinalg(op);
+ if (!optExp)
return failure();
unsigned exp = *optExp;
@@ -1562,6 +1570,8 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
// to resolve cycles by inserting a conversion.
bool isAdmissible = false;
bool hasCycle = true;
+ OpOperand *sparseOut = nullptr;
+ unsigned outerParNest = -1u;
// An const list of all masks that we used for interation graph
// computation. Must be ordered from more strict to less strict.
const auto allMask = {SortMask::kIncludeAll, SortMask::kIncludeUndef,
@@ -1569,7 +1579,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
for (auto mask : allMask)
if (computeIterationGraph(env, mask)) {
hasCycle = false;
- if (isAdmissibleTensorExp(env, exp)) {
+ if (isAdmissibleTensorExp(env, exp, &sparseOut, &outerParNest)) {
isAdmissible = true;
break;
}
@@ -1589,9 +1599,9 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
SparseTensorLoopEmitter lpe(
tensors,
StringAttr::get(op.getContext(), linalg::GenericOp::getOperationName()),
- /*hasOutput=*/true, /*isSparseOut=*/env.sparseOut != nullptr,
- env.topSort);
- env.startEmit(&lpe);
+ /*hasOutput=*/true, /*isSparseOut=*/sparseOut != nullptr,
+ env.topSortRef());
+ env.startEmit(sparseOut, outerParNest, &lpe);
// Recursively generates code if admissible.
genBuffers(env, rewriter);
@@ -1607,7 +1617,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
// Compute topological sort while leaving out every
// sparse input tensor in succession until an acylic
// iteration graph results.
- for (OpOperand *t : env.linalgOp.getDpsInputOperands()) {
+ for (OpOperand *t : env.op().getDpsInputOperands()) {
unsigned tensor = t->getOperandNumber();
Value tval = t->get();
auto srcEnc = getSparseTensorEncoding(tval.getType());
@@ -1623,14 +1633,14 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
auto srcTp = tval.getType().cast<RankedTensorType>();
auto dstEnc = SparseTensorEncodingAttr::get(
getContext(), srcEnc.getDimLevelType(),
- permute(env, env.linalgOp.getMatchingIndexingMap(t)), // new order
+ permute(env, env.op().getMatchingIndexingMap(t)), // new order
srcEnc.getHigherOrdering(), srcEnc.getPointerBitWidth(),
srcEnc.getIndexBitWidth());
auto dstTp = RankedTensorType::get(srcTp.getShape(),
srcTp.getElementType(), dstEnc);
auto convert = rewriter.create<ConvertOp>(tval.getLoc(), dstTp, tval);
- env.linalgOp->setOperand(tensor, convert);
- rewriter.setInsertionPointAfter(env.linalgOp);
+ env.op()->setOperand(tensor, convert);
+ rewriter.setInsertionPointAfter(env.op());
rewriter.create<bufferization::DeallocTensorOp>(tval.getLoc(), convert);
return success();
}
More information about the Mlir-commits
mailing list