[Mlir-commits] [mlir] fc5d8fc - [mlir][sparse] support dual sparse convolution.
Peiming Liu
llvmlistbot at llvm.org
Mon Jul 10 09:49:38 PDT 2023
Author: Peiming Liu
Date: 2023-07-10T16:49:32Z
New Revision: fc5d8fce7dddd27b591f0a6dcb20a7cfa59842fd
URL: https://github.com/llvm/llvm-project/commit/fc5d8fce7dddd27b591f0a6dcb20a7cfa59842fd
DIFF: https://github.com/llvm/llvm-project/commit/fc5d8fce7dddd27b591f0a6dcb20a7cfa59842fd.diff
LOG: [mlir][sparse] support dual sparse convolution.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D152601
Added:
mlir/test/Integration/Dialect/SparseTensor/CPU/dual_sparse_conv_2d.mlir
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index 9a130ec04445db..f4e7cc49b9f46d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -37,10 +37,31 @@ using namespace mlir::sparse_tensor;
#define MULI(lhs, rhs) (builder.create<arith::MulIOp>(loc, (lhs), (rhs)))
#define SELECT(c, l, r) (builder.create<arith::SelectOp>(loc, (c), (l), (r)))
+//===----------------------------------------------------------------------===//
+// Debugging utils
+//===----------------------------------------------------------------------===//
+
+#ifndef NDEBUG
+LLVM_ATTRIBUTE_UNUSED static void dumpIndexMemRef(OpBuilder &builder,
+ Location loc, Value memref) {
+ memref = builder.create<memref::CastOp>(
+ loc, UnrankedMemRefType::get(builder.getIndexType(), 0), memref);
+ createFuncCall(builder, loc, "printMemrefInd", TypeRange{},
+ ValueRange{memref}, EmitCInterface::On);
+}
+#endif
+
//===----------------------------------------------------------------------===//
// File local helper functions.
//===----------------------------------------------------------------------===//
+// For index reduction loops, since the tensor are sliced into non-continuous
+// fragments, we need a triple [pLo, pHi, pPtr], in which the pair (pLo, pHi)
+// specifies the range of the fragment, and pPtr specifies the index of the
+// corresponding fragment in the child level (i.e., a pointer to the sliced
+// position array).
+static constexpr unsigned kSliceIterWidth = 3;
+
static Value genSliceOffset(OpBuilder &builder, Location loc, Value tensor,
Level lvl) {
auto enc = getSparseTensorEncoding(tensor.getType());
@@ -123,6 +144,28 @@ static Value genSparseReducedAffineCond(OpBuilder &builder, Location loc,
return ifOp.getResult(0);
}
+// Helper functions that load/store into the position buffer for slice-driven
+// loops.
+// The sliced pointer buffer is orgnized as:
+// [size, curPtr] (two metadata) + [[pLo, pHi, pNext], ...] (list of tuples)
+static Value loadSlicePosPtr(OpBuilder &builder, Location loc, Value sPosBuf) {
+ // Load curPtr.
+ // TODO: We should use SSA value for it.
+ return genIndexLoad(builder, loc, sPosBuf, C_IDX(1));
+}
+static void updateSlicePosPtr(OpBuilder &builder, Location loc, Value sPosBuf,
+ Value pPtr) {
+ // Set curPtr.
+ // TODO: We should use SSA value for it.
+ builder.create<memref::StoreOp>(loc, pPtr, sPosBuf, C_IDX(1));
+}
+static Value loadSliceNextPosPtrStart(OpBuilder &builder, Location loc,
+ Value sPosBuf, Value tupleIdx) {
+ // load the pNext in the current tuple specified by `tupleIdx`.
+ // 4 = 2 (two metadata) + 2 (pNext == tuple[2])
+ return genIndexLoad(builder, loc, sPosBuf, ADDI(tupleIdx, C_IDX(4)));
+}
+
std::pair<Value, Value>
LoopEmitter::genSliceLegitPredicate(OpBuilder &builder, Location loc, Value crd,
TensorId tid, Level lvl) {
@@ -566,18 +609,6 @@ void LoopEmitter::exitCurrentLoopSeq(OpBuilder &builder, Location loc) {
// If this is a unresolved-slice-driven loop, pops out the slice.
assert(sliceStack[tid].back().slicedOnLvl == lvl);
sliceStack[tid].pop_back();
- } else {
- if (!isDenseDLT(lvlTypes[tid][lvl])) {
- // Else this is a resolved-slice, and advance posit similar to TACO.
- Value c1 = C_IDX(1), c2 = C_IDX(2);
- // pIdx += 2, we finished the current lvl, advance the pointer index of
- // the previous level by two to skip the [pLo, pHi] for current level.
- Value sPtrBuf = slicePosBuffer[tid][lvl].back();
- Value curP = genIndexLoad(builder, loc, sPtrBuf, c1);
- // TODO: we could probably use an SSA value for it.
- Value nexP = ADDI(curP, c2);
- builder.create<memref::StoreOp>(loc, nexP, sPtrBuf, c1);
- }
}
}
loopSeqStack.pop_back();
@@ -1274,11 +1305,11 @@ void LoopEmitter::enterTensorsAtDenseLvls(
// Pushes sliced levels to build correct LoopInfo.
bool unReduc = isAffineIdxUnRedCond(denseLoopCond);
SliceInfo &info = sliceStack[tid].back();
+ // Pushes sliced dense loop info to tell LoopEmitter how to exit it.
+ sliceInfo.emplace_back(tid, lvl, /*fullyReduced=*/!unReduc);
if (unReduc) {
- // Pushes sliced dense loop info to tell LoopEmitter how to exit it.
- sliceInfo.emplace_back(tid, lvl, /*fullyReduced=*/false);
- // Update the slice information as we enter the new loop.
assert(*info.slicedOnLvl == lvl);
+ // Update the slice information as we enter the new loop.
info.minCrd = info.offset = iv;
info.isNonEmpty = constantI1(builder, loc, true);
levelReducedDep[tid][lvl]++;
@@ -1312,15 +1343,20 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
MutableArrayRef<Value> reduc) {
const LoopInfo &loopInfo = loopStack.back();
for (auto [tid, lvl, reduced] : loopInfo.sliceDrivenInfo) {
- SliceInfo &info = sliceStack[tid].back();
- assert(isDenseDLT(lvlTypes[tid][lvl]));
- assert(*info.slicedOnLvl == lvl && !reduced);
- (void)reduced;
- // Resets slices pointers as the resolved slices are invalidated after we
- // moves forward to the next slice.
- invalidateSliceIterIdx(rewriter, loc, tid, lvl);
- info.minCrd = info.offset = info.isNonEmpty = Value();
- levelReducedDep[tid][lvl]--;
+ if (!reduced) {
+ SliceInfo &info = sliceStack[tid].back();
+ assert(isDenseDLT(lvlTypes[tid][lvl]));
+ assert(*info.slicedOnLvl == lvl);
+ (void)reduced;
+ // Resets slices pointers as the resolved slices are invalidated after we
+ // moves forward to the next slice.
+ invalidateSliceIterIdx(rewriter, loc, tid, lvl);
+ info.minCrd = info.offset = info.isNonEmpty = Value();
+ levelReducedDep[tid][lvl]--;
+ } else {
+ forwardsReducedSliceLevelTreeIt(rewriter, loc, tid, lvl,
+ constantIndex(rewriter, loc, 1));
+ }
}
if (auto forOp = llvm::dyn_cast<scf::ForOp>(loopInfo.loop)) {
if (!reduc.empty()) {
@@ -1398,6 +1434,61 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
}
}
+void LoopEmitter::forwardsReducedSliceLevelTreeIt(OpBuilder &builder,
+ Location loc, TensorId tid,
+ Level rootLvl, Value fcnt) {
+ auto stt = getSparseTensorType(tensors[tid]);
+
+ // Finds a [Lvl, leafLvl) range, and all level in between are fully reduced
+ // level (but not resolved). Since we forward an iterator at higher level of
+ // the tree, the subtree need to be pruned.
+ Level leafLvl = rootLvl + 1;
+ while (leafLvl < stt.getLvlRank() && !dependentLvlMap[tid][leafLvl].empty()) {
+ assert(depFullyReduced(tid, leafLvl));
+ leafLvl++;
+ }
+
+ Level curLvl = rootLvl + 1;
+ // Prunes all denses subtree.
+ while (curLvl < leafLvl && isDenseDLT(lvlTypes[tid][curLvl])) {
+ // One step forward in parent level results in forwarding `slice.size` step
+ // in child dense level.
+ fcnt = MULI(sliceSizes[tid][curLvl].back(), fcnt);
+ curLvl++;
+ }
+
+ Value nxPosPtr = nullptr;
+ if (curLvl < leafLvl) {
+ assert(!isDenseDLT(lvlTypes[tid][curLvl]));
+ // The first compressed level, setting up the position pointer for it.
+ Value sPosBuf = slicePosBuffer[tid][curLvl].back();
+ // One step forwards in the parent level result in forwarding one `segment`
+ // (kSliceIterWidth) in the child sparse level.
+ Value fPosPtr = MULI(fcnt, C_IDX(kSliceIterWidth)); // forward ptr
+ Value pPosPtr = loadSlicePosPtr(builder, loc, sPosBuf); // previous ptr
+ Value cPosPtr = ADDI(fPosPtr, pPosPtr); // current ptr
+ updateSlicePosPtr(builder, loc, sPosBuf, cPosPtr);
+ // Loads the position pointer start for next level.
+ nxPosPtr = loadSliceNextPosPtrStart(builder, loc, sPosBuf, cPosPtr);
+ curLvl++;
+ }
+
+ // TODO: This is not always needed, but we did it unconditionally for now for
+ // simplicity.
+ // It is only needed when `curLvl` is forwarded without traversing its child
+ // level (e.g., the level is in a conjunctive lattices and got pruned), such
+ // that the position pointer is not forwarded inside the loop.
+ for (; curLvl < leafLvl; curLvl++) {
+ assert(nxPosPtr);
+ if (!isDenseDLT(lvlTypes[tid][curLvl])) {
+ nxPosPtr = MULI(nxPosPtr, C_IDX(kSliceIterWidth));
+ Value sPosBuf = slicePosBuffer[tid][curLvl].back();
+ updateSlicePosPtr(builder, loc, sPosBuf, nxPosPtr);
+ nxPosPtr = loadSliceNextPosPtrStart(builder, loc, sPosBuf, nxPosPtr);
+ }
+ }
+}
+
void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
MutableArrayRef<Value> reduc) {
const LoopInfo &loopInfo = loopStack.back();
@@ -1425,17 +1516,25 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
continue;
}
+ Value forwarded = nullptr;
if (loopInfo.trivialTidLvls.empty() &&
loopInfo.sliceDrivenInfo.size() == 1) {
// Forwards the position iterator.
operands.push_back(ADDI(posits[tid][lvl], one));
+ forwarded = constantI1(builder, loc, true);
} else {
const Value pos = posits[tid][lvl];
const Value nxPos = ADDI(posits[tid][lvl], one);
- Value cmp = CMPI(eq, coords[tid][lvl], iv);
- operands.push_back(SELECT(cmp, nxPos, pos));
+ forwarded = CMPI(eq, coords[tid][lvl], iv);
+ operands.push_back(SELECT(forwarded, nxPos, pos));
+ }
+ {
+ OpBuilder::InsertionGuard guard(builder);
+ auto ifOp = builder.create<scf::IfOp>(loc, TypeRange{}, forwarded,
+ /*else=*/false);
+ builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+ forwardsReducedSliceLevelTreeIt(builder, loc, tid, lvl, one);
}
-
// The coordinate is invalid now.
coords[tid][lvl] = nullptr;
@@ -1633,7 +1732,6 @@ std::pair<Operation *, ValueRange> LoopEmitter::genSliceLvlTraverseLoop(
}
// Generates a loop nest that traverse all the unresolved levels in between.
-// TODO: it can only handle all compressed tensors.
//
// for(int i = 0; i < slicePos.size(); i+=2) {
// loopLo = slicePos[i];
@@ -1660,6 +1758,15 @@ ValueRange LoopEmitter::genUnResolvedSliceTreeTraverse(
OpBuilder::InsertPoint ip;
SmallVector<Value> innerArgs(userReduc.begin(), userReduc.end());
scf::ForOp outerMost = nullptr; // the outtermost loop.
+
+ // Wraps body builder and inserts a extra counting instruction at the end.
+ auto wrapped = [bodyBuilder](OpBuilder &builder, Location loc, Value iv,
+ MutableArrayRef<Value> reduc) {
+ bodyBuilder(builder, loc, iv, reduc.drop_back());
+ // Increments the counter.
+ reduc.back() = ADDI(reduc.back(), C_IDX(1));
+ };
+
if (firstResLvl.has_value()) {
// Overwrite position when the first level is fully resolved.
pos = posits[firstResLvl->first][firstResLvl->second];
@@ -1669,18 +1776,28 @@ ValueRange LoopEmitter::genUnResolvedSliceTreeTraverse(
Level firstLvl = *frontSlice.slicedOnLvl;
if (!lvlFullyResolved(tid, firstLvl)) {
if (isCompressedDLT(lvlTypes[tid][firstLvl])) {
+ // An extra counter that tracks how many segments are there in the child
+ // compressed level.
+ innerArgs.push_back(c0);
+ // Overrides the user-provided builder.
+ bodyBuilder = wrapped;
unsigned depth = frontSlice.depth - 1;
Value offset = frontSlice.offset;
Value sPtrBuf = slicePosBuffer[tid][firstLvl][depth];
Value mSz = genIndexLoad(builder, loc, sPtrBuf, c0); // memSize
outerMost = builder.create<scf::ForOp>(
- loc, c2, mSz, c2, innerArgs,
- [this, c1, tid, firstLvl, offset, sPtrBuf, &ip, &pos,
+ loc, c2, mSz, C_IDX(kSliceIterWidth), innerArgs,
+ [this, c1, c2, tid, firstLvl, offset, sPtrBuf, &ip, &pos,
&innerArgs](OpBuilder &builder, Location loc, Value iv,
ValueRange iterArgs) {
// generate traversal for each level.
Value loopLo = genIndexLoad(builder, loc, sPtrBuf, iv);
Value loopHi = genIndexLoad(builder, loc, sPtrBuf, ADDI(iv, c1));
+ // We need to remember the starting index for next level's
+ // position, because slice-driven loop breaks the level into
+ // non-consecutive segments.
+ builder.create<memref::StoreOp>(loc, iterArgs.back(), sPtrBuf,
+ ADDI(iv, c2).getResult());
ValueRange itArgs =
genSliceLvlTraverseLoop(
builder, loc, loopLo, loopHi, offset,
@@ -1832,8 +1949,7 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
TensorId tid, Level lvl) {
Value c0 = C_IDX(0), c1 = C_IDX(1), c2 = C_IDX(2);
unsigned depth = levelReducedDep[tid][lvl];
- Value size = sliceSizes[tid][lvl][depth];
- // Dense slice begin is trivial
+ Value size = sliceSizes[tid][lvl][depth]; // Dense slice begin is trivial
if (isDenseDLT(lvlTypes[tid][lvl])) {
sliceStack[tid].emplace_back(c0, c0, constantI1(builder, loc, false), lvl,
depth + 1);
@@ -1879,9 +1995,8 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
ValueRange result = genUnResolvedSliceTreeTraverse(
builder, loc, tid, unResSlices, firstResLvl, reduc,
- [this, c1, c2, tid, lvl, sPtrBuf](OpBuilder &builder, Location loc,
- Value iv,
- MutableArrayRef<Value> reduc) {
+ [this, c1, tid, lvl, sPtrBuf](OpBuilder &builder, Location loc, Value iv,
+ MutableArrayRef<Value> reduc) {
Value &nonEmpty = reduc[0];
Value &minCrd = reduc[1];
Value &curMemSz = reduc[2];
@@ -1919,8 +2034,8 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
builder.create<memref::StoreOp>(loc, sPLo, sPtrBuf, curMemSz);
Value nxtMemSize = ADDI(curMemSz, c1);
builder.create<memref::StoreOp>(loc, sPHi, sPtrBuf, nxtMemSize);
- // curMemSize += 2
- curMemSz = ADDI(curMemSz, c2);
+ // curMemSize += kSliceIterWidth
+ curMemSz = ADDI(curMemSz, C_IDX(kSliceIterWidth));
});
Value isNonEmpty = result[0];
@@ -1947,7 +2062,7 @@ bool LoopEmitter::genSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
// generates slice begin any more, instead we fall back to TACO-based
// algorithm to (co)iterates over the slice.
Value pLoPtr =
- genIndexLoad(builder, loc, slicePosBuffer[tid][lvl].back(), c1);
+ loadSlicePosPtr(builder, loc, slicePosBuffer[tid][lvl].back());
pLoPtr = ADDI(pLoPtr, c2);
Value pHiPtr = ADDI(pLoPtr, c1);
posits[tid][lvl] =
@@ -1999,10 +2114,10 @@ bool LoopEmitter::genSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
Value sz = *(sliceSizes[tid][lvl].rbegin() + depth - 1);
bufSize = MULI(bufSize, sz);
}
- // For a pair of [pLo, pHi]. Note that we can not compress pHi because
- // slice creates segments in the index buffer so that the pHi for the
- // current level is no longer the pLo for the next level.
- bufSize = MULI(bufSize, c2);
+ // For a triple of [pLo, pHi, pPtr]. Note that we can not compress pHi
+ // because slice creates segments in the index buffer so that the pHi for
+ // the current level is no longer the pLo for the next level.
+ bufSize = MULI(bufSize, C_IDX(kSliceIterWidth));
// Additional two metadata {memSize, idx} at head.
bufSize = ADDI(bufSize, c2);
llvm::for_each(
@@ -2026,8 +2141,7 @@ void LoopEmitter::invalidateSliceIterIdx(OpBuilder &builder, Location loc,
TensorId tid, Level lvl) {
for (unsigned i = 0; i <= lvl; i++) {
if (!isDenseDLT(lvlTypes[tid][i]) && !dependentLvlMap[tid][i].empty()) {
- builder.create<memref::StoreOp>(loc, C_IDX(0),
- slicePosBuffer[tid][i].back(), C_IDX(1));
+ updateSlicePosPtr(builder, loc, slicePosBuffer[tid][i].back(), C_IDX(0));
}
}
}
@@ -2080,7 +2194,7 @@ void LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
YIELD(reduc);
// else /*minCrd == offset*/ {
- // for (i = 0; i < slicePos.size(); i+=2) {
+ // for (i = 0; i < slicePos.size(); i+=kSliceIterWidth) {
// if (crd[pos[slicePos[i]]] == minCrd) {
// slicePos[i]++;
// }
@@ -2096,7 +2210,7 @@ void LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
reduc[1] = constantI1(builder, loc, false); // isNonEmpty
auto loopArgs = static_cast<ValueRange>(reduc).drop_back();
auto forOp = scf::buildLoopNest(
- builder, loc, pSt, mSz, c2, loopArgs,
+ builder, loc, pSt, mSz, C_IDX(kSliceIterWidth), loopArgs,
[this, tid, lvl, c1, sPtrBuf,
&info](OpBuilder &builder, Location loc, ValueRange ivs,
ValueRange iterArgs) -> scf::ValueVector {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
index 4f5100cc361041..315cc1d05e9266 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
@@ -470,6 +470,11 @@ class LoopEmitter {
return tid < lvlTypes.size() && lvl < lvlTypes[tid].size();
}
+ /// Forwards the (conceptual) "tree iterator" when iterating over a fully
+ /// reduced slice created by index-reduction.
+ void forwardsReducedSliceLevelTreeIt(OpBuilder &builder, Location loc,
+ TensorId tid, Level lvl, Value fcnt);
+
/// Prepares loop for iterating over `tensor[lvl]`, under the assumption
/// that `tensor[0...lvl-1]` loops have already been set up.
void prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index b767f61598ff74..b75cdba8449a1a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1484,7 +1484,16 @@ static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId ldx,
env.merger().foreachTensorLoopId(
p, /*simple=*/true,
[&](TensorLoopId b, TensorId tid, std::optional<Level> lvl,
- DimLevelType dlt, bool /*unused*/) {
+ DimLevelType dlt, bool isIdxRed) {
+ if (isIdxRed) {
+ // Since there is no 1:1 mapping from loop to level (multiple loops
+ // are required to resolve one level with non-trivial index
+ // expression), we need to reconstruct the tensor level types if this
+ // loop requires index reduction condition.
+ assert(lvl.has_value() && isUndefDLT(dlt));
+ auto stt = getSparseTensorType(env.op().getInputs()[tid]);
+ dlt = stt.getLvlType(*lvl);
+ }
assert(ldx == env.merger().loop(b));
Value clause;
if (isCompressedDLT(dlt) || isSingletonDLT(dlt) ||
@@ -1517,14 +1526,16 @@ static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId ldx,
/// Generates end of true branch of if-statement within a while-loop.
static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
Operation *loop, Value redInput, Value cntInput,
- Value insInput) {
+ Value insInput, Value validIns) {
SmallVector<Value> operands;
if (env.isReduc()) {
operands.push_back(env.getReduc());
env.updateReduc(redInput);
- if (env.getValidLexInsert())
+ if (env.getValidLexInsert()) {
// Any overlapping indices during a reduction creates a valid lex insert.
operands.push_back(constantI1(builder, env.op().getLoc(), true));
+ env.setValidLexInsert(validIns);
+ }
}
if (env.isExpand()) {
operands.push_back(env.getExpandCount());
@@ -1852,6 +1863,7 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
Value redInput = env.getReduc();
Value cntInput = env.getExpandCount();
Value insInput = env.getInsertionChain();
+ Value validIns = env.getValidLexInsert();
// NOTE: We cannot change this to `for (const LatPointId lj : env.set(lts))`
// because the loop body causes data-movement which invalidates the
// iterator.
@@ -1863,7 +1875,8 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
if (!isSingleCond) {
scf::IfOp ifOp = genIf(env, rewriter, idx, lj);
genStmt(env, rewriter, ej, at + 1);
- endIf(env, rewriter, ifOp, loop, redInput, cntInput, insInput);
+ endIf(env, rewriter, ifOp, loop, redInput, cntInput, insInput,
+ validIns);
} else {
genStmt(env, rewriter, ej, at + 1);
}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
index d1620125a43ed9..d2b8b6a9316ae1 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
@@ -5,17 +5,18 @@
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
#DCSR = #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>
+
// CHECK-LABEL: func.func @conv2d_all_sparse_CSR(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<8x8xi32, #sparse_tensor.encoding<{{.*}}>>,
-// CHECK-SAME: %[[VAL_1:.*]]: tensor<3x3xi32>)
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<3x3xi32>) -> tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>> {
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant true
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant -2 : index
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 8 : index
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 3 : index
-// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 4 : index
-// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 0 : i32
// CHECK-DAG: %[[VAL_11:.*]] = arith.constant false
// CHECK-DAG: %[[VAL_12:.*]] = bufferization.alloc_tensor() : tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>
@@ -24,241 +25,241 @@
// CHECK-DAG: %[[VAL_15:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<8x8xi32, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
// CHECK-DAG: %[[VAL_16:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<8x8xi32, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
// CHECK-DAG: %[[VAL_17:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x8xi32, #sparse_tensor.encoding<{{.*}}>> to memref<?xi32>
-// CHECK-DAG: %[[VAL_18:.*]] = memref.alloca() : memref<8xindex>
-// CHECK-DAG: %[[VAL_19:.*]] = memref.alloca() : memref<4xindex>
-// CHECK-DAG: %[[VAL_20:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_9]]] : memref<?xindex>
-// CHECK: memref.store %[[VAL_6]], %[[VAL_19]]{{\[}}%[[VAL_8]]] : memref<4xindex>
-// CHECK: memref.store %[[VAL_8]], %[[VAL_19]]{{\[}}%[[VAL_9]]] : memref<4xindex>
-// CHECK: memref.store %[[VAL_8]], %[[VAL_19]]{{\[}}%[[VAL_7]]] : memref<4xindex>
-// CHECK: memref.store %[[VAL_20]], %[[VAL_19]]{{\[}}%[[VAL_5]]] : memref<4xindex>
-// CHECK: %[[VAL_21:.*]] = arith.cmpi ugt, %[[VAL_20]], %[[VAL_8]] : index
-// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_8]]] : memref<?xindex>
+// CHECK-DAG: %[[VAL_18:.*]] = memref.alloca() : memref<11xindex>
+// CHECK-DAG: %[[VAL_19:.*]] = memref.alloca() : memref<5xindex>
+// CHECK-DAG: %[[VAL_20:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_7]]] : memref<?xindex>
+// CHECK: memref.store %[[VAL_9]], %[[VAL_19]]{{\[}}%[[VAL_6]]] : memref<5xindex>
+// CHECK: memref.store %[[VAL_6]], %[[VAL_19]]{{\[}}%[[VAL_7]]] : memref<5xindex>
+// CHECK: memref.store %[[VAL_6]], %[[VAL_19]]{{\[}}%[[VAL_8]]] : memref<5xindex>
+// CHECK: memref.store %[[VAL_20]], %[[VAL_19]]{{\[}}%[[VAL_5]]] : memref<5xindex>
+// CHECK: %[[VAL_21:.*]] = arith.cmpi ugt, %[[VAL_20]], %[[VAL_6]] : index
+// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_6]]] : memref<?xindex>
// CHECK: %[[VAL_23:.*]] = arith.cmpi uge, %[[VAL_22]], %[[VAL_5]] : index
// CHECK: %[[VAL_24:.*]] = arith.andi %[[VAL_21]], %[[VAL_23]] : i1
// CHECK: %[[VAL_25:.*]] = arith.addi %[[VAL_22]], %[[VAL_3]] : index
-// CHECK: %[[VAL_26:.*]] = arith.select %[[VAL_24]], %[[VAL_25]], %[[VAL_8]] : index
+// CHECK: %[[VAL_26:.*]] = arith.select %[[VAL_24]], %[[VAL_25]], %[[VAL_6]] : index
// CHECK: %[[VAL_27:.*]]:3 = scf.while (%[[VAL_28:.*]] = %[[VAL_21]], %[[VAL_29:.*]] = %[[VAL_22]], %[[VAL_30:.*]] = %[[VAL_26]], %[[VAL_31:.*]] = %[[VAL_12]]) : (i1, index, index, tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>) -> (index, index, tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>) {
// CHECK: scf.condition(%[[VAL_28]]) %[[VAL_29]], %[[VAL_30]], %[[VAL_31]] : index, index, tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>
// CHECK: } do {
// CHECK: ^bb0(%[[VAL_32:.*]]: index, %[[VAL_33:.*]]: index, %[[VAL_34:.*]]: tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>):
-// CHECK: %[[VAL_35:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_8]]] : memref<4xindex>
-// CHECK: %[[VAL_36:.*]]:3 = scf.for %[[VAL_37:.*]] = %[[VAL_7]] to %[[VAL_35]] step %[[VAL_7]] iter_args(%[[VAL_38:.*]] = %[[VAL_11]], %[[VAL_39:.*]] = %[[VAL_4]], %[[VAL_40:.*]] = %[[VAL_7]]) -> (i1, index, index) {
-// CHECK: %[[VAL_41:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_37]]] : memref<4xindex>
-// CHECK: %[[VAL_42:.*]] = arith.addi %[[VAL_37]], %[[VAL_9]] : index
-// CHECK: %[[VAL_43:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_42]]] : memref<4xindex>
-// CHECK: %[[VAL_44:.*]] = arith.addi %[[VAL_33]], %[[VAL_5]] : index
-// CHECK: %[[VAL_45:.*]]:4 = scf.while (%[[VAL_46:.*]] = %[[VAL_41]], %[[VAL_47:.*]] = %[[VAL_38]], %[[VAL_48:.*]] = %[[VAL_39]], %[[VAL_49:.*]] = %[[VAL_40]]) : (index, i1, index, index) -> (index, i1, index, index) {
-// CHECK: %[[VAL_50:.*]] = arith.cmpi ult, %[[VAL_46]], %[[VAL_43]] : index
-// CHECK: %[[VAL_51:.*]] = scf.if %[[VAL_50]] -> (i1) {
-// CHECK: %[[VAL_52:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_46]]] : memref<?xindex>
-// CHECK: %[[VAL_53:.*]] = arith.cmpi ult, %[[VAL_52]], %[[VAL_44]] : index
-// CHECK: scf.yield %[[VAL_53]] : i1
+// CHECK: %[[VAL_35:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_6]]] : memref<5xindex>
+// CHECK: %[[VAL_36:.*]]:4 = scf.for %[[VAL_37:.*]] = %[[VAL_8]] to %[[VAL_35]] step %[[VAL_5]] iter_args(%[[VAL_38:.*]] = %[[VAL_11]], %[[VAL_39:.*]] = %[[VAL_4]], %[[VAL_40:.*]] = %[[VAL_8]], %[[VAL_41:.*]] = %[[VAL_6]]) -> (i1, index, index, index) {
+// CHECK: %[[VAL_42:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_37]]] : memref<5xindex>
+// CHECK: %[[VAL_43:.*]] = arith.addi %[[VAL_37]], %[[VAL_7]] : index
+// CHECK: %[[VAL_44:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_43]]] : memref<5xindex>
+// CHECK: %[[VAL_45:.*]] = arith.addi %[[VAL_37]], %[[VAL_8]] : index
+// CHECK: memref.store %[[VAL_41]], %[[VAL_19]]{{\[}}%[[VAL_45]]] : memref<5xindex>
+// CHECK: %[[VAL_46:.*]] = arith.addi %[[VAL_33]], %[[VAL_5]] : index
+// CHECK: %[[VAL_47:.*]]:5 = scf.while (%[[VAL_48:.*]] = %[[VAL_42]], %[[VAL_49:.*]] = %[[VAL_38]], %[[VAL_50:.*]] = %[[VAL_39]], %[[VAL_51:.*]] = %[[VAL_40]], %[[VAL_52:.*]] = %[[VAL_41]]) : (index, i1, index, index, index) -> (index, i1, index, index, index) {
+// CHECK: %[[VAL_53:.*]] = arith.cmpi ult, %[[VAL_48]], %[[VAL_44]] : index
+// CHECK: %[[VAL_54:.*]] = scf.if %[[VAL_53]] -> (i1) {
+// CHECK: %[[VAL_55:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_48]]] : memref<?xindex>
+// CHECK: %[[VAL_56:.*]] = arith.cmpi ult, %[[VAL_55]], %[[VAL_46]] : index
+// CHECK: scf.yield %[[VAL_56]] : i1
// CHECK: } else {
// CHECK: scf.yield %[[VAL_11]] : i1
// CHECK: }
-// CHECK: scf.condition(%[[VAL_54:.*]]) %[[VAL_46]], %[[VAL_47]], %[[VAL_48]], %[[VAL_49]] : index, i1, index, index
+// CHECK: scf.condition(%[[VAL_57:.*]]) %[[VAL_48]], %[[VAL_49]], %[[VAL_50]], %[[VAL_51]], %[[VAL_52]] : index, i1, index, index, index
// CHECK: } do {
-// CHECK: ^bb0(%[[VAL_55:.*]]: index, %[[VAL_56:.*]]: i1, %[[VAL_57:.*]]: index, %[[VAL_58:.*]]: index):
-// CHECK: %[[VAL_59:.*]] = arith.addi %[[VAL_55]], %[[VAL_9]] : index
-// CHECK: %[[VAL_60:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_55]]] : memref<?xindex>
-// CHECK: %[[VAL_61:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_59]]] : memref<?xindex>
-// CHECK: %[[VAL_62:.*]] = arith.cmpi ult, %[[VAL_60]], %[[VAL_61]] : index
-// CHECK: %[[VAL_63:.*]] = arith.ori %[[VAL_62]], %[[VAL_56]] : i1
-// CHECK: %[[VAL_64:.*]] = scf.if %[[VAL_62]] -> (index) {
-// CHECK: %[[VAL_65:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_60]]] : memref<?xindex>
-// CHECK: %[[VAL_66:.*]] = arith.cmpi ult, %[[VAL_65]], %[[VAL_57]] : index
-// CHECK: %[[VAL_67:.*]] = arith.select %[[VAL_66]], %[[VAL_65]], %[[VAL_57]] : index
-// CHECK: scf.yield %[[VAL_67]] : index
+// CHECK: ^bb0(%[[VAL_58:.*]]: index, %[[VAL_59:.*]]: i1, %[[VAL_60:.*]]: index, %[[VAL_61:.*]]: index, %[[VAL_62:.*]]: index):
+// CHECK: %[[VAL_63:.*]] = arith.addi %[[VAL_58]], %[[VAL_7]] : index
+// CHECK: %[[VAL_64:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_58]]] : memref<?xindex>
+// CHECK: %[[VAL_65:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_63]]] : memref<?xindex>
+// CHECK: %[[VAL_66:.*]] = arith.cmpi ult, %[[VAL_64]], %[[VAL_65]] : index
+// CHECK: %[[VAL_67:.*]] = arith.ori %[[VAL_66]], %[[VAL_59]] : i1
+// CHECK: %[[VAL_68:.*]] = scf.if %[[VAL_66]] -> (index) {
+// CHECK: %[[VAL_69:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_64]]] : memref<?xindex>
+// CHECK: %[[VAL_70:.*]] = arith.cmpi ult, %[[VAL_69]], %[[VAL_60]] : index
+// CHECK: %[[VAL_71:.*]] = arith.select %[[VAL_70]], %[[VAL_69]], %[[VAL_60]] : index
+// CHECK: scf.yield %[[VAL_71]] : index
// CHECK: } else {
-// CHECK: scf.yield %[[VAL_57]] : index
+// CHECK: scf.yield %[[VAL_60]] : index
// CHECK: }
-// CHECK: memref.store %[[VAL_60]], %[[VAL_18]]{{\[}}%[[VAL_58]]] : memref<8xindex>
-// CHECK: %[[VAL_68:.*]] = arith.addi %[[VAL_58]], %[[VAL_9]] : index
-// CHECK: memref.store %[[VAL_61]], %[[VAL_18]]{{\[}}%[[VAL_68]]] : memref<8xindex>
-// CHECK: %[[VAL_69:.*]] = arith.addi %[[VAL_58]], %[[VAL_7]] : index
-// CHECK: scf.yield %[[VAL_59]], %[[VAL_63]], %[[VAL_70:.*]], %[[VAL_69]] : index, i1, index, index
+// CHECK: memref.store %[[VAL_64]], %[[VAL_18]]{{\[}}%[[VAL_61]]] : memref<11xindex>
+// CHECK: %[[VAL_72:.*]] = arith.addi %[[VAL_61]], %[[VAL_7]] : index
+// CHECK: memref.store %[[VAL_65]], %[[VAL_18]]{{\[}}%[[VAL_72]]] : memref<11xindex>
+// CHECK: %[[VAL_73:.*]] = arith.addi %[[VAL_61]], %[[VAL_5]] : index
+// CHECK: %[[VAL_74:.*]] = arith.addi %[[VAL_62]], %[[VAL_7]] : index
+// CHECK: scf.yield %[[VAL_63]], %[[VAL_67]], %[[VAL_75:.*]], %[[VAL_73]], %[[VAL_74]] : index, i1, index, index, index
// CHECK: }
-// CHECK: scf.yield %[[VAL_71:.*]]#1, %[[VAL_71]]#2, %[[VAL_71]]#3 : i1, index, index
+// CHECK: scf.yield %[[VAL_76:.*]]#1, %[[VAL_76]]#2, %[[VAL_76]]#3, %[[VAL_76]]#4 : i1, index, index, index
// CHECK: }
-// CHECK: memref.store %[[VAL_72:.*]]#2, %[[VAL_18]]{{\[}}%[[VAL_8]]] : memref<8xindex>
-// CHECK: memref.store %[[VAL_8]], %[[VAL_18]]{{\[}}%[[VAL_9]]] : memref<8xindex>
-// CHECK: %[[VAL_73:.*]] = arith.cmpi uge, %[[VAL_72]]#1, %[[VAL_5]] : index
-// CHECK: %[[VAL_74:.*]] = arith.andi %[[VAL_72]]#0, %[[VAL_73]] : i1
-// CHECK: %[[VAL_75:.*]] = arith.addi %[[VAL_72]]#1, %[[VAL_3]] : index
-// CHECK: %[[VAL_76:.*]] = arith.select %[[VAL_74]], %[[VAL_75]], %[[VAL_8]] : index
-// CHECK: %[[VAL_77:.*]]:3 = scf.while (%[[VAL_78:.*]] = %[[VAL_72]]#0, %[[VAL_79:.*]] = %[[VAL_72]]#1, %[[VAL_80:.*]] = %[[VAL_76]], %[[VAL_81:.*]] = %[[VAL_34]]) : (i1, index, index, tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>) -> (index, index, tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>) {
-// CHECK: scf.condition(%[[VAL_78]]) %[[VAL_79]], %[[VAL_80]], %[[VAL_81]] : index, index, tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>
+// CHECK: memref.store %[[VAL_77:.*]]#2, %[[VAL_18]]{{\[}}%[[VAL_6]]] : memref<11xindex>
+// CHECK: memref.store %[[VAL_6]], %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<11xindex>
+// CHECK: %[[VAL_78:.*]] = arith.cmpi uge, %[[VAL_77]]#1, %[[VAL_5]] : index
+// CHECK: %[[VAL_79:.*]] = arith.andi %[[VAL_77]]#0, %[[VAL_78]] : i1
+// CHECK: %[[VAL_80:.*]] = arith.addi %[[VAL_77]]#1, %[[VAL_3]] : index
+// CHECK: %[[VAL_81:.*]] = arith.select %[[VAL_79]], %[[VAL_80]], %[[VAL_6]] : index
+// CHECK: %[[VAL_82:.*]]:3 = scf.while (%[[VAL_83:.*]] = %[[VAL_77]]#0, %[[VAL_84:.*]] = %[[VAL_77]]#1, %[[VAL_85:.*]] = %[[VAL_81]], %[[VAL_86:.*]] = %[[VAL_34]]) : (i1, index, index, tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>) -> (index, index, tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>) {
+// CHECK: scf.condition(%[[VAL_83]]) %[[VAL_84]], %[[VAL_85]], %[[VAL_86]] : index, index, tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>
// CHECK: } do {
-// CHECK: ^bb0(%[[VAL_82:.*]]: index, %[[VAL_83:.*]]: index, %[[VAL_84:.*]]: tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>):
-// CHECK: %[[VAL_85:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_9]]] : memref<4xindex>
-// CHECK: %[[VAL_86:.*]] = arith.addi %[[VAL_85]], %[[VAL_7]] : index
-// CHECK: %[[VAL_87:.*]] = arith.addi %[[VAL_85]], %[[VAL_5]] : index
-// CHECK: %[[VAL_88:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_86]]] : memref<4xindex>
-// CHECK: %[[VAL_89:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_87]]] : memref<4xindex>
-// CHECK: %[[VAL_90:.*]] = arith.addi %[[VAL_33]], %[[VAL_5]] : index
-// CHECK: %[[VAL_91:.*]]:3 = scf.while (%[[VAL_92:.*]] = %[[VAL_88]], %[[VAL_93:.*]] = %[[VAL_10]], %[[VAL_94:.*]] = %[[VAL_11]]) : (index, i32, i1) -> (index, i32, i1) {
-// CHECK: %[[VAL_95:.*]] = arith.cmpi ult, %[[VAL_92]], %[[VAL_89]] : index
-// CHECK: %[[VAL_96:.*]] = scf.if %[[VAL_95]] -> (i1) {
-// CHECK: %[[VAL_97:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_92]]] : memref<?xindex>
-// CHECK: %[[VAL_98:.*]] = arith.cmpi ult, %[[VAL_97]], %[[VAL_90]] : index
-// CHECK: scf.yield %[[VAL_98]] : i1
+// CHECK: ^bb0(%[[VAL_87:.*]]: index, %[[VAL_88:.*]]: index, %[[VAL_89:.*]]: tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>):
+// CHECK: %[[VAL_90:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_7]]] : memref<5xindex>
+// CHECK: %[[VAL_91:.*]] = arith.addi %[[VAL_90]], %[[VAL_8]] : index
+// CHECK: %[[VAL_92:.*]] = arith.addi %[[VAL_90]], %[[VAL_5]] : index
+// CHECK: %[[VAL_93:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_91]]] : memref<5xindex>
+// CHECK: %[[VAL_94:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_92]]] : memref<5xindex>
+// CHECK: %[[VAL_95:.*]] = arith.addi %[[VAL_33]], %[[VAL_5]] : index
+// CHECK: %[[VAL_96:.*]]:3 = scf.while (%[[VAL_97:.*]] = %[[VAL_93]], %[[VAL_98:.*]] = %[[VAL_10]], %[[VAL_99:.*]] = %[[VAL_11]]) : (index, i32, i1) -> (index, i32, i1) {
+// CHECK: %[[VAL_100:.*]] = arith.cmpi ult, %[[VAL_97]], %[[VAL_94]] : index
+// CHECK: %[[VAL_101:.*]] = scf.if %[[VAL_100]] -> (i1) {
+// CHECK: %[[VAL_102:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_97]]] : memref<?xindex>
+// CHECK: %[[VAL_103:.*]] = arith.cmpi ult, %[[VAL_102]], %[[VAL_95]] : index
+// CHECK: scf.yield %[[VAL_103]] : i1
// CHECK: } else {
// CHECK: scf.yield %[[VAL_11]] : i1
// CHECK: }
-// CHECK: scf.condition(%[[VAL_99:.*]]) %[[VAL_92]], %[[VAL_93]], %[[VAL_94]] : index, i32, i1
+// CHECK: scf.condition(%[[VAL_104:.*]]) %[[VAL_97]], %[[VAL_98]], %[[VAL_99]] : index, i32, i1
// CHECK: } do {
-// CHECK: ^bb0(%[[VAL_100:.*]]: index, %[[VAL_101:.*]]: i32, %[[VAL_102:.*]]: i1):
-// CHECK: %[[VAL_103:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_100]]] : memref<?xindex>
-// CHECK: %[[VAL_104:.*]] = arith.subi %[[VAL_103]], %[[VAL_33]] : index
-// CHECK: %[[VAL_105:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_9]]] : memref<8xindex>
-// CHECK: %[[VAL_106:.*]] = arith.addi %[[VAL_105]], %[[VAL_7]] : index
-// CHECK: %[[VAL_107:.*]] = arith.addi %[[VAL_105]], %[[VAL_5]] : index
-// CHECK: %[[VAL_108:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_106]]] : memref<8xindex>
-// CHECK: %[[VAL_109:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_107]]] : memref<8xindex>
-// CHECK: %[[VAL_110:.*]] = arith.addi %[[VAL_83]], %[[VAL_5]] : index
-// CHECK: %[[VAL_111:.*]]:2 = scf.while (%[[VAL_112:.*]] = %[[VAL_108]], %[[VAL_113:.*]] = %[[VAL_101]]) : (index, i32) -> (index, i32) {
-// CHECK: %[[VAL_114:.*]] = arith.cmpi ult, %[[VAL_112]], %[[VAL_109]] : index
-// CHECK: %[[VAL_115:.*]] = scf.if %[[VAL_114]] -> (i1) {
-// CHECK: %[[VAL_116:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_112]]] : memref<?xindex>
-// CHECK: %[[VAL_117:.*]] = arith.cmpi ult, %[[VAL_116]], %[[VAL_110]] : index
-// CHECK: scf.yield %[[VAL_117]] : i1
+// CHECK: ^bb0(%[[VAL_105:.*]]: index, %[[VAL_106:.*]]: i32, %[[VAL_107:.*]]: i1):
+// CHECK: %[[VAL_108:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_105]]] : memref<?xindex>
+// CHECK: %[[VAL_109:.*]] = arith.subi %[[VAL_108]], %[[VAL_33]] : index
+// CHECK: %[[VAL_110:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<11xindex>
+// CHECK: %[[VAL_111:.*]] = arith.addi %[[VAL_110]], %[[VAL_8]] : index
+// CHECK: %[[VAL_112:.*]] = arith.addi %[[VAL_110]], %[[VAL_5]] : index
+// CHECK: %[[VAL_113:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_111]]] : memref<11xindex>
+// CHECK: %[[VAL_114:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_112]]] : memref<11xindex>
+// CHECK: %[[VAL_115:.*]] = arith.addi %[[VAL_88]], %[[VAL_5]] : index
+// CHECK: %[[VAL_116:.*]]:2 = scf.while (%[[VAL_117:.*]] = %[[VAL_113]], %[[VAL_118:.*]] = %[[VAL_106]]) : (index, i32) -> (index, i32) {
+// CHECK: %[[VAL_119:.*]] = arith.cmpi ult, %[[VAL_117]], %[[VAL_114]] : index
+// CHECK: %[[VAL_120:.*]] = scf.if %[[VAL_119]] -> (i1) {
+// CHECK: %[[VAL_121:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_117]]] : memref<?xindex>
+// CHECK: %[[VAL_122:.*]] = arith.cmpi ult, %[[VAL_121]], %[[VAL_115]] : index
+// CHECK: scf.yield %[[VAL_122]] : i1
// CHECK: } else {
// CHECK: scf.yield %[[VAL_11]] : i1
// CHECK: }
-// CHECK: scf.condition(%[[VAL_118:.*]]) %[[VAL_112]], %[[VAL_113]] : index, i32
+// CHECK: scf.condition(%[[VAL_123:.*]]) %[[VAL_117]], %[[VAL_118]] : index, i32
// CHECK: } do {
-// CHECK: ^bb0(%[[VAL_119:.*]]: index, %[[VAL_120:.*]]: i32):
-// CHECK: %[[VAL_121:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_119]]] : memref<?xindex>
-// CHECK: %[[VAL_122:.*]] = arith.subi %[[VAL_121]], %[[VAL_83]] : index
-// CHECK: %[[VAL_123:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_119]]] : memref<?xi32>
-// CHECK: %[[VAL_124:.*]] = tensor.extract %[[VAL_1]]{{\[}}%[[VAL_104]], %[[VAL_122]]] : tensor<3x3xi32>
-// CHECK: %[[VAL_125:.*]] = arith.muli %[[VAL_123]], %[[VAL_124]] : i32
-// CHECK: %[[VAL_126:.*]] = arith.addi %[[VAL_120]], %[[VAL_125]] : i32
-// CHECK: %[[VAL_127:.*]] = arith.addi %[[VAL_119]], %[[VAL_9]] : index
-// CHECK: scf.yield %[[VAL_127]], %[[VAL_126]] : index, i32
+// CHECK: ^bb0(%[[VAL_124:.*]]: index, %[[VAL_125:.*]]: i32):
+// CHECK: %[[VAL_126:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_124]]] : memref<?xindex>
+// CHECK: %[[VAL_127:.*]] = arith.subi %[[VAL_126]], %[[VAL_88]] : index
+// CHECK: %[[VAL_128:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_124]]] : memref<?xi32>
+// CHECK: %[[VAL_129:.*]] = tensor.extract %[[VAL_1]]{{\[}}%[[VAL_109]], %[[VAL_127]]] : tensor<3x3xi32>
+// CHECK: %[[VAL_130:.*]] = arith.muli %[[VAL_128]], %[[VAL_129]] : i32
+// CHECK: %[[VAL_131:.*]] = arith.addi %[[VAL_125]], %[[VAL_130]] : i32
+// CHECK: %[[VAL_132:.*]] = arith.addi %[[VAL_124]], %[[VAL_7]] : index
+// CHECK: scf.yield %[[VAL_132]], %[[VAL_131]] : index, i32
// CHECK: }
-// CHECK: %[[VAL_128:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_9]]] : memref<8xindex>
-// CHECK: %[[VAL_129:.*]] = arith.addi %[[VAL_128]], %[[VAL_7]] : index
-// CHECK: memref.store %[[VAL_129]], %[[VAL_18]]{{\[}}%[[VAL_9]]] : memref<8xindex>
-// CHECK: %[[VAL_130:.*]] = arith.addi %[[VAL_100]], %[[VAL_9]] : index
-// CHECK: scf.yield %[[VAL_130]], %[[VAL_131:.*]]#1, %[[VAL_2]] : index, i32, i1
+// CHECK: %[[VAL_133:.*]] = arith.addi %[[VAL_105]], %[[VAL_7]] : index
+// CHECK: %[[VAL_134:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<11xindex>
+// CHECK: %[[VAL_135:.*]] = arith.addi %[[VAL_134]], %[[VAL_5]] : index
+// CHECK: memref.store %[[VAL_135]], %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<11xindex>
+// CHECK: scf.yield %[[VAL_133]], %[[VAL_136:.*]]#1, %[[VAL_2]] : index, i32, i1
// CHECK: }
-// CHECK: %[[VAL_132:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_9]]] : memref<4xindex>
-// CHECK: %[[VAL_133:.*]] = arith.addi %[[VAL_132]], %[[VAL_7]] : index
-// CHECK: memref.store %[[VAL_133]], %[[VAL_19]]{{\[}}%[[VAL_9]]] : memref<4xindex>
-// CHECK: %[[VAL_134:.*]] = scf.if %[[VAL_135:.*]]#2 -> (tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>) {
-// CHECK: %[[VAL_136:.*]] = sparse_tensor.insert %[[VAL_135]]#1 into %[[VAL_84]]{{\[}}%[[VAL_33]], %[[VAL_83]]] : tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>
-// CHECK: scf.yield %[[VAL_136]] : tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>
+// CHECK: %[[VAL_137:.*]] = scf.if %[[VAL_138:.*]]#2 -> (tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>) {
+// CHECK: %[[VAL_139:.*]] = sparse_tensor.insert %[[VAL_138]]#1 into %[[VAL_89]]{{\[}}%[[VAL_33]], %[[VAL_88]]] : tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>
+// CHECK: scf.yield %[[VAL_139]] : tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>
// CHECK: } else {
-// CHECK: scf.yield %[[VAL_84]] : tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>
+// CHECK: scf.yield %[[VAL_89]] : tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>
// CHECK: }
-// CHECK: memref.store %[[VAL_8]], %[[VAL_19]]{{\[}}%[[VAL_9]]] : memref<4xindex>
-// CHECK: memref.store %[[VAL_8]], %[[VAL_18]]{{\[}}%[[VAL_9]]] : memref<8xindex>
-// CHECK: %[[VAL_137:.*]] = arith.cmpi ugt, %[[VAL_82]], %[[VAL_83]] : index
-// CHECK: %[[VAL_138:.*]]:3 = scf.if %[[VAL_137]] -> (index, i1, index) {
-// CHECK: %[[VAL_139:.*]] = arith.addi %[[VAL_83]], %[[VAL_9]] : index
-// CHECK: scf.yield %[[VAL_82]], %[[VAL_2]], %[[VAL_139]] : index, i1, index
+// CHECK: memref.store %[[VAL_6]], %[[VAL_19]]{{\[}}%[[VAL_7]]] : memref<5xindex>
+// CHECK: memref.store %[[VAL_6]], %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<11xindex>
+// CHECK: %[[VAL_140:.*]] = arith.cmpi ugt, %[[VAL_87]], %[[VAL_88]] : index
+// CHECK: %[[VAL_141:.*]]:3 = scf.if %[[VAL_140]] -> (index, i1, index) {
+// CHECK: %[[VAL_142:.*]] = arith.addi %[[VAL_88]], %[[VAL_7]] : index
+// CHECK: scf.yield %[[VAL_87]], %[[VAL_2]], %[[VAL_142]] : index, i1, index
// CHECK: } else {
-// CHECK: %[[VAL_140:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_8]]] : memref<8xindex>
-// CHECK: %[[VAL_141:.*]]:2 = scf.for %[[VAL_142:.*]] = %[[VAL_7]] to %[[VAL_140]] step %[[VAL_7]] iter_args(%[[VAL_143:.*]] = %[[VAL_4]], %[[VAL_144:.*]] = %[[VAL_11]]) -> (index, i1) {
-// CHECK: %[[VAL_145:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_142]]] : memref<8xindex>
-// CHECK: %[[VAL_146:.*]] = arith.addi %[[VAL_142]], %[[VAL_9]] : index
-// CHECK: %[[VAL_147:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_146]]] : memref<8xindex>
-// CHECK: %[[VAL_148:.*]] = arith.cmpi ult, %[[VAL_145]], %[[VAL_147]] : index
-// CHECK: %[[VAL_149:.*]] = scf.if %[[VAL_148]] -> (index) {
-// CHECK: %[[VAL_150:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_145]]] : memref<?xindex>
-// CHECK: %[[VAL_151:.*]] = arith.cmpi eq, %[[VAL_150]], %[[VAL_82]] : index
-// CHECK: %[[VAL_152:.*]] = scf.if %[[VAL_151]] -> (index) {
-// CHECK: %[[VAL_153:.*]] = arith.addi %[[VAL_145]], %[[VAL_9]] : index
-// CHECK: memref.store %[[VAL_153]], %[[VAL_18]]{{\[}}%[[VAL_142]]] : memref<8xindex>
-// CHECK: scf.yield %[[VAL_153]] : index
+// CHECK: %[[VAL_143:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_6]]] : memref<11xindex>
+// CHECK: %[[VAL_144:.*]]:2 = scf.for %[[VAL_145:.*]] = %[[VAL_8]] to %[[VAL_143]] step %[[VAL_5]] iter_args(%[[VAL_146:.*]] = %[[VAL_4]], %[[VAL_147:.*]] = %[[VAL_11]]) -> (index, i1) {
+// CHECK: %[[VAL_148:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_145]]] : memref<11xindex>
+// CHECK: %[[VAL_149:.*]] = arith.addi %[[VAL_145]], %[[VAL_7]] : index
+// CHECK: %[[VAL_150:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_149]]] : memref<11xindex>
+// CHECK: %[[VAL_151:.*]] = arith.cmpi ult, %[[VAL_148]], %[[VAL_150]] : index
+// CHECK: %[[VAL_152:.*]] = scf.if %[[VAL_151]] -> (index) {
+// CHECK: %[[VAL_153:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_148]]] : memref<?xindex>
+// CHECK: %[[VAL_154:.*]] = arith.cmpi eq, %[[VAL_153]], %[[VAL_87]] : index
+// CHECK: %[[VAL_155:.*]] = scf.if %[[VAL_154]] -> (index) {
+// CHECK: %[[VAL_156:.*]] = arith.addi %[[VAL_148]], %[[VAL_7]] : index
+// CHECK: memref.store %[[VAL_156]], %[[VAL_18]]{{\[}}%[[VAL_145]]] : memref<11xindex>
+// CHECK: scf.yield %[[VAL_156]] : index
// CHECK: } else {
-// CHECK: scf.yield %[[VAL_145]] : index
+// CHECK: scf.yield %[[VAL_148]] : index
// CHECK: }
-// CHECK: scf.yield %[[VAL_154:.*]] : index
+// CHECK: scf.yield %[[VAL_157:.*]] : index
// CHECK: } else {
-// CHECK: scf.yield %[[VAL_145]] : index
+// CHECK: scf.yield %[[VAL_148]] : index
// CHECK: }
-// CHECK: %[[VAL_155:.*]] = arith.cmpi ult, %[[VAL_156:.*]], %[[VAL_147]] : index
-// CHECK: %[[VAL_157:.*]] = scf.if %[[VAL_155]] -> (index) {
-// CHECK: %[[VAL_158:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_156]]] : memref<?xindex>
-// CHECK: scf.yield %[[VAL_158]] : index
+// CHECK: %[[VAL_158:.*]] = arith.cmpi ult, %[[VAL_159:.*]], %[[VAL_150]] : index
+// CHECK: %[[VAL_160:.*]] = scf.if %[[VAL_158]] -> (index) {
+// CHECK: %[[VAL_161:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_159]]] : memref<?xindex>
+// CHECK: scf.yield %[[VAL_161]] : index
// CHECK: } else {
-// CHECK: scf.yield %[[VAL_143]] : index
+// CHECK: scf.yield %[[VAL_146]] : index
// CHECK: }
-// CHECK: %[[VAL_159:.*]] = arith.ori %[[VAL_155]], %[[VAL_144]] : i1
-// CHECK: %[[VAL_160:.*]] = arith.cmpi ult, %[[VAL_161:.*]], %[[VAL_143]] : index
-// CHECK: %[[VAL_162:.*]] = arith.select %[[VAL_160]], %[[VAL_161]], %[[VAL_143]] : index
-// CHECK: scf.yield %[[VAL_162]], %[[VAL_159]] : index, i1
+// CHECK: %[[VAL_162:.*]] = arith.ori %[[VAL_158]], %[[VAL_147]] : i1
+// CHECK: %[[VAL_163:.*]] = arith.cmpi ult, %[[VAL_164:.*]], %[[VAL_146]] : index
+// CHECK: %[[VAL_165:.*]] = arith.select %[[VAL_163]], %[[VAL_164]], %[[VAL_146]] : index
+// CHECK: scf.yield %[[VAL_165]], %[[VAL_162]] : index, i1
// CHECK: }
-// CHECK: %[[VAL_163:.*]] = arith.addi %[[VAL_164:.*]]#0, %[[VAL_9]] : index
-// CHECK: %[[VAL_165:.*]] = arith.addi %[[VAL_164]]#0, %[[VAL_3]] : index
-// CHECK: %[[VAL_166:.*]] = arith.cmpi uge, %[[VAL_163]], %[[VAL_5]] : index
-// CHECK: %[[VAL_167:.*]] = arith.select %[[VAL_166]], %[[VAL_165]], %[[VAL_8]] : index
-// CHECK: scf.yield %[[VAL_164]]#0, %[[VAL_164]]#1, %[[VAL_167]] : index, i1, index
+// CHECK: %[[VAL_166:.*]] = arith.addi %[[VAL_167:.*]]#0, %[[VAL_7]] : index
+// CHECK: %[[VAL_168:.*]] = arith.addi %[[VAL_167]]#0, %[[VAL_3]] : index
+// CHECK: %[[VAL_169:.*]] = arith.cmpi uge, %[[VAL_166]], %[[VAL_5]] : index
+// CHECK: %[[VAL_170:.*]] = arith.select %[[VAL_169]], %[[VAL_168]], %[[VAL_6]] : index
+// CHECK: scf.yield %[[VAL_167]]#0, %[[VAL_167]]#1, %[[VAL_170]] : index, i1, index
// CHECK: }
-// CHECK: %[[VAL_168:.*]] = arith.addi %[[VAL_83]], %[[VAL_9]] : index
-// CHECK: %[[VAL_169:.*]] = arith.cmpi ugt, %[[VAL_170:.*]]#2, %[[VAL_168]] : index
-// CHECK: %[[VAL_171:.*]] = arith.select %[[VAL_169]], %[[VAL_170]]#2, %[[VAL_168]] : index
-// CHECK: %[[VAL_172:.*]] = arith.addi %[[VAL_171]], %[[VAL_5]] : index
-// CHECK: %[[VAL_173:.*]] = arith.cmpi ule, %[[VAL_172]], %[[VAL_4]] : index
-// CHECK: %[[VAL_174:.*]] = arith.andi %[[VAL_170]]#1, %[[VAL_173]] : i1
-// CHECK: scf.yield %[[VAL_174]], %[[VAL_170]]#0, %[[VAL_171]], %[[VAL_175:.*]] : i1, index, index, tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>
+// CHECK: %[[VAL_171:.*]] = arith.addi %[[VAL_88]], %[[VAL_7]] : index
+// CHECK: %[[VAL_172:.*]] = arith.cmpi ugt, %[[VAL_173:.*]]#2, %[[VAL_171]] : index
+// CHECK: %[[VAL_174:.*]] = arith.select %[[VAL_172]], %[[VAL_173]]#2, %[[VAL_171]] : index
+// CHECK: %[[VAL_175:.*]] = arith.addi %[[VAL_174]], %[[VAL_5]] : index
+// CHECK: %[[VAL_176:.*]] = arith.cmpi ule, %[[VAL_175]], %[[VAL_4]] : index
+// CHECK: %[[VAL_177:.*]] = arith.andi %[[VAL_173]]#1, %[[VAL_176]] : i1
+// CHECK: scf.yield %[[VAL_177]], %[[VAL_173]]#0, %[[VAL_174]], %[[VAL_178:.*]] : i1, index, index, tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>
// CHECK: }
-// CHECK: memref.store %[[VAL_8]], %[[VAL_19]]{{\[}}%[[VAL_9]]] : memref<4xindex>
-// CHECK: %[[VAL_176:.*]] = arith.cmpi ugt, %[[VAL_32]], %[[VAL_33]] : index
-// CHECK: %[[VAL_177:.*]]:3 = scf.if %[[VAL_176]] -> (index, i1, index) {
-// CHECK: %[[VAL_178:.*]] = arith.addi %[[VAL_33]], %[[VAL_9]] : index
-// CHECK: scf.yield %[[VAL_32]], %[[VAL_2]], %[[VAL_178]] : index, i1, index
+// CHECK: memref.store %[[VAL_6]], %[[VAL_19]]{{\[}}%[[VAL_7]]] : memref<5xindex>
+// CHECK: %[[VAL_179:.*]] = arith.cmpi ugt, %[[VAL_32]], %[[VAL_33]] : index
+// CHECK: %[[VAL_180:.*]]:3 = scf.if %[[VAL_179]] -> (index, i1, index) {
+// CHECK: %[[VAL_181:.*]] = arith.addi %[[VAL_33]], %[[VAL_7]] : index
+// CHECK: scf.yield %[[VAL_32]], %[[VAL_2]], %[[VAL_181]] : index, i1, index
// CHECK: } else {
-// CHECK: %[[VAL_179:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_8]]] : memref<4xindex>
-// CHECK: %[[VAL_180:.*]]:2 = scf.for %[[VAL_181:.*]] = %[[VAL_7]] to %[[VAL_179]] step %[[VAL_7]] iter_args(%[[VAL_182:.*]] = %[[VAL_4]], %[[VAL_183:.*]] = %[[VAL_11]]) -> (index, i1) {
-// CHECK: %[[VAL_184:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_181]]] : memref<4xindex>
-// CHECK: %[[VAL_185:.*]] = arith.addi %[[VAL_181]], %[[VAL_9]] : index
-// CHECK: %[[VAL_186:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_185]]] : memref<4xindex>
-// CHECK: %[[VAL_187:.*]] = arith.cmpi ult, %[[VAL_184]], %[[VAL_186]] : index
-// CHECK: %[[VAL_188:.*]] = scf.if %[[VAL_187]] -> (index) {
-// CHECK: %[[VAL_189:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_184]]] : memref<?xindex>
-// CHECK: %[[VAL_190:.*]] = arith.cmpi eq, %[[VAL_189]], %[[VAL_32]] : index
-// CHECK: %[[VAL_191:.*]] = scf.if %[[VAL_190]] -> (index) {
-// CHECK: %[[VAL_192:.*]] = arith.addi %[[VAL_184]], %[[VAL_9]] : index
-// CHECK: memref.store %[[VAL_192]], %[[VAL_19]]{{\[}}%[[VAL_181]]] : memref<4xindex>
-// CHECK: scf.yield %[[VAL_192]] : index
+// CHECK: %[[VAL_182:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_6]]] : memref<5xindex>
+// CHECK: %[[VAL_183:.*]]:2 = scf.for %[[VAL_184:.*]] = %[[VAL_8]] to %[[VAL_182]] step %[[VAL_5]] iter_args(%[[VAL_185:.*]] = %[[VAL_4]], %[[VAL_186:.*]] = %[[VAL_11]]) -> (index, i1) {
+// CHECK: %[[VAL_187:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_184]]] : memref<5xindex>
+// CHECK: %[[VAL_188:.*]] = arith.addi %[[VAL_184]], %[[VAL_7]] : index
+// CHECK: %[[VAL_189:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_188]]] : memref<5xindex>
+// CHECK: %[[VAL_190:.*]] = arith.cmpi ult, %[[VAL_187]], %[[VAL_189]] : index
+// CHECK: %[[VAL_191:.*]] = scf.if %[[VAL_190]] -> (index) {
+// CHECK: %[[VAL_192:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_187]]] : memref<?xindex>
+// CHECK: %[[VAL_193:.*]] = arith.cmpi eq, %[[VAL_192]], %[[VAL_32]] : index
+// CHECK: %[[VAL_194:.*]] = scf.if %[[VAL_193]] -> (index) {
+// CHECK: %[[VAL_195:.*]] = arith.addi %[[VAL_187]], %[[VAL_7]] : index
+// CHECK: memref.store %[[VAL_195]], %[[VAL_19]]{{\[}}%[[VAL_184]]] : memref<5xindex>
+// CHECK: scf.yield %[[VAL_195]] : index
// CHECK: } else {
-// CHECK: scf.yield %[[VAL_184]] : index
+// CHECK: scf.yield %[[VAL_187]] : index
// CHECK: }
-// CHECK: scf.yield %[[VAL_193:.*]] : index
+// CHECK: scf.yield %[[VAL_196:.*]] : index
// CHECK: } else {
-// CHECK: scf.yield %[[VAL_184]] : index
+// CHECK: scf.yield %[[VAL_187]] : index
// CHECK: }
-// CHECK: %[[VAL_194:.*]] = arith.cmpi ult, %[[VAL_195:.*]], %[[VAL_186]] : index
-// CHECK: %[[VAL_196:.*]] = scf.if %[[VAL_194]] -> (index) {
-// CHECK: %[[VAL_197:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_195]]] : memref<?xindex>
-// CHECK: scf.yield %[[VAL_197]] : index
+// CHECK: %[[VAL_197:.*]] = arith.cmpi ult, %[[VAL_198:.*]], %[[VAL_189]] : index
+// CHECK: %[[VAL_199:.*]] = scf.if %[[VAL_197]] -> (index) {
+// CHECK: %[[VAL_200:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_198]]] : memref<?xindex>
+// CHECK: scf.yield %[[VAL_200]] : index
// CHECK: } else {
-// CHECK: scf.yield %[[VAL_182]] : index
+// CHECK: scf.yield %[[VAL_185]] : index
// CHECK: }
-// CHECK: %[[VAL_198:.*]] = arith.ori %[[VAL_194]], %[[VAL_183]] : i1
-// CHECK: %[[VAL_199:.*]] = arith.cmpi ult, %[[VAL_200:.*]], %[[VAL_182]] : index
-// CHECK: %[[VAL_201:.*]] = arith.select %[[VAL_199]], %[[VAL_200]], %[[VAL_182]] : index
-// CHECK: scf.yield %[[VAL_201]], %[[VAL_198]] : index, i1
+// CHECK: %[[VAL_201:.*]] = arith.ori %[[VAL_197]], %[[VAL_186]] : i1
+// CHECK: %[[VAL_202:.*]] = arith.cmpi ult, %[[VAL_203:.*]], %[[VAL_185]] : index
+// CHECK: %[[VAL_204:.*]] = arith.select %[[VAL_202]], %[[VAL_203]], %[[VAL_185]] : index
+// CHECK: scf.yield %[[VAL_204]], %[[VAL_201]] : index, i1
// CHECK: }
-// CHECK: %[[VAL_202:.*]] = arith.addi %[[VAL_203:.*]]#0, %[[VAL_9]] : index
-// CHECK: %[[VAL_204:.*]] = arith.addi %[[VAL_203]]#0, %[[VAL_3]] : index
-// CHECK: %[[VAL_205:.*]] = arith.cmpi uge, %[[VAL_202]], %[[VAL_5]] : index
-// CHECK: %[[VAL_206:.*]] = arith.select %[[VAL_205]], %[[VAL_204]], %[[VAL_8]] : index
-// CHECK: scf.yield %[[VAL_203]]#0, %[[VAL_203]]#1, %[[VAL_206]] : index, i1, index
+// CHECK: %[[VAL_205:.*]] = arith.addi %[[VAL_206:.*]]#0, %[[VAL_7]] : index
+// CHECK: %[[VAL_207:.*]] = arith.addi %[[VAL_206]]#0, %[[VAL_3]] : index
+// CHECK: %[[VAL_208:.*]] = arith.cmpi uge, %[[VAL_205]], %[[VAL_5]] : index
+// CHECK: %[[VAL_209:.*]] = arith.select %[[VAL_208]], %[[VAL_207]], %[[VAL_6]] : index
+// CHECK: scf.yield %[[VAL_206]]#0, %[[VAL_206]]#1, %[[VAL_209]] : index, i1, index
// CHECK: }
-// CHECK: %[[VAL_207:.*]] = arith.addi %[[VAL_33]], %[[VAL_9]] : index
-// CHECK: %[[VAL_208:.*]] = arith.cmpi ugt, %[[VAL_209:.*]]#2, %[[VAL_207]] : index
-// CHECK: %[[VAL_210:.*]] = arith.select %[[VAL_208]], %[[VAL_209]]#2, %[[VAL_207]] : index
-// CHECK: %[[VAL_211:.*]] = arith.addi %[[VAL_210]], %[[VAL_5]] : index
-// CHECK: %[[VAL_212:.*]] = arith.cmpi ule, %[[VAL_211]], %[[VAL_4]] : index
-// CHECK: %[[VAL_213:.*]] = arith.andi %[[VAL_209]]#1, %[[VAL_212]] : i1
-// CHECK: scf.yield %[[VAL_213]], %[[VAL_209]]#0, %[[VAL_210]], %[[VAL_214:.*]]#2 : i1, index, index, tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>
+// CHECK: %[[VAL_210:.*]] = arith.addi %[[VAL_33]], %[[VAL_7]] : index
+// CHECK: %[[VAL_211:.*]] = arith.cmpi ugt, %[[VAL_212:.*]]#2, %[[VAL_210]] : index
+// CHECK: %[[VAL_213:.*]] = arith.select %[[VAL_211]], %[[VAL_212]]#2, %[[VAL_210]] : index
+// CHECK: %[[VAL_214:.*]] = arith.addi %[[VAL_213]], %[[VAL_5]] : index
+// CHECK: %[[VAL_215:.*]] = arith.cmpi ule, %[[VAL_214]], %[[VAL_4]] : index
+// CHECK: %[[VAL_216:.*]] = arith.andi %[[VAL_212]]#1, %[[VAL_215]] : i1
+// CHECK: scf.yield %[[VAL_216]], %[[VAL_212]]#0, %[[VAL_213]], %[[VAL_217:.*]]#2 : i1, index, index, tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>
// CHECK: }
-// CHECK: %[[VAL_215:.*]] = sparse_tensor.load %[[VAL_216:.*]]#2 hasInserts : tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>
-// CHECK: return %[[VAL_215]] : tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>
+// CHECK: %[[VAL_218:.*]] = sparse_tensor.load %[[VAL_219:.*]]#2 hasInserts : tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>
+// CHECK: return %[[VAL_218]] : tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>
// CHECK: }
func.func @conv2d_all_sparse_CSR(%arg0: tensor<8x8xi32, #DCSR>,
%arg1: tensor<3x3xi32>) -> tensor<6x6xi32, #DCSR> {
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/dual_sparse_conv_2d.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/dual_sparse_conv_2d.mlir
new file mode 100644
index 00000000000000..713c644fd4bc03
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/dual_sparse_conv_2d.mlir
@@ -0,0 +1,231 @@
+// DEFINE: %{option} = "enable-runtime-library=true enable-index-reduction=true"
+// DEFINE: %{compile} = mlir-opt %s --sparse-compiler=%{option}
+// DEFINE: %{run} = mlir-cpu-runner \
+// DEFINE: -e entry -entry-point-result=void \
+// DEFINE: -shared-libs=%mlir_c_runner_utils,%mlir_runner_utils | \
+// DEFINE: FileCheck %s
+//
+// RUN: %{compile} | %{run}
+//
+// Do the same run, but now with direct IR generation.
+// REDEFINE: %{option} = "enable-runtime-library=false enable-index-reduction=true"
+// RUN: %{compile} | %{run}
+//
+// Do the same run, but now with direct IR generation and vectorization.
+// REDEFINE: %{option} = "enable-runtime-library=false enable-index-reduction=true vl=2 reassociate-fp-reductions=true enable-index-optimizations=true"
+// RUN: %{compile} | %{run}
+
+// Do the same run, but now with direct IR generation and, if available, VLA
+// vectorization.
+// REDEFINE: %{option} = "enable-runtime-library=false vl=4 enable-index-reduction=true enable-arm-sve=%ENABLE_VLA"
+// REDEFINE: %{run} = %lli_host_or_aarch64_cmd \
+// REDEFINE: --entry-function=entry_lli \
+// REDEFINE: --extra-module=%S/Inputs/main_for_lli.ll \
+// REDEFINE: %VLA_ARCH_ATTR_OPTIONS \
+// REDEFINE: --dlopen=%mlir_native_utils_lib_dir/libmlir_c_runner_utils%shlibext | \
+// REDEFINE: FileCheck %s
+// RUN: %{compile} | mlir-translate -mlir-to-llvmir | %{run}
+
+#DCSR = #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>
+#CSR = #sparse_tensor.encoding<{lvlTypes = ["dense", "compressed"]}>
+#CDR = #sparse_tensor.encoding<{lvlTypes = ["compressed", "dense"]}>
+#CSC = #sparse_tensor.encoding<{
+ lvlTypes = [ "dense", "compressed" ],
+ dimToLvl = affine_map<(i,j) -> (j,i)>
+}>
+
+// An example of a 2D convolution with a sparse filter.
+module {
+
+ func.func @conv2d(%input: tensor<8x8xi32>,
+ %filter: tensor<3x3xi32>,
+ %output: tensor<6x6xi32>) -> tensor<6x6xi32> {
+ %0 = linalg.conv_2d
+ ins (%input, %filter: tensor<8x8xi32>, tensor<3x3xi32>)
+ outs (%output: tensor<6x6xi32>) -> tensor<6x6xi32>
+ return %0 : tensor<6x6xi32>
+ }
+
+ func.func @conv2d_all_sparse_DCSR(%input: tensor<8x8xi32, #DCSR>,
+ %filter: tensor<3x3xi32, #DCSR>) -> tensor<6x6xi32, #DCSR> {
+ %s = bufferization.alloc_tensor() : tensor<6x6xi32, #DCSR>
+ %0 = linalg.conv_2d
+ ins (%input, %filter: tensor<8x8xi32, #DCSR>, tensor<3x3xi32, #DCSR>)
+ outs (%s: tensor<6x6xi32, #DCSR>) -> tensor<6x6xi32, #DCSR>
+ return %0 : tensor<6x6xi32, #DCSR>
+ }
+
+ func.func @conv2d_all_sparse_CSR(%input: tensor<8x8xi32, #CSR>,
+ %filter: tensor<3x3xi32, #CSR>) -> tensor<6x6xi32, #CSR> {
+ %s = bufferization.alloc_tensor() : tensor<6x6xi32, #CSR>
+ %0 = linalg.conv_2d
+ ins (%input, %filter: tensor<8x8xi32, #CSR>, tensor<3x3xi32, #CSR>)
+ outs (%s: tensor<6x6xi32, #CSR>) -> tensor<6x6xi32, #CSR>
+ return %0 : tensor<6x6xi32, #CSR>
+ }
+
+ func.func @conv2d_all_sparse_CD(%input: tensor<8x8xi32, #CDR>,
+ %filter: tensor<3x3xi32, #CDR>) -> tensor<6x6xi32, #CDR> {
+ %s = bufferization.alloc_tensor() : tensor<6x6xi32, #CDR>
+ %0 = linalg.conv_2d
+ ins (%input, %filter: tensor<8x8xi32, #CDR>, tensor<3x3xi32, #CDR>)
+ outs (%s: tensor<6x6xi32, #CDR>) -> tensor<6x6xi32, #CDR>
+ return %0 : tensor<6x6xi32, #CDR>
+ }
+
+ func.func @conv2d_all_sparse_CSC(%input: tensor<8x8xi32, #CSC>,
+ %filter: tensor<3x3xi32, #CSC>) -> tensor<6x6xi32, #CSC> {
+ %s = bufferization.alloc_tensor() : tensor<6x6xi32, #CSC>
+ %0 = linalg.conv_2d
+ ins (%input, %filter: tensor<8x8xi32, #CSC>, tensor<3x3xi32, #CSC>)
+ outs (%s: tensor<6x6xi32, #CSC>) -> tensor<6x6xi32, #CSC>
+ return %0 : tensor<6x6xi32, #CSC>
+ }
+
+ func.func @entry() {
+ %c0 = arith.constant 0 : index
+ %i0 = arith.constant 0 : i32
+
+ // A typical edge detection filter.
+ %filter = arith.constant dense<[
+ [ 1, 0, -1 ],
+ [ 0, 0, 0 ],
+ [ -1, 0, 1 ]
+ ]> : tensor<3x3xi32>
+ %sparse_filter_DCSR = sparse_tensor.convert %filter
+ : tensor<3x3xi32> to tensor<3x3xi32, #DCSR>
+ %sparse_filter_CSR = sparse_tensor.convert %filter
+ : tensor<3x3xi32> to tensor<3x3xi32, #CSR>
+ %sparse_filter_CD = sparse_tensor.convert %filter
+ : tensor<3x3xi32> to tensor<3x3xi32, #CDR>
+ %sparse_filter_CSC = sparse_tensor.convert %filter
+ : tensor<3x3xi32> to tensor<3x3xi32, #CSC>
+
+ %input = arith.constant dense<[
+ [ 1, 2, 3, 4, 0, 6, 7, 8 ],
+ [ 2, 2, 4, 4, 0, 0, 6, 8 ],
+ [ 2, 2, 4, 4, 0, 0, 6, 8 ],
+ [ 2, 2, 3, 4, 0, 0, 7, 8 ],
+ [ 1, 3, 3, 4, 0, 0, 6, 8 ],
+ [ 3, 2, 3, 4, 0, 0, 7, 8 ],
+ [ 1, 3, 3, 4, 3, 6, 6, 8 ],
+ [ 1, 3, 3, 4, 3, 0, 7, 8 ]
+ ]> : tensor<8x8xi32>
+ %sparse_input_DCSR = sparse_tensor.convert %input
+ : tensor<8x8xi32> to tensor<8x8xi32, #DCSR>
+ %sparse_input_CSR = sparse_tensor.convert %input
+ : tensor<8x8xi32> to tensor<8x8xi32, #CSR>
+ %sparse_input_CD = sparse_tensor.convert %input
+ : tensor<8x8xi32> to tensor<8x8xi32, #CDR>
+ %sparse_input_CSC = sparse_tensor.convert %input
+ : tensor<8x8xi32> to tensor<8x8xi32, #CSC>
+
+ // Call the kernel.
+ %output = arith.constant dense<0> : tensor<6x6xi32>
+ %0 = call @conv2d(%input, %filter, %output)
+ : (tensor<8x8xi32>,
+ tensor<3x3xi32>, tensor<6x6xi32>) -> tensor<6x6xi32>
+ %2 = call @conv2d_all_sparse_DCSR(%sparse_input_DCSR, %sparse_filter_DCSR)
+ : (tensor<8x8xi32, #DCSR>,
+ tensor<3x3xi32, #DCSR>) -> tensor<6x6xi32, #DCSR>
+ %3 = call @conv2d_all_sparse_CSR(%sparse_input_CSR, %sparse_filter_CSR)
+ : (tensor<8x8xi32, #CSR>,
+ tensor<3x3xi32, #CSR>) -> tensor<6x6xi32, #CSR>
+ %4 = call @conv2d_all_sparse_CD(%sparse_input_CD, %sparse_filter_CD)
+ : (tensor<8x8xi32, #CDR>,
+ tensor<3x3xi32, #CDR>) -> tensor<6x6xi32, #CDR>
+ %5 = call @conv2d_all_sparse_CSC(%sparse_input_CSC, %sparse_filter_CSC)
+ : (tensor<8x8xi32, #CSC>,
+ tensor<3x3xi32, #CSC>) -> tensor<6x6xi32, #CSC>
+
+
+ // Verify the output.
+ //
+ // CHECK: ( ( 0, 0, -1, -6, -1, 6 ),
+ // CHECK-SAME: ( -1, 0, 1, 0, 1, 0 ),
+ // CHECK-SAME: ( 0, -1, 1, 0, 0, 0 ),
+ // CHECK-SAME: ( -1, 0, 0, 0, 0, 0 ),
+ // CHECK-SAME: ( 0, 0, 3, 6, -3, -6 ),
+ // CHECK-SAME: ( 2, -1, 3, 0, -3, 0 ) )
+ //
+ %v = vector.transfer_read %0[%c0, %c0], %i0
+ : tensor<6x6xi32>, vector<6x6xi32>
+ vector.print %v : vector<6x6xi32>
+
+ //
+ // Should be the same as dense output
+ // CHECK: ( ( 0, 0, -1, -6, -1, 6 ),
+ // CHECK-SAME: ( -1, 0, 1, 0, 1, 0 ),
+ // CHECK-SAME: ( 0, -1, 1, 0, 0, 0 ),
+ // CHECK-SAME: ( -1, 0, 0, 0, 0, 0 ),
+ // CHECK-SAME: ( 0, 0, 3, 6, -3, -6 ),
+ // CHECK-SAME: ( 2, -1, 3, 0, -3, 0 ) )
+ //
+ %all_sparse_DCSR = sparse_tensor.convert %2
+ : tensor<6x6xi32, #DCSR> to tensor<6x6xi32>
+ %v2 = vector.transfer_read %all_sparse_DCSR[%c0, %c0], %i0
+ : tensor<6x6xi32>, vector<6x6xi32>
+ vector.print %v2 : vector<6x6xi32>
+
+ //
+ // Should be the same as dense output
+ // CHECK: ( ( 0, 0, -1, -6, -1, 6 ),
+ // CHECK-SAME: ( -1, 0, 1, 0, 1, 0 ),
+ // CHECK-SAME: ( 0, -1, 1, 0, 0, 0 ),
+ // CHECK-SAME: ( -1, 0, 0, 0, 0, 0 ),
+ // CHECK-SAME: ( 0, 0, 3, 6, -3, -6 ),
+ // CHECK-SAME: ( 2, -1, 3, 0, -3, 0 ) )
+ //
+ %all_sparse_CD = sparse_tensor.convert %4
+ : tensor<6x6xi32, #CDR> to tensor<6x6xi32>
+ %v4 = vector.transfer_read %all_sparse_CD[%c0, %c0], %i0
+ : tensor<6x6xi32>, vector<6x6xi32>
+ vector.print %v4 : vector<6x6xi32>
+
+ //
+ // Should be the same as dense output
+ // CHECK: ( ( 0, 0, -1, -6, -1, 6 ),
+ // CHECK-SAME: ( -1, 0, 1, 0, 1, 0 ),
+ // CHECK-SAME: ( 0, -1, 1, 0, 0, 0 ),
+ // CHECK-SAME: ( -1, 0, 0, 0, 0, 0 ),
+ // CHECK-SAME: ( 0, 0, 3, 6, -3, -6 ),
+ // CHECK-SAME: ( 2, -1, 3, 0, -3, 0 ) )
+ //
+ %all_sparse_CSR = sparse_tensor.convert %3
+ : tensor<6x6xi32, #CSR> to tensor<6x6xi32>
+ %v3 = vector.transfer_read %all_sparse_CSR[%c0, %c0], %i0
+ : tensor<6x6xi32>, vector<6x6xi32>
+ vector.print %v3 : vector<6x6xi32>
+
+ //
+ // Should be the same as dense output
+ // CHECK: ( ( 0, 0, -1, -6, -1, 6 ),
+ // CHECK-SAME: ( -1, 0, 1, 0, 1, 0 ),
+ // CHECK-SAME: ( 0, -1, 1, 0, 0, 0 ),
+ // CHECK-SAME: ( -1, 0, 0, 0, 0, 0 ),
+ // CHECK-SAME: ( 0, 0, 3, 6, -3, -6 ),
+ // CHECK-SAME: ( 2, -1, 3, 0, -3, 0 ) )
+ //
+ %all_sparse_CSC = sparse_tensor.convert %5
+ : tensor<6x6xi32, #CSC> to tensor<6x6xi32>
+ %v5 = vector.transfer_read %all_sparse_CSC[%c0, %c0], %i0
+ : tensor<6x6xi32>, vector<6x6xi32>
+ vector.print %v5 : vector<6x6xi32>
+
+ // Release the resources.
+ bufferization.dealloc_tensor %sparse_input_DCSR : tensor<8x8xi32, #DCSR>
+ bufferization.dealloc_tensor %sparse_input_CSR : tensor<8x8xi32, #CSR>
+ bufferization.dealloc_tensor %sparse_input_CSC : tensor<8x8xi32, #CSC>
+ bufferization.dealloc_tensor %sparse_input_CD : tensor<8x8xi32, #CDR>
+ bufferization.dealloc_tensor %sparse_filter_DCSR : tensor<3x3xi32, #DCSR>
+ bufferization.dealloc_tensor %sparse_filter_CSR : tensor<3x3xi32, #CSR>
+ bufferization.dealloc_tensor %sparse_filter_CD : tensor<3x3xi32, #CDR>
+ bufferization.dealloc_tensor %sparse_filter_CSC : tensor<3x3xi32, #CSC>
+
+ bufferization.dealloc_tensor %2 : tensor<6x6xi32, #DCSR>
+ bufferization.dealloc_tensor %3 : tensor<6x6xi32, #CSR>
+ bufferization.dealloc_tensor %4 : tensor<6x6xi32, #CDR>
+ bufferization.dealloc_tensor %5 : tensor<6x6xi32, #CSC>
+ return
+ }
+}
More information about the Mlir-commits
mailing list