[Mlir-commits] [mlir] [mlir][sparse] implement lowering rules for ExtractIterSpaceOp. (PR #89143)

Peiming Liu llvmlistbot at llvm.org
Tue May 21 14:39:18 PDT 2024


https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/89143

>From 221c7a85de33645cc88c980447bf4678000e4748 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 12 Apr 2024 22:03:06 +0000
Subject: [PATCH 1/8] [mlir][sparse] introduce `sparse_tensor.iterate`
 operation

---
 .../Dialect/SparseTensor/IR/SparseTensor.h    |  38 +++
 .../SparseTensor/IR/SparseTensorAttrDefs.td   |  15 ++
 .../SparseTensor/IR/SparseTensorOps.td        | 101 ++++++-
 .../SparseTensor/IR/SparseTensorDialect.cpp   | 255 ++++++++++++++++++
 mlir/test/Dialect/SparseTensor/invalid.mlir   |  57 ++++
 mlir/test/Dialect/SparseTensor/roundtrip.mlir |  28 ++
 .../SparseTensor/sparse_itertion_licm.mlir    |  26 ++
 7 files changed, 519 insertions(+), 1 deletion(-)
 create mode 100644 mlir/test/Dialect/SparseTensor/sparse_itertion_licm.mlir

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 3cf81d2e58f21..239118a6575d8 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -17,9 +17,13 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/TensorEncoding.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
+#include "llvm/ADT/bit.h"
+
 //===----------------------------------------------------------------------===//
 //
 // Type aliases to help code be more self-documenting. Unfortunately
@@ -54,6 +58,40 @@ struct COOSegment {
   }
 };
 
+/// A simple wrapper to encode a bitset of defined  (at most 64) levels.
+class LevelSet {
+  uint64_t bits = 0;
+
+public:
+  LevelSet() = default;
+  explicit LevelSet(uint64_t bits) : bits(bits) {}
+  operator uint64_t() const { return bits; }
+
+  LevelSet &set(unsigned i) {
+    assert(i < 64);
+    bits |= 1 << i;
+    return *this;
+  }
+
+  LevelSet &operator|=(LevelSet lhs) {
+    bits |= static_cast<uint64_t>(lhs);
+    return *this;
+  }
+
+  LevelSet &lshift(unsigned offset) {
+    bits = bits << offset;
+    return *this;
+  }
+
+  bool operator[](unsigned i) const {
+    assert(i < 64);
+    return (bits & (1 << i)) != 0;
+  }
+
+  unsigned count() const { return llvm::popcount(bits); }
+  bool empty() const { return bits == 0; }
+};
+
 } // namespace sparse_tensor
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 53dd8e39438cc..c26193af09e81 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -19,6 +19,21 @@ class SparseTensor_Attr<string name,
                         list<Trait> traits = []>
     : AttrDef<SparseTensor_Dialect, name, traits>;
 
