[Mlir-commits] [mlir] 98f93e3 - [mlir][sparse] factorized merger/emitter/codegen into single environment
Aart Bik
llvmlistbot at llvm.org
Tue Dec 20 11:34:21 PST 2022
Author: Aart Bik
Date: 2022-12-20T11:34:12-08:00
New Revision: 98f93e3b726ac061153ff381b0119d6f2711abe4
URL: https://github.com/llvm/llvm-project/commit/98f93e3b726ac061153ff381b0119d6f2711abe4
DIFF: https://github.com/llvm/llvm-project/commit/98f93e3b726ac061153ff381b0119d6f2711abe4.diff
LOG: [mlir][sparse] factorized merger/emitter/codegen into single environment
This cleans up a lot of parameter passing. It also prepares adding
proper "delegate" functions to the new environment and moving this
out into its own class with a better OO design.
Reviewed By: Peiming
Differential Revision: https://reviews.llvm.org/D140257
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 680ef8d00b037..3f67f66a93097 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -41,7 +41,7 @@ using namespace mlir::sparse_tensor;
namespace {
-// Iteration graph sorting.
+/// Iteration graph sorting.
enum SortMask {
kSparseOnly = 0x0,
kIncludeDense = 0x1,
@@ -49,37 +49,91 @@ enum SortMask {
kIncludeAll = 0x3
};
-// Reduction kinds.
+/// Reduction kinds.
enum Reduction { kNoReduc, kSum, kProduct, kAnd, kOr, kXor, kCustom };
-// Code generation.
-struct CodeGen {
- CodeGen(SparsificationOptions o, MLIRContext *context, ValueRange tensors,
- unsigned numTensors, unsigned numLoops, OpOperand *op, unsigned nest,
- std::vector<unsigned> &ts)
- : options(o),
- loopEmitter(
- tensors,
- StringAttr::get(context, linalg::GenericOp::getOperationName()),
- /*hasOutput=*/true,
- /*isSparseOut=*/op != nullptr, ts),
- sparseOut(op), outerParNest(nest), topSort(ts) {
- if (op)
- insChain = op->get();
+/// Code generation environment. This structure aggregates a number
+/// of data structures needed during code generation. Such an environment
+/// simplifies passing around data during sparsification (rather than
+/// passing around all the individual compoments where needed).
+//
+// TODO: refactor further, move into own file
+//
+struct CodeGenEnv {
+ CodeGenEnv(linalg::GenericOp linop, SparsificationOptions opts,
+ unsigned numTensors, unsigned numLoops, unsigned numFilterLoops)
+ : linalgOp(linop), options(opts), topSort(),
+ merger(numTensors, numLoops, numFilterLoops), loopEmitter(nullptr),
+ redExp(-1u), redKind(kNoReduc), redCustom(-1u), sparseOut(nullptr) {}
+
+ // Start emitting.
+ void startEmit(SparseTensorLoopEmitter *le) {
+ assert(!loopEmitter && "must only start emitting once");
+ loopEmitter = le;
+ if (sparseOut) {
+ insChain = sparseOut->get();
+ merger.setHasSparseOut(true);
+ }
+ }
+
+ // Delegate methods to merger.
+ TensorExp &exp(unsigned e) { return merger.exp(e); }
+ LatPoint &lat(unsigned l) { return merger.lat(l); }
+ SmallVector<unsigned> &set(unsigned s) { return merger.set(s); }
+ DimLevelType dimLevelType(unsigned t, unsigned i) const {
+ return merger.getDimLevelType(t, i);
+ }
+ DimLevelType dimLevelType(unsigned b) const {
+ return merger.getDimLevelType(b);
+ }
+ bool isFilterLoop(unsigned i) const { return merger.isFilterLoop(i); }
+
+ // Delegate methods to loop emitter.
+ Value getLoopIV(unsigned i) const { return loopEmitter->getLoopIV(i); }
+ const std::vector<Value> &getValBuffer() const {
+ return loopEmitter->getValBuffer();
+ }
+
+ // Convenience method to slice topsort.
+ ArrayRef<unsigned> getTopSortSlice(size_t n, size_t m) const {
+ return ArrayRef<unsigned>(topSort).slice(n, m);
+ }
+
+ // Convenience method to get current loop stack.
+ ArrayRef<unsigned> getLoopCurStack() const {
+ return getTopSortSlice(0, loopEmitter->getCurrentDepth());
+ }
+
+ // Convenience method to get the IV of the given loop index.
+ Value getLoopIdxValue(size_t loopIdx) const {
+ for (unsigned lv = 0, lve = topSort.size(); lv < lve; lv++)
+ if (topSort[lv] == loopIdx)
+ return getLoopIV(lv);
+ llvm_unreachable("invalid loop index");
}
+
+ // TODO: make private
+
+ /// Linalg operation.
+ linalg::GenericOp linalgOp;
/// Sparsification options.
SparsificationOptions options;
- /// Loop emitter helper class.
- SparseTensorLoopEmitter loopEmitter;
+ // Topological sort.
+ std::vector<unsigned> topSort;
+ /// Merger helper class.
+ Merger merger;
+ /// Loop emitter helper class (keep reference in scope!).
+ /// TODO: move emitter constructor up in time?
+ SparseTensorLoopEmitter *loopEmitter;
/// Current reduction, updated during code generation. When indices of a
/// reduction are exhausted, all inner loops can use a scalarized reduction.
- unsigned redExp = -1u;
+ unsigned redExp;
Value redVal;
- Reduction redKind = kNoReduc;
- unsigned redCustom = -1u;
- // Sparse tensor as output. Implemented either through direct injective
- // insertion in lexicographic index order or through access pattern expansion
- // in the innermost loop nest (`expValues` through `expCount`).
+ Reduction redKind;
+ unsigned redCustom;
+ /// Sparse tensor as output. Implemented either through direct injective
+ /// insertion in lexicographic index order or through access pattern expansion
+ /// in the innermost loop nest (`expValues` through `expCount`).
OpOperand *sparseOut;
unsigned outerParNest;
Value insChain; // bookkeeping for insertion chain
@@ -87,21 +141,6 @@ struct CodeGen {
Value expFilled;
Value expAdded;
Value expCount;
- // Topsort (reference should remain in scope).
- std::vector<unsigned> &topSort;
-
- ArrayRef<unsigned> getLoopCurStack() const {
- ArrayRef<unsigned> topSortRef = topSort;
- return topSortRef.slice(0, loopEmitter.getCurrentDepth());
- }
-
- Value getLoopIdxValue(size_t loopIdx) const {
- for (unsigned lv = 0; lv < topSort.size(); lv++)
- if (topSort[lv] == loopIdx)
- return loopEmitter.getLoopIV(lv);
-
- llvm_unreachable("invalid loop index");
- }
};
/// A helper class that visits an affine expression and tries to find an
@@ -133,6 +172,7 @@ class AffineDimFinder : public AffineExprVisitor<AffineDimFinder> {
/// The mapping between dim=>iterator type.
SmallVector<utils::IteratorType> iterTypes;
};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -172,17 +212,17 @@ static bool isInvariantAffine(AffineExpr a, ArrayRef<unsigned> loopStack,
}
/// Determines if affine expression is invariant.
-static bool isInvariantAffine(const CodeGen &codegen, AffineExpr a,
- unsigned ldx, bool &atLevel) {
- return isInvariantAffine(a, codegen.getLoopCurStack(), ldx, atLevel);
+static bool isInvariantAffine(CodeGenEnv &env, AffineExpr a, unsigned ldx,
+ bool &atLevel) {
+ return isInvariantAffine(a, env.getLoopCurStack(), ldx, atLevel);
}
/// Helper method to construct a permuted dimension ordering
/// that adheres to the given topological sort.
-static AffineMap permute(const Merger &merger, MLIRContext *context,
- AffineMap m, ArrayRef<unsigned> topSort) {
- assert(m.getNumDims() + merger.getNumFilterLoops() == topSort.size() &&
- "TopoSort/AffineMap size mismatch");
+static AffineMap permute(CodeGenEnv &env, AffineMap m) {
+ assert(m.getNumDims() + env.merger.getNumFilterLoops() ==
+ env.topSort.size() &&
+ "size mismatch");
// Construct the inverse of `m`; to avoid the asymptotic complexity
// of calling `m.getPermutedPosition` repeatedly.
SmallVector<unsigned> perm;
@@ -191,13 +231,14 @@ static AffineMap permute(const Merger &merger, MLIRContext *context,
unsigned loopDepth = 1;
// Construct the permutation.
- while (worklist.any() && loopDepth <= topSort.size()) {
+ while (worklist.any() && loopDepth <= env.topSort.size()) {
unsigned preSize = perm.size();
for (auto dim : worklist.set_bits()) {
bool atLevel = false;
if (m.getResult(dim).isa<AffineConstantExpr>() ||
- (isInvariantAffine(m.getResult(dim), topSort.slice(0, loopDepth),
- topSort[loopDepth - 1], atLevel) &&
+ (isInvariantAffine(m.getResult(dim),
+ env.getTopSortSlice(0, loopDepth),
+ env.topSort[loopDepth - 1], atLevel) &&
atLevel)) {
// If the matching affine is constant expression or just become
// invariant. We can visit the dimension now without breaking the
@@ -215,7 +256,7 @@ static AffineMap permute(const Merger &merger, MLIRContext *context,
}
assert(perm.size() == numResults);
- return AffineMap::getPermutationMap(perm, context);
+ return AffineMap::getPermutationMap(perm, env.linalgOp.getContext());
}
/// Helper method to inspect affine expressions. Rejects cases where the
@@ -305,24 +346,24 @@ static unsigned getNumCompoundAffineOnSparseDims(linalg::GenericOp op) {
/// Returns true if the sparse annotations and affine subscript
/// expressions of all tensors are admissible. Returns false if
/// no annotations are found or inadmissible constructs occur.
-static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) {
+static bool findSparseAnnotations(CodeGenEnv &env) {
bool annotated = false;
- unsigned filterLdx = merger.getFilterLoopStartingIdx();
- for (OpOperand &t : op->getOpOperands()) {
- auto map = op.getMatchingIndexingMap(&t);
+ unsigned filterLdx = env.merger.getFilterLoopStartingIdx();
+ for (OpOperand &t : env.linalgOp->getOpOperands()) {
+ auto map = env.linalgOp.getMatchingIndexingMap(&t);
auto enc = getSparseTensorEncoding(t.get().getType());
if (enc)
annotated = true;
- assert(map.getNumResults() == op.getRank(&t));
-
+ assert(map.getNumResults() == env.linalgOp.getRank(&t));
for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
unsigned tensor = t.getOperandNumber();
AffineExpr a = map.getResult(toOrigDim(enc, d));
- if (!findAffine(merger, tensor, d, a, getDimLevelType(enc, d), filterLdx))
+ if (!findAffine(env.merger, tensor, d, a, getDimLevelType(enc, d),
+ filterLdx))
return false; // inadmissible affine expression
}
}
- assert(filterLdx == merger.getNumLoops());
+ assert(filterLdx == env.merger.getNumLoops());
return annotated;
}
@@ -330,9 +371,8 @@ static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) {
/// as we use adj matrix for the graph.
/// The sorted result will put the first Reduction iterator to the
/// latest possible index.
-static bool topSortOptimal(unsigned n,
+static bool topSortOptimal(CodeGenEnv &env, unsigned n,
ArrayRef<utils::IteratorType> iteratorTypes,
- const Merger &merger, std::vector<unsigned> &topSort,
std::vector<unsigned> &inDegree,
std::vector<std::vector<bool>> &adjM) {
std::vector<unsigned> redIt; // reduce iterator with 0 degree
@@ -340,7 +380,7 @@ static bool topSortOptimal(unsigned n,
std::vector<unsigned> filterIt; // filter loop with 0 degree
for (unsigned i = 0; i < n; i++) {
if (inDegree[i] == 0) {
- if (merger.isFilterLoop(i))
+ if (env.isFilterLoop(i))
filterIt.push_back(i);
else if (linalg::isReductionIterator(iteratorTypes[i]))
redIt.push_back(i);
@@ -371,12 +411,12 @@ static bool topSortOptimal(unsigned n,
// O(X) computation => O(NK+NMX) time complexity
auto &it = !filterIt.empty() ? filterIt : (!parIt.empty() ? parIt : redIt);
auto src = it.back();
- topSort.push_back(src);
+ env.topSort.push_back(src);
it.pop_back();
// Update in-degree, and push 0-degree node into worklist.
for (unsigned dst = 0; dst < n; dst++) {
if (adjM[src][dst] && --inDegree[dst] == 0) {
- if (merger.isFilterLoop(dst))
+ if (env.isFilterLoop(dst))
filterIt.push_back(dst);
else if (linalg::isReductionIterator(iteratorTypes[dst]))
redIt.push_back(dst);
@@ -385,7 +425,7 @@ static bool topSortOptimal(unsigned n,
}
}
}
- return topSort.size() == n;
+ return env.topSort.size() == n;
}
/// Helper method to add all constraints from the indices in one affine
@@ -477,21 +517,21 @@ static void tryLoosenAffineDenseConstraints(linalg::GenericOp op,
/// along fixed dimensions. Even for dense storage formats, however, the
/// natural index order yields innermost unit-stride access with better
/// spatial locality.
-static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
- std::vector<unsigned> &topSort, unsigned mask,
+static bool computeIterationGraph(CodeGenEnv &env, unsigned mask,
OpOperand *skip = nullptr) {
// Set up an n x n from/to adjacency matrix of the iteration graph
// for the implicit loop indices i_0 .. i_n-1.
- unsigned n = merger.getNumLoops();
+ unsigned n = env.merger.getNumLoops();
std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false));
std::vector<unsigned> inDegree(n, 0); // in-degree of each node.
- auto iteratorTypes = op.getIteratorTypesArray();
+ auto iteratorTypes = env.linalgOp.getIteratorTypesArray();
// Iterate over the indexing maps of every tensor in the tensor expression.
- for (OpOperand &t : op->getOpOperands()) {
+ for (OpOperand &t : env.linalgOp->getOpOperands()) {
// Get map and encoding.
- auto map = op.getMatchingIndexingMap(&t);
+ auto map = env.linalgOp.getMatchingIndexingMap(&t);
auto enc = getSparseTensorEncoding(t.get().getType());
- assert(map.getNumDims() + getNumCompoundAffineOnSparseDims(op) == n);
+ assert(map.getNumDims() + getNumCompoundAffineOnSparseDims(env.linalgOp) ==
+ n);
// Skip dense tensor constraints when not requested.
if (!(mask & SortMask::kIncludeDense) && !enc)
continue;
@@ -501,11 +541,11 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
// on the loop indices if no explicit dimension ordering is given.
for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
AffineExpr ta = map.getResult(toOrigDim(enc, d));
- Optional<unsigned> tldx = merger.getLoopIdx(t.getOperandNumber(), d);
+ Optional<unsigned> tldx = env.merger.getLoopIdx(t.getOperandNumber(), d);
// Filter loops should be constructed after all the dependent loops,
// i.e., d0 + d1 < filter_loop(d0 + d1)
- if (tldx && merger.isFilterLoop(*tldx)) {
+ if (tldx && env.isFilterLoop(*tldx)) {
assert(!ta.isa<AffineDimExpr>() &&
!isDenseDLT(getDimLevelType(enc, d)));
addAffineOrderings(adjM, inDegree, ta, AffineExpr(), std::nullopt,
@@ -525,7 +565,7 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
if (d > 0) {
AffineExpr fa = map.getResult(toOrigDim(enc, d - 1));
Optional<unsigned> fldx =
- merger.getLoopIdx(t.getOperandNumber(), d - 1);
+ env.merger.getLoopIdx(t.getOperandNumber(), d - 1);
// Applying order constraints on every pair of dimExpr between two
// compound affine expressions can sometime too strict:
@@ -533,7 +573,7 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
// It is totally fine to have loop sequence d0->d2->d1->d3 instead of
// requiring d0 < d2, d1 < d2, d0 < d3, d1 < d3.
if (!(mask & SortMask::kIncludeDense))
- tryLoosenAffineDenseConstraints(op, fldx, fa, tldx, ta);
+ tryLoosenAffineDenseConstraints(env.linalgOp, fldx, fa, tldx, ta);
// (d0 + d1) < (d2 + d3), or
// filter_loop_d-1 < (d2 + d3), or
@@ -548,24 +588,24 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
if (mask & SortMask::kIncludeUndef) {
unsigned tensor = t.getOperandNumber();
for (unsigned i = 0; i < n; i++)
- if (isCompressedDLT(merger.getDimLevelType(tensor, i)) ||
- isSingletonDLT(merger.getDimLevelType(tensor, i))) {
+ if (isCompressedDLT(env.dimLevelType(tensor, i)) ||
+ isSingletonDLT(env.dimLevelType(tensor, i))) {
for (unsigned j = 0; j < n; j++)
- if (isUndefDLT(merger.getDimLevelType(tensor, j))) {
+ if (isUndefDLT(env.dimLevelType(tensor, j))) {
adjM[i][j] = true;
inDegree[j]++;
}
} else {
- assert(isDenseDLT(merger.getDimLevelType(tensor, i)) ||
- isUndefDLT(merger.getDimLevelType(tensor, i)));
+ assert(isDenseDLT(env.dimLevelType(tensor, i)) ||
+ isUndefDLT(env.dimLevelType(tensor, i)));
}
}
}
// Topologically sort the iteration graph to determine loop order.
// Report failure for a cyclic iteration graph.
- topSort.clear();
- topSort.reserve(n);
- return topSortOptimal(n, iteratorTypes, merger, topSort, inDegree, adjM);
+ env.topSort.clear();
+ env.topSort.reserve(n);
+ return topSortOptimal(env, n, iteratorTypes, inDegree, adjM);
}
/// Returns true if tensor materializes uninitialized into the computation.
@@ -574,29 +614,25 @@ static bool isMaterializing(Value val) {
val.getDefiningOp<bufferization::AllocTensorOp>();
}
-/// Returns true when the tensor expression is admissible for codegen.
+/// Returns true when the tensor expression is admissible for env.
/// Since all sparse input tensors are admissible, we just need to check
/// whether the out tensor in the tensor expression codegen is admissible.
/// Sets `sparseOut` to the tensor and `outerParNest` to the outer injective
/// nesting depth when a "truly dynamic" sparse tensor output occurs.
-static bool isAdmissibleTensorExp(Merger &merger, linalg::GenericOp op,
- std::vector<unsigned> &topSort, unsigned exp,
- OpOperand **sparseOut,
- unsigned &outerParNest) {
+static bool isAdmissibleTensorExp(CodeGenEnv &env, unsigned exp) {
// We reject any expression that makes a reduction from `-outTensor`, as those
- // expression create dependency between the current iteration (i) and the
- // previous iteration (i-1). It would then require iterating over the whole
- // coordinate space, which prevent us from exploiting sparsity for faster
- // code.
- for (utils::IteratorType it : op.getIteratorTypesArray()) {
+ // expressions create a dependency between the current iteration (i) and the
+ // previous iteration (i-1). It would require iterating over the whole
+ // coordinate space, which prevent exploiting sparsity for faster code.
+ for (utils::IteratorType it : env.linalgOp.getIteratorTypesArray()) {
if (it == utils::IteratorType::reduction) {
- if (merger.hasNegateOnOut(exp))
+ if (env.merger.hasNegateOnOut(exp))
return false;
break;
}
}
- OpOperand *lhs = op.getDpsInitOperand(0);
+ OpOperand *lhs = env.linalgOp.getDpsInitOperand(0);
unsigned tensor = lhs->getOperandNumber();
auto enc = getSparseTensorEncoding(lhs->get().getType());
// An non-annotated output tensor is assumed dense, and becomes a random
@@ -606,40 +642,41 @@ static bool isAdmissibleTensorExp(Merger &merger, linalg::GenericOp op,
// An all-dense annotated "sparse" output tensor becomes a linearized random
// access 1-dim memref. Also admissible since insertions cannot occur.
bool allDense = true;
- unsigned numLoops = merger.getNumLoops(); // numNativeLoops + numFilterLoops
- for (unsigned i = 0; i < merger.getNumLoops(); i++)
- if (isCompressedDLT(merger.getDimLevelType(tensor, i)) ||
- isSingletonDLT(merger.getDimLevelType(tensor, i))) {
+ unsigned numLoops =
+ env.merger.getNumLoops(); // numNativeLoops + numFilterLoops
+ for (unsigned i = 0; i < env.merger.getNumLoops(); i++)
+ if (isCompressedDLT(env.dimLevelType(tensor, i)) ||
+ isSingletonDLT(env.dimLevelType(tensor, i))) {
allDense = false;
break;
} else {
- assert(isDenseDLT(merger.getDimLevelType(tensor, i)) ||
- isUndefDLT(merger.getDimLevelType(tensor, i)));
+ assert(isDenseDLT(env.dimLevelType(tensor, i)) ||
+ isUndefDLT(env.dimLevelType(tensor, i)));
}
if (allDense)
return true;
// TODO: support compound affine expression on sparse output.
- if (getNumCompoundAffineOnSparseDims(op.getMatchingIndexingMap(lhs),
+ if (getNumCompoundAffineOnSparseDims(env.linalgOp.getMatchingIndexingMap(lhs),
lhs->get()) != 0)
return false;
// A tensor expression with a sparse output tensor that changes its values
// but not its nonzero structure, an operation called "simply dynamic" in
- // [Bik96,Ch9], is also admissible without special codegen.
- if (merger.isSingleCondition(tensor, exp))
+ // [Bik96,Ch9], is also admissible without special env.
+ if (env.merger.isSingleCondition(tensor, exp))
return true;
// Accept "truly dynamic" if the output tensor materializes uninitialized
// into the computation and insertions occur in lexicographic index order.
if (isMaterializing(lhs->get())) {
- auto iteratorTypes = op.getIteratorTypesArray();
+ auto iteratorTypes = env.linalgOp.getIteratorTypesArray();
unsigned nest = 0;
for (unsigned i = 0; i < numLoops; i++) {
- if (!merger.isFilterLoop(topSort[i])) {
+ if (!env.isFilterLoop(env.topSort[i])) {
// We only count non-filter loops as filter loops should be considered
// as a special type of parallel loops.
- if (linalg::isReductionIterator(iteratorTypes[topSort[i]]))
+ if (linalg::isReductionIterator(iteratorTypes[env.topSort[i]]))
break; // terminate at first reduction
nest++;
}
@@ -647,9 +684,9 @@ static bool isAdmissibleTensorExp(Merger &merger, linalg::GenericOp op,
// Determine admissible dynamic insertion situations:
// (1) fully injective, since there are no reductions,
// (2) admissible 1-d expansion in innermost dimension.
- if (nest >= op.getRank(lhs) - 1) {
- *sparseOut = lhs;
- outerParNest = nest;
+ if (nest >= env.linalgOp.getRank(lhs) - 1) {
+ env.sparseOut = lhs;
+ env.outerParNest = nest;
return true;
}
}
@@ -688,9 +725,9 @@ static Reduction getReduction(Kind kind) {
}
/// Updates scalarized reduction value.
-static void updateReduc(Merger &merger, CodeGen &codegen, Value reduc) {
- assert(codegen.redKind != kNoReduc);
- codegen.redVal = merger.exp(codegen.redExp).val = reduc;
+static void updateReduc(CodeGenEnv &env, Value reduc) {
+ assert(env.redKind != kNoReduc);
+ env.redVal = env.exp(env.redExp).val = reduc;
}
/// Extracts identity from custom reduce.
@@ -705,38 +742,38 @@ static Value getCustomRedId(Operation *op) {
/// Generates loop boundary statements (entering/exiting loops). The function
/// passes and updates the reduction value.
static Optional<Operation *> genLoopBoundary(
- CodeGen &codegen, Merger &merger,
+ CodeGenEnv &env,
function_ref<Optional<Operation *>(MutableArrayRef<Value> reduc)>
callback) {
SmallVector<Value> reduc;
- if (codegen.redVal)
- reduc.push_back(codegen.redVal);
- if (codegen.expValues)
- reduc.push_back(codegen.expCount);
- if (codegen.insChain)
- reduc.push_back(codegen.insChain);
+ if (env.redVal)
+ reduc.push_back(env.redVal);
+ if (env.expValues)
+ reduc.push_back(env.expCount);
+ if (env.insChain)
+ reduc.push_back(env.insChain);
auto r = callback(reduc);
// Callback should do in-place update on reduction value vector.
unsigned i = 0;
- if (codegen.redVal)
- updateReduc(merger, codegen, reduc[i++]);
- if (codegen.expValues)
- codegen.expCount = reduc[i++];
- if (codegen.insChain)
- codegen.insChain = reduc[i];
+ if (env.redVal)
+ updateReduc(env, reduc[i++]);
+ if (env.expValues)
+ env.expCount = reduc[i++];
+ if (env.insChain)
+ env.insChain = reduc[i];
return r;
}
/// Local bufferization of all dense and sparse data structures.
-static void genBuffers(Merger &merger, CodeGen &codegen, OpBuilder &builder,
- linalg::GenericOp op) {
+static void genBuffers(CodeGenEnv &env, OpBuilder &builder) {
+ linalg::GenericOp op = env.linalgOp;
Location loc = op.getLoc();
assert(op.getNumOperands() == op.getNumDpsInputs() + 1);
- codegen.loopEmitter.initializeLoopEmit(
+ env.loopEmitter->initializeLoopEmit(
builder, loc,
/// Generates buffer for the output tensor.
/// Note that all sparse kernels assume that when all elements are written
@@ -749,10 +786,9 @@ static void genBuffers(Merger &merger, CodeGen &codegen, OpBuilder &builder,
Value tensor) -> Value {
// Must not be a sparse tensor.
assert(!getSparseTensorEncoding(tensor.getType()));
+ // Two output tensor references should point to the same object.
OpOperand *lhs = op.getDpsInitOperand(0);
- // Two output tensors references should pointed to the same object.
assert(lhs->get() == tensor);
- bool isInit = op.isInitTensor(lhs);
// An output tensor can simply materialize from the buffer of the tensor
// that appears in the outs() clause. For updates, this has the
// advantage that only the nonzero value are involved in the
@@ -761,6 +797,7 @@ static void genBuffers(Merger &merger, CodeGen &codegen, OpBuilder &builder,
// may negatively impact running complexity (viz. O(n^2 + nnz) vs.
// O(nnz) for matrices).
// TODO: use better analysis to avoid zeroing out the buffer?
+ bool isInit = op.isInitTensor(lhs);
Value init = memref;
if (!isInit) {
Value zero = constantZero(builder, loc,
@@ -773,83 +810,82 @@ static void genBuffers(Merger &merger, CodeGen &codegen, OpBuilder &builder,
}
/// Generates index for load/store on sparse tensor.
-static Value genIndex(CodeGen &codegen, linalg::GenericOp op, OpOperand *t) {
- auto map = op.getMatchingIndexingMap(t);
+static Value genIndex(CodeGenEnv &env, OpOperand *t) {
+ auto map = env.linalgOp.getMatchingIndexingMap(t);
auto enc = getSparseTensorEncoding(t->get().getType());
AffineExpr a = map.getResult(toOrigDim(enc, map.getNumResults() - 1));
assert(a.getKind() == AffineExprKind::DimId);
unsigned idx = a.cast<AffineDimExpr>().getPosition();
- return codegen.getLoopIdxValue(idx);
+ return env.getLoopIdxValue(idx);
}
/// Generates subscript for load/store on a dense or sparse tensor.
-static Value genSubscript(CodeGen &codegen, OpBuilder &builder,
- linalg::GenericOp op, OpOperand *t,
+static Value genSubscript(CodeGenEnv &env, OpBuilder &builder, OpOperand *t,
SmallVectorImpl<Value> &args) {
+ linalg::GenericOp op = env.linalgOp;
unsigned tensor = t->getOperandNumber();
auto map = op.getMatchingIndexingMap(t);
auto enc = getSparseTensorEncoding(t->get().getType());
unsigned rank = map.getNumResults();
if (enc) {
- Value pidx = codegen.loopEmitter.getPidxs()[tensor].back();
+ Value pidx = env.loopEmitter->getPidxs()[tensor].back();
assert(pidx);
args.push_back(pidx); // position index
} else {
for (unsigned d = 0; d < rank; d++) {
AffineExpr a = map.getResult(d);
- args.push_back(codegen.loopEmitter.genAffine(builder, a, op.getLoc()));
+ args.push_back(env.loopEmitter->genAffine(builder, a, op.getLoc()));
}
}
- return codegen.loopEmitter.getValBuffer()[tensor];
+ return env.getValBuffer()[tensor];
}
/// Generates insertion code to implement dynamic tensor load.
-static Value genInsertionLoad(CodeGen &codegen, OpBuilder &builder,
- linalg::GenericOp op, OpOperand *t) {
+static Value genInsertionLoad(CodeGenEnv &env, OpBuilder &builder,
+ OpOperand *t) {
+ linalg::GenericOp op = env.linalgOp;
Location loc = op.getLoc();
// Direct lexicographic index order, tensor loads as zero.
- if (!codegen.expValues) {
+ if (!env.expValues) {
Type tp = getElementTypeOrSelf(t->get().getType());
return constantZero(builder, loc, tp);
}
// Load from expanded access pattern.
- Value index = genIndex(codegen, op, t);
- return builder.create<memref::LoadOp>(loc, codegen.expValues, index);
+ Value index = genIndex(env, t);
+ return builder.create<memref::LoadOp>(loc, env.expValues, index);
}
/// Generates insertion code to implement dynamic tensor load for reduction.
-static Value genInsertionLoadReduce(Merger &merger, CodeGen &codegen,
- OpBuilder &builder, linalg::GenericOp op,
+static Value genInsertionLoadReduce(CodeGenEnv &env, OpBuilder &builder,
OpOperand *t) {
+ linalg::GenericOp op = env.linalgOp;
Location loc = op.getLoc();
- Value identity = getCustomRedId(merger.exp(codegen.redCustom).op);
+ Value identity = getCustomRedId(env.exp(env.redCustom).op);
// Direct lexicographic index order, tensor loads as identity.
- if (!codegen.expValues) {
+ if (!env.expValues) {
return identity;
}
// Load from expanded access pattern if filled, identity otherwise.
- Value index = genIndex(codegen, op, t);
- Value isFilled =
- builder.create<memref::LoadOp>(loc, codegen.expFilled, index);
- Value valAtIndex =
- builder.create<memref::LoadOp>(loc, codegen.expValues, index);
+ Value index = genIndex(env, t);
+ Value isFilled = builder.create<memref::LoadOp>(loc, env.expFilled, index);
+ Value valAtIndex = builder.create<memref::LoadOp>(loc, env.expValues, index);
return builder.create<arith::SelectOp>(loc, isFilled, valAtIndex, identity);
}
/// Generates insertion code to implement dynamic tensor store.
-static void genInsertionStore(CodeGen &codegen, OpBuilder &builder,
- linalg::GenericOp op, OpOperand *t, Value rhs) {
+static void genInsertionStore(CodeGenEnv &env, OpBuilder &builder, OpOperand *t,
+ Value rhs) {
+ linalg::GenericOp op = env.linalgOp;
Location loc = op.getLoc();
// Direct insertion in lexicographic index order.
- if (!codegen.expValues) {
+ if (!env.expValues) {
unsigned rank = op.getRank(t);
SmallVector<Value> indices;
for (unsigned i = 0; i < rank; i++) {
- assert(codegen.loopEmitter.getLoopIV(i));
- indices.push_back(codegen.loopEmitter.getLoopIV(i));
+ assert(env.getLoopIV(i));
+ indices.push_back(env.getLoopIV(i));
}
- codegen.insChain =
- builder.create<InsertOp>(loc, rhs, codegen.insChain, indices);
+ env.insChain = builder.create<InsertOp>(loc, rhs, env.insChain, indices);
return;
}
// Generates insertion code along expanded access pattern.
@@ -858,110 +894,108 @@ static void genInsertionStore(CodeGen &codegen, OpBuilder &builder,
// expAdded[inserts++] = i
// endif
// values[i] = rhs
- Value index = genIndex(codegen, op, t);
+ Value index = genIndex(env, t);
Value fval = constantI1(builder, loc, false);
Value tval = constantI1(builder, loc, true);
// If statement.
- Value filled = builder.create<memref::LoadOp>(loc, codegen.expFilled, index);
+ Value filled = builder.create<memref::LoadOp>(loc, env.expFilled, index);
Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
filled, fval);
scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIndexType(), cond,
/*else=*/true);
// True branch.
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
- builder.create<memref::StoreOp>(loc, tval, codegen.expFilled, index);
- builder.create<memref::StoreOp>(loc, index, codegen.expAdded,
- codegen.expCount);
+ builder.create<memref::StoreOp>(loc, tval, env.expFilled, index);
+ builder.create<memref::StoreOp>(loc, index, env.expAdded, env.expCount);
Value one = constantIndex(builder, loc, 1);
- Value add = builder.create<arith::AddIOp>(loc, codegen.expCount, one);
+ Value add = builder.create<arith::AddIOp>(loc, env.expCount, one);
builder.create<scf::YieldOp>(loc, add);
// False branch.
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
- builder.create<scf::YieldOp>(loc, codegen.expCount);
+ builder.create<scf::YieldOp>(loc, env.expCount);
builder.setInsertionPointAfter(ifOp);
// Value assignment.
- codegen.expCount = ifOp.getResult(0);
- builder.create<memref::StoreOp>(loc, rhs, codegen.expValues, index);
+ env.expCount = ifOp.getResult(0);
+ builder.create<memref::StoreOp>(loc, rhs, env.expValues, index);
}
/// Generates a load on a dense or sparse tensor.
-static Value genTensorLoad(Merger &merger, CodeGen &codegen, OpBuilder &builder,
- linalg::GenericOp op, unsigned exp) {
+static Value genTensorLoad(CodeGenEnv &env, OpBuilder &builder, unsigned exp) {
// Test if the load was hoisted to a higher loop nest.
- Value val = merger.exp(exp).val;
+ Value val = env.exp(exp).val;
if (val)
return val;
// Load during insertion.
- OpOperand &t = op->getOpOperand(merger.exp(exp).tensor);
- if (&t == codegen.sparseOut) {
- if (codegen.redCustom != -1u)
- return genInsertionLoadReduce(merger, codegen, builder, op, &t);
- return genInsertionLoad(codegen, builder, op, &t);
+ linalg::GenericOp op = env.linalgOp;
+ OpOperand &t = op->getOpOperand(env.exp(exp).tensor);
+ if (&t == env.sparseOut) {
+ if (env.redCustom != -1u)
+ return genInsertionLoadReduce(env, builder, &t);
+ return genInsertionLoad(env, builder, &t);
}
// Actual load.
SmallVector<Value> args;
- Value ptr = genSubscript(codegen, builder, op, &t, args);
+ Value ptr = genSubscript(env, builder, &t, args);
return builder.create<memref::LoadOp>(op.getLoc(), ptr, args);
}
/// Generates a store on a dense or sparse tensor.
-static void genTensorStore(Merger &merger, CodeGen &codegen, OpBuilder &builder,
- linalg::GenericOp op, unsigned exp, Value rhs) {
+static void genTensorStore(CodeGenEnv &env, OpBuilder &builder, unsigned exp,
+ Value rhs) {
+ linalg::GenericOp op = env.linalgOp;
Location loc = op.getLoc();
// Test if this is a scalarized reduction.
- if (codegen.redVal) {
- updateReduc(merger, codegen, rhs);
+ if (env.redVal) {
+ updateReduc(env, rhs);
return;
}
// Store during insertion.
OpOperand *t = op.getDpsInitOperand(0);
- if (t == codegen.sparseOut) {
+ if (t == env.sparseOut) {
if (!rhs) {
// Only unary and binary are allowed to return uninitialized rhs
// to indicate missing output.
- assert(merger.exp(exp).kind == kUnary || merger.exp(exp).kind == kBinary);
- } else if (merger.exp(exp).kind == kSelect) {
+ assert(env.exp(exp).kind == kUnary || env.exp(exp).kind == kBinary);
+ } else if (env.exp(exp).kind == kSelect) {
// Select operation insertion.
- Value insChain = codegen.insChain;
+ Value insChain = env.insChain;
assert(insChain);
scf::IfOp ifOp = builder.create<scf::IfOp>(loc, insChain.getType(), rhs,
/*else=*/true);
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
// Existing value was preserved to be used here.
- assert(merger.exp(exp).val);
- Value v0 = merger.exp(exp).val;
- genInsertionStore(codegen, builder, op, t, v0);
- merger.exp(exp).val = Value();
+ assert(env.exp(exp).val);
+ Value v0 = env.exp(exp).val;
+ genInsertionStore(env, builder, t, v0);
+ env.exp(exp).val = Value();
// Yield modified insertion chain along true branch.
- builder.create<scf::YieldOp>(op.getLoc(), codegen.insChain);
+ builder.create<scf::YieldOp>(op.getLoc(), env.insChain);
// Yield original insertion chain along false branch.
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
builder.create<scf::YieldOp>(loc, insChain);
// Done with if statement.
- codegen.insChain = ifOp->getResult(0);
+ env.insChain = ifOp->getResult(0);
builder.setInsertionPointAfter(ifOp);
} else {
- genInsertionStore(codegen, builder, op, t, rhs);
+ genInsertionStore(env, builder, t, rhs);
}
return;
}
// Actual store.
SmallVector<Value> args;
- Value ptr = genSubscript(codegen, builder, op, t, args);
+ Value ptr = genSubscript(env, builder, t, args);
builder.create<memref::StoreOp>(loc, rhs, ptr, args);
}
/// Generates an invariant value.
-inline static Value genInvariantValue(Merger &merger, CodeGen &codegen,
- OpBuilder &builder, unsigned exp) {
- return merger.exp(exp).val;
+inline static Value genInvariantValue(CodeGenEnv &env, unsigned exp) {
+ return env.exp(exp).val;
}
/// Generates an index value.
-inline static Value genIndexValue(CodeGen &codegen, OpBuilder &builder,
- unsigned idx) {
- return codegen.getLoopIdxValue(idx);
+inline static Value genIndexValue(CodeGenEnv &env, unsigned idx) {
+ return env.getLoopIdxValue(idx);
}
/// Semi-ring branches are simply inlined by the sparse compiler. Prior
@@ -969,86 +1003,84 @@ inline static Value genIndexValue(CodeGen &codegen, OpBuilder &builder,
/// branch or otherwise invariantly defined outside the loop nest, with the
/// exception of index computations, which need to be relinked to actual
/// inlined cloned code.
-static Value relinkBranch(CodeGen &codegen, RewriterBase &rewriter,
- Block *block, Value e, unsigned ldx) {
+static Value relinkBranch(CodeGenEnv &env, RewriterBase &rewriter, Block *block,
+ Value e, unsigned ldx) {
if (Operation *def = e.getDefiningOp()) {
if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
- return genIndexValue(codegen, rewriter, indexOp.getDim());
+ return genIndexValue(env, indexOp.getDim());
if (def->getBlock() == block) {
for (unsigned i = 0, n = def->getNumOperands(); i < n; i++)
def->setOperand(
- i, relinkBranch(codegen, rewriter, block, def->getOperand(i), ldx));
+ i, relinkBranch(env, rewriter, block, def->getOperand(i), ldx));
}
}
return e;
}
/// Recursively generates tensor expression.
-static Value genExp(Merger &merger, CodeGen &codegen, RewriterBase &rewriter,
- linalg::GenericOp op, unsigned exp, unsigned ldx) {
+static Value genExp(CodeGenEnv &env, RewriterBase &rewriter, unsigned exp,
+ unsigned ldx) {
+ linalg::GenericOp op = env.linalgOp;
Location loc = op.getLoc();
+
if (exp == -1u)
return Value();
- if (merger.exp(exp).kind == Kind::kTensor)
- return genTensorLoad(merger, codegen, rewriter, op, exp);
- if (merger.exp(exp).kind == Kind::kInvariant)
- return genInvariantValue(merger, codegen, rewriter, exp);
- if (merger.exp(exp).kind == Kind::kIndex)
- return genIndexValue(codegen, rewriter, merger.exp(exp).index);
-
- if (merger.exp(exp).kind == Kind::kReduce) {
- // Make custom reduction identity accessible for expanded access pattern.
- assert(codegen.redCustom == -1u);
- codegen.redCustom = exp;
+ if (env.exp(exp).kind == Kind::kTensor)
+ return genTensorLoad(env, rewriter, exp);
+ if (env.exp(exp).kind == Kind::kInvariant)
+ return genInvariantValue(env, exp);
+ if (env.exp(exp).kind == Kind::kIndex)
+ return genIndexValue(env, env.exp(exp).index);
+
+ // Make custom reduction identity accessible for expanded access pattern.
+ if (env.exp(exp).kind == Kind::kReduce) {
+ assert(env.redCustom == -1u);
+ env.redCustom = exp;
}
- Value v0 =
- genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0, ldx);
- Value v1 =
- genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1, ldx);
- Value ee = merger.buildExp(rewriter, loc, exp, v0, v1);
- if (ee && (merger.exp(exp).kind == Kind::kUnary ||
- merger.exp(exp).kind == Kind::kBinary ||
- merger.exp(exp).kind == Kind::kBinaryBranch ||
- merger.exp(exp).kind == Kind::kReduce ||
- merger.exp(exp).kind == Kind::kSelect))
- ee = relinkBranch(codegen, rewriter, ee.getParentBlock(), ee, ldx);
-
- if (merger.exp(exp).kind == kSelect) {
- assert(!merger.exp(exp).val);
- merger.exp(exp).val = v0; // Preserve value for later use.
- }
-
- if (merger.exp(exp).kind == Kind::kReduce) {
- assert(codegen.redCustom != -1u);
- codegen.redCustom = -1u;
+ Value v0 = genExp(env, rewriter, env.exp(exp).children.e0, ldx);
+ Value v1 = genExp(env, rewriter, env.exp(exp).children.e1, ldx);
+ Value ee = env.merger.buildExp(rewriter, loc, exp, v0, v1);
+ if (ee && (env.exp(exp).kind == Kind::kUnary ||
+ env.exp(exp).kind == Kind::kBinary ||
+ env.exp(exp).kind == Kind::kBinaryBranch ||
+ env.exp(exp).kind == Kind::kReduce ||
+ env.exp(exp).kind == Kind::kSelect))
+ ee = relinkBranch(env, rewriter, ee.getParentBlock(), ee, ldx);
+
+ if (env.exp(exp).kind == kSelect) {
+ assert(!env.exp(exp).val);
+ env.exp(exp).val = v0; // Preserve value for later use.
+ } else if (env.exp(exp).kind == Kind::kReduce) {
+ assert(env.redCustom != -1u);
+ env.redCustom = -1u;
}
return ee;
}
/// Hoists loop invariant tensor loads for which indices have been exhausted.
-static void genInvariants(Merger &merger, CodeGen &codegen, OpBuilder &builder,
- linalg::GenericOp op, unsigned exp, unsigned ldx,
- bool atStart, unsigned last = -1u) {
+static void genInvariants(CodeGenEnv &env, OpBuilder &builder, unsigned exp,
+ unsigned ldx, bool atStart, unsigned last = -1u) {
if (exp == -1u)
return;
- if (merger.exp(exp).kind == Kind::kTensor) {
+ if (env.exp(exp).kind == Kind::kTensor) {
// Inspect tensor indices.
bool atLevel = ldx == -1u;
- OpOperand &t = op->getOpOperand(merger.exp(exp).tensor);
+ linalg::GenericOp op = env.linalgOp;
+ OpOperand &t = op->getOpOperand(env.exp(exp).tensor);
auto map = op.getMatchingIndexingMap(&t);
auto enc = getSparseTensorEncoding(t.get().getType());
for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
AffineExpr a = map.getResult(toOrigDim(enc, d));
- Optional<unsigned> sldx = merger.getLoopIdx(t.getOperandNumber(), d);
- if (sldx && merger.isFilterLoop(*sldx)) {
- if (!codegen.getLoopIdxValue(*sldx))
+ Optional<unsigned> sldx = env.merger.getLoopIdx(t.getOperandNumber(), d);
+ if (sldx && env.isFilterLoop(*sldx)) {
+ if (!env.getLoopIdxValue(*sldx))
// The filter loops has not been constructed.
return;
if (*sldx == ldx)
atLevel = true;
- } else if (!isInvariantAffine(codegen, a, ldx, atLevel))
+ } else if (!isInvariantAffine(env, a, ldx, atLevel))
return; // still in play
}
// All exhausted at this level (atLevel denotes exactly at this level).
@@ -1058,45 +1090,43 @@ static void genInvariants(Merger &merger, CodeGen &codegen, OpBuilder &builder,
if (lhs == &t) {
// Start or end a scalarized reduction
if (atStart) {
- Kind kind = merger.exp(last).kind;
- Value load = kind == Kind::kReduce
- ? getCustomRedId(merger.exp(last).op)
- : genTensorLoad(merger, codegen, builder, op, exp);
- codegen.redKind = getReduction(kind);
- codegen.redExp = exp;
- updateReduc(merger, codegen, load);
+ Kind kind = env.exp(last).kind;
+ Value load = kind == Kind::kReduce ? getCustomRedId(env.exp(last).op)
+ : genTensorLoad(env, builder, exp);
+ env.redKind = getReduction(kind);
+ env.redExp = exp;
+ updateReduc(env, load);
} else {
- Value redVal = codegen.redVal;
- updateReduc(merger, codegen, Value());
- codegen.redExp = -1u;
- codegen.redKind = kNoReduc;
- genTensorStore(merger, codegen, builder, op, exp, redVal);
+ Value redVal = env.redVal;
+ updateReduc(env, Value());
+ env.redExp = -1u;
+ env.redKind = kNoReduc;
+ genTensorStore(env, builder, exp, redVal);
}
} else {
// Start or end loop invariant hoisting of a tensor load.
- merger.exp(exp).val =
- atStart ? genTensorLoad(merger, codegen, builder, op, exp) : Value();
+ env.exp(exp).val = atStart ? genTensorLoad(env, builder, exp) : Value();
}
- } else if (merger.exp(exp).kind != Kind::kInvariant &&
- merger.exp(exp).kind != Kind::kIndex) {
+ } else if (env.exp(exp).kind != Kind::kInvariant &&
+ env.exp(exp).kind != Kind::kIndex) {
// Traverse into the binary operations. Note that we only hoist
// tensor loads, since subsequent MLIR/LLVM passes know how to
// deal with all other kinds of derived loop invariants.
- unsigned e0 = merger.exp(exp).children.e0;
- unsigned e1 = merger.exp(exp).children.e1;
- genInvariants(merger, codegen, builder, op, e0, ldx, atStart, exp);
- genInvariants(merger, codegen, builder, op, e1, ldx, atStart, exp);
+ unsigned e0 = env.exp(exp).children.e0;
+ unsigned e1 = env.exp(exp).children.e1;
+ genInvariants(env, builder, e0, ldx, atStart, exp);
+ genInvariants(env, builder, e1, ldx, atStart, exp);
}
}
/// Generates an expanded access pattern in innermost dimension.
-static void genExpansion(Merger &merger, CodeGen &codegen, OpBuilder &builder,
- linalg::GenericOp op, unsigned at, bool atStart) {
- OpOperand *lhs = codegen.sparseOut;
- if (!lhs || codegen.outerParNest != op.getRank(lhs) - 1 ||
- at != codegen.outerParNest)
+static void genExpansion(CodeGenEnv &env, OpBuilder &builder, unsigned at,
+ bool atStart) {
+ linalg::GenericOp op = env.linalgOp;
+ OpOperand *lhs = env.sparseOut;
+ if (!lhs || env.outerParNest != op.getRank(lhs) - 1 || at != env.outerParNest)
return; // not needed at this level
- assert(codegen.redVal == nullptr);
+ assert(env.redVal == nullptr);
// Generate start or end of an expanded access pattern. Note that because
// an expension does not rely on the ongoing contents of the sparse storage
// scheme, we can use the original tensor as incoming SSA value (which
@@ -1114,38 +1144,37 @@ static void genExpansion(Merger &merger, CodeGen &codegen, OpBuilder &builder,
auto res =
builder.create<ExpandOp>(loc, TypeRange({t1, t2, t3, t4}), tensor);
assert(res.getNumResults() == 4);
- assert(!codegen.expValues);
- codegen.expValues = res.getResult(0);
- codegen.expFilled = res.getResult(1);
- codegen.expAdded = res.getResult(2);
- codegen.expCount = res.getResult(3);
+ assert(!env.expValues);
+ env.expValues = res.getResult(0);
+ env.expFilled = res.getResult(1);
+ env.expAdded = res.getResult(2);
+ env.expCount = res.getResult(3);
} else {
- assert(codegen.expValues);
+ assert(env.expValues);
SmallVector<Value> indices;
for (unsigned i = 0; i < at; i++) {
- assert(codegen.loopEmitter.getLoopIV(i));
- indices.push_back(codegen.loopEmitter.getLoopIV(i));
+ assert(env.getLoopIV(i));
+ indices.push_back(env.getLoopIV(i));
}
- codegen.insChain = builder.create<CompressOp>(
- loc, codegen.expValues, codegen.expFilled, codegen.expAdded,
- codegen.expCount, codegen.insChain, indices);
- codegen.expValues = codegen.expFilled = codegen.expAdded =
- codegen.expCount = Value();
+ env.insChain = builder.create<CompressOp>(loc, env.expValues, env.expFilled,
+ env.expAdded, env.expCount,
+ env.insChain, indices);
+ env.expValues = env.expFilled = env.expAdded = env.expCount = Value();
}
}
/// Returns parallelization strategy. Any implicit loop in the Linalg
/// operation that is marked "parallel" is a candidate. Whether it is actually
/// converted to a parallel operation depends on the requested strategy.
-static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isSparse) {
+static bool isParallelFor(CodeGenEnv &env, bool isOuter, bool isSparse) {
// Reject parallelization of sparse output.
- if (codegen.sparseOut)
+ if (env.sparseOut)
return false;
// Parallel loops on tensor expansion can cause data races.
- if (codegen.expCount)
+ if (env.expCount)
return false;
// Inspect strategy.
- switch (codegen.options.parallelizationStrategy) {
+ switch (env.options.parallelizationStrategy) {
case SparseParallelizationStrategy::kNone:
return false;
case SparseParallelizationStrategy::kDenseOuterLoop:
@@ -1161,98 +1190,94 @@ static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isSparse) {
}
/// Generates a for-loop on a single index.
-static Operation *genFor(Merger &merger, CodeGen &codegen, OpBuilder &builder,
- linalg::GenericOp op, bool isOuter, bool isInner,
- unsigned idx, size_t tid, size_t dim,
+static Operation *genFor(CodeGenEnv &env, OpBuilder &builder, bool isOuter,
+ bool isInner, unsigned idx, size_t tid, size_t dim,
ArrayRef<size_t> extraTids,
ArrayRef<size_t> extraDims) {
+ linalg::GenericOp op = env.linalgOp;
Location loc = op.getLoc();
- bool isSparse = isCompressedDLT(merger.getDimLevelType(tid, idx)) ||
- isSingletonDLT(merger.getDimLevelType(tid, idx));
- bool isParallel = isParallelFor(codegen, isOuter, isSparse);
-
- Operation *loop =
- *genLoopBoundary(codegen, merger, [&](MutableArrayRef<Value> reduc) {
- if (merger.isFilterLoop(idx)) {
- // extraTids/extraDims must be empty because filter loops only
- // corresponding to the one and only sparse tensor level.
- assert(isSparse && extraTids.empty() && extraDims.empty());
- OpOperand *t = &op->getOpOperand(tid);
- auto enc = getSparseTensorEncoding(t->get().getType());
- // Retrieves the affine expression for the filter loop.
- AffineExpr a =
- op.getMatchingIndexingMap(t).getResult(toOrigDim(enc, dim));
- return codegen.loopEmitter.enterFilterLoopOverTensorAtDim(
- builder, loc, tid, dim, a, reduc);
- }
- return codegen.loopEmitter.enterLoopOverTensorAtDim(
- builder, loc, tid, dim, reduc, isParallel, extraTids, extraDims);
- });
+ auto iteratorTypes = op.getIteratorTypesArray();
+ bool isSparse = isCompressedDLT(env.dimLevelType(tid, idx)) ||
+ isSingletonDLT(env.dimLevelType(tid, idx));
+ bool isParallel = isParallelFor(env, isOuter, isSparse);
+
+ Operation *loop = *genLoopBoundary(env, [&](MutableArrayRef<Value> reduc) {
+ if (env.isFilterLoop(idx)) {
+ // extraTids/extraDims must be empty because filter loops only
+ // corresponding to the one and only sparse tensor level.
+ assert(isSparse && extraTids.empty() && extraDims.empty());
+ OpOperand *t = &op->getOpOperand(tid);
+ auto enc = getSparseTensorEncoding(t->get().getType());
+ // Retrieves the affine expression for the filter loop.
+ AffineExpr a =
+ op.getMatchingIndexingMap(t).getResult(toOrigDim(enc, dim));
+ return env.loopEmitter->enterFilterLoopOverTensorAtDim(builder, loc, tid,
+ dim, a, reduc);
+ }
+ return env.loopEmitter->enterLoopOverTensorAtDim(
+ builder, loc, tid, dim, reduc, isParallel, extraTids, extraDims);
+ });
assert(loop);
return loop;
}
/// Emit a while-loop for co-iteration over multiple indices.
-static Operation *genWhile(Merger &merger, CodeGen &codegen, OpBuilder &builder,
- linalg::GenericOp op, unsigned idx, bool needsUniv,
- ArrayRef<size_t> condTids, ArrayRef<size_t> condDims,
+static Operation *genWhile(CodeGenEnv &env, OpBuilder &builder, unsigned idx,
+ bool needsUniv, ArrayRef<size_t> condTids,
+ ArrayRef<size_t> condDims,
ArrayRef<size_t> extraTids,
ArrayRef<size_t> extraDims) {
-
- Operation *loop =
- *genLoopBoundary(codegen, merger, [&](MutableArrayRef<Value> reduc) {
- // Construct the while-loop with a parameter for each index.
- return codegen.loopEmitter.enterCoIterationOverTensorsAtDims(
- builder, op.getLoc(), condTids, condDims, needsUniv, reduc,
- extraTids, extraDims);
- });
+ Operation *loop = *genLoopBoundary(env, [&](MutableArrayRef<Value> reduc) {
+ // Construct the while-loop with a parameter for each index.
+ return env.loopEmitter->enterCoIterationOverTensorsAtDims(
+ builder, env.linalgOp.getLoc(), condTids, condDims, needsUniv, reduc,
+ extraTids, extraDims);
+ });
assert(loop);
return loop;
}
/// Generates a for-loop or a while-loop, depending on whether it implements
/// singleton iteration or co-iteration over the given conjunction.
-static Operation *genLoop(Merger &merger, CodeGen &codegen, OpBuilder &builder,
- linalg::GenericOp op, unsigned at, bool needsUniv,
- ArrayRef<size_t> condTids, ArrayRef<size_t> condDims,
- ArrayRef<size_t> extraTids,
+static Operation *genLoop(CodeGenEnv &env, OpBuilder &builder, unsigned at,
+ bool needsUniv, ArrayRef<size_t> condTids,
+ ArrayRef<size_t> condDims, ArrayRef<size_t> extraTids,
ArrayRef<size_t> extraDims) {
assert(condTids.size() == condDims.size());
assert(extraTids.size() == extraDims.size());
- unsigned idx = codegen.topSort[at];
+ unsigned idx = env.topSort[at];
if (condTids.size() == 1) {
bool isOuter = at == 0;
- bool isInner = at == codegen.topSort.size() - 1;
- return genFor(merger, codegen, builder, op, isOuter, isInner, idx,
- condTids.front(), condDims.front(), extraTids, extraDims);
+ bool isInner = at == env.topSort.size() - 1;
+ return genFor(env, builder, isOuter, isInner, idx, condTids.front(),
+ condDims.front(), extraTids, extraDims);
}
- return genWhile(merger, codegen, builder, op, idx, needsUniv, condTids,
- condDims, extraTids, extraDims);
+ return genWhile(env, builder, idx, needsUniv, condTids, condDims, extraTids,
+ extraDims);
}
/// Generates the induction structure for a while-loop.
-static void finalizeWhileOp(Merger &merger, CodeGen &codegen,
- OpBuilder &builder, linalg::GenericOp op,
- unsigned idx, bool needsUniv, BitVector &induction,
+static void finalizeWhileOp(CodeGenEnv &env, OpBuilder &builder, unsigned idx,
+ bool needsUniv, BitVector &induction,
scf::WhileOp whileOp) {
- Location loc = op.getLoc();
+ Location loc = env.linalgOp.getLoc();
// Finalize each else branch of all if statements.
- if (codegen.redVal || codegen.expValues || codegen.insChain) {
+ if (env.redVal || env.expValues || env.insChain) {
while (auto ifOp = dyn_cast_or_null<scf::IfOp>(
builder.getInsertionBlock()->getParentOp())) {
unsigned y = 0;
SmallVector<Value> yields;
- if (codegen.redVal) {
- yields.push_back(codegen.redVal);
- updateReduc(merger, codegen, ifOp.getResult(y++));
+ if (env.redVal) {
+ yields.push_back(env.redVal);
+ updateReduc(env, ifOp.getResult(y++));
}
- if (codegen.expValues) {
- yields.push_back(codegen.expCount);
- codegen.expCount = ifOp->getResult(y++);
+ if (env.expValues) {
+ yields.push_back(env.expCount);
+ env.expCount = ifOp->getResult(y++);
}
- if (codegen.insChain) {
- yields.push_back(codegen.insChain);
- codegen.insChain = ifOp->getResult(y++);
+ if (env.insChain) {
+ yields.push_back(env.insChain);
+ env.insChain = ifOp->getResult(y++);
}
assert(y == yields.size());
builder.create<scf::YieldOp>(loc, yields);
@@ -1263,62 +1288,61 @@ static void finalizeWhileOp(Merger &merger, CodeGen &codegen,
}
/// Generates a single if-statement within a while-loop.
-static scf::IfOp genIf(Merger &merger, CodeGen &codegen, OpBuilder &builder,
- linalg::GenericOp op, unsigned idx,
+static scf::IfOp genIf(CodeGenEnv &env, OpBuilder &builder, unsigned idx,
BitVector &conditions) {
- Location loc = op.getLoc();
+ Location loc = env.linalgOp.getLoc();
SmallVector<Type> types;
Value cond;
for (unsigned b = 0, be = conditions.size(); b < be; b++) {
if (!conditions[b])
continue;
- unsigned tensor = merger.tensor(b);
- assert(idx == merger.index(b));
+ unsigned tensor = env.merger.tensor(b);
+ assert(idx == env.merger.index(b));
Value clause;
- if (isCompressedDLT(merger.getDimLevelType(b)) ||
- isSingletonDLT(merger.getDimLevelType(b))) {
- auto dim = *merger.getDimNum(tensor, idx);
- Value op1 = codegen.loopEmitter.getCoord()[tensor][dim];
- Value op2 = codegen.getLoopIdxValue(idx);
+ if (isCompressedDLT(env.dimLevelType(b)) ||
+ isSingletonDLT(env.dimLevelType(b))) {
+ auto dim = *env.merger.getDimNum(tensor, idx);
+ Value op1 = env.loopEmitter->getCoord()[tensor][dim];
+ Value op2 = env.getLoopIdxValue(idx);
clause = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, op1,
op2);
} else {
- assert(isDenseDLT(merger.getDimLevelType(b)) ||
- isUndefDLT(merger.getDimLevelType(b)));
+ assert(isDenseDLT(env.dimLevelType(b)) ||
+ isUndefDLT(env.dimLevelType(b)));
clause = constantI1(builder, loc, true);
}
cond = cond ? builder.create<arith::AndIOp>(loc, cond, clause) : clause;
}
- if (codegen.redVal)
- types.push_back(codegen.redVal.getType());
- if (codegen.expValues)
+ if (env.redVal)
+ types.push_back(env.redVal.getType());
+ if (env.expValues)
types.push_back(builder.getIndexType());
- if (codegen.insChain)
- types.push_back(codegen.insChain.getType());
+ if (env.insChain)
+ types.push_back(env.insChain.getType());
scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true);
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
return ifOp;
}
/// Generates end of true branch of if-statement within a while-loop.
-static void endIf(Merger &merger, CodeGen &codegen, OpBuilder &builder,
- linalg::GenericOp op, scf::IfOp ifOp, Operation *loop,
- Value redInput, Value cntInput, Value insInput) {
+static void endIf(CodeGenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
+ Operation *loop, Value redInput, Value cntInput,
+ Value insInput) {
SmallVector<Value> operands;
- if (codegen.redVal) {
- operands.push_back(codegen.redVal);
- updateReduc(merger, codegen, redInput);
+ if (env.redVal) {
+ operands.push_back(env.redVal);
+ updateReduc(env, redInput);
}
- if (codegen.expValues) {
- operands.push_back(codegen.expCount);
- codegen.expCount = cntInput;
+ if (env.expValues) {
+ operands.push_back(env.expCount);
+ env.expCount = cntInput;
}
- if (codegen.insChain) {
- operands.push_back(codegen.insChain);
- codegen.insChain = insInput;
+ if (env.insChain) {
+ operands.push_back(env.insChain);
+ env.insChain = insInput;
}
if (!operands.empty())
- builder.create<scf::YieldOp>(op.getLoc(), operands);
+ builder.create<scf::YieldOp>(env.linalgOp.getLoc(), operands);
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
}
@@ -1328,24 +1352,24 @@ static void endIf(Merger &merger, CodeGen &codegen, OpBuilder &builder,
/// Starts a loop sequence at given level. Returns true if
/// the universal loop index must be maintained at this level.
-static bool startLoopSeq(Merger &merger, CodeGen &codegen, OpBuilder &builder,
- linalg::GenericOp op, unsigned exp, unsigned at,
- unsigned idx, unsigned ldx, unsigned lts) {
- assert(!codegen.getLoopIdxValue(idx));
+static bool startLoopSeq(CodeGenEnv &env, OpBuilder &builder, unsigned exp,
+ unsigned at, unsigned idx, unsigned ldx,
+ unsigned lts) {
+ assert(!env.getLoopIdxValue(idx));
// Emit invariants at this loop sequence level.
- genInvariants(merger, codegen, builder, op, exp, ldx, /*atStart=*/true);
+ genInvariants(env, builder, exp, ldx, /*atStart=*/true);
// Emit access pattern expansion for sparse tensor output.
- genExpansion(merger, codegen, builder, op, at, /*atStart=*/true);
+ genExpansion(env, builder, at, /*atStart=*/true);
// Emit further intitialization at this loop sequence level.
- unsigned l0 = merger.set(lts)[0];
+ unsigned l0 = env.set(lts)[0];
bool needsUniv = false;
SmallVector<size_t> tids;
SmallVector<size_t> dims;
- merger.foreachTidDimPairInBits(
- merger.lat(l0).bits,
+ env.merger.foreachTidDimPairInBits(
+ env.lat(l0).bits,
[&](unsigned b, unsigned tid, Optional<unsigned> dim, DimLevelType dlt) {
- assert(merger.index(b) == idx);
+ assert(env.merger.index(b) == idx);
if (isDenseDLT(dlt) || isUndefDLT(dlt)) {
needsUniv = true;
} else {
@@ -1355,28 +1379,27 @@ static bool startLoopSeq(Merger &merger, CodeGen &codegen, OpBuilder &builder,
}
});
- codegen.loopEmitter.enterNewLoopSeq(builder, op.getLoc(), tids, dims);
+ env.loopEmitter->enterNewLoopSeq(builder, env.linalgOp.getLoc(), tids, dims);
// Maintain the universal index only if it is actually
// consumed by a subsequent lattice point.
if (needsUniv) {
- unsigned lsize = merger.set(lts).size();
+ unsigned lsize = env.set(lts).size();
for (unsigned i = 1; i < lsize; i++) {
- unsigned li = merger.set(lts)[i];
- if (!merger.hasAnySparse(merger.lat(li).simple))
+ unsigned li = env.set(lts)[i];
+ if (!env.merger.hasAnySparse(env.lat(li).simple))
return true;
}
}
return false;
}
-static void genConstantDenseAddressFromLevel(CodeGen &codegen,
- OpBuilder &builder,
- linalg::GenericOp op, unsigned tid,
+static void genConstantDenseAddressFromLevel(CodeGenEnv &env,
+ OpBuilder &builder, unsigned tid,
unsigned lvl) {
// TODO: Handle affine expression on output tensor.
+ linalg::GenericOp op = env.linalgOp;
assert(tid < op.getNumDpsInputs());
-
OpOperand *input = op.getDpsInputOperands()[tid];
ArrayRef<AffineExpr> affines = op.getMatchingIndexingMap(input).getResults();
auto enc = getSparseTensorEncoding(input->get().getType());
@@ -1384,42 +1407,38 @@ static void genConstantDenseAddressFromLevel(CodeGen &codegen,
for (unsigned i = lvl, e = affines.size(); i < e; i++) {
AffineExpr affine = affines[toOrigDim(enc, i)];
if (isDenseDLT(getDimLevelType(enc, i)) &&
- affine.isa<AffineConstantExpr>()) {
- codegen.loopEmitter.genDenseAffineAddressAtCurLevel(
+ affine.isa<AffineConstantExpr>())
+ env.loopEmitter->genDenseAffineAddressAtCurLevel(
builder, op.getLoc(), input->getOperandNumber(), i, affine);
- } else {
- // Breaks on first non-dense non-constant level.
- return;
- }
+ else
+ return; // break on first non-dense non-constant level
}
}
}
-static void genInitConstantDenseAddress(CodeGen &codegen,
- RewriterBase &rewriter,
- linalg::GenericOp op) {
- // We can generates address for constant affine expression before any loops
+static void genInitConstantDenseAddress(CodeGenEnv &env,
+ RewriterBase &rewriter) {
+ // We can generate address for constant affine expression before any loops
// starting from the first level as they do not depend on any thing.
// E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
// levels can be determined before loops.
- for (unsigned tid = 0, e = op.getNumDpsInputs(); tid < e; tid++)
- genConstantDenseAddressFromLevel(codegen, rewriter, op, tid, 0);
+ for (unsigned tid = 0, e = env.linalgOp.getNumDpsInputs(); tid < e; tid++)
+ genConstantDenseAddressFromLevel(env, rewriter, tid, 0);
}
static void translateBitsToTidDimPairs(
- Merger &merger, CodeGen &codegen, linalg::GenericOp op, unsigned li,
- unsigned idx, SmallVectorImpl<size_t> &condTids,
- SmallVectorImpl<size_t> &condDims, SmallVectorImpl<size_t> &extraTids,
- SmallVectorImpl<size_t> &extraDims, SmallVectorImpl<size_t> &affineTids,
- SmallVectorImpl<size_t> &affineDims, SmallVectorImpl<AffineExpr> &exps) {
-
- const BitVector &all = merger.lat(li).bits;
- const BitVector &simple = merger.lat(li).simple;
+ CodeGenEnv &env, unsigned li, unsigned idx,
+ SmallVectorImpl<size_t> &condTids, SmallVectorImpl<size_t> &condDims,
+ SmallVectorImpl<size_t> &extraTids, SmallVectorImpl<size_t> &extraDims,
+ SmallVectorImpl<size_t> &affineTids, SmallVectorImpl<size_t> &affineDims,
+ SmallVectorImpl<AffineExpr> &exps) {
+ const BitVector &all = env.lat(li).bits;
+ const BitVector &simple = env.lat(li).simple;
// Converts bits to array + dim pair
- merger.foreachTidDimPairInBits(all, [&, idx](unsigned b, unsigned tid,
- Optional<unsigned> dim,
- DimLevelType dlt) {
+ env.merger.foreachTidDimPairInBits(all, [&, idx](unsigned b, unsigned tid,
+ Optional<unsigned> dim,
+ DimLevelType dlt) {
if (simple.test(b)) {
if (isUndefDLT(dlt)) {
// An undefined dlt in the lattices, we probably mean to iterate based
@@ -1428,8 +1447,8 @@ static void translateBitsToTidDimPairs(
// output tensor).
// out[i][j] = invariant; or a broadcast
// out[i][j] = in[i] (j is undef for input)
- tid = merger.getOutTensorID();
- dim = merger.getDimNum(tid, idx);
+ tid = env.merger.getOutTensorID();
+ dim = env.merger.getDimNum(tid, idx);
// Skips invalid dim (e.g., when this is a zero ranked tensor).
if (!dim)
return;
@@ -1442,6 +1461,7 @@ static void translateBitsToTidDimPairs(
extraDims.push_back(*dim);
} else {
assert(isUndefDLT(dlt));
+ linalg::GenericOp op = env.linalgOp;
if (tid >= op.getNumDpsInputs())
// We only handle affine expression on input tensors (for now).
return;
@@ -1464,7 +1484,7 @@ static void translateBitsToTidDimPairs(
// Constant affine expression are handled in genLoop
if (!exp.isa<AffineConstantExpr>()) {
bool atLevel = false;
- if (isInvariantAffine(codegen, exp, idx, atLevel) && atLevel) {
+ if (isInvariantAffine(env, exp, idx, atLevel) && atLevel) {
// If the compound affine is invariant and we are right at the
// level. We need to generate the address according to the affine
// expression. This is also the best place we can do it to avoid
@@ -1482,20 +1502,19 @@ static void translateBitsToTidDimPairs(
}
});
- if (isDenseDLT(merger.getDimLevelType(merger.getOutTensorID(), idx))) {
+ if (isDenseDLT(env.dimLevelType(env.merger.getOutTensorID(), idx))) {
// Note that we generate dense indices of the output tensor
// unconditionally, since they may not appear in the lattice, but may be
// needed for linearized codegen.
- auto dim = *merger.getDimNum(merger.getOutTensorID(), idx);
- extraTids.push_back(merger.getOutTensorID());
+ auto dim = *env.merger.getDimNum(env.merger.getOutTensorID(), idx);
+ extraTids.push_back(env.merger.getOutTensorID());
extraDims.push_back(dim);
}
}
/// Starts a single loop in current sequence.
-static Operation *startLoop(Merger &merger, CodeGen &codegen,
- OpBuilder &builder, linalg::GenericOp op,
- unsigned at, unsigned li, bool needsUniv) {
+static Operation *startLoop(CodeGenEnv &env, OpBuilder &builder, unsigned at,
+ unsigned li, bool needsUniv) {
// The set of tensors + dims to generate loops on
SmallVector<size_t> condTids, condDims;
// The set of (dense) tensors that is optimized from condition, yet still
@@ -1506,17 +1525,16 @@ static Operation *startLoop(Merger &merger, CodeGen &codegen,
// level.
SmallVector<size_t> affineTids, affineDims;
SmallVector<AffineExpr> affines;
+ translateBitsToTidDimPairs(env, li, env.topSort[at], condTids, condDims,
+ extraTids, extraDims, affineTids, affineDims,
+ affines);
- translateBitsToTidDimPairs(merger, codegen, op, li, codegen.topSort[at],
- condTids, condDims, extraTids, extraDims,
- affineTids, affineDims, affines);
// Emit the for/while-loop control.
- Operation *loop = genLoop(merger, codegen, builder, op, at, needsUniv,
- condTids, condDims, extraTids, extraDims);
-
+ Operation *loop = genLoop(env, builder, at, needsUniv, condTids, condDims,
+ extraTids, extraDims);
for (auto [tid, dim, exp] : llvm::zip(affineTids, affineDims, affines)) {
- codegen.loopEmitter.genDenseAffineAddressAtCurLevel(builder, op.getLoc(),
- tid, dim, exp);
+ env.loopEmitter->genDenseAffineAddressAtCurLevel(
+ builder, env.linalgOp.getLoc(), tid, dim, exp);
}
// Until now, we have entered every <tid, dim> pair in {cond, extra,
@@ -1525,27 +1543,25 @@ static Operation *startLoop(Merger &merger, CodeGen &codegen,
auto allTids = llvm::concat<size_t>(condTids, extraTids, affineTids);
auto allDims = llvm::concat<size_t>(condDims, extraDims, affineDims);
for (auto [tid, dim] : llvm::zip(allTids, allDims)) {
- if (tid != merger.getOutTensorID())
- genConstantDenseAddressFromLevel(codegen, builder, op, tid, dim + 1);
+ if (tid != env.merger.getOutTensorID())
+ genConstantDenseAddressFromLevel(env, builder, tid, dim + 1);
}
return loop;
}
/// Ends a single loop in current sequence. Returns new values for needsUniv.
-static bool endLoop(Merger &merger, CodeGen &codegen, RewriterBase &rewriter,
- linalg::GenericOp op, Operation *loop, unsigned idx,
- unsigned li, bool needsUniv) {
+static bool endLoop(CodeGenEnv &env, RewriterBase &rewriter, Operation *loop,
+ unsigned idx, unsigned li, bool needsUniv) {
// End a while-loop.
if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) {
- finalizeWhileOp(merger, codegen, rewriter, op, idx, needsUniv,
- merger.lat(li).bits, whileOp);
+ finalizeWhileOp(env, rewriter, idx, needsUniv, env.lat(li).bits, whileOp);
} else {
needsUniv = false;
}
- genLoopBoundary(codegen, merger, [&](MutableArrayRef<Value> reduc) {
- codegen.loopEmitter.exitCurrentLoop(rewriter, op.getLoc(), reduc);
+ genLoopBoundary(env, [&](MutableArrayRef<Value> reduc) {
+ env.loopEmitter->exitCurrentLoop(rewriter, env.linalgOp.getLoc(), reduc);
return std::nullopt;
});
@@ -1553,85 +1569,79 @@ static bool endLoop(Merger &merger, CodeGen &codegen, RewriterBase &rewriter,
}
/// Ends a loop sequence at given level.
-static void endLoopSeq(Merger &merger, CodeGen &codegen, OpBuilder &builder,
- linalg::GenericOp op, unsigned exp, unsigned at,
- unsigned idx, unsigned ldx) {
- assert(codegen.getLoopIdxValue(idx) == nullptr);
- codegen.loopEmitter.exitCurrentLoopSeq();
+static void endLoopSeq(CodeGenEnv &env, OpBuilder &builder, unsigned exp,
+ unsigned at, unsigned idx, unsigned ldx) {
+ assert(env.getLoopIdxValue(idx) == nullptr);
+ env.loopEmitter->exitCurrentLoopSeq();
// Unmark bookkeeping of invariants and loop index.
- genInvariants(merger, codegen, builder, op, exp, ldx, /*atStart=*/false);
+ genInvariants(env, builder, exp, ldx, /*atStart=*/false);
// Finalize access pattern expansion for sparse tensor output.
- genExpansion(merger, codegen, builder, op, at, /*atStart=*/false);
+ genExpansion(env, builder, at, /*atStart=*/false);
}
/// Recursively generates code while computing iteration lattices in order
/// to manage the complexity of implementing co-iteration over unions
/// and intersections of sparse iterations spaces.
-static void genStmt(Merger &merger, CodeGen &codegen, RewriterBase &rewriter,
- linalg::GenericOp op, unsigned exp, unsigned at) {
+static void genStmt(CodeGenEnv &env, RewriterBase &rewriter, unsigned exp,
+ unsigned at) {
// At each leaf, assign remaining tensor (sub)expression to output tensor.
- if (at == codegen.topSort.size()) {
- unsigned ldx = codegen.topSort[at - 1];
- Value rhs = genExp(merger, codegen, rewriter, op, exp, ldx);
- genTensorStore(merger, codegen, rewriter, op, exp, rhs);
+ if (at == env.topSort.size()) {
+ unsigned ldx = env.topSort[at - 1];
+ Value rhs = genExp(env, rewriter, exp, ldx);
+ genTensorStore(env, rewriter, exp, rhs);
return;
}
// Construct iteration lattices for current loop index, with L0 at top.
- unsigned idx = codegen.topSort[at];
- unsigned ldx = at == 0 ? -1u : codegen.topSort[at - 1];
- unsigned lts = merger.optimizeSet(merger.buildLattices(exp, idx));
+ unsigned idx = env.topSort[at];
+ unsigned ldx = at == 0 ? -1u : env.topSort[at - 1];
+ unsigned lts = env.merger.optimizeSet(env.merger.buildLattices(exp, idx));
// TODO: sort
// TODO: dedup
// Start a loop sequence.
- bool needsUniv =
- startLoopSeq(merger, codegen, rewriter, op, exp, at, idx, ldx, lts);
+ bool needsUniv = startLoopSeq(env, rewriter, exp, at, idx, ldx, lts);
// Emit a loop for every lattice point L0 >= Li in this loop sequence.
- unsigned lsize = merger.set(lts).size();
+ unsigned lsize = env.set(lts).size();
for (unsigned i = 0; i < lsize; i++) {
// Start a loop.
- unsigned li = merger.set(lts)[i];
- Operation *loop =
- startLoop(merger, codegen, rewriter, op, at, li, needsUniv);
+ unsigned li = env.set(lts)[i];
+ Operation *loop = startLoop(env, rewriter, at, li, needsUniv);
// Visit all lattices points with Li >= Lj to generate the
// loop-body, possibly with if statements for coiteration.
- Value redInput = codegen.redVal;
- Value cntInput = codegen.expCount;
- Value insInput = codegen.insChain;
+ Value redInput = env.redVal;
+ Value cntInput = env.expCount;
+ Value insInput = env.insChain;
bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr;
for (unsigned j = 0; j < lsize; j++) {
- unsigned lj = merger.set(lts)[j];
- unsigned ej = merger.lat(lj).exp;
- if (li == lj || merger.latGT(li, lj)) {
+ unsigned lj = env.set(lts)[j];
+ unsigned ej = env.lat(lj).exp;
+ if (li == lj || env.merger.latGT(li, lj)) {
// Recurse into body of each branch.
if (isWhile) {
- scf::IfOp ifOp =
- genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple);
- genStmt(merger, codegen, rewriter, op, ej, at + 1);
- endIf(merger, codegen, rewriter, op, ifOp, loop, redInput, cntInput,
- insInput);
+ scf::IfOp ifOp = genIf(env, rewriter, idx, env.lat(lj).simple);
+ genStmt(env, rewriter, ej, at + 1);
+ endIf(env, rewriter, ifOp, loop, redInput, cntInput, insInput);
} else {
- genStmt(merger, codegen, rewriter, op, ej, at + 1);
+ genStmt(env, rewriter, ej, at + 1);
}
}
}
// End a loop.
- needsUniv =
- endLoop(merger, codegen, rewriter, op, loop, idx, li, needsUniv);
+ needsUniv = endLoop(env, rewriter, loop, idx, li, needsUniv);
}
// End a loop sequence.
- endLoopSeq(merger, codegen, rewriter, op, exp, at, idx, ldx);
+ endLoopSeq(env, rewriter, exp, at, idx, ldx);
}
/// Converts the result computed by the sparse kernel into the required form.
-static void genResult(Merger &merger, CodeGen &codegen, RewriterBase &rewriter,
- linalg::GenericOp op) {
+static void genResult(CodeGenEnv &env, RewriterBase &rewriter) {
+ linalg::GenericOp op = env.linalgOp;
OpOperand *lhs = op.getDpsInitOperand(0);
Value tensor = lhs->get();
Type resType = tensor.getType();
@@ -1639,14 +1649,14 @@ static void genResult(Merger &merger, CodeGen &codegen, RewriterBase &rewriter,
// The sparse tensor rematerializes from the original sparse tensor's
// underlying sparse storage format. For an insertion chain, the
// tensor materializes from the chain with 'hasInserts' enabled.
- bool hasInserts = codegen.sparseOut == lhs;
+ bool hasInserts = env.sparseOut == lhs;
if (hasInserts)
- tensor = codegen.insChain;
+ tensor = env.insChain;
rewriter.replaceOpWithNewOp<LoadOp>(op, resType, tensor, hasInserts);
} else {
// To rematerialize an non-annotated tensor, simply load it
// from the bufferized value.
- Value val = codegen.loopEmitter.getValBuffer().back(); // value array
+ Value val = env.getValBuffer().back(); // value array
rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, val);
}
}
@@ -1656,6 +1666,7 @@ static void genResult(Merger &merger, CodeGen &codegen, RewriterBase &rewriter,
//===----------------------------------------------------------------------===//
namespace {
+
/// Sparse rewriting rule for generic Lingalg operation.
struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
public:
@@ -1664,86 +1675,84 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
LogicalResult matchAndRewrite(linalg::GenericOp op,
PatternRewriter &rewriter) const override {
- // Detects sparse annotations and translate the per-dimension sparsity
- // information for all tensors to loop indices in the kernel.
+ // Only accept single output operations.
if (op.getNumDpsInits() != 1)
return failure();
+
+ // Sets up a code generation environment.
unsigned numTensors = op->getNumOperands();
unsigned numLoops = op.getNumLoops();
unsigned numFilterLoops = getNumCompoundAffineOnSparseDims(op);
- Merger merger(numTensors, numLoops, numFilterLoops);
- if (!findSparseAnnotations(merger, op))
+ CodeGenEnv env(op, options, numTensors, numLoops, numFilterLoops);
+
+ // Detects sparse annotations and translates the per-dimension sparsity
+ // information for all tensors to loop indices in the kernel.
+ if (!findSparseAnnotations(env))
return failure();
// Builds the tensor expression for the Linalg operation in SSA form.
- Optional<unsigned> optExp = merger.buildTensorExpFromLinalg(op);
+ Optional<unsigned> optExp = env.merger.buildTensorExpFromLinalg(op);
if (!optExp.has_value())
return failure();
-
unsigned exp = *optExp;
- OpOperand *sparseOut = nullptr;
- unsigned outerParNest = 0;
+
// Computes a topologically sorted iteration graph to ensure tensors
// are visited in natural index order. Gradually relaxes the considered
// constraints until an acyclic iteration graph results, such that sparse
// code generation can proceed. As a last resort, an attempt is made
// to resolve cycles by inserting a conversion.
- std::vector<unsigned> topSort;
- // Whether the current GenericOp is admissible.
bool isAdmissible = false;
bool hasCycle = true;
// An const list of all masks that we used for interation graph
- // computation. Must be ordered from strict -> loose.
+ // computation. Must be ordered from more strict to less strict.
const auto allMask = {SortMask::kIncludeAll, SortMask::kIncludeUndef,
SortMask::kIncludeDense, SortMask::kSparseOnly};
for (auto mask : allMask)
- if (computeIterationGraph(merger, op, topSort, mask)) {
+ if (computeIterationGraph(env, mask)) {
hasCycle = false;
- if (isAdmissibleTensorExp(merger, op, topSort, exp, &sparseOut,
- outerParNest)) {
+ if (isAdmissibleTensorExp(env, exp)) {
isAdmissible = true;
break;
}
// else try a set of less strict constraints.
}
-
if (hasCycle)
- // Give it one last shot to resolve the cycle.
- return resolveCycle(merger, rewriter, op);
+ return resolveCycle(env, rewriter); // one last shot
if (!isAdmissible)
- // Inadmissible expression, reject.
- return failure();
-
- merger.setHasSparseOut(sparseOut != nullptr);
+ return failure(); // inadmissible expression, reject
+ // Updates environment with a loop emitter.
+ // TODO: refactor so that emitter can be constructed earlier
+ // and updating is made easy, i.e. remove this whole block?
SmallVector<Value> tensors;
for (OpOperand &t : op->getOpOperands())
tensors.push_back(t.get());
+ SparseTensorLoopEmitter lpe(
+ tensors,
+ StringAttr::get(op.getContext(), linalg::GenericOp::getOperationName()),
+ /*hasOutput=*/true, /*isSparseOut=*/env.sparseOut != nullptr,
+ env.topSort);
+ env.startEmit(&lpe);
// Recursively generates code if admissible.
- CodeGen codegen(options, op.getContext(), tensors, numTensors, numLoops,
- sparseOut, outerParNest, topSort);
- genBuffers(merger, codegen, rewriter, op);
- genInitConstantDenseAddress(codegen, rewriter, op);
- genStmt(merger, codegen, rewriter, op, exp, 0);
- genResult(merger, codegen, rewriter, op);
+ genBuffers(env, rewriter);
+ genInitConstantDenseAddress(env, rewriter);
+ genStmt(env, rewriter, exp, 0);
+ genResult(env, rewriter);
return success();
}
private:
// Last resort cycle resolution.
- LogicalResult resolveCycle(Merger &merger, PatternRewriter &rewriter,
- linalg::GenericOp op) const {
+ LogicalResult resolveCycle(CodeGenEnv &env, PatternRewriter &rewriter) const {
// Compute topological sort while leaving out every
// sparse input tensor in succession until an acylic
// iteration graph results.
- std::vector<unsigned> topSort;
- for (OpOperand *t : op.getDpsInputOperands()) {
+ for (OpOperand *t : env.linalgOp.getDpsInputOperands()) {
unsigned tensor = t->getOperandNumber();
Value tval = t->get();
auto srcEnc = getSparseTensorEncoding(tval.getType());
- if (!srcEnc ||
- !computeIterationGraph(merger, op, topSort, SortMask::kSparseOnly, t))
+ if (!srcEnc || !computeIterationGraph(env, SortMask::kSparseOnly, t))
continue;
// Found an input tensor that resolves the cycle by inserting a
// conversion into a sparse tensor that adheres to the iteration
@@ -1754,16 +1763,15 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
//
auto srcTp = tval.getType().cast<RankedTensorType>();
auto dstEnc = SparseTensorEncodingAttr::get(
- op->getContext(), srcEnc.getDimLevelType(),
- permute(merger, getContext(), op.getMatchingIndexingMap(t),
- topSort), // new order
+ getContext(), srcEnc.getDimLevelType(),
+ permute(env, env.linalgOp.getMatchingIndexingMap(t)), // new order
srcEnc.getHigherOrdering(), srcEnc.getPointerBitWidth(),
srcEnc.getIndexBitWidth());
auto dstTp = RankedTensorType::get(srcTp.getShape(),
srcTp.getElementType(), dstEnc);
auto convert = rewriter.create<ConvertOp>(tval.getLoc(), dstTp, tval);
- op->setOperand(tensor, convert);
- rewriter.setInsertionPointAfter(op);
+ env.linalgOp->setOperand(tensor, convert);
+ rewriter.setInsertionPointAfter(env.linalgOp);
rewriter.create<bufferization::DeallocTensorOp>(tval.getLoc(), convert);
return success();
}
More information about the Mlir-commits
mailing list