[Mlir-commits] [mlir] [mlir][sparse] introduce `sparse_tensor.iterate` operation (PR #88955)

Peiming Liu llvmlistbot at llvm.org
Tue May 21 10:18:08 PDT 2024


https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/88955

>From 0723a02a1c57ae639b11801ca7c2cb77b3abc6b7 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 12 Apr 2024 22:03:06 +0000
Subject: [PATCH 1/5] [mlir][sparse] introduce `sparse_tensor.iterate`
 operation

---
 .../Dialect/SparseTensor/IR/SparseTensor.h    |  38 +++
 .../SparseTensor/IR/SparseTensorAttrDefs.td   |  15 ++
 .../SparseTensor/IR/SparseTensorOps.td        | 101 ++++++-
 .../SparseTensor/IR/SparseTensorDialect.cpp   | 255 ++++++++++++++++++
 mlir/test/Dialect/SparseTensor/invalid.mlir   |  57 ++++
 mlir/test/Dialect/SparseTensor/roundtrip.mlir |  28 ++
 .../SparseTensor/sparse_itertion_licm.mlir    |  26 ++
 7 files changed, 519 insertions(+), 1 deletion(-)
 create mode 100644 mlir/test/Dialect/SparseTensor/sparse_itertion_licm.mlir

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 3cf81d2e58f21..239118a6575d8 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -17,9 +17,13 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/TensorEncoding.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
+#include "llvm/ADT/bit.h"
+
 //===----------------------------------------------------------------------===//
 //
 // Type aliases to help code be more self-documenting. Unfortunately
@@ -54,6 +58,40 @@ struct COOSegment {
   }
 };
 
+/// A simple wrapper to encode a bitset of defined  (at most 64) levels.
+class LevelSet {
+  uint64_t bits = 0;
+
+public:
+  LevelSet() = default;
+  explicit LevelSet(uint64_t bits) : bits(bits) {}
+  operator uint64_t() const { return bits; }
+
+  LevelSet &set(unsigned i) {
+    assert(i < 64);
+    bits |= 1 << i;
+    return *this;
+  }
+
+  LevelSet &operator|=(LevelSet lhs) {
+    bits |= static_cast<uint64_t>(lhs);
+    return *this;
+  }
+
+  LevelSet &lshift(unsigned offset) {
+    bits = bits << offset;
+    return *this;
+  }
+
+  bool operator[](unsigned i) const {
+    assert(i < 64);
+    return (bits & (1 << i)) != 0;
+  }
+
+  unsigned count() const { return llvm::popcount(bits); }
+  bool empty() const { return bits == 0; }
+};
+
 } // namespace sparse_tensor
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 53dd8e39438cc..c26193af09e81 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -19,6 +19,21 @@ class SparseTensor_Attr<string name,
                         list<Trait> traits = []>
     : AttrDef<SparseTensor_Dialect, name, traits>;
 