+//===----------------------------------------------------------------------===//
+// A simple bitset attribute wrapped over a single int64_t to encode a set of
+// sparse tensor levels.
+//===----------------------------------------------------------------------===//
+
+def LevelSetAttr :
+    TypedAttrBase<
+      I64, "IntegerAttr",
+      And<[CPred<"::llvm::isa<::mlir::IntegerAttr>($_self)">,
+           CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getType().isInteger(64)">]>,
+      "LevelSet attribute"> {
+  let returnType = [{::mlir::sparse_tensor::LevelSet}];
+  let convertFromStorage = [{::mlir::sparse_tensor::LevelSet($_self.getValue().getZExtValue())}];
+}
+
 //===----------------------------------------------------------------------===//
 // These attributes are just like `IndexAttr` except that they clarify whether
 // the index refers to a dimension (an axis of the semantic tensor) or a level
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 4e4441c640ed9..0ef03a5afa733 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -15,6 +15,8 @@ include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td"
 include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/ControlFlowInterfaces.td"
+include "mlir/Interfaces/LoopLikeInterface.td"
 
 //===----------------------------------------------------------------------===//
 // Base class.
@@ -1304,7 +1306,7 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [Pure, SameOperandsAndResu
 
 def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator,
     ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp",
-                 "ForeachOp"]>]> {
+                 "ForeachOp", "IterateOp"]>]> {
   let summary = "Yield from sparse_tensor set-like operations";
   let description = [{
       Yields a value from within a `binary`, `unary`, `reduce`,
@@ -1513,6 +1515,103 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
   let hasVerifier = 1;
 }
 
+def IterateOp : SparseTensor_Op<"iterate",
+    [RecursiveMemoryEffects, RecursivelySpeculatable,
+     DeclareOpInterfaceMethods<LoopLikeOpInterface,
+      ["getInitsMutable", "getLoopResults", "getRegionIterArgs",
+       "getYieldedValuesMutable"]>,
+     DeclareOpInterfaceMethods<RegionBranchOpInterface,
+      ["getEntrySuccessorOperands"]>,
+     SingleBlockImplicitTerminator<"sparse_tensor::YieldOp">]> {
+
+  let summary = "Iterate over a sparse iteration space";
+  let description = [{
+      The `sparse_tensor.iterate` operations represents a loop over the
+      provided iteration space extracted from a specific sparse tensor.
+      The operation defines an SSA value for a sparse iterator that points
+      to the current stored element in the sparse tensor and SSA values
+      for coordinates of the stored element. The coordinates are always
+      converted to `index` type despite of the underlying sparse tensor
+      storage. When coordinates are not used, the SSA values can be skipped
+      by `_` symbols, which usually leads to simpler generated code after
+      sparsification. For example:
+
+      ```mlir
+      // The coordinate for level 0 is not used when iterating over a 2-D
+      // iteration space.
+      %sparse_tensor.iterate %iterator in %space at(_, %crd_1)
+        : !sparse_tensor.iter_space<#CSR, lvls = 0 to 2>
+      ```
+
+      `sparse_tensor.iterate` can also operate on loop-carried variables
+      and returns the final values after loop termination.
+      The initial values of the variables are passed as additional SSA operands
+      to the iterator SSA value and used coordinate SSA values mentioned
+      above. The operation region has an argument for the iterator, variadic
+      arguments for specified (used) coordiates and followed by one argument
+      for each loop-carried variable, representing the value of the variable
+      at the current iteration.
+      The body region must contain exactly one block that terminates with
+      `sparse_tensor.yield`.
+
+      `sparse_tensor.iterate` results hold the final values after the last
+      iteration. If the `sparse_tensor.iterate` defines any values, a yield
+      must be explicitly present.
+      The number and types of the `sparse_tensor.iterate` results must match
+      the initial values in the iter_args binding and the yield operands.
+
+
+      A nested `sparse_tensor.iterate` example that prints all the coordinates
+      stored in the sparse input:
+
+      ```mlir
+      func.func @nested_iterate(%sp : tensor<4x8xf32, #COO>) {
+        // Iterates over the first level of %sp
+        %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+        %r1 = sparse_tensor.iterate %it1 in %l1 at (%crd0)
+            : !sparse_tensor.iter_space<#COO, lvls = 0 to 1>  {
+          // Iterates over the second level of %sp
+          %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1
+              : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
+          %r2 = sparse_tensor.iterate %it2 in %l2 at (crd1)
+              : !sparse_tensor.iter_space<#COO, lvls = 1 to 2>  {
+             vector.print %crd0 : index
+             vector.print %crd1 : index
+          }
+        }
+      }
+
+      ```
+  }];
+
+  let arguments = (ins AnySparseIterSpace:$iterSpace,
+                       Variadic<AnyType>:$initArgs,
+                       LevelSetAttr:$crdUsedLvls);
+  let results = (outs Variadic<AnyType>:$results);
+  let regions = (region SizedRegion<1>:$region);
+
+  let extraClassDeclaration = [{
+    unsigned getSpaceDim() {
+      return getIterSpace().getType().getSpaceDim();
+    }
+    BlockArgument getIterator() {
+      return getRegion().getArguments().front();
+    }
+    Block::BlockArgListType getCrds() {
+      // The first block argument is iterator, the remaining arguments are
+      // referenced coordinates.
+      return getRegion().getArguments().slice(1, getCrdUsedLvls().count());
+    }
+    unsigned getNumRegionIterArgs() {
+      return getRegion().getArguments().size() - 1 - getCrdUsedLvls().count();
+    }
+  }];
+
+  let hasVerifier = 1;
+  let hasRegionVerifier = 1;
+  let hasCustomAssemblyFormat = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // Sparse Tensor Debugging and Test-Only Operations.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 4adb1c19096a2..0a9254573538d 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -2130,6 +2130,106 @@ static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
   printLevelRange(p, lo, hi);
 }
 
+static ParseResult
+parseSparseSpaceLoop(OpAsmParser &parser, OperationState &state,
+                     SmallVectorImpl<OpAsmParser::Argument> &iterators,
+                     SmallVectorImpl<OpAsmParser::Argument> &iterArgs) {
+  SmallVector<OpAsmParser::UnresolvedOperand> spaces;
+  SmallVector<OpAsmParser::UnresolvedOperand> initArgs;
+
+  // Parses "%iters, ... in %spaces, ..."
+  if (parser.parseArgumentList(iterators) || parser.parseKeyword("in") ||
+      parser.parseOperandList(spaces))
+    return failure();
+
+  if (iterators.size() != spaces.size())
+    return parser.emitError(
+        parser.getNameLoc(),
+        "mismatch in number of sparse iterators and sparse spaces");
+
+  // Parse "at(%crd0, _, ...)"
+  LevelSet crdUsedLvlSet;
+  bool hasUsedCrds = succeeded(parser.parseOptionalKeyword("at"));
+  unsigned lvlCrdCnt = 0;
+  if (hasUsedCrds) {
+    ParseResult crdList = parser.parseCommaSeparatedList(
+        OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
+          if (parser.parseOptionalKeyword("_")) {
+            if (parser.parseArgument(iterArgs.emplace_back()))
+              return failure();
+            // Always use IndexType for the coordinate.
+            crdUsedLvlSet.set(lvlCrdCnt);
+            iterArgs.back().type = parser.getBuilder().getIndexType();
+          }
+          lvlCrdCnt += 1;
+          return success();
+        });
+    if (failed(crdList)) {
+      return parser.emitError(
+          parser.getNameLoc(),
+          "expecting SSA value or \"_\" for level coordinates");
+    }
+  }
+  // Set the CrdUsedLvl bitset.
+  state.addAttribute("crdUsedLvls",
+                     parser.getBuilder().getI64IntegerAttr(crdUsedLvlSet));
+
+  // Parse "iter_args(%arg = %init, ...)"
+  bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
+  if (hasIterArgs)
+    if (parser.parseAssignmentList(iterArgs, initArgs))
+      return failure();
+
+  SmallVector<Type> iterSpaceTps;
+  // parse ": sparse_tensor.iter_space -> ret"
+  if (parser.parseColon() || parser.parseTypeList(iterSpaceTps))
+    return failure();
+  if (iterSpaceTps.size() != spaces.size())
+    return parser.emitError(parser.getNameLoc(),
+                            "mismatch in number of iteration space operands "
+                            "and iteration space types");
+
+  for (auto [it, tp] : llvm::zip_equal(iterators, iterSpaceTps)) {
+    IterSpaceType spaceTp = llvm::dyn_cast<IterSpaceType>(tp);
+    if (!spaceTp)
+      return parser.emitError(parser.getNameLoc(),
+                              "expected sparse_tensor.iter_space type for "
+                              "iteration space operands");
+    if (hasUsedCrds && spaceTp.getSpaceDim() != lvlCrdCnt)
+      return parser.emitError(parser.getNameLoc(),
+                              "mismatch in number of iteration space dimension "
+                              "and specified coordinates");
+    it.type = spaceTp.getIteratorType();
+  }
+
+  if (hasIterArgs)
+    if (parser.parseArrowTypeList(state.types))
+      return failure();
+
+  // Resolves input operands.
+  if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(),
+                             state.operands))
+    return failure();
+
+  if (hasIterArgs) {
+    unsigned numCrds = crdUsedLvlSet.count();
+    // Strip off leading args that used for coordinates.
+    MutableArrayRef args = MutableArrayRef(iterArgs).drop_front(numCrds);
+    if (args.size() != initArgs.size() || args.size() != state.types.size()) {
+      return parser.emitError(
+          parser.getNameLoc(),
+          "mismatch in number of iteration arguments and return values");
+    }
+
+    for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) {
+      it.type = tp;
+      if (parser.resolveOperand(init, tp, state.operands))
+        return failure();
+    }
+  }
+  return success();
+}
+
 LogicalResult ExtractIterSpaceOp::inferReturnTypes(
     MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
     DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
@@ -2166,6 +2266,161 @@ LogicalResult ExtractIterSpaceOp::verify() {
   return success();
 }
 
+ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
+  OpAsmParser::Argument iterator;
+  OpAsmParser::UnresolvedOperand iterSpace;
+
+  SmallVector<OpAsmParser::Argument> iters, iterArgs;
+  if (parseSparseSpaceLoop(parser, result, iters, iterArgs))
+    return failure();
+  if (iters.size() != 1)
+    return parser.emitError(parser.getNameLoc(),
+                            "expected only one iterator/iteration space");
+
+  iters.append(iterArgs);
+  Region *body = result.addRegion();
+  if (parser.parseRegion(*body, iters))
+    return failure();
+
+  IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
+
+  // Parse the optional attribute list.
+  if (parser.parseOptionalAttrDict(result.attributes))
+    return failure();
+
+  return success();
+}
+
+/// Prints the initialization list in the form of
+///   <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
+/// where 'inner' values are assumed to be region arguments and 'outer' values
+/// are regular SSA values.
+static void printInitializationList(OpAsmPrinter &p,
+                                    Block::BlockArgListType blocksArgs,
+                                    ValueRange initializers,
+                                    StringRef prefix = "") {
+  assert(blocksArgs.size() == initializers.size() &&
+         "expected same length of arguments and initializers");
+  if (initializers.empty())
+    return;
+
+  p << prefix << '(';
+  llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
+    p << std::get<0>(it) << " = " << std::get<1>(it);
+  });
+  p << ")";
+}
+
+static void printUsedCrdsList(OpAsmPrinter &p, unsigned spaceDim,
+                              Block::BlockArgListType blocksArgs,
+                              LevelSet crdUsedLvls) {
+  if (crdUsedLvls.empty())
+    return;
+
+  p << " at(";
+  for (unsigned i = 0; i < spaceDim; i++) {
+    if (crdUsedLvls[i]) {
+      p << blocksArgs.front();
+      blocksArgs = blocksArgs.drop_front();
+    } else {
+      p << "_";
+    }
+    if (i != spaceDim - 1)
+      p << ", ";
+  }
+  assert(blocksArgs.empty());
+  p << ")";
+}
+
+void IterateOp::print(OpAsmPrinter &p) {
+  p << " " << getIterator() << " in " << getIterSpace();
+  printUsedCrdsList(p, getSpaceDim(), getCrds(), getCrdUsedLvls());
+  printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
+
+  p << " : " << getIterSpace().getType() << " ";
+  if (!getInitArgs().empty())
+    p << "-> (" << getInitArgs().getTypes() << ") ";
+
+  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
+                /*printBlockTerminators=*/!getInitArgs().empty());
+}
+
+LogicalResult IterateOp::verify() {
+  if (getInitArgs().size() != getNumResults()) {
+    return emitOpError(
+        "mismatch in number of loop-carried values and defined values");
+  }
+  return success();
+}
+
+LogicalResult IterateOp::verifyRegions() {
+  if (getIterator().getType() != getIterSpace().getType().getIteratorType())
+    return emitOpError("mismatch in iterator and iteration space type");
+  if (getNumRegionIterArgs() != getNumResults())
+    return emitOpError(
+        "mismatch in number of basic block args and defined values");
+
+  auto initArgs = getInitArgs();
+  auto iterArgs = getRegionIterArgs();
+  auto yieldVals = getYieldedValues();
+  auto opResults = getResults();
+  if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
+                        opResults.size()})) {
+    return emitOpError() << "number mismatch between iter args and results.";
+  }
+
+  for (auto [i, init, iter, yield, ret] :
+       llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) {
+    if (init.getType() != ret.getType())
+      return emitOpError() << "types mismatch between " << i
+                           << "th iter operand and defined value";
+    if (iter.getType() != ret.getType())
+      return emitOpError() << "types mismatch between " << i
+                           << "th iter region arg and defined value";
+    if (yield.getType() != ret.getType())
+      return emitOpError() << "types mismatch between " << i
+                           << "th yield value and defined value";
+  }
+
+  return success();
+}
+
+/// IterateOp implemented OpInterfaces' methods.
+SmallVector<Region *> IterateOp::getLoopRegions() { return {&getRegion()}; }
+
+MutableArrayRef<OpOperand> IterateOp::getInitsMutable() {
+  return getInitArgsMutable();
+}
+
+Block::BlockArgListType IterateOp::getRegionIterArgs() {
+  return getRegion().getArguments().take_back(getNumRegionIterArgs());
+}
+
+std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() {
+  return cast<sparse_tensor::YieldOp>(
+             getRegion().getBlocks().front().getTerminator())
+      .getResultsMutable();
+}
+
+std::optional<ResultRange> IterateOp::getLoopResults() { return getResults(); }
+
+OperandRange IterateOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+  return getInitArgs();
+}
+
+void IterateOp::getSuccessorRegions(RegionBranchPoint point,
+                                    SmallVectorImpl<RegionSuccessor> &regions) {
+  // Both the operation itself and the region may be branching into the body or
+  // back into the operation itself.
+  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
+  // It is possible for loop not to enter the body.
+  regions.push_back(RegionSuccessor(getResults()));
+}
+
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Dialect Setups.
+//===----------------------------------------------------------------------===//
+
 /// Materialize a single constant operation from a given attribute value with
 /// the desired resultant type.
 Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 3fa696e1600a9..b13024cd4ed99 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -1094,3 +1094,60 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse
   %l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 2 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
   return
 }
