[Mlir-commits] [mlir] [mlir][sparse] implement lowering rules for ExtractIterSpaceOp. @PeimingLiu (PR #89143)

Peiming Liu llvmlistbot at llvm.org
Wed Apr 17 13:59:13 PDT 2024


https://github.com/PeimingLiu created https://github.com/llvm/llvm-project/pull/89143

**DO NOT MERGE** until https://github.com/llvm/llvm-project/pull/89003

>From 7db93fc5e9fab3de3b4805d56066cfdbf74895b4 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 12 Apr 2024 22:03:06 +0000
Subject: [PATCH 1/3] [mlir][sparse] introduce `sparse_tensor.iterate`
 operation

---
 .../Dialect/SparseTensor/IR/SparseTensor.h    |  38 +++
 .../SparseTensor/IR/SparseTensorAttrDefs.td   |  15 ++
 .../SparseTensor/IR/SparseTensorOps.td        | 101 ++++++-
 .../SparseTensor/IR/SparseTensorDialect.cpp   | 255 ++++++++++++++++++
 mlir/test/Dialect/SparseTensor/invalid.mlir   |  57 ++++
 mlir/test/Dialect/SparseTensor/roundtrip.mlir |  28 ++
 .../SparseTensor/sparse_itertion_licm.mlir    |  26 ++
 7 files changed, 519 insertions(+), 1 deletion(-)
 create mode 100644 mlir/test/Dialect/SparseTensor/sparse_itertion_licm.mlir

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 5e523ec428aefb..081a9b8cad8d62 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -17,9 +17,13 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/TensorEncoding.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
+#include "llvm/ADT/bit.h"
+
 //===----------------------------------------------------------------------===//
 //
 // Type aliases to help code be more self-documenting. Unfortunately
@@ -41,6 +45,40 @@ using Level = uint64_t;
 /// including the value `ShapedType::kDynamic` (for shapes).
 using Size = int64_t;
 
+/// A simple wrapper to encode a bitset of defined  (at most 64) levels.
+class LevelSet {
+  uint64_t bits = 0;
+
+public:
+  LevelSet() = default;
+  explicit LevelSet(uint64_t bits) : bits(bits) {}
+  operator uint64_t() const { return bits; }
+
+  LevelSet &set(unsigned i) {
+    assert(i < 64);
+    bits |= 1 << i;
+    return *this;
+  }
+
+  LevelSet &operator|=(LevelSet lhs) {
+    bits |= static_cast<uint64_t>(lhs);
+    return *this;
+  }
+
+  LevelSet &lshift(unsigned offset) {
+    bits = bits << offset;
+    return *this;
+  }
+
+  bool operator[](unsigned i) const {
+    assert(i < 64);
+    return (bits & (1 << i)) != 0;
+  }
+
+  unsigned count() const { return llvm::popcount(bits); }
+  bool empty() const { return bits == 0; }
+};
+
 } // namespace sparse_tensor
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 4a9b9169ae4b86..d5398a98f5b171 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -19,6 +19,21 @@ class SparseTensor_Attr<string name,
                         list<Trait> traits = []>
     : AttrDef<SparseTensor_Dialect, name, traits>;
 