+//===----------------------------------------------------------------------===//
+// A simple bitset attribute wrapped over a single int64_t to encode a set of
+// sparse tensor levels.
+//===----------------------------------------------------------------------===//
+
+def LevelSetAttr :
+    TypedAttrBase<
+      I64, "IntegerAttr",
+      And<[CPred<"::llvm::isa<::mlir::IntegerAttr>($_self)">,
+           CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getType().isInteger(64)">]>,
+      "LevelSet attribute"> {
+  let returnType = [{::mlir::sparse_tensor::LevelSet}];
+  let convertFromStorage = [{::mlir::sparse_tensor::LevelSet($_self.getValue().getZExtValue())}];
+}
+
 //===----------------------------------------------------------------------===//
 // These attributes are just like `IndexAttr` except that they clarify whether
 // the index refers to a dimension (an axis of the semantic tensor) or a level
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 4e4441c640ed9..0ef03a5afa733 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -15,6 +15,8 @@ include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td"
 include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/ControlFlowInterfaces.td"
+include "mlir/Interfaces/LoopLikeInterface.td"
 
 //===----------------------------------------------------------------------===//
 // Base class.
@@ -1304,7 +1306,7 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [Pure, SameOperandsAndResu
 
 def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator,
     ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp",
-                 "ForeachOp"]>]> {
+                 "ForeachOp", "IterateOp"]>]> {
   let summary = "Yield from sparse_tensor set-like operations";
   let description = [{
       Yields a value from within a `binary`, `unary`, `reduce`,
@@ -1513,6 +1515,103 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
   let hasVerifier = 1;
 }
 
+def IterateOp : SparseTensor_Op<"iterate",
+    [RecursiveMemoryEffects, RecursivelySpeculatable,
+     DeclareOpInterfaceMethods<LoopLikeOpInterface,
+      ["getInitsMutable", "getLoopResults", "getRegionIterArgs",
+       "getYieldedValuesMutable"]>,
+     DeclareOpInterfaceMethods<RegionBranchOpInterface,
+      ["getEntrySuccessorOperands"]>,
+     SingleBlockImplicitTerminator<"sparse_tensor::YieldOp">]> {
+
+  let summary = "Iterate over a sparse iteration space";
+  let description = [{
+      The `sparse_tensor.iterate` operations represents a loop over the
+      provided iteration space extracted from a specific sparse tensor.
+      The operation defines an SSA value for a sparse iterator that points
+      to the current stored element in the sparse tensor and SSA values
+      for coordinates of the stored element. The coordinates are always
+      converted to `index` type despite of the underlying sparse tensor
+      storage. When coordinates are not used, the SSA values can be skipped
+      by `_` symbols, which usually leads to simpler generated code after
+      sparsification. For example:
+
+      ```mlir
+      // The coordinate for level 0 is not used when iterating over a 2-D
+      // iteration space.
+      %sparse_tensor.iterate %iterator in %space at(_, %crd_1)
+        : !sparse_tensor.iter_space<#CSR, lvls = 0 to 2>
+      ```
+
+      `sparse_tensor.iterate` can also operate on loop-carried variables
+      and returns the final values after loop termination.
+      The initial values of the variables are passed as additional SSA operands
+      to the iterator SSA value and used coordinate SSA values mentioned
+      above. The operation region has an argument for the iterator, variadic
+      arguments for specified (used) coordiates and followed by one argument
+      for each loop-carried variable, representing the value of the variable
+      at the current iteration.
+      The body region must contain exactly one block that terminates with
+      `sparse_tensor.yield`.
+
+      `sparse_tensor.iterate` results hold the final values after the last
+      iteration. If the `sparse_tensor.iterate` defines any values, a yield
+      must be explicitly present.
+      The number and types of the `sparse_tensor.iterate` results must match
+      the initial values in the iter_args binding and the yield operands.
+
+
+      A nested `sparse_tensor.iterate` example that prints all the coordinates
+      stored in the sparse input:
+
+      ```mlir
+      func.func @nested_iterate(%sp : tensor<4x8xf32, #COO>) {
+        // Iterates over the first level of %sp
+        %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+        %r1 = sparse_tensor.iterate %it1 in %l1 at (%crd0)
+            : !sparse_tensor.iter_space<#COO, lvls = 0 to 1>  {
+          // Iterates over the second level of %sp
+          %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1
+              : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
+          %r2 = sparse_tensor.iterate %it2 in %l2 at (crd1)
+              : !sparse_tensor.iter_space<#COO, lvls = 1 to 2>  {
+             vector.print %crd0 : index
+             vector.print %crd1 : index
+          }
+        }
+      }
+
+      ```
+  }];
+
+  let arguments = (ins AnySparseIterSpace:$iterSpace,
+                       Variadic<AnyType>:$initArgs,
+                       LevelSetAttr:$crdUsedLvls);
+  let results = (outs Variadic<AnyType>:$results);
+  let regions = (region SizedRegion<1>:$region);
+
+  let extraClassDeclaration = [{
+    unsigned getSpaceDim() {
+      return getIterSpace().getType().getSpaceDim();
+    }
+    BlockArgument getIterator() {
+      return getRegion().getArguments().front();
+    }
+    Block::BlockArgListType getCrds() {
+      // The first block argument is iterator, the remaining arguments are
+      // referenced coordinates.
+      return getRegion().getArguments().slice(1, getCrdUsedLvls().count());
+    }
+    unsigned getNumRegionIterArgs() {
+      return getRegion().getArguments().size() - 1 - getCrdUsedLvls().count();
+    }
+  }];
+
+  let hasVerifier = 1;
+  let hasRegionVerifier = 1;
+  let hasCustomAssemblyFormat = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // Sparse Tensor Debugging and Test-Only Operations.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 4adb1c19096a2..0a9254573538d 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -2130,6 +2130,106 @@ static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
   printLevelRange(p, lo, hi);
 }
 
+static ParseResult
+parseSparseSpaceLoop(OpAsmParser &parser, OperationState &state,
+                     SmallVectorImpl<OpAsmParser::Argument> &iterators,
+                     SmallVectorImpl<OpAsmParser::Argument> &iterArgs) {
+  SmallVector<OpAsmParser::UnresolvedOperand> spaces;
+  SmallVector<OpAsmParser::UnresolvedOperand> initArgs;
+
+  // Parses "%iters, ... in %spaces, ..."
+  if (parser.parseArgumentList(iterators) || parser.parseKeyword("in") ||
+      parser.parseOperandList(spaces))
+    return failure();
+
+  if (iterators.size() != spaces.size())
+    return parser.emitError(
+        parser.getNameLoc(),
+        "mismatch in number of sparse iterators and sparse spaces");
+
+  // Parse "at(%crd0, _, ...)"
+  LevelSet crdUsedLvlSet;
+  bool hasUsedCrds = succeeded(parser.parseOptionalKeyword("at"));
+  unsigned lvlCrdCnt = 0;
+  if (hasUsedCrds) {
+    ParseResult crdList = parser.parseCommaSeparatedList(
+        OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
+          if (parser.parseOptionalKeyword("_")) {
+            if (parser.parseArgument(iterArgs.emplace_back()))
+              return failure();
+            // Always use IndexType for the coordinate.
+            crdUsedLvlSet.set(lvlCrdCnt);
+            iterArgs.back().type = parser.getBuilder().getIndexType();
+          }
+          lvlCrdCnt += 1;
+          return success();
+        });
+    if (failed(crdList)) {
+      return parser.emitError(
+          parser.getNameLoc(),
+          "expecting SSA value or \"_\" for level coordinates");
+    }
+  }
+  // Set the CrdUsedLvl bitset.
+  state.addAttribute("crdUsedLvls",
+                     parser.getBuilder().getI64IntegerAttr(crdUsedLvlSet));
+
+  // Parse "iter_args(%arg = %init, ...)"
+  bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
+  if (hasIterArgs)
+    if (parser.parseAssignmentList(iterArgs, initArgs))
+      return failure();
+
+  SmallVector<Type> iterSpaceTps;
+  // parse ": sparse_tensor.iter_space -> ret"
+  if (parser.parseColon() || parser.parseTypeList(iterSpaceTps))
+    return failure();
+  if (iterSpaceTps.size() != spaces.size())
+    return parser.emitError(parser.getNameLoc(),
+                            "mismatch in number of iteration space operands "
+                            "and iteration space types");
+
+  for (auto [it, tp] : llvm::zip_equal(iterators, iterSpaceTps)) {
+    IterSpaceType spaceTp = llvm::dyn_cast<IterSpaceType>(tp);
+    if (!spaceTp)
+      return parser.emitError(parser.getNameLoc(),
+                              "expected sparse_tensor.iter_space type for "
+                              "iteration space operands");
+    if (hasUsedCrds && spaceTp.getSpaceDim() != lvlCrdCnt)
+      return parser.emitError(parser.getNameLoc(),
+                              "mismatch in number of iteration space dimension "
+                              "and specified coordinates");
+    it.type = spaceTp.getIteratorType();
+  }
+
+  if (hasIterArgs)
+    if (parser.parseArrowTypeList(state.types))
+      return failure();
+
+  // Resolves input operands.
+  if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(),
+                             state.operands))
+    return failure();
+
+  if (hasIterArgs) {
+    unsigned numCrds = crdUsedLvlSet.count();
+    // Strip off leading args that used for coordinates.
+    MutableArrayRef args = MutableArrayRef(iterArgs).drop_front(numCrds);
+    if (args.size() != initArgs.size() || args.size() != state.types.size()) {
+      return parser.emitError(
+          parser.getNameLoc(),
+          "mismatch in number of iteration arguments and return values");
+    }
+
+    for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) {
+      it.type = tp;
+      if (parser.resolveOperand(init, tp, state.operands))
+        return failure();
+    }
+  }
+  return success();
+}
+
 LogicalResult ExtractIterSpaceOp::inferReturnTypes(
     MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
     DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
@@ -2166,6 +2266,161 @@ LogicalResult ExtractIterSpaceOp::verify() {
   return success();
 }
 
+ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
+  OpAsmParser::Argument iterator;
+  OpAsmParser::UnresolvedOperand iterSpace;
+
+  SmallVector<OpAsmParser::Argument> iters, iterArgs;
+  if (parseSparseSpaceLoop(parser, result, iters, iterArgs))
+    return failure();
+  if (iters.size() != 1)
+    return parser.emitError(parser.getNameLoc(),
+                            "expected only one iterator/iteration space");
+
+  iters.append(iterArgs);
+  Region *body = result.addRegion();
+  if (parser.parseRegion(*body, iters))
+    return failure();
+
+  IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
+
+  // Parse the optional attribute list.
+  if (parser.parseOptionalAttrDict(result.attributes))
+    return failure();
+
+  return success();
+}
+
+/// Prints the initialization list in the form of
+///   <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
+/// where 'inner' values are assumed to be region arguments and 'outer' values
+/// are regular SSA values.
+static void printInitializationList(OpAsmPrinter &p,
+                                    Block::BlockArgListType blocksArgs,
+                                    ValueRange initializers,
+                                    StringRef prefix = "") {
+  assert(blocksArgs.size() == initializers.size() &&
+         "expected same length of arguments and initializers");
+  if (initializers.empty())
+    return;
+
+  p << prefix << '(';
+  llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
+    p << std::get<0>(it) << " = " << std::get<1>(it);
+  });
+  p << ")";
+}
+
+static void printUsedCrdsList(OpAsmPrinter &p, unsigned spaceDim,
+                              Block::BlockArgListType blocksArgs,
+                              LevelSet crdUsedLvls) {
+  if (crdUsedLvls.empty())
+    return;
+
+  p << " at(";
+  for (unsigned i = 0; i < spaceDim; i++) {
+    if (crdUsedLvls[i]) {
+      p << blocksArgs.front();
+      blocksArgs = blocksArgs.drop_front();
+    } else {
+      p << "_";
+    }
+    if (i != spaceDim - 1)
+      p << ", ";
+  }
+  assert(blocksArgs.empty());
+  p << ")";
+}
+
+void IterateOp::print(OpAsmPrinter &p) {
+  p << " " << getIterator() << " in " << getIterSpace();
+  printUsedCrdsList(p, getSpaceDim(), getCrds(), getCrdUsedLvls());
+  printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
+
+  p << " : " << getIterSpace().getType() << " ";
+  if (!getInitArgs().empty())
+    p << "-> (" << getInitArgs().getTypes() << ") ";
+
+  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
+                /*printBlockTerminators=*/!getInitArgs().empty());
+}
+
+LogicalResult IterateOp::verify() {
+  if (getInitArgs().size() != getNumResults()) {
+    return emitOpError(
+        "mismatch in number of loop-carried values and defined values");
+  }
+  return success();
+}
+
+LogicalResult IterateOp::verifyRegions() {
+  if (getIterator().getType() != getIterSpace().getType().getIteratorType())
+    return emitOpError("mismatch in iterator and iteration space type");
+  if (getNumRegionIterArgs() != getNumResults())
+    return emitOpError(
+        "mismatch in number of basic block args and defined values");
+
+  auto initArgs = getInitArgs();
+  auto iterArgs = getRegionIterArgs();
+  auto yieldVals = getYieldedValues();
+  auto opResults = getResults();
+  if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
+                        opResults.size()})) {
+    return emitOpError() << "number mismatch between iter args and results.";
+  }
+
+  for (auto [i, init, iter, yield, ret] :
+       llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) {
+    if (init.getType() != ret.getType())
+      return emitOpError() << "types mismatch between " << i
+                           << "th iter operand and defined value";
+    if (iter.getType() != ret.getType())
+      return emitOpError() << "types mismatch between " << i
+                           << "th iter region arg and defined value";
+    if (yield.getType() != ret.getType())
+      return emitOpError() << "types mismatch between " << i
+                           << "th yield value and defined value";
+  }
+
+  return success();
+}
+
+/// IterateOp implemented OpInterfaces' methods.
+SmallVector<Region *> IterateOp::getLoopRegions() { return {&getRegion()}; }
+
+MutableArrayRef<OpOperand> IterateOp::getInitsMutable() {
+  return getInitArgsMutable();
+}
+
+Block::BlockArgListType IterateOp::getRegionIterArgs() {
+  return getRegion().getArguments().take_back(getNumRegionIterArgs());
+}
+
+std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() {
+  return cast<sparse_tensor::YieldOp>(
+             getRegion().getBlocks().front().getTerminator())
+      .getResultsMutable();
+}
+
+std::optional<ResultRange> IterateOp::getLoopResults() { return getResults(); }
+
+OperandRange IterateOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+  return getInitArgs();
+}
+
+void IterateOp::getSuccessorRegions(RegionBranchPoint point,
+                                    SmallVectorImpl<RegionSuccessor> &regions) {
+  // Both the operation itself and the region may be branching into the body or
+  // back into the operation itself.
+  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
+  // It is possible for loop not to enter the body.
+  regions.push_back(RegionSuccessor(getResults()));
+}
+
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Dialect Setups.
+//===----------------------------------------------------------------------===//
+
 /// Materialize a single constant operation from a given attribute value with
 /// the desired resultant type.
 Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 3fa696e1600a9..b13024cd4ed99 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -1094,3 +1094,60 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse
   %l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 2 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
   return
 }
