[Mlir-commits] [mlir] fc5d8fc - [mlir][sparse] support dual sparse convolution.

Peiming Liu llvmlistbot at llvm.org
Mon Jul 10 09:49:38 PDT 2023


Author: Peiming Liu
Date: 2023-07-10T16:49:32Z
New Revision: fc5d8fce7dddd27b591f0a6dcb20a7cfa59842fd

URL: https://github.com/llvm/llvm-project/commit/fc5d8fce7dddd27b591f0a6dcb20a7cfa59842fd
DIFF: https://github.com/llvm/llvm-project/commit/fc5d8fce7dddd27b591f0a6dcb20a7cfa59842fd.diff

LOG: [mlir][sparse] support dual sparse convolution.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D152601

Added: 
    mlir/test/Integration/Dialect/SparseTensor/CPU/dual_sparse_conv_2d.mlir

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
    mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index 9a130ec04445db..f4e7cc49b9f46d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -37,10 +37,31 @@ using namespace mlir::sparse_tensor;
 #define MULI(lhs, rhs) (builder.create<arith::MulIOp>(loc, (lhs), (rhs)))
 #define SELECT(c, l, r) (builder.create<arith::SelectOp>(loc, (c), (l), (r)))
 
+//===----------------------------------------------------------------------===//
+// Debugging utils
+//===----------------------------------------------------------------------===//
+
+#ifndef NDEBUG
+LLVM_ATTRIBUTE_UNUSED static void dumpIndexMemRef(OpBuilder &builder,
+                                                  Location loc, Value memref) {
+  memref = builder.create<memref::CastOp>(
+      loc, UnrankedMemRefType::get(builder.getIndexType(), 0), memref);
+  createFuncCall(builder, loc, "printMemrefInd", TypeRange{},
+                 ValueRange{memref}, EmitCInterface::On);
+}
+#endif
+
 //===----------------------------------------------------------------------===//
 // File local helper functions.
 //===----------------------------------------------------------------------===//
 