+//===----------------------------------------------------------------------===//
+// A simple bitset attribute wrapped over a single int64_t to encode a set of
+// sparse tensor levels.
+//===----------------------------------------------------------------------===//
+
+def LevelSetAttr :
+    TypedAttrBase<
+      I64, "IntegerAttr",
+      And<[CPred<"::llvm::isa<::mlir::IntegerAttr>($_self)">,
+           CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getType().isInteger(64)">]>,
+      "LevelSet attribute"> {
+  let returnType = [{::mlir::sparse_tensor::LevelSet}];
+  let convertFromStorage = [{::mlir::sparse_tensor::LevelSet($_self.getValue().getZExtValue())}];
+}
+
 //===----------------------------------------------------------------------===//
 // These attributes are just like `IndexAttr` except that they clarify whether
 // the index refers to a dimension (an axis of the semantic tensor) or a level
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 4e4441c640ed95..08140b9d2b6192 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -15,6 +15,8 @@ include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td"
 include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/ControlFlowInterfaces.td"
+include "mlir/Interfaces/LoopLikeInterface.td"
 
 //===----------------------------------------------------------------------===//
 // Base class.
@@ -1304,7 +1306,7 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [Pure, SameOperandsAndResu
 
 def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator,
     ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp",
-                 "ForeachOp"]>]> {
+                 "ForeachOp", "IterateOp"]>]> {
   let summary = "Yield from sparse_tensor set-like operations";
   let description = [{
       Yields a value from within a `binary`, `unary`, `reduce`,
@@ -1513,6 +1515,103 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
   let hasVerifier = 1;
 }
 
+def IterateOp : SparseTensor_Op<"iterate",
+    [RecursiveMemoryEffects, RecursivelySpeculatable,
+     DeclareOpInterfaceMethods<LoopLikeOpInterface,
+      ["getInitsMutable", "getLoopResults", "getRegionIterArgs",
+       "getYieldedValuesMutable"]>,
+     DeclareOpInterfaceMethods<RegionBranchOpInterface,
+      ["getEntrySuccessorOperands"]>,
+     SingleBlockImplicitTerminator<"sparse_tensor::YieldOp">]> {
+
+  let arguments = (ins AnySparseIterSpace:$iterSpace,
+                       Variadic<AnyType>:$initArgs,
+                       LevelSetAttr:$crdUsedLvls);
+  let results = (outs Variadic<AnyType>:$results);
+  let regions = (region SizedRegion<1>:$region);
+
+  let summary = "Iterate over a sparse iteration space";
+  let description = [{
+      The `sparse_tensor.iterate` operations represents a loop over the
+      provided iteration space extracted from a specific sparse tensor.
+      The operation defines an SSA value for a sparse iterator that points
+      to the current stored element in the sparse tensor and SSA values
+      for coordinates of the stored element. The coordinates are always
+      converted to `index` type despite of the underlying sparse tensor
+      storage. When coordinates are not used, the SSA values can be skipped
+      by `_` symbols, which usually leads to simpler generated code after
+      sparsification. For example:
+
+      ```mlir
+      // The coordinate for level 0 is not used when iterating over a 2-D
+      // iteration space.
+      %sparse_tensor.iterate %iterator in %space at(_, %crd_1)
+        : !sparse_tensor.iter_space<#CSR, lvls = 0 to 2>
+      ```
+
+      `sparse_tensor.iterate` can also operate on loop-carried variables
+      and returns the final values after loop termination.
+      The initial values of the variables are passed as additional SSA operands
+      to the iterator SSA value and used coordinate SSA values mentioned
+      above. The operation region has an argument for the iterator, variadic
+      arguments for specified (used) coordiates and followed by one argument
+      for each loop-carried variable, representing the value of the variable
+      at the current iteration.
+      The body region must contain exactly one block that terminates with
+      `sparse_tensor.yield`.
+
+      `sparse_tensor.iterate` results hold the final values after the last
+      iteration. If the `sparse_tensor.iterate` defines any values, a yield
+      must be explicitly present.
+      The number and types of the `sparse_tensor.iterate` results must match
+      the initial values in the iter_args binding and the yield operands.
+
+
+      A nested `sparse_tensor.iterate` example that prints all the coordinates
+      stored in the sparse input:
+
+      ```mlir
+      func.func @nested_iterate(%sp : tensor<4x8xf32, #COO>) {
+        // Iterates over the first level of %sp
+        %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+        %r1 = sparse_tensor.iterate %it1 in %l1 at (%crd0)
+            : !sparse_tensor.iter_space<#COO, lvls = 0 to 1>  {
+          // Iterates over the second level of %sp
+          %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1
+              : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
+          %r2 = sparse_tensor.iterate %it2 in %l2 at (crd1)
+              : !sparse_tensor.iter_space<#COO, lvls = 1 to 2>  {
+             vector.print %crd0 : index
+             vector.print %crd1 : index
+          }
+        }
+      }
+
+      ```
+  }];
+
+  let extraClassDeclaration = [{
+    unsigned getSpaceDim() {
+      return getIterSpace().getType().getSpaceDim();
+    }
+    BlockArgument getIterator() {
+      return getRegion().getArguments().front();
+    }
+    Block::BlockArgListType getCrds() {
+      // The first block argument is iterator, the remaining arguments are
+      // referenced coordinates.
+      return getRegion().getArguments().slice(1, getCrdUsedLvls().count());
+    }
+    unsigned getNumRegionIterArgs() {
+      return getRegion().getArguments().size() - 1 - getCrdUsedLvls().count();
+    }
+  }];
+
+  let hasVerifier = 1;
+  let hasRegionVerifier = 1;
+  let hasCustomAssemblyFormat = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // Sparse Tensor Debugging and Test-Only Operations.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 516b0943bdcfac..36908def09f403 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -2027,6 +2027,106 @@ static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
   printLevelRange(p, lo, hi);
 }
 
+static ParseResult
+parseSparseSpaceLoop(OpAsmParser &parser, OperationState &state,
+                     SmallVectorImpl<OpAsmParser::Argument> &iterators,
+                     SmallVectorImpl<OpAsmParser::Argument> &iterArgs) {
+  SmallVector<OpAsmParser::UnresolvedOperand> spaces;
+  SmallVector<OpAsmParser::UnresolvedOperand> initArgs;
+
+  // Parses "%iters, ... in %spaces, ..."
+  if (parser.parseArgumentList(iterators) || parser.parseKeyword("in") ||
+      parser.parseOperandList(spaces))
+    return failure();
+
+  if (iterators.size() != spaces.size())
+    return parser.emitError(
+        parser.getNameLoc(),
+        "mismatch in number of sparse iterators and sparse spaces");
+
+  // Parse "at(%crd0, _, ...)"
+  LevelSet crdUsedLvlSet;
+  bool hasUsedCrds = succeeded(parser.parseOptionalKeyword("at"));
+  unsigned lvlCrdCnt = 0;
+  if (hasUsedCrds) {
+    ParseResult crdList = parser.parseCommaSeparatedList(
+        OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
+          if (parser.parseOptionalKeyword("_")) {
+            if (parser.parseArgument(iterArgs.emplace_back()))
+              return failure();
+            // Always use IndexType for the coordinate.
+            crdUsedLvlSet.set(lvlCrdCnt);
+            iterArgs.back().type = parser.getBuilder().getIndexType();
+          }
+          lvlCrdCnt += 1;
+          return success();
+        });
+    if (failed(crdList)) {
+      return parser.emitError(
+          parser.getNameLoc(),
+          "expecting SSA value or \"_\" for level coordinates");
+    }
+  }
+  // Set the CrdUsedLvl bitset.
+  state.addAttribute("crdUsedLvls",
+                     parser.getBuilder().getI64IntegerAttr(crdUsedLvlSet));
+
+  // Parse "iter_args(%arg = %init, ...)"
+  bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
+  if (hasIterArgs)
+    if (parser.parseAssignmentList(iterArgs, initArgs))
+      return failure();
+
+  SmallVector<Type> iterSpaceTps;
+  // parse ": sparse_tensor.iter_space -> ret"
+  if (parser.parseColon() || parser.parseTypeList(iterSpaceTps))
+    return failure();
+  if (iterSpaceTps.size() != spaces.size())
+    return parser.emitError(parser.getNameLoc(),
+                            "mismatch in number of iteration space operands "
+                            "and iteration space types");
+
+  for (auto [it, tp] : llvm::zip_equal(iterators, iterSpaceTps)) {
+    IterSpaceType spaceTp = llvm::dyn_cast<IterSpaceType>(tp);
+    if (!spaceTp)
+      return parser.emitError(parser.getNameLoc(),
+                              "expected sparse_tensor.iter_space type for "
+                              "iteration space operands");
+    if (hasUsedCrds && spaceTp.getSpaceDim() != lvlCrdCnt)
+      return parser.emitError(parser.getNameLoc(),
+                              "mismatch in number of iteration space dimension "
+                              "and specified coordinates");
+    it.type = spaceTp.getIteratorType();
+  }
+
+  if (hasIterArgs)
+    if (parser.parseArrowTypeList(state.types))
+      return failure();
+
+  // Resolves input operands.
+  if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(),
+                             state.operands))
+    return failure();
+
+  if (hasIterArgs) {
+    unsigned numCrds = crdUsedLvlSet.count();
+    // Strip off leading args that used for coordinates.
+    MutableArrayRef args = MutableArrayRef(iterArgs).drop_front(numCrds);
+    if (args.size() != initArgs.size() || args.size() != state.types.size()) {
+      return parser.emitError(
+          parser.getNameLoc(),
+          "mismatch in number of iteration arguments and return values");
+    }
+
+    for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) {
+      it.type = tp;
+      if (parser.resolveOperand(init, tp, state.operands))
+        return failure();
+    }
+  }
+  return success();
+}
+
 LogicalResult ExtractIterSpaceOp::inferReturnTypes(
     MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
     DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
@@ -2063,6 +2163,161 @@ LogicalResult ExtractIterSpaceOp::verify() {
   return success();
 }
 
+ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
+  OpAsmParser::Argument iterator;
+  OpAsmParser::UnresolvedOperand iterSpace;
+
+  SmallVector<OpAsmParser::Argument> iters, iterArgs;
+  if (parseSparseSpaceLoop(parser, result, iters, iterArgs))
+    return failure();
+  if (iters.size() != 1)
+    return parser.emitError(parser.getNameLoc(),
+                            "expected only one iterator/iteration space");
+
+  iters.append(iterArgs);
+  Region *body = result.addRegion();
+  if (parser.parseRegion(*body, iters))
+    return failure();
+
+  IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
+
+  // Parse the optional attribute list.
+  if (parser.parseOptionalAttrDict(result.attributes))
+    return failure();
+
+  return success();
+}
+
+/// Prints the initialization list in the form of
+///   <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
+/// where 'inner' values are assumed to be region arguments and 'outer' values
+/// are regular SSA values.
+static void printInitializationList(OpAsmPrinter &p,
+                                    Block::BlockArgListType blocksArgs,
+                                    ValueRange initializers,
+                                    StringRef prefix = "") {
+  assert(blocksArgs.size() == initializers.size() &&
+         "expected same length of arguments and initializers");
+  if (initializers.empty())
+    return;
+
+  p << prefix << '(';
+  llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
+    p << std::get<0>(it) << " = " << std::get<1>(it);
+  });
+  p << ")";
+}
+
+static void printUsedCrdsList(OpAsmPrinter &p, unsigned spaceDim,
+                              Block::BlockArgListType blocksArgs,
+                              LevelSet crdUsedLvls) {
+  if (crdUsedLvls.empty())
+    return;
+
+  p << " at(";
+  for (unsigned i = 0; i < spaceDim; i++) {
+    if (crdUsedLvls[i]) {
+      p << blocksArgs.front();
+      blocksArgs = blocksArgs.drop_front();
+    } else {
+      p << "_";
+    }
+    if (i != spaceDim - 1)
+      p << ", ";
+  }
+  assert(blocksArgs.empty());
+  p << ")";
+}
+
+void IterateOp::print(OpAsmPrinter &p) {
+  p << " " << getIterator() << " in " << getIterSpace();
+  printUsedCrdsList(p, getSpaceDim(), getCrds(), getCrdUsedLvls());
+  printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
+
+  p << " : " << getIterSpace().getType() << " ";
+  if (!getInitArgs().empty())
+    p << "-> (" << getInitArgs().getTypes() << ") ";
+
+  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
+                /*printBlockTerminators=*/!getInitArgs().empty());
+}
+
+LogicalResult IterateOp::verify() {
+  if (getInitArgs().size() != getNumResults()) {
+    return emitOpError(
+        "mismatch in number of loop-carried values and defined values");
+  }
+  return success();
+}
+
+LogicalResult IterateOp::verifyRegions() {
+  if (getIterator().getType() != getIterSpace().getType().getIteratorType())
+    return emitOpError("mismatch in iterator and iteration space type");
+  if (getNumRegionIterArgs() != getNumResults())
+    return emitOpError(
+        "mismatch in number of basic block args and defined values");
+
+  auto initArgs = getInitArgs();
+  auto iterArgs = getRegionIterArgs();
+  auto yieldVals = getYieldedValues();
+  auto opResults = getResults();
+  if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
+                        opResults.size()})) {
+    return emitOpError() << "number mismatch between iter args and results.";
+  }
+
+  for (auto [i, init, iter, yield, ret] :
+       llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) {
+    if (init.getType() != ret.getType())
+      return emitOpError() << "types mismatch between " << i
+                           << "th iter operand and defined value";
+    if (iter.getType() != ret.getType())
+      return emitOpError() << "types mismatch between " << i
+                           << "th iter region arg and defined value";
+    if (yield.getType() != ret.getType())
+      return emitOpError() << "types mismatch between " << i
+                           << "th yield value and defined value";
+  }
+
+  return success();
+}
+
+/// IterateOp implemented OpInterfaces' methods.
+SmallVector<Region *> IterateOp::getLoopRegions() { return {&getRegion()}; }
+
+MutableArrayRef<OpOperand> IterateOp::getInitsMutable() {
+  return getInitArgsMutable();
+}
+
+Block::BlockArgListType IterateOp::getRegionIterArgs() {
+  return getRegion().getArguments().take_back(getNumRegionIterArgs());
+}
+
+std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() {
+  return cast<sparse_tensor::YieldOp>(
+             getRegion().getBlocks().front().getTerminator())
+      .getResultsMutable();
+}
+
+std::optional<ResultRange> IterateOp::getLoopResults() { return getResults(); }
+
+OperandRange IterateOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+  return getInitArgs();
+}
+
+void IterateOp::getSuccessorRegions(RegionBranchPoint point,
+                                    SmallVectorImpl<RegionSuccessor> &regions) {
+  // Both the operation itself and the region may be branching into the body or
+  // back into the operation itself.
+  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
+  // It is possible for loop not to enter the body.
+  regions.push_back(RegionSuccessor(getResults()));
+}
+
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Dialect Setups.
+//===----------------------------------------------------------------------===//
+
 /// Materialize a single constant operation from a given attribute value with
 /// the desired resultant type.
 Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 3fa696e1600a93..b13024cd4ed99d 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -1094,3 +1094,60 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse
   %l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 2 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
   return
 }