+
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : compressed(nonunique),
+    j : singleton(soa)
+  )
+}>
+
+func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index, %j : index) -> index {
+  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+  // expected-error @+1 {{'sparse_tensor.iterate' op different number of region iter_args and yielded values: 2 != 1}}
+  %r1, %r2 = sparse_tensor.iterate %it1 in %l1 at (%crd) iter_args(%si = %i, %sj = %j): !sparse_tensor.iter_space<#COO, lvls = 0> -> (index, index) {
+    sparse_tensor.yield %si : index
+  }
+  return %r1 : index
+}
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : compressed(nonunique),
+    j : singleton(soa)
+  )
+}>
+
+// expected-note at +1 {{prior use here}}
+func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index) -> f32 {
+  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+  // expected-error @+1 {{use of value '%i' expects different type than prior uses: 'f32' vs 'index'}}
+  %r1 = sparse_tensor.iterate %it1 in %l1 at (%crd) iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0> -> f32 {
+    sparse_tensor.yield %outer : f32
+  }
+  return %r1 : f32
+}
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : compressed(nonunique),
+    j : singleton(soa)
+  )
+}>
+
+func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index, %j : index) -> index {
+  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+  // expected-error @+1 {{'sparse_tensor.iterate' op 0-th region iter_arg and 0-th yielded value have different type: 'index' != 'f32'}}
+  %r1 = sparse_tensor.iterate %it1 in %l1 at (%crd) iter_args(%si = %i): !sparse_tensor.iter_space<#COO, lvls = 0> -> index {
+    %y = arith.constant 1.0 :  f32
+    sparse_tensor.yield %y : f32
+  }
+  return %r1 : index
+}
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index d34071279e512..e9a898f16b41d 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -763,3 +763,31 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse
   %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
   return %l1, %l2 : !sparse_tensor.iter_space<#COO, lvls = 0>, !sparse_tensor.iter_space<#COO, lvls = 1>
 }