+// For index reduction loops, since the tensor are sliced into non-continuous
+// fragments, we need a triple [pLo, pHi, pPtr], in which the pair (pLo, pHi)
+// specifies the range of the fragment, and pPtr specifies the index of the
+// corresponding fragment in the child level (i.e., a pointer to the sliced
+// position array).
+static constexpr unsigned kSliceIterWidth = 3;
+
 static Value genSliceOffset(OpBuilder &builder, Location loc, Value tensor,
                             Level lvl) {
   auto enc = getSparseTensorEncoding(tensor.getType());
@@ -123,6 +144,28 @@ static Value genSparseReducedAffineCond(OpBuilder &builder, Location loc,
   return ifOp.getResult(0);
 }
 
+// Helper functions that load/store into the position buffer for slice-driven
+// loops.
+// The sliced pointer buffer is orgnized as:
+// [size, curPtr] (two metadata) + [[pLo, pHi, pNext], ...] (list of tuples)
+static Value loadSlicePosPtr(OpBuilder &builder, Location loc, Value sPosBuf) {
+  // Load curPtr.
+  // TODO: We should use SSA value for it.
+  return genIndexLoad(builder, loc, sPosBuf, C_IDX(1));
+}
+static void updateSlicePosPtr(OpBuilder &builder, Location loc, Value sPosBuf,
+                              Value pPtr) {
+  // Set curPtr.
+  // TODO: We should use SSA value for it.
+  builder.create<memref::StoreOp>(loc, pPtr, sPosBuf, C_IDX(1));
+}
+static Value loadSliceNextPosPtrStart(OpBuilder &builder, Location loc,
+                                      Value sPosBuf, Value tupleIdx) {
+  // load the pNext in the current tuple specified by `tupleIdx`.
+  // 4 = 2 (two metadata) + 2 (pNext == tuple[2])
+  return genIndexLoad(builder, loc, sPosBuf, ADDI(tupleIdx, C_IDX(4)));
+}
+
 std::pair<Value, Value>
 LoopEmitter::genSliceLegitPredicate(OpBuilder &builder, Location loc, Value crd,
                                     TensorId tid, Level lvl) {
@@ -566,18 +609,6 @@ void LoopEmitter::exitCurrentLoopSeq(OpBuilder &builder, Location loc) {
       // If this is a unresolved-slice-driven loop, pops out the slice.
       assert(sliceStack[tid].back().slicedOnLvl == lvl);
       sliceStack[tid].pop_back();
-    } else {
-      if (!isDenseDLT(lvlTypes[tid][lvl])) {
-        // Else this is a resolved-slice, and advance posit similar to TACO.
-        Value c1 = C_IDX(1), c2 = C_IDX(2);
-        // pIdx += 2, we finished the current lvl, advance the pointer index of
-        // the previous level by two to skip the [pLo, pHi] for current level.
-        Value sPtrBuf = slicePosBuffer[tid][lvl].back();
-        Value curP = genIndexLoad(builder, loc, sPtrBuf, c1);
-        // TODO: we could probably use an SSA value for it.
-        Value nexP = ADDI(curP, c2);
-        builder.create<memref::StoreOp>(loc, nexP, sPtrBuf, c1);
-      }
     }
   }
   loopSeqStack.pop_back();
@@ -1274,11 +1305,11 @@ void LoopEmitter::enterTensorsAtDenseLvls(
       // Pushes sliced levels to build correct LoopInfo.
       bool unReduc = isAffineIdxUnRedCond(denseLoopCond);
       SliceInfo &info = sliceStack[tid].back();
+      // Pushes sliced dense loop info to tell LoopEmitter how to exit it.
+      sliceInfo.emplace_back(tid, lvl, /*fullyReduced=*/!unReduc);
       if (unReduc) {
-        // Pushes sliced dense loop info to tell LoopEmitter how to exit it.
-        sliceInfo.emplace_back(tid, lvl, /*fullyReduced=*/false);
-        // Update the slice information as we enter the new loop.
         assert(*info.slicedOnLvl == lvl);
+        // Update the slice information as we enter the new loop.
         info.minCrd = info.offset = iv;
         info.isNonEmpty = constantI1(builder, loc, true);
         levelReducedDep[tid][lvl]++;
@@ -1312,15 +1343,20 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
                               MutableArrayRef<Value> reduc) {
   const LoopInfo &loopInfo = loopStack.back();
   for (auto [tid, lvl, reduced] : loopInfo.sliceDrivenInfo) {
-    SliceInfo &info = sliceStack[tid].back();
-    assert(isDenseDLT(lvlTypes[tid][lvl]));
-    assert(*info.slicedOnLvl == lvl && !reduced);
-    (void)reduced;
-    // Resets slices pointers as the resolved slices are invalidated after we
-    // moves forward to the next slice.
-    invalidateSliceIterIdx(rewriter, loc, tid, lvl);
-    info.minCrd = info.offset = info.isNonEmpty = Value();
-    levelReducedDep[tid][lvl]--;
+    if (!reduced) {
+      SliceInfo &info = sliceStack[tid].back();
+      assert(isDenseDLT(lvlTypes[tid][lvl]));
+      assert(*info.slicedOnLvl == lvl);
+      (void)reduced;
+      // Resets slices pointers as the resolved slices are invalidated after we
+      // moves forward to the next slice.
+      invalidateSliceIterIdx(rewriter, loc, tid, lvl);
+      info.minCrd = info.offset = info.isNonEmpty = Value();
+      levelReducedDep[tid][lvl]--;
+    } else {
+      forwardsReducedSliceLevelTreeIt(rewriter, loc, tid, lvl,
+                                      constantIndex(rewriter, loc, 1));
+    }
   }
   if (auto forOp = llvm::dyn_cast<scf::ForOp>(loopInfo.loop)) {
     if (!reduc.empty()) {
@@ -1398,6 +1434,61 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
   }
 }
 
+void LoopEmitter::forwardsReducedSliceLevelTreeIt(OpBuilder &builder,
+                                                  Location loc, TensorId tid,
+                                                  Level rootLvl, Value fcnt) {
+  auto stt = getSparseTensorType(tensors[tid]);
+
+  // Finds a [Lvl, leafLvl) range, and all level in between are fully reduced
+  // level (but not resolved). Since we forward an iterator at higher level of
+  // the tree, the subtree need to be pruned.
+  Level leafLvl = rootLvl + 1;
+  while (leafLvl < stt.getLvlRank() && !dependentLvlMap[tid][leafLvl].empty()) {
+    assert(depFullyReduced(tid, leafLvl));
+    leafLvl++;
+  }
+
+  Level curLvl = rootLvl + 1;
+  // Prunes all denses subtree.
+  while (curLvl < leafLvl && isDenseDLT(lvlTypes[tid][curLvl])) {
+    // One step forward in parent level results in forwarding `slice.size` step
+    // in child dense level.
+    fcnt = MULI(sliceSizes[tid][curLvl].back(), fcnt);
+    curLvl++;
+  }
+
+  Value nxPosPtr = nullptr;
+  if (curLvl < leafLvl) {
+    assert(!isDenseDLT(lvlTypes[tid][curLvl]));
+    // The first compressed level, setting up the position pointer for it.
+    Value sPosBuf = slicePosBuffer[tid][curLvl].back();
+    // One step forwards in the parent level result in forwarding one `segment`
+    // (kSliceIterWidth) in the child sparse level.
+    Value fPosPtr = MULI(fcnt, C_IDX(kSliceIterWidth));     // forward ptr
+    Value pPosPtr = loadSlicePosPtr(builder, loc, sPosBuf); // previous ptr
+    Value cPosPtr = ADDI(fPosPtr, pPosPtr);                 // current ptr
+    updateSlicePosPtr(builder, loc, sPosBuf, cPosPtr);
+    // Loads the position pointer start for next level.
+    nxPosPtr = loadSliceNextPosPtrStart(builder, loc, sPosBuf, cPosPtr);
+    curLvl++;
+  }
+
+  // TODO: This is not always needed, but we did it unconditionally for now for
+  // simplicity.
+  // It is only needed when `curLvl` is forwarded without traversing its child
+  // level (e.g., the level is in a conjunctive lattices and got pruned), such
+  // that the position pointer is not forwarded inside the loop.
+  for (; curLvl < leafLvl; curLvl++) {
+    assert(nxPosPtr);
+    if (!isDenseDLT(lvlTypes[tid][curLvl])) {
+      nxPosPtr = MULI(nxPosPtr, C_IDX(kSliceIterWidth));
+      Value sPosBuf = slicePosBuffer[tid][curLvl].back();
+      updateSlicePosPtr(builder, loc, sPosBuf, nxPosPtr);
+      nxPosPtr = loadSliceNextPosPtrStart(builder, loc, sPosBuf, nxPosPtr);
+    }
+  }
+}
+
 void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
                                 MutableArrayRef<Value> reduc) {
   const LoopInfo &loopInfo = loopStack.back();
@@ -1425,17 +1516,25 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
       continue;
     }
 
+    Value forwarded = nullptr;
     if (loopInfo.trivialTidLvls.empty() &&
         loopInfo.sliceDrivenInfo.size() == 1) {
       // Forwards the position iterator.
       operands.push_back(ADDI(posits[tid][lvl], one));
+      forwarded = constantI1(builder, loc, true);
     } else {
       const Value pos = posits[tid][lvl];
       const Value nxPos = ADDI(posits[tid][lvl], one);
-      Value cmp = CMPI(eq, coords[tid][lvl], iv);
-      operands.push_back(SELECT(cmp, nxPos, pos));
+      forwarded = CMPI(eq, coords[tid][lvl], iv);
+      operands.push_back(SELECT(forwarded, nxPos, pos));
+    }
+    {
+      OpBuilder::InsertionGuard guard(builder);
+      auto ifOp = builder.create<scf::IfOp>(loc, TypeRange{}, forwarded,
+                                            /*else=*/false);
+      builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+      forwardsReducedSliceLevelTreeIt(builder, loc, tid, lvl, one);
     }
-
     // The coordinate is invalid now.
     coords[tid][lvl] = nullptr;
 
@@ -1633,7 +1732,6 @@ std::pair<Operation *, ValueRange> LoopEmitter::genSliceLvlTraverseLoop(
 }
 
 // Generates a loop nest that traverse all the unresolved levels in between.
-// TODO: it can only handle all compressed tensors.
 //
 // for(int i = 0; i < slicePos.size(); i+=2) {
 //   loopLo = slicePos[i];
@@ -1660,6 +1758,15 @@ ValueRange LoopEmitter::genUnResolvedSliceTreeTraverse(
   OpBuilder::InsertPoint ip;
   SmallVector<Value> innerArgs(userReduc.begin(), userReduc.end());
   scf::ForOp outerMost = nullptr; // the outtermost loop.
+
+  // Wraps body builder and inserts a extra counting instruction at the end.
+  auto wrapped = [bodyBuilder](OpBuilder &builder, Location loc, Value iv,
+                               MutableArrayRef<Value> reduc) {
+    bodyBuilder(builder, loc, iv, reduc.drop_back());
+    // Increments the counter.
+    reduc.back() = ADDI(reduc.back(), C_IDX(1));
+  };
+
   if (firstResLvl.has_value()) {
     // Overwrite position when the first level is fully resolved.
     pos = posits[firstResLvl->first][firstResLvl->second];
@@ -1669,18 +1776,28 @@ ValueRange LoopEmitter::genUnResolvedSliceTreeTraverse(
     Level firstLvl = *frontSlice.slicedOnLvl;
     if (!lvlFullyResolved(tid, firstLvl)) {
       if (isCompressedDLT(lvlTypes[tid][firstLvl])) {
+        // An extra counter that tracks how many segments are there in the child
+        // compressed level.
+        innerArgs.push_back(c0);
+        // Overrides the user-provided builder.
+        bodyBuilder = wrapped;
         unsigned depth = frontSlice.depth - 1;
         Value offset = frontSlice.offset;
         Value sPtrBuf = slicePosBuffer[tid][firstLvl][depth];
         Value mSz = genIndexLoad(builder, loc, sPtrBuf, c0); // memSize
         outerMost = builder.create<scf::ForOp>(
-            loc, c2, mSz, c2, innerArgs,
-            [this, c1, tid, firstLvl, offset, sPtrBuf, &ip, &pos,
+            loc, c2, mSz, C_IDX(kSliceIterWidth), innerArgs,
+            [this, c1, c2, tid, firstLvl, offset, sPtrBuf, &ip, &pos,
              &innerArgs](OpBuilder &builder, Location loc, Value iv,
                          ValueRange iterArgs) {
               // generate traversal for each level.
               Value loopLo = genIndexLoad(builder, loc, sPtrBuf, iv);
               Value loopHi = genIndexLoad(builder, loc, sPtrBuf, ADDI(iv, c1));
+              // We need to remember the starting index for next level's
+              // position, because slice-driven loop breaks the level into
+              // non-consecutive segments.
+              builder.create<memref::StoreOp>(loc, iterArgs.back(), sPtrBuf,
+                                              ADDI(iv, c2).getResult());
               ValueRange itArgs =
                   genSliceLvlTraverseLoop(
                       builder, loc, loopLo, loopHi, offset,
@@ -1832,8 +1949,7 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
                                           TensorId tid, Level lvl) {
   Value c0 = C_IDX(0), c1 = C_IDX(1), c2 = C_IDX(2);
   unsigned depth = levelReducedDep[tid][lvl];
-  Value size = sliceSizes[tid][lvl][depth];
-  // Dense slice begin is trivial
+  Value size = sliceSizes[tid][lvl][depth]; // Dense slice begin is trivial
   if (isDenseDLT(lvlTypes[tid][lvl])) {
     sliceStack[tid].emplace_back(c0, c0, constantI1(builder, loc, false), lvl,
                                  depth + 1);
@@ -1879,9 +1995,8 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
 
   ValueRange result = genUnResolvedSliceTreeTraverse(
       builder, loc, tid, unResSlices, firstResLvl, reduc,
-      [this, c1, c2, tid, lvl, sPtrBuf](OpBuilder &builder, Location loc,
-                                        Value iv,
-                                        MutableArrayRef<Value> reduc) {
+      [this, c1, tid, lvl, sPtrBuf](OpBuilder &builder, Location loc, Value iv,
+                                    MutableArrayRef<Value> reduc) {
         Value &nonEmpty = reduc[0];
         Value &minCrd = reduc[1];
         Value &curMemSz = reduc[2];
@@ -1919,8 +2034,8 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
         builder.create<memref::StoreOp>(loc, sPLo, sPtrBuf, curMemSz);
         Value nxtMemSize = ADDI(curMemSz, c1);
         builder.create<memref::StoreOp>(loc, sPHi, sPtrBuf, nxtMemSize);
-        // curMemSize += 2
-        curMemSz = ADDI(curMemSz, c2);
+        // curMemSize += kSliceIterWidth
+        curMemSz = ADDI(curMemSz, C_IDX(kSliceIterWidth));
       });
 
   Value isNonEmpty = result[0];
@@ -1947,7 +2062,7 @@ bool LoopEmitter::genSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
     // generates slice begin any more, instead we fall back to TACO-based
     // algorithm to (co)iterates over the slice.
     Value pLoPtr =
-        genIndexLoad(builder, loc, slicePosBuffer[tid][lvl].back(), c1);
+        loadSlicePosPtr(builder, loc, slicePosBuffer[tid][lvl].back());
     pLoPtr = ADDI(pLoPtr, c2);
     Value pHiPtr = ADDI(pLoPtr, c1);
     posits[tid][lvl] =
@@ -1999,10 +2114,10 @@ bool LoopEmitter::genSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
       Value sz = *(sliceSizes[tid][lvl].rbegin() + depth - 1);
       bufSize = MULI(bufSize, sz);
     }
-    // For a pair of [pLo, pHi]. Note that we can not compress pHi because
-    // slice creates segments in the index buffer so that the pHi for the
-    // current level is no longer the pLo for the next level.
-    bufSize = MULI(bufSize, c2);
+    // For a triple of [pLo, pHi, pPtr]. Note that we can not compress pHi
+    // because slice creates segments in the index buffer so that the pHi for
+    // the current level is no longer the pLo for the next level.
+    bufSize = MULI(bufSize, C_IDX(kSliceIterWidth));
     // Additional two metadata {memSize, idx} at head.
     bufSize = ADDI(bufSize, c2);
     llvm::for_each(
@@ -2026,8 +2141,7 @@ void LoopEmitter::invalidateSliceIterIdx(OpBuilder &builder, Location loc,
                                          TensorId tid, Level lvl) {
   for (unsigned i = 0; i <= lvl; i++) {
     if (!isDenseDLT(lvlTypes[tid][i]) && !dependentLvlMap[tid][i].empty()) {
-      builder.create<memref::StoreOp>(loc, C_IDX(0),
-                                      slicePosBuffer[tid][i].back(), C_IDX(1));
+      updateSlicePosPtr(builder, loc, slicePosBuffer[tid][i].back(), C_IDX(0));
     }
   }
 }
@@ -2080,7 +2194,7 @@ void LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
     YIELD(reduc);
 
     // else /*minCrd == offset*/ {
-    //    for (i = 0; i < slicePos.size(); i+=2) {
+    //    for (i = 0; i < slicePos.size(); i+=kSliceIterWidth) {
     //       if (crd[pos[slicePos[i]]] == minCrd) {
     //          slicePos[i]++;
     //       }
@@ -2096,7 +2210,7 @@ void LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
     reduc[1] = constantI1(builder, loc, false);          // isNonEmpty
     auto loopArgs = static_cast<ValueRange>(reduc).drop_back();
     auto forOp = scf::buildLoopNest(
-        builder, loc, pSt, mSz, c2, loopArgs,
+        builder, loc, pSt, mSz, C_IDX(kSliceIterWidth), loopArgs,
         [this, tid, lvl, c1, sPtrBuf,
          &info](OpBuilder &builder, Location loc, ValueRange ivs,
                 ValueRange iterArgs) -> scf::ValueVector {

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
index 4f5100cc361041..315cc1d05e9266 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
@@ -470,6 +470,11 @@ class LoopEmitter {
     return tid < lvlTypes.size() && lvl < lvlTypes[tid].size();
   }
 
+  /// Forwards the (conceptual) "tree iterator" when iterating over a fully
+  /// reduced slice created by index-reduction.
+  void forwardsReducedSliceLevelTreeIt(OpBuilder &builder, Location loc,
+                                       TensorId tid, Level lvl, Value fcnt);
+
   /// Prepares loop for iterating over `tensor[lvl]`, under the assumption
   /// that `tensor[0...lvl-1]` loops have already been set up.
   void prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index b767f61598ff74..b75cdba8449a1a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1484,7 +1484,16 @@ static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId ldx,
   env.merger().foreachTensorLoopId(
       p, /*simple=*/true,
       [&](TensorLoopId b, TensorId tid, std::optional<Level> lvl,
-          DimLevelType dlt, bool /*unused*/) {
+          DimLevelType dlt, bool isIdxRed) {
+        if (isIdxRed) {
+          // Since there is no 1:1 mapping from loop to level (multiple loops
+          // are required to resolve one level with non-trivial index
+          // expression), we need to reconstruct the tensor level types if this
+          // loop requires index reduction condition.
+          assert(lvl.has_value() && isUndefDLT(dlt));
+          auto stt = getSparseTensorType(env.op().getInputs()[tid]);
+          dlt = stt.getLvlType(*lvl);
+        }
         assert(ldx == env.merger().loop(b));
         Value clause;
         if (isCompressedDLT(dlt) || isSingletonDLT(dlt) ||
@@ -1517,14 +1526,16 @@ static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId ldx,
 /// Generates end of true branch of if-statement within a while-loop.
 static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
                   Operation *loop, Value redInput, Value cntInput,
-                  Value insInput) {
+                  Value insInput, Value validIns) {
   SmallVector<Value> operands;
   if (env.isReduc()) {
     operands.push_back(env.getReduc());
     env.updateReduc(redInput);
-    if (env.getValidLexInsert())
+    if (env.getValidLexInsert()) {
       // Any overlapping indices during a reduction creates a valid lex insert.
       operands.push_back(constantI1(builder, env.op().getLoc(), true));
+      env.setValidLexInsert(validIns);
+    }
   }
   if (env.isExpand()) {
     operands.push_back(env.getExpandCount());
@@ -1852,6 +1863,7 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
     Value redInput = env.getReduc();
     Value cntInput = env.getExpandCount();
     Value insInput = env.getInsertionChain();
+    Value validIns = env.getValidLexInsert();
     // NOTE: We cannot change this to `for (const LatPointId lj : env.set(lts))`
     // because the loop body causes data-movement which invalidates the
     // iterator.
@@ -1863,7 +1875,8 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
         if (!isSingleCond) {
           scf::IfOp ifOp = genIf(env, rewriter, idx, lj);
           genStmt(env, rewriter, ej, at + 1);
-          endIf(env, rewriter, ifOp, loop, redInput, cntInput, insInput);
+          endIf(env, rewriter, ifOp, loop, redInput, cntInput, insInput,
+                validIns);
         } else {
           genStmt(env, rewriter, ej, at + 1);
         }

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
index d1620125a43ed9..d2b8b6a9316ae1 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
@@ -5,17 +5,18 @@
 #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
 
 #DCSR = #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>
+
 // CHECK-LABEL:   func.func @conv2d_all_sparse_CSR(
 // CHECK-SAME:      %[[VAL_0:.*]]: tensor<8x8xi32, #sparse_tensor.encoding<{{.*}}>>,
-// CHECK-SAME:      %[[VAL_1:.*]]: tensor<3x3xi32>)
+// CHECK-SAME:      %[[VAL_1:.*]]: tensor<3x3xi32>) -> tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>> {
 // CHECK-DAG:       %[[VAL_2:.*]] = arith.constant true
 // CHECK-DAG:       %[[VAL_3:.*]] = arith.constant -2 : index
 // CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 8 : index
 // CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 3 : index
-// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 4 : index
-// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 2 : index
-// CHECK-DAG:       %[[VAL_8:.*]] = arith.constant 0 : index
-// CHECK-DAG:       %[[VAL_9:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[VAL_8:.*]] = arith.constant 2 : index
+// CHECK-DAG:       %[[VAL_9:.*]] = arith.constant 4 : index
 // CHECK-DAG:       %[[VAL_10:.*]] = arith.constant 0 : i32
 // CHECK-DAG:       %[[VAL_11:.*]] = arith.constant false
 // CHECK-DAG:       %[[VAL_12:.*]] = bufferization.alloc_tensor() : tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>
@@ -24,241 +25,241 @@
 // CHECK-DAG:       %[[VAL_15:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<8x8xi32, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_16:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<8x8xi32, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_17:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x8xi32, #sparse_tensor.encoding<{{.*}}>> to memref<?xi32>
-// CHECK-DAG:       %[[VAL_18:.*]] = memref.alloca() : memref<8xindex>
-// CHECK-DAG:       %[[VAL_19:.*]] = memref.alloca() : memref<4xindex>
-// CHECK-DAG:       %[[VAL_20:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_9]]] : memref<?xindex>
-// CHECK:           memref.store %[[VAL_6]], %[[VAL_19]]{{\[}}%[[VAL_8]]] : memref<4xindex>
-// CHECK:           memref.store %[[VAL_8]], %[[VAL_19]]{{\[}}%[[VAL_9]]] : memref<4xindex>
-// CHECK:           memref.store %[[VAL_8]], %[[VAL_19]]{{\[}}%[[VAL_7]]] : memref<4xindex>
-// CHECK:           memref.store %[[VAL_20]], %[[VAL_19]]{{\[}}%[[VAL_5]]] : memref<4xindex>
-// CHECK:           %[[VAL_21:.*]] = arith.cmpi ugt, %[[VAL_20]], %[[VAL_8]] : index
-// CHECK:           %[[VAL_22:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_8]]] : memref<?xindex>
+// CHECK-DAG:       %[[VAL_18:.*]] = memref.alloca() : memref<11xindex>
+// CHECK-DAG:       %[[VAL_19:.*]] = memref.alloca() : memref<5xindex>
+// CHECK-DAG:       %[[VAL_20:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_7]]] : memref<?xindex>
+// CHECK:           memref.store %[[VAL_9]], %[[VAL_19]]{{\[}}%[[VAL_6]]] : memref<5xindex>
+// CHECK:           memref.store %[[VAL_6]], %[[VAL_19]]{{\[}}%[[VAL_7]]] : memref<5xindex>
+// CHECK:           memref.store %[[VAL_6]], %[[VAL_19]]{{\[}}%[[VAL_8]]] : memref<5xindex>
+// CHECK:           memref.store %[[VAL_20]], %[[VAL_19]]{{\[}}%[[VAL_5]]] : memref<5xindex>
+// CHECK:           %[[VAL_21:.*]] = arith.cmpi ugt, %[[VAL_20]], %[[VAL_6]] : index
+// CHECK:           %[[VAL_22:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_6]]] : memref<?xindex>
 // CHECK:           %[[VAL_23:.*]] = arith.cmpi uge, %[[VAL_22]], %[[VAL_5]] : index
 // CHECK:           %[[VAL_24:.*]] = arith.andi %[[VAL_21]], %[[VAL_23]] : i1
 // CHECK:           %[[VAL_25:.*]] = arith.addi %[[VAL_22]], %[[VAL_3]] : index
-// CHECK:           %[[VAL_26:.*]] = arith.select %[[VAL_24]], %[[VAL_25]], %[[VAL_8]] : index
+// CHECK:           %[[VAL_26:.*]] = arith.select %[[VAL_24]], %[[VAL_25]], %[[VAL_6]] : index
 // CHECK:           %[[VAL_27:.*]]:3 = scf.while (%[[VAL_28:.*]] = %[[VAL_21]], %[[VAL_29:.*]] = %[[VAL_22]], %[[VAL_30:.*]] = %[[VAL_26]], %[[VAL_31:.*]] = %[[VAL_12]]) : (i1, index, index, tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>) -> (index, index, tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>) {
 // CHECK:             scf.condition(%[[VAL_28]]) %[[VAL_29]], %[[VAL_30]], %[[VAL_31]] : index, index, tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>
 // CHECK:           } do {
 // CHECK:           ^bb0(%[[VAL_32:.*]]: index, %[[VAL_33:.*]]: index, %[[VAL_34:.*]]: tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>):
-// CHECK:             %[[VAL_35:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_8]]] : memref<4xindex>
-// CHECK:             %[[VAL_36:.*]]:3 = scf.for %[[VAL_37:.*]] = %[[VAL_7]] to %[[VAL_35]] step %[[VAL_7]] iter_args(%[[VAL_38:.*]] = %[[VAL_11]], %[[VAL_39:.*]] = %[[VAL_4]], %[[VAL_40:.*]] = %[[VAL_7]]) -> (i1, index, index) {
-// CHECK:               %[[VAL_41:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_37]]] : memref<4xindex>
-// CHECK:               %[[VAL_42:.*]] = arith.addi %[[VAL_37]], %[[VAL_9]] : index
-// CHECK:               %[[VAL_43:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_42]]] : memref<4xindex>
-// CHECK:               %[[VAL_44:.*]] = arith.addi %[[VAL_33]], %[[VAL_5]] : index
-// CHECK:               %[[VAL_45:.*]]:4 = scf.while (%[[VAL_46:.*]] = %[[VAL_41]], %[[VAL_47:.*]] = %[[VAL_38]], %[[VAL_48:.*]] = %[[VAL_39]], %[[VAL_49:.*]] = %[[VAL_40]]) : (index, i1, index, index) -> (index, i1, index, index) {
-// CHECK:                 %[[VAL_50:.*]] = arith.cmpi ult, %[[VAL_46]], %[[VAL_43]] : index
-// CHECK:                 %[[VAL_51:.*]] = scf.if %[[VAL_50]] -> (i1) {
-// CHECK:                   %[[VAL_52:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_46]]] : memref<?xindex>
-// CHECK:                   %[[VAL_53:.*]] = arith.cmpi ult, %[[VAL_52]], %[[VAL_44]] : index
-// CHECK:                   scf.yield %[[VAL_53]] : i1
+// CHECK:             %[[VAL_35:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_6]]] : memref<5xindex>
+// CHECK:             %[[VAL_36:.*]]:4 = scf.for %[[VAL_37:.*]] = %[[VAL_8]] to %[[VAL_35]] step %[[VAL_5]] iter_args(%[[VAL_38:.*]] = %[[VAL_11]], %[[VAL_39:.*]] = %[[VAL_4]], %[[VAL_40:.*]] = %[[VAL_8]], %[[VAL_41:.*]] = %[[VAL_6]]) -> (i1, index, index, index) {
+// CHECK:               %[[VAL_42:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_37]]] : memref<5xindex>
+// CHECK:               %[[VAL_43:.*]] = arith.addi %[[VAL_37]], %[[VAL_7]] : index
+// CHECK:               %[[VAL_44:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_43]]] : memref<5xindex>
+// CHECK:               %[[VAL_45:.*]] = arith.addi %[[VAL_37]], %[[VAL_8]] : index
+// CHECK:               memref.store %[[VAL_41]], %[[VAL_19]]{{\[}}%[[VAL_45]]] : memref<5xindex>
+// CHECK:               %[[VAL_46:.*]] = arith.addi %[[VAL_33]], %[[VAL_5]] : index
+// CHECK:               %[[VAL_47:.*]]:5 = scf.while (%[[VAL_48:.*]] = %[[VAL_42]], %[[VAL_49:.*]] = %[[VAL_38]], %[[VAL_50:.*]] = %[[VAL_39]], %[[VAL_51:.*]] = %[[VAL_40]], %[[VAL_52:.*]] = %[[VAL_41]]) : (index, i1, index, index, index) -> (index, i1, index, index, index) {
+// CHECK:                 %[[VAL_53:.*]] = arith.cmpi ult, %[[VAL_48]], %[[VAL_44]] : index
+// CHECK:                 %[[VAL_54:.*]] = scf.if %[[VAL_53]] -> (i1) {
+// CHECK:                   %[[VAL_55:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_48]]] : memref<?xindex>
+// CHECK:                   %[[VAL_56:.*]] = arith.cmpi ult, %[[VAL_55]], %[[VAL_46]] : index
+// CHECK:                   scf.yield %[[VAL_56]] : i1
 // CHECK:                 } else {
 // CHECK:                   scf.yield %[[VAL_11]] : i1
 // CHECK:                 }
-// CHECK:                 scf.condition(%[[VAL_54:.*]]) %[[VAL_46]], %[[VAL_47]], %[[VAL_48]], %[[VAL_49]] : index, i1, index, index
+// CHECK:                 scf.condition(%[[VAL_57:.*]]) %[[VAL_48]], %[[VAL_49]], %[[VAL_50]], %[[VAL_51]], %[[VAL_52]] : index, i1, index, index, index
 // CHECK:               } do {
-// CHECK:               ^bb0(%[[VAL_55:.*]]: index, %[[VAL_56:.*]]: i1, %[[VAL_57:.*]]: index, %[[VAL_58:.*]]: index):
-// CHECK:                 %[[VAL_59:.*]] = arith.addi %[[VAL_55]], %[[VAL_9]] : index
-// CHECK:                 %[[VAL_60:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_55]]] : memref<?xindex>
-// CHECK:                 %[[VAL_61:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_59]]] : memref<?xindex>
-// CHECK:                 %[[VAL_62:.*]] = arith.cmpi ult, %[[VAL_60]], %[[VAL_61]] : index
-// CHECK:                 %[[VAL_63:.*]] = arith.ori %[[VAL_62]], %[[VAL_56]] : i1
-// CHECK:                 %[[VAL_64:.*]] = scf.if %[[VAL_62]] -> (index) {
-// CHECK:                   %[[VAL_65:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_60]]] : memref<?xindex>
-// CHECK:                   %[[VAL_66:.*]] = arith.cmpi ult, %[[VAL_65]], %[[VAL_57]] : index
-// CHECK:                   %[[VAL_67:.*]] = arith.select %[[VAL_66]], %[[VAL_65]], %[[VAL_57]] : index
-// CHECK:                   scf.yield %[[VAL_67]] : index
+// CHECK:               ^bb0(%[[VAL_58:.*]]: index, %[[VAL_59:.*]]: i1, %[[VAL_60:.*]]: index, %[[VAL_61:.*]]: index, %[[VAL_62:.*]]: index):
+// CHECK:                 %[[VAL_63:.*]] = arith.addi %[[VAL_58]], %[[VAL_7]] : index
+// CHECK:                 %[[VAL_64:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_58]]] : memref<?xindex>
+// CHECK:                 %[[VAL_65:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_63]]] : memref<?xindex>
+// CHECK:                 %[[VAL_66:.*]] = arith.cmpi ult, %[[VAL_64]], %[[VAL_65]] : index
+// CHECK:                 %[[VAL_67:.*]] = arith.ori %[[VAL_66]], %[[VAL_59]] : i1
+// CHECK:                 %[[VAL_68:.*]] = scf.if %[[VAL_66]] -> (index) {
+// CHECK:                   %[[VAL_69:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_64]]] : memref<?xindex>
+// CHECK:                   %[[VAL_70:.*]] = arith.cmpi ult, %[[VAL_69]], %[[VAL_60]] : index
+// CHECK:                   %[[VAL_71:.*]] = arith.select %[[VAL_70]], %[[VAL_69]], %[[VAL_60]] : index
+// CHECK:                   scf.yield %[[VAL_71]] : index
 // CHECK:                 } else {
-// CHECK:                   scf.yield %[[VAL_57]] : index
+// CHECK:                   scf.yield %[[VAL_60]] : index
 // CHECK:                 }
-// CHECK:                 memref.store %[[VAL_60]], %[[VAL_18]]{{\[}}%[[VAL_58]]] : memref<8xindex>
-// CHECK:                 %[[VAL_68:.*]] = arith.addi %[[VAL_58]], %[[VAL_9]] : index
-// CHECK:                 memref.store %[[VAL_61]], %[[VAL_18]]{{\[}}%[[VAL_68]]] : memref<8xindex>
-// CHECK:                 %[[VAL_69:.*]] = arith.addi %[[VAL_58]], %[[VAL_7]] : index
-// CHECK:                 scf.yield %[[VAL_59]], %[[VAL_63]], %[[VAL_70:.*]], %[[VAL_69]] : index, i1, index, index
+// CHECK:                 memref.store %[[VAL_64]], %[[VAL_18]]{{\[}}%[[VAL_61]]] : memref<11xindex>
+// CHECK:                 %[[VAL_72:.*]] = arith.addi %[[VAL_61]], %[[VAL_7]] : index
+// CHECK:                 memref.store %[[VAL_65]], %[[VAL_18]]{{\[}}%[[VAL_72]]] : memref<11xindex>
+// CHECK:                 %[[VAL_73:.*]] = arith.addi %[[VAL_61]], %[[VAL_5]] : index
+// CHECK:                 %[[VAL_74:.*]] = arith.addi %[[VAL_62]], %[[VAL_7]] : index
+// CHECK:                 scf.yield %[[VAL_63]], %[[VAL_67]], %[[VAL_75:.*]], %[[VAL_73]], %[[VAL_74]] : index, i1, index, index, index
 // CHECK:               }
-// CHECK:               scf.yield %[[VAL_71:.*]]#1, %[[VAL_71]]#2, %[[VAL_71]]#3 : i1, index, index
+// CHECK:               scf.yield %[[VAL_76:.*]]#1, %[[VAL_76]]#2, %[[VAL_76]]#3, %[[VAL_76]]#4 : i1, index, index, index
 // CHECK:             }
-// CHECK:             memref.store %[[VAL_72:.*]]#2, %[[VAL_18]]{{\[}}%[[VAL_8]]] : memref<8xindex>
-// CHECK:             memref.store %[[VAL_8]], %[[VAL_18]]{{\[}}%[[VAL_9]]] : memref<8xindex>
-// CHECK:             %[[VAL_73:.*]] = arith.cmpi uge, %[[VAL_72]]#1, %[[VAL_5]] : index
-// CHECK:             %[[VAL_74:.*]] = arith.andi %[[VAL_72]]#0, %[[VAL_73]] : i1
-// CHECK:             %[[VAL_75:.*]] = arith.addi %[[VAL_72]]#1, %[[VAL_3]] : index
-// CHECK:             %[[VAL_76:.*]] = arith.select %[[VAL_74]], %[[VAL_75]], %[[VAL_8]] : index
-// CHECK:             %[[VAL_77:.*]]:3 = scf.while (%[[VAL_78:.*]] = %[[VAL_72]]#0, %[[VAL_79:.*]] = %[[VAL_72]]#1, %[[VAL_80:.*]] = %[[VAL_76]], %[[VAL_81:.*]] = %[[VAL_34]]) : (i1, index, index, tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>) -> (index, index, tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>) {
-// CHECK:               scf.condition(%[[VAL_78]]) %[[VAL_79]], %[[VAL_80]], %[[VAL_81]] : index, index, tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>
+// CHECK:             memref.store %[[VAL_77:.*]]#2, %[[VAL_18]]{{\[}}%[[VAL_6]]] : memref<11xindex>
+// CHECK:             memref.store %[[VAL_6]], %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<11xindex>
+// CHECK:             %[[VAL_78:.*]] = arith.cmpi uge, %[[VAL_77]]#1, %[[VAL_5]] : index
+// CHECK:             %[[VAL_79:.*]] = arith.andi %[[VAL_77]]#0, %[[VAL_78]] : i1
+// CHECK:             %[[VAL_80:.*]] = arith.addi %[[VAL_77]]#1, %[[VAL_3]] : index
+// CHECK:             %[[VAL_81:.*]] = arith.select %[[VAL_79]], %[[VAL_80]], %[[VAL_6]] : index
+// CHECK:             %[[VAL_82:.*]]:3 = scf.while (%[[VAL_83:.*]] = %[[VAL_77]]#0, %[[VAL_84:.*]] = %[[VAL_77]]#1, %[[VAL_85:.*]] = %[[VAL_81]], %[[VAL_86:.*]] = %[[VAL_34]]) : (i1, index, index, tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>) -> (index, index, tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>) {
+// CHECK:               scf.condition(%[[VAL_83]]) %[[VAL_84]], %[[VAL_85]], %[[VAL_86]] : index, index, tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>
 // CHECK:             } do {
-// CHECK:             ^bb0(%[[VAL_82:.*]]: index, %[[VAL_83:.*]]: index, %[[VAL_84:.*]]: tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>):
-// CHECK:               %[[VAL_85:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_9]]] : memref<4xindex>
-// CHECK:               %[[VAL_86:.*]] = arith.addi %[[VAL_85]], %[[VAL_7]] : index
-// CHECK:               %[[VAL_87:.*]] = arith.addi %[[VAL_85]], %[[VAL_5]] : index
-// CHECK:               %[[VAL_88:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_86]]] : memref<4xindex>
-// CHECK:               %[[VAL_89:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_87]]] : memref<4xindex>
-// CHECK:               %[[VAL_90:.*]] = arith.addi %[[VAL_33]], %[[VAL_5]] : index
-// CHECK:               %[[VAL_91:.*]]:3 = scf.while (%[[VAL_92:.*]] = %[[VAL_88]], %[[VAL_93:.*]] = %[[VAL_10]], %[[VAL_94:.*]] = %[[VAL_11]]) : (index, i32, i1) -> (index, i32, i1) {
-// CHECK:                 %[[VAL_95:.*]] = arith.cmpi ult, %[[VAL_92]], %[[VAL_89]] : index
-// CHECK:                 %[[VAL_96:.*]] = scf.if %[[VAL_95]] -> (i1) {
-// CHECK:                   %[[VAL_97:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_92]]] : memref<?xindex>
-// CHECK:                   %[[VAL_98:.*]] = arith.cmpi ult, %[[VAL_97]], %[[VAL_90]] : index
-// CHECK:                   scf.yield %[[VAL_98]] : i1
+// CHECK:             ^bb0(%[[VAL_87:.*]]: index, %[[VAL_88:.*]]: index, %[[VAL_89:.*]]: tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>):
+// CHECK:               %[[VAL_90:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_7]]] : memref<5xindex>
+// CHECK:               %[[VAL_91:.*]] = arith.addi %[[VAL_90]], %[[VAL_8]] : index
+// CHECK:               %[[VAL_92:.*]] = arith.addi %[[VAL_90]], %[[VAL_5]] : index
+// CHECK:               %[[VAL_93:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_91]]] : memref<5xindex>
+// CHECK:               %[[VAL_94:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_92]]] : memref<5xindex>
+// CHECK:               %[[VAL_95:.*]] = arith.addi %[[VAL_33]], %[[VAL_5]] : index
+// CHECK:               %[[VAL_96:.*]]:3 = scf.while (%[[VAL_97:.*]] = %[[VAL_93]], %[[VAL_98:.*]] = %[[VAL_10]], %[[VAL_99:.*]] = %[[VAL_11]]) : (index, i32, i1) -> (index, i32, i1) {
+// CHECK:                 %[[VAL_100:.*]] = arith.cmpi ult, %[[VAL_97]], %[[VAL_94]] : index
+// CHECK:                 %[[VAL_101:.*]] = scf.if %[[VAL_100]] -> (i1) {
+// CHECK:                   %[[VAL_102:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_97]]] : memref<?xindex>
+// CHECK:                   %[[VAL_103:.*]] = arith.cmpi ult, %[[VAL_102]], %[[VAL_95]] : index
+// CHECK:                   scf.yield %[[VAL_103]] : i1
 // CHECK:                 } else {
 // CHECK:                   scf.yield %[[VAL_11]] : i1
 // CHECK:                 }
-// CHECK:                 scf.condition(%[[VAL_99:.*]]) %[[VAL_92]], %[[VAL_93]], %[[VAL_94]] : index, i32, i1
+// CHECK:                 scf.condition(%[[VAL_104:.*]]) %[[VAL_97]], %[[VAL_98]], %[[VAL_99]] : index, i32, i1
 // CHECK:               } do {
-// CHECK:               ^bb0(%[[VAL_100:.*]]: index, %[[VAL_101:.*]]: i32, %[[VAL_102:.*]]: i1):
-// CHECK:                 %[[VAL_103:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_100]]] : memref<?xindex>
-// CHECK:                 %[[VAL_104:.*]] = arith.subi %[[VAL_103]], %[[VAL_33]] : index
-// CHECK:                 %[[VAL_105:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_9]]] : memref<8xindex>
-// CHECK:                 %[[VAL_106:.*]] = arith.addi %[[VAL_105]], %[[VAL_7]] : index
-// CHECK:                 %[[VAL_107:.*]] = arith.addi %[[VAL_105]], %[[VAL_5]] : index
-// CHECK:                 %[[VAL_108:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_106]]] : memref<8xindex>
-// CHECK:                 %[[VAL_109:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_107]]] : memref<8xindex>
-// CHECK:                 %[[VAL_110:.*]] = arith.addi %[[VAL_83]], %[[VAL_5]] : index
-// CHECK:                 %[[VAL_111:.*]]:2 = scf.while (%[[VAL_112:.*]] = %[[VAL_108]], %[[VAL_113:.*]] = %[[VAL_101]]) : (index, i32) -> (index, i32) {
-// CHECK:                   %[[VAL_114:.*]] = arith.cmpi ult, %[[VAL_112]], %[[VAL_109]] : index
-// CHECK:                   %[[VAL_115:.*]] = scf.if %[[VAL_114]] -> (i1) {
-// CHECK:                     %[[VAL_116:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_112]]] : memref<?xindex>
-// CHECK:                     %[[VAL_117:.*]] = arith.cmpi ult, %[[VAL_116]], %[[VAL_110]] : index
-// CHECK:                     scf.yield %[[VAL_117]] : i1
+// CHECK:               ^bb0(%[[VAL_105:.*]]: index, %[[VAL_106:.*]]: i32, %[[VAL_107:.*]]: i1):
+// CHECK:                 %[[VAL_108:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_105]]] : memref<?xindex>
+// CHECK:                 %[[VAL_109:.*]] = arith.subi %[[VAL_108]], %[[VAL_33]] : index
+// CHECK:                 %[[VAL_110:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<11xindex>
+// CHECK:                 %[[VAL_111:.*]] = arith.addi %[[VAL_110]], %[[VAL_8]] : index
+// CHECK:                 %[[VAL_112:.*]] = arith.addi %[[VAL_110]], %[[VAL_5]] : index
+// CHECK:                 %[[VAL_113:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_111]]] : memref<11xindex>
+// CHECK:                 %[[VAL_114:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_112]]] : memref<11xindex>
+// CHECK:                 %[[VAL_115:.*]] = arith.addi %[[VAL_88]], %[[VAL_5]] : index
+// CHECK:                 %[[VAL_116:.*]]:2 = scf.while (%[[VAL_117:.*]] = %[[VAL_113]], %[[VAL_118:.*]] = %[[VAL_106]]) : (index, i32) -> (index, i32) {
+// CHECK:                   %[[VAL_119:.*]] = arith.cmpi ult, %[[VAL_117]], %[[VAL_114]] : index
+// CHECK:                   %[[VAL_120:.*]] = scf.if %[[VAL_119]] -> (i1) {
+// CHECK:                     %[[VAL_121:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_117]]] : memref<?xindex>
+// CHECK:                     %[[VAL_122:.*]] = arith.cmpi ult, %[[VAL_121]], %[[VAL_115]] : index
+// CHECK:                     scf.yield %[[VAL_122]] : i1
 // CHECK:                   } else {
 // CHECK:                     scf.yield %[[VAL_11]] : i1
 // CHECK:                   }
-// CHECK:                   scf.condition(%[[VAL_118:.*]]) %[[VAL_112]], %[[VAL_113]] : index, i32
+// CHECK:                   scf.condition(%[[VAL_123:.*]]) %[[VAL_117]], %[[VAL_118]] : index, i32
 // CHECK:                 } do {
-// CHECK:                 ^bb0(%[[VAL_119:.*]]: index, %[[VAL_120:.*]]: i32):
-// CHECK:                   %[[VAL_121:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_119]]] : memref<?xindex>
-// CHECK:                   %[[VAL_122:.*]] = arith.subi %[[VAL_121]], %[[VAL_83]] : index
-// CHECK:                   %[[VAL_123:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_119]]] : memref<?xi32>
-// CHECK:                   %[[VAL_124:.*]] = tensor.extract %[[VAL_1]]{{\[}}%[[VAL_104]], %[[VAL_122]]] : tensor<3x3xi32>
-// CHECK:                   %[[VAL_125:.*]] = arith.muli %[[VAL_123]], %[[VAL_124]] : i32
-// CHECK:                   %[[VAL_126:.*]] = arith.addi %[[VAL_120]], %[[VAL_125]] : i32
-// CHECK:                   %[[VAL_127:.*]] = arith.addi %[[VAL_119]], %[[VAL_9]] : index
-// CHECK:                   scf.yield %[[VAL_127]], %[[VAL_126]] : index, i32
+// CHECK:                 ^bb0(%[[VAL_124:.*]]: index, %[[VAL_125:.*]]: i32):
+// CHECK:                   %[[VAL_126:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_124]]] : memref<?xindex>
+// CHECK:                   %[[VAL_127:.*]] = arith.subi %[[VAL_126]], %[[VAL_88]] : index
+// CHECK:                   %[[VAL_128:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_124]]] : memref<?xi32>
+// CHECK:                   %[[VAL_129:.*]] = tensor.extract %[[VAL_1]]{{\[}}%[[VAL_109]], %[[VAL_127]]] : tensor<3x3xi32>
+// CHECK:                   %[[VAL_130:.*]] = arith.muli %[[VAL_128]], %[[VAL_129]] : i32
+// CHECK:                   %[[VAL_131:.*]] = arith.addi %[[VAL_125]], %[[VAL_130]] : i32
+// CHECK:                   %[[VAL_132:.*]] = arith.addi %[[VAL_124]], %[[VAL_7]] : index
+// CHECK:                   scf.yield %[[VAL_132]], %[[VAL_131]] : index, i32
 // CHECK:                 }
-// CHECK:                 %[[VAL_128:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_9]]] : memref<8xindex>
-// CHECK:                 %[[VAL_129:.*]] = arith.addi %[[VAL_128]], %[[VAL_7]] : index
-// CHECK:                 memref.store %[[VAL_129]], %[[VAL_18]]{{\[}}%[[VAL_9]]] : memref<8xindex>
-// CHECK:                 %[[VAL_130:.*]] = arith.addi %[[VAL_100]], %[[VAL_9]] : index
-// CHECK:                 scf.yield %[[VAL_130]], %[[VAL_131:.*]]#1, %[[VAL_2]] : index, i32, i1
+// CHECK:                 %[[VAL_133:.*]] = arith.addi %[[VAL_105]], %[[VAL_7]] : index
+// CHECK:                 %[[VAL_134:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<11xindex>
+// CHECK:                 %[[VAL_135:.*]] = arith.addi %[[VAL_134]], %[[VAL_5]] : index
+// CHECK:                 memref.store %[[VAL_135]], %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<11xindex>
+// CHECK:                 scf.yield %[[VAL_133]], %[[VAL_136:.*]]#1, %[[VAL_2]] : index, i32, i1
 // CHECK:               }
-// CHECK:               %[[VAL_132:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_9]]] : memref<4xindex>
-// CHECK:               %[[VAL_133:.*]] = arith.addi %[[VAL_132]], %[[VAL_7]] : index
-// CHECK:               memref.store %[[VAL_133]], %[[VAL_19]]{{\[}}%[[VAL_9]]] : memref<4xindex>
-// CHECK:               %[[VAL_134:.*]] = scf.if %[[VAL_135:.*]]#2 -> (tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>) {
-// CHECK:                 %[[VAL_136:.*]] = sparse_tensor.insert %[[VAL_135]]#1 into %[[VAL_84]]{{\[}}%[[VAL_33]], %[[VAL_83]]] : tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>
-// CHECK:                 scf.yield %[[VAL_136]] : tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>
+// CHECK:               %[[VAL_137:.*]] = scf.if %[[VAL_138:.*]]#2 -> (tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>) {
+// CHECK:                 %[[VAL_139:.*]] = sparse_tensor.insert %[[VAL_138]]#1 into %[[VAL_89]]{{\[}}%[[VAL_33]], %[[VAL_88]]] : tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>
+// CHECK:                 scf.yield %[[VAL_139]] : tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>
 // CHECK:               } else {
-// CHECK:                 scf.yield %[[VAL_84]] : tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>
+// CHECK:                 scf.yield %[[VAL_89]] : tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>
 // CHECK:               }
-// CHECK:               memref.store %[[VAL_8]], %[[VAL_19]]{{\[}}%[[VAL_9]]] : memref<4xindex>
-// CHECK:               memref.store %[[VAL_8]], %[[VAL_18]]{{\[}}%[[VAL_9]]] : memref<8xindex>
-// CHECK:               %[[VAL_137:.*]] = arith.cmpi ugt, %[[VAL_82]], %[[VAL_83]] : index
-// CHECK:               %[[VAL_138:.*]]:3 = scf.if %[[VAL_137]] -> (index, i1, index) {
-// CHECK:                 %[[VAL_139:.*]] = arith.addi %[[VAL_83]], %[[VAL_9]] : index
-// CHECK:                 scf.yield %[[VAL_82]], %[[VAL_2]], %[[VAL_139]] : index, i1, index
+// CHECK:               memref.store %[[VAL_6]], %[[VAL_19]]{{\[}}%[[VAL_7]]] : memref<5xindex>
+// CHECK:               memref.store %[[VAL_6]], %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<11xindex>
+// CHECK:               %[[VAL_140:.*]] = arith.cmpi ugt, %[[VAL_87]], %[[VAL_88]] : index
+// CHECK:               %[[VAL_141:.*]]:3 = scf.if %[[VAL_140]] -> (index, i1, index) {
+// CHECK:                 %[[VAL_142:.*]] = arith.addi %[[VAL_88]], %[[VAL_7]] : index
+// CHECK:                 scf.yield %[[VAL_87]], %[[VAL_2]], %[[VAL_142]] : index, i1, index
 // CHECK:               } else {
-// CHECK:                 %[[VAL_140:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_8]]] : memref<8xindex>
-// CHECK:                 %[[VAL_141:.*]]:2 = scf.for %[[VAL_142:.*]] = %[[VAL_7]] to %[[VAL_140]] step %[[VAL_7]] iter_args(%[[VAL_143:.*]] = %[[VAL_4]], %[[VAL_144:.*]] = %[[VAL_11]]) -> (index, i1) {
-// CHECK:                   %[[VAL_145:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_142]]] : memref<8xindex>
-// CHECK:                   %[[VAL_146:.*]] = arith.addi %[[VAL_142]], %[[VAL_9]] : index
-// CHECK:                   %[[VAL_147:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_146]]] : memref<8xindex>
-// CHECK:                   %[[VAL_148:.*]] = arith.cmpi ult, %[[VAL_145]], %[[VAL_147]] : index
-// CHECK:                   %[[VAL_149:.*]] = scf.if %[[VAL_148]] -> (index) {
-// CHECK:                     %[[VAL_150:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_145]]] : memref<?xindex>
-// CHECK:                     %[[VAL_151:.*]] = arith.cmpi eq, %[[VAL_150]], %[[VAL_82]] : index
-// CHECK:                     %[[VAL_152:.*]] = scf.if %[[VAL_151]] -> (index) {
-// CHECK:                       %[[VAL_153:.*]] = arith.addi %[[VAL_145]], %[[VAL_9]] : index
-// CHECK:                       memref.store %[[VAL_153]], %[[VAL_18]]{{\[}}%[[VAL_142]]] : memref<8xindex>
-// CHECK:                       scf.yield %[[VAL_153]] : index
+// CHECK:                 %[[VAL_143:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_6]]] : memref<11xindex>
+// CHECK:                 %[[VAL_144:.*]]:2 = scf.for %[[VAL_145:.*]] = %[[VAL_8]] to %[[VAL_143]] step %[[VAL_5]] iter_args(%[[VAL_146:.*]] = %[[VAL_4]], %[[VAL_147:.*]] = %[[VAL_11]]) -> (index, i1) {
+// CHECK:                   %[[VAL_148:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_145]]] : memref<11xindex>
+// CHECK:                   %[[VAL_149:.*]] = arith.addi %[[VAL_145]], %[[VAL_7]] : index
+// CHECK:                   %[[VAL_150:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_149]]] : memref<11xindex>
+// CHECK:                   %[[VAL_151:.*]] = arith.cmpi ult, %[[VAL_148]], %[[VAL_150]] : index
+// CHECK:                   %[[VAL_152:.*]] = scf.if %[[VAL_151]] -> (index) {
+// CHECK:                     %[[VAL_153:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_148]]] : memref<?xindex>
+// CHECK:                     %[[VAL_154:.*]] = arith.cmpi eq, %[[VAL_153]], %[[VAL_87]] : index
+// CHECK:                     %[[VAL_155:.*]] = scf.if %[[VAL_154]] -> (index) {
+// CHECK:                       %[[VAL_156:.*]] = arith.addi %[[VAL_148]], %[[VAL_7]] : index
+// CHECK:                       memref.store %[[VAL_156]], %[[VAL_18]]{{\[}}%[[VAL_145]]] : memref<11xindex>
+// CHECK:                       scf.yield %[[VAL_156]] : index
 // CHECK:                     } else {
-// CHECK:                       scf.yield %[[VAL_145]] : index
+// CHECK:                       scf.yield %[[VAL_148]] : index
 // CHECK:                     }
-// CHECK:                     scf.yield %[[VAL_154:.*]] : index
+// CHECK:                     scf.yield %[[VAL_157:.*]] : index
 // CHECK:                   } else {
-// CHECK:                     scf.yield %[[VAL_145]] : index
+// CHECK:                     scf.yield %[[VAL_148]] : index
 // CHECK:                   }
-// CHECK:                   %[[VAL_155:.*]] = arith.cmpi ult, %[[VAL_156:.*]], %[[VAL_147]] : index
-// CHECK:                   %[[VAL_157:.*]] = scf.if %[[VAL_155]] -> (index) {
-// CHECK:                     %[[VAL_158:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_156]]] : memref<?xindex>
-// CHECK:                     scf.yield %[[VAL_158]] : index
+// CHECK:                   %[[VAL_158:.*]] = arith.cmpi ult, %[[VAL_159:.*]], %[[VAL_150]] : index
+// CHECK:                   %[[VAL_160:.*]] = scf.if %[[VAL_158]] -> (index) {
+// CHECK:                     %[[VAL_161:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_159]]] : memref<?xindex>
+// CHECK:                     scf.yield %[[VAL_161]] : index
 // CHECK:                   } else {
-// CHECK:                     scf.yield %[[VAL_143]] : index
+// CHECK:                     scf.yield %[[VAL_146]] : index
 // CHECK:                   }
-// CHECK:                   %[[VAL_159:.*]] = arith.ori %[[VAL_155]], %[[VAL_144]] : i1
-// CHECK:                   %[[VAL_160:.*]] = arith.cmpi ult, %[[VAL_161:.*]], %[[VAL_143]] : index
-// CHECK:                   %[[VAL_162:.*]] = arith.select %[[VAL_160]], %[[VAL_161]], %[[VAL_143]] : index
-// CHECK:                   scf.yield %[[VAL_162]], %[[VAL_159]] : index, i1
+// CHECK:                   %[[VAL_162:.*]] = arith.ori %[[VAL_158]], %[[VAL_147]] : i1
+// CHECK:                   %[[VAL_163:.*]] = arith.cmpi ult, %[[VAL_164:.*]], %[[VAL_146]] : index
+// CHECK:                   %[[VAL_165:.*]] = arith.select %[[VAL_163]], %[[VAL_164]], %[[VAL_146]] : index
+// CHECK:                   scf.yield %[[VAL_165]], %[[VAL_162]] : index, i1
 // CHECK:                 }
-// CHECK:                 %[[VAL_163:.*]] = arith.addi %[[VAL_164:.*]]#0, %[[VAL_9]] : index
-// CHECK:                 %[[VAL_165:.*]] = arith.addi %[[VAL_164]]#0, %[[VAL_3]] : index
-// CHECK:                 %[[VAL_166:.*]] = arith.cmpi uge, %[[VAL_163]], %[[VAL_5]] : index
-// CHECK:                 %[[VAL_167:.*]] = arith.select %[[VAL_166]], %[[VAL_165]], %[[VAL_8]] : index
-// CHECK:                 scf.yield %[[VAL_164]]#0, %[[VAL_164]]#1, %[[VAL_167]] : index, i1, index
+// CHECK:                 %[[VAL_166:.*]] = arith.addi %[[VAL_167:.*]]#0, %[[VAL_7]] : index
+// CHECK:                 %[[VAL_168:.*]] = arith.addi %[[VAL_167]]#0, %[[VAL_3]] : index
+// CHECK:                 %[[VAL_169:.*]] = arith.cmpi uge, %[[VAL_166]], %[[VAL_5]] : index
+// CHECK:                 %[[VAL_170:.*]] = arith.select %[[VAL_169]], %[[VAL_168]], %[[VAL_6]] : index
+// CHECK:                 scf.yield %[[VAL_167]]#0, %[[VAL_167]]#1, %[[VAL_170]] : index, i1, index
 // CHECK:               }
-// CHECK:               %[[VAL_168:.*]] = arith.addi %[[VAL_83]], %[[VAL_9]] : index
-// CHECK:               %[[VAL_169:.*]] = arith.cmpi ugt, %[[VAL_170:.*]]#2, %[[VAL_168]] : index
-// CHECK:               %[[VAL_171:.*]] = arith.select %[[VAL_169]], %[[VAL_170]]#2, %[[VAL_168]] : index
-// CHECK:               %[[VAL_172:.*]] = arith.addi %[[VAL_171]], %[[VAL_5]] : index
-// CHECK:               %[[VAL_173:.*]] = arith.cmpi ule, %[[VAL_172]], %[[VAL_4]] : index
-// CHECK:               %[[VAL_174:.*]] = arith.andi %[[VAL_170]]#1, %[[VAL_173]] : i1
-// CHECK:               scf.yield %[[VAL_174]], %[[VAL_170]]#0, %[[VAL_171]], %[[VAL_175:.*]] : i1, index, index, tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>
+// CHECK:               %[[VAL_171:.*]] = arith.addi %[[VAL_88]], %[[VAL_7]] : index
+// CHECK:               %[[VAL_172:.*]] = arith.cmpi ugt, %[[VAL_173:.*]]#2, %[[VAL_171]] : index
+// CHECK:               %[[VAL_174:.*]] = arith.select %[[VAL_172]], %[[VAL_173]]#2, %[[VAL_171]] : index
+// CHECK:               %[[VAL_175:.*]] = arith.addi %[[VAL_174]], %[[VAL_5]] : index
+// CHECK:               %[[VAL_176:.*]] = arith.cmpi ule, %[[VAL_175]], %[[VAL_4]] : index
+// CHECK:               %[[VAL_177:.*]] = arith.andi %[[VAL_173]]#1, %[[VAL_176]] : i1
+// CHECK:               scf.yield %[[VAL_177]], %[[VAL_173]]#0, %[[VAL_174]], %[[VAL_178:.*]] : i1, index, index, tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>
 // CHECK:             }
-// CHECK:             memref.store %[[VAL_8]], %[[VAL_19]]{{\[}}%[[VAL_9]]] : memref<4xindex>
-// CHECK:             %[[VAL_176:.*]] = arith.cmpi ugt, %[[VAL_32]], %[[VAL_33]] : index
-// CHECK:             %[[VAL_177:.*]]:3 = scf.if %[[VAL_176]] -> (index, i1, index) {
-// CHECK:               %[[VAL_178:.*]] = arith.addi %[[VAL_33]], %[[VAL_9]] : index
-// CHECK:               scf.yield %[[VAL_32]], %[[VAL_2]], %[[VAL_178]] : index, i1, index
+// CHECK:             memref.store %[[VAL_6]], %[[VAL_19]]{{\[}}%[[VAL_7]]] : memref<5xindex>
+// CHECK:             %[[VAL_179:.*]] = arith.cmpi ugt, %[[VAL_32]], %[[VAL_33]] : index
+// CHECK:             %[[VAL_180:.*]]:3 = scf.if %[[VAL_179]] -> (index, i1, index) {
+// CHECK:               %[[VAL_181:.*]] = arith.addi %[[VAL_33]], %[[VAL_7]] : index
+// CHECK:               scf.yield %[[VAL_32]], %[[VAL_2]], %[[VAL_181]] : index, i1, index
 // CHECK:             } else {
-// CHECK:               %[[VAL_179:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_8]]] : memref<4xindex>
-// CHECK:               %[[VAL_180:.*]]:2 = scf.for %[[VAL_181:.*]] = %[[VAL_7]] to %[[VAL_179]] step %[[VAL_7]] iter_args(%[[VAL_182:.*]] = %[[VAL_4]], %[[VAL_183:.*]] = %[[VAL_11]]) -> (index, i1) {
-// CHECK:                 %[[VAL_184:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_181]]] : memref<4xindex>
-// CHECK:                 %[[VAL_185:.*]] = arith.addi %[[VAL_181]], %[[VAL_9]] : index
-// CHECK:                 %[[VAL_186:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_185]]] : memref<4xindex>
-// CHECK:                 %[[VAL_187:.*]] = arith.cmpi ult, %[[VAL_184]], %[[VAL_186]] : index
-// CHECK:                 %[[VAL_188:.*]] = scf.if %[[VAL_187]] -> (index) {
-// CHECK:                   %[[VAL_189:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_184]]] : memref<?xindex>
-// CHECK:                   %[[VAL_190:.*]] = arith.cmpi eq, %[[VAL_189]], %[[VAL_32]] : index
-// CHECK:                   %[[VAL_191:.*]] = scf.if %[[VAL_190]] -> (index) {
-// CHECK:                     %[[VAL_192:.*]] = arith.addi %[[VAL_184]], %[[VAL_9]] : index
-// CHECK:                     memref.store %[[VAL_192]], %[[VAL_19]]{{\[}}%[[VAL_181]]] : memref<4xindex>
-// CHECK:                     scf.yield %[[VAL_192]] : index
+// CHECK:               %[[VAL_182:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_6]]] : memref<5xindex>
+// CHECK:               %[[VAL_183:.*]]:2 = scf.for %[[VAL_184:.*]] = %[[VAL_8]] to %[[VAL_182]] step %[[VAL_5]] iter_args(%[[VAL_185:.*]] = %[[VAL_4]], %[[VAL_186:.*]] = %[[VAL_11]]) -> (index, i1) {
+// CHECK:                 %[[VAL_187:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_184]]] : memref<5xindex>
+// CHECK:                 %[[VAL_188:.*]] = arith.addi %[[VAL_184]], %[[VAL_7]] : index
+// CHECK:                 %[[VAL_189:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_188]]] : memref<5xindex>
+// CHECK:                 %[[VAL_190:.*]] = arith.cmpi ult, %[[VAL_187]], %[[VAL_189]] : index
+// CHECK:                 %[[VAL_191:.*]] = scf.if %[[VAL_190]] -> (index) {
+// CHECK:                   %[[VAL_192:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_187]]] : memref<?xindex>
+// CHECK:                   %[[VAL_193:.*]] = arith.cmpi eq, %[[VAL_192]], %[[VAL_32]] : index
+// CHECK:                   %[[VAL_194:.*]] = scf.if %[[VAL_193]] -> (index) {
+// CHECK:                     %[[VAL_195:.*]] = arith.addi %[[VAL_187]], %[[VAL_7]] : index
+// CHECK:                     memref.store %[[VAL_195]], %[[VAL_19]]{{\[}}%[[VAL_184]]] : memref<5xindex>
+// CHECK:                     scf.yield %[[VAL_195]] : index
 // CHECK:                   } else {
-// CHECK:                     scf.yield %[[VAL_184]] : index
+// CHECK:                     scf.yield %[[VAL_187]] : index
 // CHECK:                   }
-// CHECK:                   scf.yield %[[VAL_193:.*]] : index
+// CHECK:                   scf.yield %[[VAL_196:.*]] : index
 // CHECK:                 } else {
-// CHECK:                   scf.yield %[[VAL_184]] : index
+// CHECK:                   scf.yield %[[VAL_187]] : index
 // CHECK:                 }
-// CHECK:                 %[[VAL_194:.*]] = arith.cmpi ult, %[[VAL_195:.*]], %[[VAL_186]] : index
-// CHECK:                 %[[VAL_196:.*]] = scf.if %[[VAL_194]] -> (index) {
-// CHECK:                   %[[VAL_197:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_195]]] : memref<?xindex>
-// CHECK:                   scf.yield %[[VAL_197]] : index
+// CHECK:                 %[[VAL_197:.*]] = arith.cmpi ult, %[[VAL_198:.*]], %[[VAL_189]] : index
+// CHECK:                 %[[VAL_199:.*]] = scf.if %[[VAL_197]] -> (index) {
+// CHECK:                   %[[VAL_200:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_198]]] : memref<?xindex>
+// CHECK:                   scf.yield %[[VAL_200]] : index
 // CHECK:                 } else {
-// CHECK:                   scf.yield %[[VAL_182]] : index
+// CHECK:                   scf.yield %[[VAL_185]] : index
 // CHECK:                 }
-// CHECK:                 %[[VAL_198:.*]] = arith.ori %[[VAL_194]], %[[VAL_183]] : i1
-// CHECK:                 %[[VAL_199:.*]] = arith.cmpi ult, %[[VAL_200:.*]], %[[VAL_182]] : index
-// CHECK:                 %[[VAL_201:.*]] = arith.select %[[VAL_199]], %[[VAL_200]], %[[VAL_182]] : index
-// CHECK:                 scf.yield %[[VAL_201]], %[[VAL_198]] : index, i1
+// CHECK:                 %[[VAL_201:.*]] = arith.ori %[[VAL_197]], %[[VAL_186]] : i1
+// CHECK:                 %[[VAL_202:.*]] = arith.cmpi ult, %[[VAL_203:.*]], %[[VAL_185]] : index
+// CHECK:                 %[[VAL_204:.*]] = arith.select %[[VAL_202]], %[[VAL_203]], %[[VAL_185]] : index
+// CHECK:                 scf.yield %[[VAL_204]], %[[VAL_201]] : index, i1
 // CHECK:               }
-// CHECK:               %[[VAL_202:.*]] = arith.addi %[[VAL_203:.*]]#0, %[[VAL_9]] : index
-// CHECK:               %[[VAL_204:.*]] = arith.addi %[[VAL_203]]#0, %[[VAL_3]] : index
-// CHECK:               %[[VAL_205:.*]] = arith.cmpi uge, %[[VAL_202]], %[[VAL_5]] : index
-// CHECK:               %[[VAL_206:.*]] = arith.select %[[VAL_205]], %[[VAL_204]], %[[VAL_8]] : index
-// CHECK:               scf.yield %[[VAL_203]]#0, %[[VAL_203]]#1, %[[VAL_206]] : index, i1, index
+// CHECK:               %[[VAL_205:.*]] = arith.addi %[[VAL_206:.*]]#0, %[[VAL_7]] : index
+// CHECK:               %[[VAL_207:.*]] = arith.addi %[[VAL_206]]#0, %[[VAL_3]] : index
+// CHECK:               %[[VAL_208:.*]] = arith.cmpi uge, %[[VAL_205]], %[[VAL_5]] : index
+// CHECK:               %[[VAL_209:.*]] = arith.select %[[VAL_208]], %[[VAL_207]], %[[VAL_6]] : index
+// CHECK:               scf.yield %[[VAL_206]]#0, %[[VAL_206]]#1, %[[VAL_209]] : index, i1, index
 // CHECK:             }
-// CHECK:             %[[VAL_207:.*]] = arith.addi %[[VAL_33]], %[[VAL_9]] : index
-// CHECK:             %[[VAL_208:.*]] = arith.cmpi ugt, %[[VAL_209:.*]]#2, %[[VAL_207]] : index
-// CHECK:             %[[VAL_210:.*]] = arith.select %[[VAL_208]], %[[VAL_209]]#2, %[[VAL_207]] : index
-// CHECK:             %[[VAL_211:.*]] = arith.addi %[[VAL_210]], %[[VAL_5]] : index
-// CHECK:             %[[VAL_212:.*]] = arith.cmpi ule, %[[VAL_211]], %[[VAL_4]] : index
-// CHECK:             %[[VAL_213:.*]] = arith.andi %[[VAL_209]]#1, %[[VAL_212]] : i1
-// CHECK:             scf.yield %[[VAL_213]], %[[VAL_209]]#0, %[[VAL_210]], %[[VAL_214:.*]]#2 : i1, index, index, tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>
+// CHECK:             %[[VAL_210:.*]] = arith.addi %[[VAL_33]], %[[VAL_7]] : index
+// CHECK:             %[[VAL_211:.*]] = arith.cmpi ugt, %[[VAL_212:.*]]#2, %[[VAL_210]] : index
+// CHECK:             %[[VAL_213:.*]] = arith.select %[[VAL_211]], %[[VAL_212]]#2, %[[VAL_210]] : index
+// CHECK:             %[[VAL_214:.*]] = arith.addi %[[VAL_213]], %[[VAL_5]] : index
+// CHECK:             %[[VAL_215:.*]] = arith.cmpi ule, %[[VAL_214]], %[[VAL_4]] : index
+// CHECK:             %[[VAL_216:.*]] = arith.andi %[[VAL_212]]#1, %[[VAL_215]] : i1
+// CHECK:             scf.yield %[[VAL_216]], %[[VAL_212]]#0, %[[VAL_213]], %[[VAL_217:.*]]#2 : i1, index, index, tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>
 // CHECK:           }
-// CHECK:           %[[VAL_215:.*]] = sparse_tensor.load %[[VAL_216:.*]]#2 hasInserts : tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>
-// CHECK:           return %[[VAL_215]] : tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>
+// CHECK:           %[[VAL_218:.*]] = sparse_tensor.load %[[VAL_219:.*]]#2 hasInserts : tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>
+// CHECK:           return %[[VAL_218]] : tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>
 // CHECK:         }
 func.func @conv2d_all_sparse_CSR(%arg0: tensor<8x8xi32, #DCSR>,
                                  %arg1: tensor<3x3xi32>) -> tensor<6x6xi32, #DCSR> {

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/dual_sparse_conv_2d.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/dual_sparse_conv_2d.mlir
new file mode 100644
index 00000000000000..713c644fd4bc03
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/dual_sparse_conv_2d.mlir
@@ -0,0 +1,231 @@
+// DEFINE: %{option} = "enable-runtime-library=true enable-index-reduction=true"
+// DEFINE: %{compile} = mlir-opt %s --sparse-compiler=%{option}
+// DEFINE: %{run} = mlir-cpu-runner \
+// DEFINE:  -e entry -entry-point-result=void  \
+// DEFINE:  -shared-libs=%mlir_c_runner_utils,%mlir_runner_utils | \
+// DEFINE: FileCheck %s
+//
+// RUN: %{compile} | %{run}
+//
+// Do the same run, but now with direct IR generation.
+// REDEFINE: %{option} = "enable-runtime-library=false enable-index-reduction=true"
+// RUN: %{compile} | %{run}
+//
+// Do the same run, but now with direct IR generation and vectorization.
+// REDEFINE: %{option} = "enable-runtime-library=false enable-index-reduction=true vl=2 reassociate-fp-reductions=true enable-index-optimizations=true"
+// RUN: %{compile} | %{run}
+
+// Do the same run, but now with direct IR generation and, if available, VLA
+// vectorization.
+// REDEFINE: %{option} = "enable-runtime-library=false vl=4 enable-index-reduction=true enable-arm-sve=%ENABLE_VLA"
+// REDEFINE: %{run} = %lli_host_or_aarch64_cmd \
+// REDEFINE:   --entry-function=entry_lli \
+// REDEFINE:   --extra-module=%S/Inputs/main_for_lli.ll \
+// REDEFINE:   %VLA_ARCH_ATTR_OPTIONS \
+// REDEFINE:   --dlopen=%mlir_native_utils_lib_dir/libmlir_c_runner_utils%shlibext | \
+// REDEFINE: FileCheck %s
+// RUN: %{compile} | mlir-translate -mlir-to-llvmir | %{run}
+
+#DCSR = #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>
+#CSR = #sparse_tensor.encoding<{lvlTypes = ["dense", "compressed"]}>
+#CDR = #sparse_tensor.encoding<{lvlTypes = ["compressed", "dense"]}>
+#CSC = #sparse_tensor.encoding<{
+  lvlTypes = [ "dense", "compressed" ],
+  dimToLvl = affine_map<(i,j) -> (j,i)>
+}>
+
+// An example of a 2D convolution with a sparse filter.
+module {
+
+  func.func @conv2d(%input:  tensor<8x8xi32>,
+               %filter: tensor<3x3xi32>,
+               %output: tensor<6x6xi32>) -> tensor<6x6xi32> {
+    %0 = linalg.conv_2d
+      ins  (%input, %filter: tensor<8x8xi32>, tensor<3x3xi32>)
+      outs (%output: tensor<6x6xi32>) -> tensor<6x6xi32>
+    return %0 : tensor<6x6xi32>
+  }
+
+  func.func @conv2d_all_sparse_DCSR(%input:  tensor<8x8xi32, #DCSR>,
+               %filter: tensor<3x3xi32, #DCSR>) -> tensor<6x6xi32, #DCSR> {
+    %s = bufferization.alloc_tensor() : tensor<6x6xi32, #DCSR>
+    %0 = linalg.conv_2d
+      ins  (%input, %filter: tensor<8x8xi32, #DCSR>, tensor<3x3xi32, #DCSR>)
+      outs (%s: tensor<6x6xi32, #DCSR>) -> tensor<6x6xi32, #DCSR>
+    return %0 : tensor<6x6xi32, #DCSR>
+  }
+
+  func.func @conv2d_all_sparse_CSR(%input:  tensor<8x8xi32, #CSR>,
+               %filter: tensor<3x3xi32, #CSR>) -> tensor<6x6xi32, #CSR> {
+    %s = bufferization.alloc_tensor() : tensor<6x6xi32, #CSR>
+    %0 = linalg.conv_2d
+      ins  (%input, %filter: tensor<8x8xi32, #CSR>, tensor<3x3xi32, #CSR>)
+      outs (%s: tensor<6x6xi32, #CSR>) -> tensor<6x6xi32, #CSR>
+    return %0 : tensor<6x6xi32, #CSR>
+  }
+
+  func.func @conv2d_all_sparse_CD(%input:  tensor<8x8xi32, #CDR>,
+               %filter: tensor<3x3xi32, #CDR>) -> tensor<6x6xi32, #CDR> {
+    %s = bufferization.alloc_tensor() : tensor<6x6xi32, #CDR>
+    %0 = linalg.conv_2d
+      ins  (%input, %filter: tensor<8x8xi32, #CDR>, tensor<3x3xi32, #CDR>)
+      outs (%s: tensor<6x6xi32, #CDR>) -> tensor<6x6xi32, #CDR>
+    return %0 : tensor<6x6xi32, #CDR>
+  }
+
+  func.func @conv2d_all_sparse_CSC(%input:  tensor<8x8xi32, #CSC>,
+               %filter: tensor<3x3xi32, #CSC>) -> tensor<6x6xi32, #CSC> {
+    %s = bufferization.alloc_tensor() : tensor<6x6xi32, #CSC>
+    %0 = linalg.conv_2d
+      ins  (%input, %filter: tensor<8x8xi32, #CSC>, tensor<3x3xi32, #CSC>)
+      outs (%s: tensor<6x6xi32, #CSC>) -> tensor<6x6xi32, #CSC>
+    return %0 : tensor<6x6xi32, #CSC>
+  }
+
+  func.func @entry() {
+    %c0 = arith.constant 0 : index
+    %i0 = arith.constant 0 : i32
+
+    // A typical edge detection filter.
+    %filter = arith.constant dense<[
+      [  1,  0, -1 ],
+      [  0,  0,  0 ],
+      [ -1,  0,  1 ]
+    ]> : tensor<3x3xi32>
+    %sparse_filter_DCSR = sparse_tensor.convert %filter
+      : tensor<3x3xi32> to tensor<3x3xi32, #DCSR>
+    %sparse_filter_CSR = sparse_tensor.convert %filter
+      : tensor<3x3xi32> to tensor<3x3xi32, #CSR>
+    %sparse_filter_CD = sparse_tensor.convert %filter
+      : tensor<3x3xi32> to tensor<3x3xi32, #CDR>
+    %sparse_filter_CSC = sparse_tensor.convert %filter
+      : tensor<3x3xi32> to tensor<3x3xi32, #CSC>
+
+    %input = arith.constant dense<[
+      [  1,  2,  3,  4,  0,  6,  7,  8 ],
+      [  2,  2,  4,  4,  0,  0,  6,  8 ],
+      [  2,  2,  4,  4,  0,  0,  6,  8 ],
+      [  2,  2,  3,  4,  0,  0,  7,  8 ],
+      [  1,  3,  3,  4,  0,  0,  6,  8 ],
+      [  3,  2,  3,  4,  0,  0,  7,  8 ],
+      [  1,  3,  3,  4,  3,  6,  6,  8 ],
+      [  1,  3,  3,  4,  3,  0,  7,  8 ]
+    ]> : tensor<8x8xi32>
+    %sparse_input_DCSR = sparse_tensor.convert %input
+      : tensor<8x8xi32> to tensor<8x8xi32, #DCSR>
+    %sparse_input_CSR = sparse_tensor.convert %input
+      : tensor<8x8xi32> to tensor<8x8xi32, #CSR>
+    %sparse_input_CD = sparse_tensor.convert %input
+      : tensor<8x8xi32> to tensor<8x8xi32, #CDR>
+    %sparse_input_CSC = sparse_tensor.convert %input
+      : tensor<8x8xi32> to tensor<8x8xi32, #CSC>
+
+    // Call the kernel.
+    %output = arith.constant dense<0> : tensor<6x6xi32>
+    %0 = call @conv2d(%input, %filter, %output)
+       : (tensor<8x8xi32>,
+          tensor<3x3xi32>, tensor<6x6xi32>) -> tensor<6x6xi32>
+    %2 = call @conv2d_all_sparse_DCSR(%sparse_input_DCSR, %sparse_filter_DCSR)
+       : (tensor<8x8xi32, #DCSR>,
+          tensor<3x3xi32, #DCSR>) -> tensor<6x6xi32, #DCSR>
+    %3 = call @conv2d_all_sparse_CSR(%sparse_input_CSR, %sparse_filter_CSR)
+       : (tensor<8x8xi32, #CSR>,
+          tensor<3x3xi32, #CSR>) -> tensor<6x6xi32, #CSR>
+    %4 = call @conv2d_all_sparse_CD(%sparse_input_CD, %sparse_filter_CD)
+       : (tensor<8x8xi32, #CDR>,
+          tensor<3x3xi32, #CDR>) -> tensor<6x6xi32, #CDR>
+    %5 = call @conv2d_all_sparse_CSC(%sparse_input_CSC, %sparse_filter_CSC)
+       : (tensor<8x8xi32, #CSC>,
+          tensor<3x3xi32, #CSC>) -> tensor<6x6xi32, #CSC>
+
+
+    // Verify the output.
+    //
+    // CHECK:    ( ( 0, 0, -1, -6, -1, 6 ),
+    // CHECK-SAME: ( -1, 0, 1, 0, 1, 0 ),
+    // CHECK-SAME: ( 0, -1, 1, 0, 0, 0 ),
+    // CHECK-SAME: ( -1, 0, 0, 0, 0, 0 ),
+    // CHECK-SAME: ( 0, 0, 3, 6, -3, -6 ),
+    // CHECK-SAME: ( 2, -1, 3, 0, -3, 0 ) )
+    //
+    %v = vector.transfer_read %0[%c0, %c0], %i0
+      : tensor<6x6xi32>, vector<6x6xi32>
+    vector.print %v : vector<6x6xi32>
+
+    //
+    // Should be the same as dense output
+    // CHECK:    ( ( 0, 0, -1, -6, -1, 6 ),
+    // CHECK-SAME: ( -1, 0, 1, 0, 1, 0 ),
+    // CHECK-SAME: ( 0, -1, 1, 0, 0, 0 ),
+    // CHECK-SAME: ( -1, 0, 0, 0, 0, 0 ),
+    // CHECK-SAME: ( 0, 0, 3, 6, -3, -6 ),
+    // CHECK-SAME: ( 2, -1, 3, 0, -3, 0 ) )
+    //
+    %all_sparse_DCSR = sparse_tensor.convert %2
+      : tensor<6x6xi32, #DCSR> to tensor<6x6xi32>
+    %v2 = vector.transfer_read %all_sparse_DCSR[%c0, %c0], %i0
+      : tensor<6x6xi32>, vector<6x6xi32>
+    vector.print %v2 : vector<6x6xi32>
+
+    //
+    // Should be the same as dense output
+    // CHECK:    ( ( 0, 0, -1, -6, -1, 6 ),
+    // CHECK-SAME: ( -1, 0, 1, 0, 1, 0 ),
+    // CHECK-SAME: ( 0, -1, 1, 0, 0, 0 ),
+    // CHECK-SAME: ( -1, 0, 0, 0, 0, 0 ),
+    // CHECK-SAME: ( 0, 0, 3, 6, -3, -6 ),
+    // CHECK-SAME: ( 2, -1, 3, 0, -3, 0 ) )
+    //
+    %all_sparse_CD = sparse_tensor.convert %4
+      : tensor<6x6xi32, #CDR> to tensor<6x6xi32>
+    %v4 = vector.transfer_read %all_sparse_CD[%c0, %c0], %i0
+      : tensor<6x6xi32>, vector<6x6xi32>
+    vector.print %v4 : vector<6x6xi32>
+
+    //
+    // Should be the same as dense output
+    // CHECK:    ( ( 0, 0, -1, -6, -1, 6 ),
+    // CHECK-SAME: ( -1, 0, 1, 0, 1, 0 ),
+    // CHECK-SAME: ( 0, -1, 1, 0, 0, 0 ),
+    // CHECK-SAME: ( -1, 0, 0, 0, 0, 0 ),
+    // CHECK-SAME: ( 0, 0, 3, 6, -3, -6 ),
+    // CHECK-SAME: ( 2, -1, 3, 0, -3, 0 ) )
+    //
+    %all_sparse_CSR = sparse_tensor.convert %3
+      : tensor<6x6xi32, #CSR> to tensor<6x6xi32>
+    %v3 = vector.transfer_read %all_sparse_CSR[%c0, %c0], %i0
+      : tensor<6x6xi32>, vector<6x6xi32>
+    vector.print %v3 : vector<6x6xi32>
+
+    //
+    // Should be the same as dense output
+    // CHECK:    ( ( 0, 0, -1, -6, -1, 6 ),
+    // CHECK-SAME: ( -1, 0, 1, 0, 1, 0 ),
+    // CHECK-SAME: ( 0, -1, 1, 0, 0, 0 ),
+    // CHECK-SAME: ( -1, 0, 0, 0, 0, 0 ),
+    // CHECK-SAME: ( 0, 0, 3, 6, -3, -6 ),
+    // CHECK-SAME: ( 2, -1, 3, 0, -3, 0 ) )
+    //
+    %all_sparse_CSC = sparse_tensor.convert %5
+      : tensor<6x6xi32, #CSC> to tensor<6x6xi32>
+    %v5 = vector.transfer_read %all_sparse_CSC[%c0, %c0], %i0
+      : tensor<6x6xi32>, vector<6x6xi32>
+    vector.print %v5 : vector<6x6xi32>
+
+    // Release the resources.
+    bufferization.dealloc_tensor %sparse_input_DCSR : tensor<8x8xi32, #DCSR>
+    bufferization.dealloc_tensor %sparse_input_CSR : tensor<8x8xi32, #CSR>
+    bufferization.dealloc_tensor %sparse_input_CSC : tensor<8x8xi32, #CSC>
+    bufferization.dealloc_tensor %sparse_input_CD : tensor<8x8xi32, #CDR>
+    bufferization.dealloc_tensor %sparse_filter_DCSR : tensor<3x3xi32, #DCSR>
+    bufferization.dealloc_tensor %sparse_filter_CSR : tensor<3x3xi32, #CSR>
+    bufferization.dealloc_tensor %sparse_filter_CD : tensor<3x3xi32, #CDR>
+    bufferization.dealloc_tensor %sparse_filter_CSC : tensor<3x3xi32, #CSC>
+
+    bufferization.dealloc_tensor %2 : tensor<6x6xi32, #DCSR>
+    bufferization.dealloc_tensor %3 : tensor<6x6xi32, #CSR>
+    bufferization.dealloc_tensor %4 : tensor<6x6xi32, #CDR>
+    bufferization.dealloc_tensor %5 : tensor<6x6xi32, #CSC>
+    return
+  }
+}


        


More information about the Mlir-commits mailing list