+
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : compressed(nonunique),
+    j : singleton(soa)
+  )
+}>
+
+func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index, %j : index) -> index {
+  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+  // expected-error @+1 {{'sparse_tensor.iterate' op different number of region iter_args and yielded values: 2 != 1}}
+  %r1, %r2 = sparse_tensor.iterate %it1 in %l1 at (%crd) iter_args(%si = %i, %sj = %j): !sparse_tensor.iter_space<#COO, lvls = 0> -> (index, index) {
+    sparse_tensor.yield %si : index
+  }
+  return %r1 : index
+}
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : compressed(nonunique),
+    j : singleton(soa)
+  )
+}>
+
+// expected-note at +1 {{prior use here}}
+func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index) -> f32 {
+  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+  // expected-error @+1 {{use of value '%i' expects different type than prior uses: 'f32' vs 'index'}}
+  %r1 = sparse_tensor.iterate %it1 in %l1 at (%crd) iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0> -> f32 {
+    sparse_tensor.yield %outer : f32
+  }
+  return %r1 : f32
+}
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : compressed(nonunique),
+    j : singleton(soa)
+  )
+}>
+
+func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index, %j : index) -> index {
+  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+  // expected-error @+1 {{'sparse_tensor.iterate' op 0-th region iter_arg and 0-th yielded value have different type: 'index' != 'f32'}}
+  %r1 = sparse_tensor.iterate %it1 in %l1 at (%crd) iter_args(%si = %i): !sparse_tensor.iter_space<#COO, lvls = 0> -> index {
+    %y = arith.constant 1.0 :  f32
+    sparse_tensor.yield %y : f32
+  }
+  return %r1 : index
+}
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index d34071279e5129..e9a898f16b41d2 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -763,3 +763,31 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse
   %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
   return %l1, %l2 : !sparse_tensor.iter_space<#COO, lvls = 0>, !sparse_tensor.iter_space<#COO, lvls = 1>
 }
+
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : compressed(nonunique),
+    j : singleton(soa)
+  )
+}>
+
+// CHECK-LABEL:   func.func @sparse_iterate(
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<4x8xf32, #sparse{{[0-9]*}}>,
+// CHECK-SAME:      %[[VAL_1:.*]]: index,
+// CHECK-SAME:      %[[VAL_2:.*]]: index) -> index {
+// CHECK:           %[[VAL_3:.*]] = sparse_tensor.extract_iteration_space %[[VAL_0]] lvls = 0 : tensor<4x8xf32, #sparse{{[0-9]*}}>
+// CHECK:           %[[VAL_4:.*]] = sparse_tensor.iterate %[[VAL_5:.*]] in %[[VAL_3]] at(%[[VAL_6:.*]]) iter_args(%[[VAL_7:.*]] = %[[VAL_1]]) : !sparse_tensor.iter_space<#sparse{{[0-9]*}}, lvls = 0> -> (index) {
+// CHECK:             sparse_tensor.yield %[[VAL_7]] : index
+// CHECK:           }
+// CHECK:           return %[[VAL_4]] : index
+// CHECK:         }
+func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index, %j : index) -> index {
+  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+  %r1 = sparse_tensor.iterate %it1 in %l1 at (%crd) iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0 to 1> -> index {
+    sparse_tensor.yield %outer : index
+  }
+  return %r1 : index
+}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_itertion_licm.mlir b/mlir/test/Dialect/SparseTensor/sparse_itertion_licm.mlir
new file mode 100644
index 00000000000000..e7158d04b37feb
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_itertion_licm.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt %s --loop-invariant-code-motion | FileCheck %s
+
+#CSR = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : dense,
+    j : compressed
+  )
+}>
+
+// Make sure that pure instructions are hoisted outside the loop.
+//
+// CHECK: sparse_tensor.values
+// CHECK: sparse_tensor.positions
+// CHECK: sparse_tensor.coordinate
+// CHECK: sparse_tensor.iterate
+func.func @sparse_iterate(%sp : tensor<?x?xf64, #CSR>) {
+  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<?x?xf64, #CSR>
+  sparse_tensor.iterate %it1 in %l1 at (%crd) : !sparse_tensor.iter_space<#CSR, lvls = 0> {
+    %0 = sparse_tensor.values %sp : tensor<?x?xf64, #CSR> to memref<?xf64>
+    %1 = sparse_tensor.positions %sp { level = 1 : index } : tensor<?x?xf64, #CSR> to memref<?xindex>
+    %2 = sparse_tensor.coordinates  %sp { level = 1 : index } : tensor<?x?xf64, #CSR> to memref<?xindex>
+    "test.op"(%0, %1, %2) : (memref<?xf64>, memref<?xindex>, memref<?xindex>) -> ()
+  }
+
+  return
+}

>From 6adcdb6106b3c05e27e41404dd8dce5264d75794 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 16 Apr 2024 17:00:45 +0000
Subject: [PATCH 2/3] [mlir][sparse] implement sparse space collapse pass.