+
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : compressed(nonunique),
+    j : singleton(soa)
+  )
+}>
+
+func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index, %j : index) -> index {
+  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+  // expected-error @+1 {{'sparse_tensor.iterate' op different number of region iter_args and yielded values: 2 != 1}}
+  %r1, %r2 = sparse_tensor.iterate %it1 in %l1 at (%crd) iter_args(%si = %i, %sj = %j): !sparse_tensor.iter_space<#COO, lvls = 0> -> (index, index) {
+    sparse_tensor.yield %si : index
+  }
+  return %r1 : index
+}
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : compressed(nonunique),
+    j : singleton(soa)
+  )
+}>
+
+// expected-note at +1 {{prior use here}}
+func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index) -> f32 {
+  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+  // expected-error @+1 {{use of value '%i' expects different type than prior uses: 'f32' vs 'index'}}
+  %r1 = sparse_tensor.iterate %it1 in %l1 at (%crd) iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0> -> f32 {
+    sparse_tensor.yield %outer : f32
+  }
+  return %r1 : f32
+}
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : compressed(nonunique),
+    j : singleton(soa)
+  )
+}>
+
+func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index, %j : index) -> index {
+  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+  // expected-error @+1 {{'sparse_tensor.iterate' op 0-th region iter_arg and 0-th yielded value have different type: 'index' != 'f32'}}
+  %r1 = sparse_tensor.iterate %it1 in %l1 at (%crd) iter_args(%si = %i): !sparse_tensor.iter_space<#COO, lvls = 0> -> index {
+    %y = arith.constant 1.0 :  f32
+    sparse_tensor.yield %y : f32
+  }
+  return %r1 : index
+}
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index d34071279e512..e9a898f16b41d 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -763,3 +763,31 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse
   %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
   return %l1, %l2 : !sparse_tensor.iter_space<#COO, lvls = 0>, !sparse_tensor.iter_space<#COO, lvls = 1>
 }
+
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : compressed(nonunique),
+    j : singleton(soa)
+  )
+}>
+
+// CHECK-LABEL:   func.func @sparse_iterate(
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<4x8xf32, #sparse{{[0-9]*}}>,
+// CHECK-SAME:      %[[VAL_1:.*]]: index,
+// CHECK-SAME:      %[[VAL_2:.*]]: index) -> index {
+// CHECK:           %[[VAL_3:.*]] = sparse_tensor.extract_iteration_space %[[VAL_0]] lvls = 0 : tensor<4x8xf32, #sparse{{[0-9]*}}>
+// CHECK:           %[[VAL_4:.*]] = sparse_tensor.iterate %[[VAL_5:.*]] in %[[VAL_3]] at(%[[VAL_6:.*]]) iter_args(%[[VAL_7:.*]] = %[[VAL_1]]) : !sparse_tensor.iter_space<#sparse{{[0-9]*}}, lvls = 0> -> (index) {
+// CHECK:             sparse_tensor.yield %[[VAL_7]] : index
+// CHECK:           }
+// CHECK:           return %[[VAL_4]] : index
+// CHECK:         }
+func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index, %j : index) -> index {
+  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+  %r1 = sparse_tensor.iterate %it1 in %l1 at (%crd) iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0 to 1> -> index {
+    sparse_tensor.yield %outer : index
+  }
+  return %r1 : index
+}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_itertion_licm.mlir b/mlir/test/Dialect/SparseTensor/sparse_itertion_licm.mlir
new file mode 100644
index 0000000000000..e7158d04b37fe
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_itertion_licm.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt %s --loop-invariant-code-motion | FileCheck %s
+
+#CSR = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : dense,
+    j : compressed
+  )
+}>
+
+// Make sure that pure instructions are hoisted outside the loop.
+//
+// CHECK: sparse_tensor.values
+// CHECK: sparse_tensor.positions
+// CHECK: sparse_tensor.coordinate
+// CHECK: sparse_tensor.iterate
+func.func @sparse_iterate(%sp : tensor<?x?xf64, #CSR>) {
+  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<?x?xf64, #CSR>
+  sparse_tensor.iterate %it1 in %l1 at (%crd) : !sparse_tensor.iter_space<#CSR, lvls = 0> {
+    %0 = sparse_tensor.values %sp : tensor<?x?xf64, #CSR> to memref<?xf64>
+    %1 = sparse_tensor.positions %sp { level = 1 : index } : tensor<?x?xf64, #CSR> to memref<?xindex>
+    %2 = sparse_tensor.coordinates  %sp { level = 1 : index } : tensor<?x?xf64, #CSR> to memref<?xindex>
+    "test.op"(%0, %1, %2) : (memref<?xf64>, memref<?xindex>, memref<?xindex>) -> ()
+  }
+
+  return
+}

>From c6368bc924503841281ef539adbb084ac7cff530 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 17 Apr 2024 21:50:31 +0000
Subject: [PATCH 2/8] address comments

---
 mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h      | 4 ++--
 .../mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td      | 2 +-
 mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td  | 4 ++--
 mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp      | 2 +-
 4 files changed, 6 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 239118a6575d8..338adf17be891 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -58,7 +58,7 @@ struct COOSegment {
   }
 };
 
