[Mlir-commits] [mlir] [mlir][sparse] setup `SparseIterator` to help generating code to traverse a sparse tensor level. (PR #78345)
Peiming Liu
llvmlistbot at llvm.org
Tue Jan 16 13:12:22 PST 2024
https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/78345
>From bd5649b14b4a7b48716628aa257f4d7271ee0f90 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 19 Dec 2023 21:05:25 +0000
Subject: [PATCH 01/11] [mlir][sparse] setup sparse iterator skeleton
---
.../Transforms/SparseTensorRewriting.cpp | 2 +-
.../Transforms/Sparsification.cpp | 9 +-
.../Transforms/Utils/LoopEmitter.cpp | 707 ++++++++++--------
.../Transforms/Utils/LoopEmitter.h | 44 +-
.../Transforms/Utils/SparseTensorLevel.cpp | 394 ++++++++--
.../Transforms/Utils/SparseTensorLevel.h | 195 ++++-
6 files changed, 949 insertions(+), 402 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index fa97e405584791..76df01800bda8e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -1105,7 +1105,7 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
LoopEmitter loopEmitter(
ValueRange{input},
StringAttr::get(getContext(), ForeachOp::getOperationName()));
- loopEmitter.initializeLoopEmit(rewriter, loc);
+ loopEmitter.initializeLoopEmit(rewriter, loc, /*genDedup=*/false);
for (Level l = 0; l < lvlRank; l++) {
// TODO: provide utility function for loop sequences that only contains
// one for loop?
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 5834426cae2f41..9dd9cd42b7d3ad 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -294,7 +294,7 @@ static void genBuffers(CodegenEnv &env, OpBuilder &builder) {
.createLoopRanges(builder, loc);
env.emitter().initializeLoopEmit(
- builder, loc,
+ builder, loc, /*genDedup=*/true,
/// Generates buffer for the output tensor.
/// Note that all sparse kernels assume that when all elements are written
/// to (viz. x(i) = y(i) * z(i)), the output buffer is already initialized
@@ -815,8 +815,7 @@ static Operation *genCoIteration(CodegenEnv &env, OpBuilder &builder,
Operation *loop = *env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
// Construct while-loop with a parameter for each index.
return env.emitter().enterCoIterationOverTensorsAtLvls(
- builder, env.op().getLoc(), tidLvls, reduc, tryParallel,
- /*genDedup=*/true, needsUniv);
+ builder, env.op().getLoc(), tidLvls, reduc, tryParallel, needsUniv);
});
assert(loop);
return loop;
@@ -1032,10 +1031,12 @@ static bool getAllTidLvlsInLatPoints(
});
if (isDenseLT(env.lt(outTid, curr))) {
+ auto stt = getSparseTensorType(env.op().getOutputs().front());
// Note that we generate dense indices of the output tensor
// unconditionally, since they may not appear in the lattice, but may be
// needed for linearized env.
- callback(env.makeTensorLevel(outTid, *outLvl), nullptr);
+ if (stt.hasEncoding() && stt.isAllDense())
+ callback(env.makeTensorLevel(outTid, *outLvl), nullptr);
}
if (numloopCond == 0) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
index 80dad064676220..a972de04db0a64 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
@@ -208,7 +208,7 @@ LoopEmitter::genSliceLegitPredicate(OpBuilder &builder, Location loc, Value crd,
}
// Second, coord_in_slice < length
- auto ltLength = CMPI(ult, newCrd, lvlSizes[tid][lvl]);
+ auto ltLength = CMPI(ult, newCrd, lvls[tid][lvl]->size());
conds.push_back(ltLength);
// Third, rem == 0 (skip the check if stride is known to be 1).
@@ -309,13 +309,13 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
this->tensors.assign(ts.begin(), ts.end());
// Arrays with len == numTensor.
this->lvlTypes.assign(numTensors, std::vector<LevelType>());
- this->lvlSizes.assign(numTensors, std::vector<Value>());
this->highs.assign(numTensors, std::vector<Value>());
this->segHi.assign(numTensors, std::vector<Value>());
this->posits.assign(numTensors, std::vector<Value>());
this->coords.assign(numTensors, std::vector<Value>());
this->valBuffer.assign(numTensors, nullptr);
this->lvls.resize(numTensors);
+ this->iters.resize(numTensors);
this->isSparseSlices.assign(numTensors, false);
this->sliceOffsets.assign(numTensors, std::vector<Value>());
this->sliceStrides.assign(numTensors, std::vector<Value>());
@@ -367,12 +367,12 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
}
// Initialize using empty value.
- lvlSizes[tid].assign(lvlRank, Value());
highs[tid].assign(lvlRank, Value());
segHi[tid].assign(lvlRank, Value());
posits[tid].assign(lvlRank, Value());
coords[tid].assign(lvlRank, Value());
lvls[tid].resize(lvlRank);
+ iters[tid].resize(lvlRank);
sliceOffsets[tid].assign(lvlRank, Value());
sliceStrides[tid].assign(lvlRank, Value());
@@ -408,14 +408,38 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
}
}
+std::unique_ptr<SparseIterator>
+LoopEmitter::makeLevelIterator(OpBuilder &builder, Location loc, TensorId t,
+ Level l, bool genDedup) {
+ auto it = makeSimpleIterator(*lvls[t][l], genDedup);
+ if (isSparseSlices[t]) {
+ Value offset = genSliceOffset(builder, loc, tensors[t], l);
+ Value stride = genSliceStride(builder, loc, tensors[t], l);
+ auto slicedIt = makeSlicedLevelIterator(std::move(it), offset, stride,
+ lvls[t][l]->size());
+ // TODO: remove below.
+ sliceOffsets[t][l] = offset;
+ sliceStrides[t][l] = stride;
+ return slicedIt;
+ }
+ return it;
+}
+
void LoopEmitter::initializeLoopEmit(
- OpBuilder &builder, Location loc, LoopEmitter::OutputUpdater updater,
+ OpBuilder &builder, Location loc, bool genDedup,
+ LoopEmitter::OutputUpdater updater,
LoopEmitter::SynTensorBoundSetter synSetter) {
-
+ this->genDedup = genDedup;
// For every synthetic tensor, set the high bound by calling the callback.
- if (synSetter)
- for (unsigned i = 0, e = highs[getSynTensorId()].size(); i < e; i++)
- highs[getSynTensorId()][i] = synSetter(builder, loc, i);
+ if (synSetter) {
+ TensorId synId = getSynTensorId();
+ for (unsigned i = 0, e = highs[synId].size(); i < e; i++) {
+ Value sz = highs[synId][i] = synSetter(builder, loc, i);
+ auto [stl, it] = makeSynLevelAndIterator(sz, synId, i);
+ lvls[synId][i] = std::move(stl);
+ iters[synId][i].emplace_back(std::move(it));
+ }
+ }
// For every manifest tensor:
// * get the values buffer.
@@ -448,14 +472,14 @@ void LoopEmitter::initializeLoopEmit(
// Scan all levels of current tensor.
for (Level l = 0; l < lvlRank; l++) {
- lvls[t][l] = makeSparseTensorLevel(builder, loc, tensor, l);
-
// Find upper bound in current dimension.
- highs[t][l] = lvlSizes[t][l] = lvlSzs[l];
- if (isSparseSlices[t]) {
- sliceOffsets[t][l] = genSliceOffset(builder, loc, tensors[t], l);
- sliceStrides[t][l] = genSliceStride(builder, loc, tensors[t], l);
- }
+ highs[t][l] = lvlSzs[l];
+ lvls[t][l] = makeSparseTensorLevel(builder, loc, tensor, t, l);
+ if (!dependentLvlMap[t][l].empty())
+ continue;
+
+ auto it = makeLevelIterator(builder, loc, t, l, genDedup);
+ iters[t][l].emplace_back(std::move(it));
}
// Perform the required bufferization. Dense inputs materialize
@@ -492,9 +516,65 @@ void LoopEmitter::initializeLoopEmit(
// hoist the code ouside if-conditions.
}
+ initSubSectIterator(builder, loc);
initSliceDriven(builder, loc);
}
+void LoopEmitter::initSubSectIterator(OpBuilder &builder, Location loc) {
+ Value c0 = C_IDX(0);
+ for (TensorId t = 0, e = tensors.size(); t < e; t++) {
+ auto rtp = dyn_cast<RankedTensorType>(tensors[t].getType());
+ if (!rtp)
+ continue;
+
+ Level lvlRank = SparseTensorType(rtp).getLvlRank();
+
+ // Compute the dependency reduction order.
+ auto remDepStack = dependentLvlMap;
+ std::vector<std::tuple<LoopId, TensorId, Level>> depRedOrder;
+ for (Level lvl = 0; lvl < lvlRank; lvl++) {
+ // Reverse queue into a stack.
+ std::reverse(remDepStack[t][lvl].begin(), remDepStack[t][lvl].end());
+ for (auto [loop, coeff] : dependentLvlMap[t][lvl])
+ depRedOrder.emplace_back(std::make_tuple(loop, t, lvl));
+ }
+
+ if (depRedOrder.empty())
+ continue;
+
+ std::sort(depRedOrder.begin(), depRedOrder.end(),
+ [](auto &l, auto &r) { return std::get<0>(l) < std::get<0>(r); });
+
+ for (auto [loop, t, lvl] : depRedOrder) {
+ std::pair<LoopId, unsigned> curDep = remDepStack[t][lvl].back();
+ assert(curDep.first == loop);
+ remDepStack[t][lvl].pop_back();
+
+ auto lvlIt = makeLevelIterator(builder, loc, t, lvl, genDedup);
+ const SparseIterator *parent =
+ lvl == 0 && iters[t][lvl].empty()
+ ? nullptr
+ : (!iters[t][lvl].empty() ? iters[t][lvl].back().get()
+ : iters[t][lvl - 1].back().get());
+
+ std::unique_ptr<SparseIterator> it;
+ if (!remDepStack[t][lvl].empty()) {
+ // Compute the subsection size.
+ Value size = c0;
+ for (auto [loop, stride] : remDepStack[t][lvl]) {
+ Value loopHi = highs[getSynTensorId()][loop];
+ size = ADDI(size, MULI(loopHi, C_IDX(stride)));
+ }
+ it = makeNonEmptySubSectIterator(builder, loc, parent, std::move(lvlIt),
+ size, curDep.second);
+ } else {
+ it = makeTraverseSubSectIterator(parent, std::move(lvlIt));
+ }
+ iters[t][lvl].emplace_back(std::move(it));
+ }
+ }
+}
+
void LoopEmitter::initSliceDriven(OpBuilder &builder, Location loc) {
Value c0 = C_IDX(0);
for (TensorId t = 0, e = tensors.size(); t < e; t++) {
@@ -594,6 +674,28 @@ void LoopEmitter::categorizeLoopCondition(
});
}
+void LoopEmitter::categorizeIterators(
+ ArrayRef<TensorLevel> tidLvls, SmallVectorImpl<SparseIterator *> &raIters,
+ SmallVectorImpl<SparseIterator *> &spIters) {
+ // Finds out the tensor level that we should use to generate loops. Amongs all
+ // the tensor levels, there is at most one sparse tensor level.
+ for (auto [t, l] : unpackTensorLevelRange(tidLvls)) {
+ SparseIterator *it =
+ dependentLvlMap[t][l].empty()
+ ? iters[t][l].back().get()
+ : iters[t][l][iters[t][l].size() - remDepOnLevel(t, l)].get();
+ if (it->randomAccessible())
+ raIters.push_back(it);
+ else
+ spIters.push_back(it);
+ }
+
+ std::stable_sort(spIters.begin(), spIters.end(), [](auto lhs, auto rhs) {
+ // AffineUnRed > Affine > Slice > Trivial
+ return static_cast<uint8_t>(lhs->kind) > static_cast<uint8_t>(rhs->kind);
+ });
+}
+
void LoopEmitter::enterNewLoopSeq(OpBuilder &builder, Location loc,
ArrayRef<TensorLevel> tidLvls) {
// TODO: sort
@@ -605,7 +707,7 @@ void LoopEmitter::enterNewLoopSeq(OpBuilder &builder, Location loc,
if (!dependentLvlMap[tid][lvl].empty()) {
bool fullyRed = genSliceBegin(builder, loc, tid, lvl);
slicedTids.emplace_back(tid, lvl, fullyRed);
- } else if (!isSynTensor(tid)) {
+ } else {
prepareLoopOverTensorAtLvl(builder, loc, tid, lvl);
}
}
@@ -661,16 +763,15 @@ Value LoopEmitter::genAffine(OpBuilder &builder, Location loc, AffineExpr a) {
}
std::pair<Operation *, Value> LoopEmitter::emitForLoopOverTensorAtLvl(
- OpBuilder &builder, Location loc, TensorId tid, Level lvl, Value lo,
- Value hi, MutableArrayRef<Value> reduc, bool isParallel) {
- bool isSparseCond = isCompressedLT(lvlTypes[tid][lvl]) ||
- isLooseCompressedLT(lvlTypes[tid][lvl]) ||
- is2OutOf4LT(lvlTypes[tid][lvl]) ||
- isSingletonLT(lvlTypes[tid][lvl]);
+ OpBuilder &builder, Location loc, SparseIterator &iter,
+ MutableArrayRef<Value> reduc, bool isParallel) {
+
// TODO: support dynamic slices.
// Uses the first dimension here to build the loop bound (which is also the
// biggest range).
+
Value step = C_IDX(1);
+ auto [lo, hi] = iter.genForCond(builder, loc);
Operation *loop = nullptr;
Value iv;
if (isParallel) {
@@ -703,47 +804,45 @@ std::pair<Operation *, Value> LoopEmitter::emitForLoopOverTensorAtLvl(
}
assert(loop && iv);
- Value crd;
- if (isSparseCond) {
- // For COO, the position is the same across consecutive levels.
- /// FIXME: See the [CLARIFY_POSITS_LVL] note in the header.
- posits[tid][lvl] = iv;
- crd = genSparseCrd(builder, loc, tid, lvl);
+ Value crd = iv;
+ if (!iter.randomAccessible()) {
+ iter.linkNewScope(iv);
+ crd = iter.deref(builder, loc);
} else {
- // Dense tensor, the coordinate is the inducation variable.
- crd = iv;
+ iter.locate(builder, loc, iv);
}
- if (isSparseSlices[tid] && isSparseCond) {
- // For sparse level slices, we need to filter out invalid coordinates that
- // are not included in the slice.
- SmallVector<Type> types;
- for (Value red : reduc)
- types.push_back(red.getType());
-
- auto [trans, pred] = genSliceLegitPredicate(builder, loc, crd, tid, lvl);
- bool hasReduc = !types.empty();
- scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, pred,
- /*else*/ hasReduc);
- if (hasReduc) {
- // scf.for (a) -> v
- // %s = scf.if (a) -> v
- // user-generated code.
- // else
- // yield a
- // yield %s
- YIELD(ifOp.getResults());
- builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
- // On mismatch.
- YIELD(reduc);
- }
- // Set the insertion point to matched branch.
- builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
- crd = trans;
- }
+ // if (isSparseSlices[tid] && isSparseCond) {
+ // // For sparse level slices, we need to filter out invalid coordinates
+ // that
+ // // are not included in the slice.
+ // SmallVector<Type> types;
+ // for (Value red : reduc)
+ // types.push_back(red.getType());
+
+ // auto [trans, pred] = genSliceLegitPredicate(builder, loc, crd, tid, lvl);
+ // bool hasReduc = !types.empty();
+ // scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, pred,
+ // /*else*/ hasReduc);
+ // if (hasReduc) {
+ // // scf.for (a) -> v
+ // // %s = scf.if (a) -> v
+ // // user-generated code.
+ // // else
+ // // yield a
+ // // yield %s
+ // YIELD(ifOp.getResults());
+ // builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+ // // On mismatch.
+ // YIELD(reduc);
+ // }
+ // // Set the insertion point to matched branch.
+ // builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+ // crd = trans;
+ // }
- assert(crd);
- coords[tid][lvl] = crd;
+ coords[iter.tid][iter.lvl] = crd;
+ posits[iter.tid][iter.lvl] = iter.getItVals().front();
return {loop, crd};
}
@@ -908,52 +1007,52 @@ ValueRange LoopEmitter::genCheckedValue(OpBuilder &builder, Location loc,
}
std::pair<Operation *, Value> LoopEmitter::emitWhileLoopOverTensorsAtLvls(
- OpBuilder &builder, Location loc, ArrayRef<TensorLvlCond> spConds,
+ OpBuilder &builder, Location loc, ArrayRef<SparseIterator *> spIters,
MutableArrayRef<Value> reduc, bool needsUniv) {
// NOTE: the slice driven tensor-related reduction variable must
// appear before normal tensors.
- assert(!spConds.empty());
// The set of induction variables for the while loop.
SmallVector<Value> ivs;
- // Segment sizes for induction variables used for different kinds of loop
- // conditions.
- SmallVector<unsigned> opSegSize;
// Construct the while-loop with a parameter for each coordinate.
- for (auto [tl, cKind] : spConds) {
- auto [tid, lvl] = unpackTensorLevel(tl);
- const auto lvlTp = lvlTypes[tid][lvl];
- // Dense level are handled by the shared univeral index.
- assert(!isDenseCond(cKind));
- // Must be a recognizable sparse level.
- assert(isCompressedLT(lvlTp) || isLooseCompressedLT(lvlTp) ||
- isSingletonLT(lvlTp));
- (void)lvlTp;
-
- unsigned prevSz = ivs.size();
- if (isAffineIdxCond(cKind)) {
- // TODO: Support view-based reshape on sparse levels with affine index
- // expressions.
- if (isAffineIdxUnRedCond(cKind)) {
- SliceInfo &sliceInfo = sliceStack[tid].back();
- // The order matters!
- ivs.push_back(sliceInfo.isNonEmpty);
- ivs.push_back(sliceInfo.minCrd);
- ivs.push_back(sliceInfo.offset);
- } else {
- ivs.push_back(posits[tid][lvl]); // loop lower bound (pos low).
- }
- // We reduced one more dependency after entering the loop.
- levelReducedDep[tid][lvl]++;
- } else {
- assert(dependentLvlMap[tid][lvl].empty());
- const Value pos = posits[tid][lvl];
- ivs.push_back(pos);
- }
- opSegSize.push_back(ivs.size() - prevSz);
+ for (SparseIterator *it : spIters) {
+ ValueRange itVals = it->getItVals();
+ ivs.append(itVals.begin(), itVals.end());
}
+ // for (auto [tl, cKind] : spConds) {
+ // auto [tid, lvl] = unpackTensorLevel(tl);
+ // const auto lvlTp = lvlTypes[tid][lvl];
+ // // Dense level are handled by the shared univeral index.
+ // assert(!isDenseCond(cKind));
+ // // Must be a recognizable sparse level.
+ // assert(isCompressedLT(lvlTp) || isLooseCompressedLT(lvlTp) ||
+ // isSingletonLT(lvlTp));
+ // (void)lvlTp;
+ // unsigned prevSz = ivs.size();
+ // if (isAffineIdxCond(cKind)) {
+ // // TODO: Support view-based reshape on sparse levels with affine index
+ // // expressions.
+ // if (isAffineIdxUnRedCond(cKind)) {
+ // SliceInfo &sliceInfo = sliceStack[tid].back();
+ // // The order matters!
+ // ivs.push_back(sliceInfo.isNonEmpty);
+ // ivs.push_back(sliceInfo.minCrd);
+ // ivs.push_back(sliceInfo.offset);
+ // } else {
+ // ivs.push_back(posits[tid][lvl]); // loop lower bound (pos low).
+ // }
+ // // We reduced one more dependency after entering the loop.
+ // levelReducedDep[tid][lvl]++;
+ // } else {
+ // assert(dependentLvlMap[tid][lvl].empty());
+ // const Value pos = posits[tid][lvl];
+ // ivs.push_back(pos);
+ // }
+ // opSegSize.push_back(ivs.size() - prevSz);
+ // }
+
// The position where user-supplied reduction variable starts.
ivs.append(reduc.begin(), reduc.end());
// Update universal index.
@@ -973,10 +1072,15 @@ std::pair<Operation *, Value> LoopEmitter::emitWhileLoopOverTensorsAtLvls(
builder.setInsertionPointToStart(before);
ValueRange bArgs = before->getArguments();
Value whileCond = nullptr; // bool values for loop condition.
- for (auto [c, segSz] : llvm::zip_equal(spConds, opSegSize)) {
- Value cv = genWhileLoopConditions(builder, loc, bArgs.take_front(segSz), c);
- bArgs = bArgs.drop_front(segSz);
- whileCond = !whileCond ? cv : ANDI(whileCond, cv);
+ // for (auto [c, segSz] : llvm::zip_equal(spConds, opSegSize)) {
+ // Value cv = genWhileLoopConditions(builder, loc, bArgs.take_front(segSz),
+ // c); bArgs = bArgs.drop_front(segSz); whileCond = !whileCond ? cv :
+ // ANDI(whileCond, cv);
+ // }
+ for (SparseIterator *it : spIters) {
+ auto [cond, remArgs] = it->genWhileCond(builder, loc, bArgs);
+ whileCond = !whileCond ? cond : ANDI(whileCond, cond);
+ bArgs = remArgs;
}
// The remaining block arguments are user-provided reduction values and an
// optional universal index. Make sure their sizes match.
@@ -992,48 +1096,57 @@ std::pair<Operation *, Value> LoopEmitter::emitWhileLoopOverTensorsAtLvls(
SmallVector<Value> nextArgs(aArgs.begin(), aArgs.end());
// A mutable alias for convenient slicing.
MutableArrayRef<Value> nextArgsRef = nextArgs;
- Value extraPred = nullptr;
- for (auto [c, segSz] : llvm::zip_equal(spConds, opSegSize)) {
- ValueRange condArgs = aArgs.take_front(segSz);
- auto pred = genWhileLoopBody(builder, loc, condArgs, c);
- assert(pred.has_value() == isCondWithExtraCheck(c.second));
- if (pred.has_value()) {
- // We need all extra checks to pass.
- extraPred = extraPred == nullptr ? *pred : ANDI(*pred, extraPred);
- ValueRange nxArgs = genCheckedValue(builder, loc, *pred, condArgs, c);
- assert(nxArgs.size() == segSz);
- // Update the value for cases when some check fails.
- for (unsigned i = 0; i < segSz; i++) {
- nextArgsRef[i] = nxArgs[i];
- }
- }
- aArgs = aArgs.drop_front(segSz);
- nextArgsRef = nextArgsRef.drop_front(segSz);
- }
-
- if (extraPred) {
- auto ifOp = builder.create<scf::IfOp>(loc, types, extraPred, /*else*/ true);
- // Marks this special IfOp so that Sparsification does not finalizing it.
- ifOp->setAttr(getLoopEmitterLoopAttrName(),
- StringAttr::get(builder.getContext(), "slice"));
- // Links the SSA chain outside the if statement.
- YIELD(ifOp->getResults());
-
- // If not all slices are legit, yield the updated value.
- builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
- YIELD(nextArgs);
+ // Value extraPred = nullptr;
+ // for (auto [c, segSz] : llvm::zip_equal(spConds, opSegSize)) {
+ // ValueRange condArgs = aArgs.take_front(segSz);
+ // auto pred = genWhileLoopBody(builder, loc, condArgs, c);
+ // assert(pred.has_value() == isCondWithExtraCheck(c.second));
+ // if (pred.has_value()) {
+ // // We need all extra checks to pass.
+ // extraPred = extraPred == nullptr ? *pred : ANDI(*pred, extraPred);
+ // ValueRange nxArgs = genCheckedValue(builder, loc, *pred, condArgs, c);
+ // assert(nxArgs.size() == segSz);
+ // // Update the value for cases when some check fails.
+ // for (unsigned i = 0; i < segSz; i++) {
+ // nextArgsRef[i] = nxArgs[i];
+ // }
+ // }
+ // aArgs = aArgs.drop_front(segSz);
+ // nextArgsRef = nextArgsRef.drop_front(segSz);
+ // }
- // If all slices are legit, start the user generated code.
- builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+ for (SparseIterator *it : spIters) {
+ aArgs = it->linkNewScope(aArgs);
+ Value crd = it->deref(builder, loc);
+ posits[it->tid][it->lvl] = it->getItVals().front();
+ coords[it->tid][it->lvl] = crd;
}
- for (auto [tid, lvl] : unpackTensorLevelFromCondRange(spConds)) {
- // Generates segment high for non-unique level.
- if (!isUniqueLT(lvlTypes[tid][lvl])) {
- segHi[tid][lvl] = genSegmentHigh(builder, loc, tid, lvl, posits[tid][lvl],
- highs[tid][lvl]);
- }
- }
+ // if (extraPred) {
+ // auto ifOp = builder.create<scf::IfOp>(loc, types, extraPred, /*else*/
+ // true);
+ // // Marks this special IfOp so that Sparsification does not finalizing it.
+ // ifOp->setAttr(getLoopEmitterLoopAttrName(),
+ // StringAttr::get(builder.getContext(), "slice"));
+ // // Links the SSA chain outside the if statement.
+ // YIELD(ifOp->getResults());
+
+ // // If not all slices are legit, yield the updated value.
+ // builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+ // YIELD(nextArgs);
+
+ // // If all slices are legit, start the user generated code.
+ // builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+ // }
+
+ // for (auto [tid, lvl] : unpackTensorLevelFromCondRange(spConds)) {
+ // // Generates segment high for non-unique level.
+ // if (!isUniqueLT(lvlTypes[tid][lvl])) {
+ // segHi[tid][lvl] = genSegmentHigh(builder, loc, tid, lvl,
+ // posits[tid][lvl],
+ // highs[tid][lvl]);
+ // }
+ // }
// In-place update on reduction variable.
assert(aArgs.size() == reduc.size() + needsUniv ? 1 : 0);
@@ -1043,21 +1156,15 @@ std::pair<Operation *, Value> LoopEmitter::emitWhileLoopOverTensorsAtLvls(
Value min;
// Finds the minimum coordinate
if (!needsUniv) {
- for (auto [tid, lvl] : unpackTensorLevelFromCondRange(spConds)) {
- const auto lvlTp = lvlTypes[tid][lvl];
- if (isCompressedLT(lvlTp) || isSingletonLT(lvlTp) ||
- isLooseCompressedLT(lvlTp)) {
- const auto crd = coords[tid][lvl];
- if (min) {
- Value cmp = CMPI(ult, coords[tid][lvl], min);
- min = SELECT(cmp, coords[tid][lvl], min);
- } else {
- min = crd;
- }
+ for (SparseIterator *it : spIters) {
+ if (min) {
+ Value cmp = CMPI(ult, it->getCrd(), min);
+ min = SELECT(cmp, it->getCrd(), min);
+ } else {
+ min = it->getCrd();
}
}
} else {
- assert(!min);
// Otherwise, universal index is the minimal pos.
min = whileOp.getAfterArguments().back();
}
@@ -1065,30 +1172,20 @@ std::pair<Operation *, Value> LoopEmitter::emitWhileLoopOverTensorsAtLvls(
return {whileOp, min};
}
-bool LoopEmitter::shouldIteratedByForLoop(ArrayRef<TensorLvlCond> sparseConds,
- bool genDedup) {
- assert(llvm::all_of(sparseConds,
- [](TensorLvlCond c) { return isSparseCond(c.second); }));
-
+bool LoopEmitter::shouldIteratedByForLoop(ArrayRef<SparseIterator *> spIters) {
// If we need to co-iterate over two sparse tensors, we need a while loop
- if (sparseConds.size() > 1)
+ if (spIters.size() > 1)
return false;
- // We also need a while loop for levels with affine index expression and
- // non-unique levels when deduplication is required.
- if (sparseConds.size() == 1) {
- auto [tid, lvl] = unpackTensorLevel(sparseConds.back().first);
- return !isAffineIdxCond(sparseConds.back().second) &&
- !(genDedup && !isUniqueLT(lvlTypes[tid][lvl]));
- }
+ if (spIters.size() == 1)
+ return spIters.front()->iteratableByFor();
return true;
}
Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
OpBuilder &builder, Location loc, ArrayRef<TensorLevel> tidLvls,
- MutableArrayRef<Value> reduc, bool tryParallel, bool genDedup,
- bool needsUniv) {
+ MutableArrayRef<Value> reduc, bool tryParallel, bool needsUniv) {
#ifndef NDEBUG
// Sanity checks.
assert(!tidLvls.empty());
@@ -1104,11 +1201,15 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
SmallVector<TensorLvlCond> dnConds;
categorizeLoopCondition(tidLvls, dnConds, spConds);
+ SmallVector<SparseIterator *> raIters;
+ SmallVector<SparseIterator *> spIters;
+ categorizeIterators(tidLvls, raIters, spIters);
+
// Only when there is at least one sparse conditions, do we really need the
// universal index.
// TODO: Maybe we should instead requires merger to pass in a valid value at
// the first place instead of adjusting it in LoopEmitter?
- needsUniv = !spConds.empty() && needsUniv;
+ needsUniv = !spIters.empty() && needsUniv;
// The TensorLevel used for loop conditions.
// If there is any sparse level, we need to use the sparse condition.
// If all levels are dense, we can pick arbitrary one (dense slice-driven loop
@@ -1120,38 +1221,39 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
// Generates loops differently depending on whether we need a slice-driven
// loop or a simple level traversal loop.
- if (shouldIteratedByForLoop(spConds, genDedup) && !needsUniv) {
- assert(spConds.size() <= 1);
+ if (shouldIteratedByForLoop(spIters) && !needsUniv) {
+ assert(spIters.size() <= 1);
TensorLvlCond tlCond = spConds.empty() ? dnConds.front() : spConds.front();
- auto loopCondKind = tlCond.second;
- auto [tid, lvl] = unpackTensorLevel(tlCond.first);
- Value lo = isSparseCond(loopCondKind)
- ? posits[tid][lvl] // current offset
- : loopSeqStack.back().first; // universal index
- Value hi = highs[tid][lvl];
- if (isDenseCond(loopCondKind) && isAffineIdxCond(loopCondKind)) {
- bool unReduc = isAffineIdxUnRedCond(loopCondKind);
- assert(unReduc == !depFullyReduced(tid, lvl));
- unsigned depth = sliceStack[tid].back().depth;
- assert(depth >= 1);
- // The *next* slice size after reducing the current index variable.
- auto [nxSz, nxStride] = sliceMeta[tid][lvl][depth];
- // The *current* stride to reduce the current index variable.
- // E.g., for 2 * i, stride = 2.
- unsigned stride = sliceMeta[tid][lvl][depth - 1].second;
- hi = nxSz;
- if (unReduc) {
- // Adjust for loop hi for dense slice-driven loop.
- hi = SUBI(lvlSizes[tid][lvl], hi);
- hi = ADDI(hi, C_IDX(1));
- hi = DIVUI(hi, C_IDX(stride));
- } else {
- // TODO: dialuted convolution.
- assert(nxStride == 1 && "Not yet implemented.");
- }
- }
- std::tie(l, iv) = emitForLoopOverTensorAtLvl(builder, loc, tid, lvl, lo, hi,
- reduc, tryParallel);
+ SparseIterator &it = spIters.empty() ? *raIters.front() : *spIters.front();
+ // auto [tid, lvl] = unpackTensorLevel(tlCond.first);
+ // Value lo = isSparseCond(loopCondKind)
+ // ? posits[tid][lvl] // current offset
+ // : loopSeqStack.back().first; // universal index
+ // Value hi = highs[tid][lvl];
+ // if (isDenseCond(loopCondKind) && isAffineIdxCond(loopCondKind)) {
+ // bool unReduc = isAffineIdxUnRedCond(loopCondKind);
+ // assert(unReduc == !depFullyReduced(tid, lvl));
+ // unsigned depth = sliceStack[tid].back().depth;
+ // assert(depth >= 1);
+ // // The *next* slice size after reducing the current index variable.
+ // auto [nxSz, nxStride] = sliceMeta[tid][lvl][depth];
+ // // The *current* stride to reduce the current index variable.
+ // // E.g., for 2 * i, stride = 2.
+ // unsigned stride = sliceMeta[tid][lvl][depth - 1].second;
+ // hi = nxSz;
+ // if (unReduc) {
+ // // Adjust for loop hi for dense slice-driven loop.
+ // hi = SUBI(lvls[tid][lvl]->size(), hi);
+ // hi = ADDI(hi, C_IDX(1));
+ // hi = DIVUI(hi, C_IDX(stride));
+ // } else {
+ // // TODO: dialuted convolution.
+ // assert(nxStride == 1 && "Not yet implemented.");
+ // }
+ // }
+ std::tie(l, iv) =
+ emitForLoopOverTensorAtLvl(builder, loc, it, reduc, tryParallel);
+
// For loop condition must be a trivial condition (levels without affine
// index expression).
trivialLvls.push_back(tlCond.first);
@@ -1167,12 +1269,16 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
}
}
+ if (needsUniv)
+ for (auto *it : raIters)
+ trivialLvls.push_back(makeTensorLevel(it->tid, it->lvl));
+
std::tie(l, iv) =
- emitWhileLoopOverTensorsAtLvls(builder, loc, spConds, reduc, needsUniv);
+ emitWhileLoopOverTensorsAtLvls(builder, loc, spIters, reduc, needsUniv);
}
// Enter dense tensor levels.
- enterTensorsAtDenseLvls(builder, loc, dnConds, iv, sliceDrivenInfo);
+ enterTensorsAtDenseLvls(builder, loc, raIters, iv, sliceDrivenInfo);
// NOTE: we can also prepare for next dim here in advance
// Pushes the loop into stack.
@@ -1259,98 +1365,70 @@ void LoopEmitter::genDenseAffineAddress(OpBuilder &builder, Location loc,
void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
TensorId tid, Level lvl) {
assert(isValidLevel(tid, lvl));
- const auto lvlTp = lvlTypes[tid][lvl];
-
- if (isDenseLT(lvlTp))
- return;
-
- const Value c0 = C_IDX(0);
- const Value c1 = C_IDX(1);
- // Either the first level, or the previous level has been set.
- /// FIXME: See the [CLARIFY_POSITS_LVL] note in the header.
- assert(lvl == 0 || posits[tid][lvl - 1]);
- if (isCompressedLT(lvlTp) || isLooseCompressedLT(lvlTp) ||
- is2OutOf4LT(lvlTp)) {
-
- Value pos = lvl == 0 ? c0 : posits[tid][lvl - 1];
- std::tie(posits[tid][lvl], highs[tid][lvl]) =
- lvls[tid][lvl]->peekRangeAt(builder, loc, pos);
- return;
- }
- if (isSingletonLT(lvlTp)) {
- // TODO: merge this as well when SparseTensorLevel support dedup.
- const Value pLo = lvl == 0 ? c0 : posits[tid][lvl - 1];
- posits[tid][lvl] = pLo;
-
- // If we are coiterating non-unique levels, then use pHi=segHi;
- // otherwise use pHi=pLo+1.
- // NOTE: Just because the level is non-unique, that does not
- // guarantee that segHi is defined: because we only generate segHi
- // whenever coiterating, in order to improve code quality for the
- // non-coiterating cases.
- const auto parentSegHi = segHi[tid][lvl - 1];
- highs[tid][lvl] = (!isUniqueLT(lvlTypes[tid][lvl - 1]) && parentSegHi)
- ? parentSegHi
- : ADDI(pLo, c1);
- return;
- }
- llvm_unreachable("Unrecognized level-type!");
+ const SparseIterator *parent =
+ lvl == 0 ? nullptr : iters[tid][lvl - 1].back().get();
+ SparseIterator &curIt = *iters[tid][lvl].back();
+ curIt.genInit(builder, loc, parent);
}
void LoopEmitter::enterTensorsAtDenseLvls(
- OpBuilder &builder, Location loc, ArrayRef<TensorLvlCond> dnConds, Value iv,
- SmallVectorImpl<SliceLoopInfo> &sliceInfo) {
- for (auto [dnTidLvl, denseLoopCond] : dnConds) {
- auto [tid, lvl] = unpackTensorLevel(dnTidLvl);
- assert(isDenseLT(lvlTypes[tid][lvl]));
-
- if (isAffineIdxCond(denseLoopCond)) {
- // Pushes sliced levels to build correct LoopInfo.
- bool unReduc = isAffineIdxUnRedCond(denseLoopCond);
- SliceInfo &info = sliceStack[tid].back();
- // Pushes sliced dense loop info to tell LoopEmitter how to exit it.
- sliceInfo.emplace_back(tid, lvl, /*fullyReduced=*/!unReduc);
- // FIXME: The offset and position iterator need to be adjusted when the
- // slice is strided.
- if (unReduc) {
- assert(*info.slicedOnLvl == lvl);
- unsigned depth = sliceStack[tid].back().depth;
- assert(depth >= 1);
- unsigned stride = sliceMeta[tid][lvl][depth - 1].second;
- // Update the slice information as we enter the new loop.
- info.minCrd = info.offset = MULI(iv, C_IDX(stride));
- info.isNonEmpty = constantI1(builder, loc, true);
- } else {
- posits[tid][lvl] =
- genAddress(builder, loc, tid, lvl, ADDI(info.offset, iv));
- Value fwdCnt = lvl == 0 || trivialSlice[tid][lvl]
- ? C_IDX(0)
- : sliceTupleFwdCnt[tid][lvl - 1];
- Value sz = sliceMeta[tid][lvl].back().first;
- Value mul = MULI(fwdCnt, sz);
- sliceTupleFwdCnt[tid][lvl] = ADDI(mul, iv);
- }
- levelReducedDep[tid][lvl]++;
- } else {
- // Skips the synthetic tensor
- if (isSynTensor(tid))
- continue;
- // A dense level with trivial index expression.
- assert(dependentLvlMap[tid][lvl].empty());
- auto enc = getSparseTensorEncoding(tensors[tid].getType());
- if (enc && !isSparseOutput(tid)) {
- bool validPos = lvl == 0 || posits[tid][lvl - 1];
- if (!validPos) {
- // We might not find the pos for the sparse output tensor as it is
- // unconditionally required by the sparsification.
- assert(isOutputTensor(tid));
- continue;
- }
- posits[tid][lvl] = genAddress(builder, loc, tid, lvl, iv);
- // NOTE: we can also prepare for next lvl here in advance
- }
- }
+ OpBuilder &builder, Location loc, ArrayRef<SparseIterator *> raIters,
+ Value crd, SmallVectorImpl<SliceLoopInfo> &sliceInfo) {
+ for (SparseIterator *it : raIters) {
+ it->locate(builder, loc, crd);
+ posits[it->tid][it->lvl] = it->getItVals().front();
}
+ // for (auto [dnTidLvl, denseLoopCond] : dnConds) {
+ // auto [tid, lvl] = unpackTensorLevel(dnTidLvl);
+ // assert(isDenseLT(lvlTypes[tid][lvl]));
+
+ // if (isAffineIdxCond(denseLoopCond)) {
+ // // Pushes sliced levels to build correct LoopInfo.
+ // bool unReduc = isAffineIdxUnRedCond(denseLoopCond);
+ // SliceInfo &info = sliceStack[tid].back();
+ // // Pushes sliced dense loop info to tell LoopEmitter how to exit it.
+ // sliceInfo.emplace_back(tid, lvl, /*fullyReduced=*/!unReduc);
+ // // FIXME: The offset and position iterator need to be adjusted when the
+ // // slice is strided.
+ // if (unReduc) {
+ // assert(*info.slicedOnLvl == lvl);
+ // unsigned depth = sliceStack[tid].back().depth;
+ // assert(depth >= 1);
+ // unsigned stride = sliceMeta[tid][lvl][depth - 1].second;
+ // // Update the slice information as we enter the new loop.
+ // info.minCrd = info.offset = MULI(iv, C_IDX(stride));
+ // info.isNonEmpty = constantI1(builder, loc, true);
+ // } else {
+ // posits[tid][lvl] =
+ // genAddress(builder, loc, tid, lvl, ADDI(info.offset, iv));
+ // Value fwdCnt = lvl == 0 || trivialSlice[tid][lvl]
+ // ? C_IDX(0)
+ // : sliceTupleFwdCnt[tid][lvl - 1];
+ // Value sz = sliceMeta[tid][lvl].back().first;
+ // Value mul = MULI(fwdCnt, sz);
+ // sliceTupleFwdCnt[tid][lvl] = ADDI(mul, iv);
+ // }
+ // levelReducedDep[tid][lvl]++;
+ // } else {
+ // // Skips the synthetic tensor
+ // if (isSynTensor(tid))
+ // continue;
+ // // A dense level with trivial index expression.
+ // assert(dependentLvlMap[tid][lvl].empty());
+ // auto enc = getSparseTensorEncoding(tensors[tid].getType());
+ // if (enc && !isSparseOutput(tid)) {
+ // bool validPos = lvl == 0 || posits[tid][lvl - 1];
+ // if (!validPos) {
+ // // We might not find the pos for the sparse output tensor as it is
+ // // unconditionally required by the sparsification.
+ // assert(isOutputTensor(tid));
+ // continue;
+ // }
+ // posits[tid][lvl] = genAddress(builder, loc, tid, lvl, iv);
+ // // NOTE: we can also prepare for next lvl here in advance
+ // }
+ // }
+ // }
}
void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
@@ -1457,6 +1535,7 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
unsigned o = 0;
SmallVector<Value> operands;
unsigned delta = 0;
+ ValueRange whileRes = whileOp.getResults();
for (auto [tid, lvl, resolved] : loopInfo.sliceDrivenInfo) {
// TODO: handle dense.
assert(isCompressedLT(lvlTypes[tid][lvl]));
@@ -1499,34 +1578,30 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
};
for (auto [tid, lvl] : unpackTensorLevelRange(loopInfo.trivialTidLvls)) {
- const auto lvlTp = lvlTypes[tid][lvl];
- if (isCompressedLT(lvlTp) || isSingletonLT(lvlTp) ||
- isLooseCompressedLT(lvlTp)) {
- const Value crd = coords[tid][lvl];
- const Value pos = posits[tid][lvl];
- Value cmp = CMPI(eq, crd, iv);
- // If the loop contains a coiteration with non-unique level, we fast
- // forward all the duplicated coords by setting the position to the
- // segment high.
- Value add =
- !isUniqueLT(lvlTypes[tid][lvl]) ? segHi[tid][lvl] : ADDI(pos, one);
-
- operands.push_back(SELECT(cmp, add, pos));
+ SparseIterator &it = *iters[tid][lvl].back();
+ if (!it.randomAccessible()) {
+ // Forward the sparse iterator.
+ Value cmp = CMPI(eq, it.getCrd(), iv);
+ it.forwardIf(builder, loc, cmp);
+ operands.append(it.getItVals().begin(), it.getItVals().end());
+ o += it.getItVals().size();
+ // const Value newPos = whileOp->getResult(o++);
// Following loops continue iteration from the break point of the
// current while loop.
- const Value newPos = whileOp->getResult(o++);
- // We need to define a new local variable for `tid` to avoid
- // warnings about "captured structured bindings are a C++20 extension".
- // FIXME(wrengr): define a helper function to capture this idiom!
- const TensorId newTid = tid;
- posits[newTid][lvl] = newPos;
-
- // The coordinate is invalid now.
- coords[tid][lvl] = nullptr;
- // The segment high is invalid now.
- segHi[tid][lvl] = nullptr;
- // highs remains unchanged.
+ whileRes = it.linkNewScope(whileRes);
+ } else {
+ // Make sure randomly accessible (dense) iterator is set to the right
+ // position according to the universal index.
+ Value uniIdx = whileOp.getResults().back();
+ it.locate(builder, loc, uniIdx);
}
+
+ posits[tid][lvl] = it.getItVals().front();
+ // The coordinate is invalid now.
+ coords[tid][lvl] = nullptr;
+ // The segment high is invalid now.
+ segHi[tid][lvl] = nullptr;
+ // highs remains unchanged.
}
// Reduction value from users.
@@ -1798,7 +1873,7 @@ ValueRange LoopEmitter::genUnResolvedSliceTreeTraverse(
lbs.push_back(offset);
ubs.push_back(ADDI(offset, sliceSz));
steps.push_back(c1);
- lvlSzs.push_back(lvlSizes[tid][sliceLvl]);
+ lvlSzs.push_back(lvls[tid][sliceLvl]->size());
}
auto denseNest =
scf::buildLoopNest(builder, loc, lbs, ubs, steps, innerArgs,
@@ -1938,7 +2013,7 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
Value sPtrBuf = slicePosBuffer[tid][lvl].back();
SmallVector<Value, 3> reduc = {
constantI1(builder, loc, false), // isNonEmpty
- lvlSizes[tid][lvl], // minCoord
+ lvls[tid][lvl]->size(), // minCoord
c0, // memSize
};
@@ -2108,7 +2183,7 @@ LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
reduc[2] = absOffset; // restore value.
Value mSz = info.posTupleNum; // tuple number.
- reduc[0] = lvlSizes[tid][lvl]; // next min coord
+ reduc[0] = lvls[tid][lvl]->size(); // next min coord
reduc[1] = constantI1(builder, loc, false); // isNonEmpty
auto loopArgs = static_cast<ValueRange>(reduc).drop_back();
auto forOp = scf::buildLoopNest(
@@ -2216,7 +2291,7 @@ LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
// FIXME: this only works if there is only one parent.
assert(info.depth - 1 == 0);
// nextNonEmpty = nextNonEmpty && slice upper bound <= parent upperbound.
- nextNonEmpty = ANDI(nextNonEmpty, CMPI(ule, sliceUB, lvlSizes[tid][lvl]));
+ nextNonEmpty = ANDI(nextNonEmpty, CMPI(ule, sliceUB, lvls[tid][lvl]->size()));
// FIXME: compute relative offset.
assert(info.depth - 1 == 0);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
index 450678924c138e..4d0ba11cacfc77 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
@@ -95,7 +95,7 @@ class LoopEmitter {
/// Starts a loop emitting session by generating all the buffers needed
/// for iterating over the tensors.
- void initializeLoopEmit(OpBuilder &builder, Location loc,
+ void initializeLoopEmit(OpBuilder &builder, Location loc, bool genDedup,
OutputUpdater updater = nullptr,
SynTensorBoundSetter synSetter = nullptr);
@@ -153,7 +153,7 @@ class LoopEmitter {
Operation *enterCoIterationOverTensorsAtLvls(
OpBuilder &builder, Location loc, ArrayRef<TensorLevel> tidLvls,
MutableArrayRef<Value> reduc = {}, bool isParallel = false,
- bool genDedup = false, bool needsUniv = false);
+ bool needsUniv = false);
/// Generates code to exit the current loop (e.g., generates yields, forwards
/// loop induction variables, etc).
@@ -310,6 +310,7 @@ class LoopEmitter {
///
/// Enums for different kinds of loop conditions.
+ /// TODO: remove the enum after fully migrating to SparseTensorLevel.
///
// The bit indicating whether the loop conditions is sparse.
@@ -392,6 +393,9 @@ class LoopEmitter {
SmallVectorImpl<TensorLvlCond> &dnConds,
SmallVectorImpl<TensorLvlCond> &spConds);
+ void categorizeIterators(ArrayRef<TensorLevel> tidLvls,
+ SmallVectorImpl<SparseIterator *> &raIters,
+ SmallVectorImpl<SparseIterator *> &spIters);
///
/// LoopEmitter internal helper functions.
///
@@ -400,7 +404,7 @@ class LoopEmitter {
MutableArrayRef<Value>)>;
/// Whether the list of the sparse condition should be iterated by for loop.
- bool shouldIteratedByForLoop(ArrayRef<TensorLvlCond> spConds, bool genDedup);
+ bool shouldIteratedByForLoop(ArrayRef<SparseIterator *> spIters);
/// Linearizes address for dense dimension (i.e., p = (i * d0) + j).
Value genAddress(OpBuilder &builder, Location loc, TensorId tid, Level lvl,
@@ -441,7 +445,7 @@ class LoopEmitter {
}
bool isValidLevel(TensorId tid, Level lvl) const {
- return tid < lvlTypes.size() && lvl < lvlTypes[tid].size();
+ return tid < lvls.size() && lvl < lvls[tid].size();
}
/// Prepares loop for iterating over `tensor[lvl]`, under the assumption
@@ -453,7 +457,7 @@ class LoopEmitter {
/// optimized from the loop condition, we need to compute the
/// positions/coordinates inside the loop body.
void enterTensorsAtDenseLvls(OpBuilder &builder, Location loc,
- ArrayRef<TensorLvlCond> dnConds, Value iv,
+ ArrayRef<SparseIterator *> dnConds, Value iv,
SmallVectorImpl<SliceLoopInfo> &sliceInfo);
/// Emits a for loop to iterate over a tensor level with the provided
@@ -463,9 +467,9 @@ class LoopEmitter {
/// Returns a pair: the loop generated and the value for the induction
/// variable.
std::pair<Operation *, Value>
- emitForLoopOverTensorAtLvl(OpBuilder &builder, Location loc, TensorId tid,
- Level lvl, Value lo, Value hi,
- MutableArrayRef<Value> reduc, bool isParallel);
+ emitForLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
+ SparseIterator &iter, MutableArrayRef<Value> reduc,
+ bool isParallel);
/// Emits a while loop to co-iterate over a list of sparse condition, or
/// (complex) single sparse condition that can not be handled by for loop
@@ -475,7 +479,7 @@ class LoopEmitter {
/// iterated).
std::pair<Operation *, Value>
emitWhileLoopOverTensorsAtLvls(OpBuilder &builder, Location loc,
- ArrayRef<TensorLvlCond> spConds,
+ ArrayRef<SparseIterator *> iters,
MutableArrayRef<Value> reduc, bool needsUniv);
/// Generates the while loop condition for the given tensor level condition.
@@ -530,6 +534,8 @@ class LoopEmitter {
// Slice-driven loop related methods.
//
+ void initSubSectIterator(OpBuilder &builder, Location loc);
+ // TODO: remove below.
void initSliceDriven(OpBuilder &builder, Location loc);
/// Retrieves the most recent slice on lvl. To reduce affine expression like
@@ -602,6 +608,10 @@ class LoopEmitter {
/// return true if has already been resolved.
bool genSliceBegin(OpBuilder &builder, Location loc, TensorId tid, Level lvl);
+ std::unique_ptr<SparseIterator> makeLevelIterator(OpBuilder &builder,
+ Location loc, TensorId tid,
+ Level l, bool genDedup);
+
/// Generates code to get the next non-empty slices of tid on lvl.
/// Returns a tuple of values for <NonEmpty, MinCrd, AbsOffset> (see
/// SliceInfo) respectively.
@@ -622,15 +632,18 @@ class LoopEmitter {
//
// Fields which have `numTensor` many entries.
//
- // TODO: switch to an AOS style to avoid any possible mismatches.
- //
/// Input and (optional) output tensors.
std::vector<Value> tensors;
+ std::vector<std::vector<std::unique_ptr<SparseTensorLevel>>> lvls;
+ std::vector<std::vector<std::vector<std::unique_ptr<SparseIterator>>>> iters;
+ std::vector<Value> valBuffer; // to_value
+
+ // TODO: remove all below.
/// Level-types for each `(TensorId, Level)` pair.
- std::vector<std::vector<LevelType>> lvlTypes;
// Sparse iteration information for each `(TensorId, Level)` pair.
// These arrays are updated to remain current within the current loop.
+ std::vector<std::vector<LevelType>> lvlTypes;
std::vector<std::vector<Value>> posits;
/// The collection of coordinates for a given element (one such
/// collection for each tensor).
@@ -639,8 +652,7 @@ class LoopEmitter {
std::vector<std::vector<Value>> segHi;
std::vector<std::vector<Value>> highs;
std::vector<std::vector<Value>> lvlSizes;
- std::vector<std::vector<std::unique_ptr<SparseTensorLevel>>> lvls;
- std::vector<Value> valBuffer; // to_value
+ bool genDedup; // TODO: remove it.
//
// Slice-driven loops related fields.
@@ -659,8 +671,8 @@ class LoopEmitter {
// The cached position buffer for the slices, they serve the same purpose as
// ptrBuffer for compressed dimensions.
- // But they always starts with the first pidx pointing to coord > slice.offset
- // to avoid iteration from the beginning.
+ // But they always starts with the first pidx pointing to coord >
+ // slice.offset to avoid iteration from the beginning.
std::vector<std::vector<std::vector<Value>>> slicePosBuffer;
std::vector<std::vector<Value>> sliceTupleNxStartIdx;
std::vector<std::vector<Value>> sliceTupleFwdCnt;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
index aea0910d980ab7..58cdbd1645eff2 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
@@ -9,11 +9,14 @@
#include "SparseTensorLevel.h"
#include "CodegenUtils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
using namespace mlir;
using namespace mlir::sparse_tensor;
using ValuePair = std::pair<Value, Value>;
+using ValueTuple = std::tuple<Value, Value, Value>;
//===----------------------------------------------------------------------===//
// File local helper functions/macros.
@@ -31,8 +34,44 @@ using ValuePair = std::pair<Value, Value>;
#define DIVUI(lhs, rhs) (b.create<arith::DivUIOp>(l, (lhs), (rhs)))
#define SELECT(c, lhs, rhs) (b.create<arith::SelectOp>(l, (c), (lhs), (rhs)))
-static ValuePair constantRange(OpBuilder &b, Location l, Value lo, Value sz) {
- return std::make_pair(lo, ADDI(lo, sz));
+// Helper functions that load/store into the position buffer for slice-driven
+// loops.
+static constexpr unsigned kSliceIterWidth = 3;
+// The sliced pointer buffer is organized as:
+// [[pLo0, pLo1, pLo2, ...],
+// [pHi0, pHi1, pHi2, ...],
+// [pNx0, pNx1, pNx2, ...]]
+static Value allocSlicePosBuf(OpBuilder &b, Location l, Value tupleCnt) {
+ Value bufSz = MULI(tupleCnt, C_IDX(kSliceIterWidth));
+ // Additional two metadata {memSize, idx} at head.
+ return genAlloca(b, l, bufSz, b.getIndexType());
+}
+
+// Gets and sets position values for slice-driven loops.
+enum class SlicePosKind { kLo, kHi, kNext };
+static Value getSlicePosIdx(OpBuilder &b, Location l, Value posBuf,
+ Value tupleIdx, SlicePosKind posKind) {
+ Value dim = b.create<memref::DimOp>(l, posBuf, C_IDX(0));
+ Value tupleCnt = DIVUI(dim, C_IDX(kSliceIterWidth));
+ switch (posKind) {
+ case SlicePosKind::kLo:
+ return tupleIdx;
+ case SlicePosKind::kHi:
+ return ADDI(tupleIdx, tupleCnt);
+ case SlicePosKind::kNext:
+ return ADDI(tupleIdx, MULI(tupleCnt, C_IDX(2)));
+ }
+ llvm_unreachable("unexpected kind");
+}
+static Value loadSlicePos(OpBuilder &b, Location l, Value sPosBuf,
+ Value tupleIdx, SlicePosKind posKind) {
+ return genIndexLoad(b, l, sPosBuf,
+ getSlicePosIdx(b, l, sPosBuf, tupleIdx, posKind));
+}
+static void updateSlicePos(OpBuilder &b, Location l, Value sPosBuf, Value pos,
+ Value tupleIdx, SlicePosKind posKind) {
+ b.create<memref::StoreOp>(l, pos, sPosBuf,
+ getSlicePosIdx(b, l, sPosBuf, tupleIdx, posKind));
}
//===----------------------------------------------------------------------===//
@@ -43,11 +82,12 @@ namespace {
class SparseLevel : public SparseTensorLevel {
public:
- SparseLevel(LevelType lt, Value lvlSize, Value crdBuffer)
- : SparseTensorLevel(lt, lvlSize), crdBuffer(crdBuffer) {}
+ SparseLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
+ Value crdBuffer)
+ : SparseTensorLevel(tid, lvl, lt, lvlSize), crdBuffer(crdBuffer) {}
- Value peekCrdAt(OpBuilder &b, Location l, Value pos) const override {
- return genIndexLoad(b, l, crdBuffer, pos);
+ Value peekCrdAt(OpBuilder &b, Location l, Value iv) const override {
+ return genIndexLoad(b, l, crdBuffer, iv);
}
protected:
@@ -56,10 +96,9 @@ class SparseLevel : public SparseTensorLevel {
class DenseLevel : public SparseTensorLevel {
public:
- DenseLevel(Value lvlSize) : SparseTensorLevel(LevelType::Dense, lvlSize) {
- // Dense level, loop upper bound equals to the level size.
- loopHi = lvlSize;
- }
+ DenseLevel(unsigned tid, Level lvl, Value lvlSize, bool encoded)
+ : SparseTensorLevel(tid, lvl, LevelType::Dense, lvlSize),
+ encoded(encoded) {}
Value peekCrdAt(OpBuilder &, Location, Value pos) const override {
return pos;
@@ -68,14 +107,22 @@ class DenseLevel : public SparseTensorLevel {
ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
Value max) const override {
assert(max == nullptr && "Dense level can not be non-unique.");
- return constantRange(b, l, C_IDX(0), lvlSize);
+ if (encoded) {
+ Value posLo = MULI(p, lvlSize);
+ return {posLo, lvlSize};
+ }
+ // No need to linearize the position for non-annotated tensors.
+ return {C_IDX(0), lvlSize};
}
+
+ const bool encoded;
};
class CompressedLevel : public SparseLevel {
public:
- CompressedLevel(LevelType lt, Value lvlSize, Value posBuffer, Value crdBuffer)
- : SparseLevel(lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
+ CompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
+ Value posBuffer, Value crdBuffer)
+ : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
Value max) const override {
@@ -84,7 +131,7 @@ class CompressedLevel : public SparseLevel {
Value pHi = genIndexLoad(b, l, posBuffer, ADDI(p, C_IDX(1)));
return {pLo, pHi};
}
- llvm_unreachable("TODO: dedup not implemented");
+ llvm_unreachable("compressed-nu should be the first non-unique level.");
}
private:
@@ -93,15 +140,13 @@ class CompressedLevel : public SparseLevel {
class LooseCompressedLevel : public SparseLevel {
public:
- LooseCompressedLevel(LevelType lt, Value lvlSize, Value posBuffer,
- Value crdBuffer)
- : SparseLevel(lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
+ LooseCompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
+ Value posBuffer, Value crdBuffer)
+ : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
Value max) const override {
- // Allows this?
assert(max == nullptr && "loss compressed level can not be non-unique.");
-
p = MULI(p, C_IDX(2));
Value pLo = genIndexLoad(b, l, posBuffer, p);
Value pHi = genIndexLoad(b, l, posBuffer, ADDI(p, C_IDX(1)));
@@ -114,68 +159,321 @@ class LooseCompressedLevel : public SparseLevel {
class SingletonLevel : public SparseLevel {
public:
- SingletonLevel(LevelType lt, Value lvlSize, Value crdBuffer)
- : SparseLevel(lt, lvlSize, crdBuffer) {}
+ SingletonLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
+ Value crdBuffer)
+ : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
- Value max) const override {
- if (max == nullptr)
- return constantRange(b, l, p, C_IDX(1));
- llvm_unreachable("TODO: dedup not implemented");
+ Value segHi) const override {
+ if (segHi == nullptr)
+ return {p, ADDI(p, C_IDX(1))};
+
+ // Use the segHi as the loop upper bound.
+ return {p, segHi};
}
};
class TwoOutFourLevel : public SparseLevel {
public:
- TwoOutFourLevel(LevelType lt, Value lvlSize, Value crdBuffer)
- : SparseLevel(lt, lvlSize, crdBuffer) {}
+ TwoOutFourLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
+ Value crdBuffer)
+ : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
Value max) const override {
- assert(max == nullptr && "2:4 level can not be non-unique.");
- // Each 2:4 block has exactly two specified elements.
- Value c2 = C_IDX(2);
- return constantRange(b, l, MULI(p, c2), c2);
+ assert(max == nullptr && isUnique() && "2:4 level can not be non-unique.");
+ // Each 2:4 blk has exactly two specified elements.
+ Value posLo = MULI(p, C_IDX(2));
+ return {posLo, ADDI(posLo, C_IDX(2))};
}
};
} // namespace
+//===----------------------------------------------------------------------===//
+// SparseIterator derived classes.
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+class TrivialIterator : public SparseIterator {
+ Value getLoopLo(OpBuilder &b, Location l) const {
+ // Dense loop are traversed by coordinate, delinearize the position to get
+ // the coordinate.
+ if (randomAccessible())
+ return SUBI(itPos, posLo);
+ return itPos;
+ }
+
+public:
+ TrivialIterator(const SparseTensorLevel &stl,
+ const IterKind kind = IterKind::kTrivial)
+ : SparseIterator(kind, stl.tid, stl.lvl, itPos), stl(stl) {}
+
+ // For LLVM-style RTTI.
+ static bool classof(const SparseIterator *from) {
+ return from->kind == IterKind::kTrivial;
+ }
+
+ bool randomAccessible() const override { return isDenseLT(stl.getLT()); };
+ bool iteratableByFor() const override { return true; };
+
+ ValuePair peekNxLvlRange(OpBuilder &b, Location l,
+ const SparseTensorLevel &stl) const override {
+ assert(stl.tid == this->tid && stl.lvl - 1 == this->lvl);
+ return stl.peekRangeAt(b, l, itPos);
+ }
+
+ void genInit(OpBuilder &b, Location l,
+ const SparseIterator *parent) override {
+ if (parent)
+ std::tie(posLo, loopHi) = parent->peekNxLvlRange(b, l, stl);
+ else
+ std::tie(posLo, loopHi) = stl.peekRangeAt(b, l, C_IDX(0));
+
+ // Only randomly accessible iterator's position need to be linearized.
+ seek(posLo);
+ }
+
+ ValuePair genForCond(OpBuilder &b, Location l) override {
+ assert(iteratableByFor());
+ return std::make_pair(getLoopLo(b, l), loopHi);
+ }
+
+ Value genIsEnd(OpBuilder &b, Location l) override {
+ // We used the first level bound as the bound the collapsed set of levels.
+ return CMPI(ult, itPos, loopHi);
+ }
+
+ Value deref(OpBuilder &b, Location l) override {
+ updateCrd(stl.peekCrdAt(b, l, itPos));
+ return getCrd();
+ };
+
+ ValueRange forward(OpBuilder &b, Location l) override {
+ seek(ADDI(itPos, C_IDX(1)).getResult());
+ return getItVals();
+ }
+
+ void locate(OpBuilder &b, Location l, Value crd) override {
+ assert(randomAccessible());
+ // Seek to the linearized position.
+ seek(ADDI(crd, posLo).getResult());
+ updateCrd(crd);
+ }
+
+ Value itPos; // the position that represent the iterator
+
+ Value posLo, loopHi;
+ const SparseTensorLevel &stl;
+};
+
+class DedupIterator : public SparseIterator {
+private:
+ Value genSegmentHigh(OpBuilder &b, Location l, Value pos);
+
+public:
+ DedupIterator(const SparseTensorLevel &stl)
+ : SparseIterator(IterKind::kDedup, stl.tid, stl.lvl, posAndSegHi),
+ stl(stl) {
+ assert(!stl.isUnique());
+ }
+ // For LLVM-style RTTI.
+ static bool classof(const SparseIterator *from) {
+ return from->kind == IterKind::kDedup;
+ }
+
+ bool randomAccessible() const override { return false; };
+ bool iteratableByFor() const override { return false; };
+
+ ValuePair peekNxLvlRange(OpBuilder &b, Location l,
+ const SparseTensorLevel &stl) const override {
+ assert(stl.tid == this->tid && stl.lvl - 1 == this->lvl);
+ return stl.peekRangeAt(b, l, getPos(), getSegHi());
+ }
+
+ void genInit(OpBuilder &b, Location l,
+ const SparseIterator *parent) override {
+ Value posLo;
+
+ if (parent)
+ std::tie(posLo, loopHi) = parent->peekNxLvlRange(b, l, stl);
+ else
+ std::tie(posLo, loopHi) = stl.peekRangeAt(b, l, C_IDX(0));
+
+ seek({posLo, genSegmentHigh(b, l, posLo)});
+ }
+
+ Value genIsEnd(OpBuilder &b, Location l) override {
+ return CMPI(ult, getPos(), loopHi);
+ }
+
+ Value deref(OpBuilder &b, Location l) override {
+ updateCrd(stl.peekCrdAt(b, l, getPos()));
+ return getCrd();
+ };
+
+ ValueRange forward(OpBuilder &b, Location l) override {
+ Value nxPos = getSegHi(); // forward the position to the next segment.
+ seek({nxPos, genSegmentHigh(b, l, nxPos)});
+ return getItVals();
+ }
+
+ Value getPos() const { return posAndSegHi[0]; }
+ Value getSegHi() const { return posAndSegHi[1]; }
+
+ Value loopHi;
+ Value posAndSegHi[2]; // position and segment high
+ const SparseTensorLevel &stl;
+};
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// SparseIterator derived classes impl.
+//===----------------------------------------------------------------------===//
+
+ValueRange SparseIterator::forwardIf(OpBuilder &b, Location l, Value cond) {
+ auto ifOp = b.create<scf::IfOp>(l, getItVals().getTypes(), cond, true);
+ // Generate else branch first, otherwise iterator values will be updated by
+ // `forward()`.
+ b.setInsertionPointToStart(ifOp.elseBlock());
+ YIELD(getItVals());
+
+ b.setInsertionPointToStart(ifOp.thenBlock());
+ YIELD(forward(b, l));
+
+ b.setInsertionPointAfter(ifOp);
+ seek(ifOp.getResults());
+ return getItVals();
+}
+
+Value DedupIterator::genSegmentHigh(OpBuilder &b, Location l, Value pos) {
+ auto whileOp = b.create<scf::WhileOp>(
+ l, pos.getType(), pos,
+ /*beforeBuilder=*/
+ [this, pos](OpBuilder &b, Location l, ValueRange ivs) {
+ Value inBound = CMPI(ult, ivs.front(), loopHi);
+ auto ifInBound = b.create<scf::IfOp>(l, b.getI1Type(), inBound, true);
+ {
+ OpBuilder::InsertionGuard guard(b);
+ // If in bound, load the next coordinates and check duplication.
+ b.setInsertionPointToStart(ifInBound.thenBlock());
+ Value headCrd = stl.peekCrdAt(b, l, pos);
+ Value tailCrd = stl.peekCrdAt(b, l, ivs.front());
+ Value isDup = CMPI(eq, headCrd, tailCrd);
+ YIELD(isDup);
+ // Else, the position is out of bound, yield false.
+ b.setInsertionPointToStart(ifInBound.elseBlock());
+ YIELD(constantI1(b, l, false));
+ }
+ b.create<scf::ConditionOp>(l, ifInBound.getResults()[0], ivs);
+ },
+ /*afterBuilder=*/
+ [](OpBuilder &b, Location l, ValueRange ivs) {
+ // pos ++
+ Value nxPos = ADDI(ivs[0], C_IDX(1));
+ YIELD(nxPos);
+ });
+ // Return the segment high.
+ return whileOp.getResult(0);
+}
+
+Value FilterIterator::genShouldFilter(OpBuilder &b, Location l) {
+ Value end = wrap->genIsEnd(b, l);
+
+ auto shouldFilter = b.create<scf::IfOp>(l, b.getI1Type(), end, true);
+ // it.end() ? false : should_filter(*it);
+ b.setInsertionPointToStart(shouldFilter.thenBlock());
+ YIELD(constantI1(b, l, false));
+
+ // Iterator not at the end.
+ b.setInsertionPointToStart(shouldFilter.elseBlock());
+ Value wrapCrd = wrap->deref(b, l);
+ Value crd = fromWrapCrd(b, l, wrapCrd);
+ // on stride
+ Value legit = CMPI(eq, toWrapCrd(b, l, crd), wrapCrd);
+ // wrapCrd >= offset
+ legit = ANDI(CMPI(uge, wrapCrd, offset), legit);
+ // crd < length
+ legit = ANDI(CMPI(ult, crd, size), legit);
+ YIELD(legit);
+
+ b.setInsertionPointAfter(shouldFilter);
+ return shouldFilter.getResult(0);
+}
+
std::unique_ptr<SparseTensorLevel>
-sparse_tensor::makeSparseTensorLevel(OpBuilder &builder, Location loc, Value t,
- Level l) {
+sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
+ unsigned tid, Level lvl) {
auto stt = getSparseTensorType(t);
- LevelType lt = stt.getLvlType(l);
- Value lvlSz = stt.hasEncoding()
- ? builder.create<LvlOp>(loc, t, l).getResult()
- : builder.create<tensor::DimOp>(loc, t, l).getResult();
+ LevelType lt = stt.getLvlType(lvl);
+ Value sz = stt.hasEncoding() ? b.create<LvlOp>(l, t, lvl).getResult()
+ : b.create<tensor::DimOp>(l, t, lvl).getResult();
switch (*getLevelFormat(lt)) {
case LevelFormat::Dense:
- return std::make_unique<DenseLevel>(lvlSz);
+ return std::make_unique<DenseLevel>(tid, lvl, sz, stt.hasEncoding());
case LevelFormat::Compressed: {
- Value posBuf = genToPositions(builder, loc, t, l);
- Value crdBuf = genToCoordinates(builder, loc, t, l);
- return std::make_unique<CompressedLevel>(lt, lvlSz, posBuf, crdBuf);
+ Value pos = genToPositions(b, l, t, lvl);
+ Value crd = genToCoordinates(b, l, t, lvl);
+ return std::make_unique<CompressedLevel>(tid, lvl, lt, sz, pos, crd);
}
case LevelFormat::LooseCompressed: {
- Value posBuf = genToPositions(builder, loc, t, l);
- Value crdBuf = genToCoordinates(builder, loc, t, l);
- return std::make_unique<LooseCompressedLevel>(lt, lvlSz, posBuf, crdBuf);
+ Value pos = genToPositions(b, l, t, lvl);
+ Value crd = genToCoordinates(b, l, t, lvl);
+ return std::make_unique<LooseCompressedLevel>(tid, lvl, lt, sz, pos, crd);
}
case LevelFormat::Singleton: {
- Value crdBuf = genToCoordinates(builder, loc, t, l);
- return std::make_unique<SingletonLevel>(lt, lvlSz, crdBuf);
+ Value crd = genToCoordinates(b, l, t, lvl);
+ return std::make_unique<SingletonLevel>(tid, lvl, lt, sz, crd);
}
case LevelFormat::TwoOutOfFour: {
- Value crdBuf = genToCoordinates(builder, loc, t, l);
- return std::make_unique<TwoOutFourLevel>(lt, lvlSz, crdBuf);
+ Value crd = genToCoordinates(b, l, t, lvl);
+ return std::make_unique<TwoOutFourLevel>(tid, lvl, lt, sz, crd);
}
}
llvm_unreachable("unrecognizable level format");
}
+std::pair<std::unique_ptr<SparseTensorLevel>, std::unique_ptr<SparseIterator>>
+sparse_tensor::makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl) {
+ auto stl = std::make_unique<DenseLevel>(tid, lvl, sz, /*encoded=*/false);
+ auto it = std::make_unique<TrivialIterator>(*stl);
+ return std::make_pair(std::move(stl), std::move(it));
+}
+
+std::unique_ptr<SparseIterator>
+sparse_tensor::makeSimpleIterator(const SparseTensorLevel &stl, bool dedup) {
+ dedup = dedup && !isUniqueLT(stl.getLT());
+ if (dedup)
+ return std::make_unique<DedupIterator>(stl);
+ return std::make_unique<TrivialIterator>(stl);
+}
+
+std::unique_ptr<SparseIterator>
+sparse_tensor::makeSlicedLevelIterator(std::unique_ptr<SparseIterator> &&sit,
+ Value offset, Value stride, Value size) {
+ return nullptr;
+ // return std::make_unique<FilterIterator>(std::move(sit), offset, stride,
+ // size);
+}
+
+std::unique_ptr<SparseIterator> sparse_tensor::makeNonEmptySubSectIterator(
+ OpBuilder &b, Location l, const SparseIterator *parent,
+ std::unique_ptr<SparseIterator> &&lvlIt, Value size, unsigned stride) {
+ return nullptr;
+ // return std::make_unique<NonEmptySubSectIterator>(
+ // b, l, parent, std::move(lvlIt), size, stride);
+}
+
+std::unique_ptr<SparseIterator> sparse_tensor::makeTraverseSubSectIterator(
+ const SparseIterator *parent, std::unique_ptr<SparseIterator> &&lvlIt) {
+ // return std::make_unique<SubSectIterator>(parent, std::move(lvlIt));
+ return nullptr;
+}
+
#undef CMPI
#undef C_IDX
#undef YIELD
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
index f5c29cda7c54f4..e6249c245b22ec 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
@@ -21,42 +21,203 @@ class SparseTensorLevel {
SparseTensorLevel &operator=(const SparseTensorLevel &) = delete;
public:
- SparseTensorLevel() : SparseTensorLevel(LevelType::Undef, nullptr){};
virtual ~SparseTensorLevel() = default;
- virtual Value peekCrdAt(OpBuilder &b, Location l, Value p) const = 0;
+ virtual Value peekCrdAt(OpBuilder &b, Location l, Value iv) const = 0;
/// Peeks the lower and upper bound to *fully* traverse the level with
/// the given position `p` that the immediate parent level is current at.
+ /// Returns a pair of values for *posLo* and *loopHi* respectively.
+ ///
+ /// For dense level, the *posLo* is the linearized position at beginning,
+ /// while *loopHi* is the largest *coordinate*, it also implies that the
+ /// smallest *coordinate* to start the loop is 0.
+ ///
+ /// For sparse level, [posLo, loopHi) specifies the range of index pointer to
+ /// load coordinate from the coordinate buffer.
+ ///
/// `bound` is only used when the level is `non-unique` and deduplication is
/// required. It specifies the max upper bound of the non-unique segment.
virtual std::pair<Value, Value> peekRangeAt(OpBuilder &b, Location l, Value p,
- Value bound = Value()) const = 0;
+ Value segHi = Value()) const = 0;
+ Level getLevel() const { return lvl; }
LevelType getLT() const { return lt; }
- Value getPos() const { return pos; }
- Value getCrd() const { return crd; }
- Value getLoopHi() const { return loopHi; }
- Value getLoopLo() const { return loopLo; }
+ Value size() const { return lvlSize; }
+
+ //
+ // Level properties
+ //
+ bool isUnique() const { return isUniqueLT(lt); }
protected:
- SparseTensorLevel(LevelType lt, Value lvlSize)
- : lt(lt), lvlSize(lvlSize), pos(nullptr), crd(nullptr), loopHi(nullptr),
- loopLo(nullptr){};
+ SparseTensorLevel(unsigned tid, unsigned lvl, LevelType lt, Value lvlSize)
+ : tid(tid), lvl(lvl), lt(lt), lvlSize(lvlSize){};
+public:
+ const unsigned tid, lvl;
const LevelType lt;
const Value lvlSize;
+};
-public: // TODO: make these values private upon feature complete.
- Value pos;
- Value crd;
- Value loopHi;
- Value loopLo;
+enum class IterKind : uint8_t {
+ kTrivial,
+ kDedup,
+ kSubSect,
+ kNonEmptySubSect,
+ kFilter,
+};
+
+/// Helper class that helps generating loop conditions, etc, to traverse a
+/// sparse tensor level.
+class SparseIterator {
+ SparseIterator(SparseIterator &&) = delete;
+ SparseIterator(const SparseIterator &) = delete;
+ SparseIterator &operator=(SparseIterator &&) = delete;
+ SparseIterator &operator=(const SparseIterator &) = delete;
+
+protected:
+ SparseIterator(IterKind kind, unsigned tid, unsigned lvl,
+ MutableArrayRef<Value> itVals)
+ : kind(kind), tid(tid), lvl(lvl), crd(nullptr), itVals(itVals){};
+
+ SparseIterator(IterKind kind, const SparseIterator *wrap)
+ : kind(kind), tid(wrap->tid), lvl(wrap->lvl), crd(nullptr),
+ itVals(wrap->itVals){};
+
+public:
+ virtual ~SparseIterator() = default;
+
+ Value getCrd() const { return crd; }
+
+ ValueRange getItVals() const { return itVals; };
+ void seek(ValueRange vals) {
+ assert(vals.size() == itVals.size());
+ for (unsigned i = 0, e = vals.size(); i < e; i++)
+ itVals[i] = vals[i];
+ // Now that the iterator is re-positioned, the coordinate becomes invalid.
+ crd = nullptr;
+ }
+
+ //
+ // Iterator properties.
+ //
+
+ // Whether the iterator support random access (i.e., support look up by
+ // *coordinate*).
+ // A random access iterator also traverses a dense space.
+ virtual bool randomAccessible() const = 0;
+ // Whether the iterator can simply traversed by a for loop.
+ virtual bool iteratableByFor() const { return false; };
+
+ //
+ // Core functions.
+ //
+
+ // Peeks the range to iterate on child level at the current position.
+ // See SparseTensorLevel::peekRangeAt();
+ //
+ // Not every type of iterator supports the operations, e.g., non-empty
+ // subsection iterator does not.
+ virtual std::pair<Value, Value>
+ peekNxLvlRange(OpBuilder &, Location, const SparseTensorLevel &) const {
+ llvm_unreachable("unsupported");
+ };
+
+ // Initialize the iterator according to the parent iterator's state.
+ virtual void genInit(OpBuilder &, Location, const SparseIterator *) = 0;
+
+ // Return a tuple of values for *upper*, *lower* bound and *step*
+ // respectively.
+ virtual std::pair<Value, Value> genForCond(OpBuilder &, Location) {
+ llvm_unreachable("Unsupported");
+ }
+
+ virtual Value genIsEnd(OpBuilder &b, Location l) = 0;
+ std::pair<Value, ValueRange> genWhileCond(OpBuilder &b, Location l,
+ ValueRange vs) {
+ seek(vs.take_front(itVals.size()));
+ return std::make_pair(genIsEnd(b, l), vs.drop_front(itVals.size()));
+ }
+
+ // Dereference the iterator, loads the coordinate at the current position.
+ //
+ // The method assumes that the iterator is not currently exhausted (i.e.,
+ // it != it.end()).
+ virtual Value deref(OpBuilder &b, Location l) = 0;
+
+ virtual ValueRange forward(OpBuilder &b, Location l) = 0;
+
+ // Generate a conditional it.next() in the following form
+ //
+ // if (crd == it.crd)
+ // yield it.next
+ // else
+ // yield it
+ //
+ // The function is virtual to allow alternative implementation. For example,
+ // if it.next() is trivial to compute, we can use a select operation instead.
+ // E.g.,
+ //
+ // it = select crd == it.crd ? it+1 : it
+ virtual ValueRange forwardIf(OpBuilder &b, Location l, Value cond);
+
+ // Locate the iterator to the position specified by *crd*, this can only
+ // be done on an iterator that supports randm access.
+ virtual void locate(OpBuilder &b, Location l, Value crd) {
+ llvm_unreachable("Unsupported");
+ }
+
+ // Update the SSA value for the iterator after entering a new scope.
+ ValueRange linkNewScope(ValueRange pos) {
+ assert(!randomAccessible() && "random accessible iterators are traversed "
+ "by coordinate, call locate() instead.");
+ seek(pos.take_front(itVals.size()));
+ return pos.drop_front(itVals.size());
+ };
+
+protected:
+ void updateCrd(Value crd) { this->crd = crd; }
+
+public:
+ const IterKind kind; // For LLVM-style RTTI.
+ const unsigned tid, lvl; // tensor level identifier.
+
+private:
+ Value crd; // The sparse coordinate used to coiterate;
+
+ // A range of value that together defines the current state of the
+ // iterator.
+ //
+ // For trivial iterators, it is the position; for dedup iterators, it consists
+ // of the positon and the segment high, for non-empty subsection iterator, it
+ // is the metadata that specifies the subsection.
+ MutableArrayRef<Value> itVals;
};
/// Helper function to create a TensorLevel object from given `tensor`.
-std::unique_ptr<SparseTensorLevel>
-makeSparseTensorLevel(OpBuilder &builder, Location loc, Value t, Level l);
+std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(OpBuilder &builder,
+ Location loc, Value t,
+ unsigned tid, Level l);
+
+/// Helper function to create a SparseIterator object.
+std::unique_ptr<SparseIterator> makeSimpleIterator(const SparseTensorLevel &stl,
+ bool dedup);
+
+std::pair<std::unique_ptr<SparseTensorLevel>, std::unique_ptr<SparseIterator>>
+makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl);
+
+std::unique_ptr<SparseIterator>
+makeSlicedLevelIterator(std::unique_ptr<SparseIterator> &&sit, Value offset,
+ Value stride, Value size);
+
+std::unique_ptr<SparseIterator> makeNonEmptySubSectIterator(
+ OpBuilder &b, Location l, const SparseIterator *parent,
+ std::unique_ptr<SparseIterator> &&lvlIt, Value size, unsigned stride);
+
+std::unique_ptr<SparseIterator>
+makeTraverseSubSectIterator(const SparseIterator *parent,
+ std::unique_ptr<SparseIterator> &&lvlIt);
} // namespace sparse_tensor
} // namespace mlir
>From 5a289930c5f368d30078c6ad2ca2228f8ff7028c Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 5 Jan 2024 17:40:27 +0000
Subject: [PATCH 02/11] [mlir][sparse] setup FilterIterator to handle sparse
slices.
---
.../Transforms/SparseTensorRewriting.cpp | 20 +-
.../Transforms/Sparsification.cpp | 2 +-
.../Transforms/Utils/LoopEmitter.cpp | 12 +-
.../Transforms/Utils/LoopEmitter.h | 8 +-
.../Transforms/Utils/SparseTensorLevel.cpp | 324 ++++++++++++++----
.../Transforms/Utils/SparseTensorLevel.h | 16 +-
6 files changed, 288 insertions(+), 94 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 76df01800bda8e..d32b8520f38618 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -1105,7 +1105,7 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
LoopEmitter loopEmitter(
ValueRange{input},
StringAttr::get(getContext(), ForeachOp::getOperationName()));
- loopEmitter.initializeLoopEmit(rewriter, loc, /*genDedup=*/false);
+ loopEmitter.initializeLoopEmit(rewriter, loc);
for (Level l = 0; l < lvlRank; l++) {
// TODO: provide utility function for loop sequences that only contains
// one for loop?
@@ -1148,17 +1148,17 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
SmallVector<Value> reducValue = srcBlock->getTerminator()->getOperands();
rewriter.eraseOp(srcBlock->getTerminator());
- // Inline body.
- if (!reducValue.empty()) {
- rewriter.mergeBlocks(srcBlock, rewriter.getBlock(), args);
- } else {
- // This is annoying, since scf.for inserts a implicit yield op when
- // there is no reduction variable upon creation, in this case we need to
- // merge the block *before* the yield op.
- rewriter.inlineBlockBefore(srcBlock, &*rewriter.getInsertionPoint(),
- args);
+ Operation &last = rewriter.getBlock()->back();
+ if (llvm::isa<scf::YieldOp>(last)) {
+ // scf.for inserts a implicit yield op when there is no reduction
+ // variable upon creation, in this case we need to merge the block
+ // *before* the yield op.
+ rewriter.setInsertionPoint(&last);
}
+ rewriter.inlineBlockBefore(srcBlock, rewriter.getBlock(),
+ rewriter.getInsertionPoint(), args);
+ rewriter.setInsertionPointToEnd(rewriter.getBlock());
for (Level l = 0; l < lvlRank; l++) {
// Link the reduction chain. Note that loop emitter update the reducValue
// in place.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 9dd9cd42b7d3ad..918a0911b5e0aa 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -294,7 +294,7 @@ static void genBuffers(CodegenEnv &env, OpBuilder &builder) {
.createLoopRanges(builder, loc);
env.emitter().initializeLoopEmit(
- builder, loc, /*genDedup=*/true,
+ builder, loc,
/// Generates buffer for the output tensor.
/// Note that all sparse kernels assume that when all elements are written
/// to (viz. x(i) = y(i) * z(i)), the output buffer is already initialized
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
index a972de04db0a64..302b4da932017c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
@@ -410,8 +410,8 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
std::unique_ptr<SparseIterator>
LoopEmitter::makeLevelIterator(OpBuilder &builder, Location loc, TensorId t,
- Level l, bool genDedup) {
- auto it = makeSimpleIterator(*lvls[t][l], genDedup);
+ Level l) {
+ auto it = makeSimpleIterator(*lvls[t][l]);
if (isSparseSlices[t]) {
Value offset = genSliceOffset(builder, loc, tensors[t], l);
Value stride = genSliceStride(builder, loc, tensors[t], l);
@@ -426,10 +426,8 @@ LoopEmitter::makeLevelIterator(OpBuilder &builder, Location loc, TensorId t,
}
void LoopEmitter::initializeLoopEmit(
- OpBuilder &builder, Location loc, bool genDedup,
- LoopEmitter::OutputUpdater updater,
+ OpBuilder &builder, Location loc, LoopEmitter::OutputUpdater updater,
LoopEmitter::SynTensorBoundSetter synSetter) {
- this->genDedup = genDedup;
// For every synthetic tensor, set the high bound by calling the callback.
if (synSetter) {
TensorId synId = getSynTensorId();
@@ -478,7 +476,7 @@ void LoopEmitter::initializeLoopEmit(
if (!dependentLvlMap[t][l].empty())
continue;
- auto it = makeLevelIterator(builder, loc, t, l, genDedup);
+ auto it = makeLevelIterator(builder, loc, t, l);
iters[t][l].emplace_back(std::move(it));
}
@@ -550,7 +548,7 @@ void LoopEmitter::initSubSectIterator(OpBuilder &builder, Location loc) {
assert(curDep.first == loop);
remDepStack[t][lvl].pop_back();
- auto lvlIt = makeLevelIterator(builder, loc, t, lvl, genDedup);
+ auto lvlIt = makeLevelIterator(builder, loc, t, lvl);
const SparseIterator *parent =
lvl == 0 && iters[t][lvl].empty()
? nullptr
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
index 4d0ba11cacfc77..9ab99f4feb5627 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
@@ -95,7 +95,7 @@ class LoopEmitter {
/// Starts a loop emitting session by generating all the buffers needed
/// for iterating over the tensors.
- void initializeLoopEmit(OpBuilder &builder, Location loc, bool genDedup,
+ void initializeLoopEmit(OpBuilder &builder, Location loc,
OutputUpdater updater = nullptr,
SynTensorBoundSetter synSetter = nullptr);
@@ -608,9 +608,8 @@ class LoopEmitter {
/// return true if has already been resolved.
bool genSliceBegin(OpBuilder &builder, Location loc, TensorId tid, Level lvl);
- std::unique_ptr<SparseIterator> makeLevelIterator(OpBuilder &builder,
- Location loc, TensorId tid,
- Level l, bool genDedup);
+ std::unique_ptr<SparseIterator>
+ makeLevelIterator(OpBuilder &builder, Location loc, TensorId tid, Level l);
/// Generates code to get the next non-empty slices of tid on lvl.
/// Returns a tuple of values for <NonEmpty, MinCrd, AbsOffset> (see
@@ -652,7 +651,6 @@ class LoopEmitter {
std::vector<std::vector<Value>> segHi;
std::vector<std::vector<Value>> highs;
std::vector<std::vector<Value>> lvlSizes;
- bool genDedup; // TODO: remove it.
//
// Slice-driven loops related fields.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
index 58cdbd1645eff2..26ddc9b50c107d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
@@ -22,17 +22,21 @@ using ValueTuple = std::tuple<Value, Value, Value>;
// File local helper functions/macros.
//===----------------------------------------------------------------------===//
#define CMPI(p, lhs, rhs) \
- (b.create<arith::CmpIOp>(l, arith::CmpIPredicate::p, (lhs), (rhs)))
+ (b.create<arith::CmpIOp>(l, arith::CmpIPredicate::p, (lhs), (rhs)) \
+ .getResult())
+#define C_FALSE (constantI1(b, l, false))
#define C_IDX(v) (constantIndex(b, l, (v)))
#define YIELD(vs) (b.create<scf::YieldOp>(l, (vs)))
-#define ADDI(lhs, rhs) (b.create<arith::AddIOp>(l, (lhs), (rhs)))
-#define ANDI(lhs, rhs) (b.create<arith::AndIOp>(l, (lhs), (rhs)))
-#define SUBI(lhs, rhs) (b.create<arith::SubIOp>(l, (lhs), (rhs)))
-#define MULI(lhs, rhs) (b.create<arith::MulIOp>(l, (lhs), (rhs)))
-#define REMUI(lhs, rhs) (b.create<arith::RemUIOp>(l, (lhs), (rhs)))
-#define DIVUI(lhs, rhs) (b.create<arith::DivUIOp>(l, (lhs), (rhs)))
-#define SELECT(c, lhs, rhs) (b.create<arith::SelectOp>(l, (c), (lhs), (rhs)))
+#define ADDI(lhs, rhs) (b.create<arith::AddIOp>(l, (lhs), (rhs)).getResult())
+#define ORI(lhs, rhs) (b.create<arith::OrIOp>(l, (lhs), (rhs)).getResult())
+#define ANDI(lhs, rhs) (b.create<arith::AndIOp>(l, (lhs), (rhs)).getResult())
+#define SUBI(lhs, rhs) (b.create<arith::SubIOp>(l, (lhs), (rhs)).getResult())
+#define MULI(lhs, rhs) (b.create<arith::MulIOp>(l, (lhs), (rhs)).getResult())
+#define REMUI(lhs, rhs) (b.create<arith::RemUIOp>(l, (lhs), (rhs)).getResult())
+#define DIVUI(lhs, rhs) (b.create<arith::DivUIOp>(l, (lhs), (rhs)).getResult())
+#define SELECT(c, lhs, rhs) \
+ (b.create<arith::SelectOp>(l, (c), (lhs), (rhs)).getResult())
// Helper functions that load/store into the position buffer for slice-driven
// loops.
@@ -218,20 +222,17 @@ class TrivialIterator : public SparseIterator {
bool randomAccessible() const override { return isDenseLT(stl.getLT()); };
bool iteratableByFor() const override { return true; };
- ValuePair peekNxLvlRange(OpBuilder &b, Location l,
- const SparseTensorLevel &stl) const override {
- assert(stl.tid == this->tid && stl.lvl - 1 == this->lvl);
- return stl.peekRangeAt(b, l, itPos);
- }
+ ValuePair getCurPosition() const override { return {itPos, nullptr}; }
void genInit(OpBuilder &b, Location l,
const SparseIterator *parent) override {
+ Value pos = C_IDX(0);
+ Value hi = nullptr;
if (parent)
- std::tie(posLo, loopHi) = parent->peekNxLvlRange(b, l, stl);
- else
- std::tie(posLo, loopHi) = stl.peekRangeAt(b, l, C_IDX(0));
+ std::tie(pos, hi) = parent->getCurPosition();
- // Only randomly accessible iterator's position need to be linearized.
+ std::tie(posLo, loopHi) = stl.peekRangeAt(b, l, pos, hi);
+ // Seek to the lowest position.
seek(posLo);
}
@@ -240,7 +241,7 @@ class TrivialIterator : public SparseIterator {
return std::make_pair(getLoopLo(b, l), loopHi);
}
- Value genIsEnd(OpBuilder &b, Location l) override {
+ Value genNotEnd(OpBuilder &b, Location l) override {
// We used the first level bound as the bound the collapsed set of levels.
return CMPI(ult, itPos, loopHi);
}
@@ -251,14 +252,14 @@ class TrivialIterator : public SparseIterator {
};
ValueRange forward(OpBuilder &b, Location l) override {
- seek(ADDI(itPos, C_IDX(1)).getResult());
+ seek(ADDI(itPos, C_IDX(1)));
return getItVals();
}
void locate(OpBuilder &b, Location l, Value crd) override {
assert(randomAccessible());
// Seek to the linearized position.
- seek(ADDI(crd, posLo).getResult());
+ seek(ADDI(crd, posLo));
updateCrd(crd);
}
@@ -286,26 +287,24 @@ class DedupIterator : public SparseIterator {
bool randomAccessible() const override { return false; };
bool iteratableByFor() const override { return false; };
- ValuePair peekNxLvlRange(OpBuilder &b, Location l,
- const SparseTensorLevel &stl) const override {
- assert(stl.tid == this->tid && stl.lvl - 1 == this->lvl);
- return stl.peekRangeAt(b, l, getPos(), getSegHi());
- }
+ ValuePair getCurPosition() const override { return {getPos(), getSegHi()}; }
void genInit(OpBuilder &b, Location l,
const SparseIterator *parent) override {
- Value posLo;
+ Value pos = C_IDX(0);
+ Value hi = nullptr;
if (parent)
- std::tie(posLo, loopHi) = parent->peekNxLvlRange(b, l, stl);
- else
- std::tie(posLo, loopHi) = stl.peekRangeAt(b, l, C_IDX(0));
+ std::tie(pos, hi) = parent->getCurPosition();
+
+ Value posLo;
+ std::tie(posLo, posHi) = stl.peekRangeAt(b, l, pos, hi);
seek({posLo, genSegmentHigh(b, l, posLo)});
}
- Value genIsEnd(OpBuilder &b, Location l) override {
- return CMPI(ult, getPos(), loopHi);
+ Value genNotEnd(OpBuilder &b, Location l) override {
+ return CMPI(ult, getPos(), posHi);
}
Value deref(OpBuilder &b, Location l) override {
@@ -322,11 +321,145 @@ class DedupIterator : public SparseIterator {
Value getPos() const { return posAndSegHi[0]; }
Value getSegHi() const { return posAndSegHi[1]; }
- Value loopHi;
+ Value posHi;
Value posAndSegHi[2]; // position and segment high
const SparseTensorLevel &stl;
};
+class FilterIterator : public SparseIterator {
+ // Coorindate translation between crd loaded from the wrap iterator and the
+ // filter iterator.
+ Value fromWrapCrd(OpBuilder &b, Location l, Value wrapCrd) {
+ // crd = (wrapCrd - offset) / stride
+ return DIVUI(SUBI(wrapCrd, offset), stride);
+ }
+ Value toWrapCrd(OpBuilder &b, Location l, Value crd) {
+ // wrapCrd = crd * stride + offset
+ return ADDI(MULI(crd, stride), offset);
+ }
+
+ ValueRange genWhenWrapInBound(
+ OpBuilder &b, Location l, ValueRange elseRet,
+ llvm::function_ref<ValueRange(OpBuilder &, Location, Value)> builder);
+
+ Value genCrdNotLegitPredicate(OpBuilder &b, Location l, Value wrapCrd);
+
+ Value genShouldFilter(OpBuilder &b, Location l);
+
+public:
+ FilterIterator(std::unique_ptr<SparseIterator> &&w, Value offset,
+ Value stride, Value size)
+ : SparseIterator(IterKind::kFilter, w.get()), offset(offset),
+ stride(stride), size(size), wrap(std::move(w)) {}
+
+ // For LLVM-style RTTI.
+ static bool classof(const SparseIterator *from) {
+ return from->kind == IterKind::kFilter;
+ }
+
+ bool randomAccessible() const override { return wrap->randomAccessible(); };
+ bool iteratableByFor() const override { return randomAccessible(); };
+
+ ValuePair getCurPosition() const override { return wrap->getCurPosition(); }
+
+ void genInit(OpBuilder &b, Location l,
+ const SparseIterator *parent) override {
+ wrap->genInit(b, l, parent);
+ if (!randomAccessible()) {
+ // TODO: we can skip this when stride == 1 and offset == 0, we can also
+ // use binary search here.
+ forwardIf(b, l, genShouldFilter(b, l));
+ }
+ }
+
+ ValuePair genForCond(OpBuilder &b, Location l) override {
+ assert(randomAccessible());
+
+ auto [lo, hi] = wrap->genForCond(b, l);
+ // if offset < lo, we use lo - offset as the new lower bound, else we use 0.
+ Value loInBound = CMPI(ult, offset, lo);
+ lo = SELECT(loInBound, SUBI(lo, offset), C_IDX(0));
+ return {lo, size};
+ }
+
+ Value genNotEnd(OpBuilder &b, Location l) override;
+
+ Value deref(OpBuilder &b, Location l) override {
+ updateCrd(fromWrapCrd(b, l, wrap->deref(b, l)));
+ return getCrd();
+ }
+
+ void locate(OpBuilder &b, Location l, Value crd) override {
+ assert(randomAccessible());
+ wrap->locate(b, l, toWrapCrd(b, l, crd));
+ updateCrd(crd);
+ }
+
+ ValueRange forward(OpBuilder &b, Location l) override;
+
+ const Value offset, stride, size;
+ std::unique_ptr<SparseIterator> wrap;
+};
+
+/*
+class NonEmptySubSectIterator : public SparseIterator {
+public:
+ NonEmptySubSectIterator(OpBuilder &b, Location l,
+ const SparseIterator *parent,
+ std::unique_ptr<SparseIterator> &&w, Value size)
+ : SparseIterator(IterKind::kNonEmptySubSect, w->tid, w->lvl),
+ parent(parent), wrap(std::move(w)), size(size), stride(stride) {
+
+ auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
+ if (p == nullptr) {
+ // Extract subsections along the root level.
+ prevUnResCnt = C_IDX(1);
+ } else if (p->lvl == lvl) {
+ // Extract subsections along the same level.
+ prevUnResCnt = p->prevUnResCnt;
+ } else {
+ // Extract subsections along the previous level.
+ assert(p->lvl + 1 == lvl);
+ prevUnResCnt = MULI(p->prevUnResCnt, p->size);
+ }
+
+ // We don't need an extra buffer to find subsections on dense levels.
+ if (randomAccessible())
+ return;
+ subSectPosBuf = allocSlicePosBuf(b, l, prevUnResCnt);
+ }
+
+ // For LLVM-style RTTI.
+ static bool classof(const SparseIterator *from) {
+ return from->kind == IterKind::kNonEmptySubSect;
+ }
+
+ bool randomAccessible() const override { return wrap->randomAccessible(); };
+ bool iteratableByFor() const override { return randomAccessible(); };
+
+ Value size, prevUnResCnt, subSectPosBuf;
+ unsigned stride;
+};
+
+class SubSectIterator : public SparseIterator {
+public:
+ SubSectIterator(const SparseIterator *parent,
+ std::unique_ptr<SparseIterator> &&w)
+ : SparseIterator(IterKind::kSubSect, w->tid, w->lvl), parent(parent),
+ wrap(std::move(w)) {}
+
+ // For LLVM-style RTTI.
+ static bool classof(const SparseIterator *from) {
+ return from->kind == IterKind::kSubSect;
+ }
+
+ bool randomAccessible() const override { return wrap->randomAccessible(); };
+ bool iteratableByFor() const override { return randomAccessible(); };
+
+ const SparseIterator *parent;
+ std::unique_ptr<SparseIterator> wrap;
+};
+*/
} // namespace
//===----------------------------------------------------------------------===//
@@ -353,7 +486,7 @@ Value DedupIterator::genSegmentHigh(OpBuilder &b, Location l, Value pos) {
l, pos.getType(), pos,
/*beforeBuilder=*/
[this, pos](OpBuilder &b, Location l, ValueRange ivs) {
- Value inBound = CMPI(ult, ivs.front(), loopHi);
+ Value inBound = CMPI(ult, ivs.front(), posHi);
auto ifInBound = b.create<scf::IfOp>(l, b.getI1Type(), inBound, true);
{
OpBuilder::InsertionGuard guard(b);
@@ -379,28 +512,92 @@ Value DedupIterator::genSegmentHigh(OpBuilder &b, Location l, Value pos) {
return whileOp.getResult(0);
}
-Value FilterIterator::genShouldFilter(OpBuilder &b, Location l) {
- Value end = wrap->genIsEnd(b, l);
+ValueRange FilterIterator::genWhenWrapInBound(
+ OpBuilder &b, Location l, ValueRange elseRet,
+ llvm::function_ref<ValueRange(OpBuilder &, Location, Value)> builder) {
+ // !it.end() ? callback(*crd) : resOOB;
+ TypeRange ifRetTypes = elseRet.getTypes();
+ auto ifOp = b.create<scf::IfOp>(l, ifRetTypes, wrap->genNotEnd(b, l), true);
- auto shouldFilter = b.create<scf::IfOp>(l, b.getI1Type(), end, true);
- // it.end() ? false : should_filter(*it);
- b.setInsertionPointToStart(shouldFilter.thenBlock());
- YIELD(constantI1(b, l, false));
-
- // Iterator not at the end.
- b.setInsertionPointToStart(shouldFilter.elseBlock());
+ b.setInsertionPointToStart(ifOp.thenBlock());
Value wrapCrd = wrap->deref(b, l);
+ YIELD(builder(b, l, wrapCrd));
+
+ b.setInsertionPointToStart(ifOp.elseBlock());
+ YIELD(elseRet);
+
+ b.setInsertionPointAfter(ifOp);
+ return ifOp.getResults();
+}
+
+Value FilterIterator::genCrdNotLegitPredicate(OpBuilder &b, Location l,
+ Value wrapCrd) {
Value crd = fromWrapCrd(b, l, wrapCrd);
- // on stride
- Value legit = CMPI(eq, toWrapCrd(b, l, crd), wrapCrd);
- // wrapCrd >= offset
- legit = ANDI(CMPI(uge, wrapCrd, offset), legit);
- // crd < length
- legit = ANDI(CMPI(ult, crd, size), legit);
- YIELD(legit);
-
- b.setInsertionPointAfter(shouldFilter);
- return shouldFilter.getResult(0);
+ // not on stride
+ Value notlegit = CMPI(ne, toWrapCrd(b, l, crd), wrapCrd);
+ // wrapCrd < offset
+ notlegit = ORI(CMPI(ult, wrapCrd, offset), notlegit);
+ // crd >= length
+ notlegit = ORI(CMPI(uge, crd, size), notlegit);
+ return notlegit;
+}
+
+Value FilterIterator::genShouldFilter(OpBuilder &b, Location l) {
+ ValueRange r = genWhenWrapInBound(
+ b, l, C_FALSE, [this](OpBuilder &b, Location l, Value wrapCrd) {
+ Value notLegit = genCrdNotLegitPredicate(b, l, wrapCrd);
+ return notLegit.getDefiningOp()->getResults();
+ });
+
+ assert(r.size() == 1);
+ return r.front();
+}
+
+Value FilterIterator::genNotEnd(OpBuilder &b, Location l) {
+ assert(!wrap->randomAccessible());
+ ValueRange r = genWhenWrapInBound(
+ b, l, C_FALSE, [this](OpBuilder &b, Location l, Value wrapCrd) {
+ Value crd = fromWrapCrd(b, l, wrapCrd);
+ // crd < size
+ return CMPI(ult, crd, size).getDefiningOp()->getResults();
+ });
+ assert(r.size() == 1);
+ return r.front();
+}
+
+ValueRange FilterIterator::forward(OpBuilder &b, Location l) {
+ assert(!randomAccessible());
+ // Generates
+ //
+ // wrap ++;
+ // while !it.end() && !legit(*it)
+ // wrap ++;
+ wrap->forward(b, l);
+ auto whileOp = b.create<scf::WhileOp>(
+ l, getItVals().getTypes(), getItVals(),
+ /*beforeBuilder=*/
+ [this](OpBuilder &b, Location l, ValueRange ivs) {
+ linkNewScope(ivs);
+ ValueRange cont = genWhenWrapInBound(
+ b, l, C_FALSE, [this](OpBuilder &b, Location l, Value wrapCrd) {
+ // crd < size && !legit();
+ Value notLegit = genCrdNotLegitPredicate(b, l, wrapCrd);
+ Value crd = fromWrapCrd(b, l, wrapCrd);
+ Value ret = ANDI(CMPI(ult, crd, size), notLegit);
+ return ret.getDefiningOp()->getResults();
+ });
+ b.create<scf::ConditionOp>(l, cont.front(), ivs);
+ },
+ /*afterBuilder=*/
+ [this](OpBuilder &b, Location l, ValueRange ivs) {
+ linkNewScope(ivs);
+ wrap->forward(b, l);
+ YIELD(getItVals());
+ });
+
+ b.setInsertionPointAfter(whileOp);
+ linkNewScope(whileOp.getResults());
+ return getItVals();
}
std::unique_ptr<SparseTensorLevel>
@@ -445,33 +642,34 @@ sparse_tensor::makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl) {
}
std::unique_ptr<SparseIterator>
-sparse_tensor::makeSimpleIterator(const SparseTensorLevel &stl, bool dedup) {
- dedup = dedup && !isUniqueLT(stl.getLT());
- if (dedup)
+sparse_tensor::makeSimpleIterator(const SparseTensorLevel &stl) {
+ if (!isUniqueLT(stl.getLT())) {
+ // We always dedupliate the non-unique level, but we should optimize it away
+ // if possible.
return std::make_unique<DedupIterator>(stl);
+ }
return std::make_unique<TrivialIterator>(stl);
}
std::unique_ptr<SparseIterator>
sparse_tensor::makeSlicedLevelIterator(std::unique_ptr<SparseIterator> &&sit,
Value offset, Value stride, Value size) {
- return nullptr;
- // return std::make_unique<FilterIterator>(std::move(sit), offset, stride,
- // size);
+
+ return std::make_unique<FilterIterator>(std::move(sit), offset, stride, size);
}
std::unique_ptr<SparseIterator> sparse_tensor::makeNonEmptySubSectIterator(
OpBuilder &b, Location l, const SparseIterator *parent,
- std::unique_ptr<SparseIterator> &&lvlIt, Value size, unsigned stride) {
+ std::unique_ptr<SparseIterator> &&delegate, Value size, unsigned stride) {
return nullptr;
- // return std::make_unique<NonEmptySubSectIterator>(
- // b, l, parent, std::move(lvlIt), size, stride);
+ // return std::make_unique<NonEmptySubSectIterator>(
+ // b, l, parent, std::move(lvlIt), size, stride);
}
std::unique_ptr<SparseIterator> sparse_tensor::makeTraverseSubSectIterator(
- const SparseIterator *parent, std::unique_ptr<SparseIterator> &&lvlIt) {
- // return std::make_unique<SubSectIterator>(parent, std::move(lvlIt));
+ const SparseIterator *, std::unique_ptr<SparseIterator> &&delegate) {
return nullptr;
+ // return std::make_unique<SubSectIterator>(parent, std::move(lvlIt));
}
#undef CMPI
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
index e6249c245b22ec..770a6eb9b78d1f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
@@ -114,13 +114,13 @@ class SparseIterator {
// Core functions.
//
- // Peeks the range to iterate on child level at the current position.
- // See SparseTensorLevel::peekRangeAt();
+ // Get the current position and the optional *position high* (for non-unique
+ // iterators), the value should be able to uniquely identify the sparse range
+ // for the next level. See SparseTensorLevel::peekRangeAt();
//
// Not every type of iterator supports the operations, e.g., non-empty
// subsection iterator does not.
- virtual std::pair<Value, Value>
- peekNxLvlRange(OpBuilder &, Location, const SparseTensorLevel &) const {
+ virtual std::pair<Value, Value> getCurPosition() const {
llvm_unreachable("unsupported");
};
@@ -133,11 +133,11 @@ class SparseIterator {
llvm_unreachable("Unsupported");
}
- virtual Value genIsEnd(OpBuilder &b, Location l) = 0;
+ virtual Value genNotEnd(OpBuilder &b, Location l) = 0;
std::pair<Value, ValueRange> genWhileCond(OpBuilder &b, Location l,
ValueRange vs) {
seek(vs.take_front(itVals.size()));
- return std::make_pair(genIsEnd(b, l), vs.drop_front(itVals.size()));
+ return std::make_pair(genNotEnd(b, l), vs.drop_front(itVals.size()));
}
// Dereference the iterator, loads the coordinate at the current position.
@@ -201,8 +201,8 @@ std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(OpBuilder &builder,
unsigned tid, Level l);
/// Helper function to create a SparseIterator object.
-std::unique_ptr<SparseIterator> makeSimpleIterator(const SparseTensorLevel &stl,
- bool dedup);
+std::unique_ptr<SparseIterator>
+makeSimpleIterator(const SparseTensorLevel &stl);
std::pair<std::unique_ptr<SparseTensorLevel>, std::unique_ptr<SparseIterator>>
makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl);
>From b16804101fe3d541c7a5413d7743409876e9acea Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 9 Jan 2024 21:54:04 +0000
Subject: [PATCH 03/11] setup non-empty subsection iterator and support 1d
convolution
---
.../Transforms/Sparsification.cpp | 6 +
.../Transforms/Utils/LoopEmitter.cpp | 102 ++--
.../Transforms/Utils/LoopEmitter.h | 13 +-
.../Transforms/Utils/SparseTensorLevel.cpp | 503 ++++++++++++++----
.../Transforms/Utils/SparseTensorLevel.h | 32 +-
5 files changed, 471 insertions(+), 185 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 918a0911b5e0aa..5d890e8b035d0c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1035,6 +1035,8 @@ static bool getAllTidLvlsInLatPoints(
// Note that we generate dense indices of the output tensor
// unconditionally, since they may not appear in the lattice, but may be
// needed for linearized env.
+ // TODO: we should avoid introducing corner cases for all-dense sparse
+ // tensors.
if (stt.hasEncoding() && stt.isAllDense())
callback(env.makeTensorLevel(outTid, *outLvl), nullptr);
}
@@ -1065,6 +1067,10 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
SmallVector<TensorLevel> tidLvls;
getAllTidLvlsInLatPoints(env, l0, curr, [&](TensorLevel tl, AffineExpr) {
+ // TODO: remove this! Duplication can be introduced due to the speical
+ // handling for all-dense "sparse" output tensor.
+ if (llvm::find(tidLvls, tl) != tidLvls.end())
+ return;
tidLvls.emplace_back(tl);
});
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
index 302b4da932017c..da0d339427920f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
@@ -566,7 +566,10 @@ void LoopEmitter::initSubSectIterator(OpBuilder &builder, Location loc) {
it = makeNonEmptySubSectIterator(builder, loc, parent, std::move(lvlIt),
size, curDep.second);
} else {
- it = makeTraverseSubSectIterator(parent, std::move(lvlIt));
+ Value size = highs[getSynTensorId()][loop];
+ const SparseIterator &subSectIter = *iters[t][lvl].back();
+ it = makeTraverseSubSectIterator(subSectIter, *parent, std::move(lvlIt),
+ size, curDep.second);
}
iters[t][lvl].emplace_back(std::move(it));
}
@@ -678,10 +681,7 @@ void LoopEmitter::categorizeIterators(
// Finds out the tensor level that we should use to generate loops. Amongs all
// the tensor levels, there is at most one sparse tensor level.
for (auto [t, l] : unpackTensorLevelRange(tidLvls)) {
- SparseIterator *it =
- dependentLvlMap[t][l].empty()
- ? iters[t][l].back().get()
- : iters[t][l][iters[t][l].size() - remDepOnLevel(t, l)].get();
+ SparseIterator *it = &getCurIterator(t, l);
if (it->randomAccessible())
raIters.push_back(it);
else
@@ -699,35 +699,24 @@ void LoopEmitter::enterNewLoopSeq(OpBuilder &builder, Location loc,
// TODO: sort
assert(loopSeqStack.size() == loopStack.size());
// Prepares for all the tensors used in the current loop sequence.
- std::vector<std::tuple<TensorId, Level, bool>> slicedTids;
for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) {
- if (!dependentLvlMap[tid][lvl].empty()) {
- bool fullyRed = genSliceBegin(builder, loc, tid, lvl);
- slicedTids.emplace_back(tid, lvl, fullyRed);
- } else {
- prepareLoopOverTensorAtLvl(builder, loc, tid, lvl);
- }
+ levelReducedDep[tid][lvl]++;
+ prepareLoopOverTensorAtLvl(builder, loc, tid, lvl);
}
// Universal Index starts from 0.
- loopSeqStack.emplace_back(C_IDX(0), std::move(slicedTids));
+ loopSeqStack.emplace_back(C_IDX(0), tidLvls.vec());
}
void LoopEmitter::exitCurrentLoopSeq(OpBuilder &builder, Location loc) {
assert(loopSeqStack.size() == loopStack.size() + 1);
- const auto &slicedTids = loopSeqStack.back().second;
-
// Depending on whether the slice is resolved or not at current loop sequence,
// end them in different ways.
- for (auto [tid, lvl, res] : slicedTids) {
- if (!res) {
- // If this is a unresolved-slice-driven loop, pops out the slice.
- assert(sliceStack[tid].back().slicedOnLvl == lvl);
- sliceStack[tid].pop_back();
- }
- }
+ for (auto [tid, lvl] : unpackTensorLevelRange(loopSeqStack.back().second))
+ levelReducedDep[tid][lvl]--;
+
loopSeqStack.pop_back();
}
@@ -1362,11 +1351,15 @@ void LoopEmitter::genDenseAffineAddress(OpBuilder &builder, Location loc,
void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
TensorId tid, Level lvl) {
- assert(isValidLevel(tid, lvl));
+ // if this is the first level, there is no parent iterator for the current
+ // iterator.
+ // If the current iterator is a subsection-based iterator, the parent iterator
+ // is memorized by the iterator.
+ bool hasParent = lvl == 0 || !dependentLvlMap[tid][lvl].empty();
+
const SparseIterator *parent =
- lvl == 0 ? nullptr : iters[tid][lvl - 1].back().get();
- SparseIterator &curIt = *iters[tid][lvl].back();
- curIt.genInit(builder, loc, parent);
+ hasParent ? nullptr : iters[tid][lvl - 1].back().get();
+ getCurIterator(tid, lvl).genInit(builder, loc, parent);
}
void LoopEmitter::enterTensorsAtDenseLvls(
@@ -1440,7 +1433,6 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
(void)reduced;
info.minCrd = info.offset = info.isNonEmpty = Value();
}
- levelReducedDep[tid][lvl]--;
}
if (auto forOp = llvm::dyn_cast<scf::ForOp>(loopInfo.loop)) {
if (!reduc.empty()) {
@@ -1535,48 +1527,26 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
unsigned delta = 0;
ValueRange whileRes = whileOp.getResults();
for (auto [tid, lvl, resolved] : loopInfo.sliceDrivenInfo) {
- // TODO: handle dense.
- assert(isCompressedLT(lvlTypes[tid][lvl]));
- levelReducedDep[tid][lvl]--;
- if (!resolved) {
- // TODO: support coiterating multiple slices
- assert(loopInfo.sliceDrivenInfo.size() == 1);
- auto [nxNonEmpty, nxMinCrd, nxAbsOffset] =
- genSliceNextInduction(builder, loc, tid, lvl);
- // Update while loop induction operands.
- operands.push_back(nxNonEmpty);
- operands.push_back(nxMinCrd);
- operands.push_back(nxAbsOffset);
-
- // Update the slice stack.
- SliceInfo &info = sliceStack[tid].back();
- info.isNonEmpty = whileOp.getResult(o++);
- info.minCrd = whileOp.getResult(o++);
- info.offset = whileOp.getResult(o++);
- continue;
- }
-
- Value forwarded = nullptr;
- if (loopInfo.trivialTidLvls.empty() &&
- loopInfo.sliceDrivenInfo.size() == 1) {
- // Forwards the position iterator.
- operands.push_back(ADDI(posits[tid][lvl], one));
- forwarded = constantI1(builder, loc, true);
+ SparseIterator &it = getCurIterator(tid, lvl);
+ if (!it.randomAccessible()) {
+ // Forward the sparse iterator.
+ Value cmp = CMPI(eq, it.getCrd(), iv);
+ it.forwardIf(builder, loc, cmp);
+ operands.append(it.getItVals().begin(), it.getItVals().end());
+ o += it.getItVals().size();
+ // Following loops continue iteration from the break point of the
+ // current while loop.
+ whileRes = it.linkNewScope(whileRes);
} else {
- const Value pos = posits[tid][lvl];
- const Value nxPos = ADDI(posits[tid][lvl], one);
- forwarded = CMPI(eq, coords[tid][lvl], iv);
- operands.push_back(SELECT(forwarded, nxPos, pos));
+ // Make sure randomly accessible (dense) iterator is set to the right
+ // position according to the universal index.
+ Value uniIdx = whileOp.getResults().back();
+ it.locate(builder, loc, uniIdx);
}
- // The coordinate is invalid now.
- coords[tid][lvl] = nullptr;
-
- // Update the position iterator as we exit the while loop.
- posits[tid][lvl] = whileOp->getResult(o++);
};
for (auto [tid, lvl] : unpackTensorLevelRange(loopInfo.trivialTidLvls)) {
- SparseIterator &it = *iters[tid][lvl].back();
+ SparseIterator &it = getCurIterator(tid, lvl);
if (!it.randomAccessible()) {
// Forward the sparse iterator.
Value cmp = CMPI(eq, it.getCrd(), iv);
@@ -1664,6 +1634,10 @@ unsigned LoopEmitter::remDepOnLevel(TensorId tid, Level lvl) const {
return totalDependencies;
}
+unsigned LoopEmitter::redDepOnLevel(TensorId tid, Level lvl) const {
+ return levelReducedDep[tid][lvl];
+}
+
const LoopEmitter::SliceInfo &LoopEmitter::getMostRecentSliceOnLvl(TensorId tid,
Level lvl) {
// Finds the most-recent slice using a reverse iteration.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
index 9ab99f4feb5627..aafb56f03ef607 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
@@ -554,6 +554,13 @@ class LoopEmitter {
/// Get the remaining number of constraints needed to fully *resolve*
/// dependent levels on tensor[tid].
unsigned remDepOnLevel(TensorId tid, Level lvl) const;
+ /// Get the reduced number of contraints on tensor[tid][lvl].
+ unsigned redDepOnLevel(TensorId tid, Level lvl) const;
+
+ SparseIterator &getCurIterator(TensorId tid, Level lvl) const {
+ assert(redDepOnLevel(tid, lvl) >= 1);
+ return *iters[tid][lvl][redDepOnLevel(tid, lvl) - 1];
+ }
/// Whether the tid, lvl is fully *reduced*, i.e., the non-trivial index
/// expression has been reduced to a trivial one.
@@ -695,10 +702,8 @@ class LoopEmitter {
std::vector<LoopInfo> loopStack;
// Loop Sequence Stack, stores the unversial index for the current loop
- // sequence. and a list of tids which was taken sliced.
- // TODO: maybe we should have a LoopSeqInfo
- std::vector<std::pair<Value, std::vector<std::tuple<TensorId, Level, bool>>>>
- loopSeqStack;
+ // sequence. and a list of tid level that the loop sequence traverse.
+ std::vector<std::pair<Value, std::vector<TensorLevel>>> loopSeqStack;
};
} // namespace sparse_tensor
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
index 26ddc9b50c107d..79ba3230ac068d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
@@ -26,6 +26,7 @@ using ValueTuple = std::tuple<Value, Value, Value>;
.getResult())
#define C_FALSE (constantI1(b, l, false))
+#define C_TRUE (constantI1(b, l, true))
#define C_IDX(v) (constantIndex(b, l, (v)))
#define YIELD(vs) (b.create<scf::YieldOp>(l, (vs)))
#define ADDI(lhs, rhs) (b.create<arith::AddIOp>(l, (lhs), (rhs)).getResult())
@@ -38,46 +39,6 @@ using ValueTuple = std::tuple<Value, Value, Value>;
#define SELECT(c, lhs, rhs) \
(b.create<arith::SelectOp>(l, (c), (lhs), (rhs)).getResult())
-// Helper functions that load/store into the position buffer for slice-driven
-// loops.
-static constexpr unsigned kSliceIterWidth = 3;
-// The sliced pointer buffer is organized as:
-// [[pLo0, pLo1, pLo2, ...],
-// [pHi0, pHi1, pHi2, ...],
-// [pNx0, pNx1, pNx2, ...]]
-static Value allocSlicePosBuf(OpBuilder &b, Location l, Value tupleCnt) {
- Value bufSz = MULI(tupleCnt, C_IDX(kSliceIterWidth));
- // Additional two metadata {memSize, idx} at head.
- return genAlloca(b, l, bufSz, b.getIndexType());
-}
-
-// Gets and sets position values for slice-driven loops.
-enum class SlicePosKind { kLo, kHi, kNext };
-static Value getSlicePosIdx(OpBuilder &b, Location l, Value posBuf,
- Value tupleIdx, SlicePosKind posKind) {
- Value dim = b.create<memref::DimOp>(l, posBuf, C_IDX(0));
- Value tupleCnt = DIVUI(dim, C_IDX(kSliceIterWidth));
- switch (posKind) {
- case SlicePosKind::kLo:
- return tupleIdx;
- case SlicePosKind::kHi:
- return ADDI(tupleIdx, tupleCnt);
- case SlicePosKind::kNext:
- return ADDI(tupleIdx, MULI(tupleCnt, C_IDX(2)));
- }
- llvm_unreachable("unexpected kind");
-}
-static Value loadSlicePos(OpBuilder &b, Location l, Value sPosBuf,
- Value tupleIdx, SlicePosKind posKind) {
- return genIndexLoad(b, l, sPosBuf,
- getSlicePosIdx(b, l, sPosBuf, tupleIdx, posKind));
-}
-static void updateSlicePos(OpBuilder &b, Location l, Value sPosBuf, Value pos,
- Value tupleIdx, SlicePosKind posKind) {
- b.create<memref::StoreOp>(l, pos, sPosBuf,
- getSlicePosIdx(b, l, sPosBuf, tupleIdx, posKind));
-}
-
//===----------------------------------------------------------------------===//
// SparseTensorLevel derived classes.
//===----------------------------------------------------------------------===//
@@ -194,6 +155,48 @@ class TwoOutFourLevel : public SparseLevel {
} // namespace
+//===----------------------------------------------------------------------===//
+// File local helpers
+//===----------------------------------------------------------------------===//
+
+static ValueRange
+genWhenInBound(OpBuilder &b, Location l, SparseIterator &it, ValueRange elseRet,
+ llvm::function_ref<void(OpBuilder &, Location, Value)> builder) {
+ // !it.end() ? callback(*crd) : resOOB;
+ TypeRange ifRetTypes = elseRet.getTypes();
+ auto ifOp = b.create<scf::IfOp>(l, ifRetTypes, it.genNotEnd(b, l), true);
+
+ b.setInsertionPointToStart(ifOp.thenBlock());
+ Value crd = it.deref(b, l);
+ builder(b, l, crd);
+
+ b.setInsertionPointToStart(ifOp.elseBlock());
+ YIELD(elseRet);
+
+ b.setInsertionPointAfter(ifOp);
+ return ifOp.getResults();
+}
+
+/// Generates code to compute the *absolute* offset of the slice based on the
+/// provide minimum coordinates in the slice.
+/// E.g., when reducing d0 + d1 + d2, we need two slices to fully reduced the
+/// expression, i,e, s1 = slice(T, d0), s2 = slice(s1, d1). The *absolute*
+/// offset is the offset computed relative to the initial tensors T.
+///
+/// When isNonEmpty == true, the computed offset is meaningless and should not
+/// be used during runtime, the method generates code to return 0 currently in
+/// that case.
+///
+/// offset = minCrd >= size ? minCrd - size + 1 : 0;
+static Value offsetFromMinCrd(OpBuilder &b, Location l, Value minCrd,
+ Value size) {
+ Value geSize = CMPI(uge, minCrd, size);
+ // Computes minCrd - size + 1
+ Value mms = SUBI(ADDI(minCrd, C_IDX(1)), size);
+ // This is the absolute offset related to the actual tensor.
+ return SELECT(geSize, mms, C_IDX(0));
+}
+
//===----------------------------------------------------------------------===//
// SparseIterator derived classes.
//===----------------------------------------------------------------------===//
@@ -221,6 +224,24 @@ class TrivialIterator : public SparseIterator {
bool randomAccessible() const override { return isDenseLT(stl.getLT()); };
bool iteratableByFor() const override { return true; };
+ Value upperBound(OpBuilder &b, Location l) const override {
+ return stl.size();
+ };
+
+ SmallVector<Value> serialize() const override {
+ assert(!randomAccessible());
+ SmallVector<Value> ret;
+ ret.push_back(itPos);
+ ret.push_back(loopHi);
+ return ret;
+ };
+
+ void deserialize(ValueRange vs) override {
+ assert(!randomAccessible());
+ assert(vs.size() == 2);
+ seek(vs.front());
+ loopHi = vs.back();
+ };
ValuePair getCurPosition() const override { return {itPos, nullptr}; }
@@ -256,6 +277,13 @@ class TrivialIterator : public SparseIterator {
return getItVals();
}
+ ValueRange forwardIf(OpBuilder &b, Location l, Value cond) override {
+ Value curPos = getItVals().front();
+ Value nxPos = forward(b, l).front();
+ seek(SELECT(cond, nxPos, curPos));
+ return getItVals();
+ }
+
void locate(OpBuilder &b, Location l, Value crd) override {
assert(randomAccessible());
// Seek to the linearized position.
@@ -286,6 +314,9 @@ class DedupIterator : public SparseIterator {
bool randomAccessible() const override { return false; };
bool iteratableByFor() const override { return false; };
+ Value upperBound(OpBuilder &b, Location l) const override {
+ return stl.size();
+ };
ValuePair getCurPosition() const override { return {getPos(), getSegHi()}; }
@@ -303,6 +334,20 @@ class DedupIterator : public SparseIterator {
seek({posLo, genSegmentHigh(b, l, posLo)});
}
+ SmallVector<Value> serialize() const override {
+ assert(!randomAccessible());
+ SmallVector<Value> ret;
+ ret.append(getItVals().begin(), getItVals().end());
+ ret.push_back(posHi);
+ return ret;
+ };
+ void deserialize(ValueRange vs) override {
+ assert(!randomAccessible());
+ assert(vs.size() == 3);
+ seek(vs.take_front(getItVals().size()));
+ posHi = vs.back();
+ };
+
Value genNotEnd(OpBuilder &b, Location l) override {
return CMPI(ult, getPos(), posHi);
}
@@ -329,19 +374,15 @@ class DedupIterator : public SparseIterator {
class FilterIterator : public SparseIterator {
// Coorindate translation between crd loaded from the wrap iterator and the
// filter iterator.
- Value fromWrapCrd(OpBuilder &b, Location l, Value wrapCrd) {
+ Value fromWrapCrd(OpBuilder &b, Location l, Value wrapCrd) const {
// crd = (wrapCrd - offset) / stride
return DIVUI(SUBI(wrapCrd, offset), stride);
}
- Value toWrapCrd(OpBuilder &b, Location l, Value crd) {
+ Value toWrapCrd(OpBuilder &b, Location l, Value crd) const {
// wrapCrd = crd * stride + offset
return ADDI(MULI(crd, stride), offset);
}
- ValueRange genWhenWrapInBound(
- OpBuilder &b, Location l, ValueRange elseRet,
- llvm::function_ref<ValueRange(OpBuilder &, Location, Value)> builder);
-
Value genCrdNotLegitPredicate(OpBuilder &b, Location l, Value wrapCrd);
Value genShouldFilter(OpBuilder &b, Location l);
@@ -359,7 +400,14 @@ class FilterIterator : public SparseIterator {
bool randomAccessible() const override { return wrap->randomAccessible(); };
bool iteratableByFor() const override { return randomAccessible(); };
+ Value upperBound(OpBuilder &b, Location l) const override {
+ Value maxWrapCrd = SUBI(wrap->upperBound(b, l), C_IDX(1));
+ Value maxCrd = fromWrapCrd(b, l, maxWrapCrd);
+ return ADDI(maxCrd, C_IDX(1));
+ };
+ SmallVector<Value> serialize() const override { return wrap->serialize(); };
+ void deserialize(ValueRange vs) override { wrap->deserialize(vs); };
ValuePair getCurPosition() const override { return wrap->getCurPosition(); }
void genInit(OpBuilder &b, Location l,
@@ -401,69 +449,195 @@ class FilterIterator : public SparseIterator {
std::unique_ptr<SparseIterator> wrap;
};
-/*
+class SubSectIterator;
class NonEmptySubSectIterator : public SparseIterator {
+
+ // The sliced pointer buffer is organized as:
+ // [[itVal0, itVal1, ..., pNx0],
+ // [itVal0, itVal1, ..., pNx0],
+ // ...]
+ Value allocSubSectPosBuf(OpBuilder &b, Location l) {
+ return b.create<memref::AllocaOp>(
+ l,
+ MemRefType::get({ShapedType::kDynamic, tupleSz + 1}, b.getIndexType()),
+ maxTupleCnt);
+ }
+
+ SmallVector<Value> loadItVals(OpBuilder &b, Location l, Value tupleId) const {
+ SmallVector<Value> ret;
+ for (unsigned i = 0; i < tupleSz; i++) {
+ Value v = b.create<memref::LoadOp>(l, subSectPosBuf,
+ ValueRange{tupleId, C_IDX(i)});
+ ret.push_back(v);
+ }
+ return ret;
+ }
+
+ void storeItVals(OpBuilder &b, Location l, Value tupleId, ValueRange itVals) {
+ assert(itVals.size() == tupleSz);
+ for (unsigned i = 0; i < tupleSz; i++) {
+ b.create<memref::StoreOp>(l, itVals[i], subSectPosBuf,
+ ValueRange{tupleId, C_IDX(i)});
+ }
+ }
+
public:
NonEmptySubSectIterator(OpBuilder &b, Location l,
const SparseIterator *parent,
- std::unique_ptr<SparseIterator> &&w, Value size)
- : SparseIterator(IterKind::kNonEmptySubSect, w->tid, w->lvl),
- parent(parent), wrap(std::move(w)), size(size), stride(stride) {
+ std::unique_ptr<SparseIterator> &&wrap,
+ Value subSectSz, unsigned stride)
+ : SparseIterator(IterKind::kNonEmptySubSect, wrap->tid, wrap->lvl,
+ /*itVals=*/subSectMeta),
+ tupleSz(wrap->serialize().size()), subSectSz(subSectSz), stride(stride),
+ parent(parent), wrap(std::move(wrap)) {
auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
+ assert(stride == 1);
if (p == nullptr) {
// Extract subsections along the root level.
- prevUnResCnt = C_IDX(1);
+ maxTupleCnt = C_IDX(1);
} else if (p->lvl == lvl) {
// Extract subsections along the same level.
- prevUnResCnt = p->prevUnResCnt;
+ maxTupleCnt = p->maxTupleCnt;
+ assert(false && "Not implemented.");
} else {
// Extract subsections along the previous level.
assert(p->lvl + 1 == lvl);
- prevUnResCnt = MULI(p->prevUnResCnt, p->size);
+ maxTupleCnt = MULI(p->maxTupleCnt, p->subSectSz);
}
-
// We don't need an extra buffer to find subsections on dense levels.
if (randomAccessible())
return;
- subSectPosBuf = allocSlicePosBuf(b, l, prevUnResCnt);
+
+ subSectPosBuf = allocSubSectPosBuf(b, l);
}
+ bool randomAccessible() const override { return wrap->randomAccessible(); };
+ bool iteratableByFor() const override { return randomAccessible(); };
+ Value upperBound(OpBuilder &b, Location l) const override {
+ auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
+ Value parentUB =
+ p && p->lvl == lvl ? p->upperBound(b, l) : wrap->upperBound(b, l);
+ return ADDI(SUBI(parentUB, subSectSz), C_IDX(1));
+ };
+
// For LLVM-style RTTI.
static bool classof(const SparseIterator *from) {
return from->kind == IterKind::kNonEmptySubSect;
}
- bool randomAccessible() const override { return wrap->randomAccessible(); };
- bool iteratableByFor() const override { return randomAccessible(); };
+ void genInit(OpBuilder &b, Location l, const SparseIterator *) override;
- Value size, prevUnResCnt, subSectPosBuf;
- unsigned stride;
+ Value genNotEnd(OpBuilder &b, Location l) override { return getNotEnd(); };
+
+ Value deref(OpBuilder &b, Location l) override {
+ // Use the relative offset to coiterate.
+ Value crd;
+ auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
+ if (p && p->lvl == lvl)
+ crd = SUBI(getAbsOff(), p->getAbsOff());
+ crd = getAbsOff();
+
+ updateCrd(crd);
+ return crd;
+ };
+
+ ValueRange forward(OpBuilder &b, Location l) override;
+
+ Value getMinCrd() const { return subSectMeta[0]; }
+ Value getAbsOff() const { return subSectMeta[1]; }
+ Value getNotEnd() const { return subSectMeta[2]; }
+
+ Value maxTupleCnt, tupleCnt;
+ Value subSectPosBuf;
+ const unsigned tupleSz;
+ const Value subSectSz;
+ const unsigned stride;
+
+ const SparseIterator *parent;
+ std::unique_ptr<SparseIterator> wrap;
+
+ Value subSectMeta[3]; // minCrd, absolute offset, notEnd
+
+ friend SubSectIterator;
};
class SubSectIterator : public SparseIterator {
-public:
- SubSectIterator(const SparseIterator *parent,
- std::unique_ptr<SparseIterator> &&w)
- : SparseIterator(IterKind::kSubSect, w->tid, w->lvl), parent(parent),
- wrap(std::move(w)) {}
-
- // For LLVM-style RTTI.
- static bool classof(const SparseIterator *from) {
- return from->kind == IterKind::kSubSect;
+ Value fromWrapCrd(OpBuilder &b, Location l, Value wrapCrd) {
+ assert(stride == 1);
+ return SUBI(wrapCrd, subSect.getAbsOff());
}
+public:
+ SubSectIterator(const NonEmptySubSectIterator &subSect,
+ const SparseIterator &parent,
+ std::unique_ptr<SparseIterator> &&wrap, Value size,
+ unsigned stride)
+ : SparseIterator(IterKind::kSubSect, wrap.get()), subSect(subSect),
+ parent(parent), wrap(std::move(wrap)), size(size), stride(stride) {
+ assert(stride == 1 && "Not implemented.");
+ assert(subSect.tid == tid && subSect.lvl == lvl);
+ // The immediate parents of a subsection iterator is either a non-empty
+ // subsect iterator or another subsection iterator for the previous level
+ // depending on the index varaiables' reduction order.
+ assert(parent.kind == IterKind::kNonEmptySubSect ||
+ parent.kind == IterKind::kSubSect);
+ assert(parent.kind != IterKind::kNonEmptySubSect || &parent == &subSect);
+ assert(parent.kind != IterKind::kSubSect || parent.lvl + 1 == lvl);
+ };
+
bool randomAccessible() const override { return wrap->randomAccessible(); };
bool iteratableByFor() const override { return randomAccessible(); };
+ Value upperBound(OpBuilder &b, Location l) const override { return size; }
+ std::pair<Value, Value> getCurPosition() const override {
+ return wrap->getCurPosition();
+ };
+
+ void genInit(OpBuilder &b, Location l, const SparseIterator *) override {
+ Value tupleId;
+ if (llvm::isa<NonEmptySubSectIterator>(parent)) {
+ tupleId = C_IDX(0);
+ } else {
+ llvm_unreachable("Not implemented");
+ }
+ wrap->deserialize(subSect.loadItVals(b, l, tupleId));
+ }
+
+ Value genNotEnd(OpBuilder &b, Location l) override {
+ assert(!wrap->randomAccessible());
+ ValueRange r = genWhenInBound(
+ b, l, *wrap, C_FALSE, [this](OpBuilder &b, Location l, Value wrapCrd) {
+ Value crd = fromWrapCrd(b, l, wrapCrd);
+ // crd < size
+ YIELD(CMPI(ult, crd, size));
+ });
+ assert(r.size() == 1);
+ return r.front();
+ }
+
+ Value deref(OpBuilder &b, Location l) override {
+ Value wrapCrd = wrap->deref(b, l);
+ Value crd = fromWrapCrd(b, l, wrapCrd);
+ updateCrd(crd);
+ return crd;
+ };
+
+ ValueRange forward(OpBuilder &b, Location l) override {
+ return wrap->forward(b, l);
+ };
+
+ const NonEmptySubSectIterator &subSect;
+ const SparseIterator &parent;
- const SparseIterator *parent;
std::unique_ptr<SparseIterator> wrap;
+ Value size;
+ unsigned stride;
};
-*/
+
} // namespace
//===----------------------------------------------------------------------===//
-// SparseIterator derived classes impl.
+// Complex SparseIterator derived classes impl.
//===----------------------------------------------------------------------===//
ValueRange SparseIterator::forwardIf(OpBuilder &b, Location l, Value cond) {
@@ -512,24 +686,6 @@ Value DedupIterator::genSegmentHigh(OpBuilder &b, Location l, Value pos) {
return whileOp.getResult(0);
}
-ValueRange FilterIterator::genWhenWrapInBound(
- OpBuilder &b, Location l, ValueRange elseRet,
- llvm::function_ref<ValueRange(OpBuilder &, Location, Value)> builder) {
- // !it.end() ? callback(*crd) : resOOB;
- TypeRange ifRetTypes = elseRet.getTypes();
- auto ifOp = b.create<scf::IfOp>(l, ifRetTypes, wrap->genNotEnd(b, l), true);
-
- b.setInsertionPointToStart(ifOp.thenBlock());
- Value wrapCrd = wrap->deref(b, l);
- YIELD(builder(b, l, wrapCrd));
-
- b.setInsertionPointToStart(ifOp.elseBlock());
- YIELD(elseRet);
-
- b.setInsertionPointAfter(ifOp);
- return ifOp.getResults();
-}
-
Value FilterIterator::genCrdNotLegitPredicate(OpBuilder &b, Location l,
Value wrapCrd) {
Value crd = fromWrapCrd(b, l, wrapCrd);
@@ -543,10 +699,10 @@ Value FilterIterator::genCrdNotLegitPredicate(OpBuilder &b, Location l,
}
Value FilterIterator::genShouldFilter(OpBuilder &b, Location l) {
- ValueRange r = genWhenWrapInBound(
- b, l, C_FALSE, [this](OpBuilder &b, Location l, Value wrapCrd) {
+ ValueRange r = genWhenInBound(
+ b, l, *wrap, C_FALSE, [this](OpBuilder &b, Location l, Value wrapCrd) {
Value notLegit = genCrdNotLegitPredicate(b, l, wrapCrd);
- return notLegit.getDefiningOp()->getResults();
+ YIELD(notLegit);
});
assert(r.size() == 1);
@@ -555,11 +711,11 @@ Value FilterIterator::genShouldFilter(OpBuilder &b, Location l) {
Value FilterIterator::genNotEnd(OpBuilder &b, Location l) {
assert(!wrap->randomAccessible());
- ValueRange r = genWhenWrapInBound(
- b, l, C_FALSE, [this](OpBuilder &b, Location l, Value wrapCrd) {
+ ValueRange r = genWhenInBound(
+ b, l, *wrap, C_FALSE, [this](OpBuilder &b, Location l, Value wrapCrd) {
Value crd = fromWrapCrd(b, l, wrapCrd);
// crd < size
- return CMPI(ult, crd, size).getDefiningOp()->getResults();
+ YIELD(CMPI(ult, crd, size));
});
assert(r.size() == 1);
return r.front();
@@ -578,14 +734,16 @@ ValueRange FilterIterator::forward(OpBuilder &b, Location l) {
/*beforeBuilder=*/
[this](OpBuilder &b, Location l, ValueRange ivs) {
linkNewScope(ivs);
- ValueRange cont = genWhenWrapInBound(
- b, l, C_FALSE, [this](OpBuilder &b, Location l, Value wrapCrd) {
- // crd < size && !legit();
- Value notLegit = genCrdNotLegitPredicate(b, l, wrapCrd);
- Value crd = fromWrapCrd(b, l, wrapCrd);
- Value ret = ANDI(CMPI(ult, crd, size), notLegit);
- return ret.getDefiningOp()->getResults();
- });
+ ValueRange cont =
+ genWhenInBound(b, l, *wrap, C_FALSE,
+ [this](OpBuilder &b, Location l, Value wrapCrd) {
+ // crd < size && !legit();
+ Value notLegit =
+ genCrdNotLegitPredicate(b, l, wrapCrd);
+ Value crd = fromWrapCrd(b, l, wrapCrd);
+ Value ret = ANDI(CMPI(ult, crd, size), notLegit);
+ YIELD(ret);
+ });
b.create<scf::ConditionOp>(l, cont.front(), ivs);
},
/*afterBuilder=*/
@@ -600,6 +758,132 @@ ValueRange FilterIterator::forward(OpBuilder &b, Location l) {
return getItVals();
}
+void NonEmptySubSectIterator::genInit(OpBuilder &b, Location l,
+ const SparseIterator *) {
+ auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
+ if (p) {
+ llvm_unreachable("Not implemented");
+ } else {
+ wrap->genInit(b, l, parent);
+ Value c0 = C_IDX(0);
+ if (randomAccessible()) {
+ seek({/*minCrd=*/c0, /*offset=*/c0, /*notEnd=*/C_TRUE});
+ return;
+ }
+ // Handle sparse subsection iterator.
+ tupleCnt = C_IDX(1);
+ SmallVector<Value> elseRet{c0, c0, /*notEnd=*/C_FALSE};
+ ValueRange meta = genWhenInBound(
+ b, l, *wrap, elseRet, [this](OpBuilder &b, Location l, Value crd) {
+ Value offset = offsetFromMinCrd(b, l, crd, subSectSz);
+ YIELD((ValueRange{crd, offset, C_TRUE}));
+ });
+
+ seek(meta);
+ SmallVector<Value> itVals = wrap->serialize();
+ storeItVals(b, l, c0, itVals);
+ }
+}
+
+ValueRange NonEmptySubSectIterator::forward(OpBuilder &b, Location l) {
+ assert(!randomAccessible());
+ Value c0 = C_IDX(0), c1 = C_IDX(1);
+ // Forward to the next non empty slice by generating
+ //
+ // if (minCrd > offset) {
+ // offset += 1
+ // } else {
+ // minCrd = nextMinInSlice();
+ // offset = minCrd - size + 1;
+ // }
+ //
+ // if (offset + size > parents.size)
+ // isNonEmpty = false;
+ Value fastPathP = CMPI(ugt, getMinCrd(), getAbsOff());
+ auto ifOp = b.create<scf::IfOp>(l, getItVals().getTypes(), fastPathP, true);
+ {
+ OpBuilder::InsertionGuard guard(b);
+ // Take the fast path
+ // if (minCrd > offset)
+ // offset += 1
+ b.setInsertionPointToStart(&ifOp.getThenRegion().front());
+ Value nxOffset = ADDI(getAbsOff(), c1);
+ YIELD((ValueRange{getMinCrd(), nxOffset, getNotEnd()}));
+
+ // else /*minCrd == offset*/ {
+ // for (i = 0; i < tupleCnt; i++) {
+ // wrap->deserialize(pos[i]);
+ // minCrd=min(minCrd, *wrap);
+ // }
+ // offset = minCrd - size + 1;
+ // }
+ b.setInsertionPointToStart(&ifOp.getElseRegion().front());
+ ValueRange loopArgs{upperBound(b, l), // nextMinCrd
+ C_FALSE}; // isNotEnd
+ auto loopNest = scf::buildLoopNest(
+ b, l, c0, tupleCnt, c1, loopArgs,
+ [this](OpBuilder &b, Location l, ValueRange ivs,
+ ValueRange iterArgs) -> scf::ValueVector {
+ Value tupleId = ivs.front();
+ SmallVector<Value> itVals = loadItVals(b, l, tupleId);
+ wrap->deserialize(itVals);
+ return genWhenInBound(
+ b, l, *wrap, /*elseRet=*/iterArgs,
+ [this, iterArgs, tupleId](OpBuilder &b, Location l, Value crd) {
+ // if coord == minCrd
+ // wrap->forward();
+ Value isMin = CMPI(eq, crd, getMinCrd());
+ wrap->forwardIf(b, l, isMin);
+ // Update the forwarded iterator values if needed.
+ auto ifIsMin = b.create<scf::IfOp>(l, isMin, false);
+ b.setInsertionPointToStart(&ifIsMin.getThenRegion().front());
+ storeItVals(b, l, tupleId, wrap->serialize());
+ b.setInsertionPointAfter(ifIsMin);
+ // if (!wrap.end())
+ // yield(min(nxMinCrd, *wrap), true)
+ Value nxMin = iterArgs[0];
+ ValueRange ret = genWhenInBound(
+ b, l, *wrap, /*elseRet=*/iterArgs,
+ [nxMin](OpBuilder &b, Location l, Value crd) {
+ Value nx = SELECT(CMPI(ult, crd, nxMin), crd, nxMin);
+ YIELD((ValueRange{nx, C_TRUE}));
+ });
+ YIELD(ret);
+ });
+ });
+
+ scf::ForOp forOp = loopNest.loops.front();
+ b.setInsertionPointAfter(forOp);
+
+ Value nxMinCrd = forOp.getResult(0);
+ Value nxNotEnd = forOp.getResult(1);
+ Value nxAbsOff = offsetFromMinCrd(b, l, nxMinCrd, subSectSz);
+ YIELD((ValueRange{nxMinCrd, nxAbsOff, nxNotEnd}));
+ }
+
+ Value nxMinCrd = ifOp.getResult(0);
+ Value nxAbsOff = ifOp.getResult(1);
+ Value nxNotEnd = ifOp.getResult(2);
+
+ // We should at least forward the offset by one.
+ Value minAbsOff = ADDI(getAbsOff(), c1);
+ nxAbsOff = SELECT(CMPI(ugt, minAbsOff, nxAbsOff), minAbsOff, nxAbsOff);
+
+ assert(stride == 1 && "Not yet implemented");
+
+ seek(ValueRange{nxMinCrd, nxAbsOff, nxNotEnd});
+ // The coordinate should not exceeds the space upper bound.
+ Value crd = deref(b, l);
+ nxNotEnd = ANDI(nxNotEnd, CMPI(ult, crd, upperBound(b, l)));
+
+ seek(ValueRange{nxMinCrd, nxAbsOff, nxNotEnd});
+ return getItVals();
+}
+
+//===----------------------------------------------------------------------===//
+// SparseIterator factory functions.
+//===----------------------------------------------------------------------===//
+
std::unique_ptr<SparseTensorLevel>
sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
unsigned tid, Level lvl) {
@@ -661,15 +945,16 @@ sparse_tensor::makeSlicedLevelIterator(std::unique_ptr<SparseIterator> &&sit,
std::unique_ptr<SparseIterator> sparse_tensor::makeNonEmptySubSectIterator(
OpBuilder &b, Location l, const SparseIterator *parent,
std::unique_ptr<SparseIterator> &&delegate, Value size, unsigned stride) {
- return nullptr;
- // return std::make_unique<NonEmptySubSectIterator>(
- // b, l, parent, std::move(lvlIt), size, stride);
+ return std::make_unique<NonEmptySubSectIterator>(
+ b, l, parent, std::move(delegate), size, stride);
}
std::unique_ptr<SparseIterator> sparse_tensor::makeTraverseSubSectIterator(
- const SparseIterator *, std::unique_ptr<SparseIterator> &&delegate) {
- return nullptr;
- // return std::make_unique<SubSectIterator>(parent, std::move(lvlIt));
+ const SparseIterator &subsectIter, const SparseIterator &parent,
+ std::unique_ptr<SparseIterator> &&wrap, Value size, unsigned stride) {
+ return std::make_unique<SubSectIterator>(
+ llvm::cast<NonEmptySubSectIterator>(subsectIter), parent, std::move(wrap),
+ size, stride);
}
#undef CMPI
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
index 770a6eb9b78d1f..bf366ad2cdad2d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
@@ -109,6 +109,23 @@ class SparseIterator {
virtual bool randomAccessible() const = 0;
// Whether the iterator can simply traversed by a for loop.
virtual bool iteratableByFor() const { return false; };
+ // Get the upper bound of the sparse space that the iterator might visited. A
+ // sparse space is a subset of a dense space [0, bound), this function returns
+ // *bound*.
+ virtual Value upperBound(OpBuilder &b, Location l) const = 0;
+
+ // Serialize and deserialize the current status to/from a set of values. The
+ // ValueRange should contain values that specifies the postion and loop bound.
+ //
+ // Not every type of iterator supports the operations, e.g., non-empty
+ // subsection iterator does not because the the number of non-empty
+ // subsections can not be determined in advance.
+ //
+ // NOTE: All the values should have index type.
+ virtual SmallVector<Value> serialize() const {
+ llvm_unreachable("unsupported");
+ };
+ virtual void deserialize(ValueRange vs) { llvm_unreachable("unsupported"); };
//
// Core functions.
@@ -127,8 +144,7 @@ class SparseIterator {
// Initialize the iterator according to the parent iterator's state.
virtual void genInit(OpBuilder &, Location, const SparseIterator *) = 0;
- // Return a tuple of values for *upper*, *lower* bound and *step*
- // respectively.
+ // Return a pair of values for *upper*, *lower* bound respectively.
virtual std::pair<Value, Value> genForCond(OpBuilder &, Location) {
llvm_unreachable("Unsupported");
}
@@ -136,8 +152,8 @@ class SparseIterator {
virtual Value genNotEnd(OpBuilder &b, Location l) = 0;
std::pair<Value, ValueRange> genWhileCond(OpBuilder &b, Location l,
ValueRange vs) {
- seek(vs.take_front(itVals.size()));
- return std::make_pair(genNotEnd(b, l), vs.drop_front(itVals.size()));
+ ValueRange rem = linkNewScope(vs);
+ return std::make_pair(genNotEnd(b, l), rem);
}
// Dereference the iterator, loads the coordinate at the current position.
@@ -213,11 +229,11 @@ makeSlicedLevelIterator(std::unique_ptr<SparseIterator> &&sit, Value offset,
std::unique_ptr<SparseIterator> makeNonEmptySubSectIterator(
OpBuilder &b, Location l, const SparseIterator *parent,
- std::unique_ptr<SparseIterator> &&lvlIt, Value size, unsigned stride);
+ std::unique_ptr<SparseIterator> &&delegate, Value size, unsigned stride);
-std::unique_ptr<SparseIterator>
-makeTraverseSubSectIterator(const SparseIterator *parent,
- std::unique_ptr<SparseIterator> &&lvlIt);
+std::unique_ptr<SparseIterator> makeTraverseSubSectIterator(
+ const SparseIterator &subsectIter, const SparseIterator &parent,
+ std::unique_ptr<SparseIterator> &&delegate, Value size, unsigned stride);
} // namespace sparse_tensor
} // namespace mlir
>From 0fae1491af80821beb5ff3b39a688aad0a682f36 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 10 Jan 2024 00:42:38 +0000
Subject: [PATCH 04/11] support randomly accessible non-empty subsection
iterator.
---
.../Transforms/Utils/SparseTensorLevel.cpp | 73 +++++++++++++++----
1 file changed, 58 insertions(+), 15 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
index 79ba3230ac068d..676f7b40a6e9bb 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
@@ -229,18 +229,25 @@ class TrivialIterator : public SparseIterator {
};
SmallVector<Value> serialize() const override {
- assert(!randomAccessible());
SmallVector<Value> ret;
- ret.push_back(itPos);
- ret.push_back(loopHi);
+ if (randomAccessible())
+ ret.push_back(posLo);
+ else {
+ ret.push_back(itPos);
+ ret.push_back(loopHi);
+ }
return ret;
};
void deserialize(ValueRange vs) override {
- assert(!randomAccessible());
- assert(vs.size() == 2);
- seek(vs.front());
- loopHi = vs.back();
+ if (randomAccessible()) {
+ assert(vs.size() == 1);
+ posLo = vs.front();
+ } else {
+ assert(vs.size() == 2);
+ seek(vs.front());
+ loopHi = vs.back();
+ }
};
ValuePair getCurPosition() const override { return {itPos, nullptr}; }
@@ -335,14 +342,12 @@ class DedupIterator : public SparseIterator {
}
SmallVector<Value> serialize() const override {
- assert(!randomAccessible());
SmallVector<Value> ret;
ret.append(getItVals().begin(), getItVals().end());
ret.push_back(posHi);
return ret;
};
void deserialize(ValueRange vs) override {
- assert(!randomAccessible());
assert(vs.size() == 3);
seek(vs.take_front(getItVals().size()));
posHi = vs.back();
@@ -488,8 +493,8 @@ class NonEmptySubSectIterator : public SparseIterator {
Value subSectSz, unsigned stride)
: SparseIterator(IterKind::kNonEmptySubSect, wrap->tid, wrap->lvl,
/*itVals=*/subSectMeta),
- tupleSz(wrap->serialize().size()), subSectSz(subSectSz), stride(stride),
- parent(parent), wrap(std::move(wrap)) {
+ subSectSz(subSectSz), stride(stride), parent(parent),
+ wrap(std::move(wrap)) {
auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
assert(stride == 1);
@@ -509,6 +514,7 @@ class NonEmptySubSectIterator : public SparseIterator {
if (randomAccessible())
return;
+ tupleSz = this->wrap->serialize().size();
subSectPosBuf = allocSubSectPosBuf(b, l);
}
@@ -528,6 +534,22 @@ class NonEmptySubSectIterator : public SparseIterator {
void genInit(OpBuilder &b, Location l, const SparseIterator *) override;
+ std::pair<Value, Value> genForCond(OpBuilder &b, Location l) override {
+ // Yield a dense range [curCrd, upperBound).
+ return {deref(b, l), upperBound(b, l)};
+ }
+
+ void locate(OpBuilder &b, Location l, Value crd) override {
+ Value absOff = crd;
+ auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
+ if (p && p->lvl == lvl)
+ absOff = ADDI(crd, p->getAbsOff());
+
+ wrap->locate(b, l, absOff);
+ seek(ValueRange{absOff, absOff, C_TRUE});
+ updateCrd(crd);
+ }
+
Value genNotEnd(OpBuilder &b, Location l) override { return getNotEnd(); };
Value deref(OpBuilder &b, Location l) override {
@@ -548,9 +570,13 @@ class NonEmptySubSectIterator : public SparseIterator {
Value getAbsOff() const { return subSectMeta[1]; }
Value getNotEnd() const { return subSectMeta[2]; }
+ // Number of values required to serialize the wrapped iterator.
+ unsigned tupleSz;
+ // Max number of tuples, and the actual number of tuple.
Value maxTupleCnt, tupleCnt;
+ // The memory used to cache the tuple serialized from the wrapped iterator.
Value subSectPosBuf;
- const unsigned tupleSz;
+
const Value subSectSz;
const unsigned stride;
@@ -594,13 +620,30 @@ class SubSectIterator : public SparseIterator {
};
void genInit(OpBuilder &b, Location l, const SparseIterator *) override {
- Value tupleId;
if (llvm::isa<NonEmptySubSectIterator>(parent)) {
- tupleId = C_IDX(0);
+ if (randomAccessible()) {
+ // A dense range can be inferred without caching.
+ wrap->deserialize(subSect.wrap->serialize());
+ // Locate the random accessible iterator to the offset of the
+ // subsection to iterate over [offset, offset + size) later.
+ wrap->locate(b, l, subSect.getAbsOff());
+ return;
+ }
+ wrap->deserialize(subSect.loadItVals(b, l, C_IDX(0)));
} else {
llvm_unreachable("Not implemented");
}
- wrap->deserialize(subSect.loadItVals(b, l, tupleId));
+ }
+
+ std::pair<Value, Value> genForCond(OpBuilder &b, Location l) override {
+ // Yield a dense range [curCrd, upperBound).
+ return {deref(b, l), upperBound(b, l)};
+ }
+
+ void locate(OpBuilder &b, Location l, Value crd) override {
+ Value absCrd = ADDI(crd, subSect.getAbsOff());
+ wrap->locate(b, l, absCrd);
+ updateCrd(crd);
}
Value genNotEnd(OpBuilder &b, Location l) override {
>From bcfcb4972826b8f4d2dce182bc678d5e278397e5 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 10 Jan 2024 19:09:05 +0000
Subject: [PATCH 05/11] provide default genForCond() implementation for
random-access iterator
---
.../Transforms/Utils/SparseTensorLevel.cpp | 77 ++++++++-----------
.../Transforms/Utils/SparseTensorLevel.h | 6 +-
2 files changed, 34 insertions(+), 49 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
index 676f7b40a6e9bb..0cab3d1ebef72d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
@@ -230,24 +230,24 @@ class TrivialIterator : public SparseIterator {
SmallVector<Value> serialize() const override {
SmallVector<Value> ret;
- if (randomAccessible())
+ ret.push_back(itPos);
+ if (randomAccessible()) {
+ // Loop high is implicit (defined by `upperBound()`) for random-access
+ // iterator, but we need to memorize posLo for linearization.
ret.push_back(posLo);
- else {
- ret.push_back(itPos);
- ret.push_back(loopHi);
+ } else {
+ ret.push_back(posHi);
}
return ret;
};
void deserialize(ValueRange vs) override {
- if (randomAccessible()) {
- assert(vs.size() == 1);
- posLo = vs.front();
- } else {
- assert(vs.size() == 2);
- seek(vs.front());
- loopHi = vs.back();
- }
+ assert(vs.size() == 2);
+ seek(vs.front());
+ if (randomAccessible())
+ posLo = vs.back();
+ else
+ posHi = vs.back();
};
ValuePair getCurPosition() const override { return {itPos, nullptr}; }
@@ -259,23 +259,28 @@ class TrivialIterator : public SparseIterator {
if (parent)
std::tie(pos, hi) = parent->getCurPosition();
- std::tie(posLo, loopHi) = stl.peekRangeAt(b, l, pos, hi);
+ std::tie(posLo, posHi) = stl.peekRangeAt(b, l, pos, hi);
// Seek to the lowest position.
seek(posLo);
}
ValuePair genForCond(OpBuilder &b, Location l) override {
- assert(iteratableByFor());
- return std::make_pair(getLoopLo(b, l), loopHi);
+ if (randomAccessible())
+ return {deref(b, l), upperBound(b, l)};
+ return std::make_pair(getLoopLo(b, l), posHi);
}
Value genNotEnd(OpBuilder &b, Location l) override {
// We used the first level bound as the bound the collapsed set of levels.
- return CMPI(ult, itPos, loopHi);
+ return CMPI(ult, itPos, posHi);
}
Value deref(OpBuilder &b, Location l) override {
- updateCrd(stl.peekCrdAt(b, l, itPos));
+ if (randomAccessible()) {
+ updateCrd(SUBI(itPos, posLo));
+ } else {
+ updateCrd(stl.peekCrdAt(b, l, itPos));
+ }
return getCrd();
};
@@ -300,7 +305,7 @@ class TrivialIterator : public SparseIterator {
Value itPos; // the position that represent the iterator
- Value posLo, loopHi;
+ Value posLo, posHi;
const SparseTensorLevel &stl;
};
@@ -405,11 +410,7 @@ class FilterIterator : public SparseIterator {
bool randomAccessible() const override { return wrap->randomAccessible(); };
bool iteratableByFor() const override { return randomAccessible(); };
- Value upperBound(OpBuilder &b, Location l) const override {
- Value maxWrapCrd = SUBI(wrap->upperBound(b, l), C_IDX(1));
- Value maxCrd = fromWrapCrd(b, l, maxWrapCrd);
- return ADDI(maxCrd, C_IDX(1));
- };
+ Value upperBound(OpBuilder &b, Location l) const override { return size; };
SmallVector<Value> serialize() const override { return wrap->serialize(); };
void deserialize(ValueRange vs) override { wrap->deserialize(vs); };
@@ -422,19 +423,13 @@ class FilterIterator : public SparseIterator {
// TODO: we can skip this when stride == 1 and offset == 0, we can also
// use binary search here.
forwardIf(b, l, genShouldFilter(b, l));
+ } else {
+ // Else, locate to the slice.offset, which is the first coordinate
+ // included by the slice.
+ wrap->locate(b, l, offset);
}
}
- ValuePair genForCond(OpBuilder &b, Location l) override {
- assert(randomAccessible());
-
- auto [lo, hi] = wrap->genForCond(b, l);
- // if offset < lo, we use lo - offset as the new lower bound, else we use 0.
- Value loInBound = CMPI(ult, offset, lo);
- lo = SELECT(loInBound, SUBI(lo, offset), C_IDX(0));
- return {lo, size};
- }
-
Value genNotEnd(OpBuilder &b, Location l) override;
Value deref(OpBuilder &b, Location l) override {
@@ -534,11 +529,6 @@ class NonEmptySubSectIterator : public SparseIterator {
void genInit(OpBuilder &b, Location l, const SparseIterator *) override;
- std::pair<Value, Value> genForCond(OpBuilder &b, Location l) override {
- // Yield a dense range [curCrd, upperBound).
- return {deref(b, l), upperBound(b, l)};
- }
-
void locate(OpBuilder &b, Location l, Value crd) override {
Value absOff = crd;
auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
@@ -622,24 +612,17 @@ class SubSectIterator : public SparseIterator {
void genInit(OpBuilder &b, Location l, const SparseIterator *) override {
if (llvm::isa<NonEmptySubSectIterator>(parent)) {
if (randomAccessible()) {
- // A dense range can be inferred without caching.
+ // We continue from the parent's offset.
wrap->deserialize(subSect.wrap->serialize());
- // Locate the random accessible iterator to the offset of the
- // subsection to iterate over [offset, offset + size) later.
- wrap->locate(b, l, subSect.getAbsOff());
return;
}
+ // Else deserializing from the cached values.
wrap->deserialize(subSect.loadItVals(b, l, C_IDX(0)));
} else {
llvm_unreachable("Not implemented");
}
}
- std::pair<Value, Value> genForCond(OpBuilder &b, Location l) override {
- // Yield a dense range [curCrd, upperBound).
- return {deref(b, l), upperBound(b, l)};
- }
-
void locate(OpBuilder &b, Location l, Value crd) override {
Value absCrd = ADDI(crd, subSect.getAbsOff());
wrap->locate(b, l, absCrd);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
index bf366ad2cdad2d..6f6d28e24c2750 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
@@ -145,8 +145,10 @@ class SparseIterator {
virtual void genInit(OpBuilder &, Location, const SparseIterator *) = 0;
// Return a pair of values for *upper*, *lower* bound respectively.
- virtual std::pair<Value, Value> genForCond(OpBuilder &, Location) {
- llvm_unreachable("Unsupported");
+ virtual std::pair<Value, Value> genForCond(OpBuilder &b, Location l) {
+ assert(randomAccessible());
+ // Random-access iterator is traversed by coordinate, i.e., [curCrd, UB).
+ return {deref(b, l), upperBound(b, l)};
}
virtual Value genNotEnd(OpBuilder &b, Location l) = 0;
>From 3ad3bc1c62dd3a5a93adb24d004cc94441528d30 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 11 Jan 2024 04:08:55 +0000
Subject: [PATCH 06/11] handle more convolution variants
---
.../Transforms/Utils/LoopEmitter.cpp | 25 +-
.../Transforms/Utils/LoopEmitter.h | 3 +
.../Transforms/Utils/SparseTensorLevel.cpp | 543 +++++++++++++-----
.../Transforms/Utils/SparseTensorLevel.h | 24 +-
4 files changed, 445 insertions(+), 150 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
index da0d339427920f..08a326eda36bfd 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
@@ -543,17 +543,19 @@ void LoopEmitter::initSubSectIterator(OpBuilder &builder, Location loc) {
std::sort(depRedOrder.begin(), depRedOrder.end(),
[](auto &l, auto &r) { return std::get<0>(l) < std::get<0>(r); });
+ SmallVector<SparseIterator *> lastIter(tensors.size(), nullptr);
for (auto [loop, t, lvl] : depRedOrder) {
std::pair<LoopId, unsigned> curDep = remDepStack[t][lvl].back();
assert(curDep.first == loop);
remDepStack[t][lvl].pop_back();
auto lvlIt = makeLevelIterator(builder, loc, t, lvl);
- const SparseIterator *parent =
- lvl == 0 && iters[t][lvl].empty()
- ? nullptr
- : (!iters[t][lvl].empty() ? iters[t][lvl].back().get()
- : iters[t][lvl - 1].back().get());
+ const SparseIterator *parent = lastIter[t];
+ if (!parent && lvl > 0) {
+ if (dependentLvlMap[t][lvl - 1].empty()) {
+ parent = iters[t][lvl - 1].back().get();
+ }
+ }
std::unique_ptr<SparseIterator> it;
if (!remDepStack[t][lvl].empty()) {
@@ -571,6 +573,7 @@ void LoopEmitter::initSubSectIterator(OpBuilder &builder, Location loc) {
it = makeTraverseSubSectIterator(subSectIter, *parent, std::move(lvlIt),
size, curDep.second);
}
+ lastIter[t] = it.get();
iters[t][lvl].emplace_back(std::move(it));
}
}
@@ -1343,10 +1346,10 @@ void LoopEmitter::genDenseAffineAddress(OpBuilder &builder, Location loc,
TensorLevel tidLvl,
AffineExpr lvlExpr) {
auto [tid, lvl] = unpackTensorLevel(tidLvl);
- assert(isDenseLT(lvlTypes[tid][lvl]));
- // For dense levels, the vel-coordinate also serves as the position.
+ auto &it = getCurIterator(tid, lvl);
+ assert(it.kind == IterKind::kTrivial && it.randomAccessible());
Value lvlCrd = genAffine(builder, loc, lvlExpr);
- posits[tid][lvl] = genAddress(builder, loc, tid, lvl, lvlCrd);
+ it.locate(builder, loc, lvlCrd);
}
void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
@@ -1359,7 +1362,11 @@ void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
const SparseIterator *parent =
hasParent ? nullptr : iters[tid][lvl - 1].back().get();
- getCurIterator(tid, lvl).genInit(builder, loc, parent);
+ auto &it = getCurIterator(tid, lvl);
+ it.genInit(builder, loc, parent);
+ if (it.randomAccessible()) {
+ it.locate(builder, loc, C_IDX(0));
+ }
}
void LoopEmitter::enterTensorsAtDenseLvls(
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
index aafb56f03ef607..2bd2b653a4d9f3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
@@ -558,6 +558,9 @@ class LoopEmitter {
unsigned redDepOnLevel(TensorId tid, Level lvl) const;
SparseIterator &getCurIterator(TensorId tid, Level lvl) const {
+ if (dependentLvlMap[tid][lvl].empty())
+ return *iters[tid][lvl].back();
+
assert(redDepOnLevel(tid, lvl) >= 1);
return *iters[tid][lvl][redDepOnLevel(tid, lvl) - 1];
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
index 0cab3d1ebef72d..c7bc365b89c32d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
@@ -34,6 +34,7 @@ using ValueTuple = std::tuple<Value, Value, Value>;
#define ANDI(lhs, rhs) (b.create<arith::AndIOp>(l, (lhs), (rhs)).getResult())
#define SUBI(lhs, rhs) (b.create<arith::SubIOp>(l, (lhs), (rhs)).getResult())
#define MULI(lhs, rhs) (b.create<arith::MulIOp>(l, (lhs), (rhs)).getResult())
+#define MINUI(lhs, rhs) (b.create<arith::MinUIOp>(l, (lhs), (rhs)).getResult())
#define REMUI(lhs, rhs) (b.create<arith::RemUIOp>(l, (lhs), (rhs)).getResult())
#define DIVUI(lhs, rhs) (b.create<arith::DivUIOp>(l, (lhs), (rhs)).getResult())
#define SELECT(c, lhs, rhs) \
@@ -159,16 +160,28 @@ class TwoOutFourLevel : public SparseLevel {
// File local helpers
//===----------------------------------------------------------------------===//
-static ValueRange
-genWhenInBound(OpBuilder &b, Location l, SparseIterator &it, ValueRange elseRet,
- llvm::function_ref<void(OpBuilder &, Location, Value)> builder) {
+static scf::ValueVector genWhenInBound(
+ OpBuilder &b, Location l, SparseIterator &it, ValueRange elseRet,
+ llvm::function_ref<scf::ValueVector(OpBuilder &, Location, Value)>
+ builder) {
+ // Value isNotEnd = it.genNotEnd(b, l);
+ // Value crd = it.deref(b, l);
+ // scf::ValueVector ret = builder(b, l, crd);
+
+ // scf::ValueVector res;
+ // for (auto [notEnd, end] : llvm::zip_equal(ret, elseRet)) {
+ // res.push_back(SELECT(isNotEnd, notEnd, end));
+ // };
+ // return res;
+
// !it.end() ? callback(*crd) : resOOB;
TypeRange ifRetTypes = elseRet.getTypes();
auto ifOp = b.create<scf::IfOp>(l, ifRetTypes, it.genNotEnd(b, l), true);
b.setInsertionPointToStart(ifOp.thenBlock());
Value crd = it.deref(b, l);
- builder(b, l, crd);
+ scf::ValueVector ret = builder(b, l, crd);
+ YIELD(ret);
b.setInsertionPointToStart(ifOp.elseBlock());
YIELD(elseRet);
@@ -398,10 +411,10 @@ class FilterIterator : public SparseIterator {
Value genShouldFilter(OpBuilder &b, Location l);
public:
- FilterIterator(std::unique_ptr<SparseIterator> &&w, Value offset,
+ FilterIterator(std::unique_ptr<SparseIterator> &&wrap, Value offset,
Value stride, Value size)
- : SparseIterator(IterKind::kFilter, w.get()), offset(offset),
- stride(stride), size(size), wrap(std::move(w)) {}
+ : SparseIterator(IterKind::kFilter, *wrap), offset(offset),
+ stride(stride), size(size), wrap(std::move(wrap)) {}
// For LLVM-style RTTI.
static bool classof(const SparseIterator *from) {
@@ -449,47 +462,19 @@ class FilterIterator : public SparseIterator {
std::unique_ptr<SparseIterator> wrap;
};
-class SubSectIterator;
class NonEmptySubSectIterator : public SparseIterator {
-
- // The sliced pointer buffer is organized as:
- // [[itVal0, itVal1, ..., pNx0],
- // [itVal0, itVal1, ..., pNx0],
- // ...]
- Value allocSubSectPosBuf(OpBuilder &b, Location l) {
- return b.create<memref::AllocaOp>(
- l,
- MemRefType::get({ShapedType::kDynamic, tupleSz + 1}, b.getIndexType()),
- maxTupleCnt);
- }
-
- SmallVector<Value> loadItVals(OpBuilder &b, Location l, Value tupleId) const {
- SmallVector<Value> ret;
- for (unsigned i = 0; i < tupleSz; i++) {
- Value v = b.create<memref::LoadOp>(l, subSectPosBuf,
- ValueRange{tupleId, C_IDX(i)});
- ret.push_back(v);
- }
- return ret;
- }
-
- void storeItVals(OpBuilder &b, Location l, Value tupleId, ValueRange itVals) {
- assert(itVals.size() == tupleSz);
- for (unsigned i = 0; i < tupleSz; i++) {
- b.create<memref::StoreOp>(l, itVals[i], subSectPosBuf,
- ValueRange{tupleId, C_IDX(i)});
- }
- }
-
public:
+ using TraverseBuilder = llvm::function_ref<scf::ValueVector(
+ OpBuilder &, Location, const SparseIterator *, ValueRange)>;
+
NonEmptySubSectIterator(OpBuilder &b, Location l,
const SparseIterator *parent,
- std::unique_ptr<SparseIterator> &&wrap,
+ std::unique_ptr<SparseIterator> &&delegate,
Value subSectSz, unsigned stride)
- : SparseIterator(IterKind::kNonEmptySubSect, wrap->tid, wrap->lvl,
+ : SparseIterator(IterKind::kNonEmptySubSect, delegate->tid, delegate->lvl,
/*itVals=*/subSectMeta),
subSectSz(subSectSz), stride(stride), parent(parent),
- wrap(std::move(wrap)) {
+ delegate(std::move(delegate)) {
auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
assert(stride == 1);
@@ -508,38 +493,95 @@ class NonEmptySubSectIterator : public SparseIterator {
// We don't need an extra buffer to find subsections on dense levels.
if (randomAccessible())
return;
-
- tupleSz = this->wrap->serialize().size();
+ // The number of values we need to store to serialize the wrapped iterator.
+ tupleSz = this->delegate->serialize().size();
subSectPosBuf = allocSubSectPosBuf(b, l);
}
- bool randomAccessible() const override { return wrap->randomAccessible(); };
+ // For LLVM-style RTTI.
+ static bool classof(const SparseIterator *from) {
+ return from->kind == IterKind::kNonEmptySubSect;
+ }
+
+ // The sliced pointer buffer is organized as:
+ // [[itVal0, itVal1, ..., pNx0],
+ // [itVal0, itVal1, ..., pNx0],
+ // ...]
+ Value allocSubSectPosBuf(OpBuilder &b, Location l) {
+ return b.create<memref::AllocaOp>(
+ l,
+ MemRefType::get({ShapedType::kDynamic, tupleSz + 1}, b.getIndexType()),
+ maxTupleCnt);
+ }
+
+ void storeNxLvlStart(OpBuilder &b, Location l, Value tupleId,
+ Value start) const {
+ b.create<memref::StoreOp>(l, start, subSectPosBuf,
+ ValueRange{tupleId, C_IDX(tupleSz)});
+ }
+
+ Value loadNxLvlStart(OpBuilder &b, Location l, Value tupleId) const {
+ return b.create<memref::LoadOp>(l, subSectPosBuf,
+ ValueRange{tupleId, C_IDX(tupleSz)});
+ }
+
+ void storeItVals(OpBuilder &b, Location l, Value tupleId,
+ ValueRange itVals) const {
+ assert(itVals.size() == tupleSz);
+ for (unsigned i = 0; i < tupleSz; i++) {
+ b.create<memref::StoreOp>(l, itVals[i], subSectPosBuf,
+ ValueRange{tupleId, C_IDX(i)});
+ }
+ }
+
+ SmallVector<Value> loadItVals(OpBuilder &b, Location l, Value tupleId) const {
+ SmallVector<Value> ret;
+ for (unsigned i = 0; i < tupleSz; i++) {
+ Value v = b.create<memref::LoadOp>(l, subSectPosBuf,
+ ValueRange{tupleId, C_IDX(i)});
+ ret.push_back(v);
+ }
+ return ret;
+ }
+
+ bool isSubSectRoot() const {
+ return !parent || !llvm::isa<NonEmptySubSectIterator>(parent);
+ }
+
+ ValueRange genSubSectTraverseTillRoot(OpBuilder &b, Location l,
+ ValueRange reduc,
+ TraverseBuilder builder) const;
+
+ bool randomAccessible() const override {
+ return delegate->randomAccessible();
+ };
bool iteratableByFor() const override { return randomAccessible(); };
Value upperBound(OpBuilder &b, Location l) const override {
auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
Value parentUB =
- p && p->lvl == lvl ? p->upperBound(b, l) : wrap->upperBound(b, l);
+ p && p->lvl == lvl ? p->upperBound(b, l) : delegate->upperBound(b, l);
return ADDI(SUBI(parentUB, subSectSz), C_IDX(1));
};
- // For LLVM-style RTTI.
- static bool classof(const SparseIterator *from) {
- return from->kind == IterKind::kNonEmptySubSect;
- }
-
void genInit(OpBuilder &b, Location l, const SparseIterator *) override;
void locate(OpBuilder &b, Location l, Value crd) override {
Value absOff = crd;
auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
- if (p && p->lvl == lvl)
- absOff = ADDI(crd, p->getAbsOff());
+ if (isSubSectRoot())
+ delegate->locate(b, l, absOff);
+ else
+ assert(p->lvl + 1 == lvl);
- wrap->locate(b, l, absOff);
seek(ValueRange{absOff, absOff, C_TRUE});
updateCrd(crd);
}
+ Value toSubSectCrd(OpBuilder &b, Location l, Value wrapCrd) const {
+ assert(stride == 1);
+ return SUBI(wrapCrd, getAbsOff());
+ }
+
Value genNotEnd(OpBuilder &b, Location l) override { return getNotEnd(); };
Value deref(OpBuilder &b, Location l) override {
@@ -571,37 +613,73 @@ class NonEmptySubSectIterator : public SparseIterator {
const unsigned stride;
const SparseIterator *parent;
- std::unique_ptr<SparseIterator> wrap;
+ std::unique_ptr<SparseIterator> delegate;
Value subSectMeta[3]; // minCrd, absolute offset, notEnd
+};
+
+class SubSectIterator;
+
+// A simple helper that helps generating code to traverse a subsection, used
+// by both `NonEmptySubSectIterator`and `SubSectIterator`.
+struct SubSectIterHelper {
+ explicit SubSectIterHelper(const SubSectIterator &iter);
+ explicit SubSectIterHelper(const NonEmptySubSectIterator &subSect);
+
+ // Delegate methods.
+ void deserializeFromTupleId(OpBuilder &b, Location l, Value tupleId);
+ void locate(OpBuilder &b, Location l, Value crd);
+ Value genNotEnd(OpBuilder &b, Location l);
+ Value deref(OpBuilder &b, Location l);
+ ValueRange forward(OpBuilder &b, Location l);
- friend SubSectIterator;
+ const NonEmptySubSectIterator &subSect;
+ SparseIterator &wrap;
};
class SubSectIterator : public SparseIterator {
- Value fromWrapCrd(OpBuilder &b, Location l, Value wrapCrd) {
- assert(stride == 1);
- return SUBI(wrapCrd, subSect.getAbsOff());
- }
+ // RAII to sync iterator values between the wrap the iterator and the
+ // SubSectIterator.
+ struct WrapItValSyncer {
+ explicit WrapItValSyncer(SubSectIterator &it) : it(it) {
+ if (!it.randomAccessible())
+ it.wrap->seek(it.getItVals().drop_back());
+ }
+ ~WrapItValSyncer() {
+ if (!it.randomAccessible()) {
+ ValueRange wrapItVals = it.wrap->getItVals();
+ std::copy(wrapItVals.begin(), wrapItVals.end(), it.itVals.begin());
+ }
+ }
+ SubSectIterator ⁢
+ };
public:
SubSectIterator(const NonEmptySubSectIterator &subSect,
const SparseIterator &parent,
std::unique_ptr<SparseIterator> &&wrap, Value size,
unsigned stride)
- : SparseIterator(IterKind::kSubSect, wrap.get()), subSect(subSect),
- parent(parent), wrap(std::move(wrap)), size(size), stride(stride) {
+ : SparseIterator(IterKind::kSubSect, *wrap), itVals(), subSect(subSect),
+ wrap(std::move(wrap)), parent(parent), size(size), stride(stride),
+ helper(*this) {
assert(stride == 1 && "Not implemented.");
assert(subSect.tid == tid && subSect.lvl == lvl);
- // The immediate parents of a subsection iterator is either a non-empty
- // subsect iterator or another subsection iterator for the previous level
- // depending on the index varaiables' reduction order.
- assert(parent.kind == IterKind::kNonEmptySubSect ||
- parent.kind == IterKind::kSubSect);
- assert(parent.kind != IterKind::kNonEmptySubSect || &parent == &subSect);
assert(parent.kind != IterKind::kSubSect || parent.lvl + 1 == lvl);
+
+ if (!randomAccessible()) {
+ // We maintain a extra counter to count the actually sparse coordinate
+ // included in the subsection.
+ unsigned itValSz = this->wrap->getItVals().size() + 1;
+ itVals.resize(itValSz, nullptr);
+ relinkItVals(itVals);
+ }
};
+ // For LLVM-style RTTI.
+ static bool classof(const SparseIterator *from) {
+ return from->kind == IterKind::kSubSect;
+ }
+
bool randomAccessible() const override { return wrap->randomAccessible(); };
bool iteratableByFor() const override { return randomAccessible(); };
Value upperBound(OpBuilder &b, Location l) const override { return size; }
@@ -609,55 +687,85 @@ class SubSectIterator : public SparseIterator {
return wrap->getCurPosition();
};
+ Value getNxLvlTupleId(OpBuilder &b, Location l) const {
+ if (randomAccessible()) {
+ return ADDI(getCrd(), nxLvlTupleStart);
+ };
+ return ADDI(itVals.back(), nxLvlTupleStart);
+ }
+
void genInit(OpBuilder &b, Location l, const SparseIterator *) override {
- if (llvm::isa<NonEmptySubSectIterator>(parent)) {
- if (randomAccessible()) {
- // We continue from the parent's offset.
- wrap->deserialize(subSect.wrap->serialize());
- return;
+ WrapItValSyncer syncer(*this);
+ if (randomAccessible()) {
+ if (auto *p = llvm::dyn_cast<SubSectIterator>(&parent)) {
+ assert(p->lvl + 1 == lvl);
+ wrap->genInit(b, l, p);
+ // Linearize the dense subsection index.
+ nxLvlTupleStart = MULI(size, p->getNxLvlTupleId(b, l));
+ } else {
+ assert(subSect.lvl == lvl && subSect.isSubSectRoot());
+ wrap->deserialize(subSect.delegate->serialize());
+ nxLvlTupleStart = C_IDX(0);
}
- // Else deserializing from the cached values.
- wrap->deserialize(subSect.loadItVals(b, l, C_IDX(0)));
+ return;
+ }
+ assert(!randomAccessible());
+ assert(itVals.size() == wrap->getItVals().size() + 1);
+ // Extra counter that counts the number of actually visited coordinates in
+ // the sparse subsection.
+ itVals.back() = C_IDX(0);
+ Value tupleId;
+ if (auto *p = llvm::dyn_cast<SubSectIterator>(&parent)) {
+ assert(p->lvl + 1 == lvl);
+ tupleId = p->getNxLvlTupleId(b, l);
} else {
- llvm_unreachable("Not implemented");
+ assert(subSect.lvl == lvl && subSect.isSubSectRoot());
+ tupleId = C_IDX(0);
}
+ nxLvlTupleStart = subSect.loadNxLvlStart(b, l, tupleId);
+ helper.deserializeFromTupleId(b, l, tupleId);
}
void locate(OpBuilder &b, Location l, Value crd) override {
- Value absCrd = ADDI(crd, subSect.getAbsOff());
- wrap->locate(b, l, absCrd);
+ WrapItValSyncer syncer(*this);
+ helper.locate(b, l, crd);
updateCrd(crd);
}
Value genNotEnd(OpBuilder &b, Location l) override {
- assert(!wrap->randomAccessible());
- ValueRange r = genWhenInBound(
- b, l, *wrap, C_FALSE, [this](OpBuilder &b, Location l, Value wrapCrd) {
- Value crd = fromWrapCrd(b, l, wrapCrd);
- // crd < size
- YIELD(CMPI(ult, crd, size));
- });
- assert(r.size() == 1);
- return r.front();
+ WrapItValSyncer syncer(*this);
+ return helper.genNotEnd(b, l);
}
Value deref(OpBuilder &b, Location l) override {
- Value wrapCrd = wrap->deref(b, l);
- Value crd = fromWrapCrd(b, l, wrapCrd);
+ WrapItValSyncer syncer(*this);
+ Value crd = helper.deref(b, l);
updateCrd(crd);
return crd;
};
ValueRange forward(OpBuilder &b, Location l) override {
- return wrap->forward(b, l);
+ {
+ WrapItValSyncer syncer(*this);
+ helper.forward(b, l);
+ }
+ assert(!randomAccessible());
+ assert(itVals.size() == wrap->getItVals().size() + 1);
+ itVals.back() = ADDI(itVals.back(), C_IDX(1));
+ return getItVals();
};
+ SmallVector<Value> itVals;
+ Value nxLvlTupleStart;
+
const NonEmptySubSectIterator &subSect;
+ std::unique_ptr<SparseIterator> wrap;
const SparseIterator &parent;
- std::unique_ptr<SparseIterator> wrap;
Value size;
unsigned stride;
+
+ SubSectIterHelper helper;
};
} // namespace
@@ -725,10 +833,11 @@ Value FilterIterator::genCrdNotLegitPredicate(OpBuilder &b, Location l,
}
Value FilterIterator::genShouldFilter(OpBuilder &b, Location l) {
- ValueRange r = genWhenInBound(
- b, l, *wrap, C_FALSE, [this](OpBuilder &b, Location l, Value wrapCrd) {
+ auto r = genWhenInBound(
+ b, l, *wrap, C_FALSE,
+ [this](OpBuilder &b, Location l, Value wrapCrd) -> scf::ValueVector {
Value notLegit = genCrdNotLegitPredicate(b, l, wrapCrd);
- YIELD(notLegit);
+ return {notLegit};
});
assert(r.size() == 1);
@@ -737,11 +846,12 @@ Value FilterIterator::genShouldFilter(OpBuilder &b, Location l) {
Value FilterIterator::genNotEnd(OpBuilder &b, Location l) {
assert(!wrap->randomAccessible());
- ValueRange r = genWhenInBound(
- b, l, *wrap, C_FALSE, [this](OpBuilder &b, Location l, Value wrapCrd) {
+ auto r = genWhenInBound(
+ b, l, *wrap, C_FALSE,
+ [this](OpBuilder &b, Location l, Value wrapCrd) -> scf::ValueVector {
Value crd = fromWrapCrd(b, l, wrapCrd);
// crd < size
- YIELD(CMPI(ult, crd, size));
+ return {CMPI(ult, crd, size)};
});
assert(r.size() == 1);
return r.front();
@@ -762,13 +872,14 @@ ValueRange FilterIterator::forward(OpBuilder &b, Location l) {
linkNewScope(ivs);
ValueRange cont =
genWhenInBound(b, l, *wrap, C_FALSE,
- [this](OpBuilder &b, Location l, Value wrapCrd) {
+ [this](OpBuilder &b, Location l,
+ Value wrapCrd) -> scf::ValueVector {
// crd < size && !legit();
Value notLegit =
genCrdNotLegitPredicate(b, l, wrapCrd);
Value crd = fromWrapCrd(b, l, wrapCrd);
Value ret = ANDI(CMPI(ult, crd, size), notLegit);
- YIELD(ret);
+ return {ret};
});
b.create<scf::ConditionOp>(l, cont.front(), ivs);
},
@@ -784,31 +895,201 @@ ValueRange FilterIterator::forward(OpBuilder &b, Location l) {
return getItVals();
}
+SubSectIterHelper::SubSectIterHelper(const NonEmptySubSectIterator &subSect)
+ : subSect(subSect), wrap(*subSect.delegate) {}
+
+SubSectIterHelper::SubSectIterHelper(const SubSectIterator &iter)
+ : subSect(iter.subSect), wrap(*iter.wrap) {}
+
+void SubSectIterHelper::deserializeFromTupleId(OpBuilder &b, Location l,
+ Value tupleId) {
+ assert(!subSect.randomAccessible());
+ wrap.deserialize(subSect.loadItVals(b, l, tupleId));
+}
+
+void SubSectIterHelper::locate(OpBuilder &b, Location l, Value crd) {
+ Value absCrd = ADDI(crd, subSect.getAbsOff());
+ wrap.locate(b, l, absCrd);
+}
+
+Value SubSectIterHelper::genNotEnd(OpBuilder &b, Location l) {
+ assert(!wrap.randomAccessible());
+ auto r = genWhenInBound(
+ b, l, wrap, C_FALSE,
+ [this](OpBuilder &b, Location l, Value wrapCrd) -> scf::ValueVector {
+ Value crd = SUBI(wrapCrd, subSect.getAbsOff());
+ // crd < size
+ return {CMPI(ult, crd, subSect.subSectSz)};
+ });
+ assert(r.size() == 1);
+ return r.front();
+}
+
+Value SubSectIterHelper::deref(OpBuilder &b, Location l) {
+ Value wrapCrd = wrap.deref(b, l);
+ Value crd = subSect.toSubSectCrd(b, l, wrapCrd);
+ return crd;
+}
+
+ValueRange SubSectIterHelper::forward(OpBuilder &b, Location l) {
+ return wrap.forward(b, l);
+}
+
+ValueRange NonEmptySubSectIterator::genSubSectTraverseTillRoot(
+ OpBuilder &b, Location l, ValueRange reduc, TraverseBuilder builder) const {
+ // Set up the helper to help traverse a sparse subsection.
+ SubSectIterHelper helper(*this);
+ if (!randomAccessible()) {
+ // The subsection tree have been expanded till the level and cached,
+ // traverse all the leaves and expanded to the next level.
+ SmallVector<Value> iterArgs;
+ iterArgs.push_back(C_IDX(0));
+ iterArgs.append(reduc.begin(), reduc.end());
+ auto forEachLeaf = b.create<scf::ForOp>(
+ l, /*lb=*/C_IDX(0), /*ub=*/tupleCnt, /*step=*/C_IDX(1), iterArgs,
+ [&helper, &builder](OpBuilder &b, Location l, Value tupleId,
+ ValueRange iterArgs) {
+ // Deserialize the iterator at the cached position (tupleId).
+ helper.deserializeFromTupleId(b, l, tupleId);
+
+ Value cnt = iterArgs.front();
+ // Record the number of leaf nodes included in the subsection.
+ // The number indicates the starting tupleId for the next level that
+ // is corresponding to the current node.
+ helper.subSect.storeNxLvlStart(b, l, tupleId, cnt);
+
+ SmallVector<Value> whileArgs(helper.wrap.getItVals());
+ whileArgs.append(iterArgs.begin(), iterArgs.end());
+
+ auto whileOp = b.create<scf::WhileOp>(
+ l, ValueRange(whileArgs).getTypes(), whileArgs,
+ /*beforeBuilder=*/
+ [&helper](OpBuilder &b, Location l, ValueRange ivs) {
+ helper.wrap.linkNewScope(ivs);
+ b.create<scf::ConditionOp>(l, helper.genNotEnd(b, l), ivs);
+ },
+ /*afterBuilder=*/
+ [&helper, &builder](OpBuilder &b, Location l, ValueRange ivs) {
+ ValueRange remIter = helper.wrap.linkNewScope(ivs);
+ Value cnt = remIter.front();
+ ValueRange userIter = remIter.drop_front();
+ scf::ValueVector userNx = builder(b, l, &helper.wrap, userIter);
+
+ SmallVector<Value> nxIter = helper.forward(b, l);
+ nxIter.push_back(ADDI(cnt, C_IDX(1)));
+ nxIter.append(userNx.begin(), userNx.end());
+ YIELD(nxIter);
+ });
+ ValueRange res = helper.wrap.linkNewScope(whileOp.getResults());
+ YIELD(res);
+ });
+ return forEachLeaf.getResults().drop_front();
+ }
+
+ assert(randomAccessible());
+ // Helper lambda that traverse the current dense subsection range.
+ auto visitDenseSubSect = [&, this](OpBuilder &b, Location l,
+ const SparseIterator *parent,
+ ValueRange reduc) {
+ assert(!parent || parent->lvl + 1 == lvl);
+ delegate->genInit(b, l, parent);
+ auto forOp = b.create<scf::ForOp>(
+ l, /*lb=*/C_IDX(0), /*ub=*/subSectSz, /*step=*/C_IDX(1), reduc,
+ [&](OpBuilder &b, Location l, Value crd, ValueRange iterArgs) {
+ helper.locate(b, l, crd);
+ scf::ValueVector nx = builder(b, l, &helper.wrap, iterArgs);
+ YIELD(nx);
+ });
+ return forOp.getResults();
+ };
+
+ if (isSubSectRoot()) {
+ return visitDenseSubSect(b, l, parent, reduc);
+ }
+ // Else, this is not the root, recurse until root.
+ auto *p = llvm::cast<NonEmptySubSectIterator>(parent);
+ assert(p->lvl + 1 == lvl);
+ return p->genSubSectTraverseTillRoot(b, l, reduc, visitDenseSubSect);
+}
+
void NonEmptySubSectIterator::genInit(OpBuilder &b, Location l,
const SparseIterator *) {
- auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
- if (p) {
- llvm_unreachable("Not implemented");
- } else {
- wrap->genInit(b, l, parent);
- Value c0 = C_IDX(0);
+ Value c0 = C_IDX(0);
+ if (!isSubSectRoot()) {
+ assert(parent->lvl + 1 == lvl);
+ // We can not call wrap->genInit() here to initialize the wrapped iterator,
+ // because the parent of the curent iterator is still unresolved.
if (randomAccessible()) {
seek({/*minCrd=*/c0, /*offset=*/c0, /*notEnd=*/C_TRUE});
return;
}
- // Handle sparse subsection iterator.
- tupleCnt = C_IDX(1);
- SmallVector<Value> elseRet{c0, c0, /*notEnd=*/C_FALSE};
- ValueRange meta = genWhenInBound(
- b, l, *wrap, elseRet, [this](OpBuilder &b, Location l, Value crd) {
- Value offset = offsetFromMinCrd(b, l, crd, subSectSz);
- YIELD((ValueRange{crd, offset, C_TRUE}));
+
+ auto *p = cast<NonEmptySubSectIterator>(parent);
+
+ SmallVector<Value, 3> reduc = {
+ C_IDX(-1), // minCrd (max signless integer)
+ c0, // tupleId
+ };
+
+ ValueRange result = p->genSubSectTraverseTillRoot(
+ b, l, reduc,
+ [this](OpBuilder &b, Location l, const SparseIterator *parent,
+ ValueRange reduc) -> scf::ValueVector {
+ assert(parent->lvl + 1 == lvl && reduc.size() == 2);
+ Value minCrd = reduc.front();
+ Value tupleId = reduc.back();
+
+ // Initialize the subsection range.
+ SubSectIterHelper helper(*this);
+ helper.wrap.genInit(b, l, parent);
+
+ // Update minCrd.
+ minCrd = genWhenInBound(b, l, helper.wrap, minCrd,
+ [minCrd](OpBuilder &b, Location l,
+ Value crd) -> scf::ValueVector {
+ Value min = MINUI(crd, minCrd);
+ return {min};
+ })
+ .front();
+
+ // Cache the sparse range.
+ storeItVals(b, l, tupleId, helper.wrap.serialize());
+ tupleId = ADDI(tupleId, C_IDX(1));
+ return {minCrd, tupleId};
});
+ assert(result.size() == 2);
+ tupleCnt = result.back();
+
+ Value minCrd = result.front();
+ Value absOff = offsetFromMinCrd(b, l, minCrd, subSectSz);
+ Value notEnd = CMPI(ne, minCrd, C_IDX(-1));
+ seek({minCrd, absOff, notEnd});
+ return;
+ }
+
+ // This is the root level of the subsection, which means that it is resolved
+ // to one node.
+ assert(isSubSectRoot());
- seek(meta);
- SmallVector<Value> itVals = wrap->serialize();
- storeItVals(b, l, c0, itVals);
+ delegate->genInit(b, l, parent);
+ if (randomAccessible()) {
+ seek({/*minCrd=*/c0, /*offset=*/c0, /*notEnd=*/C_TRUE});
+ return;
}
+
+ // Only have one root node.
+ tupleCnt = C_IDX(1);
+ // Cache the sparse range.
+ storeItVals(b, l, c0, delegate->serialize());
+ SmallVector<Value> elseRet{c0, c0, /*notEnd=*/C_FALSE};
+ auto meta = genWhenInBound(
+ b, l, *delegate, elseRet,
+ [this](OpBuilder &b, Location l, Value crd) -> scf::ValueVector {
+ Value offset = offsetFromMinCrd(b, l, crd, subSectSz);
+ return {crd, offset, C_TRUE};
+ });
+
+ seek(meta);
}
ValueRange NonEmptySubSectIterator::forward(OpBuilder &b, Location l) {
@@ -844,37 +1125,39 @@ ValueRange NonEmptySubSectIterator::forward(OpBuilder &b, Location l) {
// offset = minCrd - size + 1;
// }
b.setInsertionPointToStart(&ifOp.getElseRegion().front());
- ValueRange loopArgs{upperBound(b, l), // nextMinCrd
- C_FALSE}; // isNotEnd
+ ValueRange loopArgs{C_IDX(-1), // nextMinCrd
+ C_FALSE}; // isNotEnd
auto loopNest = scf::buildLoopNest(
b, l, c0, tupleCnt, c1, loopArgs,
[this](OpBuilder &b, Location l, ValueRange ivs,
ValueRange iterArgs) -> scf::ValueVector {
Value tupleId = ivs.front();
- SmallVector<Value> itVals = loadItVals(b, l, tupleId);
- wrap->deserialize(itVals);
+ SubSectIterHelper helper(*this);
+ helper.deserializeFromTupleId(b, l, tupleId);
+
return genWhenInBound(
- b, l, *wrap, /*elseRet=*/iterArgs,
- [this, iterArgs, tupleId](OpBuilder &b, Location l, Value crd) {
+ b, l, *delegate, /*elseRet=*/iterArgs,
+ [this, iterArgs, tupleId](OpBuilder &b, Location l,
+ Value crd) -> scf::ValueVector {
// if coord == minCrd
// wrap->forward();
Value isMin = CMPI(eq, crd, getMinCrd());
- wrap->forwardIf(b, l, isMin);
+ delegate->forwardIf(b, l, isMin);
// Update the forwarded iterator values if needed.
auto ifIsMin = b.create<scf::IfOp>(l, isMin, false);
b.setInsertionPointToStart(&ifIsMin.getThenRegion().front());
- storeItVals(b, l, tupleId, wrap->serialize());
+ storeItVals(b, l, tupleId, delegate->serialize());
b.setInsertionPointAfter(ifIsMin);
// if (!wrap.end())
// yield(min(nxMinCrd, *wrap), true)
Value nxMin = iterArgs[0];
- ValueRange ret = genWhenInBound(
- b, l, *wrap, /*elseRet=*/iterArgs,
- [nxMin](OpBuilder &b, Location l, Value crd) {
- Value nx = SELECT(CMPI(ult, crd, nxMin), crd, nxMin);
- YIELD((ValueRange{nx, C_TRUE}));
- });
- YIELD(ret);
+ return genWhenInBound(b, l, *delegate, /*elseRet=*/iterArgs,
+ [nxMin](OpBuilder &b, Location l,
+ Value crd) -> scf::ValueVector {
+ Value nx = b.create<arith::MinUIOp>(
+ l, crd, nxMin);
+ return {nx, C_TRUE};
+ });
});
});
@@ -893,7 +1176,7 @@ ValueRange NonEmptySubSectIterator::forward(OpBuilder &b, Location l) {
// We should at least forward the offset by one.
Value minAbsOff = ADDI(getAbsOff(), c1);
- nxAbsOff = SELECT(CMPI(ugt, minAbsOff, nxAbsOff), minAbsOff, nxAbsOff);
+ nxAbsOff = b.create<arith::MaxUIOp>(l, minAbsOff, nxAbsOff);
assert(stride == 1 && "Not yet implemented");
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
index 6f6d28e24c2750..9d5904cf456828 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
@@ -81,9 +81,9 @@ class SparseIterator {
MutableArrayRef<Value> itVals)
: kind(kind), tid(tid), lvl(lvl), crd(nullptr), itVals(itVals){};
- SparseIterator(IterKind kind, const SparseIterator *wrap)
- : kind(kind), tid(wrap->tid), lvl(wrap->lvl), crd(nullptr),
- itVals(wrap->itVals){};
+ SparseIterator(IterKind kind, const SparseIterator &wrap)
+ : kind(kind), tid(wrap.tid), lvl(wrap.lvl), crd(nullptr),
+ itVals(wrap.itVals){};
public:
virtual ~SparseIterator() = default;
@@ -93,8 +93,7 @@ class SparseIterator {
ValueRange getItVals() const { return itVals; };
void seek(ValueRange vals) {
assert(vals.size() == itVals.size());
- for (unsigned i = 0, e = vals.size(); i < e; i++)
- itVals[i] = vals[i];
+ std::copy(vals.begin(), vals.end(), itVals.begin());
// Now that the iterator is re-positioned, the coordinate becomes invalid.
crd = nullptr;
}
@@ -132,11 +131,13 @@ class SparseIterator {
//
// Get the current position and the optional *position high* (for non-unique
- // iterators), the value should be able to uniquely identify the sparse range
- // for the next level. See SparseTensorLevel::peekRangeAt();
+ // iterators), the value is essentially the number of sparse coordinate that
+ // the iterator is current visiting. It should be able to uniquely identify
+ // the sparse range for the next level. See SparseTensorLevel::peekRangeAt();
//
- // Not every type of iterator supports the operations, e.g., non-empty
- // subsection iterator does not.
+ // Not every type of iterator supports the operation, e.g., non-empty
+ // subsection iterator does not because it represent a range of coordinates
+ // instead of just one.
virtual std::pair<Value, Value> getCurPosition() const {
llvm_unreachable("unsupported");
};
@@ -148,7 +149,7 @@ class SparseIterator {
virtual std::pair<Value, Value> genForCond(OpBuilder &b, Location l) {
assert(randomAccessible());
// Random-access iterator is traversed by coordinate, i.e., [curCrd, UB).
- return {deref(b, l), upperBound(b, l)};
+ return {getCrd(), upperBound(b, l)};
}
virtual Value genNotEnd(OpBuilder &b, Location l) = 0;
@@ -196,6 +197,7 @@ class SparseIterator {
protected:
void updateCrd(Value crd) { this->crd = crd; }
+ void relinkItVals(MutableArrayRef<Value> itVals) { this->itVals = itVals; }
public:
const IterKind kind; // For LLVM-style RTTI.
@@ -205,7 +207,7 @@ class SparseIterator {
Value crd; // The sparse coordinate used to coiterate;
// A range of value that together defines the current state of the
- // iterator.
+ // iterator. Only loop variants should be included.
//
// For trivial iterators, it is the position; for dedup iterators, it consists
// of the positon and the segment high, for non-empty subsection iterator, it
>From 96df4eac604d095c91ede297c2b7d9c2b240a7a2 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 16 Jan 2024 18:22:09 +0000
Subject: [PATCH 07/11] pass all integration tests.
---
.../Transforms/Utils/SparseTensorLevel.cpp | 95 ++++++++++++++-----
.../Transforms/Utils/SparseTensorLevel.h | 5 +-
2 files changed, 75 insertions(+), 25 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
index c7bc365b89c32d..dac9e4e012b4e6 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
@@ -394,6 +394,10 @@ class DedupIterator : public SparseIterator {
const SparseTensorLevel &stl;
};
+//
+// A filter iterator wrapped from another iterator. The filter iterator update
+// the wrapped iterator *in-place*.
+//
class FilterIterator : public SparseIterator {
// Coorindate translation between crd loaded from the wrap iterator and the
// filter iterator.
@@ -411,6 +415,8 @@ class FilterIterator : public SparseIterator {
Value genShouldFilter(OpBuilder &b, Location l);
public:
+ // TODO: avoid unnessary check when offset == 0 and/or when stride == 1 and/or
+ // when crd always < size.
FilterIterator(std::unique_ptr<SparseIterator> &&wrap, Value offset,
Value stride, Value size)
: SparseIterator(IterKind::kFilter, *wrap), offset(offset),
@@ -548,9 +554,10 @@ class NonEmptySubSectIterator : public SparseIterator {
return !parent || !llvm::isa<NonEmptySubSectIterator>(parent);
}
- ValueRange genSubSectTraverseTillRoot(OpBuilder &b, Location l,
- ValueRange reduc,
- TraverseBuilder builder) const;
+ // Generate code that inflate the current subsection tree till the current
+ // level such that every leaf node is visited.
+ ValueRange inflateSubSectTree(OpBuilder &b, Location l, ValueRange reduc,
+ TraverseBuilder builder) const;
bool randomAccessible() const override {
return delegate->randomAccessible();
@@ -861,24 +868,35 @@ ValueRange FilterIterator::forward(OpBuilder &b, Location l) {
assert(!randomAccessible());
// Generates
//
- // wrap ++;
- // while !it.end() && !legit(*it)
+ // bool isFirst = true;
+ // while !it.end() && (!legit(*it) || isFirst)
// wrap ++;
- wrap->forward(b, l);
+ // isFirst = false;
+ //
+ // We do not hoist the first `wrap++` outside the loop but use a `isFirst`
+ // flag here because `wrap++` might have a complex implementation (e.g., to
+ // forward a subsection).
+ Value isFirst = constantI1(b, l, true);
+
+ SmallVector<Value> whileArgs(getItVals().begin(), getItVals().end());
+ whileArgs.push_back(isFirst);
+
auto whileOp = b.create<scf::WhileOp>(
- l, getItVals().getTypes(), getItVals(),
+ l, ValueRange(whileArgs).getTypes(), whileArgs,
/*beforeBuilder=*/
[this](OpBuilder &b, Location l, ValueRange ivs) {
- linkNewScope(ivs);
+ ValueRange isFirst = linkNewScope(ivs);
+ assert(isFirst.size() == 1);
ValueRange cont =
genWhenInBound(b, l, *wrap, C_FALSE,
- [this](OpBuilder &b, Location l,
- Value wrapCrd) -> scf::ValueVector {
+ [this, isFirst](OpBuilder &b, Location l,
+ Value wrapCrd) -> scf::ValueVector {
// crd < size && !legit();
Value notLegit =
genCrdNotLegitPredicate(b, l, wrapCrd);
Value crd = fromWrapCrd(b, l, wrapCrd);
Value ret = ANDI(CMPI(ult, crd, size), notLegit);
+ ret = ORI(ret, isFirst.front());
return {ret};
});
b.create<scf::ConditionOp>(l, cont.front(), ivs);
@@ -887,7 +905,9 @@ ValueRange FilterIterator::forward(OpBuilder &b, Location l) {
[this](OpBuilder &b, Location l, ValueRange ivs) {
linkNewScope(ivs);
wrap->forward(b, l);
- YIELD(getItVals());
+ SmallVector<Value> yieldVals(getItVals().begin(), getItVals().end());
+ yieldVals.push_back(constantI1(b, l, false));
+ YIELD(yieldVals);
});
b.setInsertionPointAfter(whileOp);
@@ -935,7 +955,7 @@ ValueRange SubSectIterHelper::forward(OpBuilder &b, Location l) {
return wrap.forward(b, l);
}
-ValueRange NonEmptySubSectIterator::genSubSectTraverseTillRoot(
+ValueRange NonEmptySubSectIterator::inflateSubSectTree(
OpBuilder &b, Location l, ValueRange reduc, TraverseBuilder builder) const {
// Set up the helper to help traverse a sparse subsection.
SubSectIterHelper helper(*this);
@@ -1009,7 +1029,7 @@ ValueRange NonEmptySubSectIterator::genSubSectTraverseTillRoot(
// Else, this is not the root, recurse until root.
auto *p = llvm::cast<NonEmptySubSectIterator>(parent);
assert(p->lvl + 1 == lvl);
- return p->genSubSectTraverseTillRoot(b, l, reduc, visitDenseSubSect);
+ return p->inflateSubSectTree(b, l, reduc, visitDenseSubSect);
}
void NonEmptySubSectIterator::genInit(OpBuilder &b, Location l,
@@ -1017,21 +1037,22 @@ void NonEmptySubSectIterator::genInit(OpBuilder &b, Location l,
Value c0 = C_IDX(0);
if (!isSubSectRoot()) {
assert(parent->lvl + 1 == lvl);
- // We can not call wrap->genInit() here to initialize the wrapped iterator,
- // because the parent of the curent iterator is still unresolved.
if (randomAccessible()) {
+ // We can not call wrap->genInit() here to initialize the wrapped
+ // iterator, because the parent of the curent iterator is still
+ // unresolved.
seek({/*minCrd=*/c0, /*offset=*/c0, /*notEnd=*/C_TRUE});
return;
}
auto *p = cast<NonEmptySubSectIterator>(parent);
-
SmallVector<Value, 3> reduc = {
C_IDX(-1), // minCrd (max signless integer)
c0, // tupleId
};
- ValueRange result = p->genSubSectTraverseTillRoot(
+ // Expand the subsection tree from the parent level to the current level.
+ ValueRange result = p->inflateSubSectTree(
b, l, reduc,
[this](OpBuilder &b, Location l, const SparseIterator *parent,
ValueRange reduc) -> scf::ValueVector {
@@ -1071,6 +1092,8 @@ void NonEmptySubSectIterator::genInit(OpBuilder &b, Location l,
// to one node.
assert(isSubSectRoot());
+ // Initialize the position, the position marks the *lower bound* of the
+ // subRange. The higher bound is determined by the size of the subsection.
delegate->genInit(b, l, parent);
if (randomAccessible()) {
seek({/*minCrd=*/c0, /*offset=*/c0, /*notEnd=*/C_TRUE});
@@ -1251,19 +1274,45 @@ sparse_tensor::makeSlicedLevelIterator(std::unique_ptr<SparseIterator> &&sit,
return std::make_unique<FilterIterator>(std::move(sit), offset, stride, size);
}
+template <typename IterType>
+static const SparseIterator *tryUnwrapFilter(const SparseIterator *it) {
+ auto *filter = llvm::dyn_cast_or_null<FilterIterator>(it);
+ if (filter && llvm::isa<IterType>(filter->wrap.get())) {
+ return filter->wrap.get();
+ }
+ return it;
+}
+template <typename IterType>
+static const IterType *unwrapFilter(const SparseIterator *it) {
+ auto *filter = llvm::dyn_cast_or_null<FilterIterator>(it);
+ if (filter) {
+ return llvm::cast<IterType>(filter->wrap.get());
+ }
+ return llvm::cast<IterType>(it);
+}
+
std::unique_ptr<SparseIterator> sparse_tensor::makeNonEmptySubSectIterator(
OpBuilder &b, Location l, const SparseIterator *parent,
std::unique_ptr<SparseIterator> &&delegate, Value size, unsigned stride) {
- return std::make_unique<NonEmptySubSectIterator>(
- b, l, parent, std::move(delegate), size, stride);
+
+ // Try unwrap the NonEmptySubSectIterator from a filter parent.
+ parent = tryUnwrapFilter<NonEmptySubSectIterator>(parent);
+ auto it = std::make_unique<NonEmptySubSectIterator>(
+ b, l, parent, std::move(delegate), size, 1);
+
+ if (stride != 1)
+ return std::make_unique<FilterIterator>(std::move(it), /*offset=*/C_IDX(0),
+ C_IDX(stride), /*size=*/C_IDX(-1));
+ return it;
}
std::unique_ptr<SparseIterator> sparse_tensor::makeTraverseSubSectIterator(
- const SparseIterator &subsectIter, const SparseIterator &parent,
+ const SparseIterator &subSectIter, const SparseIterator &parent,
std::unique_ptr<SparseIterator> &&wrap, Value size, unsigned stride) {
- return std::make_unique<SubSectIterator>(
- llvm::cast<NonEmptySubSectIterator>(subsectIter), parent, std::move(wrap),
- size, stride);
+ // This must be a subsection iterator or a filtered subsection iterator.
+ auto &subSect = *unwrapFilter<NonEmptySubSectIterator>(&subSectIter);
+ return std::make_unique<SubSectIterator>(subSect, parent, std::move(wrap),
+ size, stride);
}
#undef CMPI
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
index 9d5904cf456828..1233f0099aa546 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
@@ -114,11 +114,12 @@ class SparseIterator {
virtual Value upperBound(OpBuilder &b, Location l) const = 0;
// Serialize and deserialize the current status to/from a set of values. The
- // ValueRange should contain values that specifies the postion and loop bound.
+ // ValueRange should contain values that specifies the current postion and
+ // loop bound.
//
// Not every type of iterator supports the operations, e.g., non-empty
// subsection iterator does not because the the number of non-empty
- // subsections can not be determined in advance.
+ // subsections can not be determined easily.
//
// NOTE: All the values should have index type.
virtual SmallVector<Value> serialize() const {
>From 485382110f02a9a78e5d0b91c5977540cde606bf Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 16 Jan 2024 18:22:40 +0000
Subject: [PATCH 08/11] cleanup LoopEmitter
---
.../Transforms/SparseTensorRewriting.cpp | 2 +-
.../Transforms/Sparsification.cpp | 4 +-
.../Transforms/Utils/LoopEmitter.cpp | 1543 +----------------
.../Transforms/Utils/LoopEmitter.h | 326 +---
4 files changed, 43 insertions(+), 1832 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index d32b8520f38618..e47a4db6cffbc7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -1126,7 +1126,7 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
}
Value vals = loopEmitter.getValBuffer()[0];
- Value pos = loopEmitter.getPosits()[0].back();
+ Value pos = loopEmitter.getValPosits(0);
// Loads the value from sparse tensor using position-index;
// loads the value from dense tensor using coords.
Value val = enc ? rewriter.create<memref::LoadOp>(loc, vals, pos)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 5d890e8b035d0c..6e1670bcc7dc44 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -354,7 +354,7 @@ static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
const auto stt = getSparseTensorType(t->get());
if (stt.hasEncoding()) {
// For sparse tensors we only push the last-level's position onto `args`.
- const auto pos = env.emitter().getPosits()[tid].back();
+ const auto pos = env.emitter().getValPosits(tid);
assert(pos);
args.push_back(pos);
} else {
@@ -893,7 +893,7 @@ static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr,
if (isCompressedLT(lt) || isSingletonLT(lt) ||
isLooseCompressedLT(lt) || is2OutOf4LT(lt)) {
assert(lvl.has_value());
- const Value crd = env.emitter().getCoords()[tid][*lvl];
+ const Value crd = env.emitter().getCoord(tid, *lvl);
const Value lvar = env.getLoopVar(curr);
clause = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
crd, lvar);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
index 08a326eda36bfd..76f7adac88b9a7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
@@ -63,8 +63,6 @@ LLVM_ATTRIBUTE_UNUSED static void dumpIndexMemRef(OpBuilder &builder,
// specifies the range of the fragment, and pPtr specifies the index of the
// corresponding fragment in the child level (i.e., a pointer to the sliced
// position array).
-static constexpr unsigned kSliceIterWidth = 3;
-
static Value genSliceOffset(OpBuilder &builder, Location loc, Value tensor,
Level lvl) {
auto enc = getSparseTensorEncoding(tensor.getType());
@@ -77,217 +75,10 @@ static Value genSliceStride(OpBuilder &builder, Location loc, Value tensor,
return createOrFoldSliceStrideOp(builder, loc, tensor, toDim(enc, lvl));
}
-/// Converts a coordinate relative to the slice to the coordinate relative
-/// to the underlying tensor.
-// FIXME: that description says "sliceCrd -> tensorCrd"; but the function
-// name suggests it should be "tensorCrd -> sliceCrd".
-static Value toSliceCrd(OpBuilder &builder, Location loc, Value crd,
- Value offset, Value stride, Value tensor, Level lvl) {
- // tensorCrd = sliceCrd * stride + offset
- return ADDI(MULI(crd, stride), offset);
-}
-
-/// Generates code to compute the *absolute* offset of the slice based on the
-/// provide minimum coordinates in the slice.
-/// E.g., when reducing d0 + d1 + d2, we need two slices to fully reduced the
-/// expression, i,e, s1 = slice(T, d0), s2 = slice(s1, d1). The *absolute*
-/// offset is the offset computed relative to the initial tensors T.
-///
-/// When isNonEmpty == true, the computed offset is meaningless and should not
-/// be used during runtime, the method generates code to return 0 currently in
-/// that case.
-///
-/// offset = isNonEmpty && minCrd >= size ? minCrd - size + 1 : 0;
-static Value offsetFromMinCoord(OpBuilder &builder, Location loc, Value minCrd,
- Value size, Value isNonEmpty) {
- Value geSize = CMPI(uge, minCrd, size);
- Value pred = ANDI(isNonEmpty, geSize);
- // Computes minCrd - size + 1
- Value mms = SUBI(ADDI(minCrd, C_IDX(1)), size);
- // This is the absolute offset related to the underly tensor.
- return SELECT(pred, mms, C_IDX(0));
-}
-
-/// Converts a coordinate relative to the underlying tensor to the coordinate
-/// relative to the slice, returns a extra reminder value
-// FIXME: that description says "tensorCrd -> sliceCrd"; but the function
-// name suggests it should be "sliceCrd -> tensorCrd".
-static std::pair<Value, Value> fromSliceCrd(OpBuilder &builder, Location loc,
- Value crd, Value offset,
- Value stride, Value tensor,
- Level lvl) {
- // sliceCrd = (tensorCrd - offset) / stride
- crd = SUBI(crd, offset);
- Value rem = REMUI(crd, stride);
- crd = DIVUI(crd, stride);
- return std::make_pair(crd, rem);
-}
-
-// Generates a bool value for while loop condition that tries to iterate over a
-// fully reduced level with affine index expression.
-static Value genSparseReducedAffineCond(OpBuilder &builder, Location loc,
- const SparseTensorLevel &level,
- Value crdHi, Value posit, Value posHi) {
- Value inBound = CMPI(ult, posit, posHi);
- auto ifOp =
- builder.create<scf::IfOp>(loc, builder.getI1Type(), inBound, true);
- // if (inbound)
- // yield coord < crdHi
- builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
- Value crd = level.peekCrdAt(builder, loc, posit);
- YIELD(CMPI(ult, crd, crdHi));
- // else
- // yield false
- builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
- YIELD(constantI1(builder, loc, false));
-
- builder.setInsertionPointAfter(ifOp);
- return ifOp.getResult(0);
-}
-
-// Helper functions that load/store into the position buffer for slice-driven
-// loops.
-// The sliced pointer buffer is organized as:
-// [[pLo0, pLo1, pLo2, ...],
-// [pHi0, pHi1, pHi2, ...],
-// [pNx0, pNx1, pNx2, ...]]
-static Value allocSlicePosBuf(OpBuilder &builder, Location loc,
- Value tupleCnt) {
- Value bufSz = MULI(tupleCnt, C_IDX(kSliceIterWidth));
- // Additional two metadata {memSize, idx} at head.
- return genAlloca(builder, loc, bufSz, builder.getIndexType());
-}
-
-// Gets and sets position values for slice-driven loops.
-enum class SlicePosKind { kLo, kHi, kNext };
-static Value getSlicePosIdx(OpBuilder &builder, Location loc, Value posBuf,
- Value tupleIdx, SlicePosKind posKind) {
- Value dim = builder.create<memref::DimOp>(loc, posBuf, C_IDX(0));
- Value tupleCnt = DIVUI(dim, C_IDX(kSliceIterWidth));
- switch (posKind) {
- case SlicePosKind::kLo:
- return tupleIdx;
- case SlicePosKind::kHi:
- return ADDI(tupleIdx, tupleCnt);
- case SlicePosKind::kNext:
- return ADDI(tupleIdx, MULI(tupleCnt, C_IDX(2)));
- }
- llvm_unreachable("unexpected kind");
-}
-static Value loadSlicePos(OpBuilder &builder, Location loc, Value sPosBuf,
- Value tupleIdx, SlicePosKind posKind) {
- return genIndexLoad(builder, loc, sPosBuf,
- getSlicePosIdx(builder, loc, sPosBuf, tupleIdx, posKind));
-}
-static void updateSlicePos(OpBuilder &builder, Location loc, Value sPosBuf,
- Value pos, Value tupleIdx, SlicePosKind posKind) {
- builder.create<memref::StoreOp>(
- loc, pos, sPosBuf,
- getSlicePosIdx(builder, loc, sPosBuf, tupleIdx, posKind));
-}
-
-std::pair<Value, Value>
-LoopEmitter::genSliceLegitPredicate(OpBuilder &builder, Location loc, Value crd,
- TensorId tid, Level lvl) {
- assert(isSparseSlices[tid]);
- Value slice = tensors[tid];
- Value offset = sliceOffsets[tid][lvl];
- Value stride = sliceStrides[tid][lvl];
- auto enc = getSparseTensorEncoding(slice.getType());
-
- const auto [newCrd, crdRem] =
- fromSliceCrd(builder, loc, crd, offset, stride, slice, lvl);
-
- SmallVector<Value, 3> conds; // at most 3 conditions
-
- // First, coord >= offset (skip the check if offset is known to be 0).
- if (auto staticOffset = enc.getStaticLvlSliceOffset(lvl);
- !(staticOffset.has_value() && *staticOffset == 0)) {
- auto geOffset = CMPI(uge, crd, offset);
- conds.push_back(geOffset);
- }
-
- // Second, coord_in_slice < length
- auto ltLength = CMPI(ult, newCrd, lvls[tid][lvl]->size());
- conds.push_back(ltLength);
-
- // Third, rem == 0 (skip the check if stride is known to be 1).
- if (auto staticStride = enc.getStaticLvlSliceStride(lvl);
- !(staticStride.has_value() && *staticStride == 1)) {
- auto fitStride = CMPI(eq, crdRem, C_IDX(0));
- conds.push_back(fitStride);
- }
-
- // Must meet all condition to be a valid coordinate in slice.
- auto pred = conds.front();
- for (auto cond : ValueRange(conds).drop_front())
- pred = ANDI(pred, cond);
-
- return {newCrd, pred};
-}
-
//===----------------------------------------------------------------------===//
// Sparse tensor loop emitter class implementations
//===----------------------------------------------------------------------===//
-Value LoopEmitter::genAddress(OpBuilder &builder, Location loc, TensorId tid,
- Level lvl, Value crd) {
- Value pos = lvl == 0 ? C_IDX(0) : posits[tid][lvl - 1];
- Value mul = MULI(highs[tid][lvl], pos);
- if (isSparseSlices[tid])
- crd = toSliceCrd(builder, loc, crd, sliceOffsets[tid][lvl],
- sliceStrides[tid][lvl], tensors[tid], lvl);
- Value add = ADDI(mul, crd);
- return add;
-}
-
-Value LoopEmitter::genSegmentHigh(OpBuilder &builder, Location loc,
- TensorId tid, Level lvl, Value pLo,
- Value pHi) {
- SparseTensorLevel &stl = *lvls[tid][lvl];
- const Value sameCrd = stl.peekCrdAt(builder, loc, pLo);
- auto whileOp = builder.create<scf::WhileOp>(
- loc, builder.getIndexType(), pLo,
- /*beforeBuilder=*/
- [pHi, &stl, sameCrd](OpBuilder &builder, Location loc, ValueRange ivs) {
- const auto pos = ivs[0];
- Value inBound = builder.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ult, pos, pHi);
- auto ifInBound =
- builder.create<scf::IfOp>(loc, builder.getI1Type(), inBound, true);
- {
- OpBuilder::InsertionGuard guard(builder);
- // Load the next coordinates only when inbound (to avoid OOB
- // accesses).
- builder.setInsertionPointToStart(ifInBound.thenBlock());
- Value crd = stl.peekCrdAt(builder, loc, pos);
- Value isSameCrd = builder.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::eq, crd, sameCrd);
- YIELD(isSameCrd);
- // Else, the position is out of bound, yield false to terminate the
- // loop.
- builder.setInsertionPointToStart(ifInBound.elseBlock());
- YIELD(constantI1(builder, loc, false));
- }
- builder.create<scf::ConditionOp>(loc, ifInBound.getResults()[0], ivs);
- },
- /*afterBuilder=*/
- [](OpBuilder &builder, Location loc, ValueRange ivs) {
- // pos ++
- Value nextPos = ADDI(ivs[0], C_IDX(1));
- YIELD(nextPos);
- });
- // Return the segment high.
- return whileOp.getResult(0);
-}
-
-Value LoopEmitter::genSparseCrd(OpBuilder &builder, Location loc, TensorId tid,
- Level lvl) {
- const Value pos = posits[tid][lvl];
- const Value crd = lvls[tid][lvl]->peekCrdAt(builder, loc, pos);
- return crd;
-}
-
LoopEmitter::LoopEmitter(ValueRange tensors, StringAttr loopTag, bool hasOutput,
bool isSparseOut, unsigned numLoops,
DependentLvlGetter dimGetter) {
@@ -308,17 +99,9 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
// tensors array (len == numManifestTensor).
this->tensors.assign(ts.begin(), ts.end());
// Arrays with len == numTensor.
- this->lvlTypes.assign(numTensors, std::vector<LevelType>());
- this->highs.assign(numTensors, std::vector<Value>());
- this->segHi.assign(numTensors, std::vector<Value>());
- this->posits.assign(numTensors, std::vector<Value>());
- this->coords.assign(numTensors, std::vector<Value>());
this->valBuffer.assign(numTensors, nullptr);
this->lvls.resize(numTensors);
this->iters.resize(numTensors);
- this->isSparseSlices.assign(numTensors, false);
- this->sliceOffsets.assign(numTensors, std::vector<Value>());
- this->sliceStrides.assign(numTensors, std::vector<Value>());
// These zeros will be overwritten below, but we need to initialize
// them to something since we'll need random-access assignment.
@@ -328,13 +111,8 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
// Index-reduction related fields.
this->dependentLvlMap.assign(
numTensors, std::vector<std::vector<std::pair<TensorLevel, unsigned>>>());
- this->slicePosBuffer.assign(numTensors, std::vector<std::vector<Value>>());
- this->sliceTupleNxStartIdx.assign(numTensors, std::vector<Value>());
- this->sliceTupleFwdCnt.assign(numTensors, std::vector<Value>());
- this->trivialSlice.assign(numTensors, std::vector<bool>());
this->sliceMeta.assign(
numTensors, std::vector<std::vector<std::pair<Value, unsigned>>>());
- this->sliceStack.assign(numTensors, std::vector<SliceInfo>());
this->levelReducedDep.assign(numTensors, std::vector<unsigned>());
// Initialize nested types of `TensorId`-indexed fields.
@@ -345,7 +123,6 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
// to the total number of loops (each level can potentially be mapped to
// one of the loop being generated).
lvlRank = numLoops;
- lvlTypes[tid].assign(lvlRank, LevelType::Dense);
} else {
const Value t = tensors[tid];
// a scalar or 0-dimension tensors
@@ -355,40 +132,17 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
auto rtp = getRankedTensorType(t);
const SparseTensorType stt(rtp);
lvlRank = stt.getLvlRank();
-
- if (stt.hasEncoding()) {
- const auto enc = stt.getEncoding();
- isSparseSlices[tid] = enc.isSlice();
- for (auto lvlTp : enc.getLvlTypes())
- lvlTypes[tid].push_back(lvlTp);
- } else {
- lvlTypes[tid].assign(lvlRank, LevelType::Dense);
- }
}
- // Initialize using empty value.
- highs[tid].assign(lvlRank, Value());
- segHi[tid].assign(lvlRank, Value());
- posits[tid].assign(lvlRank, Value());
- coords[tid].assign(lvlRank, Value());
lvls[tid].resize(lvlRank);
iters[tid].resize(lvlRank);
-
- sliceOffsets[tid].assign(lvlRank, Value());
- sliceStrides[tid].assign(lvlRank, Value());
+ loopHighs.assign(numLoops, nullptr);
// Slice-driven loops related initialization.
levelReducedDep[tid].assign(lvlRank, 0);
dependentLvlMap[tid].assign(
lvlRank, std::vector<std::pair<TensorLevel, unsigned>>());
- slicePosBuffer[tid].assign(lvlRank, std::vector<Value>());
- sliceTupleNxStartIdx[tid].assign(lvlRank, Value());
- sliceTupleFwdCnt[tid].assign(lvlRank, Value());
- trivialSlice[tid].assign(lvlRank, false);
sliceMeta[tid].assign(lvlRank, std::vector<std::pair<Value, unsigned>>());
- sliceStack[tid].emplace_back(/*minCrd=*/Value(),
- /*offset=*/Value(), /*isNonEmpty*/ Value(),
- /*posTupleNum=*/Value(), std::nullopt, 0);
if (dimGetter && !isSynTensor(tid)) {
for (Level l = 0; l < lvlRank; l++) {
std::vector<std::pair<LoopId, unsigned>> deps = dimGetter(tid, l);
@@ -401,8 +155,6 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
if (depends == 0)
continue;
sliceMeta[tid][l].reserve(depends);
- // We need `depends - 1` slices to fully reduce the affine expression.
- slicePosBuffer[tid][l].reserve(depends - 1);
}
}
}
@@ -412,14 +164,12 @@ std::unique_ptr<SparseIterator>
LoopEmitter::makeLevelIterator(OpBuilder &builder, Location loc, TensorId t,
Level l) {
auto it = makeSimpleIterator(*lvls[t][l]);
- if (isSparseSlices[t]) {
+ auto stt = getSparseTensorType(tensors[t]);
+ if (stt.hasEncoding() && stt.getEncoding().isSlice()) {
Value offset = genSliceOffset(builder, loc, tensors[t], l);
Value stride = genSliceStride(builder, loc, tensors[t], l);
auto slicedIt = makeSlicedLevelIterator(std::move(it), offset, stride,
lvls[t][l]->size());
- // TODO: remove below.
- sliceOffsets[t][l] = offset;
- sliceStrides[t][l] = stride;
return slicedIt;
}
return it;
@@ -431,8 +181,8 @@ void LoopEmitter::initializeLoopEmit(
// For every synthetic tensor, set the high bound by calling the callback.
if (synSetter) {
TensorId synId = getSynTensorId();
- for (unsigned i = 0, e = highs[synId].size(); i < e; i++) {
- Value sz = highs[synId][i] = synSetter(builder, loc, i);
+ for (unsigned i = 0, e = loopHighs.size(); i < e; i++) {
+ Value sz = loopHighs[i] = synSetter(builder, loc, i);
auto [stl, it] = makeSynLevelAndIterator(sz, synId, i);
lvls[synId][i] = std::move(stl);
iters[synId][i].emplace_back(std::move(it));
@@ -471,7 +221,6 @@ void LoopEmitter::initializeLoopEmit(
// Scan all levels of current tensor.
for (Level l = 0; l < lvlRank; l++) {
// Find upper bound in current dimension.
- highs[t][l] = lvlSzs[l];
lvls[t][l] = makeSparseTensorLevel(builder, loc, tensor, t, l);
if (!dependentLvlMap[t][l].empty())
continue;
@@ -513,9 +262,8 @@ void LoopEmitter::initializeLoopEmit(
// some loop preparation from tensor iteration, but will also (undesirably)
// hoist the code ouside if-conditions.
}
-
+ // TODO: avoid treating subsection iterator as a special case.
initSubSectIterator(builder, loc);
- initSliceDriven(builder, loc);
}
void LoopEmitter::initSubSectIterator(OpBuilder &builder, Location loc) {
@@ -562,13 +310,13 @@ void LoopEmitter::initSubSectIterator(OpBuilder &builder, Location loc) {
// Compute the subsection size.
Value size = c0;
for (auto [loop, stride] : remDepStack[t][lvl]) {
- Value loopHi = highs[getSynTensorId()][loop];
+ Value loopHi = loopHighs[loop];
size = ADDI(size, MULI(loopHi, C_IDX(stride)));
}
it = makeNonEmptySubSectIterator(builder, loc, parent, std::move(lvlIt),
size, curDep.second);
} else {
- Value size = highs[getSynTensorId()][loop];
+ Value size = loopHighs[loop];
const SparseIterator &subSectIter = *iters[t][lvl].back();
it = makeTraverseSubSectIterator(subSectIter, *parent, std::move(lvlIt),
size, curDep.second);
@@ -579,105 +327,6 @@ void LoopEmitter::initSubSectIterator(OpBuilder &builder, Location loc) {
}
}
-void LoopEmitter::initSliceDriven(OpBuilder &builder, Location loc) {
- Value c0 = C_IDX(0);
- for (TensorId t = 0, e = tensors.size(); t < e; t++) {
- auto rtp = dyn_cast<RankedTensorType>(tensors[t].getType());
- if (!rtp)
- continue;
-
- Level lvlRank = SparseTensorType(rtp).getLvlRank();
-
- // Compute the dependency reduction order.
- auto remDepStack = dependentLvlMap;
- std::vector<std::tuple<LoopId, TensorId, Level>> depRedOrder;
- for (Level lvl = 0; lvl < lvlRank; lvl++) {
- // Reverse queue into a stack.
- std::reverse(remDepStack[t][lvl].begin(), remDepStack[t][lvl].end());
- for (auto [loop, coeff] : dependentLvlMap[t][lvl])
- depRedOrder.emplace_back(std::make_tuple(loop, t, lvl));
- }
-
- if (depRedOrder.empty())
- continue;
- std::sort(depRedOrder.begin(), depRedOrder.end(),
- [](auto &l, auto &r) { return std::get<0>(l) < std::get<0>(r); });
-
- for (auto [loop, t, lvl] : depRedOrder) {
- std::pair<LoopId, unsigned> curDep = remDepStack[t][lvl].back();
- assert(curDep.first == loop);
- Value size = c0;
- for (auto [loop, stride] : remDepStack[t][lvl]) {
- // The synthetic tensor high defines the loop upper bound.
- Value loopHi = highs[getSynTensorId()][loop];
- size = ADDI(size, MULI(loopHi, C_IDX(stride)));
- }
- sliceMeta[t][lvl].emplace_back(size, curDep.second);
- remDepStack[t][lvl].pop_back();
-
- // Generate caches required to fast compute next-non-empty slices with
- // increasing offset for slice-base loop.
- // We do not need cache for dense levels.
- if (!remDepStack[t][lvl].empty() && !isDenseLT(lvls[t][lvl]->getLT())) {
- Value cnt = C_IDX(1);
- for (int preLvl = lvl - 1; preLvl >= 0; preLvl--) {
- if (remDepStack[t][preLvl].empty())
- break;
- assert(remDepStack[t][preLvl].size() == 1 && "Not implemented");
- auto [loop, stride] = remDepStack[t][preLvl].back();
- assert(stride == 1 && "Not yet implemented");
- // Accumlate the size required to cache the pLo for the slice.
- // E.g., if we want to cache the pIdx for slice<d0xd1xf64> on the
- // second level. We at most need a memref<d0xindex>.
- //
- // NOTE: this is apparently an over-approximation when the previous
- // level is compressed, and we can compute a precise memory size
- // inside the loops. But that would also requires us to allocate/free
- // memory in loops.
- cnt = MULI(highs[getSynTensorId()][loop], cnt);
- }
- slicePosBuffer[t][lvl].push_back(allocSlicePosBuf(builder, loc, cnt));
- } // else fully resolved.
- }
- }
-}
-
-void LoopEmitter::categorizeLoopCondition(
- ArrayRef<TensorLevel> tidLvls, SmallVectorImpl<TensorLvlCond> &dnConds,
- SmallVectorImpl<TensorLvlCond> &spConds) {
- // Finds out the tensor level that we should use to generate loops. Amongs all
- // the tensor levels, there is at most one sparse tensor level.
- for (auto [t, l] : unpackTensorLevelRange(tidLvls)) {
- assert(lvlTypes[t].size() > l); // Must be a valid tid, dim pair
- auto lvlType = lvlTypes[t][l];
- // Must be a recognizable LT.
- assert(isDenseLT(lvlType) || isCompressedLT(lvlType) ||
- isLooseCompressedLT(lvlType) || isSingletonLT(lvlType) ||
- is2OutOf4LT(lvlType));
-
- bool isSparse = !isDenseLT(lvlType);
- bool isSlice = isSparseSlices[t];
- bool isAffine = !dependentLvlMap[t][l].empty();
- bool isUnRedu = false;
- // TODO: Supports affine index expression on sparse tensor slices.
- assert(!isSlice || !isAffine);
-
- // Whether the affine index expression has been fully reduced or not.
- if (!dependentLvlMap[t][l].empty())
- isUnRedu = !depFullyReduced(t, l);
-
- auto &dstVec = isSparse ? spConds : dnConds;
- dstVec.emplace_back(
- makeTensorLevel(t, l),
- makeLoopCondKind(isSparse, isSlice, isAffine, isUnRedu));
- }
-
- std::stable_sort(spConds.begin(), spConds.end(), [](auto lhs, auto rhs) {
- // AffineUnRed > Affine > Slice > Trivial
- return static_cast<uint8_t>(lhs.second) > static_cast<uint8_t>(rhs.second);
- });
-}
-
void LoopEmitter::categorizeIterators(
ArrayRef<TensorLevel> tidLvls, SmallVectorImpl<SparseIterator *> &raIters,
SmallVectorImpl<SparseIterator *> &spIters) {
@@ -802,200 +451,9 @@ std::pair<Operation *, Value> LoopEmitter::emitForLoopOverTensorAtLvl(
iter.locate(builder, loc, iv);
}
- // if (isSparseSlices[tid] && isSparseCond) {
- // // For sparse level slices, we need to filter out invalid coordinates
- // that
- // // are not included in the slice.
- // SmallVector<Type> types;
- // for (Value red : reduc)
- // types.push_back(red.getType());
-
- // auto [trans, pred] = genSliceLegitPredicate(builder, loc, crd, tid, lvl);
- // bool hasReduc = !types.empty();
- // scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, pred,
- // /*else*/ hasReduc);
- // if (hasReduc) {
- // // scf.for (a) -> v
- // // %s = scf.if (a) -> v
- // // user-generated code.
- // // else
- // // yield a
- // // yield %s
- // YIELD(ifOp.getResults());
- // builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
- // // On mismatch.
- // YIELD(reduc);
- // }
- // // Set the insertion point to matched branch.
- // builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
- // crd = trans;
- // }
-
- coords[iter.tid][iter.lvl] = crd;
- posits[iter.tid][iter.lvl] = iter.getItVals().front();
return {loop, crd};
}
-Value LoopEmitter::genWhileLoopConditions(OpBuilder &builder, Location loc,
- ValueRange ivs, TensorLvlCond cond) {
- auto [tid, lvl] = unpackTensorLevel(cond.first);
-
- switch (cond.second) {
- case LoopCondKind::SparseCond: {
- assert(ivs.size() == 1);
- // We used the first level bound as the bound the collapsed set of levels.
- return CMPI(ult, ivs.back(), highs[tid][lvl]);
- }
- case LoopCondKind::SparseSliceCond: {
- assert(ivs.size() == 1);
- return CMPI(ult, ivs.back(), highs[tid][lvl]);
- }
- case LoopCondKind::SparseAffineCond: {
- assert(ivs.size() == 1);
-
- Value crdHi; // loop upper bound
- {
- OpBuilder::InsertionGuard guard(builder);
- Operation *loop = builder.getInsertionBlock()->getParentOp();
- // crdHi is a loop invariant, hosit the computation outside the loop.
- if (llvm::isa_and_nonnull<scf::WhileOp>(loop))
- builder.setInsertionPoint(loop);
- auto [remSz, stride] = sliceMeta[tid][lvl].back();
- assert(stride == 1 && "Not yet implemented");
- crdHi = ADDI(getMostRecentSliceOnLvl(tid, lvl).offset, remSz);
- }
- assert(crdHi);
- return genSparseReducedAffineCond(builder, loc, *lvls[tid][lvl], crdHi,
- ivs[0], highs[tid][lvl]);
- }
- case LoopCondKind::SparseAffineUnRedCond: {
- assert(ivs.size() == 3);
- return ivs.front(); // isNonEmpty
- }
- default:
- llvm_unreachable("Unhandled LoopCondKind");
- }
- llvm_unreachable("Unhandled LoopCondKind");
-}
-
-std::optional<Value> LoopEmitter::genWhileLoopBody(OpBuilder &builder,
- Location loc, ValueRange ivs,
- TensorLvlCond cond) {
- auto [tid, lvl] = unpackTensorLevel(cond.first);
-
- switch (cond.second) {
- case LoopCondKind::SparseCond: {
- // Updates position. For collapsed COO, the position is the same across
- // consecutive levels.
- posits[tid][lvl] = ivs.back();
-
- // Update coordinates.
- coords[tid][lvl] = genSparseCrd(builder, loc, tid, lvl);
- return std::nullopt;
- }
- case LoopCondKind::SparseSliceCond: {
- assert(ivs.size() == 1);
- posits[tid][lvl] = ivs.front();
- Value sCrd = genSparseCrd(builder, loc, tid, lvl);
- // Converts the coordinate loaded from the actual sparse tensor to the
- // coordinates in the sparse slice.
- auto [dCrd, pred] = genSliceLegitPredicate(builder, loc, sCrd, tid, lvl);
- coords[tid][lvl] = dCrd;
- return pred;
- }
- case LoopCondKind::SparseAffineCond: {
- assert(ivs.size() == 1);
- // Coord is the relative offset related to its parents.
- assert(sliceStack[tid].back().depth == 1 && "TODO: not yet implement");
- sliceTupleFwdCnt[tid][lvl] = SUBI(ivs[0], posits[tid][lvl]);
- // Update c = absOffset[lvl][depth] - absOffset[lvl][depth - 1]
- Value posit = ivs[0];
- // We need to substract the offset to get relative coordinates.
- // TODO: Maybe assert relC >=0 during runtime in debug build?
- Value absC = lvls[tid][lvl]->peekCrdAt(builder, loc, posit);
- auto relC = SUBI(absC, getFinalSliceOnLvl(tid, lvl).offset);
- posits[tid][lvl] = posit;
- coords[tid][lvl] = relC;
- return std::nullopt;
- }
- case LoopCondKind::SparseAffineUnRedCond: {
- unsigned depth = sliceStack[tid].back().depth;
- unsigned curStride = sliceMeta[tid][lvl][depth - 1].second;
- assert(ivs.size() == 3);
-
- // Updates the current slice info
- SliceInfo &sliceInfo = sliceStack[tid].back();
- sliceInfo.isNonEmpty = ivs[0];
- sliceInfo.minCrd = ivs[1];
- sliceInfo.offset = ivs[2];
-
- // Crd (the value we used to coiterate) is the relative offset related to
- // its parents, we can use the absolute offset here because when depth = 1,
- // absOffset[lvl][depth - 1] always equals zero.
- // TODO: Update crd =absOffset[lvl][depth] - absOffset[lvl][depth - 1]
- assert(depth == 1 && "TODO: not yet implement");
- Value crd = sliceInfo.offset;
-
- Value onStride = constantI1(builder, loc, true);
- if (curStride != 1) {
- Value strideVal = C_IDX(curStride);
- Value rem = REMUI(crd, strideVal);
- crd = DIVUI(crd, strideVal);
- onStride = CMPI(eq, rem, C_IDX(0));
- }
- coords[tid][lvl] = crd;
- // No extra check is needed before accessing the tensor level.
- return onStride;
- }
- default:
- llvm_unreachable("Unhandled LoopCondKind");
- }
- llvm_unreachable("Unhandled LoopCondKind");
-}
-
-ValueRange LoopEmitter::genCheckedValue(OpBuilder &builder, Location loc,
- Value pred, ValueRange curArgs,
- TensorLvlCond cond) {
- assert(isSparseCond(cond.second));
- auto [tid, lvl] = unpackTensorLevel(cond.first);
- if (isAffineIdxUnRedCond(cond.second)) {
- unsigned depth = sliceStack[tid].back().depth;
- unsigned curStride = sliceMeta[tid][lvl][depth - 1].second;
- if (curStride == 1)
- return curArgs;
- // Build
- // if (onStride) {
- // yield curSlice
- // } else {
- // yield nxSlice.
- //}
- assert(curArgs.size() == 3);
- auto ifOp = builder.create<scf::IfOp>(loc, curArgs.getTypes(), pred, true);
- {
- OpBuilder::InsertionGuard guard(builder);
- // If not all slices are legit, yield the updated value.
- builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
-
- YIELD(curArgs);
- // If not all slices are legit, yield the updated value.
- builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
- auto [nonEmpty, minCrd, offset] =
- genSliceNextInduction(builder, loc, tid, lvl);
- SmallVector<Value> nxSlice{nonEmpty, minCrd, offset};
- YIELD(nxSlice);
- }
- // If all slices are legit, start the user generated code.
- return ifOp.getResults();
- } else {
- // Currently only sparse slice condition need extra check.
- assert(isSliceCond(cond.second) && isSparseCond(cond.second));
- assert(curArgs.size() == 1);
- Value nextPos = ADDI(curArgs.front(), C_IDX(1));
- return SELECT(pred, curArgs.front(), nextPos)->getResults();
- }
- llvm_unreachable("unhandled case");
-}
-
std::pair<Operation *, Value> LoopEmitter::emitWhileLoopOverTensorsAtLvls(
OpBuilder &builder, Location loc, ArrayRef<SparseIterator *> spIters,
MutableArrayRef<Value> reduc, bool needsUniv) {
@@ -1011,38 +469,6 @@ std::pair<Operation *, Value> LoopEmitter::emitWhileLoopOverTensorsAtLvls(
ivs.append(itVals.begin(), itVals.end());
}
- // for (auto [tl, cKind] : spConds) {
- // auto [tid, lvl] = unpackTensorLevel(tl);
- // const auto lvlTp = lvlTypes[tid][lvl];
- // // Dense level are handled by the shared univeral index.
- // assert(!isDenseCond(cKind));
- // // Must be a recognizable sparse level.
- // assert(isCompressedLT(lvlTp) || isLooseCompressedLT(lvlTp) ||
- // isSingletonLT(lvlTp));
- // (void)lvlTp;
- // unsigned prevSz = ivs.size();
- // if (isAffineIdxCond(cKind)) {
- // // TODO: Support view-based reshape on sparse levels with affine index
- // // expressions.
- // if (isAffineIdxUnRedCond(cKind)) {
- // SliceInfo &sliceInfo = sliceStack[tid].back();
- // // The order matters!
- // ivs.push_back(sliceInfo.isNonEmpty);
- // ivs.push_back(sliceInfo.minCrd);
- // ivs.push_back(sliceInfo.offset);
- // } else {
- // ivs.push_back(posits[tid][lvl]); // loop lower bound (pos low).
- // }
- // // We reduced one more dependency after entering the loop.
- // levelReducedDep[tid][lvl]++;
- // } else {
- // assert(dependentLvlMap[tid][lvl].empty());
- // const Value pos = posits[tid][lvl];
- // ivs.push_back(pos);
- // }
- // opSegSize.push_back(ivs.size() - prevSz);
- // }
-
// The position where user-supplied reduction variable starts.
ivs.append(reduc.begin(), reduc.end());
// Update universal index.
@@ -1062,11 +488,7 @@ std::pair<Operation *, Value> LoopEmitter::emitWhileLoopOverTensorsAtLvls(
builder.setInsertionPointToStart(before);
ValueRange bArgs = before->getArguments();
Value whileCond = nullptr; // bool values for loop condition.
- // for (auto [c, segSz] : llvm::zip_equal(spConds, opSegSize)) {
- // Value cv = genWhileLoopConditions(builder, loc, bArgs.take_front(segSz),
- // c); bArgs = bArgs.drop_front(segSz); whileCond = !whileCond ? cv :
- // ANDI(whileCond, cv);
- // }
+
for (SparseIterator *it : spIters) {
auto [cond, remArgs] = it->genWhileCond(builder, loc, bArgs);
whileCond = !whileCond ? cond : ANDI(whileCond, cond);
@@ -1084,60 +506,13 @@ std::pair<Operation *, Value> LoopEmitter::emitWhileLoopOverTensorsAtLvls(
// iterations, we maintains another array to hold the iteration arguments to
// yield if the checks fails.
SmallVector<Value> nextArgs(aArgs.begin(), aArgs.end());
- // A mutable alias for convenient slicing.
- MutableArrayRef<Value> nextArgsRef = nextArgs;
- // Value extraPred = nullptr;
- // for (auto [c, segSz] : llvm::zip_equal(spConds, opSegSize)) {
- // ValueRange condArgs = aArgs.take_front(segSz);
- // auto pred = genWhileLoopBody(builder, loc, condArgs, c);
- // assert(pred.has_value() == isCondWithExtraCheck(c.second));
- // if (pred.has_value()) {
- // // We need all extra checks to pass.
- // extraPred = extraPred == nullptr ? *pred : ANDI(*pred, extraPred);
- // ValueRange nxArgs = genCheckedValue(builder, loc, *pred, condArgs, c);
- // assert(nxArgs.size() == segSz);
- // // Update the value for cases when some check fails.
- // for (unsigned i = 0; i < segSz; i++) {
- // nextArgsRef[i] = nxArgs[i];
- // }
- // }
- // aArgs = aArgs.drop_front(segSz);
- // nextArgsRef = nextArgsRef.drop_front(segSz);
- // }
for (SparseIterator *it : spIters) {
aArgs = it->linkNewScope(aArgs);
- Value crd = it->deref(builder, loc);
- posits[it->tid][it->lvl] = it->getItVals().front();
- coords[it->tid][it->lvl] = crd;
+ // Dereference the iterator to cache the coordinate.
+ it->deref(builder, loc);
}
- // if (extraPred) {
- // auto ifOp = builder.create<scf::IfOp>(loc, types, extraPred, /*else*/
- // true);
- // // Marks this special IfOp so that Sparsification does not finalizing it.
- // ifOp->setAttr(getLoopEmitterLoopAttrName(),
- // StringAttr::get(builder.getContext(), "slice"));
- // // Links the SSA chain outside the if statement.
- // YIELD(ifOp->getResults());
-
- // // If not all slices are legit, yield the updated value.
- // builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
- // YIELD(nextArgs);
-
- // // If all slices are legit, start the user generated code.
- // builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
- // }
-
- // for (auto [tid, lvl] : unpackTensorLevelFromCondRange(spConds)) {
- // // Generates segment high for non-unique level.
- // if (!isUniqueLT(lvlTypes[tid][lvl])) {
- // segHi[tid][lvl] = genSegmentHigh(builder, loc, tid, lvl,
- // posits[tid][lvl],
- // highs[tid][lvl]);
- // }
- // }
-
// In-place update on reduction variable.
assert(aArgs.size() == reduc.size() + needsUniv ? 1 : 0);
for (unsigned i = 0, e = reduc.size(); i < e; i++)
@@ -1176,21 +551,10 @@ bool LoopEmitter::shouldIteratedByForLoop(ArrayRef<SparseIterator *> spIters) {
Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
OpBuilder &builder, Location loc, ArrayRef<TensorLevel> tidLvls,
MutableArrayRef<Value> reduc, bool tryParallel, bool needsUniv) {
-#ifndef NDEBUG
- // Sanity checks.
- assert(!tidLvls.empty());
- for (auto [t, l] : unpackTensorLevelRange(tidLvls)) {
- assert(!coords[t][l] || // We cannot re-enter the same level
- !dependentLvlMap[t][l].empty()); // unless it is a slice-driver loop
- }
-#endif
+
// TODO: support multiple return on parallel for?
tryParallel = tryParallel && reduc.size() <= 1;
- SmallVector<TensorLvlCond> spConds;
- SmallVector<TensorLvlCond> dnConds;
- categorizeLoopCondition(tidLvls, dnConds, spConds);
-
SmallVector<SparseIterator *> raIters;
SmallVector<SparseIterator *> spIters;
categorizeIterators(tidLvls, raIters, spIters);
@@ -1206,142 +570,39 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
// can be generated using a simple ForOp as well).
Operation *l = nullptr;
Value iv = nullptr;
- SmallVector<SliceLoopInfo> sliceDrivenInfo;
- SmallVector<TensorLevel> trivialLvls;
+ SmallVector<TensorLevel> tls;
// Generates loops differently depending on whether we need a slice-driven
// loop or a simple level traversal loop.
if (shouldIteratedByForLoop(spIters) && !needsUniv) {
assert(spIters.size() <= 1);
- TensorLvlCond tlCond = spConds.empty() ? dnConds.front() : spConds.front();
SparseIterator &it = spIters.empty() ? *raIters.front() : *spIters.front();
- // auto [tid, lvl] = unpackTensorLevel(tlCond.first);
- // Value lo = isSparseCond(loopCondKind)
- // ? posits[tid][lvl] // current offset
- // : loopSeqStack.back().first; // universal index
- // Value hi = highs[tid][lvl];
- // if (isDenseCond(loopCondKind) && isAffineIdxCond(loopCondKind)) {
- // bool unReduc = isAffineIdxUnRedCond(loopCondKind);
- // assert(unReduc == !depFullyReduced(tid, lvl));
- // unsigned depth = sliceStack[tid].back().depth;
- // assert(depth >= 1);
- // // The *next* slice size after reducing the current index variable.
- // auto [nxSz, nxStride] = sliceMeta[tid][lvl][depth];
- // // The *current* stride to reduce the current index variable.
- // // E.g., for 2 * i, stride = 2.
- // unsigned stride = sliceMeta[tid][lvl][depth - 1].second;
- // hi = nxSz;
- // if (unReduc) {
- // // Adjust for loop hi for dense slice-driven loop.
- // hi = SUBI(lvls[tid][lvl]->size(), hi);
- // hi = ADDI(hi, C_IDX(1));
- // hi = DIVUI(hi, C_IDX(stride));
- // } else {
- // // TODO: dialuted convolution.
- // assert(nxStride == 1 && "Not yet implemented.");
- // }
- // }
std::tie(l, iv) =
emitForLoopOverTensorAtLvl(builder, loc, it, reduc, tryParallel);
-
- // For loop condition must be a trivial condition (levels without affine
- // index expression).
- trivialLvls.push_back(tlCond.first);
+ tls.push_back(makeTensorLevel(it.tid, it.lvl));
} else {
- for (auto [tl, cKind] : spConds) {
- if (isAffineIdxCond(cKind)) {
- auto [tid, lvl] = unpackTensorLevel(tl);
- bool unReduc = isAffineIdxUnRedCond(cKind);
- assert(unReduc == !depFullyReduced(tid, lvl));
- sliceDrivenInfo.emplace_back(tid, lvl, /*fullyReduced=*/!unReduc);
- } else {
- trivialLvls.push_back(tl);
- }
+ for (auto *it : spIters) {
+ tls.push_back(makeTensorLevel(it->tid, it->lvl));
}
if (needsUniv)
for (auto *it : raIters)
- trivialLvls.push_back(makeTensorLevel(it->tid, it->lvl));
+ tls.push_back(makeTensorLevel(it->tid, it->lvl));
std::tie(l, iv) =
emitWhileLoopOverTensorsAtLvls(builder, loc, spIters, reduc, needsUniv);
}
// Enter dense tensor levels.
- enterTensorsAtDenseLvls(builder, loc, raIters, iv, sliceDrivenInfo);
- // NOTE: we can also prepare for next dim here in advance
+ for (SparseIterator *it : raIters)
+ it->locate(builder, loc, iv);
+ // NOTE: we can also prepare for next dim here in advance
// Pushes the loop into stack.
- loopStack.emplace_back(trivialLvls, sliceDrivenInfo, l,
- builder.getInsertionBlock(), iv, loopTag);
+ loopStack.emplace_back(tidLvls, l, builder.getInsertionBlock(), iv, loopTag);
return l;
}
-Operation *LoopEmitter::enterFilterLoopOverTensorAtLvl(
- OpBuilder &builder, Location loc, TensorId tid, Level lvl,
- AffineExpr affine, MutableArrayRef<Value> reduc) {
- assert(isValidLevel(tid, lvl));
- assert(!isa<AffineDimExpr>(affine) && !isDenseLT(lvlTypes[tid][lvl]));
- // We can not re-enter the same level.
- assert(!coords[tid][lvl]);
-
- // TODO: We should instead use a whileOp for filter loop to allow early
- // break when exceeding (for ordered levels).
- // TODO: There are many other potiential opportunities that we might apply in
- // the future. E.g., we could use binary search to locate positions.
- const Value step = C_IDX(1);
- const Value pLo = posits[tid][lvl];
- const Value pHi = highs[tid][lvl];
- scf::ForOp forOp = builder.create<scf::ForOp>(loc, pLo, pHi, step, reduc);
-
- // In-place update on the reduction variable vector.
- assert(forOp.getNumRegionIterArgs() == reduc.size());
- for (int i = 0, e = reduc.size(); i < e; i++)
- reduc[i] = forOp.getRegionIterArg(i);
-
- builder.setInsertionPointToStart(forOp.getBody());
- // The induction variable gives the position.
- const Value pos = forOp.getInductionVar();
- posits[tid][lvl] = pos;
- const Value crd = lvls[tid][lvl]->peekCrdAt(builder, loc, pos);
- coords[tid][lvl] = crd;
-
- // Generate an if-condition to filter out coordinates that are not
- // equal to the result of the affine expression.
- Value expected = genAffine(builder, loc, affine);
- auto pred = CMPI(eq, crd, expected);
- SmallVector<Type> types;
- for (Value red : reduc) {
- types.push_back(red.getType());
- }
-
- bool hasReduc = !types.empty();
- scf::IfOp ifOp =
- builder.create<scf::IfOp>(loc, types, pred, /*else*/ hasReduc);
- if (hasReduc) {
- // scf.for (a) -> v
- // %s = scf.if (a) -> v
- // user-generated code.
- // else
- // yield a
- // yield %s
- YIELD(ifOp.getResults());
- builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
- // On mismatch.
- YIELD(reduc);
- }
- // Set the insert point to matched branch.
- builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
-
- // NOTE: we can also prepare for next lvl here in advance
- // Push the loop into stack
- loopStack.emplace_back(ArrayRef<TensorLevel>(makeTensorLevel(tid, lvl)),
- ArrayRef<SliceLoopInfo>(), forOp,
- builder.getInsertionBlock(), coords[tid][lvl],
- nullptr);
- return forOp;
-}
-
void LoopEmitter::genDenseAffineAddress(OpBuilder &builder, Location loc,
TensorLevel tidLvl,
AffineExpr lvlExpr) {
@@ -1364,83 +625,15 @@ void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
hasParent ? nullptr : iters[tid][lvl - 1].back().get();
auto &it = getCurIterator(tid, lvl);
it.genInit(builder, loc, parent);
- if (it.randomAccessible()) {
- it.locate(builder, loc, C_IDX(0));
- }
-}
-void LoopEmitter::enterTensorsAtDenseLvls(
- OpBuilder &builder, Location loc, ArrayRef<SparseIterator *> raIters,
- Value crd, SmallVectorImpl<SliceLoopInfo> &sliceInfo) {
- for (SparseIterator *it : raIters) {
- it->locate(builder, loc, crd);
- posits[it->tid][it->lvl] = it->getItVals().front();
- }
- // for (auto [dnTidLvl, denseLoopCond] : dnConds) {
- // auto [tid, lvl] = unpackTensorLevel(dnTidLvl);
- // assert(isDenseLT(lvlTypes[tid][lvl]));
-
- // if (isAffineIdxCond(denseLoopCond)) {
- // // Pushes sliced levels to build correct LoopInfo.
- // bool unReduc = isAffineIdxUnRedCond(denseLoopCond);
- // SliceInfo &info = sliceStack[tid].back();
- // // Pushes sliced dense loop info to tell LoopEmitter how to exit it.
- // sliceInfo.emplace_back(tid, lvl, /*fullyReduced=*/!unReduc);
- // // FIXME: The offset and position iterator need to be adjusted when the
- // // slice is strided.
- // if (unReduc) {
- // assert(*info.slicedOnLvl == lvl);
- // unsigned depth = sliceStack[tid].back().depth;
- // assert(depth >= 1);
- // unsigned stride = sliceMeta[tid][lvl][depth - 1].second;
- // // Update the slice information as we enter the new loop.
- // info.minCrd = info.offset = MULI(iv, C_IDX(stride));
- // info.isNonEmpty = constantI1(builder, loc, true);
- // } else {
- // posits[tid][lvl] =
- // genAddress(builder, loc, tid, lvl, ADDI(info.offset, iv));
- // Value fwdCnt = lvl == 0 || trivialSlice[tid][lvl]
- // ? C_IDX(0)
- // : sliceTupleFwdCnt[tid][lvl - 1];
- // Value sz = sliceMeta[tid][lvl].back().first;
- // Value mul = MULI(fwdCnt, sz);
- // sliceTupleFwdCnt[tid][lvl] = ADDI(mul, iv);
- // }
- // levelReducedDep[tid][lvl]++;
- // } else {
- // // Skips the synthetic tensor
- // if (isSynTensor(tid))
- // continue;
- // // A dense level with trivial index expression.
- // assert(dependentLvlMap[tid][lvl].empty());
- // auto enc = getSparseTensorEncoding(tensors[tid].getType());
- // if (enc && !isSparseOutput(tid)) {
- // bool validPos = lvl == 0 || posits[tid][lvl - 1];
- // if (!validPos) {
- // // We might not find the pos for the sparse output tensor as it is
- // // unconditionally required by the sparsification.
- // assert(isOutputTensor(tid));
- // continue;
- // }
- // posits[tid][lvl] = genAddress(builder, loc, tid, lvl, iv);
- // // NOTE: we can also prepare for next lvl here in advance
- // }
- // }
- // }
+ // Locates the randon accessible iterator to 0.
+ if (it.randomAccessible())
+ it.locate(builder, loc, C_IDX(0));
}
void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
MutableArrayRef<Value> reduc) {
const LoopInfo &loopInfo = loopStack.back();
- for (auto [tid, lvl, reduced] : loopInfo.sliceDrivenInfo) {
- if (!reduced) {
- SliceInfo &info = sliceStack[tid].back();
- assert(isDenseLT(lvlTypes[tid][lvl]));
- assert(*info.slicedOnLvl == lvl);
- (void)reduced;
- info.minCrd = info.offset = info.isNonEmpty = Value();
- }
- }
if (auto forOp = llvm::dyn_cast<scf::ForOp>(loopInfo.loop)) {
if (!reduc.empty()) {
assert(reduc.size() == forOp.getNumResults());
@@ -1503,18 +696,6 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
for (unsigned i = 0, e = parOp.getResults().size(); i < e; i++)
reduc[i] = parOp.getResult(i);
}
-
- // Finished iterating a tensor, clean up
- // We only do the clean up on for loop as while loops do not necessarily
- // finish the iteration on a sparse tensor
- for (auto [tid, lvl] : unpackTensorLevelRange(loopInfo.trivialTidLvls)) {
- // Reset to null.
- coords[tid][lvl] = Value();
- posits[tid][lvl] = Value();
- // Dense level, high is fixed.
- if (!isDenseLT(lvlTypes[tid][lvl]))
- highs[tid][lvl] = Value();
- }
}
void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
@@ -1533,26 +714,8 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
SmallVector<Value> operands;
unsigned delta = 0;
ValueRange whileRes = whileOp.getResults();
- for (auto [tid, lvl, resolved] : loopInfo.sliceDrivenInfo) {
- SparseIterator &it = getCurIterator(tid, lvl);
- if (!it.randomAccessible()) {
- // Forward the sparse iterator.
- Value cmp = CMPI(eq, it.getCrd(), iv);
- it.forwardIf(builder, loc, cmp);
- operands.append(it.getItVals().begin(), it.getItVals().end());
- o += it.getItVals().size();
- // Following loops continue iteration from the break point of the
- // current while loop.
- whileRes = it.linkNewScope(whileRes);
- } else {
- // Make sure randomly accessible (dense) iterator is set to the right
- // position according to the universal index.
- Value uniIdx = whileOp.getResults().back();
- it.locate(builder, loc, uniIdx);
- }
- };
- for (auto [tid, lvl] : unpackTensorLevelRange(loopInfo.trivialTidLvls)) {
+ for (auto [tid, lvl] : unpackTensorLevelRange(loopInfo.tidLvls)) {
SparseIterator &it = getCurIterator(tid, lvl);
if (!it.randomAccessible()) {
// Forward the sparse iterator.
@@ -1570,13 +733,6 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
Value uniIdx = whileOp.getResults().back();
it.locate(builder, loc, uniIdx);
}
-
- posits[tid][lvl] = it.getItVals().front();
- // The coordinate is invalid now.
- coords[tid][lvl] = nullptr;
- // The segment high is invalid now.
- segHi[tid][lvl] = nullptr;
- // highs remains unchanged.
}
// Reduction value from users.
@@ -1628,655 +784,6 @@ void LoopEmitter::exitCurrentLoop(RewriterBase &rewriter, Location loc,
loopStack.pop_back();
}
-//===----------------------------------------------------------------------===//
-// Slice-driven loop related methods.
-//===----------------------------------------------------------------------===//
-
-unsigned LoopEmitter::remDepOnLevel(TensorId tid, Level lvl) const {
- unsigned totalDependencies = dependentLvlMap[tid][lvl].size();
- if (totalDependencies != 0) {
- assert(totalDependencies >= 2);
- return totalDependencies - levelReducedDep[tid][lvl];
- }
- return totalDependencies;
-}
-
-unsigned LoopEmitter::redDepOnLevel(TensorId tid, Level lvl) const {
- return levelReducedDep[tid][lvl];
-}
-
-const LoopEmitter::SliceInfo &LoopEmitter::getMostRecentSliceOnLvl(TensorId tid,
- Level lvl) {
- // Finds the most-recent slice using a reverse iteration.
- for (auto it = sliceStack[tid].rbegin(), ie = sliceStack[tid].rend(); it < ie;
- it++) {
- if (it->slicedOnLvl == lvl) { // the level matched
- return *it;
- }
- }
- llvm_unreachable("Failed to find sliceInfo");
-}
-
-// Generates a while loop to iterate over a slice sparse level as follows.
-//
-// while(coords[loopLo] < offset + size) {
-// body_builder
-// loopLo ++;
-// }
-std::pair<Operation *, ValueRange> LoopEmitter::genSliceLvlTraverseLoop(
- OpBuilder &builder, Location loc, Value posLo, Value posHi, Value offset,
- Value size, TensorId tid, Level lvl, ValueRange userReduc,
- LoopBodyBuilder bodyBuilder) {
- Value c1 = C_IDX(1);
- auto [sliceSz, stride] = sliceMeta[tid][lvl].back();
- assert(stride == 1 && "Not yet implemented");
- Value sliceHi = ADDI(offset, sliceSz);
-
- SmallVector<Value> reduc{posLo}; // loop lower bounds
- const unsigned numMetaReduc = reduc.size();
-
- // Append user required reduction value.
- reduc.append(userReduc.begin(), userReduc.end());
- scf::WhileOp whileOp = builder.create<scf::WhileOp>(
- loc, ValueRange(reduc).getTypes(), reduc,
- /*beforeBuilder=*/
- [this, posHi, sliceHi, tid, lvl](OpBuilder &builder, Location loc,
- ValueRange args) {
- Value cond = genSparseReducedAffineCond(builder, loc, *lvls[tid][lvl],
- sliceHi, args[0], posHi);
- // continue if not yet break nor out of bound.
- builder.create<scf::ConditionOp>(loc, cond, args);
- },
- /*afterBuilder=*/
- [c1, numMetaReduc, bodyBuilder](OpBuilder &builder, Location loc,
- ValueRange args) {
- Value iv = args[0];
- TypeRange types = args.drop_front(numMetaReduc).getTypes();
- // The coordinate must be in bound as guaranteed by the loop
- // condition. We generate a fake if operation here only to hide the
- // extra loop induction variables maintained by us from users, which
- // will be removed by later optimization pass.
- auto ifOp = builder.create<scf::IfOp>(loc, types,
- constantI1(builder, loc, true),
- /*withElseBlock=*/!types.empty());
- {
- // 2 reduction variable maintained by us.
- SmallVector<Value> ifRet = args.drop_front(numMetaReduc);
- assert(ifRet.size() == args.size() - 1);
-
- OpBuilder::InsertionGuard guard(builder);
- // If coord >= sliceHi.
- if (!ifRet.empty()) {
- builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
- YIELD(ifRet);
- }
-
- // If coord < sliceHi.
- builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
- // Delegates to users' callback.
- bodyBuilder(builder, loc, iv, ifRet);
- }
- // Marks this special ifOp to avoid sparisification finalizing it.
- ifOp->setAttr(getLoopEmitterLoopAttrName(),
- StringAttr::get(builder.getContext(), "slice"));
- // Insertion point restored to after ifOp.
- SmallVector<Value> yields;
- // Increase induction variable.
- yields.push_back(ADDI(iv, c1));
- yields.append(ifOp.getResults().begin(), ifOp.getResults().end());
- YIELD(yields);
- });
-
- builder.setInsertionPointAfter(whileOp);
- return std::make_pair(whileOp, whileOp.getResults().drop_front(numMetaReduc));
-}
-
-// Generates a loop nest that traverse all the unresolved levels in between.
-//
-// for(int i = 0; i < slicePos.size(); i+=2) {
-// loopLo = slicePos[i];
-// loopHi = slicePos[i + 1];
-//
-// // Then the same loop generated by genSliceLvlTraverse above.
-// while (loopLo < loopHI) {
-// if (pos[loopLo] < sliceHi) {
-// bodyBuilder();
-// } else {
-// break;
-// }
-// loopLo ++;
-// }
-// }
-ValueRange LoopEmitter::genUnResolvedSliceTreeTraverse(
- OpBuilder &builder, Location loc, TensorId tid,
- ArrayRef<const SliceInfo *> unResLvls,
- std::optional<std::pair<TensorId, Level>> firstResLvl, ValueRange userReduc,
- LoopBodyBuilder bodyBuilder) {
-
- Value c0 = C_IDX(0), c1 = C_IDX(1);
- Value pos = c0;
- OpBuilder::InsertPoint ip;
- SmallVector<Value> innerArgs(userReduc.begin(), userReduc.end());
- scf::ForOp outerMost = nullptr; // the outermost loop.
-
- // Wraps body builder and inserts a extra counting instruction at the end.
- auto wrapped = [bodyBuilder](OpBuilder &builder, Location loc, Value iv,
- MutableArrayRef<Value> reduc) {
- bodyBuilder(builder, loc, iv, reduc.drop_back());
- // Increments the counter.
- reduc.back() = ADDI(reduc.back(), C_IDX(1));
- };
-
- // FIXME: Need special handling when the previous unresolved slice is strided:
- // We probably need to filter out coordinates that is not on stride.
- if (firstResLvl.has_value()) {
- // Overwrite position when the first level is fully resolved.
- pos = posits[firstResLvl->first][firstResLvl->second];
- ip = builder.saveInsertionPoint();
- } else {
- const SliceInfo &frontSlice = *unResLvls.back();
- Level firstLvl = *frontSlice.slicedOnLvl;
- if (!lvlFullyResolved(tid, firstLvl)) {
- if (isCompressedLT(lvlTypes[tid][firstLvl])) {
- // An extra counter that tracks how many segments are there in the child
- // compressed level.
- innerArgs.push_back(c0);
- // Overrides the user-provided builder.
- bodyBuilder = wrapped;
- unsigned depth = frontSlice.depth - 1;
- Value offset = frontSlice.offset;
- Value sPtrBuf = slicePosBuffer[tid][firstLvl][depth];
- Value mSz = frontSlice.posTupleNum;
- outerMost = builder.create<scf::ForOp>(
- loc, c0, mSz, c1, innerArgs,
- [this, tid, firstLvl, offset, sPtrBuf, &ip, &pos,
- &innerArgs](OpBuilder &builder, Location loc, Value iv,
- ValueRange iterArgs) {
- // generate traversal for each level.
- Value loopLo =
- loadSlicePos(builder, loc, sPtrBuf, iv, SlicePosKind::kLo);
- Value loopHi =
- loadSlicePos(builder, loc, sPtrBuf, iv, SlicePosKind::kHi);
- // We need to remember the starting index for next level's
- // position, because slice-driven loop breaks the level into
- // non-consecutive segments.
- updateSlicePos(builder, loc, sPtrBuf, iterArgs.back(), iv,
- SlicePosKind::kNext);
-
- auto [size, stride] = sliceMeta[tid][firstLvl].back();
- assert(stride == 1 && "Not yet implemented");
- ValueRange itArgs =
- genSliceLvlTraverseLoop(
- builder, loc, loopLo, loopHi, offset, size, tid, firstLvl,
- iterArgs,
- [&](OpBuilder &builder, Location, Value iv,
- MutableArrayRef<Value> reduc) {
- ip = builder.saveInsertionPoint();
- pos = iv;
- innerArgs.assign(reduc.begin(), reduc.end());
- })
- .second;
- YIELD(itArgs);
- });
- } else if (isDenseLT(lvlTypes[tid][firstLvl])) {
- assert(firstLvl == 0); // This must be the first level.
- Value lb = frontSlice.offset;
- auto [sliceSz, stride] =
- sliceMeta[tid][*frontSlice.slicedOnLvl][frontSlice.depth];
- assert(stride == 1 && "Not yet implemented");
- Value ub = ADDI(lb, sliceSz);
- outerMost = builder.create<scf::ForOp>(
- loc, lb, ub, c1, innerArgs,
- [&](OpBuilder &builder, Location loc, Value iv,
- ValueRange iterArgs) {
- ip = builder.saveInsertionPoint();
- pos = iv;
- innerArgs.assign(iterArgs.begin(), iterArgs.end());
- });
- }
- // We generated the loop for the first slice above, now remove it.
- unResLvls = unResLvls.drop_back();
- }
- }
- // Reset the insertion point into the loop body.
- builder.restoreInsertionPoint(ip);
- if (!unResLvls.empty()) {
- // Fills in dense slices levels in between.
- SmallVector<Value> lbs, ubs, steps, lvlSzs;
- for (const SliceInfo *slice : llvm::reverse(unResLvls)) {
- Level sliceLvl = *slice->slicedOnLvl;
- assert(isDenseLT(lvlTypes[tid][sliceLvl]));
- Value offset = slice->offset;
- auto [sliceSz, stride] = sliceMeta[tid][sliceLvl][slice->depth];
- assert(stride == 1 && "Not yet implemented");
- lbs.push_back(offset);
- ubs.push_back(ADDI(offset, sliceSz));
- steps.push_back(c1);
- lvlSzs.push_back(lvls[tid][sliceLvl]->size());
- }
- auto denseNest =
- scf::buildLoopNest(builder, loc, lbs, ubs, steps, innerArgs,
- [&innerArgs, &lvlSzs, &pos, bodyBuilder](
- OpBuilder &builder, Location loc, ValueRange ivs,
- ValueRange iterArgs) -> scf::ValueVector {
- for (auto em : llvm::enumerate(ivs)) {
- // Linearizes position: pos = (pos * lvlsize) +
- // iv;
- pos = MULI(pos, lvlSzs[em.index()]);
- pos = ADDI(pos, em.value());
- }
- innerArgs.assign(iterArgs.begin(), iterArgs.end());
- // Generates user request loop body.
- bodyBuilder(builder, loc, pos, innerArgs);
- return innerArgs;
- });
-
- if (!outerMost) {
- // If the outermost loop has not been set, this is the outermost loop.
- outerMost = denseNest.loops.front();
- } else {
- // Otherwise we need to generate yield operations to link the SSA chain.
- YIELD(denseNest.results);
- }
- } else {
- assert(outerMost);
- // Generates user request loop body.
- bodyBuilder(builder, loc, pos, innerArgs);
- YIELD(innerArgs);
- }
- assert(outerMost);
- // Insert after current while operation.
- builder.setInsertionPointAfter(outerMost);
- return outerMost.getResults();
-}
-
-void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
- TensorId tid, Level lvl) {
- Value c0 = C_IDX(0), c1 = C_IDX(1);
- if (isDenseLT(lvlTypes[tid][lvl])) {
- // Dense slice begin is trivial.
- sliceStack[tid].emplace_back(/*minCoord=*/c0, /*offset=*/c0,
- /*nonEmpty=*/constantI1(builder, loc, true),
- c0, lvl, /*depth=*/1);
- return;
- }
- auto [nxSz, stride] = sliceMeta[tid][lvl][1];
- assert(stride == 1 && "Not yet implemented");
- Value sPtrBuf = slicePosBuffer[tid][lvl][0];
- const SparseTensorLevel &stl = *lvls[tid][lvl];
-
- Value p = lvl == 0 ? c0 : posits[tid][lvl - 1];
- auto [pLo, pHi] = stl.peekRangeAt(builder, loc, p);
-
- // Fills out pIdxBuffer[tid][lvl][0] with [pLo, pHi]
- updateSlicePos(builder, loc, sPtrBuf, pLo, c0, SlicePosKind::kLo);
- updateSlicePos(builder, loc, sPtrBuf, pHi, c0, SlicePosKind::kHi);
- // Slice over a resolved parent, we only need one pair of pos hi and lo to
- // specify the current slice.
- Value tupleNum = c1;
- // This is an non empty tensor if pLo < pHi.
- Value isNonEmpty = CMPI(ult, pLo, pHi);
- // The minimal coord must be at the first on ordered level.
- // FIXME: Technically we should load the coord only when the slice is
- // nonempty. though we assume that even on empty sparse tensors, a non-empty
- // ptr/idx buffer is allocated for each level so it would not cause OOB to
- // avoid generating a ifOp here.
- Value minCrd = stl.peekCrdAt(builder, loc, pLo);
-
- // FIXME: We need the relative offset related to the base slice.
- Value absOffset = offsetFromMinCoord(builder, loc, minCrd, nxSz, isNonEmpty);
- sliceStack[tid].emplace_back(minCrd, absOffset, isNonEmpty, tupleNum, lvl,
- /*depth=*/1);
-}
-
-// Fills in the slicePosBuffer before slice-driven loop begin.
-// TODO: it can only handle all compressed tensors.
-//
-// // Loop generated by `genUnResolvedSliceTreeTraverse`
-// for(int i = 0; i < slicePos.size(); i+=2) {
-// loopLo = slicePos[i];
-// loopHi = slicePos[i + 1];
-// minCrd = max;
-// while (loopLo < loopHi) {
-// if (pos[loopLo] < sliceHi) {
-// // bodyBuilder
-// slicePos[tid].push_back(pos[loopLo]);
-// slicePos[tid].push_back(pos[loopLo + 1]);
-// minCrd = min(minCrd, crd[pos[loopLo]]);
-// } else {
-// break;
-// }
-// loopLo ++;
-// }
-// }
-void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
- TensorId tid, Level lvl) {
- Value c0 = C_IDX(0);
- unsigned depth = levelReducedDep[tid][lvl];
- // The remaining slice size after reduction.
- Value remSz = sliceMeta[tid][lvl][depth + 1].first;
- // Dense slice begin is trivial
- if (isDenseLT(lvlTypes[tid][lvl])) {
- sliceStack[tid].emplace_back(c0, c0, constantI1(builder, loc, false), c0,
- lvl, depth + 1);
- return;
- }
-
- assert(isCompressedLT(lvlTypes[tid][lvl]));
- // Unhandled Cases:
- //
- // 1st, lvl = prevSlicedLvl, i.e., t[d0 + d1 + d2,...] (more than one
- // variable need to be reduced on the same level).
- //
- // 2nd, lvl > prevSliceLvl + 1, i.e., t[..., d2, d3 + d4] (having a
- // simple dim expression in between).
- assert(lvl == *sliceStack[tid].back().slicedOnLvl + 1);
-
- SmallVector<const SliceInfo *> unResSlices;
- std::optional<std::pair<TensorId, Level>> firstResLvl;
- for (Level curLvl = lvl; curLvl >= 1; curLvl--) {
- Level prevLvl = curLvl - 1;
- if (lvlFullyResolved(tid, prevLvl)) {
- firstResLvl = std::make_pair(tid, prevLvl);
- break;
- }
- unResSlices.push_back(&getMostRecentSliceOnLvl(tid, prevLvl));
- if (!isDenseLT(lvlTypes[tid][prevLvl])) {
- break;
- }
- }
-
- assert(!unResSlices.empty() &&
- !lvlFullyResolved(tid, *unResSlices.front()->slicedOnLvl));
-
- Value sPtrBuf = slicePosBuffer[tid][lvl].back();
- SmallVector<Value, 3> reduc = {
- constantI1(builder, loc, false), // isNonEmpty
- lvls[tid][lvl]->size(), // minCoord
- c0, // memSize
- };
-
- ValueRange result = genUnResolvedSliceTreeTraverse(
- builder, loc, tid, unResSlices, firstResLvl, reduc,
- [this, tid, lvl, sPtrBuf](OpBuilder &builder, Location loc, Value iv,
- MutableArrayRef<Value> reduc) {
- Value &nonEmpty = reduc[0];
- Value &minCrd = reduc[1];
- Value &curTupleCnt = reduc[2];
-
- const SparseTensorLevel &stl = *lvls[tid][lvl];
- auto [sPLo, sPHi] = stl.peekRangeAt(builder, loc, iv);
-
- // isNonEmpty = isNonEmpty || lvlNonEmpty, i.e., as long as there is
- // one non-empty lvl, the slice is non-empty.
- Value lvlNonEmpty = CMPI(ult, sPLo, sPHi);
- nonEmpty = builder.create<arith::OrIOp>(loc, lvlNonEmpty, nonEmpty);
-
- // Update the minimum coordinate.
- auto ifNonEmpty = builder.create<scf::IfOp>(loc, builder.getIndexType(),
- lvlNonEmpty, true);
- {
- // Generate Code as follows.
- //
- // if (nonEmpty) {
- // minCrd = min(minCrd, crd[pos[pLo]]);
- // }
- OpBuilder::InsertionGuard guard(builder);
- builder.setInsertionPointToStart(ifNonEmpty.thenBlock());
- Value curC = stl.peekCrdAt(builder, loc, sPLo);
- Value isSmaller = CMPI(ult, curC, minCrd);
- Value newMin = SELECT(isSmaller, curC, minCrd);
- YIELD(newMin);
- builder.setInsertionPointToStart(ifNonEmpty.elseBlock());
- YIELD(minCrd);
- }
- minCrd = ifNonEmpty.getResult(0);
- updateSlicePos(builder, loc, sPtrBuf, sPLo, curTupleCnt,
- SlicePosKind::kLo);
- updateSlicePos(builder, loc, sPtrBuf, sPHi, curTupleCnt,
- SlicePosKind::kHi);
- curTupleCnt = ADDI(curTupleCnt, C_IDX(1));
- });
-
- Value isNonEmpty = result[0];
- Value minCrd = result[1];
- // Two metadata [memSize, idx].
- // FIXME: we need the relative offset related to the base slice.
- Value absOffset = offsetFromMinCoord(builder, loc, minCrd, remSz, isNonEmpty);
- sliceStack[tid].emplace_back(minCrd, absOffset, isNonEmpty, result[2], lvl,
- depth + 1);
-}
-
-bool LoopEmitter::genSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
- Level lvl) {
- Value curLvlIdx = C_IDX(0);
- if (depFullyReduced(tid, lvl)) {
- if (lvl == 0 || trivialSlice[tid][lvl]) {
- sliceTupleNxStartIdx[tid][lvl] = C_IDX(0);
- } else {
- if (isDenseLT(lvlTypes[tid][lvl])) {
- sliceTupleNxStartIdx[tid][lvl] = sliceTupleNxStartIdx[tid][lvl - 1];
- } else {
- assert(isCompressedLT(lvlTypes[tid][lvl]));
- curLvlIdx = ADDI(sliceTupleNxStartIdx[tid][lvl - 1],
- sliceTupleFwdCnt[0][lvl - 1]);
- sliceTupleNxStartIdx[tid][lvl] =
- loadSlicePos(builder, loc, slicePosBuffer[tid][lvl].back(),
- curLvlIdx, SlicePosKind::kNext);
- }
- }
- if (isDenseLT(lvlTypes[tid][lvl]))
- return true;
-
- Value sPosBuf = slicePosBuffer[tid][lvl].back();
- // If constraints on the tensor is fully resolved. We do not need to
- // generates slice begin any more, instead we fall back to TACO-based
- // algorithm to (co)iterates over the slice.
- Value tupleIdx = curLvlIdx;
- posits[tid][lvl] =
- loadSlicePos(builder, loc, sPosBuf, tupleIdx, SlicePosKind::kLo);
- highs[tid][lvl] =
- loadSlicePos(builder, loc, sPosBuf, tupleIdx, SlicePosKind::kHi);
- return true;
- }
-
- // Only when the level is sorted, the next-non-empty slice can be computed
- // efficiently.
- const LevelType lvlType = lvlTypes[tid][lvl];
- assert(isOrderedLT(lvlType));
- if (isSingletonLT(lvlType)) {
- llvm_unreachable("TODO: dense level should be easy to support, while "
- "singleton level requires more efforts");
- }
-
- assert(!dependentLvlMap[tid][lvl].empty());
- assert(!sliceStack[tid].empty());
-
- const SliceInfo &sliceInfo = sliceStack[tid].back();
- auto baseEnc = getSparseTensorEncoding(tensors[tid].getType());
- if (baseEnc.isSlice())
- llvm_unreachable("TODO: not yet implemented");
-
- if (sliceInfo.isInitialTensor() ||
- (lvl >= 1 && lvlFullyResolved(tid, lvl - 1))) {
- // First level or previous level has been full resolved.
- trivialSlice[tid][lvl] = true;
- genResolvedSliceBegin(builder, loc, tid, lvl);
- } else {
- // The previous level has not been full resolved.
- trivialSlice[tid][lvl] = false;
- genUnResolvedSliceBegin(builder, loc, tid, lvl);
- }
- return false;
-}
-
-std::tuple<Value, Value, Value>
-LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
- TensorId tid, Level lvl) {
- if (!isCompressedLT(lvlTypes[tid][lvl]))
- llvm_unreachable("TODO");
-
- // else generate code to compute next non empty slice.
- Value c0 = C_IDX(0), c1 = C_IDX(1);
-
- SliceInfo &info = sliceStack[tid].back();
- assert(info.slicedOnLvl == lvl);
- //
- // We forward to the next non empty slice by
- // if (minCrd > offset) {
- // offset += 1
- // } else {
- // minCrd = nextMinInSlice();
- // offset = minCrd - size + 1;
- // }
- //
- // if (offset + size > parents.size)
- // isNonEmpty = false;
- //
- Value absOffset = info.offset;
- SmallVector<Value, 3> reduc = {info.minCrd, info.isNonEmpty, absOffset};
- Value sPtrBuf = slicePosBuffer[tid][lvl][info.depth - 1];
- Value fastPathP = CMPI(ugt, info.minCrd, absOffset);
- auto ifOp = builder.create<scf::IfOp>(loc, ValueRange(reduc).getTypes(),
- fastPathP, true);
- {
- OpBuilder::InsertionGuard guard(builder);
- // Take the fast path
- // if (minCrd > offset) {
- // return offset += 1
- // }
- builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
- reduc[2] = ADDI(absOffset, c1);
- // Yield offset + 1.
- YIELD(reduc);
-
- // else /*minCrd == offset*/ {
- // for (i = 0; i < slicePos.size(); i+=kSliceIterWidth) {
- // if (crd[pos[slicePos[i]]] == minCrd) {
- // slicePos[i]++;
- // }
- // minCrd=min(minCrd, crd[pos[slicePos[i]]]);
- // }
- // offset = minCrd - size + 1;
- // }
- builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
- reduc[2] = absOffset; // restore value.
- Value mSz = info.posTupleNum; // tuple number.
- reduc[0] = lvls[tid][lvl]->size(); // next min coord
- reduc[1] = constantI1(builder, loc, false); // isNonEmpty
- auto loopArgs = static_cast<ValueRange>(reduc).drop_back();
- auto forOp = scf::buildLoopNest(
- builder, loc, c0, mSz, c1, loopArgs,
- [this, tid, lvl, c1, sPtrBuf,
- &info](OpBuilder &builder, Location loc, ValueRange ivs,
- ValueRange iterArgs) -> scf::ValueVector {
- Value curMinCrd = iterArgs[0];
- Value isNonEmpty = iterArgs[1];
-
- Type idxTp = builder.getIndexType();
- Value pLo = loadSlicePos(builder, loc, sPtrBuf, ivs.front(),
- SlicePosKind::kLo);
- Value pHi = loadSlicePos(builder, loc, sPtrBuf, ivs.front(),
- SlicePosKind::kHi);
- //
- // if (pLo < pHi) // Only loads when inbound.
- // coord = load[pLo]
- // if coord == minCrd
- // pLo += 1
- //
- // if (pLo < pHi)
- // curMinCrd = min(curMinCrd, load[pLo])
- //
- Value pred = CMPI(ult, pLo, pHi);
- auto advPLo = builder.create<scf::IfOp>(loc, idxTp, pred, true);
- /* if pLo < pHi */ {
- builder.setInsertionPointToStart(&advPLo.getThenRegion().front());
- // coord = load[pLo]
- Value coord = lvls[tid][lvl]->peekCrdAt(builder, loc, pLo);
- Value pred = CMPI(eq, coord, info.minCrd);
- auto ifEqual = builder.create<scf::IfOp>(loc, idxTp, pred, true);
- /* if coord == minCrd */ {
- builder.setInsertionPointToStart(
- &ifEqual.getThenRegion().front());
- Value newPlo = ADDI(pLo, c1);
- // Updates the cache.
- updateSlicePos(builder, loc, sPtrBuf, newPlo, ivs.front(),
- SlicePosKind::kLo);
- YIELD(newPlo);
- }
- /* else coord != minCrd */ {
- builder.setInsertionPointToStart(
- &ifEqual.getElseRegion().front());
- YIELD(pLo);
- }
- builder.setInsertionPointAfter(ifEqual);
- YIELD(ifEqual.getResults());
- }
- /* else pLo >= pHi */ {
- builder.setInsertionPointToStart(&advPLo.getElseRegion().front());
- YIELD(pLo);
- }
-
- builder.setInsertionPointAfter(advPLo);
- pLo = advPLo.getResult(0);
- Value lvlNonEmpty = CMPI(ult, pLo, pHi);
- // Update minCrds
- auto newMin =
- builder.create<scf::IfOp>(loc, idxTp, lvlNonEmpty, true);
- builder.setInsertionPointToStart(&newMin.getThenRegion().front());
- YIELD(lvls[tid][lvl]->peekCrdAt(builder, loc, pLo));
-
- builder.setInsertionPointToStart(&newMin.getElseRegion().front());
- YIELD(curMinCrd);
- builder.setInsertionPointAfter(newMin);
-
- // isNonEmpty = isNonEmpty || lvlNonEmpty
- isNonEmpty =
- builder.create<arith::OrIOp>(loc, lvlNonEmpty, isNonEmpty);
- curMinCrd = builder.create<arith::SelectOp>(
- loc, CMPI(ult, newMin.getResult(0), curMinCrd),
- newMin.getResult(0), curMinCrd);
- return {curMinCrd, isNonEmpty};
- });
-
- builder.setInsertionPointAfter(forOp.loops.front());
- // minOffset = minCrd + 1 >= size ? minCrd + 1 - size : c0
- Value tmp = ADDI(forOp.results.front(), c1);
- auto [size, stride] = sliceMeta[tid][lvl][info.depth];
- assert(stride == 1 && "Not yet implemented");
- Value minOffset = SUBI(tmp, size);
- Value p = CMPI(uge, tmp, size);
- minOffset = SELECT(p, minOffset, c0);
-
- SmallVector<Value, 3> yields;
- yields.assign(forOp.results.begin(), forOp.results.end());
- yields.push_back(minOffset);
- YIELD(yields);
- }
-
- Value nextMinCrd = ifOp.getResults()[0];
- Value nextNonEmpty = ifOp.getResults()[1];
-
- // The next offset should at least be offset + 1;
- Value minOffset = ifOp.getResults()[2];
- Value nxOffset = ADDI(info.offset, c1);
- Value maxPred = CMPI(ugt, minOffset, nxOffset);
- Value nextAbsOffset = SELECT(maxPred, minOffset, nxOffset);
-
- auto [size, stride] = sliceMeta[tid][lvl][info.depth];
- assert(stride == 1 && "Not yet implemented");
- Value sliceUB = ADDI(nextAbsOffset, size);
-
- // FIXME: this only works if there is only one parent.
- assert(info.depth - 1 == 0);
- // nextNonEmpty = nextNonEmpty && slice upper bound <= parent upperbound.
- nextNonEmpty = ANDI(nextNonEmpty, CMPI(ule, sliceUB, lvls[tid][lvl]->size()));
-
- // FIXME: compute relative offset.
- assert(info.depth - 1 == 0);
- return std::make_tuple(nextNonEmpty, nextMinCrd, nextAbsOffset);
-}
-
#undef CMPI
#undef C_IDX
#undef YIELD
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
index 2bd2b653a4d9f3..2b508e04162325 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
@@ -124,19 +124,8 @@ class LoopEmitter {
/// Exits the current loop sequence, this will reset universal index to 0.
void exitCurrentLoopSeq(OpBuilder &builder, Location loc);
- /// Enters a loop that tries to locate a coordinates in a sparse level based
- /// on the value evaluated by the provided affine expression.
- /// DEPRECATED: affine index expression should be handled by index reduction
- /// loop, filter loop-based solution is slow.
- Operation *enterFilterLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
- TensorId tid, Level lvl,
- AffineExpr affine,
- MutableArrayRef<Value> reduc = {});
-
/// Emits the address for a dense level based on the value evaluated by the
/// provided affine expression.
- /// DEPRECATED: affine index expression should be handled by index reduction
- /// loop, filter loop-based solution is slow.
void genDenseAffineAddress(OpBuilder &builder, Location loc,
TensorLevel tidLvl, AffineExpr lvlExpr);
@@ -224,21 +213,16 @@ class LoopEmitter {
});
}
- template <class ContainerTy>
- auto unpackTensorLevelFromCondRange(ContainerTy &&c) const {
- using EltTy = decltype(*c.begin());
- static_assert(std::is_same_v<llvm::remove_cvref_t<EltTy>, TensorLvlCond>,
- "Must be unpacking a TensorLvlCond range");
- return unpackTensorLevelRange(
- llvm::make_first_range(std::forward<ContainerTy>(c)));
- }
-
///
/// Getters.
///
- const std::vector<std::vector<Value>> &getPosits() const { return posits; };
- const std::vector<std::vector<Value>> &getCoords() const { return coords; };
- const std::vector<std::vector<Value>> &getHighs() const { return highs; };
+ Value getValPosits(TensorId tid) const {
+ Value lastLvlPos = iters[tid].back().back()->getCurPosition().first;
+ return lastLvlPos;
+ };
+ Value getCoord(TensorId tid, Level lvl) const {
+ return getCurIterator(tid, lvl).getCrd();
+ };
const std::vector<Value> &getValBuffer() const { return valBuffer; };
constexpr static llvm::StringLiteral getLoopEmitterLoopAttrName() {
@@ -250,22 +234,12 @@ class LoopEmitter {
/// Structure definitions that hold different kinds of loops information.
///
- // A tuple that stored the slice-driven loop information.
- struct SliceLoopInfo final {
- SliceLoopInfo(TensorId tid, Level lvl, bool reduced)
- : tid(tid), lvl(lvl), reduced(reduced) {}
- TensorId tid;
- Level lvl;
- bool reduced;
- };
// LoopInfo stores information of a loop generated by LoopEmitter. E.g.,
// the set of tensors levels that the loop is iterating over.
struct LoopInfo final {
- LoopInfo(ArrayRef<TensorLevel> trivialTidLvls,
- ArrayRef<SliceLoopInfo> sliceDrivenInfo, Operation *loop,
- Block *userBlock, Value iv, StringAttr loopTag)
- : trivialTidLvls(trivialTidLvls), sliceDrivenInfo(sliceDrivenInfo),
- loop(loop), userCodeBlock(userBlock), iv(iv) {
+ LoopInfo(ArrayRef<TensorLevel> tidLvls, Operation *loop, Block *userBlock,
+ Value iv, StringAttr loopTag)
+ : tidLvls(tidLvls), loop(loop), userCodeBlock(userBlock), iv(iv) {
// Attached a special tag to loop emitter generated loop.
if (loopTag)
loop->setAttr(LoopEmitter::getLoopEmitterLoopAttrName(), loopTag);
@@ -274,125 +248,12 @@ class LoopEmitter {
// used as the condition for the generated loop. Extra information is
// required for levels with non-tivial index expressions, which is
// maintained by the sliceDrivenInfo array below.
- const llvm::SmallVector<TensorLevel> trivialTidLvls;
- // The set of <tensor, lvl>, with *only* non-trivial index expressions, that
- // are used as the condition for the generated loop.
- const llvm::SmallVector<SliceLoopInfo> sliceDrivenInfo;
+ const llvm::SmallVector<TensorLevel> tidLvls;
const Operation *loop; // the loop operation
Block *const userCodeBlock; // the block holding users' generated code.
const Value iv; // the induction variable for the loop
};
- // SliceInfo stores information of an extracted slice for slice-driven loop.
- // E.g., the in-scope SSA values for the minimum coordinates and offset for
- // the slice, etc.
- struct SliceInfo final {
- // Note that we do not need to create a actual sparse tensor slice but
- // instead only need to maintain the metadata of the slice.
- SliceInfo(Value minCrd, Value offset, Value isNonEmpty, Value posTupleNum,
- std::optional<Level> slicedOnLvl, unsigned depth)
- : minCrd(minCrd), offset(offset), isNonEmpty(isNonEmpty),
- posTupleNum(posTupleNum), slicedOnLvl(slicedOnLvl), depth(depth) {
- // TODO: use std::optional<pair<Level, minCrd>>
- assert(!slicedOnLvl || minCrd);
- }
-
- // Whether this is the tensor that has not yet been sliced.
- bool isInitialTensor() const { return !slicedOnLvl.has_value(); }
-
- Value minCrd; // the minimum coordinate of the slice.
- Value offset; // the *absolute* offset of the current slice.
- Value isNonEmpty; // whether the slice is empty.
- Value posTupleNum; // The number of position tuples used in the slice.
- std::optional<Level> slicedOnLvl; // the level on which the slice is done
- unsigned depth; // the depth (relative to dependentDimMap[tid][lvl]).
- };
-
- ///
- /// Enums for different kinds of loop conditions.
- /// TODO: remove the enum after fully migrating to SparseTensorLevel.
- ///
-
- // The bit indicating whether the loop conditions is sparse.
- static constexpr uint8_t kSparseCond = 1 << 3;
- // The bit indicating whether the loop iterates over sparse tensor slices
- // (i.e., with non-empty SliceDimAttr).
- static constexpr uint8_t kSliceCond = 1 << 2;
- // The bit indicating whether the loop iterates over tensor levels with
- // non-trivial affine index reduction.
- static constexpr uint8_t kAffineIdxCond = 1 << 1;
- // The bit indicating whether the loop iterates over tensor levels with
- // non-trivial affine index reduction, and it is not fully reduced.
- static constexpr uint8_t kAffineIdxCondUnRed = 1 << 0;
-
- enum class LoopCondKind : uint8_t {
- // Dense conditions.
- DenseCond = 0,
- DenseSliceCond = kSliceCond,
- DenseAffineCond = kAffineIdxCond,
- DenseAffineUnRedCond = kAffineIdxCond | kAffineIdxCondUnRed,
- // Sparse Conditions.
- SparseCond = kSparseCond,
- SparseSliceCond = kSparseCond | kSliceCond,
- SparseAffineCond = kSparseCond | kAffineIdxCond,
- SparseAffineUnRedCond = kSparseCond | kAffineIdxCond | kAffineIdxCondUnRed,
- };
- using TensorLvlCond = std::pair<TensorLevel, LoopCondKind>;
-
- /// Sparse or dense loop condition.
- static bool isSparseCond(LoopCondKind k) {
- return static_cast<uint8_t>(k) & kSparseCond;
- }
- static bool isDenseCond(LoopCondKind k) { return !isSparseCond(k); }
-
- /// Whether loops over sparse tensor slices or sparse tensors.
- static bool isSliceCond(LoopCondKind k) {
- return static_cast<uint8_t>(k) & kSliceCond;
- }
-
- /// Affine or trivial index expression loop condition.
- static bool isAffineIdxCond(LoopCondKind k) {
- return static_cast<uint8_t>(k) & kAffineIdxCond;
- }
- static bool isTrivalIdxCond(LoopCondKind k) { return !isAffineIdxCond(k); }
-
- /// Whether the affine index expression is fully reduced.
- static bool isAffineIdxUnRedCond(LoopCondKind k) {
- return isAffineIdxCond(k) && static_cast<uint8_t>(k) & kAffineIdxCondUnRed;
- }
- static bool isAffineIdxRedCond(LoopCondKind k) {
- return isAffineIdxCond(k) && !isAffineIdxUnRedCond(k);
- }
-
- // Whether the loop condition kind requires extra check inside the loop body.
- // E.g., to iterate over sparse tensor slice, we need to check whether the
- // current cooridnate is on the slice (e.g., due to stride) or not.
- static bool isCondWithExtraCheck(LoopCondKind k) {
- return isSparseCond(k) && (isSliceCond(k) || isAffineIdxUnRedCond(k));
- }
-
- static LoopCondKind makeLoopCondKind(bool isSparse, bool isSlice,
- bool isAffine, bool isUnRedu) {
- assert(!isUnRedu || isAffine);
- uint8_t bits = 0;
- bits = isSparse ? bits | kSparseCond : bits;
- bits = isSlice ? bits | kSliceCond : bits;
- bits = isAffine ? bits | kAffineIdxCond : bits;
- bits = isUnRedu ? bits | kAffineIdxCondUnRed : bits;
- LoopCondKind kind = static_cast<LoopCondKind>(bits);
-
- // Sanity checks.
- assert(isSparse == isSparseCond(kind));
- assert(isSlice == isSliceCond(kind));
- assert(isAffine == isAffineIdxCond(kind));
- assert(isUnRedu == isAffineIdxUnRedCond(kind));
- return kind;
- }
-
- void categorizeLoopCondition(ArrayRef<TensorLevel> tidLvls,
- SmallVectorImpl<TensorLvlCond> &dnConds,
- SmallVectorImpl<TensorLvlCond> &spConds);
-
void categorizeIterators(ArrayRef<TensorLevel> tidLvls,
SmallVectorImpl<SparseIterator *> &raIters,
SmallVectorImpl<SparseIterator *> &spIters);
@@ -406,20 +267,6 @@ class LoopEmitter {
/// Whether the list of the sparse condition should be iterated by for loop.
bool shouldIteratedByForLoop(ArrayRef<SparseIterator *> spIters);
- /// Linearizes address for dense dimension (i.e., p = (i * d0) + j).
- Value genAddress(OpBuilder &builder, Location loc, TensorId tid, Level lvl,
- Value iv);
-
- /// Generates the segment high for a non-unique level (to fast forward
- /// duplicated coordinates). That is, it generates the code:
- ///
- /// crd = coordinates_tid_lvl[pos]
- /// while (pos < pHi && coordinates_tid_lvl[pos] == crd)
- /// pos++;
- /// <return pos>;
- Value genSegmentHigh(OpBuilder &builder, Location loc, TensorId tid,
- Level lvl, Value pos, Value pHi);
-
/// Generates instructions to compute the coordinate of tensors[tid][lvl]
/// under the current loop context. The final argument is the
/// collapsed-output level, whereas this function handles converting
@@ -427,13 +274,6 @@ class LoopEmitter {
Value genSparseCrd(OpBuilder &builder, Location loc, TensorId tid,
Level dstLvl);
- /// Generates a predicate to determine whether the tranformed coordinates are
- /// in the given slice.
- /// Returns std::pair<Transformed coordinates, Predicate>
- std::pair<Value, Value> genSliceLegitPredicate(OpBuilder &builder,
- Location loc, Value crd,
- TensorId tid, Level lvl);
-
bool isSynTensor(TensorId tid) const { return tid == getSynTensorId(); }
bool isOutputTensor(TensorId tid) const {
@@ -453,13 +293,6 @@ class LoopEmitter {
void prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
TensorId tid, Level lvl);
- /// Enter dense tensor levels. Since the dense tensor condition could be
- /// optimized from the loop condition, we need to compute the
- /// positions/coordinates inside the loop body.
- void enterTensorsAtDenseLvls(OpBuilder &builder, Location loc,
- ArrayRef<SparseIterator *> dnConds, Value iv,
- SmallVectorImpl<SliceLoopInfo> &sliceInfo);
-
/// Emits a for loop to iterate over a tensor level with the provided
/// lower bound `lo` and upper bound `hi`. Apart from iterating just
/// single tensor level, for loops can be used for slice-driven loop on
@@ -482,23 +315,6 @@ class LoopEmitter {
ArrayRef<SparseIterator *> iters,
MutableArrayRef<Value> reduc, bool needsUniv);
- /// Generates the while loop condition for the given tensor level condition.
- Value genWhileLoopConditions(OpBuilder &builder, Location loc, ValueRange ivs,
- TensorLvlCond cond);
-
- /// Generates the while loop body for the given tensor level condition.
- std::optional<Value> genWhileLoopBody(OpBuilder &builder, Location loc,
- ValueRange ivs, TensorLvlCond cond);
-
- /// Generates the values (to forward the loop) if the extra check failes.
- /// E.g., to iterate over a sparse tensor slice, we need:
- ///
- /// pos = onSlice(curCrd) ? pos : pos + 1
- ///
- /// to skip invalid coordinate that is included in the slice.
- ValueRange genCheckedValue(OpBuilder &builder, Location loc, Value pred,
- ValueRange curArg, TensorLvlCond cond);
-
/// Exits a for loop, returns the reduction results, e.g.,
/// For sequential for loops:
/// %ret = for () {
@@ -535,27 +351,11 @@ class LoopEmitter {
//
void initSubSectIterator(OpBuilder &builder, Location loc);
- // TODO: remove below.
- void initSliceDriven(OpBuilder &builder, Location loc);
-
- /// Retrieves the most recent slice on lvl. To reduce affine expression like
- /// d0 + d1 + d2, we need two slices (one of size d1 + d2, and the other of
- /// size d2). This methods returns the latter slice (of size d2).
- const SliceInfo &getMostRecentSliceOnLvl(TensorId tid, Level lvl);
-
- /// Similar to getMostRecentSliceOnLvl, but yields error when the most recent
- /// slice is not the final slice needed to fully reduced the dependencies.
- const SliceInfo &getFinalSliceOnLvl(TensorId tid, Level lvl) {
- const SliceInfo &info = getMostRecentSliceOnLvl(tid, lvl);
- assert(info.depth == dependentLvlMap[tid][lvl].size() - 1);
- return info;
- }
- /// Get the remaining number of constraints needed to fully *resolve*
- /// dependent levels on tensor[tid].
- unsigned remDepOnLevel(TensorId tid, Level lvl) const;
/// Get the reduced number of contraints on tensor[tid][lvl].
- unsigned redDepOnLevel(TensorId tid, Level lvl) const;
+ unsigned redDepOnLevel(TensorId tid, Level lvl) const {
+ return levelReducedDep[tid][lvl];
+ };
SparseIterator &getCurIterator(TensorId tid, Level lvl) const {
if (dependentLvlMap[tid][lvl].empty())
@@ -565,70 +365,9 @@ class LoopEmitter {
return *iters[tid][lvl][redDepOnLevel(tid, lvl) - 1];
}
- /// Whether the tid, lvl is fully *reduced*, i.e., the non-trivial index
- /// expression has been reduced to a trivial one.
- /// E.g., A[i + j] => A[i + 2] (j is reduced)
- bool depFullyReduced(TensorId tid, Level lvl) const {
- return remDepOnLevel(tid, lvl) == 1;
- }
-
- /// Whether the tid, lvl is fully resolved, i.e., we entered the level already
- /// (the index on that level is determined).
- /// E.g., A[i + j] => A[2 + 3] (both i and j become invariants for inner
- /// loops).
- bool lvlFullyResolved(TensorId tid, Level lvl) const {
- return remDepOnLevel(tid, lvl) == 0;
- }
-
- /// Generates a whileOp to iterate over a subset of coordinates on tid on lvl
- /// using the pHi and pLo provided, the loop break on the first coordinate
- /// that exceeds the slice boundary (i.e., coord >= slice.offset +
- /// slice.size).
- std::pair<Operation *, ValueRange>
- genSliceLvlTraverseLoop(OpBuilder &builder, Location loc, Value pLo,
- Value pHi, Value offset, Value size, TensorId tid,
- Level lvl, ValueRange userReduc,
- LoopBodyBuilder bodyBuilder);
-
- /// Generates a nested loop that iterates over tid on all the coordinates on
- /// lvl.
- ValueRange genUnResolvedSliceTreeTraverse(
- OpBuilder &builder, Location loc, TensorId tid,
- ArrayRef<const SliceInfo *> unResLvls,
- std::optional<std::pair<TensorId, Level>> firstResLvl,
- ValueRange userReduc, LoopBodyBuilder bodyBuilder);
-
- /// Generates code to get the first non-empty slice of tid on lvl, when all
- /// the previous level before `lvl` are resolved (or lvl is the first level).
- ///
- /// This is the simple case because the previous level are resolved into a
- /// single node in the storage tree.
- void genResolvedSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
- Level lvl);
-
- /// Generates code to get the first non-empty slice of tid on lvl, when
- /// the previous levels before `lvl` are unresolved
- ///
- /// This is the complex case because the previous levels corresponding to a
- /// range of nodes in the storage tree.
- void genUnResolvedSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
- Level lvl);
-
- /// Generates code to get the first non-empty slice of tid on lvl.
- /// return true if has already been resolved.
- bool genSliceBegin(OpBuilder &builder, Location loc, TensorId tid, Level lvl);
-
std::unique_ptr<SparseIterator>
makeLevelIterator(OpBuilder &builder, Location loc, TensorId tid, Level l);
- /// Generates code to get the next non-empty slices of tid on lvl.
- /// Returns a tuple of values for <NonEmpty, MinCrd, AbsOffset> (see
- /// SliceInfo) respectively.
- std::tuple<Value, Value, Value> genSliceNextInduction(OpBuilder &builder,
- Location loc,
- TensorId tid,
- Level lvl);
-
/// A optional string attribute that should be attached to the loop
/// generated by loop emitter, it might help following passes to identify
/// loops that operates on sparse tensors more easily.
@@ -644,48 +383,16 @@ class LoopEmitter {
/// Input and (optional) output tensors.
std::vector<Value> tensors;
+ std::vector<Value> loopHighs;
std::vector<std::vector<std::unique_ptr<SparseTensorLevel>>> lvls;
std::vector<std::vector<std::vector<std::unique_ptr<SparseIterator>>>> iters;
std::vector<Value> valBuffer; // to_value
- // TODO: remove all below.
- /// Level-types for each `(TensorId, Level)` pair.
- // Sparse iteration information for each `(TensorId, Level)` pair.
- // These arrays are updated to remain current within the current loop.
- std::vector<std::vector<LevelType>> lvlTypes;
- std::vector<std::vector<Value>> posits;
- /// The collection of coordinates for a given element (one such
- /// collection for each tensor).
- std::vector<std::vector<Value>> coords;
- // The segment upper bound for non-uniques level after de-duplication.
- std::vector<std::vector<Value>> segHi;
- std::vector<std::vector<Value>> highs;
- std::vector<std::vector<Value>> lvlSizes;
-
- //
- // Slice-driven loops related fields.
- //
-
- /// Whether the sparse input is a slice.
- std::vector<bool> isSparseSlices;
- /// Values related to slices.
- std::vector<std::vector<Value>> sliceOffsets;
- std::vector<std::vector<Value>> sliceStrides;
-
// Map from [tid, level] to a list of dependent [tidlevel, coefficient].
// See comments for `DependentLvlGetter`.
std::vector<std::vector<std::vector<std::pair<LoopId, unsigned>>>>
dependentLvlMap;
- // The cached position buffer for the slices, they serve the same purpose as
- // ptrBuffer for compressed dimensions.
- // But they always starts with the first pidx pointing to coord >
- // slice.offset to avoid iteration from the beginning.
- std::vector<std::vector<std::vector<Value>>> slicePosBuffer;
- std::vector<std::vector<Value>> sliceTupleNxStartIdx;
- std::vector<std::vector<Value>> sliceTupleFwdCnt;
- std::vector<std::vector<bool>> trivialSlice;
-
// The (size, stride) for each conceptual slice used for index reduction
// loops.
std::vector<std::vector<std::vector<std::pair<Value, unsigned>>>> sliceMeta;
@@ -693,9 +400,6 @@ class LoopEmitter {
// The number of reduced dependencies on a tensor level so far.
std::vector<std::vector<unsigned>> levelReducedDep;
- // sliceStack[tid] holds the generated slice stack on tid.
- std::vector<std::vector<SliceInfo>> sliceStack;
-
//
// Fields which have at most `numLoops` many entries.
//
>From 69339a59ccd09b2926a270f00f7125d9bd05ac7c Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 16 Jan 2024 19:56:08 +0000
Subject: [PATCH 09/11] fix bugs
---
.../Transforms/Sparsification.cpp | 4 +--
.../Transforms/Utils/LoopEmitter.cpp | 26 ++++++++++---------
.../Transforms/Utils/LoopEmitter.h | 4 +--
3 files changed, 18 insertions(+), 16 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 6e1670bcc7dc44..21d73f5cc7469a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1103,7 +1103,7 @@ static void genConstantDenseAddressFromLevel(CodegenEnv &env,
for (Level l = startLvl; l < lvlRank; l++) {
AffineExpr lvlExpr = lvlExprs[l];
if (enc.isDenseLvl(l) && isa<AffineConstantExpr>(lvlExpr))
- env.emitter().genDenseAffineAddress(
+ env.emitter().locateLvlAtAffineAddress(
builder, loc, env.makeTensorLevel(tid, l), lvlExpr);
else
return; // break on first non-dense non-constant level
@@ -1152,7 +1152,7 @@ static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
Operation *loop = genLoop(env, builder, curr, needsUniv, tidLvls);
Location loc = env.op().getLoc();
for (auto [tidLvl, exp] : affineTidLvls) {
- env.emitter().genDenseAffineAddress(builder, loc, tidLvl, exp);
+ env.emitter().locateLvlAtAffineAddress(builder, loc, tidLvl, exp);
}
// Until now, we have entered every <tid, lvl> pair in {cond, extra,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
index 76f7adac88b9a7..e4c9d3e253015b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
@@ -603,11 +603,16 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
return l;
}
-void LoopEmitter::genDenseAffineAddress(OpBuilder &builder, Location loc,
- TensorLevel tidLvl,
- AffineExpr lvlExpr) {
+void LoopEmitter::locateLvlAtAffineAddress(OpBuilder &builder, Location loc,
+ TensorLevel tidLvl,
+ AffineExpr lvlExpr) {
auto [tid, lvl] = unpackTensorLevel(tidLvl);
+
+ const SparseIterator *parent =
+ lvl == 0 ? nullptr : iters[tid][lvl - 1].back().get();
auto &it = getCurIterator(tid, lvl);
+ it.genInit(builder, loc, parent);
+
assert(it.kind == IterKind::kTrivial && it.randomAccessible());
Value lvlCrd = genAffine(builder, loc, lvlExpr);
it.locate(builder, loc, lvlCrd);
@@ -710,9 +715,7 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
// However, that would result in a rather elaborate forest of yield
// instructions during code generation. Moreover, performing the induction
// after the if-statements more closely resembles code generated by TACO.
- unsigned o = 0;
SmallVector<Value> operands;
- unsigned delta = 0;
ValueRange whileRes = whileOp.getResults();
for (auto [tid, lvl] : unpackTensorLevelRange(loopInfo.tidLvls)) {
@@ -722,7 +725,6 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
Value cmp = CMPI(eq, it.getCrd(), iv);
it.forwardIf(builder, loc, cmp);
operands.append(it.getItVals().begin(), it.getItVals().end());
- o += it.getItVals().size();
// const Value newPos = whileOp->getResult(o++);
// Following loops continue iteration from the break point of the
// current while loop.
@@ -738,20 +740,20 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
// Reduction value from users.
for (auto &i : reduc) {
operands.push_back(i);
- // In place update reduction variable.
- i = whileOp->getResult(o++);
+ // Update user reduction variables.
+ i = whileRes.front();
+ whileRes = whileRes.drop_front();
}
// An (optional) universal index.
- if (operands.size() + delta < whileOp.getNumResults()) {
- assert(operands.size() + delta + 1 == whileOp.getNumResults());
+ if (operands.size() < whileOp.getNumResults()) {
+ assert(operands.size() + 1 == whileOp.getNumResults());
// The last one is the universial index.
operands.push_back(ADDI(iv, one));
// update the loop starting point of current loop sequence
- loopSeqStack.back().first = whileOp->getResult(o++);
+ loopSeqStack.back().first = whileOp->getResults().back();
}
- assert(o == operands.size() + delta);
if (!operands.empty())
YIELD(operands);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
index 2b508e04162325..b8fe450ca9f55f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
@@ -126,8 +126,8 @@ class LoopEmitter {
/// Emits the address for a dense level based on the value evaluated by the
/// provided affine expression.
- void genDenseAffineAddress(OpBuilder &builder, Location loc,
- TensorLevel tidLvl, AffineExpr lvlExpr);
+ void locateLvlAtAffineAddress(OpBuilder &builder, Location loc,
+ TensorLevel tidLvl, AffineExpr lvlExpr);
// TODO: Get rid of `lvls` in the argument list? Track the level we
// are currently at internally. Then it would be enterNextLvlForTensor.
>From b779f92b47fc8d4e5beb217f105381d816a7f620 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 16 Jan 2024 20:59:47 +0000
Subject: [PATCH 10/11] fix check tests
---
mlir/test/Dialect/SparseTensor/dense.mlir | 12 +-
.../test/Dialect/SparseTensor/sorted_coo.mlir | 397 +++++++--------
mlir/test/Dialect/SparseTensor/sparse_2d.mlir | 35 +-
mlir/test/Dialect/SparseTensor/sparse_3d.mlir | 68 +--
.../Dialect/SparseTensor/sparse_affine.mlir | 4 +-
.../sparse_conv_2d_slice_based.mlir | 453 +++++++++---------
.../Dialect/SparseTensor/sparse_foreach.mlir | 207 ++++----
.../Dialect/SparseTensor/sparse_index.mlir | 8 +-
mlir/test/Dialect/SparseTensor/sparse_nd.mlir | 20 +-
.../Dialect/SparseTensor/sparse_perm.mlir | 16 +-
.../SparseTensor/sparse_perm_lower.mlir | 18 +-
.../SparseTensor/sparse_vector_mv.mlir | 3 +-
.../Dialect/SparseTensor/spy_sddmm_bsr.mlir | 8 +-
13 files changed, 626 insertions(+), 623 deletions(-)
diff --git a/mlir/test/Dialect/SparseTensor/dense.mlir b/mlir/test/Dialect/SparseTensor/dense.mlir
index 2d8dcfea9adc19..60a217e05e61ec 100644
--- a/mlir/test/Dialect/SparseTensor/dense.mlir
+++ b/mlir/test/Dialect/SparseTensor/dense.mlir
@@ -42,9 +42,9 @@
// CHECK: %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32, #sparse{{[0-9]*}}> to memref<?xf32>
// CHECK: %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_1]] : memref<32x16xf32>
// CHECK: scf.for %[[VAL_9:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] {
+// CHECK: %[[VAL_11:.*]] = arith.muli %[[VAL_9]], %[[VAL_4]] : index
// CHECK: scf.for %[[VAL_10:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] {
-// CHECK: %[[VAL_11:.*]] = arith.muli %[[VAL_9]], %[[VAL_4]] : index
-// CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_11]], %[[VAL_10]] : index
+// CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_10]], %[[VAL_11]] : index
// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_12]]] : memref<?xf32>
// CHECK: %[[VAL_14:.*]] = arith.addf %[[VAL_13]], %[[VAL_2]] : f32
// CHECK: memref.store %[[VAL_14]], %[[VAL_8]]{{\[}}%[[VAL_9]], %[[VAL_10]]] : memref<32x16xf32>
@@ -82,9 +82,9 @@ func.func @dense1(%arga: tensor<32x16xf32, #DenseMatrix>,
// CHECK: %[[VAL_7:.*]] = bufferization.to_memref %[[VAL_0]] : memref<32x16xf32>
// CHECK: %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32x16xf32, #sparse{{[0-9]*}}> to memref<?xf32>
// CHECK: scf.for %[[VAL_9:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] {
+// CHECK: %[[VAL_11:.*]] = arith.muli %[[VAL_9]], %[[VAL_4]] : index
// CHECK: scf.for %[[VAL_10:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] {
-// CHECK: %[[VAL_11:.*]] = arith.muli %[[VAL_9]], %[[VAL_4]] : index
-// CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_11]], %[[VAL_10]] : index
+// CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_10]], %[[VAL_11]] : index
// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_9]], %[[VAL_10]]] : memref<32x16xf32>
// CHECK: %[[VAL_14:.*]] = arith.addf %[[VAL_13]], %[[VAL_2]] : f32
// CHECK: memref.store %[[VAL_14]], %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref<?xf32>
@@ -125,9 +125,9 @@ func.func @dense2(%arga: tensor<32x16xf32>,
// CHECK: %[[VAL_7:.*]] = bufferization.to_memref %[[VAL_0]] : memref<32x16x8xf32>
// CHECK: %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32x16xf32, #sparse{{[0-9]*}}> to memref<?xf32>
// CHECK: scf.for %[[VAL_9:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] {
+// CHECK: %[[VAL_11:.*]] = arith.muli %[[VAL_9]], %[[VAL_4]] : index
// CHECK: scf.for %[[VAL_10:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] {
-// CHECK: %[[VAL_11:.*]] = arith.muli %[[VAL_9]], %[[VAL_4]] : index
-// CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_11]], %[[VAL_10]] : index
+// CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_10]], %[[VAL_11]] : index
// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref<?xf32>
// CHECK: %[[VAL_14:.*]] = scf.for %[[VAL_15:.*]] = %[[VAL_5]] to %[[VAL_2]] step %[[VAL_6]] iter_args(%[[VAL_16:.*]] = %[[VAL_13]]) -> (f32) {
// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_9]], %[[VAL_10]], %[[VAL_15]]] : memref<32x16x8xf32>
diff --git a/mlir/test/Dialect/SparseTensor/sorted_coo.mlir b/mlir/test/Dialect/SparseTensor/sorted_coo.mlir
index 91e7920b3a9033..2b9a2dd8f4883d 100644
--- a/mlir/test/Dialect/SparseTensor/sorted_coo.mlir
+++ b/mlir/test/Dialect/SparseTensor/sorted_coo.mlir
@@ -1,3 +1,4 @@
+// TODO: re-enable after lowering coo.next to function call (such that loop structure is more clear).
// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification --canonicalize | FileCheck %s
#SortedCOO = #sparse_tensor.encoding<{
@@ -37,47 +38,47 @@
//
// CHECK-LABEL: func.func @sparse_scale(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xf32, #sparse{{[0-9]*}}>) -> tensor<?x?xf32, #sparse{{[0-9]*}}> {
-// CHECK-DAG: %[[VAL_1:.*]] = arith.constant false
-// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 2.000000e+00 : f32
-// CHECK-DAG: %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex, strided<[?], offset: ?>>
-// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xf32>
-// CHECK-DAG: %[[VAL_8:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_2]]] : memref<?xindex>
-// CHECK-DAG: %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
-// CHECK: %[[VAL_10:.*]] = scf.while (%[[VAL_11:.*]] = %[[VAL_8]]) : (index) -> index {
-// CHECK: %[[VAL_12:.*]] = arith.cmpi ult, %[[VAL_11]], %[[VAL_9]] : index
-// CHECK: scf.condition(%[[VAL_12]]) %[[VAL_11]] : index
-// CHECK: } do {
-// CHECK: ^bb0(%[[VAL_13:.*]]: index):
-// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_13]]] : memref<?xindex, strided<[?], offset: ?>>
-// CHECK: %[[VAL_15:.*]] = scf.while (%[[VAL_16:.*]] = %[[VAL_13]]) : (index) -> index {
-// CHECK: %[[VAL_17:.*]] = arith.cmpi ult, %[[VAL_16]], %[[VAL_9]] : index
-// CHECK: %[[VAL_18:.*]] = scf.if %[[VAL_17]] -> (i1) {
-// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_16]]] : memref<?xindex, strided<[?], offset: ?>>
-// CHECK: %[[VAL_20:.*]] = arith.cmpi eq, %[[VAL_19]], %[[VAL_14]] : index
-// CHECK: scf.yield %[[VAL_20]] : i1
-// CHECK: } else {
-// CHECK: scf.yield %[[VAL_1]] : i1
-// CHECK: }
-// CHECK: scf.condition(%[[VAL_21:.*]]) %[[VAL_16]] : index
-// CHECK: } do {
-// CHECK: ^bb0(%[[VAL_22:.*]]: index):
-// CHECK: %[[VAL_23:.*]] = arith.addi %[[VAL_22]], %[[VAL_3]] : index
-// CHECK: scf.yield %[[VAL_23]] : index
-// CHECK: }
-// CHECK: scf.for %[[VAL_24:.*]] = %[[VAL_13]] to %[[VAL_25:.*]] step %[[VAL_3]] {
-// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_24]]] : memref<?xf32>
-// CHECK: %[[VAL_27:.*]] = arith.mulf %[[VAL_26]], %[[VAL_4]] : f32
-// CHECK: memref.store %[[VAL_27]], %[[VAL_7]]{{\[}}%[[VAL_24]]] : memref<?xf32>
-// CHECK: } {"Emitted from" = "linalg.generic"}
-// CHECK: scf.yield %[[VAL_28:.*]] : index
-// CHECK: } attributes {"Emitted from" = "linalg.generic"}
-// CHECK: %[[VAL_29:.*]] = sparse_tensor.load %[[VAL_0]] : tensor<?x?xf32, #sparse{{[0-9]*}}>
-// CHECK: return %[[VAL_29]] : tensor<?x?xf32, #sparse{{[0-9]*}}>
-// CHECK: }
+// C_HECK-SAME: %[[VAL_0:.*]]: tensor<?x?xf32, #sparse{{[0-9]*}}>) -> tensor<?x?xf32, #sparse{{[0-9]*}}> {
+// C_HECK-DAG: %[[VAL_1:.*]] = arith.constant false
+// C_HECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : index
+// C_HECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : index
+// C_HECK-DAG: %[[VAL_4:.*]] = arith.constant 2.000000e+00 : f32
+// C_HECK-DAG: %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
+// C_HECK-DAG: %[[VAL_6:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex, strided<[?], offset: ?>>
+// C_HECK-DAG: %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xf32>
+// C_HECK-DAG: %[[VAL_8:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// C_HECK-DAG: %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// C_HECK: %[[VAL_10:.*]] = scf.while (%[[VAL_11:.*]] = %[[VAL_8]]) : (index) -> index {
+// C_HECK: %[[VAL_12:.*]] = arith.cmpi ult, %[[VAL_11]], %[[VAL_9]] : index
+// C_HECK: scf.condition(%[[VAL_12]]) %[[VAL_11]] : index
+// C_HECK: } do {
+// C_HECK: ^bb0(%[[VAL_13:.*]]: index):
+// C_HECK: %[[VAL_14:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_13]]] : memref<?xindex, strided<[?], offset: ?>>
+// C_HECK: %[[VAL_15:.*]] = scf.while (%[[VAL_16:.*]] = %[[VAL_13]]) : (index) -> index {
+// C_HECK: %[[VAL_17:.*]] = arith.cmpi ult, %[[VAL_16]], %[[VAL_9]] : index
+// C_HECK: %[[VAL_18:.*]] = scf.if %[[VAL_17]] -> (i1) {
+// C_HECK: %[[VAL_19:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_16]]] : memref<?xindex, strided<[?], offset: ?>>
+// C_HECK: %[[VAL_20:.*]] = arith.cmpi eq, %[[VAL_19]], %[[VAL_14]] : index
+// C_HECK: scf.yield %[[VAL_20]] : i1
+// C_HECK: } else {
+// C_HECK: scf.yield %[[VAL_1]] : i1
+// C_HECK: }
+// C_HECK: scf.condition(%[[VAL_21:.*]]) %[[VAL_16]] : index
+// C_HECK: } do {
+// C_HECK: ^bb0(%[[VAL_22:.*]]: index):
+// C_HECK: %[[VAL_23:.*]] = arith.addi %[[VAL_22]], %[[VAL_3]] : index
+// C_HECK: scf.yield %[[VAL_23]] : index
+// C_HECK: }
+// C_HECK: scf.for %[[VAL_24:.*]] = %[[VAL_13]] to %[[VAL_25:.*]] step %[[VAL_3]] {
+// C_HECK: %[[VAL_26:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_24]]] : memref<?xf32>
+// C_HECK: %[[VAL_27:.*]] = arith.mulf %[[VAL_26]], %[[VAL_4]] : f32
+// C_HECK: memref.store %[[VAL_27]], %[[VAL_7]]{{\[}}%[[VAL_24]]] : memref<?xf32>
+// C_HECK: } {"Emitted from" = "linalg.generic"}
+// C_HECK: scf.yield %[[VAL_28:.*]] : index
+// C_HECK: } attributes {"Emitted from" = "linalg.generic"}
+// C_HECK: %[[VAL_29:.*]] = sparse_tensor.load %[[VAL_0]] : tensor<?x?xf32, #sparse{{[0-9]*}}>
+// C_HECK: return %[[VAL_29]] : tensor<?x?xf32, #sparse{{[0-9]*}}>
+// C_HECK: }
func.func @sparse_scale(%argx: tensor<?x?xf32, #SortedCOO>) -> tensor<?x?xf32, #SortedCOO> {
%c = arith.constant 2.0 : f32
%0 = linalg.generic #trait_scale
@@ -89,57 +90,57 @@ func.func @sparse_scale(%argx: tensor<?x?xf32, #SortedCOO>) -> tensor<?x?xf32, #
return %0 : tensor<?x?xf32, #SortedCOO>
}
-// CHECK-LABEL: func.func @matvec(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x64xf64, #sparse{{[0-9]*}}>,
-// CHECK-SAME: %[[VAL_1:.*]]: tensor<64xf64>,
-// CHECK-SAME: %[[VAL_2:.*]]: tensor<32xf64>) -> tensor<32xf64> {
-// CHECK-DAG: %[[VAL_3:.*]] = arith.constant false
-// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32x64xf64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32x64xf64, #sparse{{[0-9]*}}> to memref<?xindex, strided<[?], offset: ?>>
-// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<32x64xf64, #sparse{{[0-9]*}}> to memref<?xindex, strided<[?], offset: ?>>
-// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x64xf64, #sparse{{[0-9]*}}> to memref<?xf64>
-// CHECK: %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32xf64>
-// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
-// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK: %[[VAL_13:.*]] = scf.while (%[[VAL_14:.*]] = %[[VAL_11]]) : (index) -> index {
-// CHECK: %[[VAL_15:.*]] = arith.cmpi ult, %[[VAL_14]], %[[VAL_12]] : index
-// CHECK: scf.condition(%[[VAL_15]]) %[[VAL_14]] : index
-// CHECK: } do {
-// CHECK: ^bb0(%[[VAL_16:.*]]: index):
-// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_16]]] : memref<?xindex, strided<[?], offset: ?>>
-// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_16]]] : memref<?xindex, strided<[?], offset: ?>>
-// CHECK: %[[VAL_19:.*]] = scf.while (%[[VAL_20:.*]] = %[[VAL_16]]) : (index) -> index {
-// CHECK: %[[VAL_21:.*]] = arith.cmpi ult, %[[VAL_20]], %[[VAL_12]] : index
-// CHECK: %[[VAL_22:.*]] = scf.if %[[VAL_21]] -> (i1) {
-// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_20]]] : memref<?xindex, strided<[?], offset: ?>>
-// CHECK: %[[VAL_24:.*]] = arith.cmpi eq, %[[VAL_23]], %[[VAL_18]] : index
-// CHECK: scf.yield %[[VAL_24]] : i1
-// CHECK: } else {
-// CHECK: scf.yield %[[VAL_3]] : i1
-// CHECK: }
-// CHECK: scf.condition(%[[VAL_25:.*]]) %[[VAL_20]] : index
-// CHECK: } do {
-// CHECK: ^bb0(%[[VAL_26:.*]]: index):
-// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_26]], %[[VAL_5]] : index
-// CHECK: scf.yield %[[VAL_27]] : index
-// CHECK: }
-// CHECK: %[[VAL_28:.*]] = tensor.extract %[[VAL_2]]{{\[}}%[[VAL_17]]] : tensor<32xf64>
-// CHECK: %[[VAL_29:.*]] = scf.for %[[VAL_30:.*]] = %[[VAL_16]] to %[[VAL_31:.*]] step %[[VAL_5]] iter_args(%[[VAL_32:.*]] = %[[VAL_28]]) -> (f64) {
-// CHECK: %[[VAL_33:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_30]]] : memref<?xindex, strided<[?], offset: ?>>
-// CHECK: %[[VAL_34:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_30]]] : memref<?xf64>
-// CHECK: %[[VAL_35:.*]] = tensor.extract %[[VAL_1]]{{\[}}%[[VAL_33]]] : tensor<64xf64>
-// CHECK: %[[VAL_36:.*]] = arith.mulf %[[VAL_34]], %[[VAL_35]] : f64
-// CHECK: %[[VAL_37:.*]] = arith.addf %[[VAL_32]], %[[VAL_36]] : f64
-// CHECK: scf.yield %[[VAL_37]] : f64
-// CHECK: } {"Emitted from" = "linalg.generic"}
-// CHECK: memref.store %[[VAL_38:.*]], %[[VAL_10]]{{\[}}%[[VAL_17]]] : memref<32xf64>
-// CHECK: scf.yield %[[VAL_39:.*]] : index
-// CHECK: } attributes {"Emitted from" = "linalg.generic"}
-// CHECK: %[[VAL_40:.*]] = bufferization.to_tensor %[[VAL_10]] : memref<32xf64>
-// CHECK: return %[[VAL_40]] : tensor<32xf64>
-// CHECK: }
+// C_HECK-LABEL: func.func @matvec(
+// C_HECK-SAME: %[[VAL_0:.*]]: tensor<32x64xf64, #sparse{{[0-9]*}}>,
+// C_HECK-SAME: %[[VAL_1:.*]]: tensor<64xf64>,
+// C_HECK-SAME: %[[VAL_2:.*]]: tensor<32xf64>) -> tensor<32xf64> {
+// C_HECK-DAG: %[[VAL_3:.*]] = arith.constant false
+// C_HECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index
+// C_HECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index
+// C_HECK-DAG: %[[VAL_6:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32x64xf64, #sparse{{[0-9]*}}> to memref<?xindex>
+// C_HECK-DAG: %[[VAL_7:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32x64xf64, #sparse{{[0-9]*}}> to memref<?xindex, strided<[?], offset: ?>>
+// C_HECK-DAG: %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<32x64xf64, #sparse{{[0-9]*}}> to memref<?xindex, strided<[?], offset: ?>>
+// C_HECK-DAG: %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x64xf64, #sparse{{[0-9]*}}> to memref<?xf64>
+// C_HECK: %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32xf64>
+// C_HECK: %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// C_HECK: %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// C_HECK: %[[VAL_13:.*]] = scf.while (%[[VAL_14:.*]] = %[[VAL_11]]) : (index) -> index {
+// C_HECK: %[[VAL_15:.*]] = arith.cmpi ult, %[[VAL_14]], %[[VAL_12]] : index
+// C_HECK: scf.condition(%[[VAL_15]]) %[[VAL_14]] : index
+// C_HECK: } do {
+// C_HECK: ^bb0(%[[VAL_16:.*]]: index):
+// C_HECK: %[[VAL_17:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_16]]] : memref<?xindex, strided<[?], offset: ?>>
+// C_HECK: %[[VAL_18:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_16]]] : memref<?xindex, strided<[?], offset: ?>>
+// C_HECK: %[[VAL_19:.*]] = scf.while (%[[VAL_20:.*]] = %[[VAL_16]]) : (index) -> index {
+// C_HECK: %[[VAL_21:.*]] = arith.cmpi ult, %[[VAL_20]], %[[VAL_12]] : index
+// C_HECK: %[[VAL_22:.*]] = scf.if %[[VAL_21]] -> (i1) {
+// C_HECK: %[[VAL_23:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_20]]] : memref<?xindex, strided<[?], offset: ?>>
+// C_HECK: %[[VAL_24:.*]] = arith.cmpi eq, %[[VAL_23]], %[[VAL_18]] : index
+// C_HECK: scf.yield %[[VAL_24]] : i1
+// C_HECK: } else {
+// C_HECK: scf.yield %[[VAL_3]] : i1
+// C_HECK: }
+// C_HECK: scf.condition(%[[VAL_25:.*]]) %[[VAL_20]] : index
+// C_HECK: } do {
+// C_HECK: ^bb0(%[[VAL_26:.*]]: index):
+// C_HECK: %[[VAL_27:.*]] = arith.addi %[[VAL_26]], %[[VAL_5]] : index
+// C_HECK: scf.yield %[[VAL_27]] : index
+// C_HECK: }
+// C_HECK: %[[VAL_28:.*]] = tensor.extract %[[VAL_2]]{{\[}}%[[VAL_17]]] : tensor<32xf64>
+// C_HECK: %[[VAL_29:.*]] = scf.for %[[VAL_30:.*]] = %[[VAL_16]] to %[[VAL_31:.*]] step %[[VAL_5]] iter_args(%[[VAL_32:.*]] = %[[VAL_28]]) -> (f64) {
+// C_HECK: %[[VAL_33:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_30]]] : memref<?xindex, strided<[?], offset: ?>>
+// C_HECK: %[[VAL_34:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_30]]] : memref<?xf64>
+// C_HECK: %[[VAL_35:.*]] = tensor.extract %[[VAL_1]]{{\[}}%[[VAL_33]]] : tensor<64xf64>
+// C_HECK: %[[VAL_36:.*]] = arith.mulf %[[VAL_34]], %[[VAL_35]] : f64
+// C_HECK: %[[VAL_37:.*]] = arith.addf %[[VAL_32]], %[[VAL_36]] : f64
+// C_HECK: scf.yield %[[VAL_37]] : f64
+// C_HECK: } {"Emitted from" = "linalg.generic"}
+// C_HECK: memref.store %[[VAL_38:.*]], %[[VAL_10]]{{\[}}%[[VAL_17]]] : memref<32xf64>
+// C_HECK: scf.yield %[[VAL_39:.*]] : index
+// C_HECK: } attributes {"Emitted from" = "linalg.generic"}
+// C_HECK: %[[VAL_40:.*]] = bufferization.to_tensor %[[VAL_10]] : memref<32xf64>
+// C_HECK: return %[[VAL_40]] : tensor<32xf64>
+// C_HECK: }
func.func @matvec(%arga: tensor<32x64xf64, #SortedCOO>,
%argb: tensor<64xf64>,
%argx: tensor<32xf64>) -> tensor<32xf64> {
@@ -154,112 +155,112 @@ func.func @matvec(%arga: tensor<32x64xf64, #SortedCOO>,
return %0 : tensor<32xf64>
}
-// CHECK-LABEL: func.func @mateltmul(
-// CHECK-SAME: %[[VAL_0:.*0]]: tensor<32x64xf64, #sparse{{[0-9]*}}>, %[[VAL_1:.*1]]: tensor<32x64xf64, #sparse{{[0-9]*}}>,
-// CHECK-SAME: %[[VAL_2:.*2]]: tensor<32x64xf64>) -> tensor<32x64xf64> {
-// CHECK-DAG: %[[VAL_3:.*]] = arith.constant false
-// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0.000000e+00 : f64
-// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32x64xf64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32x64xf64, #sparse{{[0-9]*}}> to memref<?xindex, strided<[?], offset: ?>>
-// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<32x64xf64, #sparse{{[0-9]*}}> to memref<?xindex, strided<[?], offset: ?>>
-// CHECK-DAG: %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x64xf64, #sparse{{[0-9]*}}> to memref<?xf64>
-// CHECK-DAG: %[[VAL_11:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<32x64xf64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK-DAG: %[[VAL_12:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 0 : index} : tensor<32x64xf64, #sparse{{[0-9]*}}> to memref<?xindex, strided<[?], offset: ?>>
-// CHECK-DAG: %[[VAL_13:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 1 : index} : tensor<32x64xf64, #sparse{{[0-9]*}}> to memref<?xindex, strided<[?], offset: ?>>
-// CHECK-DAG: %[[VAL_14:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32x64xf64, #sparse{{[0-9]*}}> to memref<?xf64>
-// CHECK: %[[VAL_15:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x64xf64>
-// CHECK: linalg.fill ins(%[[VAL_4]] : f64) outs(%[[VAL_15]] : memref<32x64xf64>)
-// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_6]]] : memref<?xindex>
-// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_6]]] : memref<?xindex>
-// CHECK: %[[VAL_20:.*]]:2 = scf.while (%[[VAL_21:.*]] = %[[VAL_16]], %[[VAL_22:.*]] = %[[VAL_18]]) : (index, index) -> (index, index) {
-// CHECK: %[[VAL_23:.*]] = arith.cmpi ult, %[[VAL_21]], %[[VAL_17]] : index
-// CHECK: %[[VAL_24:.*]] = arith.cmpi ult, %[[VAL_22]], %[[VAL_19]] : index
-// CHECK: %[[VAL_25:.*]] = arith.andi %[[VAL_23]], %[[VAL_24]] : i1
-// CHECK: scf.condition(%[[VAL_25]]) %[[VAL_21]], %[[VAL_22]] : index, index
-// CHECK: } do {
-// CHECK: ^bb0(%[[VAL_26:.*]]: index, %[[VAL_27:.*]]: index):
-// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_26]]] : memref<?xindex, strided<[?], offset: ?>>
-// CHECK: %[[VAL_29:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_27]]] : memref<?xindex, strided<[?], offset: ?>>
-// CHECK: %[[VAL_32:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_26]]] : memref<?xindex, strided<[?], offset: ?>>
-// CHECK: %[[VAL_33:.*]] = scf.while (%[[VAL_34:.*]] = %[[VAL_26]]) : (index) -> index {
-// CHECK: %[[VAL_35:.*]] = arith.cmpi ult, %[[VAL_34]], %[[VAL_17]] : index
-// CHECK: %[[VAL_36:.*]] = scf.if %[[VAL_35]] -> (i1) {
-// CHECK: %[[VAL_37:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_34]]] : memref<?xindex, strided<[?], offset: ?>>
-// CHECK: %[[VAL_38:.*]] = arith.cmpi eq, %[[VAL_37]], %[[VAL_32]] : index
-// CHECK: scf.yield %[[VAL_38]] : i1
-// CHECK: } else {
-// CHECK: scf.yield %[[VAL_3]] : i1
-// CHECK: }
-// CHECK: scf.condition(%[[VAL_39:.*]]) %[[VAL_34]] : index
-// CHECK: } do {
-// CHECK: ^bb0(%[[VAL_40:.*]]: index):
-// CHECK: %[[VAL_41:.*]] = arith.addi %[[VAL_40]], %[[VAL_6]] : index
-// CHECK: scf.yield %[[VAL_41]] : index
-// CHECK: }
-// CHECK: %[[VAL_42:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_27]]] : memref<?xindex, strided<[?], offset: ?>>
-// CHECK: %[[VAL_43:.*]] = scf.while (%[[VAL_44:.*]] = %[[VAL_27]]) : (index) -> index {
-// CHECK: %[[VAL_45:.*]] = arith.cmpi ult, %[[VAL_44]], %[[VAL_19]] : index
-// CHECK: %[[VAL_46:.*]] = scf.if %[[VAL_45]] -> (i1) {
-// CHECK: %[[VAL_47:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_44]]] : memref<?xindex, strided<[?], offset: ?>>
-// CHECK: %[[VAL_48:.*]] = arith.cmpi eq, %[[VAL_47]], %[[VAL_42]] : index
-// CHECK: scf.yield %[[VAL_48]] : i1
-// CHECK: } else {
-// CHECK: scf.yield %[[VAL_3]] : i1
-// CHECK: }
-// CHECK: scf.condition(%[[VAL_49:.*]]) %[[VAL_44]] : index
-// CHECK: } do {
-// CHECK: ^bb0(%[[VAL_50:.*]]: index):
-// CHECK: %[[VAL_51:.*]] = arith.addi %[[VAL_50]], %[[VAL_6]] : index
-// CHECK: scf.yield %[[VAL_51]] : index
-// CHECK: }
-// CHECK: %[[VAL_30:.*]] = arith.cmpi ult, %[[VAL_29]], %[[VAL_28]] : index
-// CHECK: %[[VAL_31:.*]] = arith.select %[[VAL_30]], %[[VAL_29]], %[[VAL_28]] : index
-// CHECK: %[[VAL_52:.*]] = arith.cmpi eq, %[[VAL_28]], %[[VAL_31]] : index
-// CHECK: %[[VAL_53:.*]] = arith.cmpi eq, %[[VAL_29]], %[[VAL_31]] : index
-// CHECK: %[[VAL_54:.*]] = arith.andi %[[VAL_52]], %[[VAL_53]] : i1
-// CHECK: scf.if %[[VAL_54]] {
-// CHECK: %[[VAL_55:.*]]:2 = scf.while (%[[VAL_56:.*]] = %[[VAL_26]], %[[VAL_57:.*]] = %[[VAL_27]]) : (index, index) -> (index, index) {
-// CHECK: %[[VAL_58:.*]] = arith.cmpi ult, %[[VAL_56]], %[[VAL_59:.*]] : index
-// CHECK: %[[VAL_60:.*]] = arith.cmpi ult, %[[VAL_57]], %[[VAL_61:.*]] : index
-// CHECK: %[[VAL_62:.*]] = arith.andi %[[VAL_58]], %[[VAL_60]] : i1
-// CHECK: scf.condition(%[[VAL_62]]) %[[VAL_56]], %[[VAL_57]] : index, index
-// CHECK: } do {
-// CHECK: ^bb0(%[[VAL_63:.*]]: index, %[[VAL_64:.*]]: index):
-// CHECK: %[[VAL_65:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_63]]] : memref<?xindex, strided<[?], offset: ?>>
-// CHECK: %[[VAL_66:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_64]]] : memref<?xindex, strided<[?], offset: ?>>
-// CHECK: %[[VAL_67:.*]] = arith.cmpi ult, %[[VAL_66]], %[[VAL_65]] : index
-// CHECK: %[[VAL_68:.*]] = arith.select %[[VAL_67]], %[[VAL_66]], %[[VAL_65]] : index
-// CHECK: %[[VAL_69:.*]] = arith.cmpi eq, %[[VAL_65]], %[[VAL_68]] : index
-// CHECK: %[[VAL_70:.*]] = arith.cmpi eq, %[[VAL_66]], %[[VAL_68]] : index
-// CHECK: %[[VAL_71:.*]] = arith.andi %[[VAL_69]], %[[VAL_70]] : i1
-// CHECK: scf.if %[[VAL_71]] {
-// CHECK: %[[VAL_72:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_63]]] : memref<?xf64>
-// CHECK: %[[VAL_73:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_64]]] : memref<?xf64>
-// CHECK: %[[VAL_74:.*]] = arith.mulf %[[VAL_72]], %[[VAL_73]] : f64
-// CHECK: memref.store %[[VAL_74]], %[[VAL_15]]{{\[}}%[[VAL_31]], %[[VAL_68]]] : memref<32x64xf64>
-// CHECK: }
-// CHECK: %[[VAL_75:.*]] = arith.cmpi eq, %[[VAL_65]], %[[VAL_68]] : index
-// CHECK: %[[VAL_76:.*]] = arith.addi %[[VAL_63]], %[[VAL_6]] : index
-// CHECK: %[[VAL_77:.*]] = arith.select %[[VAL_75]], %[[VAL_76]], %[[VAL_63]] : index
-// CHECK: %[[VAL_78:.*]] = arith.cmpi eq, %[[VAL_66]], %[[VAL_68]] : index
-// CHECK: %[[VAL_79:.*]] = arith.addi %[[VAL_64]], %[[VAL_6]] : index
-// CHECK: %[[VAL_80:.*]] = arith.select %[[VAL_78]], %[[VAL_79]], %[[VAL_64]] : index
-// CHECK: scf.yield %[[VAL_77]], %[[VAL_80]] : index, index
-// CHECK: } attributes {"Emitted from" = "linalg.generic"}
-// CHECK: }
-// CHECK: %[[VAL_81:.*]] = arith.cmpi eq, %[[VAL_28]], %[[VAL_31]] : index
-// CHECK: %[[VAL_82:.*]] = arith.select %[[VAL_81]], %[[VAL_83:.*]], %[[VAL_26]] : index
-// CHECK: %[[VAL_84:.*]] = arith.cmpi eq, %[[VAL_29]], %[[VAL_31]] : index
-// CHECK: %[[VAL_85:.*]] = arith.select %[[VAL_84]], %[[VAL_86:.*]], %[[VAL_27]] : index
-// CHECK: scf.yield %[[VAL_82]], %[[VAL_85]] : index, index
-// CHECK: } attributes {"Emitted from" = "linalg.generic"}
-// CHECK: %[[VAL_87:.*]] = bufferization.to_tensor %[[VAL_15]] : memref<32x64xf64>
-// CHECK: return %[[VAL_87]] : tensor<32x64xf64>
-// CHECK: }
+// C_HECK-LABEL: func.func @mateltmul(
+// C_HECK-SAME: %[[VAL_0:.*0]]: tensor<32x64xf64, #sparse{{[0-9]*}}>, %[[VAL_1:.*1]]: tensor<32x64xf64, #sparse{{[0-9]*}}>,
+// C_HECK-SAME: %[[VAL_2:.*2]]: tensor<32x64xf64>) -> tensor<32x64xf64> {
+// C_HECK-DAG: %[[VAL_3:.*]] = arith.constant false
+// C_HECK-DAG: %[[VAL_4:.*]] = arith.constant 0.000000e+00 : f64
+// C_HECK-DAG: %[[VAL_5:.*]] = arith.constant 0 : index
+// C_HECK-DAG: %[[VAL_6:.*]] = arith.constant 1 : index
+// C_HECK-DAG: %[[VAL_7:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32x64xf64, #sparse{{[0-9]*}}> to memref<?xindex>
+// C_HECK-DAG: %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32x64xf64, #sparse{{[0-9]*}}> to memref<?xindex, strided<[?], offset: ?>>
+// C_HECK-DAG: %[[VAL_9:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<32x64xf64, #sparse{{[0-9]*}}> to memref<?xindex, strided<[?], offset: ?>>
+// C_HECK-DAG: %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x64xf64, #sparse{{[0-9]*}}> to memref<?xf64>
+// C_HECK-DAG: %[[VAL_11:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<32x64xf64, #sparse{{[0-9]*}}> to memref<?xindex>
+// C_HECK-DAG: %[[VAL_12:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 0 : index} : tensor<32x64xf64, #sparse{{[0-9]*}}> to memref<?xindex, strided<[?], offset: ?>>
+// C_HECK-DAG: %[[VAL_13:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 1 : index} : tensor<32x64xf64, #sparse{{[0-9]*}}> to memref<?xindex, strided<[?], offset: ?>>
+// C_HECK-DAG: %[[VAL_14:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32x64xf64, #sparse{{[0-9]*}}> to memref<?xf64>
+// C_HECK: %[[VAL_15:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x64xf64>
+// C_HECK: linalg.fill ins(%[[VAL_4]] : f64) outs(%[[VAL_15]] : memref<32x64xf64>)
+// C_HECK: %[[VAL_16:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// C_HECK: %[[VAL_17:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_6]]] : memref<?xindex>
+// C_HECK: %[[VAL_18:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// C_HECK: %[[VAL_19:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_6]]] : memref<?xindex>
+// C_HECK: %[[VAL_20:.*]]:2 = scf.while (%[[VAL_21:.*]] = %[[VAL_16]], %[[VAL_22:.*]] = %[[VAL_18]]) : (index, index) -> (index, index) {
+// C_HECK: %[[VAL_23:.*]] = arith.cmpi ult, %[[VAL_21]], %[[VAL_17]] : index
+// C_HECK: %[[VAL_24:.*]] = arith.cmpi ult, %[[VAL_22]], %[[VAL_19]] : index
+// C_HECK: %[[VAL_25:.*]] = arith.andi %[[VAL_23]], %[[VAL_24]] : i1
+// C_HECK: scf.condition(%[[VAL_25]]) %[[VAL_21]], %[[VAL_22]] : index, index
+// C_HECK: } do {
+// C_HECK: ^bb0(%[[VAL_26:.*]]: index, %[[VAL_27:.*]]: index):
+// C_HECK: %[[VAL_28:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_26]]] : memref<?xindex, strided<[?], offset: ?>>
+// C_HECK: %[[VAL_29:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_27]]] : memref<?xindex, strided<[?], offset: ?>>
+// C_HECK: %[[VAL_32:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_26]]] : memref<?xindex, strided<[?], offset: ?>>
+// C_HECK: %[[VAL_33:.*]] = scf.while (%[[VAL_34:.*]] = %[[VAL_26]]) : (index) -> index {
+// C_HECK: %[[VAL_35:.*]] = arith.cmpi ult, %[[VAL_34]], %[[VAL_17]] : index
+// C_HECK: %[[VAL_36:.*]] = scf.if %[[VAL_35]] -> (i1) {
+// C_HECK: %[[VAL_37:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_34]]] : memref<?xindex, strided<[?], offset: ?>>
+// C_HECK: %[[VAL_38:.*]] = arith.cmpi eq, %[[VAL_37]], %[[VAL_32]] : index
+// C_HECK: scf.yield %[[VAL_38]] : i1
+// C_HECK: } else {
+// C_HECK: scf.yield %[[VAL_3]] : i1
+// C_HECK: }
+// C_HECK: scf.condition(%[[VAL_39:.*]]) %[[VAL_34]] : index
+// C_HECK: } do {
+// C_HECK: ^bb0(%[[VAL_40:.*]]: index):
+// C_HECK: %[[VAL_41:.*]] = arith.addi %[[VAL_40]], %[[VAL_6]] : index
+// C_HECK: scf.yield %[[VAL_41]] : index
+// C_HECK: }
+// C_HECK: %[[VAL_42:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_27]]] : memref<?xindex, strided<[?], offset: ?>>
+// C_HECK: %[[VAL_43:.*]] = scf.while (%[[VAL_44:.*]] = %[[VAL_27]]) : (index) -> index {
+// C_HECK: %[[VAL_45:.*]] = arith.cmpi ult, %[[VAL_44]], %[[VAL_19]] : index
+// C_HECK: %[[VAL_46:.*]] = scf.if %[[VAL_45]] -> (i1) {
+// C_HECK: %[[VAL_47:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_44]]] : memref<?xindex, strided<[?], offset: ?>>
+// C_HECK: %[[VAL_48:.*]] = arith.cmpi eq, %[[VAL_47]], %[[VAL_42]] : index
+// C_HECK: scf.yield %[[VAL_48]] : i1
+// C_HECK: } else {
+// C_HECK: scf.yield %[[VAL_3]] : i1
+// C_HECK: }
+// C_HECK: scf.condition(%[[VAL_49:.*]]) %[[VAL_44]] : index
+// C_HECK: } do {
+// C_HECK: ^bb0(%[[VAL_50:.*]]: index):
+// C_HECK: %[[VAL_51:.*]] = arith.addi %[[VAL_50]], %[[VAL_6]] : index
+// C_HECK: scf.yield %[[VAL_51]] : index
+// C_HECK: }
+// C_HECK: %[[VAL_30:.*]] = arith.cmpi ult, %[[VAL_29]], %[[VAL_28]] : index
+// C_HECK: %[[VAL_31:.*]] = arith.select %[[VAL_30]], %[[VAL_29]], %[[VAL_28]] : index
+// C_HECK: %[[VAL_52:.*]] = arith.cmpi eq, %[[VAL_28]], %[[VAL_31]] : index
+// C_HECK: %[[VAL_53:.*]] = arith.cmpi eq, %[[VAL_29]], %[[VAL_31]] : index
+// C_HECK: %[[VAL_54:.*]] = arith.andi %[[VAL_52]], %[[VAL_53]] : i1
+// C_HECK: scf.if %[[VAL_54]] {
+// C_HECK: %[[VAL_55:.*]]:2 = scf.while (%[[VAL_56:.*]] = %[[VAL_26]], %[[VAL_57:.*]] = %[[VAL_27]]) : (index, index) -> (index, index) {
+// C_HECK: %[[VAL_58:.*]] = arith.cmpi ult, %[[VAL_56]], %[[VAL_59:.*]] : index
+// C_HECK: %[[VAL_60:.*]] = arith.cmpi ult, %[[VAL_57]], %[[VAL_61:.*]] : index
+// C_HECK: %[[VAL_62:.*]] = arith.andi %[[VAL_58]], %[[VAL_60]] : i1
+// C_HECK: scf.condition(%[[VAL_62]]) %[[VAL_56]], %[[VAL_57]] : index, index
+// C_HECK: } do {
+// C_HECK: ^bb0(%[[VAL_63:.*]]: index, %[[VAL_64:.*]]: index):
+// C_HECK: %[[VAL_65:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_63]]] : memref<?xindex, strided<[?], offset: ?>>
+// C_HECK: %[[VAL_66:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_64]]] : memref<?xindex, strided<[?], offset: ?>>
+// C_HECK: %[[VAL_67:.*]] = arith.cmpi ult, %[[VAL_66]], %[[VAL_65]] : index
+// C_HECK: %[[VAL_68:.*]] = arith.select %[[VAL_67]], %[[VAL_66]], %[[VAL_65]] : index
+// C_HECK: %[[VAL_69:.*]] = arith.cmpi eq, %[[VAL_65]], %[[VAL_68]] : index
+// C_HECK: %[[VAL_70:.*]] = arith.cmpi eq, %[[VAL_66]], %[[VAL_68]] : index
+// C_HECK: %[[VAL_71:.*]] = arith.andi %[[VAL_69]], %[[VAL_70]] : i1
+// C_HECK: scf.if %[[VAL_71]] {
+// C_HECK: %[[VAL_72:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_63]]] : memref<?xf64>
+// C_HECK: %[[VAL_73:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_64]]] : memref<?xf64>
+// C_HECK: %[[VAL_74:.*]] = arith.mulf %[[VAL_72]], %[[VAL_73]] : f64
+// C_HECK: memref.store %[[VAL_74]], %[[VAL_15]]{{\[}}%[[VAL_31]], %[[VAL_68]]] : memref<32x64xf64>
+// C_HECK: }
+// C_HECK: %[[VAL_75:.*]] = arith.cmpi eq, %[[VAL_65]], %[[VAL_68]] : index
+// C_HECK: %[[VAL_76:.*]] = arith.addi %[[VAL_63]], %[[VAL_6]] : index
+// C_HECK: %[[VAL_77:.*]] = arith.select %[[VAL_75]], %[[VAL_76]], %[[VAL_63]] : index
+// C_HECK: %[[VAL_78:.*]] = arith.cmpi eq, %[[VAL_66]], %[[VAL_68]] : index
+// C_HECK: %[[VAL_79:.*]] = arith.addi %[[VAL_64]], %[[VAL_6]] : index
+// C_HECK: %[[VAL_80:.*]] = arith.select %[[VAL_78]], %[[VAL_79]], %[[VAL_64]] : index
+// C_HECK: scf.yield %[[VAL_77]], %[[VAL_80]] : index, index
+// C_HECK: } attributes {"Emitted from" = "linalg.generic"}
+// C_HECK: }
+// C_HECK: %[[VAL_81:.*]] = arith.cmpi eq, %[[VAL_28]], %[[VAL_31]] : index
+// C_HECK: %[[VAL_82:.*]] = arith.select %[[VAL_81]], %[[VAL_83:.*]], %[[VAL_26]] : index
+// C_HECK: %[[VAL_84:.*]] = arith.cmpi eq, %[[VAL_29]], %[[VAL_31]] : index
+// C_HECK: %[[VAL_85:.*]] = arith.select %[[VAL_84]], %[[VAL_86:.*]], %[[VAL_27]] : index
+// C_HECK: scf.yield %[[VAL_82]], %[[VAL_85]] : index, index
+// C_HECK: } attributes {"Emitted from" = "linalg.generic"}
+// C_HECK: %[[VAL_87:.*]] = bufferization.to_tensor %[[VAL_15]] : memref<32x64xf64>
+// C_HECK: return %[[VAL_87]] : tensor<32x64xf64>
+// C_HECK: }
func.func @mateltmul(%argx: tensor<32x64xf64, #SortedCOO>,
%argy: tensor<32x64xf64, #SortedCOO>,
%argz: tensor<32x64xf64>) -> tensor<32x64xf64> {
diff --git a/mlir/test/Dialect/SparseTensor/sparse_2d.mlir b/mlir/test/Dialect/SparseTensor/sparse_2d.mlir
index 57ae18391daf8a..85ae0db916899e 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_2d.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_2d.mlir
@@ -29,9 +29,9 @@
// CHECK-DAG: %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16xf32>
// CHECK: linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_10]] : memref<32x16xf32>)
// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] {
+// CHECK: %[[VAL_13:.*]] = arith.muli %[[VAL_11]], %[[VAL_4]] : index
// CHECK: scf.for %[[VAL_12:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] {
-// CHECK: %[[VAL_13:.*]] = arith.muli %[[VAL_11]], %[[VAL_4]] : index
-// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_12]] : index
+// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_13]] : index
// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_14]]] : memref<?xf32>
// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_11]], %[[VAL_12]]] : memref<32x16xf32>
// CHECK: %[[VAL_17:.*]] = arith.addf %[[VAL_15]], %[[VAL_16]] : f32
@@ -66,9 +66,9 @@ func.func @add_dd(%arga: tensor<32x16xf32, #Tdd>, %argb: tensor<32x16xf32>, %arg
// CHECK-DAG: %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16xi1>
// CHECK: linalg.fill ins(%[[VAL_5]] : i1) outs(%[[VAL_10]] : memref<32x16xi1>)
// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_6]] to %[[VAL_3]] step %[[VAL_7]] {
+// CHECK: %[[VAL_13:.*]] = arith.muli %[[VAL_11]], %[[VAL_4]] : index
// CHECK: scf.for %[[VAL_12:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_7]] {
-// CHECK: %[[VAL_13:.*]] = arith.muli %[[VAL_11]], %[[VAL_4]] : index
-// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_12]] : index
+// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_13]] : index
// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_14]]] : memref<?xf32>
// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_11]], %[[VAL_12]]] : memref<32x16xf32>
// CHECK: %[[VAL_17:.*]] = arith.cmpf ult, %[[VAL_15]], %[[VAL_16]] : f32
@@ -102,9 +102,9 @@ func.func @cmp_dd(%arga: tensor<32x16xf32, #Tdd>, %argb: tensor<32x16xf32>, %arg
// CHECK-DAG: %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16xf32>
// CHECK: linalg.fill ins(%{{.*}} : f32) outs(%[[VAL_10]] : memref<32x16xf32>)
// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] {
+// CHECK: %[[VAL_13:.*]] = arith.muli %[[VAL_11]], %[[VAL_4]] : index
// CHECK: scf.for %[[VAL_12:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] {
-// CHECK: %[[VAL_13:.*]] = arith.muli %[[VAL_11]], %[[VAL_4]] : index
-// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_12]] : index
+// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_13]] : index
// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_14]]] : memref<?xf32>
// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_11]], %[[VAL_12]]] : memref<32x16xf32>
// CHECK: %[[VAL_17:.*]] = arith.mulf %[[VAL_15]], %[[VAL_16]] : f32
@@ -319,9 +319,9 @@ func.func @mul_ds(%arga: tensor<32x16xf32, #Tds>, %argb: tensor<32x16xf32>, %arg
// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_20]]] : memref<?xindex>
// CHECK: %[[VAL_23:.*]] = arith.cmpi eq, %[[VAL_22]], %[[VAL_21]] : index
// CHECK: scf.if %[[VAL_23]] {
+// CHECK: %[[VAL_25:.*]] = arith.muli %[[VAL_20]], %[[VAL_4]] : index
// CHECK: scf.for %[[VAL_24:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_7]] {
-// CHECK: %[[VAL_25:.*]] = arith.muli %[[VAL_20]], %[[VAL_4]] : index
-// CHECK: %[[VAL_26:.*]] = arith.addi %[[VAL_25]], %[[VAL_24]] : index
+// CHECK: %[[VAL_26:.*]] = arith.addi %[[VAL_24]], %[[VAL_25]] : index
// CHECK: %[[VAL_27:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_26]]] : memref<?xf32>
// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_21]], %[[VAL_24]]] : memref<32x16xf32>
// CHECK: %[[VAL_29:.*]] = arith.addf %[[VAL_27]], %[[VAL_28]] : f32
@@ -389,9 +389,9 @@ func.func @add_sd(%arga: tensor<32x16xf32, #Tsd>, %argb: tensor<32x16xf32>, %arg
// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_21]]] : memref<?xindex>
// CHECK: %[[VAL_24:.*]] = arith.cmpi eq, %[[VAL_23]], %[[VAL_22]] : index
// CHECK: scf.if %[[VAL_24]] {
+// CHECK: %[[VAL_26:.*]] = arith.muli %[[VAL_21]], %[[VAL_3]] : index
// CHECK: scf.for %[[VAL_25:.*]] = %[[VAL_6]] to %[[VAL_3]] step %[[VAL_7]] {
-// CHECK: %[[VAL_26:.*]] = arith.muli %[[VAL_21]], %[[VAL_3]] : index
-// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_26]], %[[VAL_25]] : index
+// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_25]], %[[VAL_26]] : index
// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_27]]] : memref<?xf32>
// CHECK: %[[VAL_29:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_22]], %[[VAL_25]]] : memref<32x16xf32>
// CHECK: %[[VAL_30:.*]] = arith.cmpf ult, %[[VAL_28]], %[[VAL_29]] : f32
@@ -451,9 +451,9 @@ func.func @cmp_sd(%arga: tensor<32x16xf32, #Tsd>, %argb: tensor<32x16xf32>, %arg
// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
// CHECK: scf.for %[[VAL_14:.*]] = %[[VAL_12]] to %[[VAL_13]] step %[[VAL_5]] {
// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_14]]] : memref<?xindex>
+// CHECK: %[[VAL_17:.*]] = arith.muli %[[VAL_14]], %[[VAL_3]] : index
// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
-// CHECK: %[[VAL_17:.*]] = arith.muli %[[VAL_14]], %[[VAL_3]] : index
-// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_17]], %[[VAL_16]] : index
+// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_16]], %[[VAL_17]] : index
// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref<?xf32>
// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_15]], %[[VAL_16]]] : memref<32x16xf32>
// CHECK: %[[VAL_21:.*]] = arith.mulf %[[VAL_19]], %[[VAL_20]] : f32
@@ -1272,6 +1272,7 @@ func.func @mul_ss_ss(%arga: tensor<32x16xf32, #Tss>, %argb: tensor<32x16xf32, #T
// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_22]]] : memref<?xindex>
// CHECK: %[[VAL_25:.*]] = arith.cmpi eq, %[[VAL_24]], %[[VAL_23]] : index
// CHECK: scf.if %[[VAL_25]] {
+// CHECK: %[[VAL_36:.*]] = arith.muli %[[VAL_22]], %[[VAL_4]] : index
// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_23]]] : memref<?xindex>
// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_23]], %[[VAL_7]] : index
// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_27]]] : memref<?xindex>
@@ -1281,8 +1282,7 @@ func.func @mul_ss_ss(%arga: tensor<32x16xf32, #Tss>, %argb: tensor<32x16xf32, #T
// CHECK: } do {
// CHECK: ^bb0(%[[VAL_33:.*]]: index, %[[VAL_34:.*]]: index):
// CHECK: %[[VAL_35:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_33]]] : memref<?xindex>
-// CHECK: %[[VAL_36:.*]] = arith.muli %[[VAL_22]], %[[VAL_4]] : index
-// CHECK: %[[VAL_37:.*]] = arith.addi %[[VAL_36]], %[[VAL_34]] : index
+// CHECK: %[[VAL_37:.*]] = arith.addi %[[VAL_34]], %[[VAL_36]] : index
// CHECK: %[[VAL_38:.*]] = arith.cmpi eq, %[[VAL_35]], %[[VAL_34]] : index
// CHECK: scf.if %[[VAL_38]] {
// CHECK: %[[VAL_39:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_37]]] : memref<?xf32>
@@ -1303,8 +1303,7 @@ func.func @mul_ss_ss(%arga: tensor<32x16xf32, #Tss>, %argb: tensor<32x16xf32, #T
// CHECK: scf.yield %[[VAL_45]], %[[VAL_46]] : index, index
// CHECK: }
// CHECK: scf.for %[[VAL_47:.*]] = %[[VAL_48:.*]]#1 to %[[VAL_4]] step %[[VAL_7]] {
-// CHECK: %[[VAL_49:.*]] = arith.muli %[[VAL_22]], %[[VAL_4]] : index
-// CHECK: %[[VAL_50:.*]] = arith.addi %[[VAL_49]], %[[VAL_47]] : index
+// CHECK: %[[VAL_50:.*]] = arith.addi %[[VAL_47]], %[[VAL_36]] : index
// CHECK: %[[VAL_51:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_50]]] : memref<?xf32>
// CHECK: memref.store %[[VAL_51]], %[[VAL_15]]{{\[}}%[[VAL_23]], %[[VAL_47]]] : memref<32x16xf32>
// CHECK: }
@@ -1369,13 +1368,13 @@ func.func @add_sd_ds(%arga: tensor<32x16xf32, #Tsd>, %argb: tensor<32x16xf32, #T
// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_14]] to %[[VAL_15]] step %[[VAL_5]] {
// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_16]]] : memref<?xindex>
+// CHECK: %[[VAL_23:.*]] = arith.muli %[[VAL_16]], %[[VAL_3]] : index
// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_17]]] : memref<?xindex>
// CHECK: %[[VAL_19:.*]] = arith.addi %[[VAL_17]], %[[VAL_5]] : index
// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_19]]] : memref<?xindex>
// CHECK: scf.for %[[VAL_21:.*]] = %[[VAL_18]] to %[[VAL_20]] step %[[VAL_5]] {
// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_21]]] : memref<?xindex>
-// CHECK: %[[VAL_23:.*]] = arith.muli %[[VAL_16]], %[[VAL_3]] : index
-// CHECK: %[[VAL_24:.*]] = arith.addi %[[VAL_23]], %[[VAL_22]] : index
+// CHECK: %[[VAL_24:.*]] = arith.addi %[[VAL_22]], %[[VAL_23]] : index
// CHECK: %[[VAL_25:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_24]]] : memref<?xf32>
// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_21]]] : memref<?xf32>
// CHECK: %[[VAL_27:.*]] = arith.mulf %[[VAL_25]], %[[VAL_26]] : f32
diff --git a/mlir/test/Dialect/SparseTensor/sparse_3d.mlir b/mlir/test/Dialect/SparseTensor/sparse_3d.mlir
index 4911c78bcff341..b2f528fc7a25e7 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_3d.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_3d.mlir
@@ -37,12 +37,12 @@
// CHECK-DAG: %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
// CHECK: linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_11]] : memref<32x16x8xf32>)
// CHECK: scf.for %[[VAL_12:.*]] = %[[VAL_6]] to %[[VAL_3]] step %[[VAL_7]] {
+// CHECK: %[[VAL_14:.*]] = arith.muli %[[VAL_12]], %[[VAL_4]] : index
// CHECK: scf.for %[[VAL_13:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_7]] {
-// CHECK: %[[VAL_14:.*]] = arith.muli %[[VAL_12]], %[[VAL_4]] : index
-// CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_14]], %[[VAL_13]] : index
+// CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_13]], %[[VAL_14]] : index
+// CHECK: %[[VAL_17:.*]] = arith.muli %[[VAL_15]], %[[VAL_5]] : index
// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_6]] to %[[VAL_5]] step %[[VAL_7]] {
-// CHECK: %[[VAL_17:.*]] = arith.muli %[[VAL_15]], %[[VAL_5]] : index
-// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_17]], %[[VAL_16]] : index
+// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_16]], %[[VAL_17]] : index
// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref<?xf32>
// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_12]], %[[VAL_13]], %[[VAL_16]]] : memref<32x16x8xf32>
// CHECK: %[[VAL_21:.*]] = arith.addf %[[VAL_19]], %[[VAL_20]] : f32
@@ -79,12 +79,12 @@ func.func @add_ddd(%arga: tensor<32x16x8xf32, #Tddd>, %argb: tensor<32x16x8xf32>
// CHECK-DAG: %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
// CHECK: linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_11]] : memref<32x16x8xf32>)
// CHECK: scf.for %[[VAL_12:.*]] = %[[VAL_6]] to %[[VAL_3]] step %[[VAL_7]] {
+// CHECK: %[[VAL_14:.*]] = arith.muli %[[VAL_12]], %[[VAL_4]] : index
// CHECK: scf.for %[[VAL_13:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_7]] {
-// CHECK: %[[VAL_14:.*]] = arith.muli %[[VAL_12]], %[[VAL_4]] : index
-// CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_14]], %[[VAL_13]] : index
+// CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_13]], %[[VAL_14]] : index
+// CHECK: %[[VAL_17:.*]] = arith.muli %[[VAL_15]], %[[VAL_5]] : index
// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_6]] to %[[VAL_5]] step %[[VAL_7]] {
-// CHECK: %[[VAL_17:.*]] = arith.muli %[[VAL_15]], %[[VAL_5]] : index
-// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_17]], %[[VAL_16]] : index
+// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_16]], %[[VAL_17]] : index
// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref<?xf32>
// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_12]], %[[VAL_13]], %[[VAL_16]]] : memref<32x16x8xf32>
// CHECK: %[[VAL_21:.*]] = arith.mulf %[[VAL_19]], %[[VAL_20]] : f32
@@ -124,9 +124,9 @@ func.func @mul_ddd(%arga: tensor<32x16x8xf32, #Tddd>, %argb: tensor<32x16x8xf32>
// CHECK-DAG: %[[VAL_15:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
// CHECK: linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_15]] : memref<32x16x8xf32>)
// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_9]] {
+// CHECK: %[[VAL_18:.*]] = arith.muli %[[VAL_16]], %[[VAL_5]] : index
// CHECK: scf.for %[[VAL_17:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_9]] {
-// CHECK: %[[VAL_18:.*]] = arith.muli %[[VAL_16]], %[[VAL_5]] : index
-// CHECK: %[[VAL_19:.*]] = arith.addi %[[VAL_18]], %[[VAL_17]] : index
+// CHECK: %[[VAL_19:.*]] = arith.addi %[[VAL_17]], %[[VAL_18]] : index
// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<?xindex>
// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_19]], %[[VAL_9]] : index
// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_21]]] : memref<?xindex>
@@ -191,9 +191,9 @@ func.func @add_dds(%arga: tensor<32x16x8xf32, #Tdds>, %argb: tensor<32x16x8xf32>
// CHECK-DAG: %[[VAL_13:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16x8xf32>
// CHECK: linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_13]] : memref<32x16x8xf32>)
// CHECK: scf.for %[[VAL_14:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_7]] {
+// CHECK: %[[VAL_16:.*]] = arith.muli %[[VAL_14]], %[[VAL_5]] : index
// CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_6]] to %[[VAL_5]] step %[[VAL_7]] {
-// CHECK: %[[VAL_16:.*]] = arith.muli %[[VAL_14]], %[[VAL_5]] : index
-// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_16]], %[[VAL_15]] : index
+// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_15]], %[[VAL_16]] : index
// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_17]]] : memref<?xindex>
// CHECK: %[[VAL_19:.*]] = arith.addi %[[VAL_17]], %[[VAL_7]] : index
// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_19]]] : memref<?xindex>
@@ -249,9 +249,9 @@ func.func @mul_dds(%arga: tensor<32x16x8xf32, #Tdds>, %argb: tensor<32x16x8xf32>
// CHECK: %[[VAL_25:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_23]]] : memref<?xindex>
// CHECK: %[[VAL_26:.*]] = arith.cmpi eq, %[[VAL_25]], %[[VAL_24]] : index
// CHECK: scf.if %[[VAL_26]] {
+// CHECK: %[[VAL_28:.*]] = arith.muli %[[VAL_23]], %[[VAL_5]] : index
// CHECK: scf.for %[[VAL_27:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] {
-// CHECK: %[[VAL_28:.*]] = arith.muli %[[VAL_23]], %[[VAL_5]] : index
-// CHECK: %[[VAL_29:.*]] = arith.addi %[[VAL_28]], %[[VAL_27]] : index
+// CHECK: %[[VAL_29:.*]] = arith.addi %[[VAL_27]], %[[VAL_28]] : index
// CHECK: %[[VAL_30:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_29]]] : memref<?xf32>
// CHECK: %[[VAL_31:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_15]], %[[VAL_24]], %[[VAL_27]]] : memref<32x16x8xf32>
// CHECK: %[[VAL_32:.*]] = arith.addf %[[VAL_30]], %[[VAL_31]] : f32
@@ -314,9 +314,9 @@ func.func @add_dsd(%arga: tensor<32x16x8xf32, #Tdsd>, %argb: tensor<32x16x8xf32>
// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_15]]] : memref<?xindex>
// CHECK: scf.for %[[VAL_17:.*]] = %[[VAL_14]] to %[[VAL_16]] step %[[VAL_6]] {
// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_17]]] : memref<?xindex>
+// CHECK: %[[VAL_20:.*]] = arith.muli %[[VAL_17]], %[[VAL_4]] : index
// CHECK: scf.for %[[VAL_19:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] {
-// CHECK: %[[VAL_20:.*]] = arith.muli %[[VAL_17]], %[[VAL_4]] : index
-// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_20]], %[[VAL_19]] : index
+// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_19]], %[[VAL_20]] : index
// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_21]]] : memref<?xf32>
// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_13]], %[[VAL_18]], %[[VAL_19]]] : memref<32x16x8xf32>
// CHECK: %[[VAL_24:.*]] = arith.mulf %[[VAL_22]], %[[VAL_23]] : f32
@@ -512,12 +512,12 @@ func.func @mul_dss(%arga: tensor<32x16x8xf32, #Tdss>, %argb: tensor<32x16x8xf32>
// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_21]]] : memref<?xindex>
// CHECK: %[[VAL_24:.*]] = arith.cmpi eq, %[[VAL_23]], %[[VAL_22]] : index
// CHECK: scf.if %[[VAL_24]] {
+// CHECK: %[[VAL_26:.*]] = arith.muli %[[VAL_21]], %[[VAL_4]] : index
// CHECK: scf.for %[[VAL_25:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_8]] {
-// CHECK: %[[VAL_26:.*]] = arith.muli %[[VAL_21]], %[[VAL_4]] : index
-// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_26]], %[[VAL_25]] : index
+// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_25]], %[[VAL_26]] : index
+// CHECK: %[[VAL_29:.*]] = arith.muli %[[VAL_27]], %[[VAL_5]] : index
// CHECK: scf.for %[[VAL_28:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] {
-// CHECK: %[[VAL_29:.*]] = arith.muli %[[VAL_27]], %[[VAL_5]] : index
-// CHECK: %[[VAL_30:.*]] = arith.addi %[[VAL_29]], %[[VAL_28]] : index
+// CHECK: %[[VAL_30:.*]] = arith.addi %[[VAL_28]], %[[VAL_29]] : index
// CHECK: %[[VAL_31:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_30]]] : memref<?xf32>
// CHECK: %[[VAL_32:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_22]], %[[VAL_25]], %[[VAL_28]]] : memref<32x16x8xf32>
// CHECK: %[[VAL_33:.*]] = arith.addf %[[VAL_31]], %[[VAL_32]] : f32
@@ -582,12 +582,12 @@ func.func @add_sdd(%arga: tensor<32x16x8xf32, #Tsdd>, %argb: tensor<32x16x8xf32>
// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_6]]] : memref<?xindex>
// CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_13]] to %[[VAL_14]] step %[[VAL_6]] {
// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_15]]] : memref<?xindex>
+// CHECK: %[[VAL_18:.*]] = arith.muli %[[VAL_15]], %[[VAL_3]] : index
// CHECK: scf.for %[[VAL_17:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] {
-// CHECK: %[[VAL_18:.*]] = arith.muli %[[VAL_15]], %[[VAL_3]] : index
-// CHECK: %[[VAL_19:.*]] = arith.addi %[[VAL_18]], %[[VAL_17]] : index
+// CHECK: %[[VAL_19:.*]] = arith.addi %[[VAL_17]], %[[VAL_18]] : index
+// CHECK: %[[VAL_21:.*]] = arith.muli %[[VAL_19]], %[[VAL_4]] : index
// CHECK: scf.for %[[VAL_20:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] {
-// CHECK: %[[VAL_21:.*]] = arith.muli %[[VAL_19]], %[[VAL_4]] : index
-// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_21]], %[[VAL_20]] : index
+// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_20]], %[[VAL_21]] : index
// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_22]]] : memref<?xf32>
// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_16]], %[[VAL_17]], %[[VAL_20]]] : memref<32x16x8xf32>
// CHECK: %[[VAL_25:.*]] = arith.mulf %[[VAL_23]], %[[VAL_24]] : f32
@@ -638,9 +638,9 @@ func.func @mul_sdd(%arga: tensor<32x16x8xf32, #Tsdd>, %argb: tensor<32x16x8xf32>
// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_24]]] : memref<?xindex>
// CHECK: %[[VAL_27:.*]] = arith.cmpi eq, %[[VAL_26]], %[[VAL_25]] : index
// CHECK: scf.if %[[VAL_27]] {
+// CHECK: %[[VAL_29:.*]] = arith.muli %[[VAL_24]], %[[VAL_5]] : index
// CHECK: scf.for %[[VAL_28:.*]] = %[[VAL_8]] to %[[VAL_5]] step %[[VAL_9]] {
-// CHECK: %[[VAL_29:.*]] = arith.muli %[[VAL_24]], %[[VAL_5]] : index
-// CHECK: %[[VAL_30:.*]] = arith.addi %[[VAL_29]], %[[VAL_28]] : index
+// CHECK: %[[VAL_30:.*]] = arith.addi %[[VAL_28]], %[[VAL_29]] : index
// CHECK: %[[VAL_31:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_30]]] : memref<?xindex>
// CHECK: %[[VAL_32:.*]] = arith.addi %[[VAL_30]], %[[VAL_9]] : index
// CHECK: %[[VAL_33:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_32]]] : memref<?xindex>
@@ -733,9 +733,9 @@ func.func @add_sds(%arga: tensor<32x16x8xf32, #Tsds>, %argb: tensor<32x16x8xf32>
// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_6]]] : memref<?xindex>
// CHECK: scf.for %[[VAL_17:.*]] = %[[VAL_15]] to %[[VAL_16]] step %[[VAL_6]] {
// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_17]]] : memref<?xindex>
+// CHECK: %[[VAL_20:.*]] = arith.muli %[[VAL_17]], %[[VAL_4]] : index
// CHECK: scf.for %[[VAL_19:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] {
-// CHECK: %[[VAL_20:.*]] = arith.muli %[[VAL_17]], %[[VAL_4]] : index
-// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_20]], %[[VAL_19]] : index
+// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_19]], %[[VAL_20]] : index
// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_21]]] : memref<?xindex>
// CHECK: %[[VAL_23:.*]] = arith.addi %[[VAL_21]], %[[VAL_6]] : index
// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_23]]] : memref<?xindex>
@@ -802,9 +802,9 @@ func.func @mul_sds(%arga: tensor<32x16x8xf32, #Tsds>, %argb: tensor<32x16x8xf32>
// CHECK: %[[VAL_36:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_34]]] : memref<?xindex>
// CHECK: %[[VAL_37:.*]] = arith.cmpi eq, %[[VAL_36]], %[[VAL_35]] : index
// CHECK: scf.if %[[VAL_37]] {
+// CHECK: %[[VAL_39:.*]] = arith.muli %[[VAL_34]], %[[VAL_5]] : index
// CHECK: scf.for %[[VAL_38:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] {
-// CHECK: %[[VAL_39:.*]] = arith.muli %[[VAL_34]], %[[VAL_5]] : index
-// CHECK: %[[VAL_40:.*]] = arith.addi %[[VAL_39]], %[[VAL_38]] : index
+// CHECK: %[[VAL_40:.*]] = arith.addi %[[VAL_38]], %[[VAL_39]] : index
// CHECK: %[[VAL_41:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_40]]] : memref<?xf32>
// CHECK: %[[VAL_42:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_24]], %[[VAL_35]], %[[VAL_38]]] : memref<32x16x8xf32>
// CHECK: %[[VAL_43:.*]] = arith.addf %[[VAL_41]], %[[VAL_42]] : f32
@@ -895,9 +895,9 @@ func.func @add_ssd(%arga: tensor<32x16x8xf32, #Tssd>, %argb: tensor<32x16x8xf32>
// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_19]]] : memref<?xindex>
// CHECK: scf.for %[[VAL_21:.*]] = %[[VAL_18]] to %[[VAL_20]] step %[[VAL_5]] {
// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_21]]] : memref<?xindex>
+// CHECK: %[[VAL_24:.*]] = arith.muli %[[VAL_21]], %[[VAL_3]] : index
// CHECK: scf.for %[[VAL_23:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
-// CHECK: %[[VAL_24:.*]] = arith.muli %[[VAL_21]], %[[VAL_3]] : index
-// CHECK: %[[VAL_25:.*]] = arith.addi %[[VAL_24]], %[[VAL_23]] : index
+// CHECK: %[[VAL_25:.*]] = arith.addi %[[VAL_23]], %[[VAL_24]] : index
// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_25]]] : memref<?xf32>
// CHECK: %[[VAL_27:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_17]], %[[VAL_22]], %[[VAL_23]]] : memref<32x16x8xf32>
// CHECK: %[[VAL_28:.*]] = arith.mulf %[[VAL_26]], %[[VAL_27]] : f32
@@ -1133,9 +1133,9 @@ func.func @mul_sss(%arga: tensor<32x16x8xf32, #Tsss>, %argb: tensor<32x16x8xf32>
// CHECK-DAG: %[[VAL_14:.*]] = tensor.dim %[[VAL_2]], %[[VAL_6]] : tensor<?x?xf32>
// CHECK-DAG: %[[VAL_16:.*]] = bufferization.to_memref %[[VAL_0]] : memref<?x?xf32>
// CHECK: scf.for %[[VAL_17:.*]] = %[[VAL_5]] to %[[VAL_13]] step %[[VAL_6]] {
+// CHECK: %[[VAL_19:.*]] = arith.muli %[[VAL_17]], %[[VAL_10]] : index
// CHECK: scf.for %[[VAL_18:.*]] = %[[VAL_5]] to %[[VAL_10]] step %[[VAL_6]] {
-// CHECK: %[[VAL_19:.*]] = arith.muli %[[VAL_10]], %[[VAL_17]] : index
-// CHECK: %[[VAL_20:.*]] = arith.addi %[[VAL_19]], %[[VAL_18]] : index
+// CHECK: %[[VAL_20:.*]] = arith.addi %[[VAL_18]], %[[VAL_19]] : index
// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_20]]] : memref<?xindex>
// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_20]], %[[VAL_6]] : index
// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_22]]] : memref<?xindex>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_affine.mlir b/mlir/test/Dialect/SparseTensor/sparse_affine.mlir
index 886b21fa975679..2128ca7539fa08 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_affine.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_affine.mlir
@@ -234,9 +234,9 @@ func.func @mul_affine_dense2d(%arga: tensor<32x16xf64, #CSR>,
// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_21]]] : memref<?xindex>
// CHECK: scf.for %[[VAL_23:.*]] = %[[VAL_20]] to %[[VAL_22]] step %[[VAL_5]] {
// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_23]]] : memref<?xindex>
-// CHECK: %[[VAL_25:.*]] = arith.addi %[[VAL_24]], %[[VAL_7]] : index
// CHECK: %[[VAL_26:.*]] = arith.muli %[[VAL_17]], %[[VAL_3]] : index
-// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_26]], %[[VAL_25]] : index
+// CHECK: %[[VAL_25:.*]] = arith.addi %[[VAL_24]], %[[VAL_7]] : index
+// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_25]], %[[VAL_26]] : index
// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_18]], %[[VAL_24]]] : memref<32x16xf64>
// CHECK: %[[VAL_29:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_23]]] : memref<?xf64>
// CHECK: %[[VAL_30:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_27]]] : memref<?xf64>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
index bf61e792ffbe05..70cf0f9af45b50 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
@@ -1,3 +1,4 @@
+// TODO: re-enable after lowering coo.next to function call (such that loop structure is more clear).
// RUN: mlir-opt %s --sparse-reinterpret-map --sparsification --canonicalize --cse | FileCheck %s
#map = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)>
@@ -8,232 +9,232 @@
// CHECK-LABEL: func.func @conv2d_all_sparse_CSR(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<8x8xi32, #sparse>,
-// CHECK-SAME: %[[VAL_1:.*]]: tensor<3x3xi32>) -> tensor<6x6xi32, #sparse> {
-// CHECK-DAG: %[[VAL_2:.*]] = arith.constant true
-// CHECK-DAG: %[[VAL_3:.*]] = arith.constant -2 : index
-// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 8 : index
-// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 3 : index
-// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 0 : i32
-// CHECK-DAG: %[[VAL_10:.*]] = arith.constant false
-// CHECK-DAG: %[[VAL_11:.*]] = tensor.empty() : tensor<6x6xi32, #sparse>
-// CHECK-DAG: %[[VAL_12:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<8x8xi32, #sparse> to memref<?xindex>
-// CHECK-DAG: %[[VAL_13:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<8x8xi32, #sparse> to memref<?xindex>
-// CHECK-DAG: %[[VAL_14:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<8x8xi32, #sparse> to memref<?xindex>
-// CHECK-DAG: %[[VAL_15:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<8x8xi32, #sparse> to memref<?xindex>
-// CHECK-DAG: %[[VAL_16:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x8xi32, #sparse> to memref<?xi32>
-// CHECK-DAG: %[[VAL_17:.*]] = memref.alloca() : memref<9xindex>
-// CHECK-DAG: %[[VAL_18:.*]] = memref.alloca() : memref<3xindex>
-// CHECK-DAG: %[[POS_LO:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_8]]] : memref<?xindex>
-// CHECK-DAG: %[[POS_HI:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_7]]] : memref<?xindex>
-// CHECK: memref.store %[[POS_LO]], %[[VAL_18]]{{\[}}%[[VAL_8]]] : memref<3xindex>
-// CHECK: memref.store %[[POS_HI]], %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<3xindex>
-// CHECK: %[[VAL_20:.*]] = arith.cmpi ult, %[[POS_LO]], %[[POS_HI]] : index
-// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_13]]{{\[}}%[[POS_LO]]] : memref<?xindex>
-// CHECK: %[[VAL_22:.*]] = arith.cmpi uge, %[[VAL_21]], %[[VAL_6]] : index
-// CHECK: %[[VAL_23:.*]] = arith.andi %[[VAL_20]], %[[VAL_22]] : i1
-// CHECK: %[[VAL_24:.*]] = arith.addi %[[VAL_21]], %[[VAL_3]] : index
-// CHECK: %[[VAL_25:.*]] = arith.select %[[VAL_23]], %[[VAL_24]], %[[VAL_8]] : index
-// CHECK: %[[VAL_26:.*]]:3 = scf.while (%[[VAL_27:.*]] = %[[VAL_20]], %[[VAL_28:.*]] = %[[VAL_21]], %[[VAL_29:.*]] = %[[VAL_25]], %[[VAL_30:.*]] = %[[VAL_11]]) : (i1, index, index, tensor<6x6xi32, #sparse>) -> (index, index, tensor<6x6xi32, #sparse>) {
-// CHECK: scf.condition(%[[VAL_27]]) %[[VAL_28]], %[[VAL_29]], %[[VAL_30]] : index, index, tensor<6x6xi32, #sparse>
-// CHECK: } do {
-// CHECK: ^bb0(%[[VAL_31:.*]]: index, %[[VAL_32:.*]]: index, %[[VAL_33:.*]]: tensor<6x6xi32, #sparse>):
-// CHECK: %[[VAL_34:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_8]]] : memref<3xindex>
-// CHECK: %[[VAL_35:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<3xindex>
-// CHECK: memref.store %[[VAL_8]], %[[VAL_18]]{{\[}}%[[VAL_4]]] : memref<3xindex>
-// CHECK: %[[VAL_36:.*]] = arith.addi %[[VAL_32]], %[[VAL_6]] : index
-// CHECK: %[[VAL_37:.*]]:5 = scf.while (%[[VAL_38:.*]] = %[[VAL_34]], %[[VAL_39:.*]] = %[[VAL_10]], %[[VAL_40:.*]] = %[[VAL_5]], %[[VAL_41:.*]] = %[[VAL_8]], %[[VAL_42:.*]] = %[[VAL_8]]) : (index, i1, index, index, index) -> (index, i1, index, index, index) {
-// CHECK: %[[VAL_43:.*]] = arith.cmpi ult, %[[VAL_38]], %[[VAL_35]] : index
-// CHECK: %[[VAL_44:.*]] = scf.if %[[VAL_43]] -> (i1) {
-// CHECK: %[[VAL_45:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_38]]] : memref<?xindex>
-// CHECK: %[[VAL_46:.*]] = arith.cmpi ult, %[[VAL_45]], %[[VAL_36]] : index
-// CHECK: scf.yield %[[VAL_46]] : i1
-// CHECK: } else {
-// CHECK: scf.yield %[[VAL_10]] : i1
-// CHECK: }
-// CHECK: scf.condition(%[[VAL_44]]) %[[VAL_38]], %[[VAL_39]], %[[VAL_40]], %[[VAL_41]], %[[VAL_42]] : index, i1, index, index, index
-// CHECK: } do {
-// CHECK: ^bb0(%[[VAL_47:.*]]: index, %[[VAL_48:.*]]: i1, %[[VAL_49:.*]]: index, %[[VAL_50:.*]]: index, %[[VAL_51:.*]]: index):
-// CHECK-DAG: %[[VAL_52:.*]] = arith.addi %[[VAL_47]], %[[VAL_7]] : index
-// CHECK-DAG: %[[VAL_53:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_47]]] : memref<?xindex>
-// CHECK: %[[VAL_54:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_52]]] : memref<?xindex>
-// CHECK: %[[VAL_55:.*]] = arith.cmpi ult, %[[VAL_53]], %[[VAL_54]] : index
-// CHECK: %[[VAL_56:.*]] = arith.ori %[[VAL_55]], %[[VAL_48]] : i1
-// CHECK: %[[VAL_57:.*]] = scf.if %[[VAL_55]] -> (index) {
-// CHECK: %[[VAL_58:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_53]]] : memref<?xindex>
-// CHECK: %[[VAL_59:.*]] = arith.cmpi ult, %[[VAL_58]], %[[VAL_49]] : index
-// CHECK: %[[VAL_60:.*]] = arith.select %[[VAL_59]], %[[VAL_58]], %[[VAL_49]] : index
-// CHECK: scf.yield %[[VAL_60]] : index
-// CHECK: } else {
-// CHECK: scf.yield %[[VAL_49]] : index
-// CHECK: }
-// CHECK: memref.store %[[VAL_53]], %[[VAL_17]]{{\[}}%[[VAL_50]]] : memref<9xindex>
-// CHECK: %[[VAL_61:.*]] = arith.addi %[[VAL_50]], %[[VAL_6]] : index
-// CHECK: memref.store %[[VAL_54]], %[[VAL_17]]{{\[}}%[[VAL_61]]] : memref<9xindex>
-// CHECK: %[[VAL_62:.*]] = arith.addi %[[VAL_50]], %[[VAL_7]] : index
-// CHECK: %[[VAL_63:.*]] = arith.addi %[[VAL_51]], %[[VAL_7]] : index
-// CHECK: scf.yield %[[VAL_52]], %[[VAL_56]], %[[VAL_57]], %[[VAL_62]], %[[VAL_63]] : index, i1, index, index, index
-// CHECK: }
-// CHECK: %[[VAL_64:.*]] = arith.cmpi uge, %[[VAL_65:.*]]#2, %[[VAL_6]] : index
-// CHECK: %[[VAL_66:.*]] = arith.andi %[[VAL_65]]#1, %[[VAL_64]] : i1
-// CHECK: %[[VAL_67:.*]] = arith.addi %[[VAL_65]]#2, %[[VAL_3]] : index
-// CHECK: %[[VAL_68:.*]] = arith.select %[[VAL_66]], %[[VAL_67]], %[[VAL_8]] : index
-// CHECK: %[[VAL_69:.*]]:3 = scf.while (%[[VAL_70:.*]] = %[[VAL_65]]#1, %[[VAL_71:.*]] = %[[VAL_65]]#2, %[[VAL_72:.*]] = %[[VAL_68]], %[[VAL_73:.*]] = %[[VAL_33]]) : (i1, index, index, tensor<6x6xi32, #sparse>) -> (index, index, tensor<6x6xi32, #sparse>) {
-// CHECK: scf.condition(%[[VAL_70]]) %[[VAL_71]], %[[VAL_72]], %[[VAL_73]] : index, index, tensor<6x6xi32, #sparse>
-// CHECK: } do {
-// CHECK: ^bb0(%[[VAL_74:.*]]: index, %[[VAL_75:.*]]: index, %[[VAL_76:.*]]: tensor<6x6xi32, #sparse>):
-// CHECK: %[[VAL_77:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_8]]] : memref<3xindex>
-// CHECK: %[[VAL_78:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<3xindex>
-// CHECK: %[[VAL_79:.*]]:3 = scf.while (%[[VAL_80:.*]] = %[[VAL_77]], %[[VAL_81:.*]] = %[[VAL_9]], %[[VAL_82:.*]] = %[[VAL_10]]) : (index, i32, i1) -> (index, i32, i1) {
-// CHECK: %[[VAL_83:.*]] = arith.cmpi ult, %[[VAL_80]], %[[VAL_78]] : index
-// CHECK: %[[VAL_84:.*]] = scf.if %[[VAL_83]] -> (i1) {
-// CHECK: %[[VAL_85:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_80]]] : memref<?xindex>
-// CHECK: %[[VAL_86:.*]] = arith.cmpi ult, %[[VAL_85]], %[[VAL_36]] : index
-// CHECK: scf.yield %[[VAL_86]] : i1
-// CHECK: } else {
-// CHECK: scf.yield %[[VAL_10]] : i1
-// CHECK: }
-// CHECK: scf.condition(%[[VAL_84]]) %[[VAL_80]], %[[VAL_81]], %[[VAL_82]] : index, i32, i1
-// CHECK: } do {
-// CHECK: ^bb0(%[[VAL_87:.*]]: index, %[[VAL_88:.*]]: i32, %[[VAL_89:.*]]: i1):
-// CHECK: %[[VAL_90:.*]] = arith.subi %[[VAL_87]], %[[VAL_77]] : index
-// CHECK: %[[VAL_91:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_87]]] : memref<?xindex>
-// CHECK: %[[VAL_92:.*]] = arith.subi %[[VAL_91]], %[[VAL_32]] : index
-// CHECK: %[[VAL_93:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_90]]] : memref<9xindex>
-// CHECK: %[[VAL_94:.*]] = arith.addi %[[VAL_90]], %[[VAL_6]] : index
-// CHECK: %[[VAL_95:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_94]]] : memref<9xindex>
-// CHECK: %[[VAL_96:.*]] = arith.addi %[[VAL_75]], %[[VAL_6]] : index
-// CHECK: %[[VAL_97:.*]]:2 = scf.while (%[[VAL_98:.*]] = %[[VAL_93]], %[[VAL_99:.*]] = %[[VAL_88]]) : (index, i32) -> (index, i32) {
-// CHECK: %[[VAL_100:.*]] = arith.cmpi ult, %[[VAL_98]], %[[VAL_95]] : index
-// CHECK: %[[VAL_101:.*]] = scf.if %[[VAL_100]] -> (i1) {
-// CHECK: %[[VAL_102:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_98]]] : memref<?xindex>
-// CHECK: %[[VAL_103:.*]] = arith.cmpi ult, %[[VAL_102]], %[[VAL_96]] : index
-// CHECK: scf.yield %[[VAL_103]] : i1
-// CHECK: } else {
-// CHECK: scf.yield %[[VAL_10]] : i1
-// CHECK: }
-// CHECK: scf.condition(%[[VAL_101]]) %[[VAL_98]], %[[VAL_99]] : index, i32
-// CHECK: } do {
-// CHECK: ^bb0(%[[VAL_104:.*]]: index, %[[VAL_105:.*]]: i32):
-// CHECK: %[[VAL_106:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_104]]] : memref<?xindex>
-// CHECK: %[[VAL_107:.*]] = arith.subi %[[VAL_106]], %[[VAL_75]] : index
-// CHECK: %[[VAL_108:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_104]]] : memref<?xi32>
-// CHECK: %[[VAL_109:.*]] = tensor.extract %[[VAL_1]]{{\[}}%[[VAL_92]], %[[VAL_107]]] : tensor<3x3xi32>
-// CHECK: %[[VAL_110:.*]] = arith.muli %[[VAL_108]], %[[VAL_109]] : i32
-// CHECK: %[[VAL_111:.*]] = arith.addi %[[VAL_105]], %[[VAL_110]] : i32
-// CHECK: %[[VAL_112:.*]] = arith.addi %[[VAL_104]], %[[VAL_7]] : index
-// CHECK: scf.yield %[[VAL_112]], %[[VAL_111]] : index, i32
-// CHECK: }
-// CHECK: %[[VAL_113:.*]] = arith.addi %[[VAL_87]], %[[VAL_7]] : index
-// CHECK: scf.yield %[[VAL_113]], %[[VAL_114:.*]]#1, %[[VAL_2]] : index, i32, i1
-// CHECK: }
-// CHECK: %[[VAL_115:.*]] = scf.if %[[VAL_116:.*]]#2 -> (tensor<6x6xi32, #sparse>) {
-// CHECK: %[[VAL_117:.*]] = sparse_tensor.insert %[[VAL_116]]#1 into %[[VAL_76]]{{\[}}%[[VAL_32]], %[[VAL_75]]] : tensor<6x6xi32, #sparse>
-// CHECK: scf.yield %[[VAL_117]] : tensor<6x6xi32, #sparse>
-// CHECK: } else {
-// CHECK: scf.yield %[[VAL_76]] : tensor<6x6xi32, #sparse>
-// CHECK: }
-// CHECK: %[[VAL_118:.*]] = arith.cmpi ugt, %[[VAL_74]], %[[VAL_75]] : index
-// CHECK: %[[VAL_119:.*]]:3 = scf.if %[[VAL_118]] -> (index, i1, index) {
-// CHECK: %[[VAL_120:.*]] = arith.addi %[[VAL_75]], %[[VAL_7]] : index
-// CHECK: scf.yield %[[VAL_74]], %[[VAL_2]], %[[VAL_120]] : index, i1, index
-// CHECK: } else {
-// CHECK: %[[VAL_121:.*]]:2 = scf.for %[[VAL_122:.*]] = %[[VAL_8]] to %[[VAL_65]]#3 step %[[VAL_7]] iter_args(%[[VAL_123:.*]] = %[[VAL_5]], %[[VAL_124:.*]] = %[[VAL_10]]) -> (index, i1) {
-// CHECK: %[[VAL_125:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_122]]] : memref<9xindex>
-// CHECK: %[[VAL_126:.*]] = arith.addi %[[VAL_122]], %[[VAL_6]] : index
-// CHECK: %[[VAL_127:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_126]]] : memref<9xindex>
-// CHECK: %[[VAL_128:.*]] = arith.cmpi ult, %[[VAL_125]], %[[VAL_127]] : index
-// CHECK: %[[VAL_129:.*]] = scf.if %[[VAL_128]] -> (index) {
-// CHECK: %[[VAL_130:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_125]]] : memref<?xindex>
-// CHECK: %[[VAL_131:.*]] = arith.cmpi eq, %[[VAL_130]], %[[VAL_74]] : index
-// CHECK: %[[VAL_132:.*]] = scf.if %[[VAL_131]] -> (index) {
-// CHECK: %[[VAL_133:.*]] = arith.addi %[[VAL_125]], %[[VAL_7]] : index
-// CHECK: memref.store %[[VAL_133]], %[[VAL_17]]{{\[}}%[[VAL_122]]] : memref<9xindex>
-// CHECK: scf.yield %[[VAL_133]] : index
-// CHECK: } else {
-// CHECK: scf.yield %[[VAL_125]] : index
-// CHECK: }
-// CHECK: scf.yield %[[VAL_132]] : index
-// CHECK: } else {
-// CHECK: scf.yield %[[VAL_125]] : index
-// CHECK: }
-// CHECK: %[[VAL_134:.*]] = arith.cmpi ult, %[[VAL_129]], %[[VAL_127]] : index
-// CHECK: %[[VAL_135:.*]] = scf.if %[[VAL_134]] -> (index) {
-// CHECK: %[[VAL_136:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_129]]] : memref<?xindex>
-// CHECK: scf.yield %[[VAL_136]] : index
-// CHECK: } else {
-// CHECK: scf.yield %[[VAL_123]] : index
-// CHECK: }
-// CHECK: %[[VAL_137:.*]] = arith.ori %[[VAL_134]], %[[VAL_124]] : i1
-// CHECK: %[[VAL_138:.*]] = arith.cmpi ult, %[[VAL_135]], %[[VAL_123]] : index
-// CHECK: %[[VAL_139:.*]] = arith.select %[[VAL_138]], %[[VAL_135]], %[[VAL_123]] : index
-// CHECK: scf.yield %[[VAL_139]], %[[VAL_137]] : index, i1
-// CHECK: }
-// CHECK: %[[VAL_140:.*]] = arith.addi %[[VAL_141:.*]]#0, %[[VAL_7]] : index
-// CHECK: %[[VAL_142:.*]] = arith.addi %[[VAL_141]]#0, %[[VAL_3]] : index
-// CHECK: %[[VAL_143:.*]] = arith.cmpi uge, %[[VAL_140]], %[[VAL_6]] : index
-// CHECK: %[[VAL_144:.*]] = arith.select %[[VAL_143]], %[[VAL_142]], %[[VAL_8]] : index
-// CHECK: scf.yield %[[VAL_141]]#0, %[[VAL_141]]#1, %[[VAL_144]] : index, i1, index
-// CHECK: }
-// CHECK: %[[VAL_145:.*]] = arith.addi %[[VAL_75]], %[[VAL_7]] : index
-// CHECK: %[[VAL_146:.*]] = arith.cmpi ugt, %[[VAL_147:.*]]#2, %[[VAL_145]] : index
-// CHECK: %[[VAL_148:.*]] = arith.select %[[VAL_146]], %[[VAL_147]]#2, %[[VAL_145]] : index
-// CHECK: %[[VAL_149:.*]] = arith.addi %[[VAL_148]], %[[VAL_6]] : index
-// CHECK: %[[VAL_150:.*]] = arith.cmpi ule, %[[VAL_149]], %[[VAL_5]] : index
-// CHECK: %[[VAL_151:.*]] = arith.andi %[[VAL_147]]#1, %[[VAL_150]] : i1
-// CHECK: scf.yield %[[VAL_151]], %[[VAL_147]]#0, %[[VAL_148]], %[[VAL_115]] : i1, index, index, tensor<6x6xi32, #sparse>
-// CHECK: }
-// CHECK: %[[VAL_152:.*]] = arith.cmpi ugt, %[[VAL_31]], %[[VAL_32]] : index
-// CHECK: %[[VAL_153:.*]]:3 = scf.if %[[VAL_152]] -> (index, i1, index) {
-// CHECK: %[[VAL_154:.*]] = arith.addi %[[VAL_32]], %[[VAL_7]] : index
-// CHECK: scf.yield %[[VAL_31]], %[[VAL_2]], %[[VAL_154]] : index, i1, index
-// CHECK: } else {
-// CHECK: %[[VAL_155:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_8]]] : memref<3xindex>
-// CHECK: %[[VAL_156:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<3xindex>
-// CHECK: %[[VAL_157:.*]] = arith.cmpi ult, %[[VAL_155]], %[[VAL_156]] : index
-// CHECK: %[[VAL_158:.*]] = scf.if %[[VAL_157]] -> (index) {
-// CHECK: %[[VAL_159:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_155]]] : memref<?xindex>
-// CHECK: %[[VAL_160:.*]] = arith.cmpi eq, %[[VAL_159]], %[[VAL_31]] : index
-// CHECK: %[[VAL_161:.*]] = scf.if %[[VAL_160]] -> (index) {
-// CHECK: %[[VAL_162:.*]] = arith.addi %[[VAL_155]], %[[VAL_7]] : index
-// CHECK: memref.store %[[VAL_162]], %[[VAL_18]]{{\[}}%[[VAL_8]]] : memref<3xindex>
-// CHECK: scf.yield %[[VAL_162]] : index
-// CHECK: } else {
-// CHECK: scf.yield %[[VAL_155]] : index
-// CHECK: }
-// CHECK: scf.yield %[[VAL_161]] : index
-// CHECK: } else {
-// CHECK: scf.yield %[[VAL_155]] : index
-// CHECK: }
-// CHECK: %[[VAL_163:.*]] = arith.cmpi ult, %[[VAL_158]], %[[VAL_156]] : index
-// CHECK: %[[VAL_164:.*]] = scf.if %[[VAL_163]] -> (index) {
-// CHECK: %[[VAL_165:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_158]]] : memref<?xindex>
-// CHECK: scf.yield %[[VAL_165]] : index
-// CHECK: } else {
-// CHECK: scf.yield %[[VAL_5]] : index
-// CHECK: }
-// CHECK: %[[VAL_166:.*]] = arith.cmpi ult, %[[VAL_164]], %[[VAL_5]] : index
-// CHECK: %[[VAL_167:.*]] = arith.select %[[VAL_166]], %[[VAL_164]], %[[VAL_5]] : index
-// CHECK: %[[VAL_168:.*]] = arith.addi %[[VAL_167]], %[[VAL_7]] : index
-// CHECK: %[[VAL_169:.*]] = arith.addi %[[VAL_167]], %[[VAL_3]] : index
-// CHECK: %[[VAL_170:.*]] = arith.cmpi uge, %[[VAL_168]], %[[VAL_6]] : index
-// CHECK: %[[VAL_171:.*]] = arith.select %[[VAL_170]], %[[VAL_169]], %[[VAL_8]] : index
-// CHECK: scf.yield %[[VAL_167]], %[[VAL_163]], %[[VAL_171]] : index, i1, index
-// CHECK: }
-// CHECK: %[[VAL_172:.*]] = arith.addi %[[VAL_32]], %[[VAL_7]] : index
-// CHECK: %[[VAL_173:.*]] = arith.cmpi ugt, %[[VAL_174:.*]]#2, %[[VAL_172]] : index
-// CHECK: %[[VAL_175:.*]] = arith.select %[[VAL_173]], %[[VAL_174]]#2, %[[VAL_172]] : index
-// CHECK: %[[VAL_176:.*]] = arith.addi %[[VAL_175]], %[[VAL_6]] : index
-// CHECK: %[[VAL_177:.*]] = arith.cmpi ule, %[[VAL_176]], %[[VAL_5]] : index
-// CHECK: %[[VAL_178:.*]] = arith.andi %[[VAL_174]]#1, %[[VAL_177]] : i1
-// CHECK: scf.yield %[[VAL_178]], %[[VAL_174]]#0, %[[VAL_175]], %[[VAL_179:.*]]#2 : i1, index, index, tensor<6x6xi32, #sparse>
-// CHECK: }
-// CHECK: %[[VAL_180:.*]] = sparse_tensor.load %[[VAL_181:.*]]#2 hasInserts : tensor<6x6xi32, #sparse>
-// CHECK: return %[[VAL_180]] : tensor<6x6xi32, #sparse>
-// CHECK: }
+// C_HECK-SAME: %[[VAL_0:.*]]: tensor<8x8xi32, #sparse>,
+// C_HECK-SAME: %[[VAL_1:.*]]: tensor<3x3xi32>) -> tensor<6x6xi32, #sparse> {
+// C_HECK-DAG: %[[VAL_2:.*]] = arith.constant true
+// C_HECK-DAG: %[[VAL_3:.*]] = arith.constant -2 : index
+// C_HECK-DAG: %[[VAL_4:.*]] = arith.constant 2 : index
+// C_HECK-DAG: %[[VAL_5:.*]] = arith.constant 8 : index
+// C_HECK-DAG: %[[VAL_6:.*]] = arith.constant 3 : index
+// C_HECK-DAG: %[[VAL_7:.*]] = arith.constant 1 : index
+// C_HECK-DAG: %[[VAL_8:.*]] = arith.constant 0 : index
+// C_HECK-DAG: %[[VAL_9:.*]] = arith.constant 0 : i32
+// C_HECK-DAG: %[[VAL_10:.*]] = arith.constant false
+// C_HECK-DAG: %[[VAL_11:.*]] = tensor.empty() : tensor<6x6xi32, #sparse>
+// C_HECK-DAG: %[[VAL_12:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<8x8xi32, #sparse> to memref<?xindex>
+// C_HECK-DAG: %[[VAL_13:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<8x8xi32, #sparse> to memref<?xindex>
+// C_HECK-DAG: %[[VAL_14:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<8x8xi32, #sparse> to memref<?xindex>
+// C_HECK-DAG: %[[VAL_15:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<8x8xi32, #sparse> to memref<?xindex>
+// C_HECK-DAG: %[[VAL_16:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x8xi32, #sparse> to memref<?xi32>
+// C_HECK-DAG: %[[VAL_17:.*]] = memref.alloca() : memref<9xindex>
+// C_HECK-DAG: %[[VAL_18:.*]] = memref.alloca() : memref<3xindex>
+// C_HECK-DAG: %[[POS_LO:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_8]]] : memref<?xindex>
+// C_HECK-DAG: %[[POS_HI:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_7]]] : memref<?xindex>
+// C_HECK: memref.store %[[POS_LO]], %[[VAL_18]]{{\[}}%[[VAL_8]]] : memref<3xindex>
+// C_HECK: memref.store %[[POS_HI]], %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<3xindex>
+// C_HECK: %[[VAL_20:.*]] = arith.cmpi ult, %[[POS_LO]], %[[POS_HI]] : index
+// C_HECK: %[[VAL_21:.*]] = memref.load %[[VAL_13]]{{\[}}%[[POS_LO]]] : memref<?xindex>
+// C_HECK: %[[VAL_22:.*]] = arith.cmpi uge, %[[VAL_21]], %[[VAL_6]] : index
+// C_HECK: %[[VAL_23:.*]] = arith.andi %[[VAL_20]], %[[VAL_22]] : i1
+// C_HECK: %[[VAL_24:.*]] = arith.addi %[[VAL_21]], %[[VAL_3]] : index
+// C_HECK: %[[VAL_25:.*]] = arith.select %[[VAL_23]], %[[VAL_24]], %[[VAL_8]] : index
+// C_HECK: %[[VAL_26:.*]]:3 = scf.while (%[[VAL_27:.*]] = %[[VAL_20]], %[[VAL_28:.*]] = %[[VAL_21]], %[[VAL_29:.*]] = %[[VAL_25]], %[[VAL_30:.*]] = %[[VAL_11]]) : (i1, index, index, tensor<6x6xi32, #sparse>) -> (index, index, tensor<6x6xi32, #sparse>) {
+// C_HECK: scf.condition(%[[VAL_27]]) %[[VAL_28]], %[[VAL_29]], %[[VAL_30]] : index, index, tensor<6x6xi32, #sparse>
+// C_HECK: } do {
+// C_HECK: ^bb0(%[[VAL_31:.*]]: index, %[[VAL_32:.*]]: index, %[[VAL_33:.*]]: tensor<6x6xi32, #sparse>):
+// C_HECK: %[[VAL_34:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_8]]] : memref<3xindex>
+// C_HECK: %[[VAL_35:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<3xindex>
+// C_HECK: memref.store %[[VAL_8]], %[[VAL_18]]{{\[}}%[[VAL_4]]] : memref<3xindex>
+// C_HECK: %[[VAL_36:.*]] = arith.addi %[[VAL_32]], %[[VAL_6]] : index
+// C_HECK: %[[VAL_37:.*]]:5 = scf.while (%[[VAL_38:.*]] = %[[VAL_34]], %[[VAL_39:.*]] = %[[VAL_10]], %[[VAL_40:.*]] = %[[VAL_5]], %[[VAL_41:.*]] = %[[VAL_8]], %[[VAL_42:.*]] = %[[VAL_8]]) : (index, i1, index, index, index) -> (index, i1, index, index, index) {
+// C_HECK: %[[VAL_43:.*]] = arith.cmpi ult, %[[VAL_38]], %[[VAL_35]] : index
+// C_HECK: %[[VAL_44:.*]] = scf.if %[[VAL_43]] -> (i1) {
+// C_HECK: %[[VAL_45:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_38]]] : memref<?xindex>
+// C_HECK: %[[VAL_46:.*]] = arith.cmpi ult, %[[VAL_45]], %[[VAL_36]] : index
+// C_HECK: scf.yield %[[VAL_46]] : i1
+// C_HECK: } else {
+// C_HECK: scf.yield %[[VAL_10]] : i1
+// C_HECK: }
+// C_HECK: scf.condition(%[[VAL_44]]) %[[VAL_38]], %[[VAL_39]], %[[VAL_40]], %[[VAL_41]], %[[VAL_42]] : index, i1, index, index, index
+// C_HECK: } do {
+// C_HECK: ^bb0(%[[VAL_47:.*]]: index, %[[VAL_48:.*]]: i1, %[[VAL_49:.*]]: index, %[[VAL_50:.*]]: index, %[[VAL_51:.*]]: index):
+// C_HECK-DAG: %[[VAL_52:.*]] = arith.addi %[[VAL_47]], %[[VAL_7]] : index
+// C_HECK-DAG: %[[VAL_53:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_47]]] : memref<?xindex>
+// C_HECK: %[[VAL_54:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_52]]] : memref<?xindex>
+// C_HECK: %[[VAL_55:.*]] = arith.cmpi ult, %[[VAL_53]], %[[VAL_54]] : index
+// C_HECK: %[[VAL_56:.*]] = arith.ori %[[VAL_55]], %[[VAL_48]] : i1
+// C_HECK: %[[VAL_57:.*]] = scf.if %[[VAL_55]] -> (index) {
+// C_HECK: %[[VAL_58:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_53]]] : memref<?xindex>
+// C_HECK: %[[VAL_59:.*]] = arith.cmpi ult, %[[VAL_58]], %[[VAL_49]] : index
+// C_HECK: %[[VAL_60:.*]] = arith.select %[[VAL_59]], %[[VAL_58]], %[[VAL_49]] : index
+// C_HECK: scf.yield %[[VAL_60]] : index
+// C_HECK: } else {
+// C_HECK: scf.yield %[[VAL_49]] : index
+// C_HECK: }
+// C_HECK: memref.store %[[VAL_53]], %[[VAL_17]]{{\[}}%[[VAL_50]]] : memref<9xindex>
+// C_HECK: %[[VAL_61:.*]] = arith.addi %[[VAL_50]], %[[VAL_6]] : index
+// C_HECK: memref.store %[[VAL_54]], %[[VAL_17]]{{\[}}%[[VAL_61]]] : memref<9xindex>
+// C_HECK: %[[VAL_62:.*]] = arith.addi %[[VAL_50]], %[[VAL_7]] : index
+// C_HECK: %[[VAL_63:.*]] = arith.addi %[[VAL_51]], %[[VAL_7]] : index
+// C_HECK: scf.yield %[[VAL_52]], %[[VAL_56]], %[[VAL_57]], %[[VAL_62]], %[[VAL_63]] : index, i1, index, index, index
+// C_HECK: }
+// C_HECK: %[[VAL_64:.*]] = arith.cmpi uge, %[[VAL_65:.*]]#2, %[[VAL_6]] : index
+// C_HECK: %[[VAL_66:.*]] = arith.andi %[[VAL_65]]#1, %[[VAL_64]] : i1
+// C_HECK: %[[VAL_67:.*]] = arith.addi %[[VAL_65]]#2, %[[VAL_3]] : index
+// C_HECK: %[[VAL_68:.*]] = arith.select %[[VAL_66]], %[[VAL_67]], %[[VAL_8]] : index
+// C_HECK: %[[VAL_69:.*]]:3 = scf.while (%[[VAL_70:.*]] = %[[VAL_65]]#1, %[[VAL_71:.*]] = %[[VAL_65]]#2, %[[VAL_72:.*]] = %[[VAL_68]], %[[VAL_73:.*]] = %[[VAL_33]]) : (i1, index, index, tensor<6x6xi32, #sparse>) -> (index, index, tensor<6x6xi32, #sparse>) {
+// C_HECK: scf.condition(%[[VAL_70]]) %[[VAL_71]], %[[VAL_72]], %[[VAL_73]] : index, index, tensor<6x6xi32, #sparse>
+// C_HECK: } do {
+// C_HECK: ^bb0(%[[VAL_74:.*]]: index, %[[VAL_75:.*]]: index, %[[VAL_76:.*]]: tensor<6x6xi32, #sparse>):
+// C_HECK: %[[VAL_77:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_8]]] : memref<3xindex>
+// C_HECK: %[[VAL_78:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<3xindex>
+// C_HECK: %[[VAL_79:.*]]:3 = scf.while (%[[VAL_80:.*]] = %[[VAL_77]], %[[VAL_81:.*]] = %[[VAL_9]], %[[VAL_82:.*]] = %[[VAL_10]]) : (index, i32, i1) -> (index, i32, i1) {
+// C_HECK: %[[VAL_83:.*]] = arith.cmpi ult, %[[VAL_80]], %[[VAL_78]] : index
+// C_HECK: %[[VAL_84:.*]] = scf.if %[[VAL_83]] -> (i1) {
+// C_HECK: %[[VAL_85:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_80]]] : memref<?xindex>
+// C_HECK: %[[VAL_86:.*]] = arith.cmpi ult, %[[VAL_85]], %[[VAL_36]] : index
+// C_HECK: scf.yield %[[VAL_86]] : i1
+// C_HECK: } else {
+// C_HECK: scf.yield %[[VAL_10]] : i1
+// C_HECK: }
+// C_HECK: scf.condition(%[[VAL_84]]) %[[VAL_80]], %[[VAL_81]], %[[VAL_82]] : index, i32, i1
+// C_HECK: } do {
+// C_HECK: ^bb0(%[[VAL_87:.*]]: index, %[[VAL_88:.*]]: i32, %[[VAL_89:.*]]: i1):
+// C_HECK: %[[VAL_90:.*]] = arith.subi %[[VAL_87]], %[[VAL_77]] : index
+// C_HECK: %[[VAL_91:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_87]]] : memref<?xindex>
+// C_HECK: %[[VAL_92:.*]] = arith.subi %[[VAL_91]], %[[VAL_32]] : index
+// C_HECK: %[[VAL_93:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_90]]] : memref<9xindex>
+// C_HECK: %[[VAL_94:.*]] = arith.addi %[[VAL_90]], %[[VAL_6]] : index
+// C_HECK: %[[VAL_95:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_94]]] : memref<9xindex>
+// C_HECK: %[[VAL_96:.*]] = arith.addi %[[VAL_75]], %[[VAL_6]] : index
+// C_HECK: %[[VAL_97:.*]]:2 = scf.while (%[[VAL_98:.*]] = %[[VAL_93]], %[[VAL_99:.*]] = %[[VAL_88]]) : (index, i32) -> (index, i32) {
+// C_HECK: %[[VAL_100:.*]] = arith.cmpi ult, %[[VAL_98]], %[[VAL_95]] : index
+// C_HECK: %[[VAL_101:.*]] = scf.if %[[VAL_100]] -> (i1) {
+// C_HECK: %[[VAL_102:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_98]]] : memref<?xindex>
+// C_HECK: %[[VAL_103:.*]] = arith.cmpi ult, %[[VAL_102]], %[[VAL_96]] : index
+// C_HECK: scf.yield %[[VAL_103]] : i1
+// C_HECK: } else {
+// C_HECK: scf.yield %[[VAL_10]] : i1
+// C_HECK: }
+// C_HECK: scf.condition(%[[VAL_101]]) %[[VAL_98]], %[[VAL_99]] : index, i32
+// C_HECK: } do {
+// C_HECK: ^bb0(%[[VAL_104:.*]]: index, %[[VAL_105:.*]]: i32):
+// C_HECK: %[[VAL_106:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_104]]] : memref<?xindex>
+// C_HECK: %[[VAL_107:.*]] = arith.subi %[[VAL_106]], %[[VAL_75]] : index
+// C_HECK: %[[VAL_108:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_104]]] : memref<?xi32>
+// C_HECK: %[[VAL_109:.*]] = tensor.extract %[[VAL_1]]{{\[}}%[[VAL_92]], %[[VAL_107]]] : tensor<3x3xi32>
+// C_HECK: %[[VAL_110:.*]] = arith.muli %[[VAL_108]], %[[VAL_109]] : i32
+// C_HECK: %[[VAL_111:.*]] = arith.addi %[[VAL_105]], %[[VAL_110]] : i32
+// C_HECK: %[[VAL_112:.*]] = arith.addi %[[VAL_104]], %[[VAL_7]] : index
+// C_HECK: scf.yield %[[VAL_112]], %[[VAL_111]] : index, i32
+// C_HECK: }
+// C_HECK: %[[VAL_113:.*]] = arith.addi %[[VAL_87]], %[[VAL_7]] : index
+// C_HECK: scf.yield %[[VAL_113]], %[[VAL_114:.*]]#1, %[[VAL_2]] : index, i32, i1
+// C_HECK: }
+// C_HECK: %[[VAL_115:.*]] = scf.if %[[VAL_116:.*]]#2 -> (tensor<6x6xi32, #sparse>) {
+// C_HECK: %[[VAL_117:.*]] = sparse_tensor.insert %[[VAL_116]]#1 into %[[VAL_76]]{{\[}}%[[VAL_32]], %[[VAL_75]]] : tensor<6x6xi32, #sparse>
+// C_HECK: scf.yield %[[VAL_117]] : tensor<6x6xi32, #sparse>
+// C_HECK: } else {
+// C_HECK: scf.yield %[[VAL_76]] : tensor<6x6xi32, #sparse>
+// C_HECK: }
+// C_HECK: %[[VAL_118:.*]] = arith.cmpi ugt, %[[VAL_74]], %[[VAL_75]] : index
+// C_HECK: %[[VAL_119:.*]]:3 = scf.if %[[VAL_118]] -> (index, i1, index) {
+// C_HECK: %[[VAL_120:.*]] = arith.addi %[[VAL_75]], %[[VAL_7]] : index
+// C_HECK: scf.yield %[[VAL_74]], %[[VAL_2]], %[[VAL_120]] : index, i1, index
+// C_HECK: } else {
+// C_HECK: %[[VAL_121:.*]]:2 = scf.for %[[VAL_122:.*]] = %[[VAL_8]] to %[[VAL_65]]#3 step %[[VAL_7]] iter_args(%[[VAL_123:.*]] = %[[VAL_5]], %[[VAL_124:.*]] = %[[VAL_10]]) -> (index, i1) {
+// C_HECK: %[[VAL_125:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_122]]] : memref<9xindex>
+// C_HECK: %[[VAL_126:.*]] = arith.addi %[[VAL_122]], %[[VAL_6]] : index
+// C_HECK: %[[VAL_127:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_126]]] : memref<9xindex>
+// C_HECK: %[[VAL_128:.*]] = arith.cmpi ult, %[[VAL_125]], %[[VAL_127]] : index
+// C_HECK: %[[VAL_129:.*]] = scf.if %[[VAL_128]] -> (index) {
+// C_HECK: %[[VAL_130:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_125]]] : memref<?xindex>
+// C_HECK: %[[VAL_131:.*]] = arith.cmpi eq, %[[VAL_130]], %[[VAL_74]] : index
+// C_HECK: %[[VAL_132:.*]] = scf.if %[[VAL_131]] -> (index) {
+// C_HECK: %[[VAL_133:.*]] = arith.addi %[[VAL_125]], %[[VAL_7]] : index
+// C_HECK: memref.store %[[VAL_133]], %[[VAL_17]]{{\[}}%[[VAL_122]]] : memref<9xindex>
+// C_HECK: scf.yield %[[VAL_133]] : index
+// C_HECK: } else {
+// C_HECK: scf.yield %[[VAL_125]] : index
+// C_HECK: }
+// C_HECK: scf.yield %[[VAL_132]] : index
+// C_HECK: } else {
+// C_HECK: scf.yield %[[VAL_125]] : index
+// C_HECK: }
+// C_HECK: %[[VAL_134:.*]] = arith.cmpi ult, %[[VAL_129]], %[[VAL_127]] : index
+// C_HECK: %[[VAL_135:.*]] = scf.if %[[VAL_134]] -> (index) {
+// C_HECK: %[[VAL_136:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_129]]] : memref<?xindex>
+// C_HECK: scf.yield %[[VAL_136]] : index
+// C_HECK: } else {
+// C_HECK: scf.yield %[[VAL_123]] : index
+// C_HECK: }
+// C_HECK: %[[VAL_137:.*]] = arith.ori %[[VAL_134]], %[[VAL_124]] : i1
+// C_HECK: %[[VAL_138:.*]] = arith.cmpi ult, %[[VAL_135]], %[[VAL_123]] : index
+// C_HECK: %[[VAL_139:.*]] = arith.select %[[VAL_138]], %[[VAL_135]], %[[VAL_123]] : index
+// C_HECK: scf.yield %[[VAL_139]], %[[VAL_137]] : index, i1
+// C_HECK: }
+// C_HECK: %[[VAL_140:.*]] = arith.addi %[[VAL_141:.*]]#0, %[[VAL_7]] : index
+// C_HECK: %[[VAL_142:.*]] = arith.addi %[[VAL_141]]#0, %[[VAL_3]] : index
+// C_HECK: %[[VAL_143:.*]] = arith.cmpi uge, %[[VAL_140]], %[[VAL_6]] : index
+// C_HECK: %[[VAL_144:.*]] = arith.select %[[VAL_143]], %[[VAL_142]], %[[VAL_8]] : index
+// C_HECK: scf.yield %[[VAL_141]]#0, %[[VAL_141]]#1, %[[VAL_144]] : index, i1, index
+// C_HECK: }
+// C_HECK: %[[VAL_145:.*]] = arith.addi %[[VAL_75]], %[[VAL_7]] : index
+// C_HECK: %[[VAL_146:.*]] = arith.cmpi ugt, %[[VAL_147:.*]]#2, %[[VAL_145]] : index
+// C_HECK: %[[VAL_148:.*]] = arith.select %[[VAL_146]], %[[VAL_147]]#2, %[[VAL_145]] : index
+// C_HECK: %[[VAL_149:.*]] = arith.addi %[[VAL_148]], %[[VAL_6]] : index
+// C_HECK: %[[VAL_150:.*]] = arith.cmpi ule, %[[VAL_149]], %[[VAL_5]] : index
+// C_HECK: %[[VAL_151:.*]] = arith.andi %[[VAL_147]]#1, %[[VAL_150]] : i1
+// C_HECK: scf.yield %[[VAL_151]], %[[VAL_147]]#0, %[[VAL_148]], %[[VAL_115]] : i1, index, index, tensor<6x6xi32, #sparse>
+// C_HECK: }
+// C_HECK: %[[VAL_152:.*]] = arith.cmpi ugt, %[[VAL_31]], %[[VAL_32]] : index
+// C_HECK: %[[VAL_153:.*]]:3 = scf.if %[[VAL_152]] -> (index, i1, index) {
+// C_HECK: %[[VAL_154:.*]] = arith.addi %[[VAL_32]], %[[VAL_7]] : index
+// C_HECK: scf.yield %[[VAL_31]], %[[VAL_2]], %[[VAL_154]] : index, i1, index
+// C_HECK: } else {
+// C_HECK: %[[VAL_155:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_8]]] : memref<3xindex>
+// C_HECK: %[[VAL_156:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<3xindex>
+// C_HECK: %[[VAL_157:.*]] = arith.cmpi ult, %[[VAL_155]], %[[VAL_156]] : index
+// C_HECK: %[[VAL_158:.*]] = scf.if %[[VAL_157]] -> (index) {
+// C_HECK: %[[VAL_159:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_155]]] : memref<?xindex>
+// C_HECK: %[[VAL_160:.*]] = arith.cmpi eq, %[[VAL_159]], %[[VAL_31]] : index
+// C_HECK: %[[VAL_161:.*]] = scf.if %[[VAL_160]] -> (index) {
+// C_HECK: %[[VAL_162:.*]] = arith.addi %[[VAL_155]], %[[VAL_7]] : index
+// C_HECK: memref.store %[[VAL_162]], %[[VAL_18]]{{\[}}%[[VAL_8]]] : memref<3xindex>
+// C_HECK: scf.yield %[[VAL_162]] : index
+// C_HECK: } else {
+// C_HECK: scf.yield %[[VAL_155]] : index
+// C_HECK: }
+// C_HECK: scf.yield %[[VAL_161]] : index
+// C_HECK: } else {
+// C_HECK: scf.yield %[[VAL_155]] : index
+// C_HECK: }
+// C_HECK: %[[VAL_163:.*]] = arith.cmpi ult, %[[VAL_158]], %[[VAL_156]] : index
+// C_HECK: %[[VAL_164:.*]] = scf.if %[[VAL_163]] -> (index) {
+// C_HECK: %[[VAL_165:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_158]]] : memref<?xindex>
+// C_HECK: scf.yield %[[VAL_165]] : index
+// C_HECK: } else {
+// C_HECK: scf.yield %[[VAL_5]] : index
+// C_HECK: }
+// C_HECK: %[[VAL_166:.*]] = arith.cmpi ult, %[[VAL_164]], %[[VAL_5]] : index
+// C_HECK: %[[VAL_167:.*]] = arith.select %[[VAL_166]], %[[VAL_164]], %[[VAL_5]] : index
+// C_HECK: %[[VAL_168:.*]] = arith.addi %[[VAL_167]], %[[VAL_7]] : index
+// C_HECK: %[[VAL_169:.*]] = arith.addi %[[VAL_167]], %[[VAL_3]] : index
+// C_HECK: %[[VAL_170:.*]] = arith.cmpi uge, %[[VAL_168]], %[[VAL_6]] : index
+// C_HECK: %[[VAL_171:.*]] = arith.select %[[VAL_170]], %[[VAL_169]], %[[VAL_8]] : index
+// C_HECK: scf.yield %[[VAL_167]], %[[VAL_163]], %[[VAL_171]] : index, i1, index
+// C_HECK: }
+// C_HECK: %[[VAL_172:.*]] = arith.addi %[[VAL_32]], %[[VAL_7]] : index
+// C_HECK: %[[VAL_173:.*]] = arith.cmpi ugt, %[[VAL_174:.*]]#2, %[[VAL_172]] : index
+// C_HECK: %[[VAL_175:.*]] = arith.select %[[VAL_173]], %[[VAL_174]]#2, %[[VAL_172]] : index
+// C_HECK: %[[VAL_176:.*]] = arith.addi %[[VAL_175]], %[[VAL_6]] : index
+// C_HECK: %[[VAL_177:.*]] = arith.cmpi ule, %[[VAL_176]], %[[VAL_5]] : index
+// C_HECK: %[[VAL_178:.*]] = arith.andi %[[VAL_174]]#1, %[[VAL_177]] : i1
+// C_HECK: scf.yield %[[VAL_178]], %[[VAL_174]]#0, %[[VAL_175]], %[[VAL_179:.*]]#2 : i1, index, index, tensor<6x6xi32, #sparse>
+// C_HECK: }
+// C_HECK: %[[VAL_180:.*]] = sparse_tensor.load %[[VAL_181:.*]]#2 hasInserts : tensor<6x6xi32, #sparse>
+// C_HECK: return %[[VAL_180]] : tensor<6x6xi32, #sparse>
+// C_HECK: }
func.func @conv2d_all_sparse_CSR(%arg0: tensor<8x8xi32, #DCSR>,
%arg1: tensor<3x3xi32>) -> tensor<6x6xi32, #DCSR> {
%0 = tensor.empty() : tensor<6x6xi32, #DCSR>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_foreach.mlir b/mlir/test/Dialect/SparseTensor/sparse_foreach.mlir
index eb611156722a82..c4ebec368a9cef 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_foreach.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_foreach.mlir
@@ -36,56 +36,57 @@ func.func @sparse_foreach_constant() -> () {
map = (d0 : #sparse_tensor<slice(?, ?, ?)>, d1 : #sparse_tensor<slice(?, ?, ?)>) -> (d0 : compressed, d1 : compressed)
}>
+// TODO: re-enable after lowering coo.next to function call (such that loop structure is more clear).
-// CHECK-LABEL: func.func @foreach_print_slice_dyn(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xf64,
-// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[VAL_3:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<?x?xf64,
-// CHECK-DAG: %[[VAL_4:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<?x?xf64,
-// CHECK-DAG: %[[VAL_5:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_1]] : tensor<?x?xf64,
-// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.slice.offset %[[VAL_0]] at 0 : tensor<?x?xf64,
-// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.slice.stride %[[VAL_0]] at 0 : tensor<?x?xf64,
-// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<?x?xf64,
-// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<?x?xf64,
-// CHECK-DAG: %[[VAL_10:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_2]] : tensor<?x?xf64,
-// CHECK-DAG: %[[VAL_11:.*]] = sparse_tensor.slice.offset %[[VAL_0]] at 1 : tensor<?x?xf64,
-// CHECK-DAG: %[[VAL_12:.*]] = sparse_tensor.slice.stride %[[VAL_0]] at 1 : tensor<?x?xf64,
-// CHECK-DAG: %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?xf64,
-// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_1]]] : memref<?xindex>
-// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_2]]] : memref<?xindex>
-// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_14]] to %[[VAL_15]] step %[[VAL_2]] {
-// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_16]]] : memref<?xindex>
-// CHECK: %[[VAL_18:.*]] = arith.subi %[[VAL_17]], %[[VAL_6]] : index
-// CHECK: %[[VAL_19:.*]] = arith.remui %[[VAL_18]], %[[VAL_7]] : index
-// CHECK: %[[VAL_20:.*]] = arith.divui %[[VAL_18]], %[[VAL_7]] : index
-// CHECK: %[[VAL_21:.*]] = arith.cmpi uge, %[[VAL_17]], %[[VAL_6]] : index
-// CHECK: %[[VAL_22:.*]] = arith.cmpi ult, %[[VAL_20]], %[[VAL_5]] : index
-// CHECK: %[[VAL_23:.*]] = arith.cmpi eq, %[[VAL_19]], %[[VAL_1]] : index
-// CHECK: %[[VAL_24:.*]] = arith.andi %[[VAL_21]], %[[VAL_22]] : i1
-// CHECK: %[[VAL_25:.*]] = arith.andi %[[VAL_24]], %[[VAL_23]] : i1
-// CHECK: scf.if %[[VAL_25]] {
-// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_16]]] : memref<?xindex>
-// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_16]], %[[VAL_2]] : index
-// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_27]]] : memref<?xindex>
-// CHECK: scf.for %[[VAL_29:.*]] = %[[VAL_26]] to %[[VAL_28]] step %[[VAL_2]] {
-// CHECK: %[[VAL_30:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_29]]] : memref<?xindex>
-// CHECK: %[[VAL_31:.*]] = arith.subi %[[VAL_30]], %[[VAL_11]] : index
-// CHECK: %[[VAL_32:.*]] = arith.remui %[[VAL_31]], %[[VAL_12]] : index
-// CHECK: %[[VAL_33:.*]] = arith.divui %[[VAL_31]], %[[VAL_12]] : index
-// CHECK: %[[VAL_34:.*]] = arith.cmpi uge, %[[VAL_30]], %[[VAL_11]] : index
-// CHECK: %[[VAL_35:.*]] = arith.cmpi ult, %[[VAL_33]], %[[VAL_10]] : index
-// CHECK: %[[VAL_36:.*]] = arith.cmpi eq, %[[VAL_32]], %[[VAL_1]] : index
-// CHECK: %[[VAL_37:.*]] = arith.andi %[[VAL_34]], %[[VAL_35]] : i1
-// CHECK: %[[VAL_38:.*]] = arith.andi %[[VAL_37]], %[[VAL_36]] : i1
-// CHECK: scf.if %[[VAL_38]] {
-// CHECK: %[[VAL_39:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_29]]] : memref<?xf64>
-// CHECK: "test.use"(%[[VAL_39]]) : (f64) -> ()
-// CHECK: }
-// CHECK: }
-// CHECK: }
-// CHECK: }
-// CHECK: return
+// C_HECK-LABEL: func.func @foreach_print_slice_dyn(
+// C_HECK-SAME: %[[VAL_0:.*]]: tensor<?x?xf64,
+// C_HECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : index
+// C_HECK-DAG: %[[VAL_2:.*]] = arith.constant 1 : index
+// C_HECK-DAG: %[[VAL_3:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<?x?xf64,
+// C_HECK-DAG: %[[VAL_4:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<?x?xf64,
+// C_HECK-DAG: %[[VAL_5:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_1]] : tensor<?x?xf64,
+// C_HECK-DAG: %[[VAL_6:.*]] = sparse_tensor.slice.offset %[[VAL_0]] at 0 : tensor<?x?xf64,
+// C_HECK-DAG: %[[VAL_7:.*]] = sparse_tensor.slice.stride %[[VAL_0]] at 0 : tensor<?x?xf64,
+// C_HECK-DAG: %[[VAL_8:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<?x?xf64,
+// C_HECK-DAG: %[[VAL_9:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<?x?xf64,
+// C_HECK-DAG: %[[VAL_10:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_2]] : tensor<?x?xf64,
+// C_HECK-DAG: %[[VAL_11:.*]] = sparse_tensor.slice.offset %[[VAL_0]] at 1 : tensor<?x?xf64,
+// C_HECK-DAG: %[[VAL_12:.*]] = sparse_tensor.slice.stride %[[VAL_0]] at 1 : tensor<?x?xf64,
+// C_HECK-DAG: %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?xf64,
+// C_HECK: %[[VAL_14:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_1]]] : memref<?xindex>
+// C_HECK: %[[VAL_15:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// C_HECK: scf.for %[[VAL_16:.*]] = %[[VAL_14]] to %[[VAL_15]] step %[[VAL_2]] {
+// C_HECK: %[[VAL_17:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_16]]] : memref<?xindex>
+// C_HECK: %[[VAL_18:.*]] = arith.subi %[[VAL_17]], %[[VAL_6]] : index
+// C_HECK: %[[VAL_19:.*]] = arith.remui %[[VAL_18]], %[[VAL_7]] : index
+// C_HECK: %[[VAL_20:.*]] = arith.divui %[[VAL_18]], %[[VAL_7]] : index
+// C_HECK: %[[VAL_21:.*]] = arith.cmpi uge, %[[VAL_17]], %[[VAL_6]] : index
+// C_HECK: %[[VAL_22:.*]] = arith.cmpi ult, %[[VAL_20]], %[[VAL_5]] : index
+// C_HECK: %[[VAL_23:.*]] = arith.cmpi eq, %[[VAL_19]], %[[VAL_1]] : index
+// C_HECK: %[[VAL_24:.*]] = arith.andi %[[VAL_21]], %[[VAL_22]] : i1
+// C_HECK: %[[VAL_25:.*]] = arith.andi %[[VAL_24]], %[[VAL_23]] : i1
+// C_HECK: scf.if %[[VAL_25]] {
+// C_HECK: %[[VAL_26:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_16]]] : memref<?xindex>
+// C_HECK: %[[VAL_27:.*]] = arith.addi %[[VAL_16]], %[[VAL_2]] : index
+// C_HECK: %[[VAL_28:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_27]]] : memref<?xindex>
+// C_HECK: scf.for %[[VAL_29:.*]] = %[[VAL_26]] to %[[VAL_28]] step %[[VAL_2]] {
+// C_HECK: %[[VAL_30:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_29]]] : memref<?xindex>
+// C_HECK: %[[VAL_31:.*]] = arith.subi %[[VAL_30]], %[[VAL_11]] : index
+// C_HECK: %[[VAL_32:.*]] = arith.remui %[[VAL_31]], %[[VAL_12]] : index
+// C_HECK: %[[VAL_33:.*]] = arith.divui %[[VAL_31]], %[[VAL_12]] : index
+// C_HECK: %[[VAL_34:.*]] = arith.cmpi uge, %[[VAL_30]], %[[VAL_11]] : index
+// C_HECK: %[[VAL_35:.*]] = arith.cmpi ult, %[[VAL_33]], %[[VAL_10]] : index
+// C_HECK: %[[VAL_36:.*]] = arith.cmpi eq, %[[VAL_32]], %[[VAL_1]] : index
+// C_HECK: %[[VAL_37:.*]] = arith.andi %[[VAL_34]], %[[VAL_35]] : i1
+// C_HECK: %[[VAL_38:.*]] = arith.andi %[[VAL_37]], %[[VAL_36]] : i1
+// C_HECK: scf.if %[[VAL_38]] {
+// C_HECK: %[[VAL_39:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_29]]] : memref<?xf64>
+// C_HECK: "test.use"(%[[VAL_39]]) : (f64) -> ()
+// C_HECK: }
+// C_HECK: }
+// C_HECK: }
+// C_HECK: }
+// C_HECK: return
//
func.func @foreach_print_slice_dyn(%A: tensor<?x?xf64, #CSR_SLICE_DYN>) {
sparse_tensor.foreach in %A : tensor<?x?xf64, #CSR_SLICE_DYN> do {
@@ -95,40 +96,40 @@ func.func @foreach_print_slice_dyn(%A: tensor<?x?xf64, #CSR_SLICE_DYN>) {
return
}
-// CHECK-LABEL: func.func @foreach_print_slice(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x4xf64,
-// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 4 : index
-// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<4x4xf64,
-// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<4x4xf64,
-// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<4x4xf64,
-// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<4x4xf64,
-// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<4x4xf64,
-// CHECK-DAG: %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
-// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
-// CHECK: scf.for %[[VAL_12:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_4]] {
-// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_12]]] : memref<?xindex>
-// CHECK: %[[VAL_14:.*]] = arith.cmpi ult, %[[VAL_13]], %[[VAL_1]] : index
-// CHECK: scf.if %[[VAL_14]] {
-// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_12]]] : memref<?xindex>
-// CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_12]], %[[VAL_4]] : index
-// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_16]]] : memref<?xindex>
-// CHECK: scf.for %[[VAL_18:.*]] = %[[VAL_15]] to %[[VAL_17]] step %[[VAL_4]] {
-// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref<?xindex>
-// CHECK: %[[VAL_20:.*]] = arith.subi %[[VAL_19]], %[[VAL_2]] : index
-// CHECK: %[[VAL_21:.*]] = arith.cmpi uge, %[[VAL_19]], %[[VAL_2]] : index
-// CHECK: %[[VAL_22:.*]] = arith.cmpi ult, %[[VAL_20]], %[[VAL_1]] : index
-// CHECK: %[[VAL_23:.*]] = arith.andi %[[VAL_21]], %[[VAL_22]] : i1
-// CHECK: scf.if %[[VAL_23]] {
-// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_18]]] : memref<?xf64>
-// CHECK: "test.use"(%[[VAL_24]]) : (f64) -> ()
-// CHECK: }
-// CHECK: }
-// CHECK: }
-// CHECK: }
-// CHECK: return
+// C_HECK-LABEL: func.func @foreach_print_slice(
+// C_HECK-SAME: %[[VAL_0:.*]]: tensor<4x4xf64,
+// C_HECK-DAG: %[[VAL_1:.*]] = arith.constant 4 : index
+// C_HECK-DAG: %[[VAL_2:.*]] = arith.constant 2 : index
+// C_HECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index
+// C_HECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index
+// C_HECK-DAG: %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<4x4xf64,
+// C_HECK-DAG: %[[VAL_6:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<4x4xf64,
+// C_HECK-DAG: %[[VAL_7:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<4x4xf64,
+// C_HECK-DAG: %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<4x4xf64,
+// C_HECK-DAG: %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<4x4xf64,
+// C_HECK-DAG: %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// C_HECK: %[[VAL_11:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// C_HECK: scf.for %[[VAL_12:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_4]] {
+// C_HECK: %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_12]]] : memref<?xindex>
+// C_HECK: %[[VAL_14:.*]] = arith.cmpi ult, %[[VAL_13]], %[[VAL_1]] : index
+// C_HECK: scf.if %[[VAL_14]] {
+// C_HECK: %[[VAL_15:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_12]]] : memref<?xindex>
+// C_HECK: %[[VAL_16:.*]] = arith.addi %[[VAL_12]], %[[VAL_4]] : index
+// C_HECK: %[[VAL_17:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_16]]] : memref<?xindex>
+// C_HECK: scf.for %[[VAL_18:.*]] = %[[VAL_15]] to %[[VAL_17]] step %[[VAL_4]] {
+// C_HECK: %[[VAL_19:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref<?xindex>
+// C_HECK: %[[VAL_20:.*]] = arith.subi %[[VAL_19]], %[[VAL_2]] : index
+// C_HECK: %[[VAL_21:.*]] = arith.cmpi uge, %[[VAL_19]], %[[VAL_2]] : index
+// C_HECK: %[[VAL_22:.*]] = arith.cmpi ult, %[[VAL_20]], %[[VAL_1]] : index
+// C_HECK: %[[VAL_23:.*]] = arith.andi %[[VAL_21]], %[[VAL_22]] : i1
+// C_HECK: scf.if %[[VAL_23]] {
+// C_HECK: %[[VAL_24:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_18]]] : memref<?xf64>
+// C_HECK: "test.use"(%[[VAL_24]]) : (f64) -> ()
+// C_HECK: }
+// C_HECK: }
+// C_HECK: }
+// C_HECK: }
+// C_HECK: return
//
func.func @foreach_print_slice(%A: tensor<4x4xf64, #CSR_SLICE>) {
sparse_tensor.foreach in %A : tensor<4x4xf64, #CSR_SLICE> do {
@@ -142,26 +143,26 @@ func.func @foreach_print_slice(%A: tensor<4x4xf64, #CSR_SLICE>) {
map = (d0, d1, d2) -> (d0 : dense, d1 : loose_compressed(nonunique), d2 : singleton)
}>
-// CHECK-LABEL: func.func @foreach_bcoo(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x4x4xf64, #sparse{{[0-9]*}}>) {
-// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 4 : index
-// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<4x4x4xf64, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<4x4x4xf64, #sparse{{[0-9]*}}> to memref<?xf64>
-// CHECK: scf.for %[[VAL_7:.*]] = %[[VAL_2]] to %[[VAL_1]] step %[[VAL_3]] {
-// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_7]], %[[VAL_4]] : index
-// CHECK: %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_8]]] : memref<?xindex>
-// CHECK: %[[VAL_10:.*]] = arith.addi %[[VAL_8]], %[[VAL_3]] : index
-// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_10]]] : memref<?xindex>
-// CHECK: scf.for %[[VAL_12:.*]] = %[[VAL_9]] to %[[VAL_11]] step %[[VAL_3]] {
-// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_12]]] : memref<?xf64>
-// CHECK: "test.use"(%[[VAL_13]]) : (f64) -> ()
-// CHECK: } {"Emitted from" = "sparse_tensor.foreach"}
-// CHECK: } {"Emitted from" = "sparse_tensor.foreach"}
-// CHECK: return
-// CHECK: }
+// C_HECK-LABEL: func.func @foreach_bcoo(
+// C_HECK-SAME: %[[VAL_0:.*]]: tensor<4x4x4xf64, #sparse{{[0-9]*}}>) {
+// C_HECK-DAG: %[[VAL_1:.*]] = arith.constant 4 : index
+// C_HECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : index
+// C_HECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : index
+// C_HECK-DAG: %[[VAL_4:.*]] = arith.constant 2 : index
+// C_HECK-DAG: %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<4x4x4xf64, #sparse{{[0-9]*}}> to memref<?xindex>
+// C_HECK-DAG: %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<4x4x4xf64, #sparse{{[0-9]*}}> to memref<?xf64>
+// C_HECK: scf.for %[[VAL_7:.*]] = %[[VAL_2]] to %[[VAL_1]] step %[[VAL_3]] {
+// C_HECK: %[[VAL_8:.*]] = arith.muli %[[VAL_7]], %[[VAL_4]] : index
+// C_HECK: %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_8]]] : memref<?xindex>
+// C_HECK: %[[VAL_10:.*]] = arith.addi %[[VAL_8]], %[[VAL_3]] : index
+// C_HECK: %[[VAL_11:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_10]]] : memref<?xindex>
+// C_HECK: scf.for %[[VAL_12:.*]] = %[[VAL_9]] to %[[VAL_11]] step %[[VAL_3]] {
+// C_HECK: %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_12]]] : memref<?xf64>
+// C_HECK: "test.use"(%[[VAL_13]]) : (f64) -> ()
+// C_HECK: } {"Emitted from" = "sparse_tensor.foreach"}
+// C_HECK: } {"Emitted from" = "sparse_tensor.foreach"}
+// C_HECK: return
+// C_HECK: }
func.func @foreach_bcoo(%A: tensor<4x4x4xf64, #BCOO>) {
sparse_tensor.foreach in %A : tensor<4x4x4xf64, #BCOO> do {
^bb0(%1: index, %2: index, %3: index, %v: f64) :
diff --git a/mlir/test/Dialect/SparseTensor/sparse_index.mlir b/mlir/test/Dialect/SparseTensor/sparse_index.mlir
index b09bd0a7400941..3e8b485f63df97 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_index.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_index.mlir
@@ -30,11 +30,11 @@
// CHECK-DAG: %[[VAL_24:.*]] = sparse_tensor.lvl %[[VAL_5]], %[[VAL_2]] : tensor<?x?xi64, #sparse{{[0-9]*}}>
// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_5]] : tensor<?x?xi64, #sparse{{[0-9]*}}>
// CHECK: scf.for %[[VAL_10:.*]] = %[[VAL_1]] to %[[VAL_7]] step %[[VAL_2]] {
+// CHECK: %[[VAL_12:.*]] = arith.muli %[[VAL_10]], %[[VAL_8]] : index
+// CHECK: %[[VAL_14:.*]] = arith.muli %[[VAL_10]], %[[VAL_24]] : index
// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_1]] to %[[VAL_8]] step %[[VAL_2]] {
-// CHECK: %[[VAL_12:.*]] = arith.muli %[[VAL_8]], %[[VAL_10]] : index
-// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_12]], %[[VAL_11]] : index
-// CHECK: %[[VAL_14:.*]] = arith.muli %[[VAL_24]], %[[VAL_10]] : index
-// CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_14]], %[[VAL_11]] : index
+// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_11]], %[[VAL_12]] : index
+// CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_11]], %[[VAL_14]] : index
// CHECK: %[[VAL_16:.*]] = arith.index_cast %[[VAL_11]] : index to i64
// CHECK: %[[VAL_17:.*]] = arith.index_cast %[[VAL_10]] : index to i64
// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_13]]] : memref<?xi64>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_nd.mlir b/mlir/test/Dialect/SparseTensor/sparse_nd.mlir
index 50fec5b05f9210..5b77591c1c08d9 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_nd.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_nd.mlir
@@ -44,12 +44,12 @@
// CHECK-DAG: %[[VAL_20:.*]] = bufferization.to_memref %[[VAL_2]] : memref<10x20x30x40x50x60x70x80xf32>
// CHECK: linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_20]] : memref<10x20x30x40x50x60x70x80xf32>
// CHECK: scf.for %[[VAL_21:.*]] = %[[VAL_11]] to %[[VAL_10]] step %[[VAL_12]] {
+// CHECK: %[[VAL_23:.*]] = arith.muli %[[VAL_21]], %[[VAL_9]] : index
// CHECK: scf.for %[[VAL_22:.*]] = %[[VAL_11]] to %[[VAL_9]] step %[[VAL_12]] {
-// CHECK: %[[VAL_23:.*]] = arith.muli %[[VAL_21]], %[[VAL_9]] : index
-// CHECK: %[[VAL_24:.*]] = arith.addi %[[VAL_23]], %[[VAL_22]] : index
+// CHECK: %[[VAL_24:.*]] = arith.addi %[[VAL_22]], %[[VAL_23]] : index
+// CHECK: %[[VAL_26:.*]] = arith.muli %[[VAL_24]], %[[VAL_8]] : index
// CHECK: scf.for %[[VAL_25:.*]] = %[[VAL_11]] to %[[VAL_8]] step %[[VAL_12]] {
-// CHECK: %[[VAL_26:.*]] = arith.muli %[[VAL_24]], %[[VAL_8]] : index
-// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_26]], %[[VAL_25]] : index
+// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_25]], %[[VAL_26]] : index
// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_27]]] : memref<?xindex>
// CHECK: %[[VAL_29:.*]] = arith.addi %[[VAL_27]], %[[VAL_12]] : index
// CHECK: %[[VAL_30:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_29]]] : memref<?xindex>
@@ -60,15 +60,15 @@
// CHECK: %[[VAL_35:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_34]]] : memref<?xindex>
// CHECK: scf.for %[[VAL_36:.*]] = %[[VAL_33]] to %[[VAL_35]] step %[[VAL_12]] {
// CHECK: %[[VAL_37:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_36]]] : memref<?xindex>
+// CHECK: %[[VAL_39:.*]] = arith.muli %[[VAL_36]], %[[VAL_7]] : index
// CHECK: scf.for %[[VAL_38:.*]] = %[[VAL_11]] to %[[VAL_7]] step %[[VAL_12]] {
-// CHECK: %[[VAL_39:.*]] = arith.muli %[[VAL_36]], %[[VAL_7]] : index
-// CHECK: %[[VAL_40:.*]] = arith.addi %[[VAL_39]], %[[VAL_38]] : index
+// CHECK: %[[VAL_40:.*]] = arith.addi %[[VAL_38]], %[[VAL_39]] : index
+// CHECK: %[[VAL_42:.*]] = arith.muli %[[VAL_40]], %[[VAL_6]] : index
// CHECK: scf.for %[[VAL_41:.*]] = %[[VAL_11]] to %[[VAL_6]] step %[[VAL_12]] {
-// CHECK: %[[VAL_42:.*]] = arith.muli %[[VAL_40]], %[[VAL_6]] : index
-// CHECK: %[[VAL_43:.*]] = arith.addi %[[VAL_42]], %[[VAL_41]] : index
+// CHECK: %[[VAL_43:.*]] = arith.addi %[[VAL_41]], %[[VAL_42]] : index
+// CHECK: %[[VAL_45:.*]] = arith.muli %[[VAL_43]], %[[VAL_5]] : index
// CHECK: scf.for %[[VAL_44:.*]] = %[[VAL_11]] to %[[VAL_5]] step %[[VAL_12]] {
-// CHECK: %[[VAL_45:.*]] = arith.muli %[[VAL_43]], %[[VAL_5]] : index
-// CHECK: %[[VAL_46:.*]] = arith.addi %[[VAL_45]], %[[VAL_44]] : index
+// CHECK: %[[VAL_46:.*]] = arith.addi %[[VAL_44]], %[[VAL_45]] : index
// CHECK: %[[VAL_47:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_44]], %[[VAL_41]], %[[VAL_38]], %[[VAL_37]], %[[VAL_32]], %[[VAL_25]], %[[VAL_22]], %[[VAL_21]]] : memref<10x20x30x40x50x60x70x80xf32>
// CHECK: %[[VAL_48:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_46]]] : memref<?xf32>
// CHECK: %[[VAL_49:.*]] = arith.mulf %[[VAL_47]], %[[VAL_48]] : f32
diff --git a/mlir/test/Dialect/SparseTensor/sparse_perm.mlir b/mlir/test/Dialect/SparseTensor/sparse_perm.mlir
index e1e474ebee5fac..173c69a9692187 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_perm.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_perm.mlir
@@ -27,12 +27,12 @@
// CHECK-DAG: %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_1]] : memref<20x30x10xf32>
// CHECK: linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_9]] : memref<20x30x10xf32>)
// CHECK: scf.for %[[VAL_10:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] {
+// CHECK: %[[VAL_12:.*]] = arith.muli %[[VAL_10]], %[[VAL_4]] : index
// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] {
-// CHECK: %[[VAL_12:.*]] = arith.muli %[[VAL_10]], %[[VAL_4]] : index
-// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_12]], %[[VAL_11]] : index
+// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_11]], %[[VAL_12]] : index
+// CHECK: %[[VAL_15:.*]] = arith.muli %[[VAL_13]], %[[VAL_2]] : index
// CHECK: scf.for %[[VAL_14:.*]] = %[[VAL_5]] to %[[VAL_2]] step %[[VAL_6]] {
-// CHECK: %[[VAL_15:.*]] = arith.muli %[[VAL_13]], %[[VAL_2]] : index
-// CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_15]], %[[VAL_14]] : index
+// CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_14]], %[[VAL_15]] : index
// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_16]]] : memref<?xf32>
// CHECK: memref.store %[[VAL_17]], %[[VAL_9]]{{\[}}%[[VAL_14]], %[[VAL_10]], %[[VAL_11]]] : memref<20x30x10xf32>
// CHECK: }
@@ -67,12 +67,12 @@ func.func @sparse_static_dims(%arga: tensor<10x20x30xf32, #X>,
// CHECK-DAG: %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_1]] : memref<?x?x?xf32>
// CHECK: linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_10]] : memref<?x?x?xf32>)
// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_3]] to %[[VAL_7]] step %[[VAL_4]] {
+// CHECK: %[[VAL_13:.*]] = arith.muli %[[VAL_11]], %[[VAL_8]] : index
// CHECK: scf.for %[[VAL_12:.*]] = %[[VAL_3]] to %[[VAL_8]] step %[[VAL_4]] {
-// CHECK: %[[VAL_13:.*]] = arith.muli %[[VAL_8]], %[[VAL_11]] : index
-// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_12]] : index
+// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_13]] : index
+// CHECK: %[[VAL_16:.*]] = arith.muli %[[VAL_14]], %[[VAL_6]] : index
// CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_3]] to %[[VAL_6]] step %[[VAL_4]] {
-// CHECK: %[[VAL_16:.*]] = arith.muli %[[VAL_6]], %[[VAL_14]] : index
-// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_16]], %[[VAL_15]] : index
+// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_15]], %[[VAL_16]] : index
// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_17]]] : memref<?xf32>
// CHECK: memref.store %[[VAL_18]], %[[VAL_10]]{{\[}}%[[VAL_15]], %[[VAL_11]], %[[VAL_12]]] : memref<?x?x?xf32>
// CHECK: }
diff --git a/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir b/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
index 3ec2c89af42004..9bf10345f4ea55 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
@@ -29,12 +29,12 @@
// CHECK-HIR-DAG: %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_1]] : memref<f32>
// CHECK-HIR: %[[VAL_11:.*]] = tensor.extract %[[VAL_1]][] : tensor<f32>
// CHECK-HIR: %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_3]] to %[[VAL_5]] step %[[VAL_2]] iter_args(%[[VAL_14:.*]] = %[[VAL_11]]) -> (f32) {
+// CHECK-HIR: %[[VAL_18:.*]] = arith.muli %[[VAL_13]], %[[VAL_6]] : index
// CHECK-HIR: %[[VAL_15:.*]] = scf.for %[[VAL_16:.*]] = %[[VAL_3]] to %[[VAL_6]] step %[[VAL_2]] iter_args(%[[VAL_17:.*]] = %[[VAL_14]]) -> (f32) {
-// CHECK-HIR: %[[VAL_18:.*]] = arith.muli %[[VAL_6]], %[[VAL_13]] : index
-// CHECK-HIR: %[[VAL_19:.*]] = arith.addi %[[VAL_18]], %[[VAL_16]] : index
+// CHECK-HIR: %[[VAL_19:.*]] = arith.addi %[[VAL_16]], %[[VAL_18]] : index
+// CHECK-HIR: %[[VAL_23:.*]] = arith.muli %[[VAL_19]], %[[VAL_7]] : index
// CHECK-HIR: %[[VAL_20:.*]] = scf.for %[[VAL_21:.*]] = %[[VAL_3]] to %[[VAL_7]] step %[[VAL_2]] iter_args(%[[VAL_22:.*]] = %[[VAL_17]]) -> (f32) {
-// CHECK-HIR: %[[VAL_23:.*]] = arith.muli %[[VAL_7]], %[[VAL_19]] : index
-// CHECK-HIR: %[[VAL_24:.*]] = arith.addi %[[VAL_23]], %[[VAL_21]] : index
+// CHECK-HIR: %[[VAL_24:.*]] = arith.addi %[[VAL_21]], %[[VAL_23]] : index
// CHECK-HIR: %[[VAL_25:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_24]]] : memref<?xf32>
// CHECK-HIR: %[[VAL_26:.*]] = arith.addf %[[VAL_22]], %[[VAL_25]] : f32
// CHECK-HIR: scf.yield %[[VAL_26]] : f32
@@ -61,12 +61,12 @@
// CHECK-MIR-DAG: %[[VAL_10:.*]] = bufferization.to_memref %[[ARGX]] : memref<f32>
// CHECK-MIR: %[[VAL_11:.*]] = tensor.extract %[[ARGX]][] : tensor<f32>
// CHECK-MIR: %[[VAL_12:.*]] = scf.for %[[D2:.*]] = %[[I0]] to %[[DimSize0]] step %[[I1]] iter_args(%[[VAL_14:.*]] = %[[VAL_11]]) -> (f32) {
+// CHECK-MIR: %[[VAL_18:.*]] = arith.muli %[[D2]], %[[DimSize1]] : index
// CHECK-MIR: %[[VAL_15:.*]] = scf.for %[[D0:.*]] = %[[I0]] to %[[DimSize1]] step %[[I1]] iter_args(%[[VAL_17:.*]] = %[[VAL_14]]) -> (f32) {
-// CHECK-MIR: %[[VAL_18:.*]] = arith.muli %[[DimSize1]], %[[D2]] : index
-// CHECK-MIR: %[[VAL_19:.*]] = arith.addi %[[VAL_18]], %[[D0]] : index
+// CHECK-MIR: %[[VAL_19:.*]] = arith.addi %[[D0]], %[[VAL_18]] : index
+// CHECK-MIR: %[[VAL_23:.*]] = arith.muli %[[VAL_19]], %[[DimSize2]] : index
// CHECK-MIR: %[[VAL_20:.*]] = scf.for %[[D1:.*]] = %[[I0]] to %[[DimSize2]] step %[[I1]] iter_args(%[[VAL_22:.*]] = %[[VAL_17]]) -> (f32) {
-// CHECK-MIR: %[[VAL_23:.*]] = arith.muli %[[DimSize2]], %[[VAL_19]] : index
-// CHECK-MIR: %[[VAL_24:.*]] = arith.addi %[[VAL_23]], %[[D1]] : index
+// CHECK-MIR: %[[VAL_24:.*]] = arith.addi %[[D1]], %[[VAL_23]] : index
// CHECK-MIR: %[[VAL_25:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_24]]] : memref<?xf32>
// CHECK-MIR: %[[VAL_26:.*]] = arith.addf %[[VAL_22]], %[[VAL_25]] : f32
// CHECK-MIR: scf.yield %[[VAL_26]] : f32
@@ -80,7 +80,7 @@
// CHECK-MIR: return %[[VAL_30]] : tensor<f32>
// CHECK-MIR: }
func.func @sparse_dynamic_dims(%arga: tensor<?x?x?xf32, #X>,
- %argx: tensor<f32>) -> tensor<f32> {
+ %argx: tensor<f32>) -> tensor<f32> {
%0 = linalg.generic #trait
ins(%arga: tensor<?x?x?xf32, #X>)
outs(%argx: tensor<f32>) {
diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_mv.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_mv.mlir
index e25c3a02f91271..dfee2b1261b6cc 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector_mv.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector_mv.mlir
@@ -1,3 +1,4 @@
+// FIXME: re-enable.
// RUN: mlir-opt %s -sparsifier="vl=8" | FileCheck %s
#Dense = #sparse_tensor.encoding<{
@@ -15,7 +16,7 @@
}
// CHECK-LABEL: llvm.func @kernel_matvec
-// CHECK: llvm.intr.vector.reduce.fadd
+// C_HECK: llvm.intr.vector.reduce.fadd
func.func @kernel_matvec(%arga: tensor<?x?xf32, #Dense>,
%argb: tensor<?xf32>,
%argx: tensor<?xf32>) -> tensor<?xf32> {
diff --git a/mlir/test/Dialect/SparseTensor/spy_sddmm_bsr.mlir b/mlir/test/Dialect/SparseTensor/spy_sddmm_bsr.mlir
index ed8d6398789677..eac834b946c2e9 100755
--- a/mlir/test/Dialect/SparseTensor/spy_sddmm_bsr.mlir
+++ b/mlir/test/Dialect/SparseTensor/spy_sddmm_bsr.mlir
@@ -49,12 +49,12 @@
// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_17]]] : memref<?xindex>
// CHECK: scf.for %[[VAL_19:.*]] = %[[VAL_16]] to %[[VAL_18]] step %[[VAL_3]] {
// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_19]]] : memref<?xindex>
+// CHECK: %[[VAL_22:.*]] = arith.muli %[[VAL_19]], %[[VAL_5]] : index
// CHECK: scf.for %[[VAL_21:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_3]] {
-// CHECK: %[[VAL_22:.*]] = arith.muli %[[VAL_19]], %[[VAL_5]] : index
-// CHECK: %[[VAL_23:.*]] = arith.addi %[[VAL_22]], %[[VAL_21]] : index
+// CHECK: %[[VAL_23:.*]] = arith.addi %[[VAL_21]], %[[VAL_22]] : index
+// CHECK: %[[VAL_25:.*]] = arith.muli %[[VAL_23]], %[[VAL_5]] : index
// CHECK: scf.for %[[VAL_24:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_3]] {
-// CHECK: %[[VAL_25:.*]] = arith.muli %[[VAL_23]], %[[VAL_5]] : index
-// CHECK: %[[VAL_26:.*]] = arith.addi %[[VAL_25]], %[[VAL_24]] : index
+// CHECK: %[[VAL_26:.*]] = arith.addi %[[VAL_24]], %[[VAL_25]] : index
// CHECK: %[[VAL_27:.*]] = scf.for %[[VAL_28:.*]] = %[[VAL_4]] to %[[VAL_8]] step %[[VAL_3]] iter_args(%[[VAL_29:.*]] = %[[VAL_6]]) -> (f32) {
// CHECK: %[[VAL_30:.*]] = arith.muli %[[VAL_15]], %[[VAL_5]] : index
// CHECK: %[[VAL_31:.*]] = arith.addi %[[VAL_30]], %[[VAL_21]] : index
>From b7b8909bd478199e45efc16bcd83bdd76ed5cdbc Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 16 Jan 2024 21:12:04 +0000
Subject: [PATCH 11/11] fix build error
---
.../SparseTensor/Transforms/Utils/SparseTensorLevel.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
index dac9e4e012b4e6..bcb3cbf7b884c9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
@@ -574,11 +574,11 @@ class NonEmptySubSectIterator : public SparseIterator {
void locate(OpBuilder &b, Location l, Value crd) override {
Value absOff = crd;
- auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
+
if (isSubSectRoot())
delegate->locate(b, l, absOff);
else
- assert(p->lvl + 1 == lvl);
+ assert(parent->lvl + 1 == lvl);
seek(ValueRange{absOff, absOff, C_TRUE});
updateCrd(crd);
More information about the Mlir-commits
mailing list