---
 .../Dialect/SparseTensor/Transforms/Passes.h  |   6 +
 .../Dialect/SparseTensor/Transforms/Passes.td |  16 ++
 .../SparseTensor/Transforms/CMakeLists.txt    |   1 +
 .../Transforms/SparseSpaceCollapse.cpp        | 183 ++++++++++++++++++
 .../SparseTensor/sparse_space_collapse.mlir   |  33 ++++
 5 files changed, 239 insertions(+)
 create mode 100644 mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
 create mode 100644 mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir

diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index d6d038ef65bdf4..3043a0c4dc4109 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -248,6 +248,12 @@ std::unique_ptr<Pass> createSparsificationAndBufferizationPass(
     bool enableBufferInitialization, unsigned vectorLength,
     bool enableVLAVectorization, bool enableSIMDIndex32, bool enableGPULibgen);
 
+//===----------------------------------------------------------------------===//
+// Sparse Iteration Transform Passes
+//===----------------------------------------------------------------------===//
+
+std::unique_ptr<Pass> createSparseSpaceCollapsePass();
+
 //===----------------------------------------------------------------------===//
 // Registration.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 4706d5ba2f218c..d2265dd08205ad 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -463,4 +463,20 @@ def SparsificationAndBufferization : Pass<"sparsification-and-bufferization", "M
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// Sparse Iteration Transform Passes
+//===----------------------------------------------------------------------===//
+
+def SparseSpaceCollapse : Pass<"sparse-space-collapse", "func::FuncOp"> {
+  let summary = "(experimental) sparse space collpasing pass";
+  let description = [{
+     This pass collapse consecutive sparse spaces (extracted from the same tensor)
+     into one multi-dimensional space.
+  }];
+  let constructor = "mlir::createSparseSpaceCollapsePass()";
+  let dependentDialects = [
+    "sparse_tensor::SparseTensorDialect",
+  ];
+}
+
 #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index af3a1b48f45af9..2a29ee8a7a87cb 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
   SparseGPUCodegen.cpp
   SparseReinterpretMap.cpp
   SparseStorageSpecifierToLLVM.cpp
+  SparseSpaceCollapse.cpp
   SparseTensorCodegen.cpp
   SparseTensorConversion.cpp
   SparseTensorPasses.cpp
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
new file mode 100644
index 00000000000000..bc469992d97103
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
@@ -0,0 +1,183 @@
+//===--------- SparseSpaceCollapse.cpp - Collapse Sparse Space Pass -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/Transforms/Passes.h"
+
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+
+namespace mlir {
+
+#define GEN_PASS_DEF_SPARSESPACECOLLAPSE
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
+
+namespace sparse_tensor {
+
+struct CollapseSpaceInfo {
+  ExtractIterSpaceOp space;
+  IterateOp loop;
+};
+
+bool isCollapsableLoops(LoopLikeOpInterface parent, LoopLikeOpInterface node) {
+  auto pIterArgs = parent.getRegionIterArgs();
+  auto nInitArgs = node.getInits();
+  if (pIterArgs.size() != nInitArgs.size())
+    return false;
+
+  // Two loops are collapsable if they are perfectly nested.
+  auto pYields = parent.getYieldedValues();
+  auto nResult = node.getLoopResults().value();
+
+  bool yieldEq =
+      llvm::all_of(llvm::zip_equal(pYields, nResult), [](auto zipped) {
+        return std::get<0>(zipped) == std::get<1>(zipped);
+      });
+
+  // Parent iter_args should be passed directly to the node's init_args.
+  bool iterArgEq =
+      llvm::all_of(llvm::zip_equal(pIterArgs, nInitArgs), [](auto zipped) {
+        return std::get<0>(zipped) == std::get<1>(zipped);
+      });
+
+  return yieldEq && iterArgEq;
+}
+
+bool legalToCollapse(SmallVectorImpl<CollapseSpaceInfo> &toCollapse,
+                     ExtractIterSpaceOp curSpace) {
+
+  auto getIterateOpOverSpace = [](ExtractIterSpaceOp space) -> IterateOp {
+    Value spaceVal = space.getResultSpace();
+    if (spaceVal.hasOneUse())
+      return llvm::dyn_cast<IterateOp>(*spaceVal.getUsers().begin());
+    return nullptr;
+  };
+
+  if (toCollapse.empty()) {
+    // Collapse root.
+    if (auto itOp = getIterateOpOverSpace(curSpace)) {
+      CollapseSpaceInfo &info = toCollapse.emplace_back();
+      info.space = curSpace;
+      info.loop = itOp;
+      return true;
+    }
+    return false;
+  }
+
+  auto parent = toCollapse.back().space;
+  auto pItOp = toCollapse.back().loop;
+  auto nItOp = getIterateOpOverSpace(curSpace);
+
+  // Can only collapse spaces extracted from the same tensor.
+  if (parent.getTensor() != curSpace.getTensor())
+    return false;
+
+  // Can only collapse consecutive simple iteration on one tensor (i.e., no
+  // coiteration).
+  if (!nItOp || nItOp->getBlock() != curSpace->getBlock() ||
+      pItOp.getIterator() != curSpace.getParentIter() ||
+      curSpace->getParentOp() != pItOp.getOperation())
+    return false;
+
+  if (pItOp && !isCollapsableLoops(pItOp, nItOp))
+    return false;
+
+  CollapseSpaceInfo &info = toCollapse.emplace_back();
+  info.space = curSpace;
+  info.loop = nItOp;
+  return true;
+}
+
+void collapseSparseSpace(MutableArrayRef<CollapseSpaceInfo> toCollapse) {
+  if (toCollapse.size() < 2)
+    return;
+
+  ExtractIterSpaceOp root = toCollapse.front().space;
+  ExtractIterSpaceOp leaf = toCollapse.back().space;
+  Location loc = root.getLoc();
+
+  assert(root->hasOneUse() && leaf->hasOneUse());
+
+  // Insert collapsed operation at the same scope as root operation.
+  OpBuilder builder(root);
+
+  // Construct the collapsed iteration space.
+  auto collapsedSpace = builder.create<ExtractIterSpaceOp>(
+      loc, root.getTensor(), root.getParentIter(), root.getLoLvl(),
+      leaf.getHiLvl());
+
+  auto rItOp = llvm::cast<IterateOp>(*root->getUsers().begin());
+  auto innermost = toCollapse.back().loop;
+
+  IRMapping mapper;
+  mapper.map(leaf, collapsedSpace.getResultSpace());
+  for (auto z : llvm::zip_equal(innermost.getInitArgs(), rItOp.getInitArgs()))
+    mapper.map(std::get<0>(z), std::get<1>(z));
+
+  auto cloned = llvm::cast<IterateOp>(builder.clone(*innermost, mapper));
+  builder.setInsertionPointToStart(cloned.getBody());
+
+  LevelSet crdUsedLvls;
+  unsigned shift = 0, argIdx = 1;
+  for (auto info : toCollapse.drop_back()) {
+    LevelSet set = info.loop.getCrdUsedLvls();
+    crdUsedLvls |= set.lshift(shift);
+    shift += info.loop.getSpaceDim();
+    for (BlockArgument crd : info.loop.getCrds()) {
+      BlockArgument collapsedCrd = cloned.getBody()->insertArgument(
+          argIdx++, builder.getIndexType(), crd.getLoc());
+      crd.replaceAllUsesWith(collapsedCrd);
+    }
+  }
+  crdUsedLvls |= innermost.getCrdUsedLvls().lshift(shift);
+  cloned.getIterator().setType(collapsedSpace.getType().getIteratorType());
+  cloned.setCrdUsedLvls(crdUsedLvls);
+
+  rItOp.replaceAllUsesWith(cloned.getResults());
+  // Erase collapsed loops.
+  rItOp.erase();
+  root.erase();
+}
+
+struct SparseSpaceCollapsePass
+    : public impl::SparseSpaceCollapseBase<SparseSpaceCollapsePass> {
+  SparseSpaceCollapsePass() = default;
+
+  void runOnOperation() override {
+    func::FuncOp func = getOperation();
+
+    // A naive (experimental) implementation to collapse consecutive sparse
+    // spaces. It does NOT handle complex cases where multiple spaces are
+    // extracted in the same basic block. E.g.,
+    //
+    // %space1 = extract_space %t1 ...
+    // %space2 = extract_space %t2 ...
+    // sparse_tensor.iterate(%sp1) ...
+    //
+    SmallVector<CollapseSpaceInfo> toCollapse;
+    func->walk([&](ExtractIterSpaceOp op) {
+      if (!legalToCollapse(toCollapse, op)) {
+        // if not legal to collapse one more space, collapse the existing ones
+        // and clear.
+        collapseSparseSpace(toCollapse);
+        toCollapse.clear();
+      }
+    });
+
+    collapseSparseSpace(toCollapse);
+  }
+};
+
+} // namespace sparse_tensor
+
+std::unique_ptr<Pass> createSparseSpaceCollapsePass() {
+  return std::make_unique<sparse_tensor::SparseSpaceCollapsePass>();
+}
+
+} // namespace mlir
diff --git a/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir b/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir
new file mode 100644
index 00000000000000..392dfe01884ba8
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-opt %s --sparse-space-collapse | FileCheck %s
+
+#COO = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : compressed(nonunique),
+    j : singleton(soa)
+  )
+}>
+
+// CHECK-LABEL:   func.func @sparse_sparse_collapse(
+// CHECK-SAME:         %[[VAL_0:.*]]: tensor<4x8xf32, #sparse>,
+// CHECK-SAME:         %[[VAL_1:.*]]: index) {
+// CHECK:           %[[VAL_3:.*]] = sparse_tensor.extract_iteration_space %[[VAL_0]] lvls = 0 to 2 : tensor<4x8xf32, #sparse>
+// CHECK:           %[[VAL_4:.*]] = sparse_tensor.iterate %[[VAL_5:.*]] in %[[VAL_3]] at(%[[VAL_6:.*]], _) iter_args(%[[VAL_7:.*]] = %[[VAL_1]])
+// CHECK:             %[[VAL_8:.*]] = "test.op"(%[[VAL_7]]) : (index) -> index
+// CHECK:             sparse_tensor.yield %[[VAL_8]] : index
+// CHECK:           }
+// CHECK:           "test.sink"(%[[VAL_4]]) : (index) -> ()
+// CHECK:           return
+// CHECK:         }
+func.func @sparse_sparse_collapse(%sp : tensor<4x8xf32, #COO>, %i : index) {
+  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+  %r1 = sparse_tensor.iterate %it1 in %l1 at(%crd0) iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0 to 1> -> index {
+    %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
+    %r2 = sparse_tensor.iterate %it2 in %l2 iter_args(%inner = %outer): !sparse_tensor.iter_space<#COO, lvls = 1 to 2> -> index {
+      %k ="test.op"(%inner) : (index) -> index
+      sparse_tensor.yield %k : index
+    }
+    sparse_tensor.yield %r2 : index
+  }
+  "test.sink"(%r1) : (index) -> ()
+  return
+}

