[Mlir-commits] [mlir] [mlir][sparse] Support pretty print to debug sparse iteration. (PR #80207)

Peiming Liu llvmlistbot at llvm.org
Thu Feb 1 12:09:07 PST 2024


https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/80207

>From 372d11372712bf827a9f6877c9b74636e256fee5 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 30 Jan 2024 22:08:32 +0000
Subject: [PATCH 1/4] [mlir][sparse] Support pretty print to debug sparse
 iteration.

---
 .../Transforms/SparseTensorPasses.cpp         |   4 +-
 .../Transforms/Sparsification.cpp             |   2 +-
 .../Transforms/Utils/CodegenEnv.cpp           |   5 +-
 .../Transforms/Utils/CodegenEnv.h             |   2 +-
 .../Transforms/Utils/LoopEmitter.cpp          |  13 +-
 .../Transforms/Utils/LoopEmitter.h            |  20 +-
 .../Transforms/Utils/SparseTensorLevel.cpp    | 315 +++++++++++-------
 .../Transforms/Utils/SparseTensorLevel.h      | 172 ++++++----
 .../sparse_conv_2d_slice_based.mlir           | 276 +++------------
 9 files changed, 387 insertions(+), 422 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 375e10f9068e4..0ae9f6483588d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -82,13 +82,15 @@ struct SparsificationPass
   SparsificationPass(const SparsificationPass &pass) = default;
   SparsificationPass(const SparsificationOptions &options) {
     parallelization = options.parallelizationStrategy;
+    debugSparseIteration = options.debugSparseIteration;
     enableRuntimeLibrary = options.enableRuntimeLibrary;
   }
 
   void runOnOperation() override {
     auto *ctx = &getContext();
     // Translate strategy flags to strategy options.
-    SparsificationOptions options(parallelization, enableRuntimeLibrary);
+    SparsificationOptions options(parallelization, debugSparseIteration,
+                                  enableRuntimeLibrary);
     // Apply sparsification and cleanup rewriting.
     RewritePatternSet patterns(ctx);
     populateSparsificationPatterns(patterns, options);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 5266ca7213bfc..2ceb214052aa3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1369,7 +1369,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
       return failure();
 
     // Recursively generates code if admissible.
-    env.startEmit();
+    env.startEmit(options.debugSparseIteration);
     genBuffers(env, rewriter);
     // TODO: Constant affine expression should be handled differently when using
     // slice-based codegen, it does not matter now because we already reject the
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.cpp
index d3de55e4d59bd..0af1cc1745f51 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.cpp
@@ -59,7 +59,7 @@ LogicalResult CodegenEnv::initTensorExp() {
   return success();
 }
 
-void CodegenEnv::startEmit() {
+void CodegenEnv::startEmit(DebugSparseIteration emitStrategy) {
   assert(insChain == nullptr && "must only start emitting once");
   if (sparseOut) {
     insChain = sparseOut->get();
@@ -96,7 +96,8 @@ void CodegenEnv::startEmit() {
       /*dependentLvlGetter=*/
       [this](TensorId t, Level lvl) -> std::vector<LoopCoeffPair> {
         return merger().getDependentLoops(t, lvl);
-      });
+      },
+      emitStrategy);
 }
 
 std::optional<Operation *> CodegenEnv::genLoopBoundary(
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h
index 728af841cc7b1..7eeddac48f4f1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h
@@ -52,7 +52,7 @@ class CodegenEnv {
   Merger &merger() { return latticeMerger; }
   LoopEmitter &emitter() { return loopEmitter; }
 
-  void startEmit();
+  void startEmit(DebugSparseIteration emitStrategy);
 
   /// Generates loop boundary statements (entering/exiting loops). The function
   /// passes and updates the passed-in parameters.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
index 3fa4004ae460e..8c1680a393181 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
@@ -81,17 +81,20 @@ static Value genSliceStride(OpBuilder &builder, Location loc, Value tensor,
 
 LoopEmitter::LoopEmitter(ValueRange tensors, StringAttr loopTag, bool hasOutput,
                          bool isSparseOut, unsigned numLoops,
-                         DependentLvlGetter dimGetter) {
+                         DependentLvlGetter dimGetter,
+                         DebugSparseIteration emitStrategy) {
   initialize(tensors, loopTag, hasOutput, isSparseOut, numLoops, dimGetter);
 }
 
 void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
                              bool isSparseOut, unsigned numLoops,
-                             DependentLvlGetter dimGetter) {
+                             DependentLvlGetter dimGetter,
+                             DebugSparseIteration emitStrategy) {
   // First initialize the top-level type of the fields.
   this->loopTag = loopTag;
   this->hasOutput = hasOutput;
   this->isSparseOut = isSparseOut;
+  SparseIterator::setDebugSparseIteration(emitStrategy);
 
   const unsigned numManifestTensors = ts.size();
   const unsigned synTensorId = numManifestTensors;
@@ -169,7 +172,7 @@ LoopEmitter::makeLevelIterator(OpBuilder &builder, Location loc, TensorId t,
     Value offset = genSliceOffset(builder, loc, tensors[t], l);
     Value stride = genSliceStride(builder, loc, tensors[t], l);
     auto slicedIt = makeSlicedLevelIterator(std::move(it), offset, stride,
-                                            lvls[t][l]->size());
+                                            lvls[t][l]->getSize());
     return slicedIt;
   }
   return it;
@@ -465,7 +468,7 @@ std::pair<Operation *, Value> LoopEmitter::emitWhileLoopOverTensorsAtLvls(
 
   // Construct the while-loop with a parameter for each coordinate.
   for (SparseIterator *it : spIters) {
-    ValueRange itVals = it->getItVals();
+    ValueRange itVals = it->getCursor();
     ivs.append(itVals.begin(), itVals.end());
   }
 
@@ -724,7 +727,7 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
       // Forward the sparse iterator.
       Value cmp = CMPI(eq, it.getCrd(), iv);
       it.forwardIf(builder, loc, cmp);
-      operands.append(it.getItVals().begin(), it.getItVals().end());
+      operands.append(it.getCursor().begin(), it.getCursor().end());
       // const Value newPos = whileOp->getResult(o++);
       // Following loops continue iteration from the break point of the
       // current while loop.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
index d0f447d926f71..e0b4f81487a68 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
@@ -15,6 +15,7 @@
 
 #include "mlir/Dialect/SparseTensor/IR/Enums.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
 #include "mlir/IR/PatternMatch.h"
 
@@ -84,14 +85,17 @@ class LoopEmitter {
   /// `isSparseOut` indicates that the sparse output tensor is empty,
   /// so the loop emitter will generate loops over it according to the
   /// level-sizes.
-  void initialize(ValueRange tensors, StringAttr loopTag = nullptr,
-                  bool hasOutput = false, bool isSparseOut = false,
-                  unsigned numLoops = 0, DependentLvlGetter getter = nullptr);
-
-  explicit LoopEmitter(ValueRange tensors, StringAttr loopTag = nullptr,
-                       bool hasOutput = false, bool isSparseOut = false,
-                       unsigned numLoops = 0,
-                       DependentLvlGetter getter = nullptr);
+  void
+  initialize(ValueRange tensors, StringAttr loopTag = nullptr,
+             bool hasOutput = false, bool isSparseOut = false,
+             unsigned numLoops = 0, DependentLvlGetter getter = nullptr,
+             DebugSparseIteration emitStrategy = DebugSparseIteration::kNone);
+
+  explicit LoopEmitter(
+      ValueRange tensors, StringAttr loopTag = nullptr, bool hasOutput = false,
+      bool isSparseOut = false, unsigned numLoops = 0,
+      DependentLvlGetter getter = nullptr,
+      DebugSparseIteration emitStrategy = DebugSparseIteration::kNone);
 
   /// Starts a loop emitting session by generating all the buffers needed
   /// for iterating over the tensors.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
index 98323c2195461..bdaf794744bea 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
@@ -46,20 +46,6 @@ using ValueTuple = std::tuple<Value, Value, Value>;
 
 namespace {
 
-class SparseLevel : public SparseTensorLevel {
-public:
-  SparseLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
-              Value crdBuffer)
-      : SparseTensorLevel(tid, lvl, lt, lvlSize), crdBuffer(crdBuffer) {}
-
-  Value peekCrdAt(OpBuilder &b, Location l, Value iv) const override {
-    return genIndexLoad(b, l, crdBuffer, iv);
-  }
-
-protected:
-  const Value crdBuffer;
-};
-
 class DenseLevel : public SparseTensorLevel {
 public:
   DenseLevel(unsigned tid, Level lvl, Value lvlSize, bool encoded)
@@ -74,53 +60,27 @@ class DenseLevel : public SparseTensorLevel {
                         Value max) const override {
     assert(max == nullptr && "Dense level can not be non-unique.");
     if (encoded) {
-      Value posLo = MULI(p, lvlSize);
-      return {posLo, lvlSize};
+      Value posLo = MULI(p, getSize());
+      return {posLo, getSize()};
     }
     // No need to linearize the position for non-annotated tensors.
-    return {C_IDX(0), lvlSize};
+    return {C_IDX(0), getSize()};
   }
 
   const bool encoded;
 };
 
-class CompressedLevel : public SparseLevel {
+class SparseLevel : public SparseTensorLevel {
 public:
-  CompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
-                  Value posBuffer, Value crdBuffer)
-      : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
-
-  ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
-                        Value max) const override {
-    if (max == nullptr) {
-      Value pLo = genIndexLoad(b, l, posBuffer, p);
-      Value pHi = genIndexLoad(b, l, posBuffer, ADDI(p, C_IDX(1)));
-      return {pLo, pHi};
-    }
-    llvm_unreachable("compressed-nu should be the first non-unique level.");
+  SparseLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
+              ValueRange lvlBuf)
+      : SparseTensorLevel(tid, lvl, lt, lvlSize, lvlBuf) {
+    assert(!lvlBuf.empty());
   }
 
-private:
-  const Value posBuffer;
-};
-
-class LooseCompressedLevel : public SparseLevel {
-public:
-  LooseCompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
-                       Value posBuffer, Value crdBuffer)
-      : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
-
-  ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
-                        Value max) const override {
-    assert(max == nullptr && "loss compressed level can not be non-unique.");
-    p = MULI(p, C_IDX(2));
-    Value pLo = genIndexLoad(b, l, posBuffer, p);
-    Value pHi = genIndexLoad(b, l, posBuffer, ADDI(p, C_IDX(1)));
-    return {pLo, pHi};
+  Value peekCrdAt(OpBuilder &b, Location l, Value iv) const override {
+    return genIndexLoad(b, l, getLvlBufs().front(), iv);
   }
-
-private:
-  const Value posBuffer;
 };
 
 class SingletonLevel : public SparseLevel {
@@ -142,8 +102,8 @@ class SingletonLevel : public SparseLevel {
 class TwoOutFourLevel : public SparseLevel {
 public:
   TwoOutFourLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
-                  Value crdBuffer)
-      : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer) {}
+                  Value crdBuf)
+      : SparseLevel(tid, lvl, lt, lvlSize, crdBuf) {}
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
                         Value max) const override {
@@ -154,6 +114,39 @@ class TwoOutFourLevel : public SparseLevel {
   }
 };
 
+class CompressedLevel : public SparseLevel {
+public:
+  CompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
+                  Value posBuffer, Value crdBuffer)
+      : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer, posBuffer}) {}
+
+  ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
+                        Value max) const override {
+    if (max == nullptr) {
+      Value pLo = genIndexLoad(b, l, getPosBuf(), p);
+      Value pHi = genIndexLoad(b, l, getPosBuf(), ADDI(p, C_IDX(1)));
+      return {pLo, pHi};
+    }
+    llvm_unreachable("compressed-nu should be the first non-unique level.");
+  }
+};
+
+class LooseCompressedLevel : public SparseLevel {
+public:
+  LooseCompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
+                       Value posBuffer, Value crdBuffer)
+      : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer, posBuffer}) {}
+
+  ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
+                        Value max) const override {
+    assert(max == nullptr && "loss compressed level can not be non-unique.");
+    p = MULI(p, C_IDX(2));
+    Value pLo = genIndexLoad(b, l, getPosBuf(), p);
+    Value pHi = genIndexLoad(b, l, getPosBuf(), ADDI(p, C_IDX(1)));
+    return {pLo, pHi};
+  }
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -203,7 +196,8 @@ static Value offsetFromMinCrd(OpBuilder &b, Location l, Value minCrd,
 // SparseIterator derived classes.
 //===----------------------------------------------------------------------===//
 
-namespace {
+namespace mlir {
+namespace sparse_tensor {
 
 // The iterator that traverses a concrete sparse tensor levels. High-level
 // abstract iterators wrap it to achieve more complex goals (such as collapsing
@@ -212,12 +206,11 @@ namespace {
 class ConcreteIterator : public SparseIterator {
 protected:
   ConcreteIterator(const SparseTensorLevel &stl, IterKind kind,
-                   unsigned itValCnt)
-      : SparseIterator(kind, stl.tid, stl.lvl, itValCnt, itValsStorage),
-        stl(stl) {
-    // Allocate enough storage for iterator values.
-    itValsStorage.resize(itValCnt);
-  }
+                   unsigned cursorValCnt)
+      : SparseIterator(kind, stl.tid, stl.lvl, cursorValCnt, cursorValsStorage),
+        stl(stl), cursorValsStorage(cursorValCnt, nullptr) {
+    assert(getCursor().size() == cursorValCnt);
+  };
 
 public:
   // For LLVM-style RTTI.
@@ -228,22 +221,34 @@ class ConcreteIterator : public SparseIterator {
   bool randomAccessible() const override { return isDenseLT(stl.getLT()); };
   bool iteratableByFor() const override { return kind != IterKind::kDedup; };
   Value upperBound(OpBuilder &b, Location l) const override {
-    return stl.size();
+    return stl.getSize();
   };
 
 protected:
+  const SparseTensorLevel &stl;
   // Owner of the storage, all wrappers build on top of a concrete iterator
   // share the same storage such that the iterator values are always
   // synchronized.
-  SmallVector<Value> itValsStorage;
-  const SparseTensorLevel &stl;
+  SmallVector<Value> cursorValsStorage;
 };
 
+} // namespace sparse_tensor
+} // namespace mlir
+
+namespace {
+
 class TrivialIterator : public ConcreteIterator {
 public:
   TrivialIterator(const SparseTensorLevel &stl)
       : ConcreteIterator(stl, IterKind::kTrivial, /*itValCnt=*/1) {}
 
+  std::string getDebugInterfacePrefix() const override {
+    return std::string("trivial<") + stl.toString() + ">";
+  }
+  SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
+    return {b.getIndexType()};
+  }
+
   SmallVector<Value> serialize() const override {
     SmallVector<Value> ret;
     ret.push_back(getItPos());
@@ -286,12 +291,12 @@ class TrivialIterator : public ConcreteIterator {
     return std::make_pair(getItPos(), posHi);
   }
 
-  Value genNotEnd(OpBuilder &b, Location l) override {
+  Value genNotEndImpl(OpBuilder &b, Location l) override {
     // We used the first level bound as the bound the collapsed set of levels.
     return CMPI(ult, getItPos(), posHi);
   }
 
-  Value deref(OpBuilder &b, Location l) override {
+  Value derefImpl(OpBuilder &b, Location l) override {
     if (randomAccessible()) {
       updateCrd(SUBI(getItPos(), posLo));
     } else {
@@ -302,24 +307,24 @@ class TrivialIterator : public ConcreteIterator {
 
   ValueRange forwardImpl(OpBuilder &b, Location l) override {
     seek(ADDI(getItPos(), C_IDX(1)));
-    return getItVals();
+    return getCursor();
   }
 
   ValueRange forwardIf(OpBuilder &b, Location l, Value cond) override {
-    Value curPos = getItVals().front();
+    Value curPos = getCursor().front();
     Value nxPos = forward(b, l).front();
     seek(SELECT(cond, nxPos, curPos));
-    return getItVals();
+    return getCursor();
   }
 
-  void locate(OpBuilder &b, Location l, Value crd) override {
+  void locateImpl(OpBuilder &b, Location l, Value crd) override {
     assert(randomAccessible());
     // Seek to the linearized position.
     seek(ADDI(crd, posLo));
     updateCrd(crd);
   }
 
-  Value getItPos() const { return getItVals().front(); }
+  Value getItPos() const { return getCursor().front(); }
   Value posLo, posHi;
 };
 
@@ -337,6 +342,13 @@ class DedupIterator : public ConcreteIterator {
     return from->kind == IterKind::kDedup;
   }
 
+  std::string getDebugInterfacePrefix() const override {
+    return std::string("dedup<") + stl.toString() + ">";
+  }
+  SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
+    return {b.getIndexType(), b.getIndexType()};
+  }
+
   ValuePair getCurPosition() const override { return {getPos(), getSegHi()}; }
 
   void genInitImpl(OpBuilder &b, Location l,
@@ -355,21 +367,21 @@ class DedupIterator : public ConcreteIterator {
 
   SmallVector<Value> serialize() const override {
     SmallVector<Value> ret;
-    ret.append(getItVals().begin(), getItVals().end());
+    ret.append(getCursor().begin(), getCursor().end());
     ret.push_back(posHi);
     return ret;
   };
   void deserialize(ValueRange vs) override {
     assert(vs.size() == 3);
-    seek(vs.take_front(getItVals().size()));
+    seek(vs.take_front(getCursor().size()));
     posHi = vs.back();
   };
 
-  Value genNotEnd(OpBuilder &b, Location l) override {
+  Value genNotEndImpl(OpBuilder &b, Location l) override {
     return CMPI(ult, getPos(), posHi);
   }
 
-  Value deref(OpBuilder &b, Location l) override {
+  Value derefImpl(OpBuilder &b, Location l) override {
     updateCrd(stl.peekCrdAt(b, l, getPos()));
     return getCrd();
   };
@@ -377,11 +389,11 @@ class DedupIterator : public ConcreteIterator {
   ValueRange forwardImpl(OpBuilder &b, Location l) override {
     Value nxPos = getSegHi(); // forward the position to the next segment.
     seek({nxPos, genSegmentHigh(b, l, nxPos)});
-    return getItVals();
+    return getCursor();
   }
 
-  Value getPos() const { return getItVals()[0]; }
-  Value getSegHi() const { return getItVals()[1]; }
+  Value getPos() const { return getCursor()[0]; }
+  Value getSegHi() const { return getCursor()[1]; }
 
   Value posHi;
 };
@@ -419,6 +431,13 @@ class FilterIterator : public SparseIterator {
     return from->kind == IterKind::kFilter;
   }
 
+  std::string getDebugInterfacePrefix() const override {
+    return std::string("filter<") + wrap->getDebugInterfacePrefix() + ">";
+  }
+  SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
+    return wrap->getCursorValTypes(b);
+  }
+
   bool randomAccessible() const override { return wrap->randomAccessible(); };
   bool iteratableByFor() const override { return randomAccessible(); };
   Value upperBound(OpBuilder &b, Location l) const override { return size; };
@@ -441,14 +460,14 @@ class FilterIterator : public SparseIterator {
     }
   }
 
-  Value genNotEnd(OpBuilder &b, Location l) override;
+  Value genNotEndImpl(OpBuilder &b, Location l) override;
 
-  Value deref(OpBuilder &b, Location l) override {
+  Value derefImpl(OpBuilder &b, Location l) override {
     updateCrd(fromWrapCrd(b, l, wrap->deref(b, l)));
     return getCrd();
   }
 
-  void locate(OpBuilder &b, Location l, Value crd) override {
+  void locateImpl(OpBuilder &b, Location l, Value crd) override {
     assert(randomAccessible());
     wrap->locate(b, l, toWrapCrd(b, l, crd));
     updateCrd(crd);
@@ -469,8 +488,7 @@ class NonEmptySubSectIterator : public SparseIterator {
                           const SparseIterator *parent,
                           std::unique_ptr<SparseIterator> &&delegate,
                           Value subSectSz)
-      : SparseIterator(IterKind::kNonEmptySubSect, delegate->tid, delegate->lvl,
-                       3, /*itVals=*/subSectMeta),
+      : SparseIterator(IterKind::kNonEmptySubSect, 3, subSectMeta, *delegate),
         parent(parent), delegate(std::move(delegate)),
         tupleSz(this->delegate->serialize().size()), subSectSz(subSectSz) {
     auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
@@ -497,6 +515,14 @@ class NonEmptySubSectIterator : public SparseIterator {
     return from->kind == IterKind::kNonEmptySubSect;
   }
 
+  std::string getDebugInterfacePrefix() const override {
+    return std::string("ne_sub<") + delegate->getDebugInterfacePrefix() + ">";
+  }
+  SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
+    // minCrd, absolute offset, notEnd
+    return {b.getIndexType(), b.getIndexType(), b.getI1Type()};
+  }
+
   // The sliced pointer buffer is organized as:
   //     [[itVal0, itVal1, ..., pNx0],
   //      [itVal0, itVal1, ..., pNx0],
@@ -519,8 +545,8 @@ class NonEmptySubSectIterator : public SparseIterator {
                                     ValueRange{tupleId, C_IDX(tupleSz)});
   }
 
-  void storeItVals(OpBuilder &b, Location l, Value tupleId,
-                   ValueRange itVals) const {
+  void storeCursorVals(OpBuilder &b, Location l, Value tupleId,
+                       ValueRange itVals) const {
     assert(itVals.size() == tupleSz);
     for (unsigned i = 0; i < tupleSz; i++) {
       b.create<memref::StoreOp>(l, itVals[i], subSectPosBuf,
@@ -528,7 +554,8 @@ class NonEmptySubSectIterator : public SparseIterator {
     }
   }
 
-  SmallVector<Value> loadItVals(OpBuilder &b, Location l, Value tupleId) const {
+  SmallVector<Value> loadCursorVals(OpBuilder &b, Location l,
+                                    Value tupleId) const {
     SmallVector<Value> ret;
     for (unsigned i = 0; i < tupleSz; i++) {
       Value v = b.create<memref::LoadOp>(l, subSectPosBuf,
@@ -560,7 +587,7 @@ class NonEmptySubSectIterator : public SparseIterator {
 
   void genInitImpl(OpBuilder &b, Location l, const SparseIterator *) override;
 
-  void locate(OpBuilder &b, Location l, Value crd) override {
+  void locateImpl(OpBuilder &b, Location l, Value crd) override {
     Value absOff = crd;
 
     if (isSubSectRoot())
@@ -576,9 +603,11 @@ class NonEmptySubSectIterator : public SparseIterator {
     return SUBI(wrapCrd, getAbsOff());
   }
 
-  Value genNotEnd(OpBuilder &b, Location l) override { return getNotEnd(); };
+  Value genNotEndImpl(OpBuilder &b, Location l) override {
+    return getNotEnd();
+  };
 
-  Value deref(OpBuilder &b, Location l) override {
+  Value derefImpl(OpBuilder &b, Location l) override {
     // Use the relative offset to coiterate.
     Value crd;
     auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
@@ -638,7 +667,7 @@ class SubSectIterator : public SparseIterator {
                   std::unique_ptr<SparseIterator> &&wrap, Value size,
                   unsigned stride)
       : SparseIterator(IterKind::kSubSect, *wrap,
-                       /*extraVal=*/wrap->randomAccessible() ? 0 : 1),
+                       /*extraCursorCnt=*/wrap->randomAccessible() ? 0 : 1),
         subSect(subSect), wrap(std::move(wrap)), parent(parent), size(size),
         stride(stride), helper(*this) {
     assert(stride == 1 && "Not implemented.");
@@ -651,6 +680,16 @@ class SubSectIterator : public SparseIterator {
     return from->kind == IterKind::kSubSect;
   }
 
+  std::string getDebugInterfacePrefix() const override {
+    return std::string("subsect<") + wrap->getDebugInterfacePrefix() + ">";
+  }
+  SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
+    SmallVector<Type> ret = wrap->getCursorValTypes(b);
+    if (!randomAccessible())
+      ret.push_back(b.getIndexType()); // The extra counter.
+    return ret;
+  }
+
   bool randomAccessible() const override { return wrap->randomAccessible(); };
   bool iteratableByFor() const override { return randomAccessible(); };
   Value upperBound(OpBuilder &b, Location l) const override { return size; }
@@ -662,7 +701,7 @@ class SubSectIterator : public SparseIterator {
     if (randomAccessible()) {
       return ADDI(getCrd(), nxLvlTupleStart);
     };
-    return ADDI(getItVals().back(), nxLvlTupleStart);
+    return ADDI(getCursor().back(), nxLvlTupleStart);
   }
 
   void genInitImpl(OpBuilder &b, Location l, const SparseIterator *) override {
@@ -680,10 +719,10 @@ class SubSectIterator : public SparseIterator {
       return;
     }
     assert(!randomAccessible());
-    assert(getItVals().size() == wrap->getItVals().size() + 1);
+    assert(getCursor().size() == wrap->getCursor().size() + 1);
     // Extra counter that counts the number of actually visited coordinates in
     // the sparse subsection.
-    getMutItVals().back() = C_IDX(0);
+    getMutCursorVals().back() = C_IDX(0);
     Value tupleId;
     if (auto *p = llvm::dyn_cast<SubSectIterator>(&parent)) {
       assert(p->lvl + 1 == lvl);
@@ -696,16 +735,16 @@ class SubSectIterator : public SparseIterator {
     helper.deserializeFromTupleId(b, l, tupleId);
   }
 
-  void locate(OpBuilder &b, Location l, Value crd) override {
+  void locateImpl(OpBuilder &b, Location l, Value crd) override {
     helper.locate(b, l, crd);
     updateCrd(crd);
   }
 
-  Value genNotEnd(OpBuilder &b, Location l) override {
+  Value genNotEndImpl(OpBuilder &b, Location l) override {
     return helper.genNotEnd(b, l);
   }
 
-  Value deref(OpBuilder &b, Location l) override {
+  Value derefImpl(OpBuilder &b, Location l) override {
     Value crd = helper.deref(b, l);
     updateCrd(crd);
     return crd;
@@ -714,9 +753,9 @@ class SubSectIterator : public SparseIterator {
   ValueRange forwardImpl(OpBuilder &b, Location l) override {
     helper.forward(b, l);
     assert(!randomAccessible());
-    assert(getItVals().size() == wrap->getItVals().size() + 1);
-    getMutItVals().back() = ADDI(getItVals().back(), C_IDX(1));
-    return getItVals();
+    assert(getCursor().size() == wrap->getCursor().size() + 1);
+    getMutCursorVals().back() = ADDI(getCursor().back(), C_IDX(1));
+    return getCursor();
   };
 
   Value nxLvlTupleStart;
@@ -737,30 +776,82 @@ class SubSectIterator : public SparseIterator {
 // SparseIterator derived classes implementation.
 //===----------------------------------------------------------------------===//
 
+DebugSparseIteration SparseIterator::emitStrategy = DebugSparseIteration::kNone;
+
 void SparseIterator::genInit(OpBuilder &b, Location l,
                              const SparseIterator *p) {
+  if (emitStrategy == DebugSparseIteration::kInterfaceOnly) {
+    std::string prefix = getDebugInterfacePrefix();
+    Operation *begin = b.create(l, b.getStringAttr(prefix + ".begin"), {},
+                                getCursorValTypes(b));
+    seek(begin->getResults());
+    return;
+  }
   // TODO: support lowering to function call.
   return genInitImpl(b, l, p);
 }
 
-ValueRange SparseIterator::forward(OpBuilder &b, Location l) {
+Value SparseIterator::genNotEnd(OpBuilder &b, Location l) {
+  if (emitStrategy == DebugSparseIteration::kInterfaceOnly) {
+    std::string prefix = getDebugInterfacePrefix();
+    Operation *notEnd = b.create(l, b.getStringAttr(prefix + ".not_end"),
+                                 getCursor(), b.getI1Type());
+    return notEnd->getResult(0);
+  }
   // TODO: support lowering to function call.
+  return genNotEndImpl(b, l);
+}
+
+void SparseIterator::locate(OpBuilder &b, Location l, Value crd) {
+  if (emitStrategy == DebugSparseIteration::kInterfaceOnly) {
+    std::string prefix = getDebugInterfacePrefix();
+    SmallVector<Value> args = getCursor();
+    args.push_back(crd);
+    Operation *locate = b.create(l, b.getStringAttr(prefix + ".locate"), args,
+                                 getCursorValTypes(b));
+    seek(locate->getResults());
+    updateCrd(crd);
+    return;
+  }
+  return locateImpl(b, l, crd);
+}
+
+Value SparseIterator::deref(OpBuilder &b, Location l) {
+  if (emitStrategy == DebugSparseIteration::kInterfaceOnly) {
+    std::string prefix = getDebugInterfacePrefix();
+    SmallVector<Value> args = getCursor();
+    Operation *deref = b.create(l, b.getStringAttr(prefix + ".deref"),
+                                getCursor(), b.getIndexType());
+    updateCrd(deref->getResult(0));
+    return getCrd();
+  }
+  return derefImpl(b, l);
+}
+
+ValueRange SparseIterator::forward(OpBuilder &b, Location l) {
+  if (emitStrategy == DebugSparseIteration::kInterfaceOnly) {
+    std::string prefix = getDebugInterfacePrefix();
+    Operation *next = b.create(l, b.getStringAttr(prefix + ".next"),
+                               getCursor(), getCursorValTypes(b));
+    seek(next->getResults());
+    return getCursor();
+  }
   return forwardImpl(b, l);
 }
 
 ValueRange SparseIterator::forwardIf(OpBuilder &b, Location l, Value cond) {
-  auto ifOp = b.create<scf::IfOp>(l, getItVals().getTypes(), cond, true);
+  auto ifOp = b.create<scf::IfOp>(l, getCursor().getTypes(), cond, true);
   // Generate else branch first, otherwise iterator values will be updated by
   // `forward()`.
   b.setInsertionPointToStart(ifOp.elseBlock());
-  YIELD(getItVals());
+  YIELD(getCursor());
 
   b.setInsertionPointToStart(ifOp.thenBlock());
   YIELD(forward(b, l));
 
   b.setInsertionPointAfter(ifOp);
   seek(ifOp.getResults());
-  return getItVals();
+  return getCursor();
 }
 
 Value DedupIterator::genSegmentHigh(OpBuilder &b, Location l, Value pos) {
@@ -817,7 +908,7 @@ Value FilterIterator::genShouldFilter(OpBuilder &b, Location l) {
   return r.front();
 }
 
-Value FilterIterator::genNotEnd(OpBuilder &b, Location l) {
+Value FilterIterator::genNotEndImpl(OpBuilder &b, Location l) {
   assert(!wrap->randomAccessible());
   auto r = genWhenInBound(
       b, l, *wrap, C_FALSE,
@@ -844,7 +935,7 @@ ValueRange FilterIterator::forwardImpl(OpBuilder &b, Location l) {
   // forward a subsection).
   Value isFirst = constantI1(b, l, true);
 
-  SmallVector<Value> whileArgs(getItVals().begin(), getItVals().end());
+  SmallVector<Value> whileArgs(getCursor().begin(), getCursor().end());
   whileArgs.push_back(isFirst);
   auto whileOp = b.create<scf::WhileOp>(
       l, ValueRange(whileArgs).getTypes(), whileArgs,
@@ -870,14 +961,14 @@ ValueRange FilterIterator::forwardImpl(OpBuilder &b, Location l) {
       [this](OpBuilder &b, Location l, ValueRange ivs) {
         linkNewScope(ivs);
         wrap->forward(b, l);
-        SmallVector<Value> yieldVals(getItVals().begin(), getItVals().end());
+        SmallVector<Value> yieldVals(getCursor().begin(), getCursor().end());
         yieldVals.push_back(constantI1(b, l, false));
         YIELD(yieldVals);
       });
 
   b.setInsertionPointAfter(whileOp);
   linkNewScope(whileOp.getResults());
-  return getItVals();
+  return getCursor();
 }
 
 SubSectIterHelper::SubSectIterHelper(const NonEmptySubSectIterator &subSect)
@@ -889,7 +980,7 @@ SubSectIterHelper::SubSectIterHelper(const SubSectIterator &iter)
 void SubSectIterHelper::deserializeFromTupleId(OpBuilder &b, Location l,
                                                Value tupleId) {
   assert(!subSect.randomAccessible());
-  wrap.deserialize(subSect.loadItVals(b, l, tupleId));
+  wrap.deserialize(subSect.loadCursorVals(b, l, tupleId));
 }
 
 void SubSectIterHelper::locate(OpBuilder &b, Location l, Value crd) {
@@ -943,7 +1034,7 @@ ValueRange NonEmptySubSectIterator::inflateSubSectTree(
           // is corresponding to the current node.
           helper.subSect.storeNxLvlStart(b, l, tupleId, cnt);
 
-          SmallVector<Value> whileArgs(helper.wrap.getItVals());
+          SmallVector<Value> whileArgs(helper.wrap.getCursor());
           whileArgs.append(iterArgs.begin(), iterArgs.end());
 
           auto whileOp = b.create<scf::WhileOp>(
@@ -1039,7 +1130,7 @@ void NonEmptySubSectIterator::genInitImpl(OpBuilder &b, Location l,
                        .front();
 
           // Cache the sparse range.
-          storeItVals(b, l, tupleId, helper.wrap.serialize());
+          storeCursorVals(b, l, tupleId, helper.wrap.serialize());
           tupleId = ADDI(tupleId, C_IDX(1));
           return {minCrd, tupleId};
         });
@@ -1068,7 +1159,7 @@ void NonEmptySubSectIterator::genInitImpl(OpBuilder &b, Location l,
   // Only have one root node.
   tupleCnt = C_IDX(1);
   // Cache the sparse range.
-  storeItVals(b, l, c0, delegate->serialize());
+  storeCursorVals(b, l, c0, delegate->serialize());
   SmallVector<Value> elseRet{c0, c0, /*notEnd=*/C_FALSE};
   auto meta = genWhenInBound(
       b, l, *delegate, elseRet,
@@ -1095,7 +1186,7 @@ ValueRange NonEmptySubSectIterator::forwardImpl(OpBuilder &b, Location l) {
   // if (offset + size > parents.size)
   //   isNonEmpty = false;
   Value fastPathP = CMPI(ugt, getMinCrd(), getAbsOff());
-  auto ifOp = b.create<scf::IfOp>(l, getItVals().getTypes(), fastPathP, true);
+  auto ifOp = b.create<scf::IfOp>(l, getCursor().getTypes(), fastPathP, true);
   {
     OpBuilder::InsertionGuard guard(b);
     // Take the fast path
@@ -1134,7 +1225,7 @@ ValueRange NonEmptySubSectIterator::forwardImpl(OpBuilder &b, Location l) {
                 // Update the forwarded iterator values if needed.
                 auto ifIsMin = b.create<scf::IfOp>(l, isMin, false);
                 b.setInsertionPointToStart(&ifIsMin.getThenRegion().front());
-                storeItVals(b, l, tupleId, delegate->serialize());
+                storeCursorVals(b, l, tupleId, delegate->serialize());
                 b.setInsertionPointAfter(ifIsMin);
                 // if (!wrap.end())
                 //  yield(min(nxMinCrd, *wrap), true)
@@ -1172,7 +1263,7 @@ ValueRange NonEmptySubSectIterator::forwardImpl(OpBuilder &b, Location l) {
   nxNotEnd = ANDI(nxNotEnd, CMPI(ult, crd, upperBound(b, l)));
 
   seek(ValueRange{nxMinCrd, nxAbsOff, nxNotEnd});
-  return getItVals();
+  return getCursor();
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
index bf115712bdfc1..2faa2a8de5651 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
@@ -10,13 +10,16 @@
 #define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORLEVEL_H_
 
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
 
 namespace mlir {
 namespace sparse_tensor {
 
-/// The base class for all types of sparse tensor levels. It provides interfaces
-/// to query the loop range (see `peekRangeAt`) and look up the coordinates (see
-/// `peekCrdAt`).
+class ConcreteIterator;
+
+/// The base class for all types of sparse tensor levels. It provides
+/// interfaces to query the loop range (see `peekRangeAt`) and look up the
+/// coordinates (see `peekCrdAt`).
 class SparseTensorLevel {
   SparseTensorLevel(SparseTensorLevel &&) = delete;
   SparseTensorLevel(const SparseTensorLevel &) = delete;
@@ -26,6 +29,10 @@ class SparseTensorLevel {
 public:
   virtual ~SparseTensorLevel() = default;
 
+  std::string toString() const {
+    return std::string(toMLIRString(lt)) + "[" + std::to_string(tid) + "," +
+           std::to_string(lvl) + "]";
+  }
   virtual Value peekCrdAt(OpBuilder &b, Location l, Value iv) const = 0;
 
   /// Peeks the lower and upper bound to *fully* traverse the level with
@@ -46,7 +53,17 @@ class SparseTensorLevel {
 
   Level getLevel() const { return lvl; }
   LevelType getLT() const { return lt; }
-  Value size() const { return lvlSize; }
+  Value getSize() const { return lvlVals.front(); }
+  Value getCrdBuf() const {
+    assert(lvlVals.size() > 1);
+    return lvlVals[1];
+  }
+  Value getPosBuf() const {
+    assert(lvlVals.size() > 2);
+    return lvlVals[2];
+  }
+  ValueRange getLvlVals() const { return lvlVals; }
+  ValueRange getLvlBufs() const { return ValueRange(lvlVals).drop_front(); }
 
   //
   // Level properties
@@ -55,12 +72,24 @@ class SparseTensorLevel {
 
 protected:
   SparseTensorLevel(unsigned tid, unsigned lvl, LevelType lt, Value lvlSize)
-      : tid(tid), lvl(lvl), lt(lt), lvlSize(lvlSize){};
+      : tid(tid), lvl(lvl), lt(lt), lvlVals() {
+    lvlVals.push_back(lvlSize);
+  };
+
+  SparseTensorLevel(unsigned tid, unsigned lvl, LevelType lt, Value lvlSize,
+                    ValueRange lvlBufs)
+      : tid(tid), lvl(lvl), lt(lt), lvlVals() {
+    lvlVals.push_back(lvlSize);
+    lvlVals.append(lvlBufs.begin(), lvlBufs.end());
+  };
 
 public:
   const unsigned tid, lvl;
   const LevelType lt;
-  const Value lvlSize;
+  // The first value in the vector is always lvlsize; for sparse levels, the
+  // second value is always the coordinate buffer; for sparse level with
+  // position buffers, the third value is always the position buffer.
+  SmallVector<Value, 3> lvlVals;
 };
 
 enum class IterKind : uint8_t {
@@ -80,37 +109,47 @@ class SparseIterator {
   SparseIterator &operator=(const SparseIterator &) = delete;
 
 protected:
-  SparseIterator(IterKind kind, unsigned tid, unsigned lvl, unsigned itValsCnt,
-                 SmallVectorImpl<Value> &storage)
-      : kind(kind), tid(tid), lvl(lvl), crd(nullptr), itValsCnt(itValsCnt),
-        itValsStorageRef(storage){};
-
-  SparseIterator(IterKind kind, const SparseIterator &wrap)
-      : kind(kind), tid(wrap.tid), lvl(wrap.lvl), crd(nullptr),
-        itValsCnt(wrap.itValsCnt), itValsStorageRef(wrap.itValsStorageRef) {
-    assert(wrap.itValsCnt == itValsStorageRef.size());
-  };
-
-  SparseIterator(IterKind kind, const SparseIterator &wrap, unsigned extraVal)
-      : kind(kind), tid(wrap.tid), lvl(wrap.lvl), crd(nullptr),
-        itValsCnt(wrap.itValsCnt + extraVal),
-        itValsStorageRef(wrap.itValsStorageRef) {
-    itValsStorageRef.append(extraVal, nullptr);
-    assert(itValsCnt == itValsStorageRef.size());
+  SparseIterator(IterKind kind, unsigned tid, unsigned lvl,
+                 unsigned cursorValsCnt,
+                 SmallVectorImpl<Value> &cursorValStorage)
+      : kind(kind), tid(tid), lvl(lvl), crd(nullptr),
+        cursorValsCnt(cursorValsCnt), cursorValsStorageRef(cursorValStorage){};
+
+  SparseIterator(IterKind kind, unsigned cursorValsCnt,
+                 SmallVectorImpl<Value> &cursorValStorage,
+                 const SparseIterator &delegate)
+      : SparseIterator(kind, delegate.tid, delegate.lvl, cursorValsCnt,
+                       cursorValStorage){};
+
+  SparseIterator(IterKind kind, const SparseIterator &wrap,
+                 unsigned extraCursorCnt = 0)
+      : SparseIterator(kind, wrap.tid, wrap.lvl,
+                       extraCursorCnt + wrap.cursorValsCnt,
+                       wrap.cursorValsStorageRef) {
+    assert(wrap.cursorValsCnt == wrap.cursorValsStorageRef.size());
+    cursorValsStorageRef.append(extraCursorCnt, nullptr);
+    assert(cursorValsStorageRef.size() == wrap.cursorValsCnt + extraCursorCnt);
   };
 
 public:
   virtual ~SparseIterator() = default;
 
+  static void setDebugSparseIteration(DebugSparseIteration strategy) {
+    SparseIterator::emitStrategy = strategy;
+  }
+
+  virtual std::string getDebugInterfacePrefix() const = 0;
+  virtual SmallVector<Type> getCursorValTypes(OpBuilder &b) const = 0;
+
   Value getCrd() const { return crd; }
-  ValueRange getItVals() const {
-    return ValueRange(itValsStorageRef).take_front(itValsCnt);
+  ValueRange getCursor() const {
+    return ValueRange(cursorValsStorageRef).take_front(cursorValsCnt);
   };
 
   // Sets the iterate to the specified position.
   void seek(ValueRange vals) {
-    assert(vals.size() == itValsCnt);
-    std::copy(vals.begin(), vals.end(), itValsStorageRef.begin());
+    assert(vals.size() == cursorValsCnt);
+    std::copy(vals.begin(), vals.end(), cursorValsStorageRef.begin());
     // Now that the iterator is re-positioned, the coordinate becomes invalid.
     crd = nullptr;
   }
@@ -120,20 +159,21 @@ class SparseIterator {
   //
 
   // Whether the iterator support random access (i.e., support look up by
-  // *coordinate*). A random access iterator must also traverses a dense space.
+  // *coordinate*). A random access iterator must also traverses a dense
+  // space.
   virtual bool randomAccessible() const = 0;
 
   // Whether the iterator can simply traversed by a for loop.
   virtual bool iteratableByFor() const { return false; };
 
-  // Get the upper bound of the sparse space that the iterator might visited. A
-  // sparse space is a subset of a dense space [0, bound), this function returns
-  // *bound*.
+  // Get the upper bound of the sparse space that the iterator might visited.
+  // A sparse space is a subset of a dense space [0, bound), this function
+  // returns *bound*.
   virtual Value upperBound(OpBuilder &b, Location l) const = 0;
 
-  // Serializes and deserializes the current status to/from a set of values. The
-  // ValueRange should contain values that are sufficient to recover the current
-  // iterating postion (i.e., itVals) as well as loop bound.
+  // Serializes and deserializes the current status to/from a set of values.
+  // The ValueRange should contain values that are sufficient to recover the
+  // current iterating postion (i.e., itVals) as well as loop bound.
   //
   // Not every type of iterator supports the operations, e.g., non-empty
   // subsection iterator does not because the the number of non-empty
@@ -155,23 +195,31 @@ class SparseIterator {
   // Forwards the iterator to the next element.
   ValueRange forward(OpBuilder &b, Location l);
 
-  // Actual Implementation provided by derived class.
-  virtual void genInitImpl(OpBuilder &, Location, const SparseIterator *) = 0;
-  virtual ValueRange forwardImpl(OpBuilder &b, Location l) = 0;
+  // be done on an iterator that supports randm access.
+  void locate(OpBuilder &b, Location l, Value crd);
 
   // Returns a boolean value that equals `!it.end()`
-  virtual Value genNotEnd(OpBuilder &b, Location l) = 0;
+  Value genNotEnd(OpBuilder &b, Location l);
 
   // Dereferences the iterator, loads the coordinate at the current position.
   //
   // The method assumes that the iterator is not currently exhausted (i.e.,
   // it != it.end()).
-  virtual Value deref(OpBuilder &b, Location l) = 0;
+  Value deref(OpBuilder &b, Location l);
 
-  // Gets the current position and the optional *position high* (for non-unique
-  // iterators), the value is essentially the number of sparse coordinate that
-  // the iterator is current visiting. It should be able to uniquely identify
-  // the sparse range for the next level. See SparseTensorLevel::peekRangeAt();
+  // Actual Implementation provided by derived class.
+  virtual void genInitImpl(OpBuilder &, Location, const SparseIterator *) = 0;
+  virtual ValueRange forwardImpl(OpBuilder &b, Location l) = 0;
+  virtual void locateImpl(OpBuilder &b, Location l, Value crd) {
+    llvm_unreachable("Unsupported");
+  }
+  virtual Value genNotEndImpl(OpBuilder &b, Location l) = 0;
+  virtual Value derefImpl(OpBuilder &b, Location l) = 0;
+  // Gets the current position and the optional *position high* (for
+  // non-unique iterators), the value is essentially the number of sparse
+  // coordinate that the iterator is current visiting. It should be able to
+  // uniquely identify the sparse range for the next level. See
+  // SparseTensorLevel::peekRangeAt();
   //
   // Not every type of iterator supports the operation, e.g., non-empty
   // subsection iterator does not because it represent a range of coordinates
@@ -202,33 +250,29 @@ class SparseIterator {
   //    yield it
   //
   // The function is virtual to allow alternative implementation. For example,
-  // if it.next() is trivial to compute, we can use a select operation instead.
-  // E.g.,
+  // if it.next() is trivial to compute, we can use a select operation
+  // instead. E.g.,
   //
   //  it = select cond ? it+1 : it
   virtual ValueRange forwardIf(OpBuilder &b, Location l, Value cond);
 
-  // Locate the iterator to the position specified by *crd*, this can only
-  // be done on an iterator that supports randm access.
-  virtual void locate(OpBuilder &b, Location l, Value crd) {
-    llvm_unreachable("Unsupported");
-  }
-
   // Update the SSA value for the iterator after entering a new scope.
   ValueRange linkNewScope(ValueRange pos) {
     assert(!randomAccessible() && "random accessible iterators are traversed "
                                   "by coordinate, call locate() instead.");
-    seek(pos.take_front(itValsCnt));
-    return pos.drop_front(itValsCnt);
+    seek(pos.take_front(cursorValsCnt));
+    return pos.drop_front(cursorValsCnt);
   };
 
 protected:
   void updateCrd(Value crd) { this->crd = crd; }
-  MutableArrayRef<Value> getMutItVals() {
-    MutableArrayRef<Value> ref = itValsStorageRef;
-    return ref.take_front(itValsCnt);
+  MutableArrayRef<Value> getMutCursorVals() {
+    MutableArrayRef<Value> ref = cursorValsStorageRef;
+    return ref.take_front(cursorValsCnt);
   }
 
+  static DebugSparseIteration emitStrategy;
+
 public:
   const IterKind kind;     // For LLVM-style RTTI.
   const unsigned tid, lvl; // tensor level identifier.
@@ -239,14 +283,14 @@ class SparseIterator {
   // A range of value that together defines the current state of the
   // iterator. Only loop variants should be included.
   //
-  // For trivial iterators, it is the position; for dedup iterators, it consists
-  // of the positon and the segment high, for non-empty subsection iterator, it
-  // is the metadata that specifies the subsection.
-  // Note that the wrapped iterator shares the same storage to maintain itVals
-  // with it wrapper, which means the wrapped iterator might only own a subset
-  // of all the values stored in itValStorage.
-  const unsigned itValsCnt;
-  SmallVectorImpl<Value> &itValsStorageRef;
+  // For trivial iterators, it is the position; for dedup iterators, it
+  // consists of the positon and the segment high, for non-empty subsection
+  // iterator, it is the metadata that specifies the subsection. Note that the
+  // wrapped iterator shares the same storage to maintain itVals with it
+  // wrapper, which means the wrapped iterator might only own a subset of all
+  // the values stored in itValStorage.
+  const unsigned cursorValsCnt;
+  SmallVectorImpl<Value> &cursorValsStorageRef;
 };
 
 /// Helper function to create a TensorLevel object from given `tensor`.
diff --git a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
index 70cf0f9af45b5..18118cab9e52c 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
@@ -1,5 +1,4 @@
-// TODO: re-enable after lowering coo.next to function call (such that loop structure is more clear).
-// RUN: mlir-opt %s --sparse-reinterpret-map --sparsification --canonicalize --cse | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map --sparsification="debug-sparse-iteration=interface-only" --canonicalize --cse --allow-unregistered-dialect | FileCheck %s
 
 #map = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)>
 #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
@@ -8,233 +7,54 @@
 #DCSR = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>
 
 
+
 // CHECK-LABEL:   func.func @conv2d_all_sparse_CSR(
-// C_HECK-SAME:      %[[VAL_0:.*]]: tensor<8x8xi32, #sparse>,
-// C_HECK-SAME:      %[[VAL_1:.*]]: tensor<3x3xi32>) -> tensor<6x6xi32, #sparse> {
-// C_HECK-DAG:       %[[VAL_2:.*]] = arith.constant true
-// C_HECK-DAG:       %[[VAL_3:.*]] = arith.constant -2 : index
-// C_HECK-DAG:       %[[VAL_4:.*]] = arith.constant 2 : index
-// C_HECK-DAG:       %[[VAL_5:.*]] = arith.constant 8 : index
-// C_HECK-DAG:       %[[VAL_6:.*]] = arith.constant 3 : index
-// C_HECK-DAG:       %[[VAL_7:.*]] = arith.constant 1 : index
-// C_HECK-DAG:       %[[VAL_8:.*]] = arith.constant 0 : index
-// C_HECK-DAG:       %[[VAL_9:.*]] = arith.constant 0 : i32
-// C_HECK-DAG:       %[[VAL_10:.*]] = arith.constant false
-// C_HECK-DAG:       %[[VAL_11:.*]] = tensor.empty() : tensor<6x6xi32, #sparse>
-// C_HECK-DAG:       %[[VAL_12:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<8x8xi32, #sparse> to memref<?xindex>
-// C_HECK-DAG:       %[[VAL_13:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<8x8xi32, #sparse> to memref<?xindex>
-// C_HECK-DAG:       %[[VAL_14:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<8x8xi32, #sparse> to memref<?xindex>
-// C_HECK-DAG:       %[[VAL_15:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<8x8xi32, #sparse> to memref<?xindex>
-// C_HECK-DAG:       %[[VAL_16:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x8xi32, #sparse> to memref<?xi32>
-// C_HECK-DAG:       %[[VAL_17:.*]] = memref.alloca() : memref<9xindex>
-// C_HECK-DAG:       %[[VAL_18:.*]] = memref.alloca() : memref<3xindex>
-// C_HECK-DAG:       %[[POS_LO:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_8]]] : memref<?xindex>
-// C_HECK-DAG:       %[[POS_HI:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_7]]] : memref<?xindex>
-// C_HECK:           memref.store %[[POS_LO]], %[[VAL_18]]{{\[}}%[[VAL_8]]] : memref<3xindex>
-// C_HECK:           memref.store %[[POS_HI]], %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<3xindex>
-// C_HECK:           %[[VAL_20:.*]] = arith.cmpi ult, %[[POS_LO]], %[[POS_HI]] : index
-// C_HECK:           %[[VAL_21:.*]] = memref.load %[[VAL_13]]{{\[}}%[[POS_LO]]] : memref<?xindex>
-// C_HECK:           %[[VAL_22:.*]] = arith.cmpi uge, %[[VAL_21]], %[[VAL_6]] : index
-// C_HECK:           %[[VAL_23:.*]] = arith.andi %[[VAL_20]], %[[VAL_22]] : i1
-// C_HECK:           %[[VAL_24:.*]] = arith.addi %[[VAL_21]], %[[VAL_3]] : index
-// C_HECK:           %[[VAL_25:.*]] = arith.select %[[VAL_23]], %[[VAL_24]], %[[VAL_8]] : index
-// C_HECK:           %[[VAL_26:.*]]:3 = scf.while (%[[VAL_27:.*]] = %[[VAL_20]], %[[VAL_28:.*]] = %[[VAL_21]], %[[VAL_29:.*]] = %[[VAL_25]], %[[VAL_30:.*]] = %[[VAL_11]]) : (i1, index, index, tensor<6x6xi32, #sparse>) -> (index, index, tensor<6x6xi32, #sparse>) {
-// C_HECK:             scf.condition(%[[VAL_27]]) %[[VAL_28]], %[[VAL_29]], %[[VAL_30]] : index, index, tensor<6x6xi32, #sparse>
-// C_HECK:           } do {
-// C_HECK:           ^bb0(%[[VAL_31:.*]]: index, %[[VAL_32:.*]]: index, %[[VAL_33:.*]]: tensor<6x6xi32, #sparse>):
-// C_HECK:             %[[VAL_34:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_8]]] : memref<3xindex>
-// C_HECK:             %[[VAL_35:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<3xindex>
-// C_HECK:             memref.store %[[VAL_8]], %[[VAL_18]]{{\[}}%[[VAL_4]]] : memref<3xindex>
-// C_HECK:             %[[VAL_36:.*]] = arith.addi %[[VAL_32]], %[[VAL_6]] : index
-// C_HECK:             %[[VAL_37:.*]]:5 = scf.while (%[[VAL_38:.*]] = %[[VAL_34]], %[[VAL_39:.*]] = %[[VAL_10]], %[[VAL_40:.*]] = %[[VAL_5]], %[[VAL_41:.*]] = %[[VAL_8]], %[[VAL_42:.*]] = %[[VAL_8]]) : (index, i1, index, index, index) -> (index, i1, index, index, index) {
-// C_HECK:               %[[VAL_43:.*]] = arith.cmpi ult, %[[VAL_38]], %[[VAL_35]] : index
-// C_HECK:               %[[VAL_44:.*]] = scf.if %[[VAL_43]] -> (i1) {
-// C_HECK:                 %[[VAL_45:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_38]]] : memref<?xindex>
-// C_HECK:                 %[[VAL_46:.*]] = arith.cmpi ult, %[[VAL_45]], %[[VAL_36]] : index
-// C_HECK:                 scf.yield %[[VAL_46]] : i1
-// C_HECK:               } else {
-// C_HECK:                 scf.yield %[[VAL_10]] : i1
-// C_HECK:               }
-// C_HECK:               scf.condition(%[[VAL_44]]) %[[VAL_38]], %[[VAL_39]], %[[VAL_40]], %[[VAL_41]], %[[VAL_42]] : index, i1, index, index, index
-// C_HECK:             } do {
-// C_HECK:             ^bb0(%[[VAL_47:.*]]: index, %[[VAL_48:.*]]: i1, %[[VAL_49:.*]]: index, %[[VAL_50:.*]]: index, %[[VAL_51:.*]]: index):
-// C_HECK-DAG:           %[[VAL_52:.*]] = arith.addi %[[VAL_47]], %[[VAL_7]] : index
-// C_HECK-DAG:           %[[VAL_53:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_47]]] : memref<?xindex>
-// C_HECK:               %[[VAL_54:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_52]]] : memref<?xindex>
-// C_HECK:               %[[VAL_55:.*]] = arith.cmpi ult, %[[VAL_53]], %[[VAL_54]] : index
-// C_HECK:               %[[VAL_56:.*]] = arith.ori %[[VAL_55]], %[[VAL_48]] : i1
-// C_HECK:               %[[VAL_57:.*]] = scf.if %[[VAL_55]] -> (index) {
-// C_HECK:                 %[[VAL_58:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_53]]] : memref<?xindex>
-// C_HECK:                 %[[VAL_59:.*]] = arith.cmpi ult, %[[VAL_58]], %[[VAL_49]] : index
-// C_HECK:                 %[[VAL_60:.*]] = arith.select %[[VAL_59]], %[[VAL_58]], %[[VAL_49]] : index
-// C_HECK:                 scf.yield %[[VAL_60]] : index
-// C_HECK:               } else {
-// C_HECK:                 scf.yield %[[VAL_49]] : index
-// C_HECK:               }
-// C_HECK:               memref.store %[[VAL_53]], %[[VAL_17]]{{\[}}%[[VAL_50]]] : memref<9xindex>
-// C_HECK:               %[[VAL_61:.*]] = arith.addi %[[VAL_50]], %[[VAL_6]] : index
-// C_HECK:               memref.store %[[VAL_54]], %[[VAL_17]]{{\[}}%[[VAL_61]]] : memref<9xindex>
-// C_HECK:               %[[VAL_62:.*]] = arith.addi %[[VAL_50]], %[[VAL_7]] : index
-// C_HECK:               %[[VAL_63:.*]] = arith.addi %[[VAL_51]], %[[VAL_7]] : index
-// C_HECK:               scf.yield %[[VAL_52]], %[[VAL_56]], %[[VAL_57]], %[[VAL_62]], %[[VAL_63]] : index, i1, index, index, index
-// C_HECK:             }
-// C_HECK:             %[[VAL_64:.*]] = arith.cmpi uge, %[[VAL_65:.*]]#2, %[[VAL_6]] : index
-// C_HECK:             %[[VAL_66:.*]] = arith.andi %[[VAL_65]]#1, %[[VAL_64]] : i1
-// C_HECK:             %[[VAL_67:.*]] = arith.addi %[[VAL_65]]#2, %[[VAL_3]] : index
-// C_HECK:             %[[VAL_68:.*]] = arith.select %[[VAL_66]], %[[VAL_67]], %[[VAL_8]] : index
-// C_HECK:             %[[VAL_69:.*]]:3 = scf.while (%[[VAL_70:.*]] = %[[VAL_65]]#1, %[[VAL_71:.*]] = %[[VAL_65]]#2, %[[VAL_72:.*]] = %[[VAL_68]], %[[VAL_73:.*]] = %[[VAL_33]]) : (i1, index, index, tensor<6x6xi32, #sparse>) -> (index, index, tensor<6x6xi32, #sparse>) {
-// C_HECK:               scf.condition(%[[VAL_70]]) %[[VAL_71]], %[[VAL_72]], %[[VAL_73]] : index, index, tensor<6x6xi32, #sparse>
-// C_HECK:             } do {
-// C_HECK:             ^bb0(%[[VAL_74:.*]]: index, %[[VAL_75:.*]]: index, %[[VAL_76:.*]]: tensor<6x6xi32, #sparse>):
-// C_HECK:               %[[VAL_77:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_8]]] : memref<3xindex>
-// C_HECK:               %[[VAL_78:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<3xindex>
-// C_HECK:               %[[VAL_79:.*]]:3 = scf.while (%[[VAL_80:.*]] = %[[VAL_77]], %[[VAL_81:.*]] = %[[VAL_9]], %[[VAL_82:.*]] = %[[VAL_10]]) : (index, i32, i1) -> (index, i32, i1) {
-// C_HECK:                 %[[VAL_83:.*]] = arith.cmpi ult, %[[VAL_80]], %[[VAL_78]] : index
-// C_HECK:                 %[[VAL_84:.*]] = scf.if %[[VAL_83]] -> (i1) {
-// C_HECK:                   %[[VAL_85:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_80]]] : memref<?xindex>
-// C_HECK:                   %[[VAL_86:.*]] = arith.cmpi ult, %[[VAL_85]], %[[VAL_36]] : index
-// C_HECK:                   scf.yield %[[VAL_86]] : i1
-// C_HECK:                 } else {
-// C_HECK:                   scf.yield %[[VAL_10]] : i1
-// C_HECK:                 }
-// C_HECK:                 scf.condition(%[[VAL_84]]) %[[VAL_80]], %[[VAL_81]], %[[VAL_82]] : index, i32, i1
-// C_HECK:               } do {
-// C_HECK:               ^bb0(%[[VAL_87:.*]]: index, %[[VAL_88:.*]]: i32, %[[VAL_89:.*]]: i1):
-// C_HECK:                 %[[VAL_90:.*]] = arith.subi %[[VAL_87]], %[[VAL_77]] : index
-// C_HECK:                 %[[VAL_91:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_87]]] : memref<?xindex>
-// C_HECK:                 %[[VAL_92:.*]] = arith.subi %[[VAL_91]], %[[VAL_32]] : index
-// C_HECK:                 %[[VAL_93:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_90]]] : memref<9xindex>
-// C_HECK:                 %[[VAL_94:.*]] = arith.addi %[[VAL_90]], %[[VAL_6]] : index
-// C_HECK:                 %[[VAL_95:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_94]]] : memref<9xindex>
-// C_HECK:                 %[[VAL_96:.*]] = arith.addi %[[VAL_75]], %[[VAL_6]] : index
-// C_HECK:                 %[[VAL_97:.*]]:2 = scf.while (%[[VAL_98:.*]] = %[[VAL_93]], %[[VAL_99:.*]] = %[[VAL_88]]) : (index, i32) -> (index, i32) {
-// C_HECK:                   %[[VAL_100:.*]] = arith.cmpi ult, %[[VAL_98]], %[[VAL_95]] : index
-// C_HECK:                   %[[VAL_101:.*]] = scf.if %[[VAL_100]] -> (i1) {
-// C_HECK:                     %[[VAL_102:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_98]]] : memref<?xindex>
-// C_HECK:                     %[[VAL_103:.*]] = arith.cmpi ult, %[[VAL_102]], %[[VAL_96]] : index
-// C_HECK:                     scf.yield %[[VAL_103]] : i1
-// C_HECK:                   } else {
-// C_HECK:                     scf.yield %[[VAL_10]] : i1
-// C_HECK:                   }
-// C_HECK:                   scf.condition(%[[VAL_101]]) %[[VAL_98]], %[[VAL_99]] : index, i32
-// C_HECK:                 } do {
-// C_HECK:                 ^bb0(%[[VAL_104:.*]]: index, %[[VAL_105:.*]]: i32):
-// C_HECK:                   %[[VAL_106:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_104]]] : memref<?xindex>
-// C_HECK:                   %[[VAL_107:.*]] = arith.subi %[[VAL_106]], %[[VAL_75]] : index
-// C_HECK:                   %[[VAL_108:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_104]]] : memref<?xi32>
-// C_HECK:                   %[[VAL_109:.*]] = tensor.extract %[[VAL_1]]{{\[}}%[[VAL_92]], %[[VAL_107]]] : tensor<3x3xi32>
-// C_HECK:                   %[[VAL_110:.*]] = arith.muli %[[VAL_108]], %[[VAL_109]] : i32
-// C_HECK:                   %[[VAL_111:.*]] = arith.addi %[[VAL_105]], %[[VAL_110]] : i32
-// C_HECK:                   %[[VAL_112:.*]] = arith.addi %[[VAL_104]], %[[VAL_7]] : index
-// C_HECK:                   scf.yield %[[VAL_112]], %[[VAL_111]] : index, i32
-// C_HECK:                 }
-// C_HECK:                 %[[VAL_113:.*]] = arith.addi %[[VAL_87]], %[[VAL_7]] : index
-// C_HECK:                 scf.yield %[[VAL_113]], %[[VAL_114:.*]]#1, %[[VAL_2]] : index, i32, i1
-// C_HECK:               }
-// C_HECK:               %[[VAL_115:.*]] = scf.if %[[VAL_116:.*]]#2 -> (tensor<6x6xi32, #sparse>) {
-// C_HECK:                 %[[VAL_117:.*]] = sparse_tensor.insert %[[VAL_116]]#1 into %[[VAL_76]]{{\[}}%[[VAL_32]], %[[VAL_75]]] : tensor<6x6xi32, #sparse>
-// C_HECK:                 scf.yield %[[VAL_117]] : tensor<6x6xi32, #sparse>
-// C_HECK:               } else {
-// C_HECK:                 scf.yield %[[VAL_76]] : tensor<6x6xi32, #sparse>
-// C_HECK:               }
-// C_HECK:               %[[VAL_118:.*]] = arith.cmpi ugt, %[[VAL_74]], %[[VAL_75]] : index
-// C_HECK:               %[[VAL_119:.*]]:3 = scf.if %[[VAL_118]] -> (index, i1, index) {
-// C_HECK:                 %[[VAL_120:.*]] = arith.addi %[[VAL_75]], %[[VAL_7]] : index
-// C_HECK:                 scf.yield %[[VAL_74]], %[[VAL_2]], %[[VAL_120]] : index, i1, index
-// C_HECK:               } else {
-// C_HECK:                 %[[VAL_121:.*]]:2 = scf.for %[[VAL_122:.*]] = %[[VAL_8]] to %[[VAL_65]]#3 step %[[VAL_7]] iter_args(%[[VAL_123:.*]] = %[[VAL_5]], %[[VAL_124:.*]] = %[[VAL_10]]) -> (index, i1) {
-// C_HECK:                   %[[VAL_125:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_122]]] : memref<9xindex>
-// C_HECK:                   %[[VAL_126:.*]] = arith.addi %[[VAL_122]], %[[VAL_6]] : index
-// C_HECK:                   %[[VAL_127:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_126]]] : memref<9xindex>
-// C_HECK:                   %[[VAL_128:.*]] = arith.cmpi ult, %[[VAL_125]], %[[VAL_127]] : index
-// C_HECK:                   %[[VAL_129:.*]] = scf.if %[[VAL_128]] -> (index) {
-// C_HECK:                     %[[VAL_130:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_125]]] : memref<?xindex>
-// C_HECK:                     %[[VAL_131:.*]] = arith.cmpi eq, %[[VAL_130]], %[[VAL_74]] : index
-// C_HECK:                     %[[VAL_132:.*]] = scf.if %[[VAL_131]] -> (index) {
-// C_HECK:                       %[[VAL_133:.*]] = arith.addi %[[VAL_125]], %[[VAL_7]] : index
-// C_HECK:                       memref.store %[[VAL_133]], %[[VAL_17]]{{\[}}%[[VAL_122]]] : memref<9xindex>
-// C_HECK:                       scf.yield %[[VAL_133]] : index
-// C_HECK:                     } else {
-// C_HECK:                       scf.yield %[[VAL_125]] : index
-// C_HECK:                     }
-// C_HECK:                     scf.yield %[[VAL_132]] : index
-// C_HECK:                   } else {
-// C_HECK:                     scf.yield %[[VAL_125]] : index
-// C_HECK:                   }
-// C_HECK:                   %[[VAL_134:.*]] = arith.cmpi ult, %[[VAL_129]], %[[VAL_127]] : index
-// C_HECK:                   %[[VAL_135:.*]] = scf.if %[[VAL_134]] -> (index) {
-// C_HECK:                     %[[VAL_136:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_129]]] : memref<?xindex>
-// C_HECK:                     scf.yield %[[VAL_136]] : index
-// C_HECK:                   } else {
-// C_HECK:                     scf.yield %[[VAL_123]] : index
-// C_HECK:                   }
-// C_HECK:                   %[[VAL_137:.*]] = arith.ori %[[VAL_134]], %[[VAL_124]] : i1
-// C_HECK:                   %[[VAL_138:.*]] = arith.cmpi ult, %[[VAL_135]], %[[VAL_123]] : index
-// C_HECK:                   %[[VAL_139:.*]] = arith.select %[[VAL_138]], %[[VAL_135]], %[[VAL_123]] : index
-// C_HECK:                   scf.yield %[[VAL_139]], %[[VAL_137]] : index, i1
-// C_HECK:                 }
-// C_HECK:                 %[[VAL_140:.*]] = arith.addi %[[VAL_141:.*]]#0, %[[VAL_7]] : index
-// C_HECK:                 %[[VAL_142:.*]] = arith.addi %[[VAL_141]]#0, %[[VAL_3]] : index
-// C_HECK:                 %[[VAL_143:.*]] = arith.cmpi uge, %[[VAL_140]], %[[VAL_6]] : index
-// C_HECK:                 %[[VAL_144:.*]] = arith.select %[[VAL_143]], %[[VAL_142]], %[[VAL_8]] : index
-// C_HECK:                 scf.yield %[[VAL_141]]#0, %[[VAL_141]]#1, %[[VAL_144]] : index, i1, index
-// C_HECK:               }
-// C_HECK:               %[[VAL_145:.*]] = arith.addi %[[VAL_75]], %[[VAL_7]] : index
-// C_HECK:               %[[VAL_146:.*]] = arith.cmpi ugt, %[[VAL_147:.*]]#2, %[[VAL_145]] : index
-// C_HECK:               %[[VAL_148:.*]] = arith.select %[[VAL_146]], %[[VAL_147]]#2, %[[VAL_145]] : index
-// C_HECK:               %[[VAL_149:.*]] = arith.addi %[[VAL_148]], %[[VAL_6]] : index
-// C_HECK:               %[[VAL_150:.*]] = arith.cmpi ule, %[[VAL_149]], %[[VAL_5]] : index
-// C_HECK:               %[[VAL_151:.*]] = arith.andi %[[VAL_147]]#1, %[[VAL_150]] : i1
-// C_HECK:               scf.yield %[[VAL_151]], %[[VAL_147]]#0, %[[VAL_148]], %[[VAL_115]] : i1, index, index, tensor<6x6xi32, #sparse>
-// C_HECK:             }
-// C_HECK:             %[[VAL_152:.*]] = arith.cmpi ugt, %[[VAL_31]], %[[VAL_32]] : index
-// C_HECK:             %[[VAL_153:.*]]:3 = scf.if %[[VAL_152]] -> (index, i1, index) {
-// C_HECK:               %[[VAL_154:.*]] = arith.addi %[[VAL_32]], %[[VAL_7]] : index
-// C_HECK:               scf.yield %[[VAL_31]], %[[VAL_2]], %[[VAL_154]] : index, i1, index
-// C_HECK:             } else {
-// C_HECK:               %[[VAL_155:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_8]]] : memref<3xindex>
-// C_HECK:               %[[VAL_156:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<3xindex>
-// C_HECK:               %[[VAL_157:.*]] = arith.cmpi ult, %[[VAL_155]], %[[VAL_156]] : index
-// C_HECK:               %[[VAL_158:.*]] = scf.if %[[VAL_157]] -> (index) {
-// C_HECK:                 %[[VAL_159:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_155]]] : memref<?xindex>
-// C_HECK:                 %[[VAL_160:.*]] = arith.cmpi eq, %[[VAL_159]], %[[VAL_31]] : index
-// C_HECK:                 %[[VAL_161:.*]] = scf.if %[[VAL_160]] -> (index) {
-// C_HECK:                   %[[VAL_162:.*]] = arith.addi %[[VAL_155]], %[[VAL_7]] : index
-// C_HECK:                   memref.store %[[VAL_162]], %[[VAL_18]]{{\[}}%[[VAL_8]]] : memref<3xindex>
-// C_HECK:                   scf.yield %[[VAL_162]] : index
-// C_HECK:                 } else {
-// C_HECK:                   scf.yield %[[VAL_155]] : index
-// C_HECK:                 }
-// C_HECK:                 scf.yield %[[VAL_161]] : index
-// C_HECK:               } else {
-// C_HECK:                 scf.yield %[[VAL_155]] : index
-// C_HECK:               }
-// C_HECK:               %[[VAL_163:.*]] = arith.cmpi ult, %[[VAL_158]], %[[VAL_156]] : index
-// C_HECK:               %[[VAL_164:.*]] = scf.if %[[VAL_163]] -> (index) {
-// C_HECK:                 %[[VAL_165:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_158]]] : memref<?xindex>
-// C_HECK:                 scf.yield %[[VAL_165]] : index
-// C_HECK:               } else {
-// C_HECK:                 scf.yield %[[VAL_5]] : index
-// C_HECK:               }
-// C_HECK:               %[[VAL_166:.*]] = arith.cmpi ult, %[[VAL_164]], %[[VAL_5]] : index
-// C_HECK:               %[[VAL_167:.*]] = arith.select %[[VAL_166]], %[[VAL_164]], %[[VAL_5]] : index
-// C_HECK:               %[[VAL_168:.*]] = arith.addi %[[VAL_167]], %[[VAL_7]] : index
-// C_HECK:               %[[VAL_169:.*]] = arith.addi %[[VAL_167]], %[[VAL_3]] : index
-// C_HECK:               %[[VAL_170:.*]] = arith.cmpi uge, %[[VAL_168]], %[[VAL_6]] : index
-// C_HECK:               %[[VAL_171:.*]] = arith.select %[[VAL_170]], %[[VAL_169]], %[[VAL_8]] : index
-// C_HECK:               scf.yield %[[VAL_167]], %[[VAL_163]], %[[VAL_171]] : index, i1, index
-// C_HECK:             }
-// C_HECK:             %[[VAL_172:.*]] = arith.addi %[[VAL_32]], %[[VAL_7]] : index
-// C_HECK:             %[[VAL_173:.*]] = arith.cmpi ugt, %[[VAL_174:.*]]#2, %[[VAL_172]] : index
-// C_HECK:             %[[VAL_175:.*]] = arith.select %[[VAL_173]], %[[VAL_174]]#2, %[[VAL_172]] : index
-// C_HECK:             %[[VAL_176:.*]] = arith.addi %[[VAL_175]], %[[VAL_6]] : index
-// C_HECK:             %[[VAL_177:.*]] = arith.cmpi ule, %[[VAL_176]], %[[VAL_5]] : index
-// C_HECK:             %[[VAL_178:.*]] = arith.andi %[[VAL_174]]#1, %[[VAL_177]] : i1
-// C_HECK:             scf.yield %[[VAL_178]], %[[VAL_174]]#0, %[[VAL_175]], %[[VAL_179:.*]]#2 : i1, index, index, tensor<6x6xi32, #sparse>
-// C_HECK:           }
-// C_HECK:           %[[VAL_180:.*]] = sparse_tensor.load %[[VAL_181:.*]]#2 hasInserts : tensor<6x6xi32, #sparse>
-// C_HECK:           return %[[VAL_180]] : tensor<6x6xi32, #sparse>
-// C_HECK:         }
+// CHECK:           "ne_sub<trivial<compressed[0,0]>>.begin"
+// CHECK:           scf.while {{.*}} {
+// CHECK:             "ne_sub<trivial<compressed[0,0]>>.not_end"
+// CHECK:           } do {
+// CHECK:             %[[D0:.*]] = "ne_sub<trivial<compressed[0,0]>>.deref"
+// CHECK:             "ne_sub<trivial<compressed[0,1]>>.begin"
+// CHECK:             scf.while {{.*}} {
+// CHECK:               "ne_sub<trivial<compressed[0,1]>>.not_end"
+// CHECK:             } do {
+// CHECK:               %[[D1:.*]] = "ne_sub<trivial<compressed[0,1]>>.deref"
+// CHECK:               "subsect<trivial<compressed[0,0]>>.begin"
+// CHECK:               scf.while {{.*}} {
+// CHECK:                 "subsect<trivial<compressed[0,0]>>.not_end
+// CHECK:               } do {
+// CHECK:                 %[[D2:.*]] = "subsect<trivial<compressed[0,0]>>.deref"
+// CHECK:                 "trivial<dense[1,0]>.locate"(%{{.*}}, %[[D2]])
+// CHECK:                 "subsect<trivial<compressed[0,1]>>.begin"
+// CHECK:                 scf.while {{.*}} {
+// CHECK:                   "subsect<trivial<compressed[0,1]>>.not_end"
+// CHECK:                 } do {
+// CHECK:                   %[[D3:.*]] = "subsect<trivial<compressed[0,1]>>.deref"
+// CHECK:                   "trivial<dense[1,1]>.locate"(%{{.*}}, %[[D3]])
+// CHECK:                   tensor.extract %{{.*}}{{\[}}%[[D2]], %[[D3]]]
+// CHECK:                   arith.muli
+// CHECK:                   arith.addi
+// CHECK:                   "subsect<trivial<compressed[0,1]>>.next
+// CHECK:                   scf.yield
+// CHECK:                 }
+// CHECK:                 "subsect<trivial<compressed[0,0]>>.next
+// CHECK:                 scf.yield
+// CHECK:               }
+// CHECK:               scf.if {{.*}} {
+// CHECK:                 sparse_tensor.insert %{{.*}} into %{{.*}}{{\[}}%[[D0]], %[[D1]]]
+// CHECK:                 scf.yield
+// CHECK:               } else {
+// CHECK:                 scf.yield
+// CHECK:               }
+// CHECK:               "ne_sub<trivial<compressed[0,1]>>.next"
+// CHECK:               scf.yield
+// CHECK:             }
+// CHECK:             "ne_sub<trivial<compressed[0,0]>>.next"
+// CHECK:             scf.yield
+// CHECK:           }
+// CHECK:           sparse_tensor.load
+// CHECK:           return
+// CHECK:         }
 func.func @conv2d_all_sparse_CSR(%arg0: tensor<8x8xi32, #DCSR>,
                                  %arg1: tensor<3x3xi32>) -> tensor<6x6xi32, #DCSR> {
   %0 = tensor.empty() : tensor<6x6xi32, #DCSR>

>From b5744e8d32609cc50049f719edd974f13d959e73 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 31 Jan 2024 22:15:32 +0000
Subject: [PATCH 2/4] revert unintended change

---
 .../Transforms/Utils/SparseTensorLevel.cpp    | 105 +++++++++---------
 .../Transforms/Utils/SparseTensorLevel.h      |  37 ++----
 2 files changed, 59 insertions(+), 83 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
index bdaf794744bea..604136c3884f9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
@@ -45,6 +45,19 @@ using ValueTuple = std::tuple<Value, Value, Value>;
 //===----------------------------------------------------------------------===//
 
 namespace {
+class SparseLevel : public SparseTensorLevel {
+public:
+  SparseLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
+              Value crdBuffer)
+      : SparseTensorLevel(tid, lvl, lt, lvlSize), crdBuffer(crdBuffer) {}
+
+  Value peekCrdAt(OpBuilder &b, Location l, Value iv) const override {
+    return genIndexLoad(b, l, crdBuffer, iv);
+  }
+
+protected:
+  const Value crdBuffer;
+};
 
 class DenseLevel : public SparseTensorLevel {
 public:
@@ -60,27 +73,53 @@ class DenseLevel : public SparseTensorLevel {
                         Value max) const override {
     assert(max == nullptr && "Dense level can not be non-unique.");
     if (encoded) {
-      Value posLo = MULI(p, getSize());
-      return {posLo, getSize()};
+      Value posLo = MULI(p, lvlSize);
+      return {posLo, lvlSize};
     }
     // No need to linearize the position for non-annotated tensors.
-    return {C_IDX(0), getSize()};
+    return {C_IDX(0), lvlSize};
   }
 
   const bool encoded;
 };
 
-class SparseLevel : public SparseTensorLevel {
+class CompressedLevel : public SparseLevel {
 public:
-  SparseLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
-              ValueRange lvlBuf)
-      : SparseTensorLevel(tid, lvl, lt, lvlSize, lvlBuf) {
-    assert(!lvlBuf.empty());
+  CompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
+                  Value posBuffer, Value crdBuffer)
+      : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
+
+  ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
+                        Value max) const override {
+    if (max == nullptr) {
+      Value pLo = genIndexLoad(b, l, posBuffer, p);
+      Value pHi = genIndexLoad(b, l, posBuffer, ADDI(p, C_IDX(1)));
+      return {pLo, pHi};
+    }
+    llvm_unreachable("compressed-nu should be the first non-unique level.");
   }
 
-  Value peekCrdAt(OpBuilder &b, Location l, Value iv) const override {
-    return genIndexLoad(b, l, getLvlBufs().front(), iv);
+private:
+  const Value posBuffer;
+};
+
+class LooseCompressedLevel : public SparseLevel {
+public:
+  LooseCompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
+                       Value posBuffer, Value crdBuffer)
+      : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
+
+  ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
+                        Value max) const override {
+    assert(max == nullptr && "loss compressed level can not be non-unique.");
+    p = MULI(p, C_IDX(2));
+    Value pLo = genIndexLoad(b, l, posBuffer, p);
+    Value pHi = genIndexLoad(b, l, posBuffer, ADDI(p, C_IDX(1)));
+    return {pLo, pHi};
   }
+
+private:
+  const Value posBuffer;
 };
 
 class SingletonLevel : public SparseLevel {
@@ -102,8 +141,8 @@ class SingletonLevel : public SparseLevel {
 class TwoOutFourLevel : public SparseLevel {
 public:
   TwoOutFourLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
-                  Value crdBuf)
-      : SparseLevel(tid, lvl, lt, lvlSize, crdBuf) {}
+                  Value crdBuffer)
+      : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer) {}
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
                         Value max) const override {
@@ -114,39 +153,6 @@ class TwoOutFourLevel : public SparseLevel {
   }
 };
 
-class CompressedLevel : public SparseLevel {
-public:
-  CompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
-                  Value posBuffer, Value crdBuffer)
-      : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer, posBuffer}) {}
-
-  ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
-                        Value max) const override {
-    if (max == nullptr) {
-      Value pLo = genIndexLoad(b, l, getPosBuf(), p);
-      Value pHi = genIndexLoad(b, l, getPosBuf(), ADDI(p, C_IDX(1)));
-      return {pLo, pHi};
-    }
-    llvm_unreachable("compressed-nu should be the first non-unique level.");
-  }
-};
-
-class LooseCompressedLevel : public SparseLevel {
-public:
-  LooseCompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
-                       Value posBuffer, Value crdBuffer)
-      : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer, posBuffer}) {}
-
-  ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
-                        Value max) const override {
-    assert(max == nullptr && "loss compressed level can not be non-unique.");
-    p = MULI(p, C_IDX(2));
-    Value pLo = genIndexLoad(b, l, getPosBuf(), p);
-    Value pHi = genIndexLoad(b, l, getPosBuf(), ADDI(p, C_IDX(1)));
-    return {pLo, pHi};
-  }
-};
-
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -195,9 +201,7 @@ static Value offsetFromMinCrd(OpBuilder &b, Location l, Value minCrd,
 //===----------------------------------------------------------------------===//
 // SparseIterator derived classes.
 //===----------------------------------------------------------------------===//
-
-namespace mlir {
-namespace sparse_tensor {
+namespace {
 
 // The iterator that traverses a concrete sparse tensor levels. High-level
 // abstract iterators wrap it to achieve more complex goals (such as collapsing
@@ -232,11 +236,6 @@ class ConcreteIterator : public SparseIterator {
   SmallVector<Value> cursorValsStorage;
 };
 
-} // namespace sparse_tensor
-} // namespace mlir
-
-namespace {
-
 class TrivialIterator : public ConcreteIterator {
 public:
   TrivialIterator(const SparseTensorLevel &stl)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
index 2faa2a8de5651..eb75df9feaae9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
@@ -15,11 +15,9 @@
 namespace mlir {
 namespace sparse_tensor {
 
-class ConcreteIterator;
-
-/// The base class for all types of sparse tensor levels. It provides
-/// interfaces to query the loop range (see `peekRangeAt`) and look up the
-/// coordinates (see `peekCrdAt`).
+/// The base class for all types of sparse tensor levels. It provides interfaces
+/// to query the loop range (see `peekRangeAt`) and look up the coordinates (see
+/// `peekCrdAt`).
 class SparseTensorLevel {
   SparseTensorLevel(SparseTensorLevel &&) = delete;
   SparseTensorLevel(const SparseTensorLevel &) = delete;
@@ -33,6 +31,7 @@ class SparseTensorLevel {
     return std::string(toMLIRString(lt)) + "[" + std::to_string(tid) + "," +
            std::to_string(lvl) + "]";
   }
+
   virtual Value peekCrdAt(OpBuilder &b, Location l, Value iv) const = 0;
 
   /// Peeks the lower and upper bound to *fully* traverse the level with
@@ -53,17 +52,7 @@ class SparseTensorLevel {
 
   Level getLevel() const { return lvl; }
   LevelType getLT() const { return lt; }
-  Value getSize() const { return lvlVals.front(); }
-  Value getCrdBuf() const {
-    assert(lvlVals.size() > 1);
-    return lvlVals[1];
-  }
-  Value getPosBuf() const {
-    assert(lvlVals.size() > 2);
-    return lvlVals[2];
-  }
-  ValueRange getLvlVals() const { return lvlVals; }
-  ValueRange getLvlBufs() const { return ValueRange(lvlVals).drop_front(); }
+  Value getSize() const { return lvlSize; }
 
   //
   // Level properties
@@ -72,24 +61,12 @@ class SparseTensorLevel {
 
 protected:
   SparseTensorLevel(unsigned tid, unsigned lvl, LevelType lt, Value lvlSize)
-      : tid(tid), lvl(lvl), lt(lt), lvlVals() {
-    lvlVals.push_back(lvlSize);
-  };
-
-  SparseTensorLevel(unsigned tid, unsigned lvl, LevelType lt, Value lvlSize,
-                    ValueRange lvlBufs)
-      : tid(tid), lvl(lvl), lt(lt), lvlVals() {
-    lvlVals.push_back(lvlSize);
-    lvlVals.append(lvlBufs.begin(), lvlBufs.end());
-  };
+      : tid(tid), lvl(lvl), lt(lt), lvlSize(lvlSize){};
 
 public:
   const unsigned tid, lvl;
   const LevelType lt;
-  // The first value in the vector is always lvlsize; for sparse levels, the
-  // second value is always the coordinate buffer; for sparse level with
-  // position buffers, the third value is always the position buffer.
-  SmallVector<Value, 3> lvlVals;
+  const Value lvlSize;
 };
 
 enum class IterKind : uint8_t {

>From afd24ca0d52ca80a00570eabdd6b4022bac3f5ef Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 31 Jan 2024 22:21:33 +0000
Subject: [PATCH 3/4] revert unintended change

---
 .../Transforms/Utils/SparseTensorLevel.cpp    |  2 ++
 .../Transforms/Utils/SparseTensorLevel.h      | 31 +++++++++----------
 2 files changed, 17 insertions(+), 16 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
index 604136c3884f9..97d51dbec4a5e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
@@ -45,6 +45,7 @@ using ValueTuple = std::tuple<Value, Value, Value>;
 //===----------------------------------------------------------------------===//
 
 namespace {
+
 class SparseLevel : public SparseTensorLevel {
 public:
   SparseLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
@@ -201,6 +202,7 @@ static Value offsetFromMinCrd(OpBuilder &b, Location l, Value minCrd,
 //===----------------------------------------------------------------------===//
 // SparseIterator derived classes.
 //===----------------------------------------------------------------------===//
+
 namespace {
 
 // The iterator that traverses a concrete sparse tensor levels. High-level
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
index eb75df9feaae9..728e2973e83c3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
@@ -136,21 +136,20 @@ class SparseIterator {
   //
 
   // Whether the iterator support random access (i.e., support look up by
-  // *coordinate*). A random access iterator must also traverses a dense
-  // space.
+  // *coordinate*). A random access iterator must also traverses a dense space.
   virtual bool randomAccessible() const = 0;
 
   // Whether the iterator can simply traversed by a for loop.
   virtual bool iteratableByFor() const { return false; };
 
-  // Get the upper bound of the sparse space that the iterator might visited.
-  // A sparse space is a subset of a dense space [0, bound), this function
-  // returns *bound*.
+  // Get the upper bound of the sparse space that the iterator might visited. A
+  // sparse space is a subset of a dense space [0, bound), this function returns
+  // *bound*.
   virtual Value upperBound(OpBuilder &b, Location l) const = 0;
 
-  // Serializes and deserializes the current status to/from a set of values.
-  // The ValueRange should contain values that are sufficient to recover the
-  // current iterating postion (i.e., itVals) as well as loop bound.
+  // Serializes and deserializes the current status to/from a set of values. The
+  // ValueRange should contain values that are sufficient to recover the current
+  // iterating postion (i.e., itVals) as well as loop bound.
   //
   // Not every type of iterator supports the operations, e.g., non-empty
   // subsection iterator does not because the the number of non-empty
@@ -227,8 +226,8 @@ class SparseIterator {
   //    yield it
   //
   // The function is virtual to allow alternative implementation. For example,
-  // if it.next() is trivial to compute, we can use a select operation
-  // instead. E.g.,
+  // if it.next() is trivial to compute, we can use a select operation instead.
+  // E.g.,
   //
   //  it = select cond ? it+1 : it
   virtual ValueRange forwardIf(OpBuilder &b, Location l, Value cond);
@@ -260,12 +259,12 @@ class SparseIterator {
   // A range of value that together defines the current state of the
   // iterator. Only loop variants should be included.
   //
-  // For trivial iterators, it is the position; for dedup iterators, it
-  // consists of the positon and the segment high, for non-empty subsection
-  // iterator, it is the metadata that specifies the subsection. Note that the
-  // wrapped iterator shares the same storage to maintain itVals with it
-  // wrapper, which means the wrapped iterator might only own a subset of all
-  // the values stored in itValStorage.
+  // For trivial iterators, it is the position; for dedup iterators, it consists
+  // of the positon and the segment high, for non-empty subsection iterator, it
+  // is the metadata that specifies the subsection.
+  // Note that the wrapped iterator shares the same storage to maintain itVals
+  // with it wrapper, which means the wrapped iterator might only own a subset
+  // of all the values stored in itValStorage.
   const unsigned cursorValsCnt;
   SmallVectorImpl<Value> &cursorValsStorageRef;
 };

>From f0c9c744b7c2000b181ebf8da6b7a619bd28d6ae Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 1 Feb 2024 20:08:49 +0000
Subject: [PATCH 4/4] rebase

---
 .../SparseTensor/Transforms/SparseTensorPasses.cpp  |  4 ++--
 .../SparseTensor/Transforms/Sparsification.cpp      |  2 +-
 .../SparseTensor/Transforms/Utils/CodegenEnv.cpp    |  2 +-
 .../SparseTensor/Transforms/Utils/CodegenEnv.h      |  2 +-
 .../SparseTensor/Transforms/Utils/LoopEmitter.cpp   |  6 +++---
 .../SparseTensor/Transforms/Utils/LoopEmitter.h     |  4 ++--
 .../Transforms/Utils/SparseTensorLevel.cpp          | 13 +++++++------
 .../Transforms/Utils/SparseTensorLevel.h            |  4 ++--
 .../SparseTensor/sparse_conv_2d_slice_based.mlir    |  2 +-
 9 files changed, 20 insertions(+), 19 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 0ae9f6483588d..8b89bd4dcdd03 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -82,14 +82,14 @@ struct SparsificationPass
   SparsificationPass(const SparsificationPass &pass) = default;
   SparsificationPass(const SparsificationOptions &options) {
     parallelization = options.parallelizationStrategy;
-    debugSparseIteration = options.debugSparseIteration;
+    sparseEmitStrategy = options.sparseEmitStrategy;
     enableRuntimeLibrary = options.enableRuntimeLibrary;
   }
 
   void runOnOperation() override {
     auto *ctx = &getContext();
     // Translate strategy flags to strategy options.
-    SparsificationOptions options(parallelization, debugSparseIteration,
+    SparsificationOptions options(parallelization, sparseEmitStrategy,
                                   enableRuntimeLibrary);
     // Apply sparsification and cleanup rewriting.
     RewritePatternSet patterns(ctx);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 2ceb214052aa3..ab38ab5cc3f78 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1369,7 +1369,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
       return failure();
 
     // Recursively generates code if admissible.
-    env.startEmit(options.debugSparseIteration);
+    env.startEmit(options.sparseEmitStrategy);
     genBuffers(env, rewriter);
     // TODO: Constant affine expression should be handled differently when using
     // slice-based codegen, it does not matter now because we already reject the
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.cpp
index 0af1cc1745f51..86c13d03c7ec6 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.cpp
@@ -59,7 +59,7 @@ LogicalResult CodegenEnv::initTensorExp() {
   return success();
 }
 
-void CodegenEnv::startEmit(DebugSparseIteration emitStrategy) {
+void CodegenEnv::startEmit(SparseEmitStrategy emitStrategy) {
   assert(insChain == nullptr && "must only start emitting once");
   if (sparseOut) {
     insChain = sparseOut->get();
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h
index 7eeddac48f4f1..d69ae53fb0f29 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h
@@ -52,7 +52,7 @@ class CodegenEnv {
   Merger &merger() { return latticeMerger; }
   LoopEmitter &emitter() { return loopEmitter; }
 
-  void startEmit(DebugSparseIteration emitStrategy);
+  void startEmit(SparseEmitStrategy emitStrategy);
 
   /// Generates loop boundary statements (entering/exiting loops). The function
   /// passes and updates the passed-in parameters.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
index 8c1680a393181..70488c34e440c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
@@ -82,19 +82,19 @@ static Value genSliceStride(OpBuilder &builder, Location loc, Value tensor,
 LoopEmitter::LoopEmitter(ValueRange tensors, StringAttr loopTag, bool hasOutput,
                          bool isSparseOut, unsigned numLoops,
                          DependentLvlGetter dimGetter,
-                         DebugSparseIteration emitStrategy) {
+                         SparseEmitStrategy emitStrategy) {
   initialize(tensors, loopTag, hasOutput, isSparseOut, numLoops, dimGetter);
 }
 
 void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
                              bool isSparseOut, unsigned numLoops,
                              DependentLvlGetter dimGetter,
-                             DebugSparseIteration emitStrategy) {
+                             SparseEmitStrategy emitStrategy) {
   // First initialize the top-level type of the fields.
   this->loopTag = loopTag;
   this->hasOutput = hasOutput;
   this->isSparseOut = isSparseOut;
-  SparseIterator::setDebugSparseIteration(emitStrategy);
+  SparseIterator::setSparseEmitStrategy(emitStrategy);
 
   const unsigned numManifestTensors = ts.size();
   const unsigned synTensorId = numManifestTensors;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
index e0b4f81487a68..5bab2c6a86081 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
@@ -89,13 +89,13 @@ class LoopEmitter {
   initialize(ValueRange tensors, StringAttr loopTag = nullptr,
              bool hasOutput = false, bool isSparseOut = false,
              unsigned numLoops = 0, DependentLvlGetter getter = nullptr,
-             DebugSparseIteration emitStrategy = DebugSparseIteration::kNone);
+             SparseEmitStrategy emitStrategy = SparseEmitStrategy::kFunctional);
 
   explicit LoopEmitter(
       ValueRange tensors, StringAttr loopTag = nullptr, bool hasOutput = false,
       bool isSparseOut = false, unsigned numLoops = 0,
       DependentLvlGetter getter = nullptr,
-      DebugSparseIteration emitStrategy = DebugSparseIteration::kNone);
+      SparseEmitStrategy emitStrategy = SparseEmitStrategy::kFunctional);
 
   /// Starts a loop emitting session by generating all the buffers needed
   /// for iterating over the tensors.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
index 97d51dbec4a5e..c1fc2a062fa10 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
@@ -777,11 +777,12 @@ class SubSectIterator : public SparseIterator {
 // SparseIterator derived classes implementation.
 //===----------------------------------------------------------------------===//
 
-DebugSparseIteration SparseIterator::emitStrategy = DebugSparseIteration::kNone;
+SparseEmitStrategy SparseIterator::emitStrategy =
+    SparseEmitStrategy::kFunctional;
 
 void SparseIterator::genInit(OpBuilder &b, Location l,
                              const SparseIterator *p) {
-  if (emitStrategy == DebugSparseIteration::kInterfaceOnly) {
+  if (emitStrategy == SparseEmitStrategy::kDebugInterface) {
     std::string prefix = getDebugInterfacePrefix();
     Operation *begin = b.create(l, b.getStringAttr(prefix + ".begin"), {},
                                 getCursorValTypes(b));
@@ -793,7 +794,7 @@ void SparseIterator::genInit(OpBuilder &b, Location l,
 }
 
 Value SparseIterator::genNotEnd(OpBuilder &b, Location l) {
-  if (emitStrategy == DebugSparseIteration::kInterfaceOnly) {
+  if (emitStrategy == SparseEmitStrategy::kDebugInterface) {
     std::string prefix = getDebugInterfacePrefix();
     Operation *notEnd = b.create(l, b.getStringAttr(prefix + ".not_end"),
                                  getCursor(), b.getI1Type());
@@ -804,7 +805,7 @@ Value SparseIterator::genNotEnd(OpBuilder &b, Location l) {
 }
 
 void SparseIterator::locate(OpBuilder &b, Location l, Value crd) {
-  if (emitStrategy == DebugSparseIteration::kInterfaceOnly) {
+  if (emitStrategy == SparseEmitStrategy::kDebugInterface) {
     std::string prefix = getDebugInterfacePrefix();
     SmallVector<Value> args = getCursor();
     args.push_back(crd);
@@ -818,7 +819,7 @@ void SparseIterator::locate(OpBuilder &b, Location l, Value crd) {
 }
 
 Value SparseIterator::deref(OpBuilder &b, Location l) {
-  if (emitStrategy == DebugSparseIteration::kInterfaceOnly) {
+  if (emitStrategy == SparseEmitStrategy::kDebugInterface) {
     std::string prefix = getDebugInterfacePrefix();
     SmallVector<Value> args = getCursor();
     Operation *deref = b.create(l, b.getStringAttr(prefix + ".deref"),
@@ -830,7 +831,7 @@ Value SparseIterator::deref(OpBuilder &b, Location l) {
 }
 
 ValueRange SparseIterator::forward(OpBuilder &b, Location l) {
-  if (emitStrategy == DebugSparseIteration::kInterfaceOnly) {
+  if (emitStrategy == SparseEmitStrategy::kDebugInterface) {
     std::string prefix = getDebugInterfacePrefix();
     Operation *next = b.create(l, b.getStringAttr(prefix + ".next"),
                                getCursor(), getCursorValTypes(b));
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
index 728e2973e83c3..6f5da8073cb60 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
@@ -111,7 +111,7 @@ class SparseIterator {
 public:
   virtual ~SparseIterator() = default;
 
-  static void setDebugSparseIteration(DebugSparseIteration strategy) {
+  static void setSparseEmitStrategy(SparseEmitStrategy strategy) {
     SparseIterator::emitStrategy = strategy;
   }
 
@@ -247,7 +247,7 @@ class SparseIterator {
     return ref.take_front(cursorValsCnt);
   }
 
-  static DebugSparseIteration emitStrategy;
+  static SparseEmitStrategy emitStrategy;
 
 public:
   const IterKind kind;     // For LLVM-style RTTI.
diff --git a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
index 18118cab9e52c..6aba0ada947e1 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --sparse-reinterpret-map --sparsification="debug-sparse-iteration=interface-only" --canonicalize --cse --allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map --sparsification="sparse-emit-strategy=debug-interface" --canonicalize --cse --allow-unregistered-dialect | FileCheck %s
 
 #map = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)>
 #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>



More information about the Mlir-commits mailing list