[Mlir-commits] [mlir] ff8815e - [mlir][sparse] code cleanup (remove topSort in CodegenEnv). (#72550)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Nov 16 13:21:53 PST 2023
Author: Peiming Liu
Date: 2023-11-16T13:21:49-08:00
New Revision: ff8815e597597a319ffde9d18e708040d226bbae
URL: https://github.com/llvm/llvm-project/commit/ff8815e597597a319ffde9d18e708040d226bbae
DIFF: https://github.com/llvm/llvm-project/commit/ff8815e597597a319ffde9d18e708040d226bbae.diff
LOG: [mlir][sparse] code cleanup (remove topSort in CodegenEnv). (#72550)
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
index 3a02d5634586070..ad2f649e5babdf4 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
@@ -28,21 +28,13 @@ static bool isMaterializing(Value val) {
val.getDefiningOp<bufferization::AllocTensorOp>();
}
-/// Makes target array's elements sorted according to the `order` array.
-static void sortArrayBasedOnOrder(std::vector<LoopCoeffPair> &target,
- ArrayRef<LoopId> order) {
+/// Sorts the dependent loops such that it is ordered in the same sequence in
+/// which loops will be generated.
+static void sortDependentLoops(std::vector<LoopCoeffPair> &target) {
std::sort(target.begin(), target.end(),
- [&order](const LoopCoeffPair &l, const LoopCoeffPair &r) {
+ [](const LoopCoeffPair &l, const LoopCoeffPair &r) {
assert(std::addressof(l) == std::addressof(r) || l != r);
- int idxL = -1, idxR = -1;
- for (int i = 0, e = order.size(); i < e; i++) {
- if (order[i] == l.first)
- idxL = i;
- if (order[i] == r.first)
- idxR = i;
- }
- assert(idxL >= 0 && idxR >= 0);
- return idxL < idxR;
+ return l.first < r.first;
});
}
//===----------------------------------------------------------------------===//
@@ -54,14 +46,10 @@ CodegenEnv::CodegenEnv(linalg::GenericOp linop, SparsificationOptions opts,
unsigned numFilterLoops, unsigned maxRank)
: linalgOp(linop), sparseOptions(opts),
latticeMerger(numTensors, numLoops, numFilterLoops, maxRank),
- loopEmitter(), topSort(), sparseOut(nullptr), outerParNest(-1u),
- insChain(), expValues(), expFilled(), expAdded(), expCount(), redVal(),
+ loopEmitter(), sparseOut(nullptr), outerParNest(-1u), insChain(),
+ expValues(), expFilled(), expAdded(), expCount(), redVal(),
redExp(detail::kInvalidId), redCustom(detail::kInvalidId),
- redValidLexInsert() {
- // TODO: remove topSort, loops should be already sorted by previous pass.
- for (unsigned l = 0; l < latticeMerger.getNumLoops(); l++)
- topSort.push_back(l);
-}
+ redValidLexInsert() {}
LogicalResult CodegenEnv::initTensorExp() {
// Builds the tensor expression for the Linalg operation in SSA form.
@@ -97,7 +85,7 @@ void CodegenEnv::startEmit() {
(void)enc;
assert(!enc || lvlRank == enc.getLvlRank());
for (Level lvl = 0; lvl < lvlRank; lvl++)
- sortArrayBasedOnOrder(latticeMerger.getDependentLoops(tid, lvl), topSort);
+ sortDependentLoops(latticeMerger.getDependentLoops(tid, lvl));
}
loopEmitter.initialize(
@@ -105,7 +93,7 @@ void CodegenEnv::startEmit() {
StringAttr::get(linalgOp.getContext(),
linalg::GenericOp::getOperationName()),
/*hasOutput=*/true,
- /*isSparseOut=*/sparseOut != nullptr, topSort,
+ /*isSparseOut=*/sparseOut != nullptr, /*numLoops=*/getLoopNum(),
// TODO: compute the map and pass it to loop emitter directly instead of
// passing in a callback.
/*dependentLvlGetter=*/
@@ -190,8 +178,7 @@ bool CodegenEnv::isAdmissibleTensorExp(ExprId exp) {
// needed.
outerParNest = 0;
const auto iteratorTypes = linalgOp.getIteratorTypesArray();
- assert(topSortSize() == latticeMerger.getNumLoops());
- for (const LoopId i : topSort) {
+ for (unsigned i = 0, e = getLoopNum(); i < e; i++) {
if (linalg::isReductionIterator(iteratorTypes[i]))
break; // terminate at first reduction
outerParNest++;
@@ -208,26 +195,8 @@ bool CodegenEnv::isAdmissibleTensorExp(ExprId exp) {
// Code generation environment topological sort methods
//===----------------------------------------------------------------------===//
-ArrayRef<LoopId> CodegenEnv::getTopSortSlice(LoopOrd n, LoopOrd m) const {
- return ArrayRef<LoopId>(topSort).slice(n, m);
-}
-
-ArrayRef<LoopId> CodegenEnv::getLoopStackUpTo(LoopOrd n) const {
- return ArrayRef<LoopId>(topSort).take_front(n);
-}
-
-ArrayRef<LoopId> CodegenEnv::getCurrentLoopStack() const {
- return getLoopStackUpTo(loopEmitter.getCurrentDepth());
-}
-
Value CodegenEnv::getLoopVar(LoopId i) const {
- // TODO: this class should store the inverse of `topSort` so that
- // it can do this conversion directly, instead of searching through
- // `topSort` every time. (Or else, `LoopEmitter` should handle this.)
- for (LoopOrd n = 0, numLoops = topSortSize(); n < numLoops; n++)
- if (topSort[n] == i)
- return loopEmitter.getLoopIV(n);
- llvm_unreachable("invalid loop identifier");
+ return loopEmitter.getLoopIV(i);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
index c0fc505d153a459..af783b9dca276e3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
@@ -83,6 +83,8 @@ class CodegenEnv {
}
DimLevelType dlt(TensorLoopId b) const { return latticeMerger.getLvlType(b); }
+ unsigned getLoopNum() const { return latticeMerger.getNumLoops(); }
+
//
// LoopEmitter delegates.
//
@@ -107,6 +109,8 @@ class CodegenEnv {
return loopEmitter.unpackTensorLevelRange(std::forward<ContainerTy>(c));
}
+ unsigned getLoopDepth() const { return loopEmitter.getCurrentDepth(); }
+
//
// Code generation environment verify functions.
//
@@ -115,25 +119,6 @@ class CodegenEnv {
/// It also sets the sparseOut if the output tensor is sparse.
bool isAdmissibleTensorExp(ExprId e);
- /// Whether the iteration graph is sorted in admissible topoOrder.
- /// Sets outerParNest on success with sparse output
- bool isAdmissibleTopoOrder();
-
- //
- // Topological delegate and sort methods.
- //
-
- LoopOrd topSortSize() const { return topSort.size(); }
- LoopId topSortAt(LoopOrd n) const { return topSort.at(n); }
- void topSortPushBack(LoopId i) { topSort.push_back(i); }
- void topSortClear(size_t capacity = 0) {
- topSort.clear();
- topSort.reserve(capacity);
- }
-
- ArrayRef<LoopId> getTopSortSlice(LoopOrd n, LoopOrd m) const;
- ArrayRef<LoopId> getLoopStackUpTo(LoopOrd n) const;
- ArrayRef<LoopId> getCurrentLoopStack() const;
/// Returns the induction-variable for the loop identified by the given
/// `LoopId`. This method handles application of the topological sort
/// in order to convert the `LoopId` into the corresponding `LoopOrd`.
@@ -191,10 +176,6 @@ class CodegenEnv {
// Loop emitter helper class.
LoopEmitter loopEmitter;
- // Topological sort. This serves as a mapping from `LoopOrd` to `LoopId`
- // (cf., `getLoopVar` and `topSortAt`).
- std::vector<LoopId> topSort;
-
// Sparse tensor as output. Implemented either through direct injective
// insertion in lexicographic index order or through access pattern
// expansion in the innermost loop nest (`expValues` through `expCount`).
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index 81ce525a62d9f61..ba798f09c4d583b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -276,13 +276,13 @@ Value LoopEmitter::genSparseCrd(OpBuilder &builder, Location loc, TensorId tid,
}
LoopEmitter::LoopEmitter(ValueRange tensors, StringAttr loopTag, bool hasOutput,
- bool isSparseOut, ArrayRef<LoopId> topSort,
+ bool isSparseOut, unsigned numLoops,
DependentLvlGetter dimGetter) {
- initialize(tensors, loopTag, hasOutput, isSparseOut, topSort, dimGetter);
+ initialize(tensors, loopTag, hasOutput, isSparseOut, numLoops, dimGetter);
}
void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
- bool isSparseOut, ArrayRef<LoopId> topSort,
+ bool isSparseOut, unsigned numLoops,
DependentLvlGetter dimGetter) {
// First initialize the top-level type of the fields.
this->loopTag = loopTag;
@@ -308,10 +308,8 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
this->sliceOffsets.assign(numTensors, std::vector<Value>());
this->sliceStrides.assign(numTensors, std::vector<Value>());
- const LoopOrd numLoops = topSort.size();
// These zeros will be overwritten below, but we need to initialize
// them to something since we'll need random-access assignment.
- this->loopIdToOrd.assign(numLoops, 0);
this->loopStack.reserve(numLoops);
this->loopSeqStack.reserve(numLoops);
@@ -387,13 +385,6 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
}
}
}
-
- // Construct the inverse of the `topSort` from the sparsifier.
- // This is needed to map `AffineDimExpr`s back to the `LoopOrd`
- // used in loop emitter.
- // FIXME: This map should be maintained outside loop emitter.
- for (LoopOrd n = 0; n < numLoops; n++)
- loopIdToOrd[topSort[n]] = n;
}
void LoopEmitter::initializeLoopEmit(
@@ -611,8 +602,7 @@ Value LoopEmitter::genAffine(OpBuilder &builder, Location loc, AffineExpr a) {
// However, elsewhere we have been lead to expect that `loopIdToOrd`
// should be indexed by `LoopId`...
const auto loopId = cast<AffineDimExpr>(a).getPosition();
- assert(loopId < loopIdToOrd.size());
- return loopStack[loopIdToOrd[loopId]].iv;
+ return loopStack[loopId].iv;
}
case AffineExprKind::Add: {
auto binOp = cast<AffineBinaryOpExpr>(a);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
index c6518decbdee0a4..320b39765dea4a7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
@@ -113,12 +113,11 @@ class LoopEmitter {
/// to `LoopId`.
void initialize(ValueRange tensors, StringAttr loopTag = nullptr,
bool hasOutput = false, bool isSparseOut = false,
- ArrayRef<LoopId> topSort = {},
- DependentLvlGetter getter = nullptr);
+ unsigned numLoops = 0, DependentLvlGetter getter = nullptr);
explicit LoopEmitter(ValueRange tensors, StringAttr loopTag = nullptr,
bool hasOutput = false, bool isSparseOut = false,
- ArrayRef<LoopId> topSort = {},
+ unsigned numLoops = 0,
DependentLvlGetter getter = nullptr);
/// Starts a loop emitting session by generating all the buffers needed
@@ -751,11 +750,6 @@ class LoopEmitter {
// TODO: maybe we should have a LoopSeqInfo
std::vector<std::pair<Value, std::vector<std::tuple<TensorId, Level, bool>>>>
loopSeqStack;
-
- /// Maps `LoopId` (used by `AffineDimExpr`) to `LoopOrd` (in the `loopStack`).
- /// TODO: We should probably use a callback function here to make it more
- /// general.
- std::vector<LoopOrd> loopIdToOrd;
};
} // namespace sparse_tensor
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index cd6f689a04acb53..ec96ce23be03f28 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -49,8 +49,8 @@ using namespace mlir::sparse_tensor;
// (assuming that's the actual meaning behind the "idx"-vs-"ldx" convention).
/// Determines if affine expression is invariant.
-static bool isInvariantAffine(AffineExpr a, ArrayRef<LoopId> loopStack,
- LoopId ldx, bool &isAtLoop) {
+static bool isInvariantAffine(AffineExpr a, unsigned loopDepth, LoopId ldx,
+ bool &isAtLoop) {
switch (a.getKind()) {
case AffineExprKind::DimId: {
const LoopId i = cast<AffineDimExpr>(a).getPosition();
@@ -59,19 +59,14 @@ static bool isInvariantAffine(AffineExpr a, ArrayRef<LoopId> loopStack,
// Must be invariant if we are at the given loop.
return true;
}
- bool isInvariant = false;
- for (LoopId l : loopStack) {
- isInvariant = (l == i);
- if (isInvariant)
- break;
- }
- return isInvariant;
+ // The DimExpr is invariant the loop has already been generated.
+ return i < loopDepth;
}
case AffineExprKind::Add:
case AffineExprKind::Mul: {
auto binOp = cast<AffineBinaryOpExpr>(a);
- return isInvariantAffine(binOp.getLHS(), loopStack, ldx, isAtLoop) &&
- isInvariantAffine(binOp.getRHS(), loopStack, ldx, isAtLoop);
+ return isInvariantAffine(binOp.getLHS(), loopDepth, ldx, isAtLoop) &&
+ isInvariantAffine(binOp.getRHS(), loopDepth, ldx, isAtLoop);
}
default: {
assert(isa<AffineConstantExpr>(a));
@@ -80,12 +75,6 @@ static bool isInvariantAffine(AffineExpr a, ArrayRef<LoopId> loopStack,
}
}
-/// Determines if affine expression is invariant.
-static bool isInvariantAffine(CodegenEnv &env, AffineExpr a, LoopId ldx,
- bool &isAtLoop) {
- return isInvariantAffine(a, env.getCurrentLoopStack(), ldx, isAtLoop);
-}
-
/// Helper method to inspect affine expressions. Rejects cases where the
/// same index is used more than once. Also rejects compound affine
/// expressions in sparse dimensions.
@@ -351,17 +340,6 @@ static void genBuffers(CodegenEnv &env, OpBuilder &builder) {
llvm::cast<linalg::LinalgOp>(op.getOperation())
.createLoopRanges(builder, loc);
- assert(loopRange.size() == env.merger().getStartingFilterLoopId());
- SmallVector<Range, 4> sortedRange;
- for (unsigned i = 0, e = env.topSortSize(); i < e; i++) {
- LoopId ldx = env.topSortAt(i);
- // FIXME: Gets rid of filter loops since we have a better algorithm to deal
- // with affine index expression.
- if (ldx < env.merger().getStartingFilterLoopId()) {
- sortedRange.push_back(loopRange[ldx]);
- }
- }
-
env.emitter().initializeLoopEmit(
builder, loc,
/// Generates buffer for the output tensor.
@@ -396,15 +374,14 @@ static void genBuffers(CodegenEnv &env, OpBuilder &builder) {
}
return init;
},
- [&sortedRange, &env](OpBuilder &b, Location loc, Level l) {
- assert(l < env.topSortSize());
+ [&loopRange, &env](OpBuilder &b, Location loc, Level l) {
+ assert(l < env.getLoopNum());
// FIXME: Remove filter loop since we have a better algorithm to
// deal with affine index expression.
if (l >= env.merger().getStartingFilterLoopId())
return Value();
- return mlir::getValueOrCreateConstantIndexOp(b, loc,
- sortedRange[l].size);
+ return mlir::getValueOrCreateConstantIndexOp(b, loc, loopRange[l].size);
});
}
@@ -762,7 +739,7 @@ static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp,
return;
if (*sldx == ldx)
isAtLoop = true;
- } else if (!isInvariantAffine(env, a, ldx, isAtLoop))
+ } else if (!isInvariantAffine(a, env.getLoopDepth(), ldx, isAtLoop))
return; // still in play
}
// All exhausted at this level (isAtLoop denotes exactly at this LoopId).
@@ -890,29 +867,6 @@ static bool shouldTryParallize(CodegenEnv &env, LoopId ldx, bool isOuter,
return isParallelFor(env, isOuter, isSparse);
}
-/// Generates a "filter loop" on the given tid level to locate a coordinate that
-/// is of the same value as evaluated by the affine expression in its matching
-/// indexing map.
-static Operation *genFilterLoop(CodegenEnv &env, OpBuilder &builder, LoopId ldx,
- TensorLevel tidLvl) {
- linalg::GenericOp op = env.op();
- Location loc = op.getLoc();
- Operation *loop = *env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
- assert(env.merger().isFilterLoop(ldx));
- const auto [tid, lvl] = env.unpackTensorLevel(tidLvl);
- // tids/lvls must only have one value because filter loops only
- // corresponding to the one and only sparse tensor level.
- OpOperand *t = &op->getOpOperand(tid);
- auto enc = getSparseTensorEncoding(t->get().getType());
- // Retrieves the affine expression for the filter loop.
- // FIXME: `toOrigDim` is deprecated.
- AffineExpr a = op.getMatchingIndexingMap(t).getResult(toOrigDim(enc, lvl));
- return env.emitter().enterFilterLoopOverTensorAtLvl(builder, loc, tid, lvl,
- a, reduc);
- });
- return loop;
-}
-
/// Emit a loop to coiterate over the list of tensor levels. The generated loop
/// can either be a for loop or while loop depending on whether there is at most
/// one sparse level in the list.
@@ -934,14 +888,8 @@ static Operation *genCoIteration(CodegenEnv &env, OpBuilder &builder,
/// singleton iteration or co-iteration over the given conjunction.
static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, LoopOrd at,
bool needsUniv, ArrayRef<TensorLevel> tidLvls) {
- const LoopId ldx = env.topSortAt(at);
- if (env.merger().isFilterLoop(ldx)) {
- assert(tidLvls.size() == 1);
- return genFilterLoop(env, builder, ldx, tidLvls.front());
- }
-
- bool tryParallel = shouldTryParallize(env, ldx, at == 0, tidLvls);
- return genCoIteration(env, builder, ldx, tidLvls, tryParallel, needsUniv);
+ bool tryParallel = shouldTryParallize(env, at, at == 0, tidLvls);
+ return genCoIteration(env, builder, at, tidLvls, tryParallel, needsUniv);
}
/// Generates the induction structure for a while-loop.
@@ -1066,12 +1014,12 @@ static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
/// Starts a loop sequence at given level. Returns true if
/// the universal loop index must be maintained at this level.
static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
- LoopOrd at, LoopId idx, LoopId ldx, LatSetId lts) {
+ LoopOrd idx, LoopId ldx, LatSetId lts) {
assert(!env.getLoopVar(idx));
// Emit invariants at this loop sequence level.
genInvariants(env, builder, exp, ldx, /*atStart=*/true);
// Emit access pattern expansion for sparse tensor output.
- genExpand(env, builder, at, /*atStart=*/true);
+ genExpand(env, builder, idx, /*atStart=*/true);
// Emit further intitialization at this loop sequence level.
const LatPointId l0 = env.set(lts)[0];
bool needsUniv = false;
@@ -1226,7 +1174,8 @@ static bool translateBitsToTidLvlPairs(
// Constant affine expression are handled in genLoop
if (!isa<AffineConstantExpr>(exp)) {
bool isAtLoop = false;
- if (isInvariantAffine(env, exp, ldx, isAtLoop) && isAtLoop) {
+ if (isInvariantAffine(exp, env.getLoopDepth(), ldx, isAtLoop) &&
+ isAtLoop) {
// If the compound affine is invariant and we are right at the
// level. We need to generate the address according to the
// affine expression. This is also the best place we can do it
@@ -1273,8 +1222,8 @@ static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
// becomes invariant and the address shall now be generated at the current
// level.
SmallVector<std::pair<TensorLevel, AffineExpr>> affineTidLvls;
- bool isSingleCond = translateBitsToTidLvlPairs(env, li, env.topSortAt(at),
- tidLvls, affineTidLvls);
+ bool isSingleCond =
+ translateBitsToTidLvlPairs(env, li, at, tidLvls, affineTidLvls);
// Emit the for/while-loop control.
Operation *loop = genLoop(env, builder, at, needsUniv, tidLvls);
@@ -1324,13 +1273,13 @@ static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop,
/// Ends a loop sequence at given level.
static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp,
- unsigned at, unsigned idx, unsigned ldx) {
+ unsigned idx, unsigned ldx) {
assert(!env.getLoopVar(idx));
env.emitter().exitCurrentLoopSeq(builder, env.op().getLoc());
// Unmark bookkeeping of invariants and loop index.
genInvariants(env, builder, exp, ldx, /*atStart=*/false);
// Finalize access pattern expansion for sparse tensor output.
- genExpand(env, builder, at, /*atStart=*/false);
+ genExpand(env, builder, idx, /*atStart=*/false);
}
/// Recursively generates code while computing iteration lattices in order
@@ -1339,22 +1288,19 @@ static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp,
static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
LoopOrd at) {
// At each leaf, assign remaining tensor (sub)expression to output tensor.
- if (at == env.topSortSize()) {
- const LoopId ldx = env.topSortAt(at - 1);
- Value rhs = genExp(env, rewriter, exp, ldx);
+ if (at == env.getLoopNum()) {
+ Value rhs = genExp(env, rewriter, exp, at - 1);
genTensorStore(env, rewriter, exp, rhs);
return;
}
// Construct iteration lattices for current loop index, with L0 at top.
- const LoopId idx = env.topSortAt(at);
- const LoopId ldx = at == 0 ? ::mlir::sparse_tensor::detail::kInvalidId
- : env.topSortAt(at - 1);
+ const LoopId ldx = at == 0 ? sparse_tensor::detail::kInvalidId : at - 1;
const LatSetId lts =
- env.merger().optimizeSet(env.merger().buildLattices(exp, idx));
+ env.merger().optimizeSet(env.merger().buildLattices(exp, at));
// Start a loop sequence.
- bool needsUniv = startLoopSeq(env, rewriter, exp, at, idx, ldx, lts);
+ bool needsUniv = startLoopSeq(env, rewriter, exp, at, ldx, lts);
// Emit a loop for every lattice point L0 >= Li in this loop sequence.
//
@@ -1382,7 +1328,7 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
if (li == lj || env.merger().latGT(li, lj)) {
// Recurse into body of each branch.
if (!isSingleCond) {
- scf::IfOp ifOp = genIf(env, rewriter, idx, lj);
+ scf::IfOp ifOp = genIf(env, rewriter, at, lj);
genStmt(env, rewriter, ej, at + 1);
endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
} else {
@@ -1392,11 +1338,11 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
}
// End a loop.
- needsUniv = endLoop(env, rewriter, loop, idx, li, needsUniv, isSingleCond);
+ needsUniv = endLoop(env, rewriter, loop, at, li, needsUniv, isSingleCond);
}
// End a loop sequence.
- endLoopSeq(env, rewriter, exp, at, idx, ldx);
+ endLoopSeq(env, rewriter, exp, at, ldx);
}
/// Converts the result computed by the sparse kernel into the required form.
More information about the Mlir-commits
mailing list