>From c5c7bb105250a8547aef4cb6b1a12e9d01bd5c02 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 17 Apr 2024 18:51:24 +0000
Subject: [PATCH 3/3] [mlir][sparse] implement lowering rules for
 ExtractIterSpaceOp.

---
 .../mlir/Dialect/SparseTensor/IR/Enums.h      |   4 +
 .../Dialect/SparseTensor/Transforms/Passes.h  |  15 ++
 .../Dialect/SparseTensor/Transforms/Passes.td |  15 ++
 .../SparseTensor/Transforms/CMakeLists.txt    |   1 +
 .../Transforms/SparseIterationToScf.cpp       |  78 ++++++++
 .../Transforms/SparseTensorPasses.cpp         |  29 +++
 .../Transforms/Utils/LoopEmitter.h            |   2 +-
 .../Transforms/Utils/SparseTensorIterator.cpp | 178 ++++++++++++------
 .../Transforms/Utils/SparseTensorIterator.h   |  76 +++++++-
 .../SparseTensor/sparse_iteration_to_scf.mlir |  23 +++
 10 files changed, 356 insertions(+), 65 deletions(-)
 create mode 100644 mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
 create mode 100644 mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index 33f613a46bad84..96ee7111fea2cf 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -357,6 +357,10 @@ struct LevelType {
     return hasSparseSemantic();
   }
 
+  constexpr unsigned getNumBuffer() const {
+    return hasDenseSemantic() ? 0 : (isWithPosLT() ? 2 : 1);
+  }
+
   std::string toMLIRString() const {
     std::string lvlStr = toFormatString(getLvlFmt());
     std::string propStr = "";
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 3043a0c4dc4109..c9164e39a3a75e 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -16,6 +16,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/OneToNTypeConversion.h"
 
 //===----------------------------------------------------------------------===//
 // Include the generated pass header (which needs some early definitions).
@@ -143,6 +144,20 @@ void populateLowerForeachToSCFPatterns(RewritePatternSet &patterns);
 
 std::unique_ptr<Pass> createLowerForeachToSCFPass();
 
+//===----------------------------------------------------------------------===//
+// The LowerSparseIterationToSCF pass.
+//===----------------------------------------------------------------------===//
+
+/// Type converter for iter_space and iterator.
+struct SparseIterationTypeConverter : public OneToNTypeConverter {
+  SparseIterationTypeConverter();
+};
+
+void populateLowerSparseIterationToSCFPatterns(TypeConverter &converter,
+                                               RewritePatternSet &patterns);
+
+std::unique_ptr<Pass> createLowerSparseIterationToSCFPass();
+
 //===----------------------------------------------------------------------===//
 // The SparseTensorConversion pass.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index d2265dd08205ad..6f252b99d02c84 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -479,4 +479,19 @@ def SparseSpaceCollapse : Pass<"sparse-space-collapse", "func::FuncOp"> {
   ];
 }
 
+def LowerSparseIterationToSCF : Pass<"lower-sparse-iteration-to-scf", "func::FuncOp"> {
+  let summary = "lower sparse_tensor.iterate/coiterate into scf loops";
+  let description = [{
+     This pass lowers `sparse_tensor.iterate` operations into `scf.for/while` operations.
+     The pass is not yet stablized.
+  }];
+  let constructor = "mlir::createLowerSparseIterationToSCFPass()";
+  let dependentDialects = [
+    "memref::MemRefDialect",
+    "scf::SCFDialect",
+    "sparse_tensor::SparseTensorDialect",
+  ];
+}
+
+
 #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index 2a29ee8a7a87cb..e4acfa8889e5f8 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
   SparseAssembler.cpp
   SparseBufferRewriting.cpp
   SparseGPUCodegen.cpp