+
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : compressed(nonunique),
+    j : singleton(soa)
+  )
+}>
+
+// CHECK-LABEL:   func.func @sparse_iterate(
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<4x8xf32, #sparse{{[0-9]*}}>,
+// CHECK-SAME:      %[[VAL_1:.*]]: index,
+// CHECK-SAME:      %[[VAL_2:.*]]: index) -> index {
+// CHECK:           %[[VAL_3:.*]] = sparse_tensor.extract_iteration_space %[[VAL_0]] lvls = 0 : tensor<4x8xf32, #sparse{{[0-9]*}}>
+// CHECK:           %[[VAL_4:.*]] = sparse_tensor.iterate %[[VAL_5:.*]] in %[[VAL_3]] at(%[[VAL_6:.*]]) iter_args(%[[VAL_7:.*]] = %[[VAL_1]]) : !sparse_tensor.iter_space<#sparse{{[0-9]*}}, lvls = 0> -> (index) {
+// CHECK:             sparse_tensor.yield %[[VAL_7]] : index
+// CHECK:           }
+// CHECK:           return %[[VAL_4]] : index
+// CHECK:         }
+func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index, %j : index) -> index {
+  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+  %r1 = sparse_tensor.iterate %it1 in %l1 at (%crd) iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0 to 1> -> index {
+    sparse_tensor.yield %outer : index
+  }
+  return %r1 : index
+}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_itertion_licm.mlir b/mlir/test/Dialect/SparseTensor/sparse_itertion_licm.mlir
new file mode 100644
index 0000000000000..e7158d04b37fe
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_itertion_licm.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt %s --loop-invariant-code-motion | FileCheck %s
+
+#CSR = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : dense,
+    j : compressed
+  )
+}>
+
+// Make sure that pure instructions are hoisted outside the loop.
+//
+// CHECK: sparse_tensor.values
+// CHECK: sparse_tensor.positions
+// CHECK: sparse_tensor.coordinate
+// CHECK: sparse_tensor.iterate
+func.func @sparse_iterate(%sp : tensor<?x?xf64, #CSR>) {
+  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<?x?xf64, #CSR>
+  sparse_tensor.iterate %it1 in %l1 at (%crd) : !sparse_tensor.iter_space<#CSR, lvls = 0> {
+    %0 = sparse_tensor.values %sp : tensor<?x?xf64, #CSR> to memref<?xf64>
+    %1 = sparse_tensor.positions %sp { level = 1 : index } : tensor<?x?xf64, #CSR> to memref<?xindex>
+    %2 = sparse_tensor.coordinates  %sp { level = 1 : index } : tensor<?x?xf64, #CSR> to memref<?xindex>
+    "test.op"(%0, %1, %2) : (memref<?xf64>, memref<?xindex>, memref<?xindex>) -> ()
+  }
+
+  return
+}

