[Mlir-commits] [mlir] 5fd9d80 - [mlir][sparse] extend loop emitter to emit slice driven loops

Peiming Liu llvmlistbot at llvm.org
Wed Apr 12 20:29:46 PDT 2023


Author: Peiming Liu
Date: 2023-04-13T03:29:40Z
New Revision: 5fd9d801350d9232098d073ea04fd64db3cf8e1e

URL: https://github.com/llvm/llvm-project/commit/5fd9d801350d9232098d073ea04fd64db3cf8e1e
DIFF: https://github.com/llvm/llvm-project/commit/5fd9d801350d9232098d073ea04fd64db3cf8e1e.diff

LOG: [mlir][sparse] extend loop emitter to emit slice driven loops

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D142930

Added: 
    mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conv_2d_slice_based.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conv_3d_slice_based.mlir

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
    mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
    mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 5a6ffeadfa999..bdce4cbce876b 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -409,6 +409,9 @@ class Merger {
     lvlTypes[t][i] = dlt;
     loopToLvl[t][i] = lvl;
     lvlToLoop[t][lvl] = i;
+    // TODO: Maybe we should favor a constant loop bound when there are multiple
+    // choices.
+    loopBounds[i] = std::make_pair(t, lvl);
   }
 
   using ForeachTensorLoopIdCallback = function_ref<void(
@@ -436,7 +439,7 @@ class Merger {
         // This must be an undefined level.
         assert(!optLvl.has_value());
         // Slice the tid along the dependent level to iterate current loop.
-        callback(b, t, loopToDependencies[loop(b)][t], lvlTp,
+        callback(b, t, getLoopDependentLevel(b), lvlTp,
                  /*isIdxReduc=*/true);
       } else {
         callback(b, t, optLvl, lvlTp, /*isIdxReduc=*/false);
@@ -447,11 +450,13 @@ class Merger {
   /// Sets whether the output tensor is sparse or not.
   void setHasSparseOut(bool s) { hasSparseOut = s; }
 
-  /// Establishes the two-way map that i <-> <t, lvl>.
-  void setLoopDependentTensorLevel(LoopId i, TensorId t, Level lvl) {
+  /// Establishes the two-way map that i <-> <t, lvl, dlt>.
+  void setLoopDependentTensorLevel(LoopId i, TensorId t, Level lvl,
+                                   DimLevelType dlt) {
     assert(isValidLoopId(i) && isValidLevel(t, lvl));
-    loopToDependencies[i][t] = lvl;
-    levelToDependentIdx[t][lvl].push_back(i);
+    assert(!loopToDependencies[i][t].has_value()); // must be the first def
+    loopToDependencies[i][t] = std::make_pair(lvl, dlt);
+    levelToDependentLoop[t][lvl].push_back(i);
   }
 
   /// Whether the loop has dependent slice.
@@ -464,7 +469,7 @@ class Merger {
   /// expression on t_l, e.g., A[i+j] => {i, j}
   std::vector<LoopId> &getDependentLoops(TensorId t, Level lvl) {
     assert(isValidLevel(t, lvl));
-    return levelToDependentIdx[t][lvl];
+    return levelToDependentLoop[t][lvl];
   }
 
   /// Returns the defining [tid, lvl] for the loop.
@@ -473,8 +478,8 @@ class Merger {
     return loopBounds[i];
   }
 
-  /// Checks whether the TensorLoopId represents a tensor level with
-  /// non-trivial index expression on it.
+  /// Checks whether the TensorLoopId represents a tensor level contains
+  /// non-trivial index expression.
   bool isLvlWithNonTrivialIdxExp(TensorLoopId b) const {
     const TensorId t = tensor(b);
     const LoopId i = loop(b);
@@ -482,6 +487,26 @@ class Merger {
     return loopToDependencies[i][t].has_value();
   }
 
+  /// Checks whether the TensorLoopId represents a sparse tensor level contains
+  /// non-trivial index expression.
+  bool isSparseLvlWithNonTrivialIdxExp(TensorLoopId b) const {
+    if (isLvlWithNonTrivialIdxExp(b)) {
+      auto dlt = getLoopDependentLevelType(b);
+      return isCompressedDLT(dlt) || isSingletonDLT(dlt);
+    }
+    return false;
+  }
+
+  Level getLoopDependentLevel(TensorLoopId b) const {
+    assert(isLvlWithNonTrivialIdxExp(b));
+    return loopToDependencies[loop(b)][tensor(b)]->first;
+  }
+
+  DimLevelType getLoopDependentLevelType(TensorLoopId b) const {
+    assert(isLvlWithNonTrivialIdxExp(b));
+    return loopToDependencies[loop(b)][tensor(b)]->second;
+  }
+
   /// Convenience getters to immediately access the stored nodes.
   /// These methods return `const&` because the underlying objects must
   /// not be mutated by client code.  The only exception is for mutating
@@ -578,6 +603,7 @@ class Merger {
     return i != detail::kInvalidId && i < numLoops;
   }
   bool isValidLevel(TensorId t, Level lvl) const {
+    assert(levelToDependentLoop[t].size() == lvlToLoop[t].size());
     return isValidTensorId(t) && lvl < lvlToLoop[t].size();
   }
   bool isValidExprId(ExprId e) const {
@@ -626,13 +652,15 @@ class Merger {
   // Map from a loop to its dependencies if any.
   // The dependencies of a loop is a set of (tensor, level) pairs.
   // It is currently only set for non-trivial index expressions.
-  // E.g., A[i+j] => i and j will have dependencies {A0} to indicate that
-  // i and j are used in the non-trivial index expression on A0.
-  std::vector<std::vector<std::optional<Level>>> loopToDependencies;
+  // E.g., A[i+j] => i and j will have dependencies {A0, dlt(A0)} to indicate
+  // that i and j are used in the non-trivial index expression on A0.
+  std::vector<std::vector<std::optional<std::pair<Level, DimLevelType>>>>
+      loopToDependencies;
+
   // The inverse map of ldxToDependencies from tensor level -> dependent loop
   // E.g., A[i+j], we have A0 => {i, j}, to indicate that A0 uses both {i, j}
   // to compute its indices.
-  std::vector<std::vector<std::vector<LoopId>>> levelToDependentIdx;
+  std::vector<std::vector<std::vector<LoopId>>> levelToDependentLoop;
 
   // Map from a loop to the [tid, lvl] pair that defines the loop boundary.
   std::vector<std::pair<TensorId, Level>> loopBounds;

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index 51a9c65714f87..b624aaddd21df 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -25,9 +25,15 @@ using namespace mlir::sparse_tensor;
 // File local helper functions.
 //===----------------------------------------------------------------------===//
 
-/// Generates a position/coordinate load from the sparse storage scheme.
-/// Narrower data types need to be zero extended before casting the
-/// value into the `Index` type used for looping and indexing.
+#define CMPI(p, l, r)                                                          \
+  (builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::p, l, r)           \
+       .getResult())
+
+#define C_IDX(v) (constantIndex(builder, loc, v))
+
+/// Generates a pointer/index load from the sparse storage scheme. Narrower
+/// data types need to be zero extended before casting the value into the
+/// index type used for looping and indexing.
 static Value genIndexLoad(OpBuilder &builder, Location loc, Value mem,
                           Value s) {
   // For the scalar case, we simply zero extend narrower indices into 64-bit
@@ -70,6 +76,27 @@ static Value toSliceCrd(OpBuilder &builder, Location loc, Value crd,
   return crd;
 }
 
+/// Generates code to compute the *absolute* offset of the slice based on the
+/// provide minimum coordinates in the slice.
+/// E.g., when reducing d0 + d1 + d2, we need two slices to fully reduced the
+/// expression, i,e, s1 = slice(T, d0), s2 = slice(s1, d1). The *absolute*
+/// offset is the offset computed relative to the initial tensors T.
+///
+/// When isNonEmpty == true, the computed offset is meaningless and should not
+/// be used during runtime, the method generates code to return 0 currently in
+/// that case.
+///
+/// offset = isNonEmpty && minCrd >= size ? minCrd - size + 1 : 0;
+static Value offsetFromMinCoord(OpBuilder &builder, Location loc, Value minCrd,
+                                Value size, Value isNonEmpty) {
+  Value geSize = CMPI(uge, minCrd, size);
+  Value pred = builder.create<arith::AndIOp>(loc, isNonEmpty, geSize);
+  Value mp1 = builder.create<arith::AddIOp>(loc, minCrd, C_IDX(1));
+  Value mms = builder.create<arith::SubIOp>(loc, mp1, size);
+  // This is the absolute offset related to the underly tensor.
+  return builder.create<arith::SelectOp>(loc, pred, mms, C_IDX(0));
+}
+
 /// Converts a coordinate relative to the underlying tensor to the coordinate
 /// relative to the slice, returns a extra reminder value
 // FIXME: that description says "tensorCrd -> sliceCrd"; but the function
@@ -102,21 +129,18 @@ LoopEmitter::genSliceLegitPredicate(OpBuilder &builder, Location loc, Value crd,
   // First, coord >= offset (skip the check if offset is known to be 0).
   if (auto staticOffset = enc.getStaticLvlSliceOffset(lvl);
       !(staticOffset.has_value() && *staticOffset == 0)) {
-    auto geOffset = builder.create<arith::CmpIOp>(
-        loc, arith::CmpIPredicate::uge, crd, offset);
+    auto geOffset = CMPI(uge, crd, offset);
     conds.push_back(geOffset);
   }
 
   // Second, coord_in_slice < length
-  auto ltLength = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
-                                                newCrd, lvlSizes[tid][lvl]);
+  auto ltLength = CMPI(ult, newCrd, lvlSizes[tid][lvl]);
   conds.push_back(ltLength);
 
   // Third, rem == 0 (skip the check if stride is known to be 1).
   if (auto staticStride = enc.getStaticLvlSliceStride(lvl);
       !(staticStride.has_value() && *staticStride == 1)) {
-    auto fitStride = builder.create<arith::CmpIOp>(
-        loc, arith::CmpIPredicate::eq, crdRem, constantIndex(builder, loc, 0));
+    auto fitStride = CMPI(eq, crdRem, C_IDX(0));
     conds.push_back(fitStride);
   }
 
@@ -134,7 +158,7 @@ LoopEmitter::genSliceLegitPredicate(OpBuilder &builder, Location loc, Value crd,
 
 Value LoopEmitter::genAddress(OpBuilder &builder, Location loc, TensorId tid,
                               Level lvl, Value crd) {
-  Value pos = lvl == 0 ? constantIndex(builder, loc, 0) : posits[tid][lvl - 1];
+  Value pos = lvl == 0 ? C_IDX(0) : posits[tid][lvl - 1];
   Value mul = builder.create<arith::MulIOp>(loc, highs[tid][lvl], pos);
   if (isSparseSlices[tid])
     crd = toSliceCrd(builder, loc, crd, sliceOffsets[tid][lvl],
@@ -177,8 +201,7 @@ Value LoopEmitter::genSegmentHigh(OpBuilder &builder, Location loc,
       /*afterBuilder=*/
       [](OpBuilder &builder, Location loc, ValueRange ivs) {
         // pos ++
-        Value nextPos = builder.create<arith::AddIOp>(
-            loc, ivs[0], constantIndex(builder, loc, 1));
+        Value nextPos = builder.create<arith::AddIOp>(loc, ivs[0], C_IDX(1));
         builder.create<scf::YieldOp>(loc, nextPos);
       });
   // Return the segment high.
@@ -187,7 +210,7 @@ Value LoopEmitter::genSegmentHigh(OpBuilder &builder, Location loc,
 
 Value LoopEmitter::genSparseCrd(OpBuilder &builder, Location loc, TensorId tid,
                                 Level dstLvl) {
-  Value crd = constantIndex(builder, loc, 0);
+  Value crd = C_IDX(0);
   const auto reassoc = getCollapseReassociation(tid, dstLvl);
   const unsigned reassocSize = reassoc.size();
   for (unsigned i = 0; i < reassocSize; i++) {
@@ -244,8 +267,13 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
   this->loopStack.reserve(numLoops);
   this->loopSeqStack.reserve(numLoops);
 
+  // Index-reduction related fields.
   this->dependentLvlMap.assign(
       numTensors, std::vector<std::vector<std::pair<TensorId, Level>>>());
+  this->slicePosBuffer.assign(numTensors, std::vector<std::vector<Value>>());
+  this->sliceSizes.assign(numTensors, std::vector<std::vector<Value>>());
+  this->sliceStack.assign(numTensors, std::vector<SliceInfo>());
+  this->levelReducedDep.assign(numTensors, std::vector<unsigned>());
 
   // Initialize nested types of `TensorId`-indexed fields.
   for (TensorId tid = 0; tid < numTensors; tid++) {
@@ -288,16 +316,30 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
     coordinatesBuffers[tid].assign(lvlRank, Value());
     sliceOffsets[tid].assign(lvlRank, Value());
     sliceStrides[tid].assign(lvlRank, Value());
+
+    // Slice-driven loops related initialization.
+    levelReducedDep[tid].assign(lvlRank, 0);
     dependentLvlMap[tid].assign(lvlRank,
                                 std::vector<std::pair<TensorId, Level>>());
+    slicePosBuffer[tid].assign(lvlRank, std::vector<Value>());
+    sliceSizes[tid].assign(lvlRank, std::vector<Value>());
+    sliceStack[tid].emplace_back(/*minCrd=*/Value(),
+                                 /*offset=*/Value(), /*isNonEmpty*/ Value(),
+                                 std::nullopt, 0);
     if (dimGetter) {
       auto reassoc = collapseReassoc[tid];
       Level dstRank = reassoc ? reassoc.size() : lvlRank;
       for (Level l = 0; l < dstRank; l++) {
         dependentLvlMap[tid][l] = dimGetter(tid, l);
+        unsigned depends = dependentLvlMap[tid][l].size();
+        if (depends == 0)
+          continue;
         // TODO: View-base collapse and dependent index reduction are not
         // compatible right now.
-        assert(!reassoc || dependentLvlMap[tid][l].empty());
+        assert(!reassoc);
+        // We need `depends - 1` slices to fully  the affine expression.
+        sliceSizes[tid][l].assign(depends - 1, nullptr);
+        slicePosBuffer[tid][l].assign(depends - 1, nullptr);
       }
     }
   }
@@ -398,6 +440,31 @@ void LoopEmitter::initializeLoopEmit(OpBuilder &builder, Location loc,
     // some loop preparation from tensor iteration, but will also (undesirably)
     // hoist the code ouside if-conditions.
   }
+
+  Type indexType = builder.getIndexType();
+  Value c0 = constantZero(builder, loc, indexType);
+  for (TensorId t = 0, e = tensors.size(); t < e; t++) {
+    auto rtp = tensors[t].getType().dyn_cast<RankedTensorType>();
+    if (!rtp)
+      continue;
+
+    Level lvlRank = SparseTensorType(rtp).getLvlRank();
+    for (Level lvl = 0; lvl < lvlRank; lvl++) {
+      if (!dependentLvlMap[t][lvl].empty()) {
+        ArrayRef<std::pair<TensorId, Level>> depLvls = dependentLvlMap[t][lvl];
+        // Needs at least two operands to form a non-trivial affine expression.
+        assert(depLvls.size() > 1);
+
+        Value size = c0;
+        for (unsigned e = depLvls.size() - 1; e >= 1; e--) {
+          auto [dt, dd] = depLvls[e];
+          size = builder.create<arith::AddIOp>(loc, size, lvlSizes[dt][dd]);
+          sliceSizes[t][lvl][e - 1] = size;
+        }
+      }
+    }
+  }
+  localInsertPos = builder.getInsertionPoint()->getPrevNode();
 }
 
 void LoopEmitter::enterNewLoopSeq(OpBuilder &builder, Location loc,
@@ -405,12 +472,47 @@ void LoopEmitter::enterNewLoopSeq(OpBuilder &builder, Location loc,
                                   ArrayRef<Level> lvls) {
   // TODO: sort
   assert(loopSeqStack.size() == loopStack.size());
-  // Universal Index starts from 0.
-  loopSeqStack.emplace_back(constantIndex(builder, loc, 0));
   // Prepares for all the tensors used in the current loop sequence.
-  assert(tids.size() == lvls.size());
-  for (auto [tid, lvl] : llvm::zip(tids, lvls))
-    prepareLoopOverTensorAtLvl(builder, loc, tid, lvl);
+  std::vector<std::tuple<TensorId, Level, bool>> slicedTids;
+  for (auto [tid, lvl] : llvm::zip(tids, lvls)) {
+    if (!dependentLvlMap[tid][lvl].empty()) {
+      bool fullyRed = genSliceBegin(builder, loc, tid, lvl);
+      slicedTids.emplace_back(tid, lvl, fullyRed);
+    } else {
+      prepareLoopOverTensorAtLvl(builder, loc, tid, lvl);
+    }
+  }
+
+  // Universal Index starts from 0.
+  loopSeqStack.emplace_back(C_IDX(0), std::move(slicedTids));
+}
+
+void LoopEmitter::exitCurrentLoopSeq(OpBuilder &builder, Location loc) {
+  assert(loopSeqStack.size() == loopStack.size() + 1);
+
+  const auto &slicedTids = loopSeqStack.back().second;
+
+  // Depending on whether the slice is resolved or not at current loop sequence,
+  // end them in 
diff erent ways.
+  for (auto [tid, lvl, res] : slicedTids) {
+    if (!res) {
+      // If this is a unresolved-slice-driven loop, pops out the slice.
+      assert(sliceStack[tid].back().slicedOnLvl == lvl);
+      sliceStack[tid].pop_back();
+    } else {
+      // Else this is a resolved-slice, and advance posit similar to TACO.
+      Value c1 = C_IDX(1), c2 = C_IDX(2);
+
+      // pIdx += 2, we finished the current lvl, advance the pointer index of
+      // the previous level by two to skip the [pLo, pHi] for current level.
+      Value sPtrBuf = slicePosBuffer[tid][lvl].back();
+      Value curP = genIndexLoad(builder, loc, sPtrBuf, c1);
+      Value nexP = builder.create<arith::AddIOp>(loc, curP, c2);
+      // TODO: we could probably use an SSA value for it.
+      builder.create<memref::StoreOp>(loc, nexP, sPtrBuf, c1);
+    }
+  }
+  loopSeqStack.pop_back();
 }
 
 Value LoopEmitter::genAffine(OpBuilder &builder, Location loc, AffineExpr a) {
@@ -438,51 +540,30 @@ Value LoopEmitter::genAffine(OpBuilder &builder, Location loc, AffineExpr a) {
   }
   case AffineExprKind::Constant: {
     int64_t c = a.cast<AffineConstantExpr>().getValue();
-    return constantIndex(builder, loc, c);
+    return C_IDX(c);
   }
   default:
     llvm_unreachable("unexpected affine subscript");
   }
 }
 
-Operation *LoopEmitter::enterLoopOverTensorAtLvl(
-    OpBuilder &builder, Location loc, ArrayRef<TensorId> tids,
-    ArrayRef<Level> lvls, MutableArrayRef<Value> reduc, bool isParallel) {
-  // TODO: support multiple return on parallel for?
-  assert(!isParallel || reduc.size() <= 1);
-  bool isSparseInput = false;
-  TensorId tid = tids.front();
-  Level dstLvl = lvls.front();
-  assert(tids.size() == lvls.size());
-  for (auto [t, l] : llvm::zip(tids, lvls)) {
-    // TODO: this check for validity of the (t,l) pairs should be
-    // checked/enforced at the callsites, if possible.
-    assert(isValidLevel(t, l));
-    assert(!coords[t][l]); // We cannot re-enter the same level
-    const auto lvlTp = lvlTypes[t][l];
-    const bool isSparse = isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp);
-    // Must be a recognizable level-type.
-    assert(isSparse || isDenseDLT(lvlTp));
-    // We can at most have one sparse input, otherwise, a while loop is required
-    // to co-iterate multiple sparse tensors.
-    assert(!isSparseInput || !isSparse);
-    if (isSparse) {
-      tid = t;
-      dstLvl = l;
-    }
-    isSparseInput = isSparseInput || isSparse;
-  }
+Operation *LoopEmitter::emitForLoopOverTensorAtLvl(OpBuilder &builder,
+                                                   Location loc, TensorId tid,
+                                                   Level dstLvl,
+                                                   MutableArrayRef<Value> reduc,
+                                                   bool isParallel) {
+  bool isSparseCond = isCompressedDLT(lvlTypes[tid][dstLvl]) ||
+                      isSingletonDLT(lvlTypes[tid][dstLvl]);
 
   const auto reassoc = getCollapseReassociation(tid, dstLvl);
   // TODO: support dynamic slices.
-  // Use the first source-level here to build the loop bound (which is
-  // also the biggest range).
+  // Uses the first dimension here to build the loop bound (which is also the
+  // biggest range).
   const Level srcLvl = reassoc.front();
-  const Value step = constantIndex(builder, loc, 1);
-  /// FIXME: See the [CLARIFY_POSITS_LVL] note in the header.
-  const Value lo = isSparseInput ? posits[tid][srcLvl]  // current position
-                                 : loopSeqStack.back(); // universal index
-  const Value hi = highs[tid][srcLvl];
+  Value step = C_IDX(1);
+  Value lo = isSparseCond ? posits[tid][srcLvl]        // current offset
+                          : loopSeqStack.back().first; // universal index
+  Value hi = highs[tid][srcLvl];
 
   Operation *loop = nullptr;
   Value iv;
@@ -518,7 +599,7 @@ Operation *LoopEmitter::enterLoopOverTensorAtLvl(
   assert(loop && iv);
 
   Value crd;
-  if (isSparseInput) {
+  if (isSparseCond) {
     assert(reassoc.size() == 1 || isUniqueCOOType(tensors[tid].getType()));
     // For COO, the position is the same across consecutive levels.
     /// FIXME: See the [CLARIFY_POSITS_LVL] note in the header.
@@ -530,7 +611,7 @@ Operation *LoopEmitter::enterLoopOverTensorAtLvl(
     crd = iv;
   }
 
-  if (isSparseSlices[tid] && isSparseInput) {
+  if (isSparseSlices[tid] && isSparseCond) {
     // For sparse level slices, we need to filter out invalid coordinates that
     // are not included in the slice.
     SmallVector<Type> types;
@@ -559,17 +640,120 @@ Operation *LoopEmitter::enterLoopOverTensorAtLvl(
   }
 
   assert(crd);
-  coords[tid][srcLvl] = crd;
-  // NOTE: we can also prepare for next level here in advance
-  // Push the loop into stack
-  loopStack.emplace_back(ArrayRef<TensorId>(tid), ArrayRef<Level>(srcLvl), loop,
-                         builder.getInsertionBlock(), crd, loopTag);
-  // Emit extra locals.
-  emitExtraLocalsForTensorsAtDenseLvls(builder, loc, tids, lvls);
+  coords[tid][dstLvl] = crd;
+  return loop;
+}
 
+Operation *LoopEmitter::emitWhileLoopOverSliceAtSparseLvl(
+    OpBuilder &builder, Location loc, Value pLo, Value pHi, Value offset,
+    Value sliceSize, TensorId tid, Level lvl, MutableArrayRef<Value> reduc) {
+  // TODO: we should generalize the method to support iteration over for
+  // normal slices as well to allow early break.
+  Operation *insertPoint = nullptr;
+  Operation *loop =
+      genSliceLvlTraverseLoop(
+          builder, loc, pLo, pHi, offset, sliceSize, tid, lvl, reduc,
+          /*genYield=*/false, // unaware of the yield values from user yet
+          [this, tid, lvl, reduc, offset,
+           &insertPoint](OpBuilder &builder, Location loc, Value iv,
+                         MutableArrayRef<Value> innerReduc) {
+            assert(innerReduc.size() == reduc.size());
+            // Updates users' reduction variable inplace
+            for (unsigned i = 0, e = reduc.size(); i < e; i++)
+              reduc[i] = innerReduc[i];
+            // Loads the coordinates.
+            Value absC =
+                genIndexLoad(builder, loc, coordinatesBuffers[tid][lvl], iv);
+
+            // We need to substract the offset to get relative coordinates.
+            // TODO: how to assert relC >=0 during runtime?
+            insertPoint = builder.create<arith::SubIOp>(loc, absC, offset);
+            posits[tid][lvl] = iv;
+            coords[tid][lvl] = insertPoint->getResult(0);
+          })
+          .first;
+  // Sets the insertionn pointer inside loop body.
+  builder.setInsertionPointAfter(insertPoint);
   return loop;
 }
 
+Operation *LoopEmitter::enterLoopOverTensorAtLvl(
+    OpBuilder &builder, Location loc, ArrayRef<TensorId> tids,
+    ArrayRef<Level> lvls, MutableArrayRef<Value> reduc, bool isParallel) {
+  // TODO: support multiple return on parallel for?
+  assert(!isParallel || reduc.size() <= 1);
+  bool isSparseCond = false, isSliceCond = false;
+  size_t tid = tids.front(), lvl = lvls.front();
+
+  // Finds out the tensor level that we should use to generate loops. Amongs all
+  // the tensor levels, there is at most one sparse tensor level.
+  for (auto [t, l] : llvm::zip(tids, lvls)) {
+    assert(lvlTypes[t].size() > l);         // Must be a valid tid, dim pair
+    assert(!coords[t][l] ||                 // We cannot re-enter the same level
+           !dependentLvlMap[t][l].empty()); // unless it is a slice-driver loop
+    auto dimType = lvlTypes[t][l];
+    // Must be a recognizable DLT.
+    assert(isDenseDLT(dimType) || isCompressedDLT(dimType) ||
+           isSingletonDLT(dimType));
+
+    // This is a slice-driven loop.
+    if (!dependentLvlMap[t][l].empty()) {
+      assert(!isSliceCond && !isSparseCond);
+      isSliceCond = true;
+      tid = t;
+      lvl = l;
+      continue;
+    }
+
+    bool isSparse = isCompressedDLT(dimType) || isSingletonDLT(dimType);
+    // We can at most have one sparse input, otherwise, a while loop is
+    // required to co-iterate multiple sparse tensors.
+    assert(!isSparseCond || !isSparse);
+    assert(!isSliceCond || !isSparseCond);
+    if (isSparse) {
+      tid = t;
+      lvl = l;
+    }
+    isSparseCond = isSparseCond || isSparse;
+  }
+
+  // Generates loops 
diff erently depending on whether we need a slice-driven
+  // loop or a simple level traversal loop.
+  Operation *l = nullptr;
+  if (isSliceCond) {
+    bool fullyReduced = depFullyReduced(tid, lvl);
+    if (!fullyReduced) {
+      l = emitSliceDrivenLoopOverTensorAtLvl(builder, loc, tid, lvl, reduc);
+    } else {
+      // If the slice is fully reduced, we can now use TACO-based algorithm to
+      // iterate it.
+      l = emitWhileLoopOverSliceAtSparseLvl(
+          builder, loc, posits[tid][lvl], highs[tid][lvl],
+          getFinalSliceOnLvl(tid, lvl).offset, sliceSizes[tid][lvl].back(), tid,
+          lvl, reduc);
+    }
+    levelReducedDep[tid][lvl]++;
+    // We can also prepare for next dim here in advance
+    // Pushes the loop into stack.
+    loopStack.emplace_back(
+        ArrayRef<TensorId>(), ArrayRef<Level>(), ArrayRef<TensorId>(tid),
+        ArrayRef<Level>(lvl), ArrayRef<bool>(fullyReduced), l,
+        builder.getInsertionBlock(), coords[tid][lvl], loopTag);
+  } else {
+    l = emitForLoopOverTensorAtLvl(builder, loc, tid, lvl, reduc, isParallel);
+    // We can also prepare for next dim here in advance
+    // Pushes the loop into stack.
+    loopStack.emplace_back(ArrayRef<TensorId>(tid), ArrayRef<Level>(lvl),
+                           ArrayRef<TensorId>(), ArrayRef<Level>(),
+                           ArrayRef<bool>(), l, builder.getInsertionBlock(),
+                           coords[tid][lvl], loopTag);
+  }
+
+  // Emit extra locals.
+  emitExtraLocalsForTensorsAtDenseLvls(builder, loc, tids, lvls);
+  return l;
+}
+
 Operation *LoopEmitter::enterFilterLoopOverTensorAtLvl(
     OpBuilder &builder, Location loc, TensorId tid, Level lvl,
     AffineExpr affine, MutableArrayRef<Value> reduc) {
@@ -582,7 +766,7 @@ Operation *LoopEmitter::enterFilterLoopOverTensorAtLvl(
   // break when exceeding (for ordered levels).
   // TODO: There are many other potiential opportunities that we might apply in
   // the future. E.g., we could use binary search to locate positions.
-  const Value step = constantIndex(builder, loc, 1);
+  const Value step = C_IDX(1);
   const Value pLo = posits[tid][lvl];
   const Value pHi = highs[tid][lvl];
   scf::ForOp forOp = builder.create<scf::ForOp>(loc, pLo, pHi, step, reduc);
@@ -604,8 +788,7 @@ Operation *LoopEmitter::enterFilterLoopOverTensorAtLvl(
   // Generate an if-condition to filter out coordinates that are not
   // equal to the result of the affine expression.
   Value expected = genAffine(builder, loc, affine);
-  auto pred = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, crd,
-                                            expected);
+  auto pred = CMPI(eq, crd, expected);
   SmallVector<Type> types;
   for (Value red : reduc) {
     types.push_back(red.getType());
@@ -631,8 +814,10 @@ Operation *LoopEmitter::enterFilterLoopOverTensorAtLvl(
 
   // NOTE: we can also prepare for next lvl here in advance
   // Push the loop into stack
-  loopStack.emplace_back(ArrayRef<TensorId>(tid), ArrayRef<Level>(lvl), forOp,
-                         builder.getInsertionBlock(), crd, nullptr);
+  loopStack.emplace_back(ArrayRef<TensorId>(tid), ArrayRef<Level>(lvl),
+                         ArrayRef<TensorId>(), ArrayRef<Level>(),
+                         ArrayRef<bool>(), forOp, builder.getInsertionBlock(),
+                         coords[tid][lvl], nullptr);
   return forOp;
 }
 
@@ -648,20 +833,24 @@ void LoopEmitter::genDenseAffineAddress(OpBuilder &builder, Location loc,
 Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
     OpBuilder &builder, Location loc, ArrayRef<TensorId> tids,
     ArrayRef<Level> lvls, bool needsUniv, MutableArrayRef<Value> reduc) {
+  // NOTE: the slice driven tensor-related reduction variable must
+  // appear before normal tensors.
   assert(tids.size() == lvls.size());
   SmallVector<Type> types;
   SmallVector<Value> operands;
   // Construct the while-loop with a parameter for each coordinate.
   const Type indexType = builder.getIndexType();
   for (auto [tid, lvl] : llvm::zip(tids, lvls)) {
+    // TODO: support coiteration with slice driven tensors.
     const auto lvlTp = lvlTypes[tid][lvl];
+    assert(dependentLvlMap[tid][lvl].empty() && "TODO: not yet implemented");
     if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp)) {
       const auto reassoc = getCollapseReassociation(tid, lvl);
       for (unsigned i = 0, e = reassoc.size() - 1; i < e; i++) {
         if (!isUniqueDLT(lvlTypes[tid][reassoc[i]])) {
           // This is the segment high for each non-unique levels.
           types.push_back(indexType);
-          operands.push_back(constantIndex(builder, loc, 0));
+          operands.push_back(C_IDX(0));
         }
       }
       const auto pos = posits[tid][reassoc.front()];
@@ -678,7 +867,7 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
   if (needsUniv) {
     types.push_back(indexType);
     // Update universal index.
-    operands.push_back(loopSeqStack.back());
+    operands.push_back(loopSeqStack.back().first);
   }
   assert(types.size() == operands.size());
   scf::WhileOp whileOp = builder.create<scf::WhileOp>(loc, types, operands);
@@ -707,8 +896,7 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
       Value op1 = before->getArgument(o);
       // We used the first level bound as the bound the collapsed set of levels.
       Value op2 = highs[tid][reassoc.front()];
-      Value opc = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
-                                                op1, op2);
+      Value opc = CMPI(ult, op1, op2);
       cond = cond ? builder.create<arith::AndIOp>(loc, cond, opc) : opc;
       // Update positions
       Value pos = after->getArgument(o++);
@@ -752,8 +940,7 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
     //
     // This "idx" is the index into `llvm::zip(tids, lvls)`
     for (auto [pred, idx] : slicesPreds) {
-      Value nextPos = builder.create<arith::AddIOp>(
-          loc, yields[idx], constantIndex(builder, loc, 1));
+      Value nextPos = builder.create<arith::AddIOp>(loc, yields[idx], C_IDX(1));
       yields[idx] =
           builder.create<arith::SelectOp>(loc, pred, yields[idx], nextPos);
     }
@@ -783,8 +970,7 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
       if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp)) {
         const auto crd = coords[tid][lvl];
         if (min) {
-          Value cmp = builder.create<arith::CmpIOp>(
-              loc, arith::CmpIPredicate::ult, crd, min);
+          Value cmp = CMPI(ult, crd, min);
           min = builder.create<arith::SelectOp>(loc, cmp, crd, min);
         } else {
           min = crd;
@@ -798,8 +984,9 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
   }
 
   // Sets up the loop stack.
-  loopStack.emplace_back(tids, lvls, whileOp, builder.getInsertionBlock(), min,
-                         loopTag);
+  loopStack.emplace_back(tids, lvls, ArrayRef<TensorId>(), ArrayRef<Level>(),
+                         ArrayRef<bool>(), whileOp, builder.getInsertionBlock(),
+                         min, loopTag);
   assert(loopStack.size() == loopSeqStack.size());
 
   for (auto [tid, dstLvl] : llvm::zip(tids, lvls)) {
@@ -869,8 +1056,8 @@ void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
   if (isDenseDLT(lvlTp))
     return;
 
-  const Value c0 = constantIndex(builder, loc, 0);
-  const Value c1 = constantIndex(builder, loc, 1);
+  const Value c0 = C_IDX(0);
+  const Value c1 = C_IDX(1);
   for (const Level srcLvl : getCollapseReassociation(tid, dstLvl)) {
     // Either the first level, or the previous level has been set.
     /// FIXME: See the [CLARIFY_POSITS_LVL] note in the header.
@@ -1022,6 +1209,7 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
   auto whileOp = llvm::cast<scf::WhileOp>(loopInfo.loop);
   builder.setInsertionPointToEnd(loopInfo.userCodeBlock);
   Value iv = loopInfo.iv;
+
   // Finalize the induction. Note that the induction could be performed
   // in the individual if-branches to avoid re-evaluating the conditions.
   // However, that would result in a rather elaborate forest of yield
@@ -1029,7 +1217,35 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
   // after the if-statements more closely resembles code generated by TACO.
   unsigned o = 0;
   SmallVector<Value> operands;
-  Value one = constantIndex(builder, loc, 1);
+  unsigned delta = 0;
+  for (auto [tid, lvl, resolved] : llvm::zip(
+           loopInfo.slicedTids, loopInfo.slicedLvls, loopInfo.sliceReduced)) {
+    levelReducedDep[tid][lvl]--;
+    if (!resolved) {
+      genSliceNextInduction(builder, loc, whileOp, tid, lvl, operands, o);
+      continue;
+    }
+    // TODO: We need to distinguish coiterate loop with slice-driven loop and
+    // fully reduced while op for iterating one slices.
+    // FIXME: since we didn't implement coiteration, this must be iteration
+    // just on fully resolved slice.
+    assert(loopInfo.slicedTids.size() == 1 && loopInfo.tids.empty());
+    // The if guard to filter out out-range coordinates.
+    assert(llvm::isa<scf::IfOp>(builder.getInsertionBlock()->getParentOp()));
+    posits[tid][lvl] = whileOp->getResult(o++);
+    // FIXME: we are not using continue here since we do not support
+    // coiteration on slices. But it need to be treated similarly as the
+    // universal index.
+    o++; // skip continue flag.
+    // Since we did not push two results from whileOp. The size of the
+    // operands vector is smaller than the actual number of return values from
+    // the whileOp.
+    // It is because we are actually generating yield in the IfOp inside the
+    // whileOp to only iterates over inbound coordinates within the slices.
+    delta += 2;
+  };
+
+  Value one = C_IDX(1);
   for (auto [tid, dstLvl] : llvm::zip(loopInfo.tids, loopInfo.lvls)) {
     const auto lvlTp = lvlTypes[tid][dstLvl];
     if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp)) {
@@ -1044,8 +1260,7 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
       }
       const Value crd = coords[tid][dstLvl];
       const Value pos = posits[tid][dstLvl];
-      Value cmp =
-          builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, crd, iv);
+      Value cmp = CMPI(eq, crd, iv);
       // If the loop contains a coiteration with non-unique level, we fast
       // forward all the duplicated coords by setting the position to the
       // segment high.
@@ -1080,15 +1295,15 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
   }
 
   // An (optional) universal index.
-  if (operands.size() < whileOp.getNumResults()) {
-    assert(operands.size() + 1 == whileOp.getNumResults());
+  if (operands.size() + delta < whileOp.getNumResults()) {
+    assert(operands.size() + delta + 1 == whileOp.getNumResults());
     // The last one is the universial index.
     operands.push_back(builder.create<arith::AddIOp>(loc, iv, one));
     // update the loop starting point of current loop sequence
-    loopSeqStack.back() = whileOp->getResult(o++);
+    loopSeqStack.back().first = whileOp->getResult(o++);
   }
 
-  assert(o == operands.size());
+  assert(o == operands.size() + delta);
   builder.create<scf::YieldOp>(loc, operands);
   builder.setInsertionPointAfter(whileOp);
 }
@@ -1109,3 +1324,629 @@ void LoopEmitter::exitCurrentLoop(RewriterBase &rewriter, Location loc,
   assert(loopStack.size() == loopSeqStack.size());
   loopStack.pop_back();
 }
+
+//===----------------------------------------------------------------------===//
+// Slice-driven loop related methods.
+//===----------------------------------------------------------------------===//
+
+unsigned LoopEmitter::remDepOnLevel(TensorId tid, Level lvl) const {
+  unsigned totalDependencies = dependentLvlMap[tid][lvl].size();
+  if (totalDependencies != 0) {
+    assert(totalDependencies >= 2);
+    return totalDependencies - levelReducedDep[tid][lvl];
+  }
+  return totalDependencies;
+}
+
+const LoopEmitter::SliceInfo &LoopEmitter::getFinalSliceOnLvl(TensorId tid,
+                                                              Level lvl) {
+  // Finds the most-recent slice using a reverse iteration.
+  for (auto it = sliceStack[tid].rbegin(), ie = sliceStack[tid].rend(); it < ie;
+       it++) {
+    if (it->slicedOnLvl == lvl) { // the level matched
+      // Must be the final slice we need to fully reduced the expression too.
+      assert(it->depth == dependentLvlMap[tid][lvl].size() - 1);
+      return *it;
+    }
+  }
+
+  llvm_unreachable("Failed to find sliceInfo");
+}
+
+// Generates a while loop to iterate over a slice sparse level as follows.
+//
+// while(loopLo < loopHi) {
+//   if (coords[loopLo] < offset + size) {
+//     body_builder
+//   } else {
+//    break;
+//   }
+//   loopLo ++;
+// }
+std::pair<Operation *, ValueRange> LoopEmitter::genSliceLvlTraverseLoop(
+    OpBuilder &builder, Location loc, Value loopLo, Value loopHi, Value offset,
+    Value size, TensorId tid, Level lvl, ValueRange userReduc, bool genYield,
+    llvm::function_ref<void(OpBuilder &, Location, Value,
+                            MutableArrayRef<Value>)>
+        bodyBuilder) {
+  Value c1 = C_IDX(1);
+  Value sliceHi = builder.create<arith::AddIOp>(loc, offset, size);
+
+  SmallVector<Value> reduc = {
+      loopLo,                         // loop lower bounds
+      constantI1(builder, loc, true), // continue
+  };
+  // Append user required reduction value.
+  reduc.append(userReduc.begin(), userReduc.end());
+  scf::WhileOp whileOp = builder.create<scf::WhileOp>(
+      loc, ValueRange(reduc).getTypes(), reduc,
+      /*beforeBuilder=*/
+      [loopHi](OpBuilder &builder, Location loc, ValueRange args) {
+        Value lo = args[0];
+        Value cont = args[1];
+        Value inBound = CMPI(ult, lo, loopHi);
+        Value cond = builder.create<arith::AndIOp>(loc, cont, inBound);
+        // continue if not yet break nor out of bound.
+        builder.create<scf::ConditionOp>(loc, cond, args);
+      },
+      /*afterBuilder=*/
+      [this, c1, tid, lvl, sliceHi, genYield,
+       bodyBuilder](OpBuilder &builder, Location loc, ValueRange args) {
+        Value iv = args[0];
+        Value coord =
+            genIndexLoad(builder, loc, coordinatesBuffers[tid][lvl], iv);
+        Value cont = CMPI(ult, coord, sliceHi);
+        TypeRange types = args.drop_front(2).getTypes();
+
+        auto ifOp = builder.create<scf::IfOp>(loc, types, cont, true);
+        {
+          // 2 reduction variable maintained by us.
+          SmallVector<Value> ifRet = args.drop_front(2);
+          assert(ifRet.size() == args.size() - 2);
+
+          OpBuilder::InsertionGuard guard(builder);
+          // If coord >= sliceHi.
+          builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+          // Coordinates is OOB, just yield.
+          builder.create<scf::YieldOp>(loc, ifRet);
+
+          // If coord < sliceHi.
+          builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+          // Delegates to users' callback.
+          bodyBuilder(builder, loc, iv, ifRet);
+          if (genYield) {
+            builder.setInsertionPointToEnd(&ifOp.getThenRegion().front());
+            builder.create<scf::YieldOp>(loc, ifRet);
+          }
+        }
+        // Marks this speical ifOp to avoid sparisification finalizing it.
+        ifOp->setAttr(getLoopEmitterLoopAttrName(),
+                      StringAttr::get(builder.getContext(), "slice"));
+        // Insertion point restored to after ifOp.
+        SmallVector<Value> yields;
+        // Increase induction variable.
+        yields.push_back(builder.create<arith::AddIOp>(loc, iv, c1));
+        // Terminates the while loop according to the continue flag.
+        yields.push_back(cont);
+        yields.append(ifOp.getResults().begin(), ifOp.getResults().end());
+        builder.create<scf::YieldOp>(loc, yields);
+      });
+
+  builder.setInsertionPointAfter(whileOp);
+  return std::make_pair(whileOp, whileOp.getResults().drop_front(2));
+}
+
+// Generates a loop nest that traverse all the unresolved levels in between.
+// TODO: it can only handle all compressed tensors.
+//
+// for(int i = 0; i < slicePos.size(); i+=2) {
+//   loopLo = slicePos[i];
+//   loopHi = slicePos[i + 1];
+//
+//   // Then the same loop generated by genSliceLvlTraverse above.
+//   while (loopLo < loopHI) {
+//     if (pos[loopLo] < sliceHi) {
+//       bodyBuilder();
+//     } else {
+//       break;
+//     }
+//     loopLo ++;
+//   }
+// }
+ValueRange LoopEmitter::genUnResolvedSliceTreeTraverse(
+    OpBuilder &builder, Location loc, Value offset, TensorId tid, Level lvl,
+    size_t depth, ValueRange userReduc,
+    llvm::function_ref<void(OpBuilder &, Location, Value,
+                            MutableArrayRef<Value>)>
+        bodyBuilder) {
+
+  Value c0 = C_IDX(0), c1 = C_IDX(1), c2 = C_IDX(2);
+
+  // TODO: it only works on all compressed tensor.
+  Value sPtrBuf = slicePosBuffer[tid][lvl][depth];
+  Value pSt = c2;                                      // pointer starting index
+  Value mSz = genIndexLoad(builder, loc, sPtrBuf, c0); // memSize
+
+  auto forOp =
+      scf::buildLoopNest(
+          builder, loc, pSt, mSz, c2, userReduc,
+          [this, c1, tid, lvl, offset, sPtrBuf,
+           bodyBuilder](OpBuilder &builder, Location loc, ValueRange ivs,
+                        ValueRange iterArgs) -> scf::ValueVector {
+            // generate traversal for each level.
+            Value loopLo = genIndexLoad(builder, loc, sPtrBuf, ivs.front());
+            Value loopHi = genIndexLoad(
+                builder, loc, sPtrBuf,
+                builder.create<arith::AddIOp>(loc, ivs.front(), c1));
+            return genSliceLvlTraverseLoop(builder, loc, loopLo, loopHi, offset,
+                                           sliceSizes[tid][lvl].back(), tid,
+                                           lvl, iterArgs, true, bodyBuilder)
+                .second;
+          })
+          .loops.front();
+
+  // Insert after current while operation.
+  builder.setInsertionPointAfter(forOp);
+  return forOp.getResults();
+}
+
+void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
+                                        TensorId tid, Level lvl) {
+  assert(lvl == 0 && "TODO: handle non-first level");
+  Value c0 = C_IDX(0), c1 = C_IDX(1), c2 = C_IDX(2), c3 = C_IDX(3),
+        c4 = C_IDX(4);
+  Value size = sliceSizes[tid][0][0];
+  Value sPtrBuf = slicePosBuffer[tid][0][0];
+  Value pHi = genIndexLoad(builder, loc, positionsBuffers[tid][0], c1);
+  // Fills out pIdxBuffer[tid][lvl][0] with [/*memSize =*/4, 0, 0, pHi]
+  builder.create<memref::StoreOp>(loc, c4, sPtrBuf, c0);  // memSize = 4
+  builder.create<memref::StoreOp>(loc, c0, sPtrBuf, c1);  // index = 0
+  builder.create<memref::StoreOp>(loc, c0, sPtrBuf, c2);  // pLo = 0;
+  builder.create<memref::StoreOp>(loc, pHi, sPtrBuf, c3); // loaded pHi.
+
+  // This is an non empty tensor if 0 < pHi.
+  Value isNonEmpty = CMPI(ult, c0, pHi);
+  // The minimal coord must be at the first on ordered level.
+  // FIXME: Technically we should load the coord only when the slice is
+  // nonempty. though we assume that even on empty sparse tensors, a non-empty
+  // ptr/idx buffer is allocated for each level so it would not cause OOB to
+  // avoid generating a ifOp here.
+  Value minCrd = genIndexLoad(builder, loc, coordinatesBuffers[tid][0], c0);
+
+  // FIXME: We need the relative offset related to the base slice.
+  Value absOffset = offsetFromMinCoord(builder, loc, minCrd, size, isNonEmpty);
+  sliceStack[tid].emplace_back(minCrd, absOffset, isNonEmpty, lvl, /*depth=*/1);
+}
+
+// Fills in the slicePosBuffer before slice-driven loop begin.
+// TODO: it can only handle all compressed tensors.
+//
+// // Loop generated by `genUnResolvedSliceTreeTraverse`
+// for(int i = 0; i < slicePos.size(); i+=2) {
+//   loopLo = slicePos[i];
+//   loopHi = slicePos[i + 1];
+//   minCrd = max;
+//   while (loopLo < loopHi) {
+//     if (pos[loopLo] < sliceHi) {
+//       // bodyBuilder
+//       slicePos[tid].push_back(pos[loopLo]);
+//       slicePos[tid].push_back(pos[loopLo + 1]);
+//       minCrd = min(minCrd, crd[pos[loopLo]]);
+//     } else {
+//       break;
+//     }
+//     loopLo ++;
+//   }
+// }
+void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
+                                          TensorId tid, Level lvl) {
+  assert(isCompressedDLT(lvlTypes[tid][lvl]));
+  Value c0 = C_IDX(0), c1 = C_IDX(1), c2 = C_IDX(2);
+  const SliceInfo &sliceInfo = sliceStack[tid].back();
+  unsigned prevLvl = *sliceInfo.slicedOnLvl;
+  assert(lvl >= prevLvl);
+  // Either lvl = prevSlicedLvl, i.e., t[d0 + d1 + d2,...] (more than one
+  // variable need to be reduced on the same level).
+  // Or lvl > prevSliceLvl + 1, i.e., t[..., d2, d3 + d4] (having a
+  // simple dim expression in between).
+  assert(lvl == prevLvl + 1 && "TODO: not yet implemented");
+  // Check slice stack integrity.
+  assert(slicePosBuffer[tid][prevLvl].size() == sliceInfo.depth);
+  Value sPtrBuf = slicePosBuffer[tid][lvl].back();
+  SmallVector<Value, 3> reduc = {
+      constantI1(builder, loc, false), // isNonEmpty
+      lvlSizes[tid][lvl],              // minCoord
+      c2,                              // memSize
+  };
+
+  ValueRange result = genUnResolvedSliceTreeTraverse(
+      builder, loc, sliceInfo.offset, tid, prevLvl, sliceInfo.depth - 1, reduc,
+      [this, c1, c2, tid, lvl, sPtrBuf](OpBuilder &builder, Location loc,
+                                        Value iv,
+                                        MutableArrayRef<Value> reduc) {
+        Value &nonEmpty = reduc[0];
+        Value &minCrd = reduc[1];
+        Value &curMemSz = reduc[2];
+
+        Value pHi = builder.create<arith::AddIOp>(loc, iv, c1);
+        Value sPLo = genIndexLoad(builder, loc, positionsBuffers[tid][lvl], iv);
+        Value sPHi =
+            genIndexLoad(builder, loc, positionsBuffers[tid][lvl], pHi);
+
+        // isNonEmpty = isNonEmpty || lvlNonEmpty, i.e., as long as there is one
+        // non-empty lvl, the slice is non-empty.
+        Value lvlNonEmpty = CMPI(ult, sPLo, sPHi);
+        nonEmpty = builder.create<arith::OrIOp>(loc, lvlNonEmpty, nonEmpty);
+
+        // Update the minimum coordinate.
+        auto ifNonEmpty = builder.create<scf::IfOp>(loc, builder.getIndexType(),
+                                                    lvlNonEmpty, true);
+        {
+          // Generate Code as follows.
+          //
+          // if (nonEmpty) {
+          //   minCrd = min(minCrd, crd[pos[pLo]]);
+          // }
+          OpBuilder::InsertionGuard guard(builder);
+          builder.setInsertionPointToStart(ifNonEmpty.thenBlock());
+          Value curC =
+              genIndexLoad(builder, loc, coordinatesBuffers[tid][lvl], sPLo);
+          Value isSmaller = CMPI(ult, curC, minCrd);
+          Value newMin =
+              builder.create<arith::SelectOp>(loc, isSmaller, curC, minCrd);
+          builder.create<scf::YieldOp>(loc, newMin);
+          builder.setInsertionPointToStart(ifNonEmpty.elseBlock());
+          builder.create<scf::YieldOp>(loc, minCrd);
+        }
+        minCrd = ifNonEmpty.getResult(0);
+        builder.create<memref::StoreOp>(loc, sPLo, sPtrBuf, curMemSz);
+        Value nxtMemSize = builder.create<arith::AddIOp>(loc, curMemSz, c1);
+        builder.create<memref::StoreOp>(loc, sPHi, sPtrBuf, nxtMemSize);
+        // updates the size of the memory curMemSize += 2
+        curMemSz = builder.create<arith::AddIOp>(loc, curMemSz, c2);
+      });
+
+  unsigned depth = levelReducedDep[tid][lvl];
+  Value size = sliceSizes[tid][lvl][depth];
+  Value isNonEmpty = result[0];
+  Value minCrd = result[1];
+  // Two metadata [memSize, idx].
+  // TODO: Can use an SSA value for these two metadata
+  builder.create<memref::StoreOp>(loc, result[2], sPtrBuf, c0);
+  builder.create<memref::StoreOp>(loc, c0, sPtrBuf, c1);
+  // FIXME: we need the relative offset related to the base slice.
+  Value absOffset = offsetFromMinCoord(builder, loc, minCrd, size, isNonEmpty);
+  sliceStack[tid].emplace_back(minCrd, absOffset, isNonEmpty, lvl, depth + 1);
+}
+
+bool LoopEmitter::genSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
+                                Level lvl) {
+  Value c1 = C_IDX(1), c2 = C_IDX(2);
+
+  if (depFullyReduced(tid, lvl)) {
+    // If constraints on the tensor is fully resolved. We do not need to
+    // generates slice begin any more, instead we fall back to TACO-based
+    // algorithm to (co)iterates over the slice.
+    Value pLoPtr =
+        genIndexLoad(builder, loc, slicePosBuffer[tid][lvl].back(), c1);
+    pLoPtr = builder.create<arith::AddIOp>(loc, pLoPtr, c2);
+    Value pHiPtr = builder.create<arith::AddIOp>(loc, pLoPtr, c1);
+    posits[tid][lvl] =
+        genIndexLoad(builder, loc, slicePosBuffer[tid][lvl].back(), pLoPtr);
+    highs[tid][lvl] =
+        genIndexLoad(builder, loc, slicePosBuffer[tid][lvl].back(), pHiPtr);
+    return true;
+  }
+
+  // Only when the level is sorted, the next-non-empty slice can be computed
+  // efficiently.
+  const DimLevelType lvlType = lvlTypes[tid][lvl];
+  assert(isOrderedDLT(lvlType));
+  if (isSingletonDLT(lvlType)) {
+    llvm_unreachable("TODO: dense level should be easy to support, while "
+                     "singleton level requres more efforts");
+  }
+
+  assert(!dependentLvlMap[tid][lvl].empty());
+  assert(!sliceStack[tid].empty());
+
+  const SliceInfo &sliceInfo = sliceStack[tid].back();
+  auto baseEnc = getSparseTensorEncoding(tensors[tid].getType());
+  if (baseEnc.isSlice())
+    llvm_unreachable("TODO: not yet implemented");
+
+  // Generate caches required to fast compute next-non-empty slices with
+  // increasing offset for slice-base loop.
+  // We do not need cache for dense levels.
+  if (slicePosBuffer[tid][lvl][0] == nullptr && !isDenseDLT(lvlType)) {
+    OpBuilder::InsertionGuard guard(builder);
+    // The buffer can be reused, and the size is loop invariant: it only depends
+    // on the iteration graph's toposort.
+    builder.setInsertionPointAfter(localInsertPos);
+    Value bufSize = C_IDX(1);
+    Value c2 = C_IDX(2);
+    // Accumlates the size required to cache the pLo for the slice.
+    // E.g., if we want to cache the pIdx for slice<d0xd1xf64> on the second
+    // level. We at most need to a memref<d0xindex>.
+    // NOTE: this is apperantly an over-approximation when the previous
+    // level is compressed, and we can compute a precise memory size
+    // inside the loops. But that would also requires us to allocate/free
+    // memorys in loops.
+    // TODO: Maybe using allocaScopeOp inside the loop to resolve the issue?
+    for (Level curLevel = lvl;
+         curLevel >= 1 && !lvlFullyResolved(tid, curLevel - 1); curLevel--) {
+      auto depth = remDepOnLevel(tid, curLevel - 1);
+      assert(sliceSizes[tid][lvl].size() >= depth);
+      Value sz = *(sliceSizes[tid][lvl].rbegin() + depth - 1);
+      bufSize = builder.create<arith::MulIOp>(loc, bufSize, sz);
+    }
+    // For a pair of [pLo, pHi]. Note that we can not compress pHi because slice
+    // creates segments in the index buffer so that the pHi for the current
+    // level is no longer the pLo for the next level.
+    bufSize = builder.create<arith::MulIOp>(loc, bufSize, c2);
+    // Additional two metadata {memSize, idx} at head.
+    bufSize = builder.create<arith::AddIOp>(loc, bufSize, c2);
+    llvm::for_each(
+        slicePosBuffer[tid][lvl], [bufSize, loc, &builder](Value &cache) {
+          cache = genAlloca(builder, loc, bufSize, builder.getIndexType());
+        });
+  }
+
+  if (sliceInfo.isInitialTensor() ||
+      (lvl >= 1 && lvlFullyResolved(tid, lvl - 1))) {
+    // First level or previous level has been full resolved.
+    genResolvedSliceBegin(builder, loc, tid, lvl);
+  } else {
+    // The previous level has not been full resolved.
+    genUnResolvedSliceBegin(builder, loc, tid, lvl);
+  }
+  return false;
+}
+
+void LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
+                                        const Operation *op, TensorId tid,
+                                        Level lvl,
+                                        SmallVectorImpl<Value> &operands,
+                                        unsigned &retIdx) {
+  if (!isCompressedDLT(lvlTypes[tid][lvl]))
+    llvm_unreachable("TODO");
+
+  // else generate code to compute next non empty slice.
+  Value c0 = C_IDX(0);
+  Value c1 = C_IDX(1);
+  Value c2 = C_IDX(2);
+
+  auto whileOp = llvm::cast<scf::WhileOp>(op);
+  SliceInfo &info = sliceStack[tid].back();
+  assert(info.slicedOnLvl == lvl);
+
+  //
+  // We forward to the next non empty slice by
+  // if (minCrd > offset) {
+  //   offset += 1
+  // } else {
+  //    minCrd = nextMinInSlice();
+  //    offset = minCrd - size + 1;
+  // }
+  //
+  // if (offset + size > parents.size)
+  //   isNonEmpty = false;
+  //
+  Value absOffset = info.offset;
+  // Resets slices pointers as the resolved slices are invalidated after we
+  // moves forward to the next slice.
+  for (unsigned i = 0; i <= lvl; i++)
+    builder.create<memref::StoreOp>(loc, c0, slicePosBuffer[tid][i].back(), c1);
+
+  SmallVector<Value, 3> reduc = {info.minCrd, info.isNonEmpty, absOffset};
+  Value sPtrBuf = slicePosBuffer[tid][lvl][info.depth - 1];
+  Value fastPathP = CMPI(ugt, info.minCrd, absOffset);
+  auto ifOp = builder.create<scf::IfOp>(loc, ValueRange(reduc).getTypes(),
+                                        fastPathP, true);
+  {
+    OpBuilder::InsertionGuard guard(builder);
+    // Take the fast path
+    // if (minCrd > offset) {
+    //   return offset += 1
+    // }
+    builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+    reduc[2] = builder.create<arith::AddIOp>(loc, absOffset, c1);
+    // Yield offset + 1.
+    builder.create<scf::YieldOp>(loc, reduc);
+
+    // else /*minCrd == offset*/ {
+    //    for (i = 0; i < slicePos.size(); i+=2) {
+    //       if (crd[pos[slicePos[i]]] == minCrd) {
+    //          slicePos[i]++;
+    //       }
+    //       minCrd=min(minCrd, crd[pos[slicePos[i]]]);
+    //    }
+    //    offset = minCrd - size + 1;
+    // }
+    builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+    reduc[2] = absOffset; // restore value.
+    Value pSt = c2;       // pointer starting index
+    Value mSz = genIndexLoad(builder, loc, sPtrBuf, c0); // memSize
+    reduc[0] = lvlSizes[tid][lvl];                       // next min coord
+    reduc[1] = constantI1(builder, loc, false);          // isNonEmpty
+    auto loopArgs = static_cast<ValueRange>(reduc).drop_back();
+    auto forOp = scf::buildLoopNest(
+        builder, loc, pSt, mSz, c2, loopArgs,
+        [this, tid, lvl, c1, sPtrBuf,
+         &info](OpBuilder &builder, Location loc, ValueRange ivs,
+                ValueRange iterArgs) -> scf::ValueVector {
+          Value curMinCrd = iterArgs[0];
+          Value isNonEmpty = iterArgs[1];
+
+          Type idxTp = builder.getIndexType();
+          Value pLo = genIndexLoad(builder, loc, sPtrBuf, ivs.front());
+          Value pHi =
+              genIndexLoad(builder, loc, sPtrBuf,
+                           builder.create<arith::AddIOp>(loc, ivs.front(), c1));
+          //
+          // if (pLo < pHi) // Only loads when inbound.
+          //   coord = load[pLo]
+          //   if coord == minCrd
+          //     pLo += 1
+          //
+          // if (pLo < pHi)
+          //   curMinCrd = min(curMinCrd, load[pLo])
+          //
+          Value pred = CMPI(ult, pLo, pHi);
+          auto advPLo = builder.create<scf::IfOp>(loc, idxTp, pred, true);
+          /* if pLo < pHi */ {
+            builder.setInsertionPointToStart(&advPLo.getThenRegion().front());
+            // coord = load[pLo]
+            Value coord =
+                genIndexLoad(builder, loc, coordinatesBuffers[tid][lvl], pLo);
+            Value pred = CMPI(eq, coord, info.minCrd);
+            auto ifEqual = builder.create<scf::IfOp>(loc, idxTp, pred, true);
+            /* if coord == minCrd */ {
+              builder.setInsertionPointToStart(
+                  &ifEqual.getThenRegion().front());
+              // pLo += 1.
+              Value newPlo = builder.create<arith::AddIOp>(loc, pLo, c1);
+              builder.create<memref::StoreOp>(loc, newPlo, sPtrBuf,
+                                              ivs.front());
+              builder.create<scf::YieldOp>(loc, newPlo);
+            }
+            /* else coord != minCrd */ {
+              builder.setInsertionPointToStart(
+                  &ifEqual.getElseRegion().front());
+              builder.create<scf::YieldOp>(loc, pLo);
+            }
+            builder.setInsertionPointAfter(ifEqual);
+            builder.create<scf::YieldOp>(loc, ifEqual.getResults());
+          }
+          /* else pLo >= pHi */ {
+            builder.setInsertionPointToStart(&advPLo.getElseRegion().front());
+            builder.create<scf::YieldOp>(loc, pLo);
+          }
+
+          builder.setInsertionPointAfter(advPLo);
+          pLo = advPLo.getResult(0);
+          Value lvlNonEmpty = CMPI(ult, pLo, pHi);
+          // Update minCrds
+          auto newMin =
+              builder.create<scf::IfOp>(loc, idxTp, lvlNonEmpty, true);
+          builder.setInsertionPointToStart(&newMin.getThenRegion().front());
+          builder.create<scf::YieldOp>(
+              loc,
+              genIndexLoad(builder, loc, coordinatesBuffers[tid][lvl], pLo));
+
+          builder.setInsertionPointToStart(&newMin.getElseRegion().front());
+          builder.create<scf::YieldOp>(loc, curMinCrd);
+          builder.setInsertionPointAfter(newMin);
+
+          // isNonEmpty = isNonEmpty || lvlNonEmpty
+          isNonEmpty =
+              builder.create<arith::OrIOp>(loc, lvlNonEmpty, isNonEmpty);
+          curMinCrd = builder.create<arith::SelectOp>(
+              loc, CMPI(ult, newMin.getResult(0), curMinCrd),
+              newMin.getResult(0), curMinCrd);
+          return {curMinCrd, isNonEmpty};
+        });
+
+    builder.setInsertionPointAfter(forOp.loops.front());
+    // minOffset = minCrd + 1 >= size ? minCrd + 1 - size : c0
+    Value tmp = builder.create<arith::AddIOp>(loc, forOp.results.front(), c1);
+    Value minOffset = builder.create<arith::SubIOp>(
+        loc, tmp, sliceSizes[tid][lvl][info.depth - 1]);
+    Value p = CMPI(uge, tmp, sliceSizes[tid][lvl][info.depth - 1]);
+    minOffset = builder.create<arith::SelectOp>(loc, p, minOffset, c0);
+    SmallVector<Value, 3> yields;
+    yields.assign(forOp.results.begin(), forOp.results.end());
+    yields.push_back(minOffset);
+    builder.create<scf::YieldOp>(loc, yields);
+  }
+
+  Value nextMinCrd = ifOp.getResults()[0];
+  Value nextNonEmpty = ifOp.getResults()[1];
+
+  // The next offset should at least be offset + 1;
+  Value minOffset = ifOp.getResults()[2];
+  Value nxOffset = builder.create<arith::AddIOp>(loc, info.offset, c1);
+  Value maxPred = CMPI(ugt, minOffset, nxOffset);
+  Value nextAbsOffset =
+      builder.create<arith::SelectOp>(loc, maxPred, minOffset, nxOffset);
+
+  Value sliceUB = builder.create<arith::AddIOp>(
+      loc, nextAbsOffset, sliceSizes[tid][lvl][info.depth - 1]);
+
+  // FIXME: this only works if there is only one parent.
+  assert(info.depth - 1 == 0);
+  // nextNonEmpty = nextNonEmpty && slice upper bound <= parent upperbound.
+  nextNonEmpty = builder.create<arith::AndIOp>(
+      loc, nextNonEmpty, CMPI(ule, sliceUB, lvlSizes[tid][lvl]));
+
+  // FIXME: compute relative offset.
+  assert(info.depth - 1 == 0);
+  Value nextRelOffset = nextAbsOffset;
+  nextRelOffset =
+      builder.create<arith::SelectOp>(loc, nextNonEmpty, nextRelOffset, c0);
+
+  operands.push_back(nextNonEmpty);
+  operands.push_back(nextMinCrd);
+  operands.push_back(nextAbsOffset); // we push the absolute offset.
+
+  // Update the slice stack.
+  info.isNonEmpty = whileOp.getResult(retIdx++);
+  info.minCrd = whileOp.getResult(retIdx++);
+  info.offset = whileOp.getResult(retIdx++);
+}
+
+Operation *LoopEmitter::emitSliceDrivenLoopOverTensorAtLvl(
+    OpBuilder &builder, Location loc, TensorId tid, Level lvl,
+    MutableArrayRef<Value> reduc) {
+  assert(!depFullyReduced(tid, lvl));
+  SliceInfo &sliceInfo = sliceStack[tid].back();
+  assert(sliceInfo.slicedOnLvl == lvl);
+
+  // The order matters!
+  SmallVector<Value, 3> operands{sliceInfo.isNonEmpty, sliceInfo.minCrd,
+                                 sliceInfo.offset};
+  // number of reduction maintained by us.
+  size_t numMetaReduc = operands.size();
+
+  // Append user-required reduction values.
+  operands.append(reduc.begin(), reduc.end());
+  assert(operands.size() == numMetaReduc + reduc.size());
+
+  // while (slice.nonEmpty()) {
+  //   bodyBuilder();
+  //   SliceNext();
+  // }
+  auto whileOp = builder.create<scf::WhileOp>(
+      loc, ValueRange(operands).getTypes(), operands,
+      /*beforeBuilder=*/
+      [](OpBuilder &builder, Location loc, ValueRange args) {
+        builder.create<scf::ConditionOp>(loc, /*isNonEmpty*/ args[0], args);
+      },
+      /*afterBuilder=*/
+      [this, tid, lvl, reduc, numMetaReduc,
+       &sliceInfo](OpBuilder &builder, Location loc, ValueRange args) {
+        assert(args.size() == reduc.size() + numMetaReduc);
+        sliceInfo.isNonEmpty = args[0];
+        sliceInfo.minCrd = args[1];
+        sliceInfo.offset = args[2];
+        // The slice offset is used to coiterate with other tensors'
+        // coordinates.
+        Value c = sliceInfo.offset;
+        if (sliceInfo.depth > 1) {
+          // Coord is the relative offset related to its parents.
+          // Update c = absOffset[lvl][depth] - absOffset[lvl][depth - 1]
+          llvm_unreachable("TODO: not yet implement");
+        }
+        coords[tid][lvl] = c;
+
+        for (unsigned i = 0, e = reduc.size(); i < e; i++)
+          reduc[i] = args[i + numMetaReduc];
+      });
+
+  // Set the insertion point to while loop body.
+  builder.setInsertionPointToEnd(&whileOp.getAfter().front());
+  return whileOp;
+}
+
+#undef CMPI
+#undef C_IDX

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
index f3b5a619b06e7..5bbb68198e0f5 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
@@ -82,6 +82,10 @@ class LoopEmitter {
   // d0 and d1 (for affine expression reduction).
   // If the list is empty, it means that there is no affine expression on the
   // input [tid, dim].
+  // NOTE: The caller is responsible to ensure that the order of the returned
+  // list to be consistent with the topological order of the iteration graph,
+  // otherwise the loop emitter might reduce a wrong dependent index variable
+  // when generating slice-driven loops.
   using DependentLvlGetter =
       function_ref<std::vector<std::pair<TensorId, Level>>(TensorId, Level)>;
 
@@ -133,10 +137,7 @@ class LoopEmitter {
                        ArrayRef<TensorId> tids, ArrayRef<Level> lvls);
 
   /// Exits the current loop sequence, this will reset universal index to 0.
-  void exitCurrentLoopSeq() {
-    assert(loopSeqStack.size() == loopStack.size() + 1);
-    loopSeqStack.pop_back();
-  }
+  void exitCurrentLoopSeq(OpBuilder &builder, Location loc);
 
   // TODO: Get rid of `lvls` in the argument list? Track the level we
   // are currently at internally. Then it would be enterNextLvlForTensor.
@@ -208,26 +209,62 @@ class LoopEmitter {
   }
 
 private:
-  struct LoopInfo {
-    LoopInfo(ArrayRef<TensorId> tids, ArrayRef<Level> lvls, Operation *loop,
-             Block *userBlock, Value iv, StringAttr loopTag)
-        : tids(tids), lvls(lvls), loop(loop), userCodeBlock(userBlock), iv(iv) {
+  // LoopInfo stores information of a loop generated by LoopEmitter. E.g.,
+  // the set of tensors levels that the loop is iterating over.
+  struct LoopInfo final {
+    LoopInfo(ArrayRef<TensorId> tids, ArrayRef<Level> lvls,
+             ArrayRef<TensorId> slicedTids, ArrayRef<Level> slicedLvls,
+             ArrayRef<bool> sliceReduced, Operation *loop, Block *userBlock,
+             Value iv, StringAttr loopTag)
+        : tids(tids), lvls(lvls), slicedTids(slicedTids),
+          slicedLvls(slicedLvls), sliceReduced(sliceReduced), loop(loop),
+          userCodeBlock(userBlock), iv(iv) {
       // Attached a special tag to loop emitter generated loop.
       if (loopTag)
         loop->setAttr(LoopEmitter::getLoopEmitterLoopAttrName(), loopTag);
     }
     // TODO: maybe use a vector<pair> for tid and lvl?
-    //       (Better yet, compress them together a la `TensorLoopId`.)
+    //       (Or compress them together with a `TensorLoopId`.)
     // The set of tensors that the loop is operating on
     const llvm::SmallVector<TensorId> tids;
     // The corresponding levels for the tensors
     const llvm::SmallVector<Level> lvls;
+    // The set of tensors for slice-driven loop conditions.
+    const llvm::SmallVector<TensorId> slicedTids;
+    // The corresponding level for slice-driven tensors.
+    const llvm::SmallVector<Level> slicedLvls;
+    // Whether the tensor is fully reduced (e.g., i + j => j).
+    const llvm::SmallVector<bool> sliceReduced;
     const Operation *loop;      // the loop operation
     Block *const userCodeBlock; // the block holding users' generated code.
     const Value iv;             // the induction variable for the loop
   };
 
-  /// Linearizes address for dense level (i.e., p = (i * d0) + j).
+  // SliceInfo stores information of an extracted slice for slice-driven loop.
+  // E.g., the in-scope SSA values for the minimum coordinates and offset for
+  // the slice, etc.
+  struct SliceInfo final {
+    // Note that we do not need to create a actual sparse tensor slice but
+    // instead only need to maintain the metadata of the slice.
+    SliceInfo(Value minCrd, Value offset, Value isNonEmpty,
+              std::optional<Level> slicedOnLvl, unsigned depth)
+        : minCrd(minCrd), offset(offset), isNonEmpty(isNonEmpty),
+          slicedOnLvl(slicedOnLvl), depth(depth) {
+      // TODO: use std::optional<pair<Level, minCrd>>
+      assert(!slicedOnLvl || minCrd);
+    }
+
+    // Whether this is the tensor that has not yet been sliced.
+    bool isInitialTensor() const { return !slicedOnLvl.has_value(); }
+
+    Value minCrd;                     // the minimum coordinate of the slice.
+    Value offset;                     // the offset of the current slice.
+    Value isNonEmpty;                 // whether the slice is empty.
+    std::optional<Level> slicedOnLvl; // the level on which the slice is done
+    unsigned depth; // the depth (relative to dependentDimMap[tid][lvl]).
+  };
+
+  /// Linearizes address for dense dimension (i.e., p = (i * d0) + j).
   Value genAddress(OpBuilder &builder, Location loc, TensorId tid, Level lvl,
                    Value iv);
 
@@ -281,6 +318,24 @@ class LoopEmitter {
                                             ArrayRef<TensorId> tids,
                                             ArrayRef<Level> lvls);
 
+  /// Emits a for loop to iterate over a dense level, or a sparse level that has
+  /// not been sliced.
+  Operation *emitForLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
+                                        TensorId tid, Level lvl,
+                                        MutableArrayRef<Value> reduc,
+                                        bool isParallel);
+
+  /// Emits a while loop to iterate over a sparse level that has been sliced.
+  /// Inserts break statement when the coordinate exceeds the sliceSize;
+  /// The method sets the insertion point inside the generated while loop body
+  /// after the break statement before return (so that callers need to handle
+  /// only in-bound coordinates).
+  Operation *emitWhileLoopOverSliceAtSparseLvl(OpBuilder &builder, Location loc,
+                                               Value pLo, Value pHi,
+                                               Value offset, Value sliceSize,
+                                               TensorId tid, Level lvl,
+                                               MutableArrayRef<Value> reduc);
+
   /// Exits a for loop, returns the reduction results, e.g.,
   /// For sequential for loops:
   /// %ret = for () {
@@ -344,6 +399,95 @@ class LoopEmitter {
     return {dstLvl};
   }
 
+  //
+  // Slice-driven loop related methods.
+  //
+
+  /// Retrieves the most recent slice on lvl. To reduce affine expression like
+  /// d0 + d1 + d2, we need two slices (one of size d1 + d2, and the other of
+  /// size d2). This methods returns the latter slice (of size d2), which is
+  /// also the final slice on the level.
+  const SliceInfo &getFinalSliceOnLvl(TensorId tid, Level lvl);
+
+  /// Get the remaining number of constraints needed to fully *resolve*
+  /// dependent levels on tensor[tid].
+  unsigned remDepOnLevel(TensorId tid, Level lvl) const;
+
+  /// Whether the tid, lvl is fully *reduced*, i.e., the non-trivial index
+  /// expression has been reduced to a trivial one.
+  /// E.g., A[i + j] => A[i + 2] (j is reduced)
+  bool depFullyReduced(TensorId tid, Level lvl) const {
+    return remDepOnLevel(tid, lvl) == 1;
+  }
+
+  /// Whether the tid, lvl is fully resolved, i.e., we entered the level already
+  /// (the index on that level is determined).
+  /// E.g., A[i + j] => A[2 + 3] (both i and j become invariants for inner
+  /// loops).
+  bool lvlFullyResolved(TensorId tid, Level lvl) const {
+    return remDepOnLevel(tid, lvl) == 0;
+  }
+
+  /// Generates a whileOp to iterate over a subset of coordinates on tid on lvl
+  /// using the pHi and pLo provided, the loop break on the first coordinate
+  /// that exceeds the slice boundary (i.e., coord >= slice.offset +
+  /// slice.size).
+  std::pair<Operation *, ValueRange>
+  genSliceLvlTraverseLoop(OpBuilder &builder, Location loc, Value pLo,
+                          Value pHi, Value offset, Value size, TensorId tid,
+                          Level lvl, ValueRange userReduc, bool genYield,
+                          /*bodyBuilder=*/
+                          llvm::function_ref<void(OpBuilder &, Location, Value,
+                                                  MutableArrayRef<Value>)>);
+
+  /// Generates a nested loop that iterates over tid on all the coordinates on
+  /// lvl.
+  ValueRange genUnResolvedSliceTreeTraverse(
+      OpBuilder &builder, Location loc, Value offset, TensorId tid, Level lvl,
+      size_t depth, ValueRange userReduc,
+      /*bodyBody=*/
+      llvm::function_ref<void(OpBuilder &, Location, Value,
+                              MutableArrayRef<Value>)>);
+
+  /// Generates code to get the first non-empty slice of tid on lvl, when all
+  /// the previous level before `lvl` are resolved (or lvl is the first level).
+  ///
+  /// This is the simple case because the previous level are resolved into a
+  /// single node in the storage tree.
+  void genResolvedSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
+                             Level lvl);
+
+  /// Generates code to get the first non-empty slice of tid on lvl, when
+  /// the previous levels before `lvl` are unresolved
+  ///
+  /// This is the complex case because the previous levels corresponding to a
+  /// range of nodes in the storage tree.
+  void genUnResolvedSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
+                               Level lvl);
+
+  /// Generates code to get the first non-empty slice of tid on lvl.
+  /// return true if has already been resolved.
+  bool genSliceBegin(OpBuilder &builder, Location loc, TensorId tid, Level lvl);
+
+  /// Generates code to get the next non-empty slices of tid on lvl.
+  void genSliceNextInduction(OpBuilder &builder, Location loc,
+                             const Operation *whileOp, TensorId tid, Level lvl,
+                             SmallVectorImpl<Value> &operands,
+                             unsigned &retIdx);
+
+  /// Generates a slice-driven while loop as follows.
+  ///
+  /// curSlice = getFirstNonEmptySlice(tensor).
+  ///
+  /// while(isNonEmpty) {
+  ///   ..user code..
+  ///   isNonEmpty, curSlice = getNextNonEmptySlice(curSlice)
+  /// }
+  Operation *emitSliceDrivenLoopOverTensorAtLvl(OpBuilder &builder,
+                                                Location loc, TensorId tid,
+                                                Level lvl,
+                                                MutableArrayRef<Value> reduc);
+
   /// A optional string attribute that should be attached to the loop
   /// generated by loop emitter, it might help following passes to identify
   /// loops that operates on sparse tensors more easily.
@@ -353,6 +497,9 @@ class LoopEmitter {
   bool hasOutput;
   bool isSparseOut;
 
+  /// The insertion point to allocate top level local variables.
+  Operation *localInsertPos;
+
   //
   // Fields which have `numTensor` many entries.
   //
@@ -388,6 +535,10 @@ class LoopEmitter {
   std::vector<std::vector<Value>> coordinatesBuffers; // to_coordinates
   std::vector<Value> valBuffer;                       // to_value
 
+  //
+  // Slice-driven loops related fields.
+  //
+
   /// Whether the sparse input is a slice.
   std::vector<bool> isSparseSlices;
   /// Values related to slices.
@@ -399,6 +550,21 @@ class LoopEmitter {
   std::vector<std::vector<std::vector<std::pair<TensorId, Level>>>>
       dependentLvlMap;
 
+  // The cached position buffer for the slices, they serve the same purpose as
+  // ptrBuffer for compressed dimensions.
+  // But they always starts with the first pidx pointing to coord > slice.offset
+  // to avoid iteration from the beginning.
+  std::vector<std::vector<std::vector<Value>>> slicePosBuffer;
+
+  // The cached size for each slices.
+  std::vector<std::vector<std::vector<Value>>> sliceSizes;
+
+  // The number of reduced dependencies on a tensor level so far.
+  std::vector<std::vector<unsigned>> levelReducedDep;
+
+  // sliceStack[tid] holds the generated slice stack on tid.
+  std::vector<std::vector<SliceInfo>> sliceStack;
+
   //
   // View based reshape related-fields and methods
   //
@@ -419,9 +585,11 @@ class LoopEmitter {
   /// alive.
   std::vector<LoopInfo> loopStack;
 
-  /// Loop Sequence Stack, stores the universal index for the current loop
-  /// sequence.
-  std::vector<Value> loopSeqStack;
+  // Loop Sequence Stack, stores the unversial index for the current loop
+  // sequence. and a list of tids which was taken sliced.
+  // TODO: maybe we should have a LoopSeqInfo
+  std::vector<std::pair<Value, std::vector<std::tuple<TensorId, Level, bool>>>>
+      loopSeqStack;
 
   /// Maps `LoopId` (used by `AffineDimExpr`) to `LoopOrd` (in the `loopStack`).
   /// TODO: We should probably use a callback function here to make it more

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 19ed23108f2f5..d22ea43fbf5ac 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -1014,7 +1014,7 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
       // Link the reduction chain. Note that loop emitter update the reducValue
       // in place.
       loopEmitter.exitCurrentLoop(rewriter, loc, reducValue);
-      loopEmitter.exitCurrentLoopSeq();
+      loopEmitter.exitCurrentLoopSeq(rewriter, loc);
     }
 
     // Replace the foreach operator with the value returned by the outtermost

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 64a86aa1a8570..40a2454f779de 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -83,9 +83,14 @@ inline static bool includesUndef(SortMask mask) {
 class AffineDimFinder : public AffineExprVisitor<AffineDimFinder> {
 public:
   explicit AffineDimFinder(linalg::GenericOp op)
-      : iterTypes(op.getIteratorTypesArray()) {}
+      : iterTypes(op.getIteratorTypes()) {}
+
+  // Overrides method from AffineExprVisitor.
   void visitDimExpr(AffineDimExpr expr) {
-    if (pickedDim == nullptr || pickIterType == iterTypes[expr.getPosition()]) {
+    if (pickedDim == nullptr ||
+        pickIterType == iterTypes[expr.getPosition()]
+                            .cast<linalg::IteratorTypeAttr>()
+                            .getValue()) {
       pickedDim = expr;
     }
   }
@@ -106,11 +111,12 @@ class AffineDimFinder : public AffineExprVisitor<AffineDimFinder> {
   /// The iterator type that we want.
   utils::IteratorType pickIterType;
   /// The mapping between dim=>iterator type.
-  SmallVector<utils::IteratorType> iterTypes;
+  ArrayAttr iterTypes;
 };
 
 // Flattens an affine expression into a list of AffineDimExprs.
 struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> {
+  // Overrides method from AffineExprVisitor.
   void visitDimExpr(AffineDimExpr expr) { dims.push_back(expr); }
   SmallVector<AffineDimExpr> dims;
 };
@@ -306,7 +312,7 @@ static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl,
         // else increase min(d0_1, d0_2).
         return false;
       }
-      merger.setLoopDependentTensorLevel(ldx, tensor, lvl);
+      merger.setLoopDependentTensorLevel(ldx, tensor, lvl, dlt);
     }
     return true;
   }
@@ -774,10 +780,6 @@ static bool computeIterationGraph(CodegenEnv &env, SortMask mask,
   for (OpOperand &t : env.op()->getOpOperands()) {
     // Get map and encoding.
     const auto enc = getSparseTensorEncoding(t.get().getType());
-    assert(env.op().getMatchingIndexingMap(&t).getNumDims() +
-               getNumNonTrivialIdxExpOnSparseLvls(env.op()) ==
-           numLoops);
-
     // Skips dense inputs/outputs when not requested.
     const bool isDenseInput = !enc && env.op().isDpsInput(&t);
     const bool isDenseOutput = !enc && !isDenseInput;
@@ -1498,8 +1500,7 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
   // consumed by a subsequent lattice point.
   if (needsUniv) {
     for (const LatPointId li : env.set(lts).drop_front())
-      if (!env.merger().hasAnySparse(env.lat(li).simple) &&
-          !env.merger().hasSparseIdxReduction(env.lat(li).simple))
+      if (!env.merger().hasAnySparse(env.lat(li).simple))
         return true;
   }
   return false;
@@ -1680,14 +1681,17 @@ static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
 
 /// Ends a single loop in current sequence. Returns new values for needsUniv.
 static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop,
-                    LoopId idx, LatPointId li, bool needsUniv) {
-  // End a while-loop.
-  if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) {
-    finalizeWhileOp(env, rewriter, idx, needsUniv, whileOp);
-  } else if (auto forOp = dyn_cast<scf::ForOp>(loop)) {
-    // Any iteration of a reduction for-loop creates a valid lex insert.
+                    LoopId idx, LatPointId li, bool needsUniv,
+                    bool isSingleCond) {
+
+  if (isSingleCond) {
+    // Either a for-loop or a while-loop that iterates over a slice.
+    // Any iteration creates a valid lex insert.
     if (env.isReduc() && env.getValidLexInsert())
       env.setValidLexInsert(constantI1(rewriter, env.op().getLoc(), true));
+  } else if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) {
+    // End a while-loop.
+    finalizeWhileOp(env, rewriter, idx, needsUniv, whileOp);
   } else {
     needsUniv = false;
   }
@@ -1701,10 +1705,10 @@ static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop,
 }
 
 /// Ends a loop sequence at given level.
-static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
-                       LoopOrd at, LoopId idx, LoopId ldx) {
+static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp,
+                       unsigned at, unsigned idx, unsigned ldx) {
   assert(!env.getLoopVar(idx));
-  env.emitter().exitCurrentLoopSeq();
+  env.emitter().exitCurrentLoopSeq(builder, env.op().getLoc());
   // Unmark bookkeeping of invariants and loop index.
   genInvariants(env, builder, exp, ldx, /*atStart=*/false);
   // Finalize access pattern expansion for sparse tensor output.
@@ -1769,7 +1773,7 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
     }
 
     // End a loop.
-    needsUniv = endLoop(env, rewriter, loop, idx, li, needsUniv);
+    needsUniv = endLoop(env, rewriter, loop, idx, li, needsUniv, isSingleCond);
   }
 
   // End a loop sequence.

diff  --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 8bdff8de8dcfe..7d331b2d298d4 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -213,10 +213,11 @@ Merger::Merger(unsigned numInputOutputTensors, unsigned numNativeLoops,
                 std::vector<std::optional<Level>>(numLoops, std::nullopt)),
       lvlToLoop(numTensors,
                 std::vector<std::optional<LoopId>>(maxLvlRank, std::nullopt)),
-      loopToDependencies(numLoops, std::vector<std::optional<Level>>(
-                                       numTensors, std::nullopt)),
-      levelToDependentIdx(numTensors, std::vector<std::vector<LoopId>>(
-                                          maxLvlRank, std::vector<LoopId>())),
+      loopToDependencies(
+          numLoops, std::vector<std::optional<std::pair<Level, DimLevelType>>>(
+                        numTensors, std::nullopt)),
+      levelToDependentLoop(numTensors, std::vector<std::vector<LoopId>>(
+                                           maxLvlRank, std::vector<LoopId>())),
       loopBounds(numLoops, std::make_pair(numTensors, numLoops)) {}
 
 //===----------------------------------------------------------------------===//
@@ -396,15 +397,10 @@ BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
     }
   }
 
-  BitVector simple(lat(p0).bits);
-  bool reset =
-      isSingleton && (hasAnySparse(simple) || hasSparseIdxReduction(simple));
-  // `be`, `b`, and `offset` are `TensorLoopId` in spirit; but we avoid
-  // using that class in this function because we need to do a bunch of
-  // arithmetic on them, so using the newtype would introduce too much
-  // boilerplate.
-  const unsigned be = simple.size();
-  unsigned offset = 0; // relative to the end
+  BitVector simple(latPoints[p0].bits);
+  bool reset = isSingleton && hasAnySparse(simple);
+  const TensorLoopId be = simple.size();
+  TensorLoopId offset = 0; // relative to the end
   if (!reset)
     // Starts resetting from a dense level, so that the first bit (if kept)
     // is not undefined level-type.
@@ -419,10 +415,9 @@ BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
   // keep the rightmost bit (which could possibly be a synthetic tensor).
   for (unsigned b = be - 1 - offset, i = 0; i < be;
        b = b == 0 ? be - 1 : b - 1, i++) {
-    // FIXME: better name? also slice on dense level has locate property as
-    // well. Handle it correctly!
-    if (simple[b] && !isLvlWithNonTrivialIdxExp(TensorLoopId{b})) {
-      const auto dlt = getDimLevelType(TensorLoopId{b});
+    // Slice on dense level has `locate` property as well, and can be optimized.
+    if (simple[b] && !isSparseLvlWithNonTrivialIdxExp(b)) {
+      const auto dlt = getDimLevelType(b);
       if (!isCompressedDLT(dlt) && !isSingletonDLT(dlt)) {
         if (reset)
           simple.reset(b);
@@ -447,9 +442,9 @@ bool Merger::latGT(LatPointId i, LatPointId j) const {
 }
 
 bool Merger::onlyDenseDiff(LatPointId i, LatPointId j) const {
-  BitVector tmp(lat(j).bits);
-  tmp ^= lat(i).bits;
-  return !hasAnySparse(tmp) && !hasSparseIdxReduction(tmp);
+  BitVector tmp(latPoints[j].bits);
+  tmp ^= latPoints[i].bits;
+  return !hasAnySparse(tmp);
 }
 
 bool Merger::expContainsTensor(ExprId e, TensorId t) const {
@@ -588,19 +583,17 @@ bool Merger::isSingleCondition(TensorId t, ExprId e) const {
 }
 
 bool Merger::hasAnySparse(const BitVector &bits) const {
-  for (TensorLoopId b = 0, be = bits.size(); b < be; b++)
-    if (bits[b]) {
-      const auto dlt = getDimLevelType(b);
-      if (isCompressedDLT(dlt) || isSingletonDLT(dlt))
-        return true;
-    }
-  return false;
+  for (TensorLoopId b : bits.set_bits()) {
+    const auto dlt = getDimLevelType(b);
+    if (isCompressedDLT(dlt) || isSingletonDLT(dlt))
+      return true;
+  }
+  return hasSparseIdxReduction(bits);
 }
 
 bool Merger::hasSparseIdxReduction(const BitVector &bits) const {
-  // TODO: return false on dense levels.
-  for (unsigned b = 0, be = bits.size(); b < be; b++)
-    if (bits[b] && isLvlWithNonTrivialIdxExp(b))
+  for (TensorLoopId b : bits.set_bits())
+    if (isSparseLvlWithNonTrivialIdxExp(b))
       return true;
   return false;
 }

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
new file mode 100644
index 0000000000000..37d3e1026a167
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
@@ -0,0 +1,284 @@
+// RUN: mlir-opt %s --sparsification="enable-index-reduction=true" --cse | FileCheck %s
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+
+#DCSR = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>
+
+// CHECK-LABEL:   func.func @conv2d_all_sparse_CSR(
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<8x8xi32, #{{.*}}>,
+// CHECK-SAME:      %[[VAL_1:.*]]: tensor<3x3xi32>) -> tensor<6x6xi32, #{{.*}}> {
+// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 8 : index
+// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 3 : index
+// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 2 : index
+// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 4 : index
+// CHECK-DAG:       %[[VAL_8:.*]] = arith.constant 0 : i32
+// CHECK-DAG:       %[[VAL_9:.*]] = arith.constant true
+// CHECK-DAG:       %[[VAL_10:.*]] = arith.constant false
+// CHECK-DAG:       %[[VAL_11:.*]] = bufferization.alloc_tensor() : tensor<6x6xi32, #{{.*}}>
+// CHECK-DAG:       %[[VAL_12:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<8x8xi32, #{{.*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_13:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<8x8xi32, #{{.*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_14:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<8x8xi32, #{{.*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_15:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<8x8xi32, #{{.*}}> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_16:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x8xi32, #{{.*}}> to memref<?xi32>
+// CHECK-DAG:       %[[VAL_17:.*]] = bufferization.to_memref %[[VAL_1]] : memref<3x3xi32>
+// CHECK-DAG:       %[[VAL_18:.*]] = memref.alloca(%[[VAL_2]]) : memref<?xindex>
+// CHECK-DAG:       %[[VAL_19:.*]] = memref.alloca(%[[VAL_7]]) : memref<?xindex>
+// CHECK:           %[[VAL_20:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK:           memref.store %[[VAL_7]], %[[VAL_19]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK:           memref.store %[[VAL_4]], %[[VAL_19]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK:           memref.store %[[VAL_4]], %[[VAL_19]]{{\[}}%[[VAL_6]]] : memref<?xindex>
+// CHECK:           memref.store %[[VAL_20]], %[[VAL_19]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK:           %[[VAL_21:.*]] = arith.cmpi ugt, %[[VAL_20]], %[[VAL_4]] : index
+// CHECK:           %[[VAL_22:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK:           %[[VAL_23:.*]] = arith.cmpi uge, %[[VAL_22]], %[[VAL_3]] : index
+// CHECK:           %[[VAL_24:.*]] = arith.andi %[[VAL_21]], %[[VAL_23]] : i1
+// CHECK:           %[[VAL_25:.*]] = arith.addi %[[VAL_22]], %[[VAL_5]] : index
+// CHECK:           %[[VAL_26:.*]] = arith.subi %[[VAL_25]], %[[VAL_3]] : index
+// CHECK:           %[[VAL_27:.*]] = arith.select %[[VAL_24]], %[[VAL_26]], %[[VAL_4]] : index
+// CHECK:           %[[VAL_28:.*]]:4 = scf.while (%[[VAL_29:.*]] = %[[VAL_21]], %[[VAL_30:.*]] = %[[VAL_22]], %[[VAL_31:.*]] = %[[VAL_27]], %[[VAL_32:.*]] = %[[VAL_11]]) : (i1, index, index, tensor<6x6xi32, #{{.*}}>) -> (i1, index, index, tensor<6x6xi32, #{{.*}}>) {
+// CHECK:             scf.condition(%[[VAL_29]]) %[[VAL_29]], %[[VAL_30]], %[[VAL_31]], %[[VAL_32]] : i1, index, index, tensor<6x6xi32, #{{.*}}>
+// CHECK:           } do {
+// CHECK:           ^bb0(%[[VAL_33:.*]]: i1, %[[VAL_34:.*]]: index, %[[VAL_35:.*]]: index, %[[VAL_36:.*]]: tensor<6x6xi32, #{{.*}}>):
+// CHECK:             %[[VAL_37:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK:             %[[VAL_38:.*]]:3 = scf.for %[[VAL_39:.*]] = %[[VAL_6]] to %[[VAL_37]] step %[[VAL_6]] iter_args(%[[VAL_40:.*]] = %[[VAL_10]], %[[VAL_41:.*]] = %[[VAL_2]], %[[VAL_42:.*]] = %[[VAL_6]]) -> (i1, index, index) {
+// CHECK:               %[[VAL_43:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_39]]] : memref<?xindex>
+// CHECK:               %[[VAL_44:.*]] = arith.addi %[[VAL_39]], %[[VAL_5]] : index
+// CHECK:               %[[VAL_45:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_44]]] : memref<?xindex>
+// CHECK:               %[[VAL_46:.*]] = arith.addi %[[VAL_35]], %[[VAL_3]] : index
+// CHECK:               %[[VAL_47:.*]]:5 = scf.while (%[[VAL_48:.*]] = %[[VAL_43]], %[[VAL_49:.*]] = %[[VAL_9]], %[[VAL_50:.*]] = %[[VAL_40]], %[[VAL_51:.*]] = %[[VAL_41]], %[[VAL_52:.*]] = %[[VAL_42]]) : (index, i1, i1, index, index) -> (index, i1, i1, index, index) {
+// CHECK:                 %[[VAL_53:.*]] = arith.cmpi ult, %[[VAL_48]], %[[VAL_45]] : index
+// CHECK:                 %[[VAL_54:.*]] = arith.andi %[[VAL_49]], %[[VAL_53]] : i1
+// CHECK:                 scf.condition(%[[VAL_54]]) %[[VAL_48]], %[[VAL_49]], %[[VAL_50]], %[[VAL_51]], %[[VAL_52]] : index, i1, i1, index, index
+// CHECK:               } do {
+// CHECK:               ^bb0(%[[VAL_55:.*]]: index, %[[VAL_56:.*]]: i1, %[[VAL_57:.*]]: i1, %[[VAL_58:.*]]: index, %[[VAL_59:.*]]: index):
+// CHECK:                 %[[VAL_60:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_55]]] : memref<?xindex>
+// CHECK:                 %[[VAL_61:.*]] = arith.cmpi ult, %[[VAL_60]], %[[VAL_46]] : index
+// CHECK:                 %[[VAL_62:.*]]:3 = scf.if %[[VAL_61]] -> (i1, index, index) {
+// CHECK:                   %[[VAL_63:.*]] = arith.addi %[[VAL_55]], %[[VAL_5]] : index
+// CHECK:                   %[[VAL_64:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_55]]] : memref<?xindex>
+// CHECK:                   %[[VAL_65:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_63]]] : memref<?xindex>
+// CHECK:                   %[[VAL_66:.*]] = arith.cmpi ult, %[[VAL_64]], %[[VAL_65]] : index
+// CHECK:                   %[[VAL_67:.*]] = arith.ori %[[VAL_66]], %[[VAL_57]] : i1
+// CHECK:                   %[[VAL_68:.*]] = scf.if %[[VAL_66]] -> (index) {
+// CHECK:                     %[[VAL_69:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_64]]] : memref<?xindex>
+// CHECK:                     %[[VAL_70:.*]] = arith.cmpi ult, %[[VAL_69]], %[[VAL_58]] : index
+// CHECK:                     %[[VAL_71:.*]] = arith.select %[[VAL_70]], %[[VAL_69]], %[[VAL_58]] : index
+// CHECK:                     scf.yield %[[VAL_71]] : index
+// CHECK:                   } else {
+// CHECK:                     scf.yield %[[VAL_58]] : index
+// CHECK:                   }
+// CHECK:                   memref.store %[[VAL_64]], %[[VAL_18]]{{\[}}%[[VAL_59]]] : memref<?xindex>
+// CHECK:                   %[[VAL_72:.*]] = arith.addi %[[VAL_59]], %[[VAL_5]] : index
+// CHECK:                   memref.store %[[VAL_65]], %[[VAL_18]]{{\[}}%[[VAL_72]]] : memref<?xindex>
+// CHECK:                   %[[VAL_73:.*]] = arith.addi %[[VAL_59]], %[[VAL_6]] : index
+// CHECK:                   scf.yield %[[VAL_67]], %[[VAL_74:.*]], %[[VAL_73]] : i1, index, index
+// CHECK:                 } else {
+// CHECK:                   scf.yield %[[VAL_57]], %[[VAL_58]], %[[VAL_59]] : i1, index, index
+// CHECK:                 } {"Emitted from" = "slice"}
+// CHECK:                 %[[VAL_75:.*]] = arith.addi %[[VAL_55]], %[[VAL_5]] : index
+// CHECK:                 scf.yield %[[VAL_75]], %[[VAL_61]], %[[VAL_76:.*]]#0, %[[VAL_76]]#1, %[[VAL_76]]#2 : index, i1, i1, index, index
+// CHECK:               }
+// CHECK:               scf.yield %[[VAL_77:.*]]#2, %[[VAL_77]]#3, %[[VAL_77]]#4 : i1, index, index
+// CHECK:             }
+// CHECK:             memref.store %[[VAL_78:.*]]#2, %[[VAL_18]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK:             memref.store %[[VAL_4]], %[[VAL_18]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK:             %[[VAL_79:.*]] = arith.cmpi uge, %[[VAL_78]]#1, %[[VAL_3]] : index
+// CHECK:             %[[VAL_80:.*]] = arith.andi %[[VAL_78]]#0, %[[VAL_79]] : i1
+// CHECK:             %[[VAL_81:.*]] = arith.addi %[[VAL_78]]#1, %[[VAL_5]] : index
+// CHECK:             %[[VAL_82:.*]] = arith.subi %[[VAL_81]], %[[VAL_3]] : index
+// CHECK:             %[[VAL_83:.*]] = arith.select %[[VAL_80]], %[[VAL_82]], %[[VAL_4]] : index
+// CHECK:             %[[VAL_84:.*]]:4 = scf.while (%[[VAL_85:.*]] = %[[VAL_78]]#0, %[[VAL_86:.*]] = %[[VAL_78]]#1, %[[VAL_87:.*]] = %[[VAL_83]], %[[VAL_88:.*]] = %[[VAL_36]]) : (i1, index, index, tensor<6x6xi32, #{{.*}}>) -> (i1, index, index, tensor<6x6xi32, #{{.*}}>) {
+// CHECK:               scf.condition(%[[VAL_85]]) %[[VAL_85]], %[[VAL_86]], %[[VAL_87]], %[[VAL_88]] : i1, index, index, tensor<6x6xi32, #{{.*}}>
+// CHECK:             } do {
+// CHECK:             ^bb0(%[[VAL_89:.*]]: i1, %[[VAL_90:.*]]: index, %[[VAL_91:.*]]: index, %[[VAL_92:.*]]: tensor<6x6xi32, #{{.*}}>):
+// CHECK:               %[[VAL_93:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK:               %[[VAL_94:.*]] = arith.addi %[[VAL_93]], %[[VAL_6]] : index
+// CHECK:               %[[VAL_95:.*]] = arith.addi %[[VAL_94]], %[[VAL_5]] : index
+// CHECK:               %[[VAL_96:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_94]]] : memref<?xindex>
+// CHECK:               %[[VAL_97:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_95]]] : memref<?xindex>
+// CHECK:               %[[VAL_98:.*]] = arith.addi %[[VAL_35]], %[[VAL_3]] : index
+// CHECK:               %[[VAL_99:.*]]:5 = scf.while (%[[VAL_100:.*]] = %[[VAL_96]], %[[VAL_101:.*]] = %[[VAL_9]], %[[VAL_102:.*]] = %[[VAL_8]], %[[VAL_103:.*]] = %[[VAL_10]], %[[VAL_104:.*]] = %[[VAL_92]]) : (index, i1, i32, i1, tensor<6x6xi32, #{{.*}}>) -> (index, i1, i32, i1, tensor<6x6xi32, #{{.*}}>) {
+// CHECK:                 %[[VAL_105:.*]] = arith.cmpi ult, %[[VAL_100]], %[[VAL_97]] : index
+// CHECK:                 %[[VAL_106:.*]] = arith.andi %[[VAL_101]], %[[VAL_105]] : i1
+// CHECK:                 scf.condition(%[[VAL_106]]) %[[VAL_100]], %[[VAL_101]], %[[VAL_102]], %[[VAL_103]], %[[VAL_104]] : index, i1, i32, i1, tensor<6x6xi32, #{{.*}}>
+// CHECK:               } do {
+// CHECK:               ^bb0(%[[VAL_107:.*]]: index, %[[VAL_108:.*]]: i1, %[[VAL_109:.*]]: i32, %[[VAL_110:.*]]: i1, %[[VAL_111:.*]]: tensor<6x6xi32, #{{.*}}>):
+// CHECK:                 %[[VAL_112:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_107]]] : memref<?xindex>
+// CHECK:                 %[[VAL_113:.*]] = arith.cmpi ult, %[[VAL_112]], %[[VAL_98]] : index
+// CHECK:                 %[[VAL_114:.*]]:3 = scf.if %[[VAL_113]] -> (i32, i1, tensor<6x6xi32, #{{.*}}>) {
+// CHECK:                   %[[VAL_115:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_107]]] : memref<?xindex>
+// CHECK:                   %[[VAL_116:.*]] = arith.subi %[[VAL_115]], %[[VAL_35]] : index
+// CHECK:                   %[[VAL_117:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK:                   %[[VAL_118:.*]] = arith.addi %[[VAL_117]], %[[VAL_6]] : index
+// CHECK:                   %[[VAL_119:.*]] = arith.addi %[[VAL_118]], %[[VAL_5]] : index
+// CHECK:                   %[[VAL_120:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_118]]] : memref<?xindex>
+// CHECK:                   %[[VAL_121:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_119]]] : memref<?xindex>
+// CHECK:                   %[[VAL_122:.*]] = arith.addi %[[VAL_91]], %[[VAL_3]] : index
+// CHECK:                   %[[VAL_123:.*]]:5 = scf.while (%[[VAL_124:.*]] = %[[VAL_120]], %[[VAL_125:.*]] = %[[VAL_9]], %[[VAL_126:.*]] = %[[VAL_109]], %[[VAL_127:.*]] = %[[VAL_110]], %[[VAL_128:.*]] = %[[VAL_111]]) : (index, i1, i32, i1, tensor<6x6xi32, #{{.*}}>) -> (index, i1, i32, i1, tensor<6x6xi32, #{{.*}}>) {
+// CHECK:                     %[[VAL_129:.*]] = arith.cmpi ult, %[[VAL_124]], %[[VAL_121]] : index
+// CHECK:                     %[[VAL_130:.*]] = arith.andi %[[VAL_125]], %[[VAL_129]] : i1
+// CHECK:                     scf.condition(%[[VAL_130]]) %[[VAL_124]], %[[VAL_125]], %[[VAL_126]], %[[VAL_127]], %[[VAL_128]] : index, i1, i32, i1, tensor<6x6xi32, #{{.*}}>
+// CHECK:                   } do {
+// CHECK:                   ^bb0(%[[VAL_131:.*]]: index, %[[VAL_132:.*]]: i1, %[[VAL_133:.*]]: i32, %[[VAL_134:.*]]: i1, %[[VAL_135:.*]]: tensor<6x6xi32, #{{.*}}>):
+// CHECK:                     %[[VAL_136:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_131]]] : memref<?xindex>
+// CHECK:                     %[[VAL_137:.*]] = arith.cmpi ult, %[[VAL_136]], %[[VAL_122]] : index
+// CHECK:                     %[[VAL_138:.*]]:3 = scf.if %[[VAL_137]] -> (i32, i1, tensor<6x6xi32, #{{.*}}>) {
+// CHECK:                       %[[VAL_139:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_131]]] : memref<?xindex>
+// CHECK:                       %[[VAL_140:.*]] = arith.subi %[[VAL_139]], %[[VAL_91]] : index
+// CHECK:                       %[[VAL_141:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_131]]] : memref<?xi32>
+// CHECK:                       %[[VAL_142:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_116]], %[[VAL_140]]] : memref<3x3xi32>
+// CHECK:                       %[[VAL_143:.*]] = arith.muli %[[VAL_141]], %[[VAL_142]] : i32
+// CHECK:                       %[[VAL_144:.*]] = arith.addi %[[VAL_133]], %[[VAL_143]] : i32
+// CHECK:                       scf.yield %[[VAL_144]], %[[VAL_9]], %[[VAL_135]] : i32, i1, tensor<6x6xi32, #{{.*}}>
+// CHECK:                     } else {
+// CHECK:                       scf.yield %[[VAL_133]], %[[VAL_134]], %[[VAL_135]] : i32, i1, tensor<6x6xi32, #{{.*}}>
+// CHECK:                     } {"Emitted from" = "slice"}
+// CHECK:                     %[[VAL_145:.*]] = arith.addi %[[VAL_131]], %[[VAL_5]] : index
+// CHECK:                     scf.yield %[[VAL_145]], %[[VAL_137]], %[[VAL_146:.*]]#0, %[[VAL_146]]#1, %[[VAL_146]]#2 : index, i1, i32, i1, tensor<6x6xi32, #{{.*}}>
+// CHECK:                   } attributes {"Emitted from" = "linalg.generic"}
+// CHECK:                   %[[VAL_147:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK:                   %[[VAL_148:.*]] = arith.addi %[[VAL_147]], %[[VAL_6]] : index
+// CHECK:                   memref.store %[[VAL_148]], %[[VAL_18]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK:                   scf.yield %[[VAL_149:.*]]#2, %[[VAL_9]], %[[VAL_149]]#4 : i32, i1, tensor<6x6xi32, #{{.*}}>
+// CHECK:                 } else {
+// CHECK:                   scf.yield %[[VAL_109]], %[[VAL_110]], %[[VAL_111]] : i32, i1, tensor<6x6xi32, #{{.*}}>
+// CHECK:                 } {"Emitted from" = "slice"}
+// CHECK:                 %[[VAL_150:.*]] = arith.addi %[[VAL_107]], %[[VAL_5]] : index
+// CHECK:                 scf.yield %[[VAL_150]], %[[VAL_113]], %[[VAL_151:.*]]#0, %[[VAL_151]]#1, %[[VAL_151]]#2 : index, i1, i32, i1, tensor<6x6xi32, #{{.*}}>
+// CHECK:               } attributes {"Emitted from" = "linalg.generic"}
+// CHECK:               %[[VAL_152:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK:               %[[VAL_153:.*]] = arith.addi %[[VAL_152]], %[[VAL_6]] : index
+// CHECK:               memref.store %[[VAL_153]], %[[VAL_19]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK:               %[[VAL_154:.*]] = scf.if %[[VAL_155:.*]]#3 -> (tensor<6x6xi32, #{{.*}}>) {
+// CHECK:                 %[[VAL_156:.*]] = sparse_tensor.insert %[[VAL_155]]#2 into %[[VAL_155]]#4{{\[}}%[[VAL_35]], %[[VAL_91]]] : tensor<6x6xi32, #{{.*}}>
+// CHECK:                 scf.yield %[[VAL_156]] : tensor<6x6xi32, #{{.*}}>
+// CHECK:               } else {
+// CHECK:                 scf.yield %[[VAL_157:.*]]#4 : tensor<6x6xi32, #{{.*}}>
+// CHECK:               }
+// CHECK:               memref.store %[[VAL_4]], %[[VAL_19]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK:               memref.store %[[VAL_4]], %[[VAL_18]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK:               %[[VAL_158:.*]] = arith.cmpi ugt, %[[VAL_90]], %[[VAL_91]] : index
+// CHECK:               %[[VAL_159:.*]]:3 = scf.if %[[VAL_158]] -> (index, i1, index) {
+// CHECK:                 %[[VAL_160:.*]] = arith.addi %[[VAL_91]], %[[VAL_5]] : index
+// CHECK:                 scf.yield %[[VAL_90]], %[[VAL_89]], %[[VAL_160]] : index, i1, index
+// CHECK:               } else {
+// CHECK:                 %[[VAL_161:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK:                 %[[VAL_162:.*]]:2 = scf.for %[[VAL_163:.*]] = %[[VAL_6]] to %[[VAL_161]] step %[[VAL_6]] iter_args(%[[VAL_164:.*]] = %[[VAL_2]], %[[VAL_165:.*]] = %[[VAL_10]]) -> (index, i1) {
+// CHECK:                   %[[VAL_166:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_163]]] : memref<?xindex>
+// CHECK:                   %[[VAL_167:.*]] = arith.addi %[[VAL_163]], %[[VAL_5]] : index
+// CHECK:                   %[[VAL_168:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_167]]] : memref<?xindex>
+// CHECK:                   %[[VAL_169:.*]] = arith.cmpi ult, %[[VAL_166]], %[[VAL_168]] : index
+// CHECK:                   %[[VAL_170:.*]] = scf.if %[[VAL_169]] -> (index) {
+// CHECK:                     %[[VAL_171:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_166]]] : memref<?xindex>
+// CHECK:                     %[[VAL_172:.*]] = arith.cmpi eq, %[[VAL_171]], %[[VAL_90]] : index
+// CHECK:                     %[[VAL_173:.*]] = scf.if %[[VAL_172]] -> (index) {
+// CHECK:                       %[[VAL_174:.*]] = arith.addi %[[VAL_166]], %[[VAL_5]] : index
+// CHECK:                       memref.store %[[VAL_174]], %[[VAL_18]]{{\[}}%[[VAL_163]]] : memref<?xindex>
+// CHECK:                       scf.yield %[[VAL_174]] : index
+// CHECK:                     } else {
+// CHECK:                       scf.yield %[[VAL_166]] : index
+// CHECK:                     }
+// CHECK:                     scf.yield %[[VAL_175:.*]] : index
+// CHECK:                   } else {
+// CHECK:                     scf.yield %[[VAL_166]] : index
+// CHECK:                   }
+// CHECK:                   %[[VAL_176:.*]] = arith.cmpi ult, %[[VAL_177:.*]], %[[VAL_168]] : index
+// CHECK:                   %[[VAL_178:.*]] = scf.if %[[VAL_176]] -> (index) {
+// CHECK:                     %[[VAL_179:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_177]]] : memref<?xindex>
+// CHECK:                     scf.yield %[[VAL_179]] : index
+// CHECK:                   } else {
+// CHECK:                     scf.yield %[[VAL_164]] : index
+// CHECK:                   }
+// CHECK:                   %[[VAL_180:.*]] = arith.ori %[[VAL_176]], %[[VAL_165]] : i1
+// CHECK:                   %[[VAL_181:.*]] = arith.cmpi ult, %[[VAL_182:.*]], %[[VAL_164]] : index
+// CHECK:                   %[[VAL_183:.*]] = arith.select %[[VAL_181]], %[[VAL_182]], %[[VAL_164]] : index
+// CHECK:                   scf.yield %[[VAL_183]], %[[VAL_180]] : index, i1
+// CHECK:                 }
+// CHECK:                 %[[VAL_184:.*]] = arith.addi %[[VAL_185:.*]]#0, %[[VAL_5]] : index
+// CHECK:                 %[[VAL_186:.*]] = arith.subi %[[VAL_184]], %[[VAL_3]] : index
+// CHECK:                 %[[VAL_187:.*]] = arith.cmpi uge, %[[VAL_184]], %[[VAL_3]] : index
+// CHECK:                 %[[VAL_188:.*]] = arith.select %[[VAL_187]], %[[VAL_186]], %[[VAL_4]] : index
+// CHECK:                 scf.yield %[[VAL_185]]#0, %[[VAL_185]]#1, %[[VAL_188]] : index, i1, index
+// CHECK:               }
+// CHECK:               %[[VAL_189:.*]] = arith.addi %[[VAL_91]], %[[VAL_5]] : index
+// CHECK:               %[[VAL_190:.*]] = arith.cmpi ugt, %[[VAL_191:.*]]#2, %[[VAL_189]] : index
+// CHECK:               %[[VAL_192:.*]] = arith.select %[[VAL_190]], %[[VAL_191]]#2, %[[VAL_189]] : index
+// CHECK:               %[[VAL_193:.*]] = arith.addi %[[VAL_192]], %[[VAL_3]] : index
+// CHECK:               %[[VAL_194:.*]] = arith.cmpi ule, %[[VAL_193]], %[[VAL_2]] : index
+// CHECK:               %[[VAL_195:.*]] = arith.andi %[[VAL_191]]#1, %[[VAL_194]] : i1
+// CHECK:               scf.yield %[[VAL_195]], %[[VAL_191]]#0, %[[VAL_192]], %[[VAL_196:.*]] : i1, index, index, tensor<6x6xi32, #{{.*}}>
+// CHECK:             } attributes {"Emitted from" = "linalg.generic"}
+// CHECK:             memref.store %[[VAL_4]], %[[VAL_19]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK:             %[[VAL_197:.*]] = arith.cmpi ugt, %[[VAL_34]], %[[VAL_35]] : index
+// CHECK:             %[[VAL_198:.*]]:3 = scf.if %[[VAL_197]] -> (index, i1, index) {
+// CHECK:               %[[VAL_199:.*]] = arith.addi %[[VAL_35]], %[[VAL_5]] : index
+// CHECK:               scf.yield %[[VAL_34]], %[[VAL_33]], %[[VAL_199]] : index, i1, index
+// CHECK:             } else {
+// CHECK:               %[[VAL_200:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK:               %[[VAL_201:.*]]:2 = scf.for %[[VAL_202:.*]] = %[[VAL_6]] to %[[VAL_200]] step %[[VAL_6]] iter_args(%[[VAL_203:.*]] = %[[VAL_2]], %[[VAL_204:.*]] = %[[VAL_10]]) -> (index, i1) {
+// CHECK:                 %[[VAL_205:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_202]]] : memref<?xindex>
+// CHECK:                 %[[VAL_206:.*]] = arith.addi %[[VAL_202]], %[[VAL_5]] : index
+// CHECK:                 %[[VAL_207:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_206]]] : memref<?xindex>
+// CHECK:                 %[[VAL_208:.*]] = arith.cmpi ult, %[[VAL_205]], %[[VAL_207]] : index
+// CHECK:                 %[[VAL_209:.*]] = scf.if %[[VAL_208]] -> (index) {
+// CHECK:                   %[[VAL_210:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_205]]] : memref<?xindex>
+// CHECK:                   %[[VAL_211:.*]] = arith.cmpi eq, %[[VAL_210]], %[[VAL_34]] : index
+// CHECK:                   %[[VAL_212:.*]] = scf.if %[[VAL_211]] -> (index) {
+// CHECK:                     %[[VAL_213:.*]] = arith.addi %[[VAL_205]], %[[VAL_5]] : index
+// CHECK:                     memref.store %[[VAL_213]], %[[VAL_19]]{{\[}}%[[VAL_202]]] : memref<?xindex>
+// CHECK:                     scf.yield %[[VAL_213]] : index
+// CHECK:                   } else {
+// CHECK:                     scf.yield %[[VAL_205]] : index
+// CHECK:                   }
+// CHECK:                   scf.yield %[[VAL_214:.*]] : index
+// CHECK:                 } else {
+// CHECK:                   scf.yield %[[VAL_205]] : index
+// CHECK:                 }
+// CHECK:                 %[[VAL_215:.*]] = arith.cmpi ult, %[[VAL_216:.*]], %[[VAL_207]] : index
+// CHECK:                 %[[VAL_217:.*]] = scf.if %[[VAL_215]] -> (index) {
+// CHECK:                   %[[VAL_218:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_216]]] : memref<?xindex>
+// CHECK:                   scf.yield %[[VAL_218]] : index
+// CHECK:                 } else {
+// CHECK:                   scf.yield %[[VAL_203]] : index
+// CHECK:                 }
+// CHECK:                 %[[VAL_219:.*]] = arith.ori %[[VAL_215]], %[[VAL_204]] : i1
+// CHECK:                 %[[VAL_220:.*]] = arith.cmpi ult, %[[VAL_221:.*]], %[[VAL_203]] : index
+// CHECK:                 %[[VAL_222:.*]] = arith.select %[[VAL_220]], %[[VAL_221]], %[[VAL_203]] : index
+// CHECK:                 scf.yield %[[VAL_222]], %[[VAL_219]] : index, i1
+// CHECK:               }
+// CHECK:               %[[VAL_223:.*]] = arith.addi %[[VAL_224:.*]]#0, %[[VAL_5]] : index
+// CHECK:               %[[VAL_225:.*]] = arith.subi %[[VAL_223]], %[[VAL_3]] : index
+// CHECK:               %[[VAL_226:.*]] = arith.cmpi uge, %[[VAL_223]], %[[VAL_3]] : index
+// CHECK:               %[[VAL_227:.*]] = arith.select %[[VAL_226]], %[[VAL_225]], %[[VAL_4]] : index
+// CHECK:               scf.yield %[[VAL_224]]#0, %[[VAL_224]]#1, %[[VAL_227]] : index, i1, index
+// CHECK:             }
+// CHECK:             %[[VAL_228:.*]] = arith.addi %[[VAL_35]], %[[VAL_5]] : index
+// CHECK:             %[[VAL_229:.*]] = arith.cmpi ugt, %[[VAL_230:.*]]#2, %[[VAL_228]] : index
+// CHECK:             %[[VAL_231:.*]] = arith.select %[[VAL_229]], %[[VAL_230]]#2, %[[VAL_228]] : index
+// CHECK:             %[[VAL_232:.*]] = arith.addi %[[VAL_231]], %[[VAL_3]] : index
+// CHECK:             %[[VAL_233:.*]] = arith.cmpi ule, %[[VAL_232]], %[[VAL_2]] : index
+// CHECK:             %[[VAL_234:.*]] = arith.andi %[[VAL_230]]#1, %[[VAL_233]] : i1
+// CHECK:             scf.yield %[[VAL_234]], %[[VAL_230]]#0, %[[VAL_231]], %[[VAL_235:.*]]#3 : i1, index, index, tensor<6x6xi32, #{{.*}}>
+// CHECK:           } attributes {"Emitted from" = "linalg.generic"}
+// CHECK:           %[[VAL_236:.*]] = sparse_tensor.load %[[VAL_237:.*]]#3 hasInserts : tensor<6x6xi32, #{{.*}}>
+// CHECK:           return %[[VAL_236]] : tensor<6x6xi32, #{{.*}}>
+// CHECK:         }
+func.func @conv2d_all_sparse_CSR(%arg0: tensor<8x8xi32, #DCSR>,
+                                 %arg1: tensor<3x3xi32>) -> tensor<6x6xi32, #DCSR> {
+  %0 = bufferization.alloc_tensor() : tensor<6x6xi32, #DCSR>
+  %1 = linalg.generic {
+         indexing_maps = [#map, #map1, #map2],
+         iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+         ins(%arg0, %arg1 : tensor<8x8xi32, #DCSR>, tensor<3x3xi32>)
+         outs(%0 : tensor<6x6xi32, #DCSR>) {
+    ^bb0(%in: i32, %in_0: i32, %out: i32):
+      %2 = arith.muli %in, %in_0 : i32
+      %3 = arith.addi %out, %2 : i32
+      linalg.yield %3 : i32
+    } -> tensor<6x6xi32, #DCSR>
+  return %1 : tensor<6x6xi32, #DCSR>
+}

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conv_2d_slice_based.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conv_2d_slice_based.mlir
new file mode 100644
index 0000000000000..0191300076737
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conv_2d_slice_based.mlir
@@ -0,0 +1,83 @@
+// DEFINE: %{option} = "enable-index-reduction=true enable-runtime-library=false"
+// DEFINE: %{command} = mlir-opt %s --sparse-compiler=%{option} | \
+// DEFINE: mlir-cpu-runner \
+// DEFINE:  -e entry -entry-point-result=void  \
+// DEFINE:  -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \
+// DEFINE: FileCheck %s
+//
+// RUN: %{command}
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+
+#DCSR = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>
+
+module {
+  func.func @conv2d_all_sparse_CSR(%arg0: tensor<8x8xi32, #DCSR>, %arg1: tensor<3x3xi32>) -> tensor<6x6xi32, #DCSR> {
+    %0 = bufferization.alloc_tensor() : tensor<6x6xi32, #DCSR>
+    %1 = linalg.generic {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+    ins(%arg0, %arg1 : tensor<8x8xi32, #DCSR>, tensor<3x3xi32>)
+    outs(%0 : tensor<6x6xi32, #DCSR>) {
+    ^bb0(%in: i32, %in_0: i32, %out: i32):
+      %2 = arith.muli %in, %in_0 : i32
+      %3 = arith.addi %out, %2 : i32
+      linalg.yield %3 : i32
+    } -> tensor<6x6xi32, #DCSR>
+    return %1 : tensor<6x6xi32, #DCSR>
+  }
+
+  func.func @entry() {
+    %c0 = arith.constant 0 : index
+    %i0 = arith.constant 0 : i32
+
+    // A typical edge detection filter.
+    %filter = arith.constant dense<[
+      [  1,  0, -1 ],
+      [  0,  0,  0 ],
+      [ -1,  0,  1 ]
+    ]> : tensor<3x3xi32>
+
+    %input = arith.constant dense<[
+      [  1,  2,  3,  4,  0,  6,  7,  8 ],
+      [  2,  2,  4,  4,  0,  0,  6,  8 ],
+      [  2,  2,  4,  4,  0,  0,  6,  8 ],
+      [  2,  2,  3,  4,  0,  0,  7,  8 ],
+      [  1,  3,  3,  4,  0,  0,  6,  8 ],
+      [  3,  2,  3,  4,  0,  0,  7,  8 ],
+      [  1,  3,  3,  4,  3,  6,  6,  8 ],
+      [  1,  3,  3,  4,  3,  0,  7,  8 ]
+    ]> : tensor<8x8xi32>
+
+    %sparse_filter_CSR = sparse_tensor.convert %filter
+      : tensor<3x3xi32> to tensor<3x3xi32>
+
+    %sparse_input_CSR = sparse_tensor.convert %input
+      : tensor<8x8xi32> to tensor<8x8xi32, #DCSR>
+
+    %3 = call @conv2d_all_sparse_CSR(%sparse_input_CSR, %sparse_filter_CSR)
+         : (tensor<8x8xi32, #DCSR>,
+            tensor<3x3xi32>) -> tensor<6x6xi32, #DCSR>
+
+    %out = sparse_tensor.convert %3
+      : tensor<6x6xi32, #DCSR> to tensor<6x6xi32>
+    //
+    // CHECK:    ( ( 0, 0, -1, -6, -1, 6 ),
+    // CHECK-SAME: ( -1, 0, 1, 0, 1, 0 ),
+    // CHECK-SAME: ( 0, -1, 1, 0, 0, 0 ),
+    // CHECK-SAME: ( -1, 0, 0, 0, 0, 0 ),
+    // CHECK-SAME: ( 0, 0, 3, 6, -3, -6 ),
+    // CHECK-SAME: ( 2, -1, 3, 0, -3, 0 ) )
+    //
+    %v2 = vector.transfer_read %out[%c0, %c0], %i0
+      : tensor<6x6xi32>, vector<6x6xi32>
+    vector.print %v2 : vector<6x6xi32>
+
+    bufferization.dealloc_tensor %sparse_input_CSR : tensor<8x8xi32, #DCSR>
+    bufferization.dealloc_tensor %3 : tensor<6x6xi32, #DCSR>
+    return
+  }
+
+}

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conv_3d_slice_based.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conv_3d_slice_based.mlir
new file mode 100644
index 0000000000000..e7269c8e77b40
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conv_3d_slice_based.mlir
@@ -0,0 +1,98 @@
+// DEFINE: %{option} = "enable-index-reduction=true enable-runtime-library=false"
+// DEFINE: %{command} = mlir-opt %s --sparse-compiler=%{option} | \
+// DEFINE: mlir-cpu-runner \
+// DEFINE:  -e entry -entry-point-result=void  \
+// DEFINE:  -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \
+// DEFINE: FileCheck %s
+//
+// RUN: %{command}
+
+#CCC = #sparse_tensor.encoding<{
+  dimLevelType = [ "compressed", "compressed", "compressed" ]
+}>
+
+func.func @alloc_3d_filled_f32(%s1 : index, %s2 : index, %s3 : index, %f : f32) -> tensor<?x?x?xf32> {
+  %buf = bufferization.alloc_tensor(%s1, %s2, %s3) : tensor<?x?x?xf32>
+  %ret = linalg.fill ins(%f : f32) outs(%buf : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  return %ret : tensor<?x?x?xf32>
+}
+
+func.func @conv_3d_CCC(%arg0: tensor<?x?x?xf32, #CCC>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32, #CCC> {
+  %c6 = arith.constant 6 : index
+  %s = bufferization.alloc_tensor(%c6, %c6, %c6) : tensor<?x?x?xf32, #CCC>
+  %ret = linalg.conv_3d
+     ins (%arg0, %arg1: tensor<?x?x?xf32, #CCC>, tensor<?x?x?xf32>)
+    outs (%s: tensor<?x?x?xf32, #CCC>) -> tensor<?x?x?xf32, #CCC>
+  return %ret : tensor<?x?x?xf32, #CCC>
+}
+
+func.func @entry() {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c3 = arith.constant 3 : index
+  %c6 = arith.constant 6 : index
+  %c8 = arith.constant 8 : index
+  %f10 = arith.constant 10.00000e+00 : f32
+  %val = arith.constant 2.00000e+00 : f32
+  %zero = arith.constant 0.00000e+00 : f32
+
+  %filter3D = call @alloc_3d_filled_f32(%c3, %c3, %c3, %val) : (index, index, index, f32) -> (tensor<?x?x?xf32>)
+  %in3D_tmp = call @alloc_3d_filled_f32(%c8, %c8, %c8, %val) : (index, index, index, f32) -> (tensor<?x?x?xf32>)
+  %in3D = tensor.insert %f10 into %in3D_tmp[%c0, %c3, %c0] : tensor<?x?x?xf32>
+  %out3D = call @alloc_3d_filled_f32(%c6, %c6, %c6, %zero) : (index, index, index, f32) -> (tensor<?x?x?xf32>)
+
+  %in3D_CCC = sparse_tensor.convert %in3D
+    : tensor<?x?x?xf32> to tensor<?x?x?xf32, #CCC>
+  %CCC_ret = call @conv_3d_CCC(%in3D_CCC, %filter3D) : (tensor<?x?x?xf32, #CCC>, tensor<?x?x?xf32>) -> (tensor<?x?x?xf32, #CCC>)
+  // CHECK:     ( ( ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 124, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 124, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 124, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ) ),
+  // CHECK-SAME:  ( ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ) ),
+  // CHECK-SAME:  ( ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ) ),
+  // CHECK-SAME:  ( ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ) ),
+  // CHECK-SAME:  ( ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ) ),
+  // CHECK-SAME:  ( ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ) ) )
+  %1 = sparse_tensor.convert %CCC_ret
+    : tensor<?x?x?xf32, #CCC> to tensor<?x?x?xf32>
+  %v1 = vector.transfer_read %1[%c0, %c0, %c0], %zero
+      : tensor<?x?x?xf32>, vector<6x6x6xf32>
+  vector.print %v1 : vector<6x6x6xf32>
+
+  // Free the resources
+  bufferization.dealloc_tensor %in3D : tensor<?x?x?xf32>
+  bufferization.dealloc_tensor %filter3D : tensor<?x?x?xf32>
+  bufferization.dealloc_tensor %out3D : tensor<?x?x?xf32>
+
+  bufferization.dealloc_tensor %in3D_CCC : tensor<?x?x?xf32, #CCC>
+  bufferization.dealloc_tensor %CCC_ret : tensor<?x?x?xf32, #CCC>
+
+  return
+}


        


More information about the Mlir-commits mailing list