+  SparseIterationToScf.cpp
   SparseReinterpretMap.cpp
   SparseStorageSpecifierToLLVM.cpp
   SparseSpaceCollapse.cpp
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
new file mode 100644
index 00000000000000..d89b0b192ffcd2
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -0,0 +1,78 @@
+
+#include "Utils/CodegenUtils.h"
+#include "Utils/SparseTensorIterator.h"
+
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+#include "mlir/Transforms/OneToNTypeConversion.h"
+
+using namespace mlir;
+using namespace mlir::sparse_tensor;
+
+void convertLevelType(SparseTensorEncodingAttr enc, Level lvl,
+                      SmallVectorImpl<Type> &fields) {
+  // Position and coordinate buffer in the sparse structure.
+  if (enc.getLvlType(lvl).isWithPosLT())
+    fields.push_back(enc.getPosMemRefType());
+  if (enc.getLvlType(lvl).isWithCrdLT())
+    fields.push_back(enc.getCrdMemRefType());
+  // One index for shape bound (result from lvlOp)
+  fields.push_back(IndexType::get(enc.getContext()));
+}
+
+static std::optional<LogicalResult>
+convertIterSpaceType(IterSpaceType itSp, SmallVectorImpl<Type> &fields) {
+
+  auto idxTp = IndexType::get(itSp.getContext());
+  for (Level l = itSp.getLoLvl(); l < itSp.getHiLvl(); l++)
+    convertLevelType(itSp.getEncoding(), l, fields);
+
+  // Two indices for lower and upper bound (we only need one pair for the last
+  // iteration space).
+  fields.append({idxTp, idxTp});
+  return success();
+}
+
+namespace {
+
+/// Sparse codegen rule for number of entries operator.
+class ExtractIterSpaceConverter
+    : public OneToNOpConversionPattern<ExtractIterSpaceOp> {
+public:
+  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+  LogicalResult
+  matchAndRewrite(ExtractIterSpaceOp op, OpAdaptor adaptor,
+                  OneToNPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
+
+    // Construct the iteration space.
+    SparseIterationSpace space(loc, rewriter, op.getTensor(), 0,
+                               op.getLvlRange(), adaptor.getParentIter());
+
+    SmallVector<Value> result = space.toValues();
+    rewriter.replaceOp(op, result, resultMapping);
+    return success();
+  }
+};
+
+} // namespace
+
+mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
+  addConversion([](Type type) { return type; });
+  addConversion(convertIterSpaceType);
+
+  addSourceMaterialization([](OpBuilder &builder, IterSpaceType spTp,
+                              ValueRange inputs,
+                              Location loc) -> std::optional<Value> {
+    return builder
+        .create<UnrealizedConversionCastOp>(loc, TypeRange(spTp), inputs)
+        .getResult(0);
+  });
+}
+
+void mlir::populateLowerSparseIterationToSCFPatterns(
+    TypeConverter &converter, RewritePatternSet &patterns) {
+  patterns.add<ExtractIterSpaceConverter>(converter, patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index b42d58634a36c4..ffbc85e9a17f5e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -26,6 +26,7 @@ namespace mlir {
 #define GEN_PASS_DEF_SPARSEREINTERPRETMAP
 #define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE
 #define GEN_PASS_DEF_SPARSIFICATIONPASS
+#define GEN_PASS_DEF_LOWERSPARSEITERATIONTOSCF
 #define GEN_PASS_DEF_LOWERSPARSEOPSTOFOREACH
 #define GEN_PASS_DEF_LOWERFOREACHTOSCF
 #define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
@@ -153,10 +154,34 @@ struct LowerForeachToSCFPass
     auto *ctx = &getContext();
     RewritePatternSet patterns(ctx);
     populateLowerForeachToSCFPatterns(patterns);
+
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
 };
 
+struct LowerSparseIterationToSCFPass
+    : public impl::LowerSparseIterationToSCFBase<
+          LowerSparseIterationToSCFPass> {
+  LowerSparseIterationToSCFPass() = default;
+  LowerSparseIterationToSCFPass(const LowerSparseIterationToSCFPass &) =
+      default;
+
+  void runOnOperation() override {
+    auto *ctx = &getContext();
+    RewritePatternSet patterns(ctx);
+    SparseIterationTypeConverter converter;
+    ConversionTarget target(*ctx);
+
+    // The actual conversion.
+    target.addIllegalOp<ExtractIterSpaceOp, IterateOp>();
+    populateLowerSparseIterationToSCFPatterns(converter, patterns);
+
+    if (failed(applyPartialOneToNConversion(getOperation(), converter,
+                                            std::move(patterns))))
+      signalPassFailure();
+  }
+};
+
 struct SparseTensorConversionPass
     : public impl::SparseTensorConversionPassBase<SparseTensorConversionPass> {
   SparseTensorConversionPass() = default;
@@ -439,6 +464,10 @@ std::unique_ptr<Pass> mlir::createLowerForeachToSCFPass() {
   return std::make_unique<LowerForeachToSCFPass>();
 }
 
+std::unique_ptr<Pass> mlir::createLowerSparseIterationToSCFPass() {
+  return std::make_unique<LowerSparseIterationToSCFPass>();
+}
+
 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
   return std::make_unique<SparseTensorConversionPass>();
 }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
index 59c3e49264dbe1..34312df912997b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
@@ -222,7 +222,7 @@ class LoopEmitter {
   ///
   SmallVector<Value> getValPosits(TensorId tid) const {
     SmallVector<Value> batchCrds = iters[tid].back().back()->getBatchCrds();
-    Value lastLvlPos = iters[tid].back().back()->getCurPosition().first;
+    Value lastLvlPos = iters[tid].back().back()->getCurPosition().front();
     batchCrds.push_back(lastLvlPos);
     return batchCrds;
   };
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
index 60dca3c55dec3d..9378cb267bf788 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
@@ -94,8 +94,10 @@ class DenseLevel : public SparseTensorLevel {
 
   ValueRange getLvlBuffers() const override { return {}; }
 
-  ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange, Value p,
-                        Value max) const override {
+  ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
+                        ValueRange parentPos) const override {
+    assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
+    Value p = parentPos.front();
     Value posLo = MULI(p, lvlSize);
     return {posLo, lvlSize};
   }
@@ -112,9 +114,9 @@ class BatchLevel : public SparseTensorLevel {
 
   ValueRange getLvlBuffers() const override { return {}; }
 
-  ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange, Value p,
-                        Value max) const override {
-    assert(max == nullptr && "Dense level can not be non-unique.");
+  ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange,
+                        ValueRange parentPos) const override {
+    assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
     // No need to linearize the position for non-annotated tensors.
     return {C_IDX(0), lvlSize};
   }
@@ -127,9 +129,11 @@ class CompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
       : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
-                        Value p, Value max) const override {
-    assert(max == nullptr &&
+                        ValueRange parentPos) const override {
+
+    assert(parentPos.size() == 1 &&
            "compressed level must be the first non-unique level.");
+    Value p = parentPos.front();
 
     SmallVector<Value> memCrd(batchPrefix);
     memCrd.push_back(p);
@@ -147,11 +151,11 @@ class LooseCompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
       : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
-                        Value p, Value max) const override {
-    assert(max == nullptr &&
+                        ValueRange parentPos) const override {
+    assert(parentPos.size() == 1 &&
            "loose-compressed level must be the first non-unique level.");
     SmallVector<Value> memCrd(batchPrefix);
-
+    Value p = parentPos.front();
     p = MULI(p, C_IDX(2));
     memCrd.push_back(p);
     Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
@@ -168,13 +172,23 @@ class SingletonLevel : public SparseLevel</*hasPosBuf=*/false> {
       : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
-                        Value p, Value segHi) const override {
+                        ValueRange parentPos) const override {
+    assert(parentPos.size() == 1 || parentPos.size() == 2);
+    Value p = parentPos.front();
+    Value segHi = parentPos.size() == 2 ? parentPos.back() : nullptr;
+
     if (segHi == nullptr)
       return {p, ADDI(p, C_IDX(1))};
-
     // Use the segHi as the loop upper bound.
     return {p, segHi};
   }
+
+  ValuePair
+  collapseRangeBetween(OpBuilder &b, Location l, ValueRange batchPrefix,
+                       std::pair<Value, Value> parentRange) const override {
+    // Singleton level keeps the same range after collapsing.
+    return parentRange;
+  };
 };
 
 class NOutOfMLevel : public SparseLevel</*hasPosBuf=*/false> {
@@ -184,11 +198,12 @@ class NOutOfMLevel : public SparseLevel</*hasPosBuf=*/false> {
       : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
-                        Value p, Value max) const override {
-    assert(max == nullptr && isUnique() && "n:m level can not be non-unique.");
+                        ValueRange parentPos) const override {
+    assert(parentPos.size() == 1 && isUnique() &&
+           "n:m level can not be non-unique.");
     // Each n:m blk has exactly n specified elements.
     auto n = getN(lt);
-    Value posLo = MULI(p, C_IDX(n));
+    Value posLo = MULI(parentPos.front(), C_IDX(n));
     return {posLo, ADDI(posLo, C_IDX(n))};
   }
 };
@@ -316,23 +331,21 @@ class TrivialIterator : public ConcreteIterator {
       posHi = vs.back();
   };
 
-  ValuePair getCurPosition() const override { return {getItPos(), nullptr}; }
-
   void genInitImpl(OpBuilder &b, Location l,
                    const SparseIterator *parent) override {
 
     if (isBatchIterator() && batchCrds.size() <= stl.lvl)
       batchCrds.resize(stl.lvl + 1, nullptr);
 
-    Value pos = C_IDX(0);
-    Value hi = nullptr;
+    Value c0 = C_IDX(0);
+    ValueRange pPos = c0;
     // If the parent iterator is a batch iterator, we also start from 0 (but
     // on a different batch).
     if (parent && !parent->isBatchIterator())
-      std::tie(pos, hi) = parent->getCurPosition();
+      pPos = parent->getCurPosition();
 
     ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
-    std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pos, hi);
+    std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pPos);
     // Seek to the lowest position.
     seek(posLo);
   }
@@ -406,21 +419,19 @@ class DedupIterator : public ConcreteIterator {
     return {b.getIndexType(), b.getIndexType()};
   }
 
-  ValuePair getCurPosition() const override { return {getPos(), getSegHi()}; }
-
   void genInitImpl(OpBuilder &b, Location l,
                    const SparseIterator *parent) override {
+    Value c0 = C_IDX(0);
+    ValueRange pPos = c0;
 
-    Value pos = C_IDX(0);
-    Value hi = nullptr;
     // If the parent iterator is a batch iterator, we also start from 0 (but
     // on a different batch).
     if (parent && !parent->isBatchIterator())
-      std::tie(pos, hi) = parent->getCurPosition();
+      pPos = parent->getCurPosition();
 
     Value posLo;
     ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
-    std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pos, hi);
+    std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pPos);
 
     seek({posLo, genSegmentHigh(b, l, posLo)});
   }