>From e28739e8cb7b9edb758cd213041675b1963702c9 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 17 Apr 2024 21:50:31 +0000
Subject: [PATCH 2/5] address comments

---
 mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h      | 4 ++--
 .../mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td      | 2 +-
 mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td  | 4 ++--
 mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp      | 2 +-
 4 files changed, 6 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 239118a6575d8..338adf17be891 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -58,7 +58,7 @@ struct COOSegment {
   }
 };
 
-/// A simple wrapper to encode a bitset of defined  (at most 64) levels.
+/// A simple wrapper to encode a bitset of defined (at most 64) levels.
 class LevelSet {
   uint64_t bits = 0;
 
@@ -69,7 +69,7 @@ class LevelSet {
 
   LevelSet &set(unsigned i) {
     assert(i < 64);
-    bits |= 1 << i;
+    bits |= static_cast<uint64_t>(0x01u) << i;
     return *this;
   }
 
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index c26193af09e81..69b212cce4ceb 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -20,7 +20,7 @@ class SparseTensor_Attr<string name,
     : AttrDef<SparseTensor_Dialect, name, traits>;
 
 //===----------------------------------------------------------------------===//
-// A simple bitset attribute wrapped over a single int64_t to encode a set of
+// A simple bitset attribute wrapped around a single int64_t to encode a set of
 // sparse tensor levels.
 //===----------------------------------------------------------------------===//
 
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 0ef03a5afa733..735a35c139834 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1526,8 +1526,8 @@ def IterateOp : SparseTensor_Op<"iterate",
 
   let summary = "Iterate over a sparse iteration space";
   let description = [{
-      The `sparse_tensor.iterate` operations represents a loop over the
-      provided iteration space extracted from a specific sparse tensor.
+      The `sparse_tensor.iterate` operation represents a loop (nest) over
+      the provided iteration space extracted from a specific sparse tensor.
       The operation defines an SSA value for a sparse iterator that points
       to the current stored element in the sparse tensor and SSA values
       for coordinates of the stored element. The coordinates are always
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 0a9254573538d..b5b92f9bfa0d2 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -2385,7 +2385,7 @@ LogicalResult IterateOp::verifyRegions() {
   return success();
 }
 
-/// IterateOp implemented OpInterfaces' methods.
+/// OpInterfaces' methods implemented by IterateOp.
 SmallVector<Region *> IterateOp::getLoopRegions() { return {&getRegion()}; }
 
 MutableArrayRef<OpOperand> IterateOp::getInitsMutable() {

>From 1d781d84da5c1c3d1012dd9cddfe26a9ecef9616 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 20 May 2024 21:00:26 +0000
Subject: [PATCH 3/5] use coord instead of crd

---
 .../mlir/Dialect/SparseTensor/IR/SparseTensorOps.td       | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 735a35c139834..2cc367931ce45 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1568,15 +1568,15 @@ def IterateOp : SparseTensor_Op<"iterate",
       func.func @nested_iterate(%sp : tensor<4x8xf32, #COO>) {
         // Iterates over the first level of %sp
         %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
-        %r1 = sparse_tensor.iterate %it1 in %l1 at (%crd0)
+        %r1 = sparse_tensor.iterate %it1 in %l1 at (%coord0)
             : !sparse_tensor.iter_space<#COO, lvls = 0 to 1>  {
           // Iterates over the second level of %sp
           %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1
               : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
-          %r2 = sparse_tensor.iterate %it2 in %l2 at (crd1)
+          %r2 = sparse_tensor.iterate %it2 in %l2 at (coord1)
               : !sparse_tensor.iter_space<#COO, lvls = 1 to 2>  {
-             vector.print %crd0 : index
-             vector.print %crd1 : index
+             vector.print %coord0 : index
+             vector.print %coord1 : index
           }
         }
       }

>From da154f428beff961577be30419e47d4569727d20 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 20 May 2024 21:25:46 +0000
Subject: [PATCH 4/5] make iter_space type explicit

---
 .../mlir/Dialect/SparseTensor/IR/SparseTensorOps.td  | 12 ++++++------
 mlir/test/Dialect/SparseTensor/invalid.mlir          | 12 ++++++++----
 mlir/test/Dialect/SparseTensor/roundtrip.mlir        |  5 +++--
 .../Dialect/SparseTensor/sparse_itertion_licm.mlir   |  1 +
 4 files changed, 18 insertions(+), 12 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 2cc367931ce45..6b337d58cf297 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1510,7 +1510,7 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
                        LevelAttr:$loLvl, LevelAttr:$hiLvl);
   let results = (outs AnySparseIterSpace:$resultSpace);
   let assemblyFormat = "$tensor (`at` $parentIter^)? `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) "
-                       " attr-dict `:` type($tensor) (`,` type($parentIter)^)?";
+                       " attr-dict `:` type($tensor) (`,` type($parentIter)^)? `->` type($resultSpace)";
 
   let hasVerifier = 1;
 }
@@ -1543,8 +1543,8 @@ def IterateOp : SparseTensor_Op<"iterate",
         : !sparse_tensor.iter_space<#CSR, lvls = 0 to 2>
       ```
 
-      `sparse_tensor.iterate` can also operate on loop-carried variables
-      and returns the final values after loop termination.
+      `sparse_tensor.iterate` can also operate on loop-carried variables.
+      It returns the final values after loop termination.
       The initial values of the variables are passed as additional SSA operands
       to the iterator SSA value and used coordinate SSA values mentioned
       above. The operation region has an argument for the iterator, variadic
@@ -1554,9 +1554,9 @@ def IterateOp : SparseTensor_Op<"iterate",
       The body region must contain exactly one block that terminates with
       `sparse_tensor.yield`.
 
-      `sparse_tensor.iterate` results hold the final values after the last
-      iteration. If the `sparse_tensor.iterate` defines any values, a yield
-      must be explicitly present.
+      The results of an `sparse_tensor.iterate` hold the final values after
+      the last iteration. If the `sparse_tensor.iterate` defines any values,
+      a yield must be explicitly present.
       The number and types of the `sparse_tensor.iterate` results must match
       the initial values in the iter_args binding and the yield operands.
 
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index b13024cd4ed99..eb0dc01be25b9 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -1025,6 +1025,7 @@ func.func @sparse_print(%arg0: tensor<10x10xf64>) {
 func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 2>) {
   // expected-error at +1 {{'sparse_tensor.extract_iteration_space' expect larger level upper bound than lower bound}}
   %l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 2 to 0 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 2>
+                                                                       -> !sparse_tensor.iter_space<#COO, lvls = 0 to 2>
   return
 }
 
@@ -1040,6 +1041,7 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse
 func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>) {
   // expected-error at +1 {{'sparse_tensor.extract_iteration_space' op parent iterator should be specified iff level lower bound equals 0}}
   %l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 0 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
+                                                                  -> !sparse_tensor.iter_space<#COO, lvls = 1>
   return
 }
 
@@ -1054,7 +1056,7 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse
 
 func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>) {
   // expected-error at +1 {{'sparse_tensor.extract_iteration_space' op parent iterator should be specified iff level lower bound equals 0}}
-  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 1 : tensor<4x8xf32, #COO>
+  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 1 : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 1>
   return
 }
 
@@ -1077,6 +1079,7 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>) {
 func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#CSR, lvls = 0>) {
   // expected-error at +1 {{'sparse_tensor.extract_iteration_space' op mismatch in parent iterator encoding and iteration space encoding.}}
   %l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#CSR, lvls = 0>
+                                                                 -> !sparse_tensor.iter_space<#COO, lvls = 1>
   return
 }
 
@@ -1092,6 +1095,7 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse
 func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>) {
   // expected-error at +1 {{'sparse_tensor.extract_iteration_space' op parent iterator should be used to extract an iteration space from a consecutive level.}}
   %l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 2 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
+                                                                  -> !sparse_tensor.iter_space<#COO, lvls = 2>
   return
 }
 
@@ -1106,7 +1110,7 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse
 }>
 
 func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index, %j : index) -> index {
-  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
   // expected-error @+1 {{'sparse_tensor.iterate' op different number of region iter_args and yielded values: 2 != 1}}
   %r1, %r2 = sparse_tensor.iterate %it1 in %l1 at (%crd) iter_args(%si = %i, %sj = %j): !sparse_tensor.iter_space<#COO, lvls = 0> -> (index, index) {
     sparse_tensor.yield %si : index
@@ -1125,7 +1129,7 @@ func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index, %j : index) -
 
 // expected-note at +1 {{prior use here}}
 func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index) -> f32 {
-  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
   // expected-error @+1 {{use of value '%i' expects different type than prior uses: 'f32' vs 'index'}}
   %r1 = sparse_tensor.iterate %it1 in %l1 at (%crd) iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0> -> f32 {
     sparse_tensor.yield %outer : f32
@@ -1143,7 +1147,7 @@ func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index) -> f32 {
 }>
 
 func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index, %j : index) -> index {
-  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
   // expected-error @+1 {{'sparse_tensor.iterate' op 0-th region iter_arg and 0-th yielded value have different type: 'index' != 'f32'}}
   %r1 = sparse_tensor.iterate %it1 in %l1 at (%crd) iter_args(%si = %i): !sparse_tensor.iter_space<#COO, lvls = 0> -> index {
     %y = arith.constant 1.0 :  f32
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index e9a898f16b41d..bce0b41a99828 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -758,9 +758,10 @@ func.func @sparse_has_runtime() -> i1 {
 func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>)
   -> (!sparse_tensor.iter_space<#COO, lvls = 0>, !sparse_tensor.iter_space<#COO, lvls = 1>) {
   // Extracting the iteration space for the first level needs no parent iterator.
-  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
   // Extracting the iteration space for the second level needs a parent iterator.
   %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
+                                                                 -> !sparse_tensor.iter_space<#COO, lvls = 1>
   return %l1, %l2 : !sparse_tensor.iter_space<#COO, lvls = 0>, !sparse_tensor.iter_space<#COO, lvls = 1>
 }
 
@@ -785,7 +786,7 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse
 // CHECK:           return %[[VAL_4]] : index
 // CHECK:         }
 func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index, %j : index) -> index {
-  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
   %r1 = sparse_tensor.iterate %it1 in %l1 at (%crd) iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0 to 1> -> index {
     sparse_tensor.yield %outer : index
   }
diff --git a/mlir/test/Dialect/SparseTensor/sparse_itertion_licm.mlir b/mlir/test/Dialect/SparseTensor/sparse_itertion_licm.mlir
index e7158d04b37fe..f70fab3b7251d 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_itertion_licm.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_itertion_licm.mlir
@@ -15,6 +15,7 @@
 // CHECK: sparse_tensor.iterate
 func.func @sparse_iterate(%sp : tensor<?x?xf64, #CSR>) {
   %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<?x?xf64, #CSR>
+                                                         -> !sparse_tensor.iter_space<#CSR, lvls = 0>
   sparse_tensor.iterate %it1 in %l1 at (%crd) : !sparse_tensor.iter_space<#CSR, lvls = 0> {
     %0 = sparse_tensor.values %sp : tensor<?x?xf64, #CSR> to memref<?xf64>
     %1 = sparse_tensor.positions %sp { level = 1 : index } : tensor<?x?xf64, #CSR> to memref<?xindex>

>From 3f0e4eb6e8e6448f150e42464f69830a4351efba Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 21 May 2024 17:17:53 +0000
Subject: [PATCH 5/5] update example code

---
 mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td | 5 ++++-
 1 file changed, 4 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 6b337d58cf297..976081849b104 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1489,6 +1489,7 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
       // Extracts a 1-D iteration space from a COO tensor at level 1.
       %space = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 1
         : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
+       ->!sparse_tensor.iter_space<#COO, lvls = 1>
       ```
   }];
 
@@ -1567,12 +1568,14 @@ def IterateOp : SparseTensor_Op<"iterate",
       ```mlir
       func.func @nested_iterate(%sp : tensor<4x8xf32, #COO>) {
         // Iterates over the first level of %sp
-        %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+        %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0
+            : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0 to 1>
         %r1 = sparse_tensor.iterate %it1 in %l1 at (%coord0)
             : !sparse_tensor.iter_space<#COO, lvls = 0 to 1>  {
           // Iterates over the second level of %sp
           %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1
               : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
+             -> !sparse_tensor.iter_space<#COO, lvls = 1 to 2>
           %r2 = sparse_tensor.iterate %it2 in %l2 at (coord1)
               : !sparse_tensor.iter_space<#COO, lvls = 1 to 2>  {
              vector.print %coord0 : index



More information about the Mlir-commits mailing list