-/// A simple wrapper to encode a bitset of defined  (at most 64) levels.
+/// A simple wrapper to encode a bitset of defined (at most 64) levels.
 class LevelSet {
   uint64_t bits = 0;
 
@@ -69,7 +69,7 @@ class LevelSet {
 
   LevelSet &set(unsigned i) {
     assert(i < 64);
-    bits |= 1 << i;
+    bits |= static_cast<uint64_t>(0x01u) << i;
     return *this;
   }
 
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index c26193af09e81..69b212cce4ceb 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -20,7 +20,7 @@ class SparseTensor_Attr<string name,
     : AttrDef<SparseTensor_Dialect, name, traits>;
 
 //===----------------------------------------------------------------------===//
-// A simple bitset attribute wrapped over a single int64_t to encode a set of
+// A simple bitset attribute wrapped around a single int64_t to encode a set of
 // sparse tensor levels.
 //===----------------------------------------------------------------------===//
 
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 0ef03a5afa733..735a35c139834 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1526,8 +1526,8 @@ def IterateOp : SparseTensor_Op<"iterate",
 
   let summary = "Iterate over a sparse iteration space";
   let description = [{
-      The `sparse_tensor.iterate` operations represents a loop over the
-      provided iteration space extracted from a specific sparse tensor.
+      The `sparse_tensor.iterate` operation represents a loop (nest) over
+      the provided iteration space extracted from a specific sparse tensor.
       The operation defines an SSA value for a sparse iterator that points
       to the current stored element in the sparse tensor and SSA values
       for coordinates of the stored element. The coordinates are always
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 0a9254573538d..b5b92f9bfa0d2 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -2385,7 +2385,7 @@ LogicalResult IterateOp::verifyRegions() {
   return success();
 }
 
-/// IterateOp implemented OpInterfaces' methods.
+/// OpInterfaces' methods implemented by IterateOp.
 SmallVector<Region *> IterateOp::getLoopRegions() { return {&getRegion()}; }
 
 MutableArrayRef<OpOperand> IterateOp::getInitsMutable() {

>From d590b70e925c4da5d4097356e83ac73bc52d9a51 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 20 May 2024 21:00:26 +0000
Subject: [PATCH 3/8] use coord instead of crd

---
 .../mlir/Dialect/SparseTensor/IR/SparseTensorOps.td       | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 735a35c139834..2cc367931ce45 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1568,15 +1568,15 @@ def IterateOp : SparseTensor_Op<"iterate",
       func.func @nested_iterate(%sp : tensor<4x8xf32, #COO>) {
         // Iterates over the first level of %sp
         %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
-        %r1 = sparse_tensor.iterate %it1 in %l1 at (%crd0)
+        %r1 = sparse_tensor.iterate %it1 in %l1 at (%coord0)
             : !sparse_tensor.iter_space<#COO, lvls = 0 to 1>  {
           // Iterates over the second level of %sp
           %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1
               : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
-          %r2 = sparse_tensor.iterate %it2 in %l2 at (crd1)
+          %r2 = sparse_tensor.iterate %it2 in %l2 at (coord1)
               : !sparse_tensor.iter_space<#COO, lvls = 1 to 2>  {
-             vector.print %crd0 : index
-             vector.print %crd1 : index
+             vector.print %coord0 : index
+             vector.print %coord1 : index
           }
         }
       }

>From 777b98d2d9d6c4f9c8123593eed2c726d94cf0e2 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 20 May 2024 21:25:46 +0000
Subject: [PATCH 4/8] make iter_space type explicit

---
 .../mlir/Dialect/SparseTensor/IR/SparseTensorOps.td  | 12 ++++++------
 mlir/test/Dialect/SparseTensor/invalid.mlir          | 12 ++++++++----
 mlir/test/Dialect/SparseTensor/roundtrip.mlir        |  5 +++--
 .../Dialect/SparseTensor/sparse_itertion_licm.mlir   |  1 +
 4 files changed, 18 insertions(+), 12 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 2cc367931ce45..6b337d58cf297 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1510,7 +1510,7 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
                        LevelAttr:$loLvl, LevelAttr:$hiLvl);
   let results = (outs AnySparseIterSpace:$resultSpace);
   let assemblyFormat = "$tensor (`at` $parentIter^)? `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) "
-                       " attr-dict `:` type($tensor) (`,` type($parentIter)^)?";
+                       " attr-dict `:` type($tensor) (`,` type($parentIter)^)? `->` type($resultSpace)";
 
   let hasVerifier = 1;
 }
@@ -1543,8 +1543,8 @@ def IterateOp : SparseTensor_Op<"iterate",
         : !sparse_tensor.iter_space<#CSR, lvls = 0 to 2>
       ```
 
-      `sparse_tensor.iterate` can also operate on loop-carried variables
-      and returns the final values after loop termination.
+      `sparse_tensor.iterate` can also operate on loop-carried variables.
+      It returns the final values after loop termination.
       The initial values of the variables are passed as additional SSA operands
       to the iterator SSA value and used coordinate SSA values mentioned
       above. The operation region has an argument for the iterator, variadic
@@ -1554,9 +1554,9 @@ def IterateOp : SparseTensor_Op<"iterate",
       The body region must contain exactly one block that terminates with
       `sparse_tensor.yield`.
 
-      `sparse_tensor.iterate` results hold the final values after the last
-      iteration. If the `sparse_tensor.iterate` defines any values, a yield
-      must be explicitly present.
+      The results of an `sparse_tensor.iterate` hold the final values after
+      the last iteration. If the `sparse_tensor.iterate` defines any values,
+      a yield must be explicitly present.
       The number and types of the `sparse_tensor.iterate` results must match
       the initial values in the iter_args binding and the yield operands.
 
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index b13024cd4ed99..eb0dc01be25b9 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -1025,6 +1025,7 @@ func.func @sparse_print(%arg0: tensor<10x10xf64>) {
 func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 2>) {
   // expected-error at +1 {{'sparse_tensor.extract_iteration_space' expect larger level upper bound than lower bound}}
   %l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 2 to 0 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 2>
+                                                                       -> !sparse_tensor.iter_space<#COO, lvls = 0 to 2>
   return
 }
 
@@ -1040,6 +1041,7 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse
 func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>) {
   // expected-error at +1 {{'sparse_tensor.extract_iteration_space' op parent iterator should be specified iff level lower bound equals 0}}
   %l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 0 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
+                                                                  -> !sparse_tensor.iter_space<#COO, lvls = 1>
   return
 }
 
@@ -1054,7 +1056,7 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse
 
 func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>) {
   // expected-error at +1 {{'sparse_tensor.extract_iteration_space' op parent iterator should be specified iff level lower bound equals 0}}
-  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 1 : tensor<4x8xf32, #COO>
+  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 1 : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 1>
   return
 }
 
@@ -1077,6 +1079,7 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>) {
 func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#CSR, lvls = 0>) {
   // expected-error at +1 {{'sparse_tensor.extract_iteration_space' op mismatch in parent iterator encoding and iteration space encoding.}}
   %l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#CSR, lvls = 0>
+                                                                 -> !sparse_tensor.iter_space<#COO, lvls = 1>
   return
 }
 
@@ -1092,6 +1095,7 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse
 func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>) {
   // expected-error at +1 {{'sparse_tensor.extract_iteration_space' op parent iterator should be used to extract an iteration space from a consecutive level.}}
   %l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 2 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
+                                                                  -> !sparse_tensor.iter_space<#COO, lvls = 2>
   return
 }
 
@@ -1106,7 +1110,7 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse
 }>
 
 func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index, %j : index) -> index {
-  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
   // expected-error @+1 {{'sparse_tensor.iterate' op different number of region iter_args and yielded values: 2 != 1}}
   %r1, %r2 = sparse_tensor.iterate %it1 in %l1 at (%crd) iter_args(%si = %i, %sj = %j): !sparse_tensor.iter_space<#COO, lvls = 0> -> (index, index) {
     sparse_tensor.yield %si : index
@@ -1125,7 +1129,7 @@ func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index, %j : index) -
 
 // expected-note at +1 {{prior use here}}
 func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index) -> f32 {
-  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
   // expected-error @+1 {{use of value '%i' expects different type than prior uses: 'f32' vs 'index'}}
   %r1 = sparse_tensor.iterate %it1 in %l1 at (%crd) iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0> -> f32 {
     sparse_tensor.yield %outer : f32
@@ -1143,7 +1147,7 @@ func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index) -> f32 {
 }>
 
 func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index, %j : index) -> index {
-  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
   // expected-error @+1 {{'sparse_tensor.iterate' op 0-th region iter_arg and 0-th yielded value have different type: 'index' != 'f32'}}
   %r1 = sparse_tensor.iterate %it1 in %l1 at (%crd) iter_args(%si = %i): !sparse_tensor.iter_space<#COO, lvls = 0> -> index {
     %y = arith.constant 1.0 :  f32
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index e9a898f16b41d..bce0b41a99828 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -758,9 +758,10 @@ func.func @sparse_has_runtime() -> i1 {
 func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>)
   -> (!sparse_tensor.iter_space<#COO, lvls = 0>, !sparse_tensor.iter_space<#COO, lvls = 1>) {
   // Extracting the iteration space for the first level needs no parent iterator.
-  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
   // Extracting the iteration space for the second level needs a parent iterator.
   %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
+                                                                 -> !sparse_tensor.iter_space<#COO, lvls = 1>
   return %l1, %l2 : !sparse_tensor.iter_space<#COO, lvls = 0>, !sparse_tensor.iter_space<#COO, lvls = 1>
 }
 
@@ -785,7 +786,7 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse
 // CHECK:           return %[[VAL_4]] : index
 // CHECK:         }
 func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index, %j : index) -> index {
-  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
   %r1 = sparse_tensor.iterate %it1 in %l1 at (%crd) iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0 to 1> -> index {
     sparse_tensor.yield %outer : index
   }
diff --git a/mlir/test/Dialect/SparseTensor/sparse_itertion_licm.mlir b/mlir/test/Dialect/SparseTensor/sparse_itertion_licm.mlir
index e7158d04b37fe..f70fab3b7251d 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_itertion_licm.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_itertion_licm.mlir
@@ -15,6 +15,7 @@
 // CHECK: sparse_tensor.iterate
 func.func @sparse_iterate(%sp : tensor<?x?xf64, #CSR>) {
   %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<?x?xf64, #CSR>
+                                                         -> !sparse_tensor.iter_space<#CSR, lvls = 0>
   sparse_tensor.iterate %it1 in %l1 at (%crd) : !sparse_tensor.iter_space<#CSR, lvls = 0> {
     %0 = sparse_tensor.values %sp : tensor<?x?xf64, #CSR> to memref<?xf64>
     %1 = sparse_tensor.positions %sp { level = 1 : index } : tensor<?x?xf64, #CSR> to memref<?xindex>

>From 2a5fe9639d325d9138948fd4d93369dd424646ae Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 21 May 2024 17:17:53 +0000
Subject: [PATCH 5/8] update example code

---
 .../Dialect/SparseTensor/IR/SparseTensorOps.td     | 14 +++++++++-----
 .../SparseTensor/IR/SparseTensorDialect.cpp        |  2 +-
 2 files changed, 10 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 6b337d58cf297..a99a0c00693cc 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1478,7 +1478,7 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
       the returned iteration space covers. `hiLvl - loLvl` defines the dimension of the
       iteration space.
 
-      The type of returned the value is automatically inferred to
+      The type of returned the value is must be
       `!sparse_tensor.iter_space<#INPUT_ENCODING, lvls = $loLvl to $hiLvl>`.
       The returned iteration space can then be iterated over by
       `sparse_tensor.iterate` operations to visit every stored element
@@ -1489,6 +1489,7 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
       // Extracts a 1-D iteration space from a COO tensor at level 1.
       %space = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 1
         : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
+       ->!sparse_tensor.iter_space<#COO, lvls = 1>
       ```
   }];
 
@@ -1501,16 +1502,17 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
       return getHiLvl() - getLoLvl();
     }
     ArrayRef<::mlir::sparse_tensor::LevelType> getSpaceLvlTypes() {
-      return getResultSpace().getType().getLvlTypes();
+      return getExtractedSpace().getType().getLvlTypes();
     }
   }];
 
   let arguments = (ins AnySparseTensor:$tensor,
                        Optional<AnySparseIterator>:$parentIter,
                        LevelAttr:$loLvl, LevelAttr:$hiLvl);
-  let results = (outs AnySparseIterSpace:$resultSpace);
+  let results = (outs AnySparseIterSpace:$extractedSpace);
   let assemblyFormat = "$tensor (`at` $parentIter^)? `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) "
-                       " attr-dict `:` type($tensor) (`,` type($parentIter)^)? `->` type($resultSpace)";
+                       " attr-dict `:` type($tensor) (`,` type($parentIter)^)? "
+                       "`->` qualified(type($extractedSpace))";
 
   let hasVerifier = 1;
 }