@@ -505,7 +516,7 @@ class FilterIterator : public SparseIterator {
 
   SmallVector<Value> serialize() const override { return wrap->serialize(); };
   void deserialize(ValueRange vs) override { wrap->deserialize(vs); };
-  ValuePair getCurPosition() const override { return wrap->getCurPosition(); }
+  ValueRange getCurPosition() const override { return wrap->getCurPosition(); }
 
   void genInitImpl(OpBuilder &b, Location l,
                    const SparseIterator *parent) override {
@@ -756,9 +767,8 @@ class SubSectIterator : public SparseIterator {
   Value upperBound(OpBuilder &b, Location l) const override {
     return subSect.subSectSz;
   }
-  std::pair<Value, Value> getCurPosition() const override {
-    return wrap->getCurPosition();
-  };
+
+  ValueRange getCurPosition() const override { return wrap->getCurPosition(); };
 
   Value getNxLvlTupleId(OpBuilder &b, Location l) const {
     if (randomAccessible()) {
@@ -1328,10 +1338,85 @@ ValueRange NonEmptySubSectIterator::forwardImpl(OpBuilder &b, Location l) {
   return getCursor();
 }
 
+//===----------------------------------------------------------------------===//
+// SparseIterationSpace Implementation
+//===----------------------------------------------------------------------===//
+
+mlir::sparse_tensor::SparseIterationSpace::SparseIterationSpace(
+    Location l, OpBuilder &b, Value t, unsigned tid,
+    std::pair<Level, Level> lvlRange, ValueRange parentPos)
+    : lvls() {
+  auto [lvlLo, lvlHi] = lvlRange;
+
+  Value c0 = C_IDX(0);
+  if (parentPos.empty())
+    parentPos = c0;
+
+  for (Level lvl = lvlLo; lvl < lvlHi; lvl++)
+    lvls.emplace_back(makeSparseTensorLevel(b, l, t, tid, lvl));
+
+  bound = lvls.front()->peekRangeAt(b, l, /*batchPrefix=*/{}, parentPos);
+  for (auto &lvl : getLvlRef().drop_front())
+    bound = lvl->collapseRangeBetween(b, l, /*batchPrefix=*/{}, bound);
+}
+
+SparseIterationSpace mlir::sparse_tensor::SparseIterationSpace::fromValues(
+    IterSpaceType dstTp, ValueRange values, unsigned int tid) {
+  // Reconstruct every sparse tensor level.
+  SparseIterationSpace space;
+  for (auto [i, lt] : llvm::enumerate(dstTp.getLvlTypes())) {
+    unsigned bufferCnt = 0;
+    if (lt.isWithPosLT())
+      bufferCnt++;
+    if (lt.isWithCrdLT())
+      bufferCnt++;
+    // Sparse tensor buffers.
+    ValueRange buffers = values.take_front(bufferCnt);
+    values = values.drop_front(bufferCnt);
+
+    // Level size.
+    Value sz = values.front();
+    values = values.drop_front();
+    space.lvls.push_back(
+        makeSparseTensorLevel(lt, sz, buffers, tid, i + dstTp.getLoLvl()));
+  }
+  // Two bounds.
+  space.bound = std::make_pair(values[0], values[1]);
+  values = values.drop_front(2);
+
+  // Must have consumed all values.
+  assert(values.empty());
+  return space;
+}
+
 //===----------------------------------------------------------------------===//
 // SparseIterator factory functions.
 //===----------------------------------------------------------------------===//
 
+/// Helper function to create a TensorLevel object from given `tensor`.
+std::unique_ptr<SparseTensorLevel>
+sparse_tensor::makeSparseTensorLevel(LevelType lt, Value sz, ValueRange b,
+                                     unsigned t, Level l) {
+  assert(lt.getNumBuffer() == b.size());
+  switch (lt.getLvlFmt()) {
+  case LevelFormat::Dense:
+    return std::make_unique<DenseLevel>(t, l, sz);
+  case LevelFormat::Batch:
+    return std::make_unique<BatchLevel>(t, l, sz);
+  case LevelFormat::Compressed:
+    return std::make_unique<CompressedLevel>(t, l, lt, sz, b[0], b[1]);
+  case LevelFormat::LooseCompressed:
+    return std::make_unique<LooseCompressedLevel>(t, l, lt, sz, b[0], b[1]);
+  case LevelFormat::Singleton:
+    return std::make_unique<SingletonLevel>(t, l, lt, sz, b[0]);
+  case LevelFormat::NOutOfM:
+    return std::make_unique<NOutOfMLevel>(t, l, lt, sz, b[0]);
+  case LevelFormat::Undef:
+    llvm_unreachable("undefined level format");
+  }
+  llvm_unreachable("unrecognizable level format");
+}
+
 std::unique_ptr<SparseTensorLevel>
 sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
                                      unsigned tid, Level lvl) {
@@ -1341,33 +1426,16 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
   Value sz = stt.hasEncoding() ? b.create<LvlOp>(l, t, lvl).getResult()
                                : b.create<tensor::DimOp>(l, t, lvl).getResult();
 
-  switch (lt.getLvlFmt()) {
-  case LevelFormat::Dense:
-    return std::make_unique<DenseLevel>(tid, lvl, sz);
-  case LevelFormat::Batch:
-    return std::make_unique<BatchLevel>(tid, lvl, sz);
-  case LevelFormat::Compressed: {
-    Value pos = b.create<ToPositionsOp>(l, t, lvl);
-    Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
-    return std::make_unique<CompressedLevel>(tid, lvl, lt, sz, pos, crd);
-  }
-  case LevelFormat::LooseCompressed: {
+  SmallVector<Value, 2> buffers;
+  if (lt.isWithPosLT()) {
     Value pos = b.create<ToPositionsOp>(l, t, lvl);
-    Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
-    return std::make_unique<LooseCompressedLevel>(tid, lvl, lt, sz, pos, crd);
+    buffers.push_back(pos);
   }
-  case LevelFormat::Singleton: {
-    Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
-    return std::make_unique<SingletonLevel>(tid, lvl, lt, sz, crd);
+  if (lt.isWithCrdLT()) {
+    Value pos = b.create<ToCoordinatesOp>(l, t, lvl);
+    buffers.push_back(pos);
   }
-  case LevelFormat::NOutOfM: {
-    Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
-    return std::make_unique<NOutOfMLevel>(tid, lvl, lt, sz, crd);
-  }
-  case LevelFormat::Undef:
-    llvm_unreachable("undefined level format");
-  }
-  llvm_unreachable("unrecognizable level format");
+  return makeSparseTensorLevel(lt, sz, buffers, tid, lvl);
 }
 
 std::pair<std::unique_ptr<SparseTensorLevel>, std::unique_ptr<SparseIterator>>
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
index 9d69a233555986..29e6dcd96c2133 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
@@ -15,6 +15,9 @@
 namespace mlir {
 namespace sparse_tensor {
 
+// Forward declaration.
+class SparseIterator;
+
 /// The base class for all types of sparse tensor levels. It provides interfaces
 /// to query the loop range (see `peekRangeAt`) and look up the coordinates (see
 /// `peekCrdAt`).
@@ -49,8 +52,14 @@ class SparseTensorLevel {
   /// `bound` is only used when the level is `non-unique` and deduplication is
   /// required. It specifies the max upper bound of the non-unique segment.
   virtual std::pair<Value, Value> peekRangeAt(OpBuilder &b, Location l,
-                                              ValueRange batchPrefix, Value p,
-                                              Value segHi = Value()) const = 0;
+                                              ValueRange batchPrefix,
+                                              ValueRange parentPos) const = 0;
+
+  virtual std::pair<Value, Value>
+  collapseRangeBetween(OpBuilder &b, Location l, ValueRange batchPrefix,
+                       std::pair<Value, Value> parentRange) const {
+    llvm_unreachable("Not Implemented");
+  };
 
   Level getLevel() const { return lvl; }
   LevelType getLT() const { return lt; }
@@ -80,6 +89,52 @@ enum class IterKind : uint8_t {
   kFilter,
 };
 
+class SparseIterationSpace {
+public:
+  SparseIterationSpace() = default;
+
+  // Constructs a N-D iteration space.
+  SparseIterationSpace(Location loc, OpBuilder &b, Value t, unsigned tid,
+                       std::pair<Level, Level> lvlRange, ValueRange parentPos);
+
+  // Constructs a 1-D iteration space.
+  SparseIterationSpace(Location loc, OpBuilder &b, Value t, unsigned tid,
+                       Level lvl, ValueRange parentPos)
+      : SparseIterationSpace(loc, b, t, tid, {lvl, lvl + 1}, parentPos){};
+
+  bool isUnique() const { return lvls.back()->isUnique(); }
+
+  unsigned getSpaceDim() const { return lvls.size(); }
+
+  // Reconstructs a iteration space directly from the provided ValueRange.
+  static SparseIterationSpace fromValues(IterSpaceType dstTp, ValueRange values,
+                                         unsigned tid);
+
+  // The inverse operation of `fromValues`.
+  SmallVector<Value> toValues() const {
+    SmallVector<Value> vals;
+    for (auto &stl : lvls) {
+      llvm::append_range(vals, stl->getLvlBuffers());
+      vals.push_back(stl->getSize());
+    }
+    vals.append({bound.first, bound.second});
+    return vals;
+  }
+
+  const SparseTensorLevel &getLastLvl() const { return *lvls.back(); }
+  ArrayRef<std::unique_ptr<SparseTensorLevel>> getLvlRef() const {
+    return lvls;
+  }
+
+  Value getBoundLo() const { return bound.first; }
+  Value getBoundHi() const { return bound.second; }
+
+
+private:
+  SmallVector<std::unique_ptr<SparseTensorLevel>> lvls;
+  std::pair<Value, Value> bound;
+};
+
 /// Helper class that generates loop conditions, etc, to traverse a
 /// sparse tensor level.
 class SparseIterator {
@@ -208,9 +263,7 @@ class SparseIterator {
   // Not every type of iterator supports the operation, e.g., non-empty
   // subsection iterator does not because it represent a range of coordinates
   // instead of just one.
-  virtual std::pair<Value, Value> getCurPosition() const {
-    llvm_unreachable("unsupported");
-  };
+  virtual ValueRange getCurPosition() const { return getCursor(); };
 
   // Returns a pair of values for *upper*, *lower* bound respectively.
   virtual std::pair<Value, Value> genForCond(OpBuilder &b, Location l) {
@@ -288,10 +341,15 @@ std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(OpBuilder &builder,
                                                          Location loc, Value t,
                                                          unsigned tid, Level l);
 
-/// Helper function to create a simple SparseIterator object that iterate over
-/// the SparseTensorLevel.
-std::unique_ptr<SparseIterator> makeSimpleIterator(const SparseTensorLevel &stl,
-                                                   SparseEmitStrategy strategy);
+/// Helper function to create a TensorLevel object from given `tensor`.
+std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(LevelType lt, Value sz,
+                                                         ValueRange buffers,
+                                                         unsigned tid, Level l);
+/// Helper function to create a simple SparseIterator object that iterate
+/// over the SparseTensorLevel.
+std::unique_ptr<SparseIterator> makeSimpleIterator(
+    const SparseTensorLevel &stl,
+    SparseEmitStrategy strategy = SparseEmitStrategy::kFunctional);
 
 /// Helper function to create a synthetic SparseIterator object that iterate
 /// over a dense space specified by [0,`sz`).
diff --git a/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir b/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir
new file mode 100644
index 00000000000000..b51c11ebf8a8c1
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir
@@ -0,0 +1,23 @@
+// RUN: mlir-opt %s --lower-sparse-iteration-to-scf | FileCheck %s
+
+#COO = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : compressed(nonunique),
+    j : singleton(soa)
+  )
+}>
+
+// CHECK-LABEL:   func.func @sparse_1D_space(
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<?x?xf32, #sparse{{[0-9]*}}>) -> !sparse_tensor.iter_space<#sparse{{[0-9]*}}, lvls = 0> {
+// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[LVL_SIZE:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[C0]] : tensor<?x?xf32, #sparse{{[0-9]*}}>
+// CHECK:           %[[POS_MEM:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK:           %[[CRD_MEM:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK:           %[[POS_LO:.*]] = memref.load %[[POS_MEM]]{{\[}}%[[C0]]] : memref<?xindex>
+// CHECK:           %[[POS_HI:.*]] = memref.load %[[POS_MEM]]{{\[}}%[[C1]]] : memref<?xindex>
+// CHECK:           %[[ITER_SPACE:.*]] = builtin.unrealized_conversion_cast %[[POS_MEM]], %[[CRD_MEM]], %[[LVL_SIZE]], %[[POS_LO]], %[[POS_HI]]
+func.func @sparse_1D_space(%sp : tensor<?x?xf32, #COO>) -> !sparse_tensor.iter_space<#COO, lvls = 0> {
+  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<?x?xf32, #COO>
+  return %l1 : !sparse_tensor.iter_space<#COO, lvls = 0>
+}



More information about the Mlir-commits mailing list