@@ -1567,12 +1569,14 @@ def IterateOp : SparseTensor_Op<"iterate",
       ```mlir
       func.func @nested_iterate(%sp : tensor<4x8xf32, #COO>) {
         // Iterates over the first level of %sp
-        %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+        %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0
+            : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0 to 1>
         %r1 = sparse_tensor.iterate %it1 in %l1 at (%coord0)
             : !sparse_tensor.iter_space<#COO, lvls = 0 to 1>  {
           // Iterates over the second level of %sp
           %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1
               : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
+             -> !sparse_tensor.iter_space<#COO, lvls = 1 to 2>
           %r2 = sparse_tensor.iterate %it2 in %l2 at (coord1)
               : !sparse_tensor.iter_space<#COO, lvls = 1 to 2>  {
              vector.print %coord0 : index
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index b5b92f9bfa0d2..d2de6bc0dbb85 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -2253,7 +2253,7 @@ LogicalResult ExtractIterSpaceOp::verify() {
   }
 
   if (pIter) {
-    IterSpaceType spaceTp = getResultSpace().getType();
+    IterSpaceType spaceTp = getExtractedSpace().getType();
     if (pIter.getType().getEncoding() != spaceTp.getEncoding())
       return emitOpError(
           "mismatch in parent iterator encoding and iteration space encoding.");

>From aa1785cac0d8705e8bd02970c24356ef4463d561 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 16 Apr 2024 17:00:45 +0000
Subject: [PATCH 6/8] [mlir][sparse] implement sparse space collapse pass.

---
 .../Dialect/SparseTensor/Transforms/Passes.h  |   6 +
 .../Dialect/SparseTensor/Transforms/Passes.td |  16 ++
 .../SparseTensor/Transforms/CMakeLists.txt    |   1 +
 .../Transforms/SparseSpaceCollapse.cpp        | 183 ++++++++++++++++++
 .../SparseTensor/sparse_space_collapse.mlir   |  33 ++++
 5 files changed, 239 insertions(+)
 create mode 100644 mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
 create mode 100644 mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir

diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index bb49d6c256f21..ff29f0a2c219c 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -254,6 +254,12 @@ std::unique_ptr<Pass> createSparsificationAndBufferizationPass(
     bool enableBufferInitialization, unsigned vectorLength,
     bool enableVLAVectorization, bool enableSIMDIndex32, bool enableGPULibgen);
 
+//===----------------------------------------------------------------------===//
+// Sparse Iteration Transform Passes
+//===----------------------------------------------------------------------===//
+
+std::unique_ptr<Pass> createSparseSpaceCollapsePass();
+
 //===----------------------------------------------------------------------===//
 // Registration.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 94c3ca60030ee..3f9fc54260503 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -500,4 +500,20 @@ def SparsificationAndBufferization : Pass<"sparsification-and-bufferization", "M
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// Sparse Iteration Transform Passes
+//===----------------------------------------------------------------------===//
+
+def SparseSpaceCollapse : Pass<"sparse-space-collapse", "func::FuncOp"> {
+  let summary = "(experimental) sparse space collpasing pass";
+  let description = [{
+     This pass collapse consecutive sparse spaces (extracted from the same tensor)
+     into one multi-dimensional space. The pass is not yet stablized.
+  }];
+  let constructor = "mlir::createSparseSpaceCollapsePass()";
+  let dependentDialects = [
+    "sparse_tensor::SparseTensorDialect",
+  ];
+}
+
 #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index af3a1b48f45af..2a29ee8a7a87c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
   SparseGPUCodegen.cpp
   SparseReinterpretMap.cpp
   SparseStorageSpecifierToLLVM.cpp
+  SparseSpaceCollapse.cpp
   SparseTensorCodegen.cpp
   SparseTensorConversion.cpp
   SparseTensorPasses.cpp
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
new file mode 100644
index 0000000000000..bc469992d9710
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
@@ -0,0 +1,183 @@
+//===--------- SparseSpaceCollapse.cpp - Collapse Sparse Space Pass -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/Transforms/Passes.h"
+
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+
+namespace mlir {
+
+#define GEN_PASS_DEF_SPARSESPACECOLLAPSE
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
+
+namespace sparse_tensor {
+
+struct CollapseSpaceInfo {
+  ExtractIterSpaceOp space;
+  IterateOp loop;
+};
+
+bool isCollapsableLoops(LoopLikeOpInterface parent, LoopLikeOpInterface node) {
+  auto pIterArgs = parent.getRegionIterArgs();
+  auto nInitArgs = node.getInits();
+  if (pIterArgs.size() != nInitArgs.size())
+    return false;
+
+  // Two loops are collapsable if they are perfectly nested.
+  auto pYields = parent.getYieldedValues();
+  auto nResult = node.getLoopResults().value();
+
+  bool yieldEq =
+      llvm::all_of(llvm::zip_equal(pYields, nResult), [](auto zipped) {
+        return std::get<0>(zipped) == std::get<1>(zipped);
+      });
+
+  // Parent iter_args should be passed directly to the node's init_args.
+  bool iterArgEq =
+      llvm::all_of(llvm::zip_equal(pIterArgs, nInitArgs), [](auto zipped) {
+        return std::get<0>(zipped) == std::get<1>(zipped);
+      });
+
+  return yieldEq && iterArgEq;
+}
+
+bool legalToCollapse(SmallVectorImpl<CollapseSpaceInfo> &toCollapse,
+                     ExtractIterSpaceOp curSpace) {
+
+  auto getIterateOpOverSpace = [](ExtractIterSpaceOp space) -> IterateOp {
+    Value spaceVal = space.getResultSpace();
+    if (spaceVal.hasOneUse())
+      return llvm::dyn_cast<IterateOp>(*spaceVal.getUsers().begin());
+    return nullptr;
+  };
+
+  if (toCollapse.empty()) {
+    // Collapse root.
+    if (auto itOp = getIterateOpOverSpace(curSpace)) {
+      CollapseSpaceInfo &info = toCollapse.emplace_back();
+      info.space = curSpace;
+      info.loop = itOp;
+      return true;
+    }
+    return false;
+  }
+
+  auto parent = toCollapse.back().space;
+  auto pItOp = toCollapse.back().loop;
+  auto nItOp = getIterateOpOverSpace(curSpace);
+
+  // Can only collapse spaces extracted from the same tensor.
+  if (parent.getTensor() != curSpace.getTensor())
+    return false;
+
+  // Can only collapse consecutive simple iteration on one tensor (i.e., no
+  // coiteration).
+  if (!nItOp || nItOp->getBlock() != curSpace->getBlock() ||
+      pItOp.getIterator() != curSpace.getParentIter() ||
+      curSpace->getParentOp() != pItOp.getOperation())
+    return false;
+
+  if (pItOp && !isCollapsableLoops(pItOp, nItOp))
+    return false;
+
+  CollapseSpaceInfo &info = toCollapse.emplace_back();
+  info.space = curSpace;
+  info.loop = nItOp;
+  return true;
+}
+
+void collapseSparseSpace(MutableArrayRef<CollapseSpaceInfo> toCollapse) {
+  if (toCollapse.size() < 2)
+    return;
+
+  ExtractIterSpaceOp root = toCollapse.front().space;
+  ExtractIterSpaceOp leaf = toCollapse.back().space;
+  Location loc = root.getLoc();
+
+  assert(root->hasOneUse() && leaf->hasOneUse());
+
+  // Insert collapsed operation at the same scope as root operation.
+  OpBuilder builder(root);
+
+  // Construct the collapsed iteration space.
+  auto collapsedSpace = builder.create<ExtractIterSpaceOp>(
+      loc, root.getTensor(), root.getParentIter(), root.getLoLvl(),
+      leaf.getHiLvl());
+
+  auto rItOp = llvm::cast<IterateOp>(*root->getUsers().begin());
+  auto innermost = toCollapse.back().loop;
+
+  IRMapping mapper;
+  mapper.map(leaf, collapsedSpace.getResultSpace());
+  for (auto z : llvm::zip_equal(innermost.getInitArgs(), rItOp.getInitArgs()))
+    mapper.map(std::get<0>(z), std::get<1>(z));
+
+  auto cloned = llvm::cast<IterateOp>(builder.clone(*innermost, mapper));
+  builder.setInsertionPointToStart(cloned.getBody());
+
+  LevelSet crdUsedLvls;
+  unsigned shift = 0, argIdx = 1;
+  for (auto info : toCollapse.drop_back()) {
+    LevelSet set = info.loop.getCrdUsedLvls();
+    crdUsedLvls |= set.lshift(shift);
+    shift += info.loop.getSpaceDim();
+    for (BlockArgument crd : info.loop.getCrds()) {
+      BlockArgument collapsedCrd = cloned.getBody()->insertArgument(
+          argIdx++, builder.getIndexType(), crd.getLoc());
+      crd.replaceAllUsesWith(collapsedCrd);
+    }
+  }
+  crdUsedLvls |= innermost.getCrdUsedLvls().lshift(shift);
+  cloned.getIterator().setType(collapsedSpace.getType().getIteratorType());
+  cloned.setCrdUsedLvls(crdUsedLvls);
+
+  rItOp.replaceAllUsesWith(cloned.getResults());
+  // Erase collapsed loops.
+  rItOp.erase();
+  root.erase();
+}
+
+struct SparseSpaceCollapsePass
+    : public impl::SparseSpaceCollapseBase<SparseSpaceCollapsePass> {
+  SparseSpaceCollapsePass() = default;
+
+  void runOnOperation() override {
+    func::FuncOp func = getOperation();
+
+    // A naive (experimental) implementation to collapse consecutive sparse
+    // spaces. It does NOT handle complex cases where multiple spaces are
+    // extracted in the same basic block. E.g.,
+    //
+    // %space1 = extract_space %t1 ...
+    // %space2 = extract_space %t2 ...
+    // sparse_tensor.iterate(%sp1) ...
+    //
+    SmallVector<CollapseSpaceInfo> toCollapse;
+    func->walk([&](ExtractIterSpaceOp op) {
+      if (!legalToCollapse(toCollapse, op)) {
+        // if not legal to collapse one more space, collapse the existing ones
+        // and clear.
+        collapseSparseSpace(toCollapse);
+        toCollapse.clear();
+      }
+    });
+
+    collapseSparseSpace(toCollapse);
+  }
+};
+
+} // namespace sparse_tensor
+
+std::unique_ptr<Pass> createSparseSpaceCollapsePass() {
+  return std::make_unique<sparse_tensor::SparseSpaceCollapsePass>();
+}
+
+} // namespace mlir
diff --git a/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir b/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir
new file mode 100644
index 0000000000000..392dfe01884ba
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-opt %s --sparse-space-collapse | FileCheck %s
+
+#COO = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : compressed(nonunique),
+    j : singleton(soa)
+  )
+}>
+
+// CHECK-LABEL:   func.func @sparse_sparse_collapse(
+// CHECK-SAME:         %[[VAL_0:.*]]: tensor<4x8xf32, #sparse>,
+// CHECK-SAME:         %[[VAL_1:.*]]: index) {
+// CHECK:           %[[VAL_3:.*]] = sparse_tensor.extract_iteration_space %[[VAL_0]] lvls = 0 to 2 : tensor<4x8xf32, #sparse>
+// CHECK:           %[[VAL_4:.*]] = sparse_tensor.iterate %[[VAL_5:.*]] in %[[VAL_3]] at(%[[VAL_6:.*]], _) iter_args(%[[VAL_7:.*]] = %[[VAL_1]])
+// CHECK:             %[[VAL_8:.*]] = "test.op"(%[[VAL_7]]) : (index) -> index
+// CHECK:             sparse_tensor.yield %[[VAL_8]] : index
+// CHECK:           }
+// CHECK:           "test.sink"(%[[VAL_4]]) : (index) -> ()
+// CHECK:           return
+// CHECK:         }
+func.func @sparse_sparse_collapse(%sp : tensor<4x8xf32, #COO>, %i : index) {
+  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+  %r1 = sparse_tensor.iterate %it1 in %l1 at(%crd0) iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0 to 1> -> index {
+    %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
+    %r2 = sparse_tensor.iterate %it2 in %l2 iter_args(%inner = %outer): !sparse_tensor.iter_space<#COO, lvls = 1 to 2> -> index {
+      %k ="test.op"(%inner) : (index) -> index
+      sparse_tensor.yield %k : index
+    }
+    sparse_tensor.yield %r2 : index
+  }
+  "test.sink"(%r1) : (index) -> ()
+  return
+}

>From 1f2857be907fa1b6fb887c620d38a5dae5fa8c32 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 21 May 2024 18:19:54 +0000
Subject: [PATCH 7/8] rebase

---
 .../SparseTensor/Transforms/SparseSpaceCollapse.cpp       | 4 ++--
 mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir | 8 ++++++--
 2 files changed, 8 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
index bc469992d9710..4d06603a59862 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
@@ -53,7 +53,7 @@ bool legalToCollapse(SmallVectorImpl<CollapseSpaceInfo> &toCollapse,
                      ExtractIterSpaceOp curSpace) {
 
   auto getIterateOpOverSpace = [](ExtractIterSpaceOp space) -> IterateOp {
-    Value spaceVal = space.getResultSpace();
+    Value spaceVal = space.getExtractedSpace();
     if (spaceVal.hasOneUse())
       return llvm::dyn_cast<IterateOp>(*spaceVal.getUsers().begin());
     return nullptr;
@@ -116,7 +116,7 @@ void collapseSparseSpace(MutableArrayRef<CollapseSpaceInfo> toCollapse) {
   auto innermost = toCollapse.back().loop;
 
   IRMapping mapper;
-  mapper.map(leaf, collapsedSpace.getResultSpace());
+  mapper.map(leaf, collapsedSpace.getExtractedSpace());
   for (auto z : llvm::zip_equal(innermost.getInitArgs(), rItOp.getInitArgs()))
     mapper.map(std::get<0>(z), std::get<1>(z));
 
diff --git a/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir b/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir
index 392dfe01884ba..baa6199f12bc3 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir
@@ -19,9 +19,13 @@
 // CHECK:           return
 // CHECK:         }
 func.func @sparse_sparse_collapse(%sp : tensor<4x8xf32, #COO>, %i : index) {
-  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0
+      : tensor<4x8xf32, #COO>
+     -> !sparse_tensor.iter_space<#COO, lvls = 0>
   %r1 = sparse_tensor.iterate %it1 in %l1 at(%crd0) iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0 to 1> -> index {
-    %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
+    %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1
+        : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
+       -> !sparse_tensor.iter_space<#COO, lvls = 1>
     %r2 = sparse_tensor.iterate %it2 in %l2 iter_args(%inner = %outer): !sparse_tensor.iter_space<#COO, lvls = 1 to 2> -> index {
       %k ="test.op"(%inner) : (index) -> index
       sparse_tensor.yield %k : index

>From bfc234a2f59f698dcefd6423a2056127d413a4e0 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 17 Apr 2024 18:51:24 +0000
Subject: [PATCH 8/8] [mlir][sparse] implement lowering rules for
 ExtractIterSpaceOp.

---
 .../mlir/Dialect/SparseTensor/IR/Enums.h      |   4 +
 .../Dialect/SparseTensor/Transforms/Passes.h  |  15 +++
 .../Dialect/SparseTensor/Transforms/Passes.td |  15 +++
 .../SparseTensor/Transforms/CMakeLists.txt    |   1 +
 .../Transforms/SparseIterationToScf.cpp       |  78 ++++++++++++
 .../Transforms/SparseTensorPasses.cpp         |  29 +++++
 .../Transforms/Utils/SparseTensorIterator.cpp | 117 ++++++++++++++----
 .../Transforms/Utils/SparseTensorIterator.h   |  67 +++++++++-
 .../SparseTensor/sparse_iteration_to_scf.mlir |  23 ++++
 .../SparseTensor/sparse_space_collapse.mlir   |  11 +-
 10 files changed, 325 insertions(+), 35 deletions(-)
 create mode 100644 mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
 create mode 100644 mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index 33f613a46bad8..96ee7111fea2c 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -357,6 +357,10 @@ struct LevelType {
     return hasSparseSemantic();
   }
 
+  constexpr unsigned getNumBuffer() const {
+    return hasDenseSemantic() ? 0 : (isWithPosLT() ? 2 : 1);
+  }
+
   std::string toMLIRString() const {
     std::string lvlStr = toFormatString(getLvlFmt());
     std::string propStr = "";
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index ff29f0a2c219c..7260223ad8da3 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -16,6 +16,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/OneToNTypeConversion.h"
 
 //===----------------------------------------------------------------------===//
 // Include the generated pass header (which needs some early definitions).
@@ -149,6 +150,20 @@ void populateLowerForeachToSCFPatterns(RewritePatternSet &patterns);
 
 std::unique_ptr<Pass> createLowerForeachToSCFPass();
 
+//===----------------------------------------------------------------------===//
+// The LowerSparseIterationToSCF pass.
+//===----------------------------------------------------------------------===//
+
+/// Type converter for iter_space and iterator.
+struct SparseIterationTypeConverter : public OneToNTypeConverter {
+  SparseIterationTypeConverter();
+};
+
+void populateLowerSparseIterationToSCFPatterns(TypeConverter &converter,
+                                               RewritePatternSet &patterns);
+
+std::unique_ptr<Pass> createLowerSparseIterationToSCFPass();
+
 //===----------------------------------------------------------------------===//
 // The SparseTensorConversion pass.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 3f9fc54260503..8e694d5b0e798 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -516,4 +516,19 @@ def SparseSpaceCollapse : Pass<"sparse-space-collapse", "func::FuncOp"> {
   ];
 }
 
+def LowerSparseIterationToSCF : Pass<"lower-sparse-iteration-to-scf", "func::FuncOp"> {
+  let summary = "lower sparse_tensor.iterate/coiterate into scf loops";
+  let description = [{
+     This pass lowers `sparse_tensor.iterate` operations into `scf.for/while` operations.
+     The pass is not yet stablized.
+  }];
+  let constructor = "mlir::createLowerSparseIterationToSCFPass()";
+  let dependentDialects = [
+    "memref::MemRefDialect",
+    "scf::SCFDialect",
+    "sparse_tensor::SparseTensorDialect",
+  ];
+}
+
+
 #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index 2a29ee8a7a87c..e4acfa8889e5f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
   SparseAssembler.cpp
   SparseBufferRewriting.cpp
   SparseGPUCodegen.cpp
+  SparseIterationToScf.cpp
   SparseReinterpretMap.cpp
   SparseStorageSpecifierToLLVM.cpp
   SparseSpaceCollapse.cpp
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
new file mode 100644
index 0000000000000..d89b0b192ffcd
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -0,0 +1,78 @@
+
+#include "Utils/CodegenUtils.h"
+#include "Utils/SparseTensorIterator.h"
+
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+#include "mlir/Transforms/OneToNTypeConversion.h"
+
+using namespace mlir;
+using namespace mlir::sparse_tensor;
+
+void convertLevelType(SparseTensorEncodingAttr enc, Level lvl,
+                      SmallVectorImpl<Type> &fields) {
+  // Position and coordinate buffer in the sparse structure.
+  if (enc.getLvlType(lvl).isWithPosLT())
+    fields.push_back(enc.getPosMemRefType());
+  if (enc.getLvlType(lvl).isWithCrdLT())
+    fields.push_back(enc.getCrdMemRefType());
+  // One index for shape bound (result from lvlOp)
+  fields.push_back(IndexType::get(enc.getContext()));
+}
+
+static std::optional<LogicalResult>
+convertIterSpaceType(IterSpaceType itSp, SmallVectorImpl<Type> &fields) {
+
+  auto idxTp = IndexType::get(itSp.getContext());
+  for (Level l = itSp.getLoLvl(); l < itSp.getHiLvl(); l++)
+    convertLevelType(itSp.getEncoding(), l, fields);
+
+  // Two indices for lower and upper bound (we only need one pair for the last
+  // iteration space).
+  fields.append({idxTp, idxTp});
+  return success();
+}
+
+namespace {
+
+/// Sparse codegen rule for number of entries operator.
+class ExtractIterSpaceConverter
+    : public OneToNOpConversionPattern<ExtractIterSpaceOp> {
+public:
+  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+  LogicalResult
+  matchAndRewrite(ExtractIterSpaceOp op, OpAdaptor adaptor,
+                  OneToNPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
+
+    // Construct the iteration space.
+    SparseIterationSpace space(loc, rewriter, op.getTensor(), 0,
+                               op.getLvlRange(), adaptor.getParentIter());
+
+    SmallVector<Value> result = space.toValues();
+    rewriter.replaceOp(op, result, resultMapping);
+    return success();
+  }
+};
+
+} // namespace
+
+mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
+  addConversion([](Type type) { return type; });
+  addConversion(convertIterSpaceType);
+
+  addSourceMaterialization([](OpBuilder &builder, IterSpaceType spTp,
+                              ValueRange inputs,
+                              Location loc) -> std::optional<Value> {
+    return builder
+        .create<UnrealizedConversionCastOp>(loc, TypeRange(spTp), inputs)
+        .getResult(0);
+  });
+}
+
+void mlir::populateLowerSparseIterationToSCFPatterns(
+    TypeConverter &converter, RewritePatternSet &patterns) {
+  patterns.add<ExtractIterSpaceConverter>(converter, patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index f57353b5892b5..12ea69d28f472 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -27,6 +27,7 @@ namespace mlir {
 #define GEN_PASS_DEF_SPARSEREINTERPRETMAP
 #define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE
 #define GEN_PASS_DEF_SPARSIFICATIONPASS
+#define GEN_PASS_DEF_LOWERSPARSEITERATIONTOSCF
 #define GEN_PASS_DEF_LOWERSPARSEOPSTOFOREACH
 #define GEN_PASS_DEF_LOWERFOREACHTOSCF
 #define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
@@ -162,10 +163,34 @@ struct LowerForeachToSCFPass
     auto *ctx = &getContext();
     RewritePatternSet patterns(ctx);
     populateLowerForeachToSCFPatterns(patterns);
+
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
 };
 
+struct LowerSparseIterationToSCFPass
+    : public impl::LowerSparseIterationToSCFBase<
+          LowerSparseIterationToSCFPass> {
+  LowerSparseIterationToSCFPass() = default;
+  LowerSparseIterationToSCFPass(const LowerSparseIterationToSCFPass &) =
+      default;
+
+  void runOnOperation() override {
+    auto *ctx = &getContext();
+    RewritePatternSet patterns(ctx);
+    SparseIterationTypeConverter converter;
+    ConversionTarget target(*ctx);
+
+    // The actual conversion.
+    target.addIllegalOp<ExtractIterSpaceOp, IterateOp>();
+    populateLowerSparseIterationToSCFPatterns(converter, patterns);
+
+    if (failed(applyPartialOneToNConversion(getOperation(), converter,
+                                            std::move(patterns))))
+      signalPassFailure();
+  }
+};
+
 struct SparseTensorConversionPass
     : public impl::SparseTensorConversionPassBase<SparseTensorConversionPass> {
   SparseTensorConversionPass() = default;
@@ -452,6 +477,10 @@ std::unique_ptr<Pass> mlir::createLowerForeachToSCFPass() {
   return std::make_unique<LowerForeachToSCFPass>();
 }
 
+std::unique_ptr<Pass> mlir::createLowerSparseIterationToSCFPass() {
+  return std::make_unique<LowerSparseIterationToSCFPass>();
+}
+
 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
   return std::make_unique<SparseTensorConversionPass>();
 }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
index dbec46d2616d9..be8e15d6ae6f4 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
@@ -168,7 +168,7 @@ class CompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
     ValueRange posRange = posRangeIf.getResults();
     return {posRange.front(), posRange.back()};
   }
-};
+}; // namespace
 
 class LooseCompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
 public:
@@ -190,7 +190,7 @@ class LooseCompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
     Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
     return {pLo, pHi};
   }
-};
+}; // namespace
 
 class SingletonLevel : public SparseLevel</*hasPosBuf=*/false> {
 public:
@@ -210,6 +210,13 @@ class SingletonLevel : public SparseLevel</*hasPosBuf=*/false> {
     // Use the segHi as the loop upper bound.
     return {p, segHi};
   }
+
+  ValuePair
+  collapseRangeBetween(OpBuilder &b, Location l, ValueRange batchPrefix,
+                       std::pair<Value, Value> parentRange) const override {
+    // Singleton level keeps the same range after collapsing.
+    return parentRange;
+  };
 };
 
 class NOutOfMLevel : public SparseLevel</*hasPosBuf=*/false> {
@@ -1474,10 +1481,85 @@ ValueRange NonEmptySubSectIterator::forwardImpl(OpBuilder &b, Location l) {
   return getCursor();
 }
 
+//===----------------------------------------------------------------------===//
+// SparseIterationSpace Implementation
+//===----------------------------------------------------------------------===//
+
+mlir::sparse_tensor::SparseIterationSpace::SparseIterationSpace(
+    Location l, OpBuilder &b, Value t, unsigned tid,
+    std::pair<Level, Level> lvlRange, ValueRange parentPos)
+    : lvls() {
+  auto [lvlLo, lvlHi] = lvlRange;
+
+  Value c0 = C_IDX(0);
+  if (parentPos.empty())
+    parentPos = c0;
+
+  for (Level lvl = lvlLo; lvl < lvlHi; lvl++)
+    lvls.emplace_back(makeSparseTensorLevel(b, l, t, tid, lvl));
+
+  bound = lvls.front()->peekRangeAt(b, l, /*batchPrefix=*/{}, parentPos);
+  for (auto &lvl : getLvlRef().drop_front())
+    bound = lvl->collapseRangeBetween(b, l, /*batchPrefix=*/{}, bound);
+}
+
+SparseIterationSpace mlir::sparse_tensor::SparseIterationSpace::fromValues(
+    IterSpaceType dstTp, ValueRange values, unsigned int tid) {
+  // Reconstruct every sparse tensor level.
+  SparseIterationSpace space;
+  for (auto [i, lt] : llvm::enumerate(dstTp.getLvlTypes())) {
+    unsigned bufferCnt = 0;
+    if (lt.isWithPosLT())
+      bufferCnt++;
+    if (lt.isWithCrdLT())
+      bufferCnt++;
+    // Sparse tensor buffers.
+    ValueRange buffers = values.take_front(bufferCnt);
+    values = values.drop_front(bufferCnt);
+
+    // Level size.
+    Value sz = values.front();
+    values = values.drop_front();
+    space.lvls.push_back(
+        makeSparseTensorLevel(lt, sz, buffers, tid, i + dstTp.getLoLvl()));
+  }
+  // Two bounds.
+  space.bound = std::make_pair(values[0], values[1]);
+  values = values.drop_front(2);
+
+  // Must have consumed all values.
+  assert(values.empty());
+  return space;
+}
+
 //===----------------------------------------------------------------------===//
 // SparseIterator factory functions.
 //===----------------------------------------------------------------------===//
 
+/// Helper function to create a TensorLevel object from given `tensor`.
+std::unique_ptr<SparseTensorLevel>
+sparse_tensor::makeSparseTensorLevel(LevelType lt, Value sz, ValueRange b,
+                                     unsigned t, Level l) {
+  assert(lt.getNumBuffer() == b.size());
+  switch (lt.getLvlFmt()) {
+  case LevelFormat::Dense:
+    return std::make_unique<DenseLevel>(t, l, sz);
+  case LevelFormat::Batch:
+    return std::make_unique<BatchLevel>(t, l, sz);
+  case LevelFormat::Compressed:
+    return std::make_unique<CompressedLevel>(t, l, lt, sz, b[0], b[1]);
+  case LevelFormat::LooseCompressed:
+    return std::make_unique<LooseCompressedLevel>(t, l, lt, sz, b[0], b[1]);
+  case LevelFormat::Singleton:
+    return std::make_unique<SingletonLevel>(t, l, lt, sz, b[0]);
+  case LevelFormat::NOutOfM:
+    return std::make_unique<NOutOfMLevel>(t, l, lt, sz, b[0]);
+  case LevelFormat::Undef:
+    llvm_unreachable("undefined level format");
+  }
+  llvm_unreachable("unrecognizable level format");
+}
+
 std::unique_ptr<SparseTensorLevel>
 sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
                                      unsigned tid, Level lvl) {
@@ -1487,33 +1569,16 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
   Value sz = stt.hasEncoding() ? b.create<LvlOp>(l, t, lvl).getResult()
                                : b.create<tensor::DimOp>(l, t, lvl).getResult();
 
-  switch (lt.getLvlFmt()) {
-  case LevelFormat::Dense:
-    return std::make_unique<DenseLevel>(tid, lvl, sz);
-  case LevelFormat::Batch:
-    return std::make_unique<BatchLevel>(tid, lvl, sz);
-  case LevelFormat::Compressed: {
-    Value pos = b.create<ToPositionsOp>(l, t, lvl);
-    Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
-    return std::make_unique<CompressedLevel>(tid, lvl, lt, sz, pos, crd);
-  }
-  case LevelFormat::LooseCompressed: {
+  SmallVector<Value, 2> buffers;
+  if (lt.isWithPosLT()) {
     Value pos = b.create<ToPositionsOp>(l, t, lvl);
-    Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
-    return std::make_unique<LooseCompressedLevel>(tid, lvl, lt, sz, pos, crd);
-  }
-  case LevelFormat::Singleton: {
-    Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
-    return std::make_unique<SingletonLevel>(tid, lvl, lt, sz, crd);
+    buffers.push_back(pos);
   }
-  case LevelFormat::NOutOfM: {
-    Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
-    return std::make_unique<NOutOfMLevel>(tid, lvl, lt, sz, crd);
+  if (lt.isWithCrdLT()) {
+    Value pos = b.create<ToCoordinatesOp>(l, t, lvl);
+    buffers.push_back(pos);
   }
-  case LevelFormat::Undef:
-    llvm_unreachable("undefined level format");
-  }
-  llvm_unreachable("unrecognizable level format");
+  return makeSparseTensorLevel(lt, sz, buffers, tid, lvl);
 }
 
 std::pair<std::unique_ptr<SparseTensorLevel>, std::unique_ptr<SparseIterator>>
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
index 120a806536f19..09503d4b6a099 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
@@ -15,6 +15,9 @@
 namespace mlir {
 namespace sparse_tensor {
 
+// Forward declaration.
+class SparseIterator;
+
 /// The base class for all types of sparse tensor levels. It provides interfaces
 /// to query the loop range (see `peekRangeAt`) and look up the coordinates (see
 /// `peekCrdAt`).
@@ -50,6 +53,12 @@ class SparseTensorLevel {
   peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
               ValueRange parentPos, Value inPadZone = nullptr) const = 0;
 
+  virtual std::pair<Value, Value>
+  collapseRangeBetween(OpBuilder &b, Location l, ValueRange batchPrefix,
+                       std::pair<Value, Value> parentRange) const {
+    llvm_unreachable("Not Implemented");
+  };
+
   Level getLevel() const { return lvl; }
   LevelType getLT() const { return lt; }
   Value getSize() const { return lvlSize; }
@@ -79,6 +88,51 @@ enum class IterKind : uint8_t {
   kPad,
 };
 
+class SparseIterationSpace {
+public:
+  SparseIterationSpace() = default;
+
+  // Constructs a N-D iteration space.
+  SparseIterationSpace(Location loc, OpBuilder &b, Value t, unsigned tid,
+                       std::pair<Level, Level> lvlRange, ValueRange parentPos);
+
+  // Constructs a 1-D iteration space.
+  SparseIterationSpace(Location loc, OpBuilder &b, Value t, unsigned tid,
+                       Level lvl, ValueRange parentPos)
+      : SparseIterationSpace(loc, b, t, tid, {lvl, lvl + 1}, parentPos){};
+
+  bool isUnique() const { return lvls.back()->isUnique(); }
+
+  unsigned getSpaceDim() const { return lvls.size(); }
+
+  // Reconstructs a iteration space directly from the provided ValueRange.
+  static SparseIterationSpace fromValues(IterSpaceType dstTp, ValueRange values,
+                                         unsigned tid);
+
+  // The inverse operation of `fromValues`.
+  SmallVector<Value> toValues() const {
+    SmallVector<Value> vals;
+    for (auto &stl : lvls) {
+      llvm::append_range(vals, stl->getLvlBuffers());
+      vals.push_back(stl->getSize());
+    }
+    vals.append({bound.first, bound.second});
+    return vals;
+  }
+
+  const SparseTensorLevel &getLastLvl() const { return *lvls.back(); }
+  ArrayRef<std::unique_ptr<SparseTensorLevel>> getLvlRef() const {
+    return lvls;
+  }
+
+  Value getBoundLo() const { return bound.first; }
+  Value getBoundHi() const { return bound.second; }
+
+private:
+  SmallVector<std::unique_ptr<SparseTensorLevel>> lvls;
+  std::pair<Value, Value> bound;
+};
+
 /// Helper class that generates loop conditions, etc, to traverse a
 /// sparse tensor level.
 class SparseIterator {
@@ -287,10 +341,15 @@ std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(OpBuilder &b,
                                                          unsigned tid,
                                                          Level lvl);
 
-/// Helper function to create a simple SparseIterator object that iterate over
-/// the SparseTensorLevel.
-std::unique_ptr<SparseIterator> makeSimpleIterator(const SparseTensorLevel &stl,
-                                                   SparseEmitStrategy strategy);
+/// Helper function to create a TensorLevel object from given `tensor`.
+std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(LevelType lt, Value sz,
+                                                         ValueRange buffers,
+                                                         unsigned tid, Level l);
+/// Helper function to create a simple SparseIterator object that iterate
+/// over the SparseTensorLevel.
+std::unique_ptr<SparseIterator> makeSimpleIterator(
+    const SparseTensorLevel &stl,
+    SparseEmitStrategy strategy = SparseEmitStrategy::kFunctional);
 
 /// Helper function to create a synthetic SparseIterator object that iterates
 /// over a dense space specified by [0,`sz`).
diff --git a/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir b/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir
new file mode 100644
index 0000000000000..b51c11ebf8a8c
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir
@@ -0,0 +1,23 @@
+// RUN: mlir-opt %s --lower-sparse-iteration-to-scf | FileCheck %s
+
+#COO = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : compressed(nonunique),
+    j : singleton(soa)
+  )
+}>
+
+// CHECK-LABEL:   func.func @sparse_1D_space(
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<?x?xf32, #sparse{{[0-9]*}}>) -> !sparse_tensor.iter_space<#sparse{{[0-9]*}}, lvls = 0> {
+// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[LVL_SIZE:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[C0]] : tensor<?x?xf32, #sparse{{[0-9]*}}>
+// CHECK:           %[[POS_MEM:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK:           %[[CRD_MEM:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK:           %[[POS_LO:.*]] = memref.load %[[POS_MEM]]{{\[}}%[[C0]]] : memref<?xindex>
+// CHECK:           %[[POS_HI:.*]] = memref.load %[[POS_MEM]]{{\[}}%[[C1]]] : memref<?xindex>
+// CHECK:           %[[ITER_SPACE:.*]] = builtin.unrealized_conversion_cast %[[POS_MEM]], %[[CRD_MEM]], %[[LVL_SIZE]], %[[POS_LO]], %[[POS_HI]]
+func.func @sparse_1D_space(%sp : tensor<?x?xf32, #COO>) -> !sparse_tensor.iter_space<#COO, lvls = 0> {
+  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<?x?xf32, #COO>
+  return %l1 : !sparse_tensor.iter_space<#COO, lvls = 0>
+}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir b/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir
index baa6199f12bc3..b99bf915c71f8 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir
@@ -18,20 +18,21 @@
 // CHECK:           "test.sink"(%[[VAL_4]]) : (index) -> ()
 // CHECK:           return
 // CHECK:         }
-func.func @sparse_sparse_collapse(%sp : tensor<4x8xf32, #COO>, %i : index) {
+func.func @sparse_sparse_collapse(%sp : tensor<4x8xf32, #COO>) -> index {
+  %i = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
   %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0
       : tensor<4x8xf32, #COO>
      -> !sparse_tensor.iter_space<#COO, lvls = 0>
-  %r1 = sparse_tensor.iterate %it1 in %l1 at(%crd0) iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0 to 1> -> index {
+    %r1 = sparse_tensor.iterate %it1 in %l1 iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0 to 1> -> index {
     %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1
         : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
        -> !sparse_tensor.iter_space<#COO, lvls = 1>
     %r2 = sparse_tensor.iterate %it2 in %l2 iter_args(%inner = %outer): !sparse_tensor.iter_space<#COO, lvls = 1 to 2> -> index {
-      %k ="test.op"(%inner) : (index) -> index
+      %k = arith.addi %inner, %c1 : index
       sparse_tensor.yield %k : index
     }
     sparse_tensor.yield %r2 : index
   }
-  "test.sink"(%r1) : (index) -> ()
-  return
+  return %r1 : index
 }



More information about the Mlir-commits mailing list