[Mlir-commits] [mlir] [WIP][mlir][sparse] Setting up sparse_tensor.iterator-related Ops. (PR #85958)

Peiming Liu llvmlistbot at llvm.org
Tue Apr 9 11:48:57 PDT 2024


https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/85958

>From eb2b84b25232c639ddd0bee0da2c51fbdc413a59 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 18 Mar 2024 19:40:46 +0000
Subject: [PATCH 1/6] test parse iterate operation

---
 .../SparseTensor/IR/SparseTensorOps.td        | 17 +++++
 .../SparseTensor/IR/SparseTensorTypes.td      | 76 +++++++++++++++++++
 2 files changed, 93 insertions(+)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 5df8a176459b7c..dfa08167929492 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1433,6 +1433,23 @@ def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Iteration Operations.
+//===----------------------------------------------------------------------===//
+
+def IterateOp : SparseTensor_Op<"iterate",
+    [RecursiveMemoryEffects]> {
+
+  let arguments = (ins AnySparseIterSpace:$iterSpace,
+                       Variadic<AnyType>:$initArgs);
+  let results = (outs Variadic<AnyType>:$results);
+  let regions = (region SizedRegion<1>:$region);
+
+  let extraClassDeclaration = [{}];
+
+  let hasCustomAssemblyFormat = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // Sparse Tensor Debugging and Test-Only Operations.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
index 185cff46ae25d5..54a8e4d7ecd398 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
@@ -72,4 +72,80 @@ def SparseTensorStorageSpecifier
     : Type<CPred<"::llvm::isa<::mlir::sparse_tensor::StorageSpecifierType>($_self)">, "metadata",
           "::mlir::sparse_tensor::StorageSpecifierType">;
 
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Iteration Types.
+//===----------------------------------------------------------------------===//
+
+def LevelTypeArrayParameter : ArrayRefParameter<"::mlir::sparse_tensor::LevelType", "level-types"> {
+  let printer = [{
+    auto lvlStrings = llvm::map_range($_self, [](auto lt){ return lt.toMLIRString(); });
+    $_printer << "[" << llvm::join(lvlStrings, ",") << "]";
+  }];
+
+  let parser = [{ [&]() -> FailureOr<SmallVector<::mlir::sparse_tensor::LevelType>> {
+    SmallVector<::mlir::sparse_tensor::LevelType> ret;
+
+    const auto res = $_parser.parseCommaSeparatedList(
+      mlir::OpAsmParser::Delimiter::Square,
+      [&]() -> ParseResult {
+        ::mlir::sparse_tensor::ir_detail::LvlTypeParser lParser;
+        auto lvlTpOrFail = lParser.parseLvlType($_parser);
+        if (failed(lvlTpOrFail))
+          return failure();
+        ret.emplace_back(*lvlTpOrFail);
+        return success();
+      }, " in level-type list");
+
+    if (failed(res))
+      return failure();
+    return ret;
+  }() }];
+}
+
+def SparseTensor_Iterator : SparseTensor_Type<"Iterator"> {
+  let mnemonic = "iterator";
+
+  let parameters = (ins
+     LevelTypeArrayParameter: $lvlTypes
+  );
+
+  // let skipDefaultBuilders = 1;
+  // let hasCustomAssemblyFormat = 1;
+  let assemblyFormat="`<` $lvlTypes `>`";
+}
+
+
+def SparseTensor_IterSpace : SparseTensor_Type<"IterSpace"> {
+  let mnemonic = "iter_space";
+
+  let parameters = (ins
+     LevelTypeArrayParameter: $lvlTypes
+  );
+
+  let extraClassDeclaration = [{
+     ::mlir::sparse_tensor::IteratorType getIteratorType() const {
+        return IteratorType::get(getContext(), getLvlTypes());
+     }
+  }];
+
+  // let skipDefaultBuilders = 1;
+  // let hasCustomAssemblyFormat = 1;
+  let assemblyFormat="`<` $lvlTypes `>`";
+}
+
+def IsSparseSparseIterSpaceTypePred
+    : CPred<"::llvm::isa<::mlir::sparse_tensor::IterSpaceType>($_self)">;
+
+def IsSparseSparseIteratorTypePred
+    : CPred<"::llvm::isa<::mlir::sparse_tensor::IteratorType>($_self)">;
+
+def AnySparseIterSpace
+    : Type<IsSparseSparseIterSpaceTypePred, "sparse iteration space",
+          "::mlir::sparse_tensor::IterSpaceType">;
+
+def AnySparseIterator
+    : Type<IsSparseSparseIteratorTypePred, "sparse iterator",
+          "::mlir::sparse_tensor::IteratorType">;
+
+
 #endif // SPARSETENSOR_TYPES

>From b43f1ff4c3537649ef68450fdcd5d2654979f189 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 20 Mar 2024 16:14:12 +0000
Subject: [PATCH 2/6] test sparse space collapse

---
 .../Dialect/SparseTensor/IR/SparseTensor.h    |   2 +
 .../SparseTensor/IR/SparseTensorOps.td        |  41 ++-
 .../SparseTensor/IR/SparseTensorTypes.td      |  15 +-
 .../Dialect/SparseTensor/Transforms/Passes.h  |   6 +
 .../Dialect/SparseTensor/Transforms/Passes.td |  16 ++
 .../SparseTensor/IR/SparseTensorDialect.cpp   | 251 ++++++++++++++++++
 .../SparseTensor/Transforms/CMakeLists.txt    |   1 +
 .../Transforms/SparseSpaceCollapse.cpp        | 152 +++++++++++
 8 files changed, 478 insertions(+), 6 deletions(-)
 create mode 100644 mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 5e523ec428aefb..78692307820bc5 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -17,7 +17,9 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/TensorEncoding.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index dfa08167929492..e1fa48d202a375 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -15,6 +15,8 @@ include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td"
 include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/ControlFlowInterfaces.td"
+include "mlir/Interfaces/LoopLikeInterface.td"
 
 //===----------------------------------------------------------------------===//
 // Base class.
@@ -1437,16 +1439,51 @@ def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
 // Sparse Tensor Iteration Operations.
 //===----------------------------------------------------------------------===//
 
+def ExtractIterSpaceOp : SparseTensor_Op<"iteration.extract_space",
+    [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+
+  let arguments = (ins AnySparseTensor:$tensor,
+                       Optional<AnySparseIterator>:$parentIter,
+                       LevelAttr:$loLvl, LevelAttr:$hiLvl);
+
+  let results = (outs AnySparseIterSpace:$resultSpace);
+
+  let extraClassDeclaration = [{
+    unsigned getSpaceDim() {
+      return getHiLvl() - getLoLvl();
+    }
+  }];
+
+  let hasVerifier = 1;
+  let assemblyFormat = "$tensor (`at`$parentIter^)? `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) "
+                       " attr-dict `:` type($tensor) (`,` type($parentIter)^)?";
+}
+
 def IterateOp : SparseTensor_Op<"iterate",
-    [RecursiveMemoryEffects]> {
+    [RecursiveMemoryEffects, RecursivelySpeculatable,
+     DeclareOpInterfaceMethods<LoopLikeOpInterface,
+      ["getInitsMutable", "getLoopResults", "getRegionIterArgs",
+      "getSingleInductionVar", "getYieldedValuesMutable"]>,
+     DeclareOpInterfaceMethods<RegionBranchOpInterface,
+      ["getEntrySuccessorOperands"]>,
+     SingleBlockImplicitTerminator<"sparse_tensor::YieldOp">]> {
 
   let arguments = (ins AnySparseIterSpace:$iterSpace,
                        Variadic<AnyType>:$initArgs);
   let results = (outs Variadic<AnyType>:$results);
   let regions = (region SizedRegion<1>:$region);
 
-  let extraClassDeclaration = [{}];
+  let extraClassDeclaration = [{
+    BlockArgument getIterator() {
+      return getRegion().getArguments().front();
+    }
+    unsigned getNumRegionIterArgs() {
+      return getRegion().getArguments().size() - 1;
+    }
+  }];
 
+  let hasVerifier = 1;
+  let hasRegionVerifier = 1;
   let hasCustomAssemblyFormat = 1;
 }
 
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
index 54a8e4d7ecd398..aa674b613e71db 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
@@ -109,8 +109,13 @@ def SparseTensor_Iterator : SparseTensor_Type<"Iterator"> {
      LevelTypeArrayParameter: $lvlTypes
   );
 
-  // let skipDefaultBuilders = 1;
-  // let hasCustomAssemblyFormat = 1;
+  let extraClassDeclaration = [{
+     unsigned getSpaceDim() {
+       return getLvlTypes().size();
+     }
+  }];
+
+
   let assemblyFormat="`<` $lvlTypes `>`";
 }
 
@@ -123,13 +128,15 @@ def SparseTensor_IterSpace : SparseTensor_Type<"IterSpace"> {
   );
 
   let extraClassDeclaration = [{
+     unsigned getSpaceDim() {
+       return getLvlTypes().size();
+     }
+
      ::mlir::sparse_tensor::IteratorType getIteratorType() const {
         return IteratorType::get(getContext(), getLvlTypes());
      }
   }];
 
-  // let skipDefaultBuilders = 1;
-  // let hasCustomAssemblyFormat = 1;
   let assemblyFormat="`<` $lvlTypes `>`";
 }
 
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 61b07d222d156b..0e9f5120f7b3dc 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -247,6 +247,12 @@ std::unique_ptr<Pass> createSparsificationAndBufferizationPass(
     bool enableBufferInitialization, unsigned vectorLength,
     bool enableVLAVectorization, bool enableSIMDIndex32, bool enableGPULibgen);
 
+//===----------------------------------------------------------------------===//
+// Sparse Iteration Transform Passes
+//===----------------------------------------------------------------------===//
+
+std::unique_ptr<Pass> createSparseSpaceCollapsePass();
+
 //===----------------------------------------------------------------------===//
 // Registration.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 58e2d6f32386c3..3ab75c23dbefa0 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -454,4 +454,20 @@ def SparsificationAndBufferization : Pass<"sparsification-and-bufferization", "M
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// Sparse Iteration Transform Passes
+//===----------------------------------------------------------------------===//
+
+def SparseSpaceCollapse : Pass<"sparse-space-collapse", "func::FuncOp"> {
+  let summary = "(experimental) sparse space collpasing pass";
+  let description = [{
+     This pass collapse consecutive sparse spaces (extracted from the same tensor)
+     into one multi-dimensional space.
+  }];
+  let constructor = "mlir::createSparseSpaceCollapsePass()";
+  let dependentDialects = [
+    "sparse_tensor::SparseTensorDialect",
+  ];
+}
+
 #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index e4d93c5623b9c4..f26f4a0a14707f 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1909,6 +1909,257 @@ LogicalResult SortOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Iteration Operations.
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr,
+                                   IntegerAttr &lvlHiAttr) {
+  Level lvlLo, lvlHi;
+
+  if (parser.parseInteger(lvlLo))
+    return failure();
+
+  if (succeeded(parser.parseOptionalKeyword("to"))) {
+    if (parser.parseInteger(lvlHi))
+      return failure();
+  } else {
+    lvlHi = lvlLo + 1;
+  }
+  lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo);
+  lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi);
+  return success();
+}
+
+static void printLevelRange(OpAsmPrinter &p, ExtractIterSpaceOp op,
+                            IntegerAttr lvlLo, IntegerAttr lvlHi) {
+  if (op.getLoLvl() + 1 == op.getHiLvl())
+    p << op.getLoLvl();
+  else
+    p << op.getLoLvl() << " to " << op.getHiLvl();
+}
+
+LogicalResult ExtractIterSpaceOp::inferReturnTypes(
+    MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
+    DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
+    SmallVectorImpl<mlir::Type> &ret) {
+
+  ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region);
+  SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
+  ArrayRef<LevelType> lts = stt.getLvlTypes().slice(
+      adaptor.getLoLvl(), adaptor.getHiLvl() - adaptor.getLoLvl());
+  ret.push_back(IterSpaceType::get(ctx, lts));
+  return success();
+}
+
+LogicalResult ExtractIterSpaceOp::verify() {
+  SparseTensorType stt = getSparseTensorType(getTensor());
+  if (getLoLvl() >= getHiLvl())
+    return emitOpError("expected smaller level low than level high");
+
+  ArrayRef<LevelType> lts = stt.getLvlTypes().slice(getLoLvl(), getSpaceDim());
+  if (!getResultSpace().getType().getLvlTypes().equals(lts)) {
+    return emitOpError(
+        "mismatch in iteration space level types and tensor level types");
+  }
+
+  TypedValue<IteratorType> pIter = getParentIter();
+  if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) {
+    return emitOpError("parent iterator is only needed iff level low equals 0");
+  }
+
+  if (pIter) {
+    unsigned pDim = pIter.getType().getSpaceDim();
+    if (getLoLvl() < pDim || !stt.getLvlTypes()
+                                  .slice(getLoLvl() - pDim, pDim)
+                                  .equals(pIter.getType().getLvlTypes())) {
+      return emitOpError(
+          "mismatch in parent iterator level types and tensor level types");
+    }
+  }
+
+  return success();
+}
+
+ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
+  OpAsmParser::Argument iterator;
+  OpAsmParser::UnresolvedOperand iterSpace;
+
+  // Parses %iters in %spaces
+  if (parser.parseArgument(iterator) || parser.parseKeyword("in") ||
+      parser.parseOperand(iterSpace)) {
+    return failure();
+  }
+
+  // Parse the optional initial iteration arguments.
+  SmallVector<OpAsmParser::Argument> regionArgs;
+  SmallVector<OpAsmParser::UnresolvedOperand> operands;
+  // Region arguments starts with iterators and follows by optional
+  // user-provided iter_args.
+  regionArgs.push_back(iterator);
+  bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
+  if (hasIterArgs)
+    if (parser.parseAssignmentList(regionArgs, operands))
+      return failure();
+
+  // parse ": sparse_tensor.iter_space -> ret"
+  Type iterSpaceTps;
+  if (parser.parseColon() || parser.parseType(iterSpaceTps))
+    return failure();
+  if (hasIterArgs)
+    if (parser.parseArrowTypeList(result.types))
+      return failure();
+
+  if (regionArgs.size() != result.types.size() + 1) {
+    return parser.emitError(
+        parser.getNameLoc(),
+        "mismatch in number of loop-carried values and defined values");
+  }
+
+  // Resolves input operands.
+  if (parser.resolveOperand(iterSpace, iterSpaceTps, result.operands))
+    return failure();
+
+  if (hasIterArgs) {
+    for (auto argOperandType :
+         llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
+      Type type = std::get<2>(argOperandType);
+      std::get<0>(argOperandType).type = type;
+      if (parser.resolveOperand(std::get<1>(argOperandType), type,
+                                result.operands))
+        return failure();
+    }
+  }
+
+  Region *body = result.addRegion();
+  regionArgs.front().type =
+      iterSpaceTps.cast<IterSpaceType>().getIteratorType();
+  if (parser.parseRegion(*body, regionArgs))
+    return failure();
+
+  IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
+
+  // Parse the optional attribute list.
+  if (parser.parseOptionalAttrDict(result.attributes))
+    return failure();
+
+  return success();
+}
+
+/// Prints the initialization list in the form of
+///   <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
+/// where 'inner' values are assumed to be region arguments and 'outer' values
+/// are regular SSA values.
+static void printInitializationList(OpAsmPrinter &p,
+                                    Block::BlockArgListType blocksArgs,
+                                    ValueRange initializers,
+                                    StringRef prefix = "") {
+  assert(blocksArgs.size() == initializers.size() &&
+         "expected same length of arguments and initializers");
+  if (initializers.empty())
+    return;
+
+  p << prefix << '(';
+  llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
+    p << std::get<0>(it) << " = " << std::get<1>(it);
+  });
+  p << ")";
+}
+
+void IterateOp::print(OpAsmPrinter &p) {
+  p << " " << getIterator() << " in " << getIterSpace();
+
+  printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
+
+  p << " : " << getIterSpace().getType() << " ";
+  if (!getInitArgs().empty())
+    p << "-> (" << getInitArgs().getTypes() << ") ";
+
+  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
+                /*printBlockTerminators=*/!getInitArgs().empty());
+}
+
+LogicalResult IterateOp::verify() {
+  if (getInitArgs().size() != getNumResults()) {
+    return emitOpError(
+        "mismatch in number of loop-carried values and defined values");
+  }
+  return success();
+}
+
+LogicalResult IterateOp::verifyRegions() {
+  if (getIterator().getType() != getIterSpace().getType().getIteratorType())
+    return emitOpError("mismatch in iterator and iteration space type");
+  if (getNumRegionIterArgs() != getNumResults())
+    return emitOpError(
+        "mismatch in number of basic block args and defined values");
+
+  auto initArgs = getInitArgs();
+  auto iterArgs = getRegionIterArgs();
+  auto yieldVals = getYieldedValues();
+  auto opResults = getResults();
+  if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
+                        opResults.size()})) {
+    return emitOpError() << "number mismatch between iter args and results.";
+  }
+
+  unsigned i = 0;
+  for (auto e : llvm::zip_equal(initArgs, iterArgs, yieldVals, opResults)) {
+    if (std::get<0>(e).getType() != std::get<3>(e).getType())
+      return emitOpError() << "types mismatch between " << i
+                           << "th iter operand and defined value";
+    if (std::get<1>(e).getType() != std::get<3>(e).getType())
+      return emitOpError() << "types mismatch between " << i
+                           << "th iter region arg and defined value";
+    if (std::get<2>(e).getType() != std::get<3>(e).getType())
+      return emitOpError() << "types mismatch between " << i
+                           << "th yield value and defined value";
+
+    ++i;
+  }
+  return success();
+}
+
+/// IterateOp implemented interfaces' methods.
+SmallVector<Region *> IterateOp::getLoopRegions() { return {&getRegion()}; }
+
+std::optional<Value> IterateOp::getSingleInductionVar() {
+  return getIterator();
+}
+
+MutableArrayRef<OpOperand> IterateOp::getInitsMutable() {
+  return getInitArgsMutable();
+}
+
+Block::BlockArgListType IterateOp::getRegionIterArgs() {
+  return getRegion().getArguments().drop_front();
+}
+
+std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() {
+  return cast<sparse_tensor::YieldOp>(
+             getRegion().getBlocks().front().getTerminator())
+      .getResultMutable();
+}
+
+std::optional<ResultRange> IterateOp::getLoopResults() { return getResults(); }
+
+OperandRange IterateOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+  return getInitArgs();
+}
+
+void IterateOp::getSuccessorRegions(RegionBranchPoint point,
+                                    SmallVectorImpl<RegionSuccessor> &regions) {
+  // Both the operation itself and the region may be branching into the body or
+  // back into the operation itself.
+  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
+  // It is possible for loop not to enter the body.
+  regions.push_back(RegionSuccessor(getResults()));
+}
+
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Dialect Setups.
+//===----------------------------------------------------------------------===//
+
 /// Materialize a single constant operation from a given attribute value with
 /// the desired resultant type.
 Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index af3a1b48f45af9..2a29ee8a7a87cb 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
   SparseGPUCodegen.cpp
   SparseReinterpretMap.cpp
   SparseStorageSpecifierToLLVM.cpp
+  SparseSpaceCollapse.cpp
   SparseTensorCodegen.cpp
   SparseTensorConversion.cpp
   SparseTensorPasses.cpp
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
new file mode 100644
index 00000000000000..f3207ede9585b4
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
@@ -0,0 +1,152 @@
+//===--------- SparseSpaceCollapse.cpp - Collapse Sparse Space Pass -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/Transforms/Passes.h"
+
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+
+namespace mlir {
+
+#define GEN_PASS_DEF_SPARSESPACECOLLAPSE
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
+
+namespace sparse_tensor {
+
+bool isCollapsableIterations(LoopLikeOpInterface parent,
+                             LoopLikeOpInterface node) {
+  auto pIterArgs = parent.getRegionIterArgs();
+  auto nInitArgs = node.getInits();
+  if (pIterArgs.size() != nInitArgs.size())
+    return false;
+
+  auto pYields = parent.getYieldedValues();
+  auto nResult = node.getLoopResults().value();
+
+  bool yieldEq =
+      llvm::all_of(llvm::zip_equal(pYields, nResult), [](auto zipped) {
+        return std::get<0>(zipped) == std::get<1>(zipped);
+      });
+
+  // Parent iter_args should be passed directly to the node's init_args.
+  bool iterArgEq =
+      llvm::all_of(llvm::zip_equal(pIterArgs, nInitArgs), [](auto zipped) {
+        return std::get<0>(zipped) == std::get<1>(zipped);
+      });
+
+  return yieldEq && iterArgEq;
+}
+
+bool legalToCollapse(ExtractIterSpaceOp parent, ExtractIterSpaceOp node) {
+  auto pItOp = llvm::dyn_cast<IterateOp>(parent->getParentOp());
+  auto nItOp = llvm::dyn_cast<IterateOp>(node->getParentOp());
+
+  // Can only collapse spaces extracted from the same tensor.
+  if (parent.getTensor() != node.getTensor() || !parent->hasOneUse())
+    return false;
+
+  // Can only collapse consecutive simple iteration on one tensor (i.e., no
+  // coiteration).
+  if (!nItOp || nItOp.getIterSpace() != parent.getResult() ||
+      nItOp->getBlock() != parent->getBlock())
+    return false;
+
+  if (pItOp && !isCollapsableIterations(pItOp, nItOp))
+    return false;
+
+  // TODO: Make sure all other operations in the same basic block as `node` can
+  // be collapsed and sink into the collapsed iteration (through Interfaces
+  // defined in TD files).
+  return true;
+}
+
+void collapseSparseSpace(ArrayRef<ExtractIterSpaceOp> toCollapse) {
+  if (toCollapse.size() < 2)
+    return;
+
+  ExtractIterSpaceOp root = toCollapse.front();
+  ExtractIterSpaceOp leaf = toCollapse.back();
+  Location loc = root.getLoc();
+
+  if (!leaf->hasOneUse())
+    return;
+  assert(root->hasOneUse());
+
+  // Insert collapsed operation at the same scope as root operation.
+  OpBuilder builder(toCollapse.front());
+
+  // Construct the collapsed iteration space.
+  auto collapsedSpace = builder.create<ExtractIterSpaceOp>(
+      loc, root.getTensor(), root.getParentIter(), root.getLoLvl(),
+      leaf.getHiLvl());
+
+  auto rItOp = llvm::cast<IterateOp>(*root->getUsers().begin());
+  auto pItOp = llvm::cast<IterateOp>(leaf->getParentOp());
+
+  // This could either be IterateOp or (TODO: in the future) CoIterateOp.
+  auto loop = llvm::dyn_cast<IterateOp>(*leaf->getUsers().begin());
+  if (!loop || !isCollapsableIterations(pItOp, loop))
+    return;
+
+  IRMapping mapper;
+  mapper.map(leaf, collapsedSpace.getResultSpace());
+  for (auto z : llvm::zip_equal(loop.getInitArgs(), rItOp.getInitArgs()))
+    mapper.map(std::get<0>(z), std::get<1>(z));
+
+  auto cloned = llvm::cast<IterateOp>(builder.clone(*loop, mapper));
+  cloned.getIterator().setType(collapsedSpace.getType().getIteratorType());
+
+  rItOp.replaceAllUsesWith(cloned.getResults());
+  // Erase collapsed loops.
+  rItOp.erase();
+  root.erase();
+}
+
+struct SparseSpaceCollapsePass
+    : public impl::SparseSpaceCollapseBase<SparseSpaceCollapsePass> {
+  SparseSpaceCollapsePass() = default;
+
+  void runOnOperation() override {
+    func::FuncOp func = getOperation();
+
+    // A naive (experimental) implementation to collapse consecutive sparse
+    // spaces. It does NOT handle complex cases where multiple spaces are
+    // extracted in the same basic block. E.g.,
+    //
+    // %space1 = extract_space %t1 ...
+    // %space2 = extract_space %t2 ...
+    // sparse_tensor.iterate(%sp1) ...
+    //
+    SmallVector<ExtractIterSpaceOp> toCollapse;
+    func->walk([&](ExtractIterSpaceOp op) {
+      if (toCollapse.empty()) {
+        // Root space to collapse.
+        toCollapse.push_back(op);
+      } else {
+        if (legalToCollapse(toCollapse.back(), op)) {
+          toCollapse.push_back(op);
+        } else {
+          collapseSparseSpace(toCollapse);
+          toCollapse.clear();
+        }
+      }
+    });
+
+    collapseSparseSpace(toCollapse);
+  }
+};
+
+} // namespace sparse_tensor
+
+std::unique_ptr<Pass> createSparseSpaceCollapsePass() {
+  return std::make_unique<sparse_tensor::SparseSpaceCollapsePass>();
+}
+
+} // namespace mlir

>From afb0f9c141e0e8ca30a01024cdcfe4b8f770b3ab Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 22 Mar 2024 18:41:34 +0000
Subject: [PATCH 3/6] test collapsing coordinate extraction from iterator.

---
 .../SparseTensor/IR/SparseTensorInterfaces.h  |   2 +
 .../SparseTensor/IR/SparseTensorInterfaces.td |  15 ++
 .../SparseTensor/IR/SparseTensorOps.td        |  14 +-
 .../SparseTensor/IR/SparseTensorDialect.cpp   | 155 ++++++++++++------
 .../Transforms/SparseSpaceCollapse.cpp        | 119 ++++++++++----
 5 files changed, 215 insertions(+), 90 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h
index c0f31762ee071f..115e08b2cf8b14 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h
@@ -10,6 +10,7 @@
 #define MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORINTERFACES_H_
 
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
 
 namespace mlir {
 class PatternRewriter;
@@ -20,6 +21,7 @@ class StageWithSortSparseOp;
 namespace detail {
 LogicalResult stageWithSortImpl(sparse_tensor::StageWithSortSparseOp op,
                                 PatternRewriter &rewriter, Value &tmpBufs);
+
 } // namespace detail
 } // namespace sparse_tensor
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td
index 05eed0483f2c8a..ee1c0b52b47e45 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td
@@ -42,5 +42,20 @@ def StageWithSortSparseOpInterface : OpInterface<"StageWithSortSparseOp"> {
   ];
 }
 
+def SparseCollapsableOpInterface : OpInterface<"SparseCollapsableOp"> {
+  let description = [{ TODO }];
+
+  let cppNamespace = "::mlir::sparse_tensor";
+
+  let methods = [
+    InterfaceMethod<
+    /*desc=*/"test",
+    /*retTy=*/"ValueRange",
+    /*methodName=*/"collaspeOpInto",
+    /*args=*/(ins "::mlir::OpBuilder &":$builder,
+                  "::mlir::ArrayRef<::mlir::Operation *>":$loops,
+                  "::mlir::Operation *":$collapsed)>,
+  ];
+}
 
 #endif // SPARSETENSOR_IR_SPARSETENSORINTERFACES
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index e1fa48d202a375..cf7c7755c7427a 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1282,7 +1282,7 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [Pure, SameOperandsAndResu
 
 def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator,
     ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp",
-                 "ForeachOp"]>]>,
+                 "ForeachOp", "IterateOp"]>]>,
     Arguments<(ins Variadic<AnyType>:$results)> {
   let summary = "Yield from sparse_tensor set-like operations";
   let description = [{
@@ -1455,10 +1455,20 @@ def ExtractIterSpaceOp : SparseTensor_Op<"iteration.extract_space",
   }];
 
   let hasVerifier = 1;
-  let assemblyFormat = "$tensor (`at`$parentIter^)? `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) "
+  let assemblyFormat = "$tensor (`at` $parentIter^)? `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) "
                        " attr-dict `:` type($tensor) (`,` type($parentIter)^)?";
 }
 
+def CoordinateOp : SparseTensor_Op<"iteration.coordinate",
+    [Pure, DeclareOpInterfaceMethods<SparseCollapsableOpInterface>]> {
+  let arguments = (ins AnySparseIterator:$iterator);
+  let results = (outs Variadic<Index>:$crds);
+
+  let extraClassDeclaration = [{ }];
+  // let hasVerifier = 1;
+  let assemblyFormat = " $iterator attr-dict `:` type($iterator) `->` type($crds)";
+}
+
 def IterateOp : SparseTensor_Op<"iterate",
     [RecursiveMemoryEffects, RecursivelySpeculatable,
      DeclareOpInterfaceMethods<LoopLikeOpInterface,
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index f26f4a0a14707f..bd9f790443b5fd 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1926,17 +1926,82 @@ static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr,
   } else {
     lvlHi = lvlLo + 1;
   }
+
+  if (lvlHi <= lvlLo)
+    parser.emitError(parser.getNameLoc(),
+                     "expect larger level upper bound than lower bound");
+
   lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo);
   lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi);
   return success();
 }
 
-static void printLevelRange(OpAsmPrinter &p, ExtractIterSpaceOp op,
-                            IntegerAttr lvlLo, IntegerAttr lvlHi) {
-  if (op.getLoLvl() + 1 == op.getHiLvl())
-    p << op.getLoLvl();
+static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
+                            IntegerAttr lvlHi) {
+  unsigned lo = lvlLo.getValue().getZExtValue();
+  unsigned hi = lvlHi.getValue().getZExtValue();
+  if (lo + 1 == hi)
+    p << lo;
   else
-    p << op.getLoLvl() << " to " << op.getHiLvl();
+    p << lo << " to " << hi;
+}
+
+ParseResult
+parseSparseSpaceLoop(OpAsmParser &parser, OperationState &state,
+                     SmallVectorImpl<OpAsmParser::Argument> &iterators,
+                     SmallVectorImpl<OpAsmParser::Argument> &iterArgs) {
+  SmallVector<OpAsmParser::UnresolvedOperand> spaces;
+  SmallVector<OpAsmParser::UnresolvedOperand> initArgs;
+  // Parses "%iters, ... in %spaces, ..."
+  if (parser.parseArgumentList(iterators) || parser.parseKeyword("in") ||
+      parser.parseOperandList(spaces))
+    return failure();
+
+  if (iterators.size() != spaces.size())
+    return parser.emitError(
+        parser.getNameLoc(),
+        "mismatch in number of sparse iterators and sparse spaces");
+
+  bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
+  if (hasIterArgs)
+    if (parser.parseAssignmentList(iterArgs, initArgs))
+      return failure();
+
+  SmallVector<Type> iterSpaceTps;
+  // parse ": sparse_tensor.iter_space -> ret"
+  if (parser.parseColon() || parser.parseTypeList(iterSpaceTps))
+    return failure();
+  if (iterSpaceTps.size() != spaces.size())
+    return parser.emitError(parser.getNameLoc(),
+                            "mismatch in number of iteration space operands "
+                            "and iteration space types");
+
+  for (auto [it, tp] : llvm::zip_equal(iterators, iterSpaceTps)) {
+    IterSpaceType spaceTp = llvm::dyn_cast<IterSpaceType>(tp);
+    if (!spaceTp)
+      return parser.emitError(parser.getNameLoc(),
+                              "expected sparse_tensor.iter_space type for "
+                              "iteration space operands");
+    it.type = spaceTp.getIteratorType();
+  }
+
+  if (hasIterArgs)
+    if (parser.parseArrowTypeList(state.types))
+      return failure();
+
+  // Resolves input operands.
+  if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(),
+                             state.operands))
+    return failure();
+
+  if (hasIterArgs) {
+    for (auto [it, init, tp] : llvm::zip(iterArgs, initArgs, state.types)) {
+      it.type = tp;
+      if (parser.resolveOperand(init, tp, state.operands))
+        return failure();
+    }
+  }
+  return success();
 }
 
 LogicalResult ExtractIterSpaceOp::inferReturnTypes(
@@ -1981,60 +2046,45 @@ LogicalResult ExtractIterSpaceOp::verify() {
   return success();
 }
 
+ValueRange CoordinateOp::collaspeOpInto(OpBuilder &builder,
+                                        ArrayRef<Operation *> loops,
+                                        Operation *collapsed) {
+  assert(llvm::all_of(loops,
+                      [](Operation *l) { return llvm::isa<IterateOp>(l); }));
+  auto finalLoop = llvm::cast<IterateOp>(collapsed);
+  SmallVector<Type> retTps(finalLoop.getIterSpace().getType().getSpaceDim(),
+                           builder.getIndexType());
+  auto collapsedCoords =
+      builder.create<CoordinateOp>(getLoc(), retTps, finalLoop.getIterator());
+
+  for (Operation *l : loops) {
+    if (getIterator().getParentBlock()->getParentOp() == l) {
+      auto space = llvm::cast<IterateOp>(l)
+                       .getIterSpace()
+                       .getDefiningOp<ExtractIterSpaceOp>();
+
+      return collapsedCoords.getResults().slice(space.getLoLvl(),
+                                                space.getSpaceDim());
+    }
+  }
+  llvm_unreachable(
+      "Can not find the corresponding iterate space for the collapsable op.");
+}
+
 ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
   OpAsmParser::Argument iterator;
   OpAsmParser::UnresolvedOperand iterSpace;
 
-  // Parses %iters in %spaces
-  if (parser.parseArgument(iterator) || parser.parseKeyword("in") ||
-      parser.parseOperand(iterSpace)) {
-    return failure();
-  }
-
-  // Parse the optional initial iteration arguments.
-  SmallVector<OpAsmParser::Argument> regionArgs;
-  SmallVector<OpAsmParser::UnresolvedOperand> operands;
-  // Region arguments starts with iterators and follows by optional
-  // user-provided iter_args.
-  regionArgs.push_back(iterator);
-  bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
-  if (hasIterArgs)
-    if (parser.parseAssignmentList(regionArgs, operands))
-      return failure();
-
-  // parse ": sparse_tensor.iter_space -> ret"
-  Type iterSpaceTps;
-  if (parser.parseColon() || parser.parseType(iterSpaceTps))
+  SmallVector<OpAsmParser::Argument> iters, iterArgs;
+  if (parseSparseSpaceLoop(parser, result, iters, iterArgs))
     return failure();
-  if (hasIterArgs)
-    if (parser.parseArrowTypeList(result.types))
-      return failure();
-
-  if (regionArgs.size() != result.types.size() + 1) {
-    return parser.emitError(
-        parser.getNameLoc(),
-        "mismatch in number of loop-carried values and defined values");
-  }
-
-  // Resolves input operands.
-  if (parser.resolveOperand(iterSpace, iterSpaceTps, result.operands))
-    return failure();
-
-  if (hasIterArgs) {
-    for (auto argOperandType :
-         llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
-      Type type = std::get<2>(argOperandType);
-      std::get<0>(argOperandType).type = type;
-      if (parser.resolveOperand(std::get<1>(argOperandType), type,
-                                result.operands))
-        return failure();
-    }
-  }
+  if (iters.size() != 1)
+    return parser.emitError(parser.getNameLoc(),
+                            "expected only one iterator/iteration space");
 
+  iters.append(iterArgs);
   Region *body = result.addRegion();
-  regionArgs.front().type =
-      iterSpaceTps.cast<IterSpaceType>().getIteratorType();
-  if (parser.parseRegion(*body, regionArgs))
+  if (parser.parseRegion(*body, iters))
     return failure();
 
   IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
@@ -2068,7 +2118,6 @@ static void printInitializationList(OpAsmPrinter &p,
 
 void IterateOp::print(OpAsmPrinter &p) {
   p << " " << getIterator() << " in " << getIterSpace();
-
   printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
 
   p << " : " << getIterSpace().getType() << " ";
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
index f3207ede9585b4..752b2dfc2a0070 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
@@ -20,8 +20,14 @@ namespace mlir {
 
 namespace sparse_tensor {
 
-bool isCollapsableIterations(LoopLikeOpInterface parent,
-                             LoopLikeOpInterface node) {
+struct CollapseSpaceInfo {
+  ExtractIterSpaceOp space;
+  // Coiteration as well (if make sense)?
+  IterateOp loop;
+  SmallVector<SparseCollapsableOp> collapseOps;
+};
+
+bool isCollapsableLoops(LoopLikeOpInterface parent, LoopLikeOpInterface node) {
   auto pIterArgs = parent.getRegionIterArgs();
   auto nInitArgs = node.getInits();
   if (pIterArgs.size() != nInitArgs.size())
@@ -44,43 +50,81 @@ bool isCollapsableIterations(LoopLikeOpInterface parent,
   return yieldEq && iterArgEq;
 }
 
-bool legalToCollapse(ExtractIterSpaceOp parent, ExtractIterSpaceOp node) {
-  auto pItOp = llvm::dyn_cast<IterateOp>(parent->getParentOp());
-  auto nItOp = llvm::dyn_cast<IterateOp>(node->getParentOp());
+bool legalToCollapse(SmallVectorImpl<CollapseSpaceInfo> &toCollapse,
+                     ExtractIterSpaceOp curSpace) {
+
+  auto getIterateOpOverSpace = [](ExtractIterSpaceOp space) -> IterateOp {
+    Value spaceVal = space.getResultSpace();
+    if (spaceVal.hasOneUse())
+      return llvm::dyn_cast<IterateOp>(*spaceVal.getUsers().begin());
+    return nullptr;
+  };
+
+  if (toCollapse.empty()) {
+    // Collapse root.
+    if (auto itOp = getIterateOpOverSpace(curSpace)) {
+      CollapseSpaceInfo &info = toCollapse.emplace_back();
+      info.space = curSpace;
+      info.loop = itOp;
+      // No operations need to be collapsed at the root level;
+      info.collapseOps = {};
+      return true;
+    }
+    return false;
+  }
+
+  auto parent = toCollapse.back().space;
+  auto pItOp = toCollapse.back().loop;
+  auto nItOp = getIterateOpOverSpace(curSpace);
 
   // Can only collapse spaces extracted from the same tensor.
-  if (parent.getTensor() != node.getTensor() || !parent->hasOneUse())
+  if (parent.getTensor() != curSpace.getTensor())
     return false;
 
   // Can only collapse consecutive simple iteration on one tensor (i.e., no
   // coiteration).
-  if (!nItOp || nItOp.getIterSpace() != parent.getResult() ||
-      nItOp->getBlock() != parent->getBlock())
+  if (!nItOp || nItOp->getBlock() != curSpace->getBlock() ||
+      pItOp.getIterator() != curSpace.getParentIter() ||
+      curSpace->getParentOp() != pItOp.getOperation())
     return false;
 
-  if (pItOp && !isCollapsableIterations(pItOp, nItOp))
+  if (pItOp && !isCollapsableLoops(pItOp, nItOp))
     return false;
 
   // TODO: Make sure all other operations in the same basic block as `node` can
   // be collapsed and sink into the collapsed iteration (through Interfaces
   // defined in TD files).
+  SmallVector<SparseCollapsableOp> collapsableOps;
+  for (Operation &op : *pItOp.getBody()) {
+    if (&op == curSpace.getOperation() || &op == nItOp.getOperation() ||
+        &op == pItOp.getBody()->getTerminator())
+      continue;
+    // All other ops in parent loop need to be collapsable.
+    auto collapsableOp = llvm::dyn_cast<SparseCollapsableOp>(&op);
+    if (!collapsableOp)
+      return false;
+    collapsableOps.push_back(collapsableOp);
+  }
+
+  CollapseSpaceInfo &info = toCollapse.emplace_back();
+  info.space = curSpace;
+  info.loop = nItOp;
+  info.collapseOps = std::move(collapsableOps);
   return true;
 }
 
-void collapseSparseSpace(ArrayRef<ExtractIterSpaceOp> toCollapse) {
+void collapseSparseSpace(SmallVectorImpl<CollapseSpaceInfo> &toCollapse) {
   if (toCollapse.size() < 2)
     return;
 
-  ExtractIterSpaceOp root = toCollapse.front();
-  ExtractIterSpaceOp leaf = toCollapse.back();
+  ExtractIterSpaceOp root = toCollapse.front().space;
+  ExtractIterSpaceOp leaf = toCollapse.back().space;
   Location loc = root.getLoc();
 
-  if (!leaf->hasOneUse())
-    return;
-  assert(root->hasOneUse());
+  assert(root->hasOneUse() && leaf->hasOneUse());
 
   // Insert collapsed operation at the same scope as root operation.
-  OpBuilder builder(toCollapse.front());
+  OpBuilder builder(root);
 
   // Construct the collapsed iteration space.
   auto collapsedSpace = builder.create<ExtractIterSpaceOp>(
@@ -88,19 +132,29 @@ void collapseSparseSpace(ArrayRef<ExtractIterSpaceOp> toCollapse) {
       leaf.getHiLvl());
 
   auto rItOp = llvm::cast<IterateOp>(*root->getUsers().begin());
-  auto pItOp = llvm::cast<IterateOp>(leaf->getParentOp());
-
-  // This could either be IterateOp or (TODO: in the future) CoIterateOp.
-  auto loop = llvm::dyn_cast<IterateOp>(*leaf->getUsers().begin());
-  if (!loop || !isCollapsableIterations(pItOp, loop))
-    return;
+  auto innermost = toCollapse.back().loop;
 
   IRMapping mapper;
   mapper.map(leaf, collapsedSpace.getResultSpace());
-  for (auto z : llvm::zip_equal(loop.getInitArgs(), rItOp.getInitArgs()))
+  for (auto z : llvm::zip_equal(innermost.getInitArgs(), rItOp.getInitArgs()))
     mapper.map(std::get<0>(z), std::get<1>(z));
 
-  auto cloned = llvm::cast<IterateOp>(builder.clone(*loop, mapper));
+  auto cloned = llvm::cast<IterateOp>(builder.clone(*innermost, mapper));
+  builder.setInsertionPointToStart(cloned.getBody());
+  SmallVector<Operation *> loops =
+      llvm::map_to_vector(toCollapse, [](CollapseSpaceInfo &info) {
+        return info.loop.getOperation();
+      });
+
+  for (const CollapseSpaceInfo &info : toCollapse) {
+    for (SparseCollapsableOp op : info.collapseOps) {
+      ValueRange colVals = op.collaspeOpInto(builder, loops, cloned);
+      for (auto [o, r] : llvm::zip(op->getResults(), colVals))
+        o.replaceAllUsesWith(r);
+      op.erase();
+    }
+  }
+
   cloned.getIterator().setType(collapsedSpace.getType().getIteratorType());
 
   rItOp.replaceAllUsesWith(cloned.getResults());
@@ -124,18 +178,13 @@ struct SparseSpaceCollapsePass
     // %space2 = extract_space %t2 ...
     // sparse_tensor.iterate(%sp1) ...
     //
-    SmallVector<ExtractIterSpaceOp> toCollapse;
+    SmallVector<CollapseSpaceInfo> toCollapse;
     func->walk([&](ExtractIterSpaceOp op) {
-      if (toCollapse.empty()) {
-        // Root space to collapse.
-        toCollapse.push_back(op);
-      } else {
-        if (legalToCollapse(toCollapse.back(), op)) {
-          toCollapse.push_back(op);
-        } else {
-          collapseSparseSpace(toCollapse);
-          toCollapse.clear();
-        }
+      if (!legalToCollapse(toCollapse, op)) {
+        // if not legal to collapse one more space, collapse the existing ones
+        // and clear.
+        collapseSparseSpace(toCollapse);
+        toCollapse.clear();
       }
     });
 

>From 2c6221cc65d044f9f513c33b677edef10f66f4dc Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 25 Mar 2024 22:41:52 +0000
Subject: [PATCH 4/6] fuse crds into iterate operation.

---
 .../Dialect/SparseTensor/IR/SparseTensor.h    |  36 ++++++
 .../SparseTensor/IR/SparseTensorAttrDefs.td   |  15 +++
 .../SparseTensor/IR/SparseTensorOps.td        |  30 +++--
 .../SparseTensor/IR/SparseTensorDialect.cpp   | 119 +++++++++++++-----
 .../Transforms/SparseSpaceCollapse.cpp        |  44 ++-----
 5 files changed, 170 insertions(+), 74 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 78692307820bc5..081a9b8cad8d62 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -22,6 +22,8 @@
 #include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
+#include "llvm/ADT/bit.h"
+
 //===----------------------------------------------------------------------===//
 //
 // Type aliases to help code be more self-documenting. Unfortunately
@@ -43,6 +45,40 @@ using Level = uint64_t;
 /// including the value `ShapedType::kDynamic` (for shapes).
 using Size = int64_t;
 
+/// A simple wrapper to encode a bitset of defined  (at most 64) levels.
+class LevelSet {
+  uint64_t bits = 0;
+
+public:
+  LevelSet() = default;
+  explicit LevelSet(uint64_t bits) : bits(bits) {}
+  operator uint64_t() const { return bits; }
+
+  LevelSet &set(unsigned i) {
+    assert(i < 64);
+    bits |= 1 << i;
+    return *this;
+  }
+
+  LevelSet &operator|=(LevelSet lhs) {
+    bits |= static_cast<uint64_t>(lhs);
+    return *this;
+  }
+
+  LevelSet &lshift(unsigned offset) {
+    bits = bits << offset;
+    return *this;
+  }
+
+  bool operator[](unsigned i) const {
+    assert(i < 64);
+    return (bits & (1 << i)) != 0;
+  }
+
+  unsigned count() const { return llvm::popcount(bits); }
+  bool empty() const { return bits == 0; }
+};
+
 } // namespace sparse_tensor
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index d3be8a3009ba1e..36c075f52f8e5b 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -19,6 +19,21 @@ class SparseTensor_Attr<string name,
                         list<Trait> traits = []>
     : AttrDef<SparseTensor_Dialect, name, traits>;
 
+//===----------------------------------------------------------------------===//
+// A simple bitset attribute wrapped over a single int64_t to encode a set of
+// sparse tensor levels.
+//===----------------------------------------------------------------------===//
+
+def LevelSetAttr :
+    TypedAttrBase<
+      I64, "IntegerAttr",
+      And<[CPred<"::llvm::isa<::mlir::IntegerAttr>($_self)">,
+           CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getType().isInteger(64)">]>,
+      "LevelSet attribute"> {
+  let returnType = [{::mlir::sparse_tensor::LevelSet}];
+  let convertFromStorage = [{::mlir::sparse_tensor::LevelSet($_self.getValue().getZExtValue())}];
+}
+
 //===----------------------------------------------------------------------===//
 // These attributes are just like `IndexAttr` except that they clarify whether
 // the index refers to a dimension (an axis of the semantic tensor) or a level
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index cf7c7755c7427a..a734262f141079 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1459,36 +1459,42 @@ def ExtractIterSpaceOp : SparseTensor_Op<"iteration.extract_space",
                        " attr-dict `:` type($tensor) (`,` type($parentIter)^)?";
 }
 
-def CoordinateOp : SparseTensor_Op<"iteration.coordinate",
-    [Pure, DeclareOpInterfaceMethods<SparseCollapsableOpInterface>]> {
-  let arguments = (ins AnySparseIterator:$iterator);
-  let results = (outs Variadic<Index>:$crds);
-
-  let extraClassDeclaration = [{ }];
-  // let hasVerifier = 1;
-  let assemblyFormat = " $iterator attr-dict `:` type($iterator) `->` type($crds)";
-}
+// def CoordinateOp : SparseTensor_Op<"iteration.coordinate",
+//     [Pure, DeclareOpInterfaceMethods<SparseCollapsableOpInterface>]> {
+//   let arguments = (ins AnySparseIterator:$iterator);
+//   let results = (outs Variadic<Index>:$crds);
+//   let extraClassDeclaration = [{ }];
+//   // let hasVerifier = 1;
+//   let assemblyFormat = " $iterator attr-dict `:` type($iterator) `->` type($crds)";
+// }
 
 def IterateOp : SparseTensor_Op<"iterate",
     [RecursiveMemoryEffects, RecursivelySpeculatable,
      DeclareOpInterfaceMethods<LoopLikeOpInterface,
       ["getInitsMutable", "getLoopResults", "getRegionIterArgs",
-      "getSingleInductionVar", "getYieldedValuesMutable"]>,
+       "getYieldedValuesMutable"]>,
      DeclareOpInterfaceMethods<RegionBranchOpInterface,
       ["getEntrySuccessorOperands"]>,
      SingleBlockImplicitTerminator<"sparse_tensor::YieldOp">]> {
 
   let arguments = (ins AnySparseIterSpace:$iterSpace,
-                       Variadic<AnyType>:$initArgs);
+                       Variadic<AnyType>:$initArgs,
+                       LevelSetAttr:$crdUsedLvls);
   let results = (outs Variadic<AnyType>:$results);
   let regions = (region SizedRegion<1>:$region);
 
   let extraClassDeclaration = [{
+    unsigned getSpaceDim() {
+      return getIterSpace().getType().getSpaceDim();
+    }
     BlockArgument getIterator() {
       return getRegion().getArguments().front();
     }
+    Block::BlockArgListType getCrds() {
+      return getRegion().getArguments().slice(1, getCrdUsedLvls().count());
+    }
     unsigned getNumRegionIterArgs() {
-      return getRegion().getArguments().size() - 1;
+      return getRegion().getArguments().size() - 1 - getCrdUsedLvls().count();
     }
   }];
 
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index bd9f790443b5fd..6c047e2a79248e 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -42,6 +42,7 @@ namespace mlir::sparse_tensor {
 llvm::hash_code hash_value(LevelType lt) {
   return llvm::hash_value(static_cast<uint64_t>(lt));
 }
+
 } // namespace mlir::sparse_tensor
 
 //===----------------------------------------------------------------------===//
@@ -1946,12 +1947,13 @@ static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
     p << lo << " to " << hi;
 }
 
-ParseResult
+static ParseResult
 parseSparseSpaceLoop(OpAsmParser &parser, OperationState &state,
                      SmallVectorImpl<OpAsmParser::Argument> &iterators,
                      SmallVectorImpl<OpAsmParser::Argument> &iterArgs) {
   SmallVector<OpAsmParser::UnresolvedOperand> spaces;
   SmallVector<OpAsmParser::UnresolvedOperand> initArgs;
+
   // Parses "%iters, ... in %spaces, ..."
   if (parser.parseArgumentList(iterators) || parser.parseKeyword("in") ||
       parser.parseOperandList(spaces))
@@ -1962,6 +1964,34 @@ parseSparseSpaceLoop(OpAsmParser &parser, OperationState &state,
         parser.getNameLoc(),
         "mismatch in number of sparse iterators and sparse spaces");
 
+  // Parse "at(%crd0, _, ...)"
+  LevelSet crdUsedLvlSet;
+  bool hasUsedCrds = succeeded(parser.parseOptionalKeyword("at"));
+  unsigned lvlCrdCnt = 0;
+  if (hasUsedCrds) {
+    ParseResult crdList = parser.parseCommaSeparatedList(
+        OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
+          if (parser.parseOptionalKeyword("_")) {
+            if (parser.parseArgument(iterArgs.emplace_back()))
+              return failure();
+            // Always use IndexType for the coordinate.
+            crdUsedLvlSet.set(lvlCrdCnt);
+            iterArgs.back().type = parser.getBuilder().getIndexType();
+          }
+          lvlCrdCnt += 1;
+          return success();
+        });
+    if (failed(crdList)) {
+      return parser.emitError(
+          parser.getNameLoc(),
+          "expecting SSA value or \"_\" for level coordinates");
+    }
+  }
+  // Set the CrdUsedLvl bitset.
+  state.addAttribute("crdUsedLvls",
+                     parser.getBuilder().getI64IntegerAttr(crdUsedLvlSet));
+
+  // Parse "iter_args(%arg = %init, ...)"
   bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
   if (hasIterArgs)
     if (parser.parseAssignmentList(iterArgs, initArgs))
@@ -1982,6 +2012,10 @@ parseSparseSpaceLoop(OpAsmParser &parser, OperationState &state,
       return parser.emitError(parser.getNameLoc(),
                               "expected sparse_tensor.iter_space type for "
                               "iteration space operands");
+    if (hasUsedCrds && spaceTp.getSpaceDim() != lvlCrdCnt)
+      return parser.emitError(parser.getNameLoc(),
+                              "mismatch in number of iteration space dimension "
+                              "and specified coordinates");
     it.type = spaceTp.getIteratorType();
   }
 
@@ -1995,7 +2029,10 @@ parseSparseSpaceLoop(OpAsmParser &parser, OperationState &state,
     return failure();
 
   if (hasIterArgs) {
-    for (auto [it, init, tp] : llvm::zip(iterArgs, initArgs, state.types)) {
+    unsigned numCrds = crdUsedLvlSet.count();
+    // Strip off leading args that used for coordinates.
+    MutableArrayRef args = MutableArrayRef(iterArgs).drop_front(numCrds);
+    for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) {
       it.type = tp;
       if (parser.resolveOperand(init, tp, state.operands))
         return failure();
@@ -2046,30 +2083,32 @@ LogicalResult ExtractIterSpaceOp::verify() {
   return success();
 }
 
-ValueRange CoordinateOp::collaspeOpInto(OpBuilder &builder,
-                                        ArrayRef<Operation *> loops,
-                                        Operation *collapsed) {
-  assert(llvm::all_of(loops,
-                      [](Operation *l) { return llvm::isa<IterateOp>(l); }));
-  auto finalLoop = llvm::cast<IterateOp>(collapsed);
-  SmallVector<Type> retTps(finalLoop.getIterSpace().getType().getSpaceDim(),
-                           builder.getIndexType());
-  auto collapsedCoords =
-      builder.create<CoordinateOp>(getLoc(), retTps, finalLoop.getIterator());
-
-  for (Operation *l : loops) {
-    if (getIterator().getParentBlock()->getParentOp() == l) {
-      auto space = llvm::cast<IterateOp>(l)
-                       .getIterSpace()
-                       .getDefiningOp<ExtractIterSpaceOp>();
-
-      return collapsedCoords.getResults().slice(space.getLoLvl(),
-                                                space.getSpaceDim());
-    }
-  }
-  llvm_unreachable(
-      "Can not find the corresponding iterate space for the collapsable op.");
-}
+// ValueRange CoordinateOp::collaspeOpInto(OpBuilder &builder,
+//                                         ArrayRef<Operation *> loops,
+//                                         Operation *collapsed) {
+//   assert(llvm::all_of(loops,
+//                       [](Operation *l) { return llvm::isa<IterateOp>(l); }));
+//   auto finalLoop = llvm::cast<IterateOp>(collapsed);
+//   SmallVector<Type> retTps(finalLoop.getIterSpace().getType().getSpaceDim(),
+//                            builder.getIndexType());
+//   auto collapsedCoords =
+//       builder.create<CoordinateOp>(getLoc(), retTps,
+//       finalLoop.getIterator());
+
+//   for (Operation *l : loops) {
+//     if (getIterator().getParentBlock()->getParentOp() == l) {
+//       auto space = llvm::cast<IterateOp>(l)
+//                        .getIterSpace()
+//                        .getDefiningOp<ExtractIterSpaceOp>();
+
+//       return collapsedCoords.getResults().slice(space.getLoLvl(),
+//                                                 space.getSpaceDim());
+//     }
+//   }
+//   llvm_unreachable(
+//       "Can not find the corresponding iterate space for the collapsable
+//       op.");
+// }
 
 ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
   OpAsmParser::Argument iterator;
@@ -2116,8 +2155,30 @@ static void printInitializationList(OpAsmPrinter &p,
   p << ")";
 }
 
+static void printUsedCrdsList(OpAsmPrinter &p, unsigned spaceDim,
+                              Block::BlockArgListType blocksArgs,
+                              LevelSet crdUsedLvls) {
+  if (crdUsedLvls.empty())
+    return;
+
+  p << " at(";
+  for (unsigned i = 0; i < spaceDim; i++) {
+    if (crdUsedLvls[i]) {
+      p << blocksArgs.front();
+      blocksArgs = blocksArgs.drop_front();
+    } else {
+      p << "_";
+    }
+    if (i != spaceDim - 1)
+      p << ", ";
+  }
+  assert(blocksArgs.empty());
+  p << ")";
+}
+
 void IterateOp::print(OpAsmPrinter &p) {
   p << " " << getIterator() << " in " << getIterSpace();
+  printUsedCrdsList(p, getSpaceDim(), getCrds(), getCrdUsedLvls());
   printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
 
   p << " : " << getIterSpace().getType() << " ";
@@ -2172,16 +2233,12 @@ LogicalResult IterateOp::verifyRegions() {
 /// IterateOp implemented interfaces' methods.
 SmallVector<Region *> IterateOp::getLoopRegions() { return {&getRegion()}; }
 
-std::optional<Value> IterateOp::getSingleInductionVar() {
-  return getIterator();
-}
-
 MutableArrayRef<OpOperand> IterateOp::getInitsMutable() {
   return getInitArgsMutable();
 }
 
 Block::BlockArgListType IterateOp::getRegionIterArgs() {
-  return getRegion().getArguments().drop_front();
+  return getRegion().getArguments().take_back(getNumRegionIterArgs());
 }
 
 std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
index 752b2dfc2a0070..39c9a9292c9be9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
@@ -24,7 +24,6 @@ struct CollapseSpaceInfo {
   ExtractIterSpaceOp space;
   // Coiteration as well (if make sense)?
   IterateOp loop;
-  SmallVector<SparseCollapsableOp> collapseOps;
 };
 
 bool isCollapsableLoops(LoopLikeOpInterface parent, LoopLikeOpInterface node) {
@@ -66,8 +65,6 @@ bool legalToCollapse(SmallVectorImpl<CollapseSpaceInfo> &toCollapse,
       CollapseSpaceInfo &info = toCollapse.emplace_back();
       info.space = curSpace;
       info.loop = itOp;
-      // No operations need to be collapsed at the root level;
-      info.collapseOps = {};
       return true;
     }
     return false;
@@ -91,29 +88,13 @@ bool legalToCollapse(SmallVectorImpl<CollapseSpaceInfo> &toCollapse,
   if (pItOp && !isCollapsableLoops(pItOp, nItOp))
     return false;
 
-  // TODO: Make sure all other operations in the same basic block as `node` can
-  // be collapsed and sink into the collapsed iteration (through Interfaces
-  // defined in TD files).
-  SmallVector<SparseCollapsableOp> collapsableOps;
-  for (Operation &op : *pItOp.getBody()) {
-    if (&op == curSpace.getOperation() || &op == nItOp.getOperation() ||
-        &op == pItOp.getBody()->getTerminator())
-      continue;
-    // All other ops in parent loop need to be collapsable.
-    auto collapsableOp = llvm::dyn_cast<SparseCollapsableOp>(&op);
-    if (!collapsableOp)
-      return false;
-    collapsableOps.push_back(collapsableOp);
-  }
-
   CollapseSpaceInfo &info = toCollapse.emplace_back();
   info.space = curSpace;
   info.loop = nItOp;
-  info.collapseOps = std::move(collapsableOps);
   return true;
 }
 
-void collapseSparseSpace(SmallVectorImpl<CollapseSpaceInfo> &toCollapse) {
+void collapseSparseSpace(MutableArrayRef<CollapseSpaceInfo> toCollapse) {
   if (toCollapse.size() < 2)
     return;
 
@@ -141,21 +122,22 @@ void collapseSparseSpace(SmallVectorImpl<CollapseSpaceInfo> &toCollapse) {
 
   auto cloned = llvm::cast<IterateOp>(builder.clone(*innermost, mapper));
   builder.setInsertionPointToStart(cloned.getBody());
-  SmallVector<Operation *> loops =
-      llvm::map_to_vector(toCollapse, [](CollapseSpaceInfo &info) {
-        return info.loop.getOperation();
-      });
 
-  for (const CollapseSpaceInfo &info : toCollapse) {
-    for (SparseCollapsableOp op : info.collapseOps) {
-      ValueRange colVals = op.collaspeOpInto(builder, loops, cloned);
-      for (auto [o, r] : llvm::zip(op->getResults(), colVals))
-        o.replaceAllUsesWith(r);
-      op.erase();
+  LevelSet crdUsedLvls;
+  unsigned shift = 0, argIdx = 1;
+  for (auto info : toCollapse.drop_back()) {
+    LevelSet set = info.loop.getCrdUsedLvls();
+    crdUsedLvls |= set.lshift(shift);
+    shift += info.loop.getSpaceDim();
+    for (BlockArgument crd : info.loop.getCrds()) {
+      BlockArgument collapsedCrd = cloned.getBody()->insertArgument(
+          argIdx++, builder.getIndexType(), crd.getLoc());
+      crd.replaceAllUsesWith(collapsedCrd);
     }
   }
-
+  crdUsedLvls |= innermost.getCrdUsedLvls().lshift(shift);
   cloned.getIterator().setType(collapsedSpace.getType().getIteratorType());
+  cloned.setCrdUsedLvls(crdUsedLvls);
 
   rItOp.replaceAllUsesWith(cloned.getResults());
   // Erase collapsed loops.

>From 592636c3a93a2516fea6c7ca9716d5f426be9689 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 28 Mar 2024 16:32:38 +0000
Subject: [PATCH 5/6] setup lowering passes

---
 .../mlir/Dialect/SparseTensor/IR/Enums.h      |  4 +
 .../SparseTensor/IR/SparseTensorOps.td        | 18 +++--
 .../Dialect/SparseTensor/Transforms/Passes.h  | 16 ++++
 .../Dialect/SparseTensor/Transforms/Passes.td | 13 ++++
 .../SparseTensor/Transforms/CMakeLists.txt    |  1 +
 .../Transforms/SparseIterationToScf.cpp       | 76 +++++++++++++++++++
 .../Transforms/SparseTensorPasses.cpp         | 27 +++++++
 .../Transforms/Utils/SparseTensorIterator.cpp | 55 ++++++++------
 .../Transforms/Utils/SparseTensorIterator.h   |  5 ++
 9 files changed, 183 insertions(+), 32 deletions(-)
 create mode 100644 mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index 33f613a46bad84..96ee7111fea2cf 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -357,6 +357,10 @@ struct LevelType {
     return hasSparseSemantic();
   }
 
+  constexpr unsigned getNumBuffer() const {
+    return hasDenseSemantic() ? 0 : (isWithPosLT() ? 2 : 1);
+  }
+
   std::string toMLIRString() const {
     std::string lvlStr = toFormatString(getLvlFmt());
     std::string propStr = "";
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index a734262f141079..8f65637a23406c 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -263,7 +263,7 @@ def SparseTensor_ReinterpretMapOp : SparseTensor_Op<"reinterpret_map", [NoMemory
 }
 
 def SparseTensor_ToPositionsOp : SparseTensor_Op<"positions",
-      [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
+      [Pure, AlwaysSpeculatable, DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
     Arguments<(ins AnySparseTensor:$tensor, LevelAttr:$level)>,
     Results<(outs AnyNon0RankedMemRef:$result)> {
   let summary = "Extracts the `level`-th positions array of the `tensor`";
@@ -285,12 +285,12 @@ def SparseTensor_ToPositionsOp : SparseTensor_Op<"positions",
        : tensor<64x64xf64, #CSR> to memref<?xindex>
     ```
   }];
-  let assemblyFormat = "$tensor attr-dict `:` type($tensor) `to` type($result)";
+  let assemblyFormat = "$tensor attr-dict `:` type($tensor) `to` type($result)"  ;
   let hasVerifier = 1;
 }
 
 def SparseTensor_ToCoordinatesOp : SparseTensor_Op<"coordinates",
-      [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
+      [Pure, AlwaysSpeculatable, DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
     Arguments<(ins AnySparseTensor:$tensor, LevelAttr:$level)>,
     Results<(outs AnyNon0RankedMemRef:$result)> {
   let summary = "Extracts the `level`-th coordinates array of the `tensor`";
@@ -317,7 +317,7 @@ def SparseTensor_ToCoordinatesOp : SparseTensor_Op<"coordinates",
 }
 
 def SparseTensor_ToCoordinatesBufferOp : SparseTensor_Op<"coordinates_buffer",
-      [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
+      [Pure, AlwaysSpeculatable, DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
     Arguments<(ins AnySparseTensor:$tensor)>,
     Results<(outs AnyNon0RankedMemRef:$result)> {
   let summary = "Extracts the linear coordinates array from a tensor";
@@ -340,8 +340,7 @@ def SparseTensor_ToCoordinatesBufferOp : SparseTensor_Op<"coordinates_buffer",
     Example:
 
     ```mlir
-    %1 = sparse_tensor.coordinates_buffer %0
-       : tensor<64x64xf64, #COO> to memref<?xindex>
+    %1 = sparse_tensor.coordinates_buffer %0 : tensor<64x64xf64, #COO>
     ```
   }];
   let assemblyFormat = "$tensor attr-dict `:` type($tensor) `to` type($result)";
@@ -349,7 +348,7 @@ def SparseTensor_ToCoordinatesBufferOp : SparseTensor_Op<"coordinates_buffer",
 }
 
 def SparseTensor_ToValuesOp : SparseTensor_Op<"values",
-      [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
+      [Pure, AlwaysSpeculatable, DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
     Arguments<(ins AnySparseTensor:$tensor)>,
     Results<(outs AnyNon0RankedMemRef:$result)> {
   let summary = "Extracts numerical values array from a tensor";
@@ -367,7 +366,7 @@ def SparseTensor_ToValuesOp : SparseTensor_Op<"values",
     Example:
 
     ```mlir
-    %1 = sparse_tensor.values %0 : tensor<64x64xf64, #CSR> to memref<?xf64>
+    %1 = sparse_tensor.values %0 : tensor<64x64xf64, #CSR>
     ```
   }];
   let assemblyFormat = "$tensor attr-dict `:` type($tensor) `to` type($result)";
@@ -1452,6 +1451,9 @@ def ExtractIterSpaceOp : SparseTensor_Op<"iteration.extract_space",
     unsigned getSpaceDim() {
       return getHiLvl() - getLoLvl();
     }
+    ArrayRef<::mlir::sparse_tensor::LevelType> getSpaceLvlTypes() {
+      return getResultSpace().getType().getLvlTypes();
+    }
   }];
 
   let hasVerifier = 1;
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 0e9f5120f7b3dc..8d2b9fe571e20b 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -16,6 +16,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/OneToNTypeConversion.h"
 
 //===----------------------------------------------------------------------===//
 // Include the generated pass header (which needs some early definitions).
@@ -142,6 +143,21 @@ void populateLowerForeachToSCFPatterns(RewritePatternSet &patterns);
 
 std::unique_ptr<Pass> createLowerForeachToSCFPass();
 
+//===----------------------------------------------------------------------===//
+// The LowerSparseIterationToSCF pass.
+//===----------------------------------------------------------------------===//
+
+/// Type converter for iter_space and iterator.
+class SparseIterationTypeConverter : public OneToNTypeConverter {
+public:
+  SparseIterationTypeConverter();
+};
+
+void populateLowerSparseIterationToSCFPatterns(TypeConverter &converter,
+                                               RewritePatternSet &patterns);
+
+std::unique_ptr<Pass> createLowerSparseIterationToSCFPass();
+
 //===----------------------------------------------------------------------===//
 // The SparseTensorConversion pass.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 3ab75c23dbefa0..f27c64a7dee84e 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -210,6 +210,19 @@ def LowerForeachToSCF : Pass<"lower-sparse-foreach-to-scf", "func::FuncOp"> {
   ];
 }
 
+def LowerSparseIterationToSCF : Pass<"lower-sparse-iteration-to-scf", "func::FuncOp"> {
+  let summary = "lower sparse_tensor.iterate/coiterate into scf loops";
+  let description = [{
+     TODO:
+  }];
+  let constructor = "mlir::createLowerSparseIterationToSCFPass()";
+  let dependentDialects = [
+    "memref::MemRefDialect",
+    "scf::SCFDialect",
+    "sparse_tensor::SparseTensorDialect",
+  ];
+}
+
 def SparseTensorConversionPass : Pass<"sparse-tensor-conversion", "ModuleOp"> {
   let summary = "Convert sparse tensors and primitives to library calls";
   let description = [{
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index 2a29ee8a7a87cb..e4acfa8889e5f8 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
   SparseAssembler.cpp
   SparseBufferRewriting.cpp
   SparseGPUCodegen.cpp
+  SparseIterationToScf.cpp
   SparseReinterpretMap.cpp
   SparseStorageSpecifierToLLVM.cpp
   SparseSpaceCollapse.cpp
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
new file mode 100644
index 00000000000000..267eff724590e7
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -0,0 +1,76 @@
+
+#include "Utils/CodegenUtils.h"
+#include "Utils/SparseTensorLevel.h"
+
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+#include "mlir/Transforms/OneToNTypeConversion.h"
+
+using namespace mlir;
+using namespace mlir::sparse_tensor;
+
+static std::optional<LogicalResult>
+convertIterSpaceType(IterSpaceType itSp, SmallVectorImpl<Type> &fields) {
+  if (itSp.getSpaceDim() > 1)
+    llvm_unreachable("Not implemented.");
+
+  auto idxTp = IndexType::get(itSp.getContext());
+  // FIXME: this assumes that the Pos/CrdBitWidth in sparse tensor encoding is
+  // overriden to non-default values.
+  auto sparseMemRef = MemRefType::get({ShapedType::kDynamic}, idxTp);
+  for (LevelType lt : itSp.getLvlTypes()) {
+    // Position and coordinate buffer in the sparse structure.
+    if (lt.isWithPosLT())
+      fields.push_back(sparseMemRef);
+    if (lt.isWithCrdLT())
+      fields.push_back(sparseMemRef);
+  }
+  // Two indices for lower and upper bound.
+  fields.append({idxTp, idxTp});
+  return success();
+}
+
+namespace {
+
+/// Sparse codegen rule for number of entries operator.
+class ExtractIterSpaceConverter
+    : public OneToNOpConversionPattern<ExtractIterSpaceOp> {
+public:
+  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+  LogicalResult
+  matchAndRewrite(ExtractIterSpaceOp op, OpAdaptor adaptor,
+                  OneToNPatternRewriter &rewriter) const override {
+    if (op.getSpaceDim() > 1)
+      llvm_unreachable("Not implemented.");
+    Location loc = op.getLoc();
+
+    const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
+    std::unique_ptr<SparseTensorLevel> lvl =
+        makeSparseTensorLevel(rewriter, loc, op.getTensor(), 0, op.getLoLvl());
+
+    SmallVector<Value> result = llvm::to_vector(lvl->getLvlBuffers());
+    if (!op.getParentIter()) {
+      // TODO: handle batch.
+      std::pair<Value, Value> bounds = lvl->peekRangeAt(
+          rewriter, loc, /*batchPrefix*/ {}, constantIndex(rewriter, loc, 0));
+      result.append({bounds.first, bounds.second});
+    } else {
+      llvm_unreachable("Not implemented.");
+    }
+
+    rewriter.replaceOp(op, result, resultMapping);
+    return success();
+  }
+};
+
+} // namespace
+
+mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
+  addConversion([](Type type) { return type; });
+  addConversion(convertIterSpaceType);
+}
+
+void mlir::populateLowerSparseIterationToSCFPatterns(
+    TypeConverter &converter, RewritePatternSet &patterns) {
+  patterns.add<ExtractIterSpaceConverter>(converter, patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index acea25f023980a..327518a4d03943 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -26,6 +26,7 @@ namespace mlir {
 #define GEN_PASS_DEF_SPARSEREINTERPRETMAP
 #define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE
 #define GEN_PASS_DEF_SPARSIFICATIONPASS
+#define GEN_PASS_DEF_LOWERSPARSEITERATIONTOSCF
 #define GEN_PASS_DEF_LOWERSPARSEOPSTOFOREACH
 #define GEN_PASS_DEF_LOWERFOREACHTOSCF
 #define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
@@ -152,10 +153,32 @@ struct LowerForeachToSCFPass
     auto *ctx = &getContext();
     RewritePatternSet patterns(ctx);
     populateLowerForeachToSCFPatterns(patterns);
+
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
 };
 
+struct LowerSparseIterationToSCFPass
+    : public impl::LowerSparseIterationToSCFBase<
+          LowerSparseIterationToSCFPass> {
+  LowerSparseIterationToSCFPass() = default;
+  LowerSparseIterationToSCFPass(const LowerSparseIterationToSCFPass &) =
+      default;
+
+  void runOnOperation() override {
+    auto *ctx = &getContext();
+    RewritePatternSet patterns(ctx);
+    SparseIterationTypeConverter converter;
+    ConversionTarget target(*ctx);
+    target.addIllegalOp<ExtractIterSpaceOp, IterateOp>();
+
+    populateLowerSparseIterationToSCFPatterns(converter, patterns);
+    if (failed(applyPartialOneToNConversion(getOperation(), converter,
+                                            std::move(patterns))))
+      signalPassFailure();
+  }
+};
+
 struct SparseTensorConversionPass
     : public impl::SparseTensorConversionPassBase<SparseTensorConversionPass> {
   SparseTensorConversionPass() = default;
@@ -438,6 +461,10 @@ std::unique_ptr<Pass> mlir::createLowerForeachToSCFPass() {
   return std::make_unique<LowerForeachToSCFPass>();
 }
 
+std::unique_ptr<Pass> mlir::createLowerSparseIterationToSCFPass() {
+  return std::make_unique<LowerSparseIterationToSCFPass>();
+}
+
 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
   return std::make_unique<SparseTensorConversionPass>();
 }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
index 60dca3c55dec3d..6b8150a01c1e31 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
@@ -1332,6 +1332,30 @@ ValueRange NonEmptySubSectIterator::forwardImpl(OpBuilder &b, Location l) {
 // SparseIterator factory functions.
 //===----------------------------------------------------------------------===//
 
+/// Helper function to create a TensorLevel object from given `tensor`.
+std::unique_ptr<SparseTensorLevel>
+sparse_tensor::makeSparseTensorLevel(LevelType lt, Value sz, ValueRange b,
+                                     unsigned t, Level l) {
+  assert(lt.getNumBuffer() == b.size());
+  switch (lt.getLvlFmt()) {
+  case LevelFormat::Dense:
+    return std::make_unique<DenseLevel>(t, l, sz);
+  case LevelFormat::Batch:
+    return std::make_unique<BatchLevel>(t, l, sz);
+  case LevelFormat::Compressed:
+    return std::make_unique<CompressedLevel>(t, l, lt, sz, b[0], b[1]);
+  case LevelFormat::LooseCompressed:
+    return std::make_unique<LooseCompressedLevel>(t, l, lt, sz, b[0], b[1]);
+  case LevelFormat::Singleton:
+    return std::make_unique<SingletonLevel>(t, l, lt, sz, b[0]);
+  case LevelFormat::NOutOfM:
+    return std::make_unique<NOutOfMLevel>(t, l, lt, sz, b[0]);
+  case LevelFormat::Undef:
+    llvm_unreachable("undefined level format");
+  }
+  llvm_unreachable("unrecognizable level format");
+}
+
 std::unique_ptr<SparseTensorLevel>
 sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
                                      unsigned tid, Level lvl) {
@@ -1341,33 +1365,16 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
   Value sz = stt.hasEncoding() ? b.create<LvlOp>(l, t, lvl).getResult()
                                : b.create<tensor::DimOp>(l, t, lvl).getResult();
 
-  switch (lt.getLvlFmt()) {
-  case LevelFormat::Dense:
-    return std::make_unique<DenseLevel>(tid, lvl, sz);
-  case LevelFormat::Batch:
-    return std::make_unique<BatchLevel>(tid, lvl, sz);
-  case LevelFormat::Compressed: {
-    Value pos = b.create<ToPositionsOp>(l, t, lvl);
-    Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
-    return std::make_unique<CompressedLevel>(tid, lvl, lt, sz, pos, crd);
-  }
-  case LevelFormat::LooseCompressed: {
+  SmallVector<Value, 2> buffers;
+  if (lt.isWithPosLT()) {
     Value pos = b.create<ToPositionsOp>(l, t, lvl);
-    Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
-    return std::make_unique<LooseCompressedLevel>(tid, lvl, lt, sz, pos, crd);
+    buffers.push_back(pos);
   }
-  case LevelFormat::Singleton: {
-    Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
-    return std::make_unique<SingletonLevel>(tid, lvl, lt, sz, crd);
+  if (lt.isWithCrdLT()) {
+    Value pos = b.create<ToCoordinatesOp>(l, t, lvl);
+    buffers.push_back(pos);
   }
-  case LevelFormat::NOutOfM: {
-    Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
-    return std::make_unique<NOutOfMLevel>(tid, lvl, lt, sz, crd);
-  }
-  case LevelFormat::Undef:
-    llvm_unreachable("undefined level format");
-  }
-  llvm_unreachable("unrecognizable level format");
+  return makeSparseTensorLevel(lt, sz, buffers, tid, lvl);
 }
 
 std::pair<std::unique_ptr<SparseTensorLevel>, std::unique_ptr<SparseIterator>>
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
index 9d69a233555986..036e750165c2b1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
@@ -288,6 +288,11 @@ std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(OpBuilder &builder,
                                                          Location loc, Value t,
                                                          unsigned tid, Level l);
 
+/// Helper function to create a TensorLevel object from given `tensor`.
+std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(LevelType lt, Value sz,
+                                                         ValueRange buffers,
+                                                         unsigned tid, Level l);
+
 /// Helper function to create a simple SparseIterator object that iterate over
 /// the SparseTensorLevel.
 std::unique_ptr<SparseIterator> makeSimpleIterator(const SparseTensorLevel &stl,

>From cd3f8bcb1fe8df639d8724f8bf6e46a9ab9c1dcf Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 1 Apr 2024 16:24:39 +0000
Subject: [PATCH 6/6] implement lower pass (WIP)

---
 .../SparseTensor/IR/SparseTensorTypes.td      |   4 +
 .../SparseTensor/IR/SparseTensorDialect.cpp   |   2 +-
 .../Transforms/SparseIterationToScf.cpp       | 107 +++++++++++++++++-
 .../Transforms/SparseTensorRewriting.cpp      |   6 +-
 .../Transforms/Utils/SparseTensorIterator.h   |   5 +-
 5 files changed, 116 insertions(+), 8 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
index aa674b613e71db..e7b6a8af670e13 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
@@ -113,6 +113,10 @@ def SparseTensor_Iterator : SparseTensor_Type<"Iterator"> {
      unsigned getSpaceDim() {
        return getLvlTypes().size();
      }
+     bool isUnique() {
+       // As long as the last level is unique, the entire iterator is unqiue.
+       return !getLvlTypes().back().isa<LevelPropNonDefault::Nonunique>();
+     }
   }];
 
 
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 6c047e2a79248e..78e04a2c7d059c 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -2244,7 +2244,7 @@ Block::BlockArgListType IterateOp::getRegionIterArgs() {
 std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() {
   return cast<sparse_tensor::YieldOp>(
              getRegion().getBlocks().front().getTerminator())
-      .getResultMutable();
+      .getResultsMutable();
 }
 
 std::optional<ResultRange> IterateOp::getLoopResults() { return getResults(); }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
index 267eff724590e7..0b9f79152edc3c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -1,7 +1,8 @@
 
 #include "Utils/CodegenUtils.h"
-#include "Utils/SparseTensorLevel.h"
+#include "Utils/SparseTensorIterator.h"
 
+#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
 #include "mlir/Transforms/OneToNTypeConversion.h"
@@ -30,6 +31,21 @@ convertIterSpaceType(IterSpaceType itSp, SmallVectorImpl<Type> &fields) {
   return success();
 }
 
+static std::optional<LogicalResult>
+convertIteratorType(IteratorType itTp, SmallVectorImpl<Type> &fields) {
+  if (itTp.getSpaceDim() > 1)
+    llvm_unreachable("Not implemented.");
+
+  auto idxTp = IndexType::get(itTp.getContext());
+  // TODO: This assumes there is no batch dimenstion in the sparse tensor.
+  if (!itTp.isUnique()) {
+    // Segment high for non-unqiue iterator.
+    fields.push_back(idxTp);
+  }
+  fields.push_back(idxTp);
+  return success();
+}
+
 namespace {
 
 /// Sparse codegen rule for number of entries operator.
@@ -63,14 +79,101 @@ class ExtractIterSpaceConverter
   }
 };
 
+class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
+public:
+  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+  LogicalResult
+  matchAndRewrite(IterateOp op, OpAdaptor adaptor,
+                  OneToNPatternRewriter &rewriter) const override {
+    if (op.getSpaceDim() > 1 || !op.getCrdUsedLvls().empty())
+      llvm_unreachable("Not implemented.");
+
+    Location loc = op.getLoc();
+
+    LevelType lt = op.getIterSpace().getType().getLvlTypes().front();
+
+    ValueRange buffers = adaptor.getIterSpace().take_front(2);
+    // TODO: Introduce a class to represent a sparse iter_space, which is a
+    // combination of sparse levels and posRange.
+    // ValueRange posRange = adaptor.getIterSpace().take_front(2);
+
+    std::unique_ptr<SparseTensorLevel> stl = makeSparseTensorLevel(
+        lt, /*sz=*/nullptr, buffers, /*tid=*/0, /*lvl=*/0);
+
+    // TODO: decouple sparse iterator with sparse levels.
+    std::unique_ptr<SparseIterator> it = makeSimpleIterator(*stl);
+
+    // FIXME: only works for the first level.
+    it->genInit(rewriter, loc, /*parent*/ nullptr);
+    if (it->iteratableByFor()) {
+      // TODO
+      llvm_unreachable("not yet implemented.");
+    } else {
+      SmallVector<Value> ivs;
+      llvm::append_range(ivs, it->getCursor());
+      for (ValueRange inits : adaptor.getInitArgs())
+        llvm::append_range(ivs, inits);
+
+      assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; }));
+
+      TypeRange types = ValueRange(ivs).getTypes();
+      auto whileOp = rewriter.create<scf::WhileOp>(loc, types, ivs);
+      SmallVector<Location> l(types.size(), op.getIterator().getLoc());
+
+      // Generates loop conditions.
+      Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, l);
+      rewriter.setInsertionPointToStart(before);
+      ValueRange bArgs = before->getArguments();
+      auto [whileCond, remArgs] = it->genWhileCond(rewriter, loc, bArgs);
+      assert(remArgs.size() == adaptor.getInitArgs().size());
+      rewriter.create<scf::ConditionOp>(loc, whileCond, before->getArguments());
+
+      // Generates loop body.
+      Block *loopBody = op.getBody();
+      OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes());
+      if (failed(typeConverter->convertSignatureArgs(
+              loopBody->getArgumentTypes(), bodyTypeMapping)))
+        return failure();
+
+      rewriter.applySignatureConversion(loopBody, bodyTypeMapping);
+      Region &dstRegion = whileOp.getAfter();
+      // TODO: handle uses of coordinate!
+      rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end());
+      ValueRange aArgs = whileOp.getAfterArguments();
+      auto yieldOp = llvm::cast<sparse_tensor::YieldOp>(
+          whileOp.getAfterBody()->getTerminator());
+
+      rewriter.setInsertionPointToEnd(whileOp.getAfterBody());
+
+      aArgs = it->linkNewScope(aArgs);
+      ValueRange nx = it->forward(rewriter, loc);
+      SmallVector<Value> yields;
+      llvm::append_range(yields, nx);
+      llvm::append_range(yields, yieldOp.getResults());
+
+      // replace sparse_tensor.yield with scf.yield.
+      yieldOp->erase();
+      rewriter.create<scf::YieldOp>(loc, yields);
+
+      const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
+      rewriter.replaceOp(
+          op, whileOp.getResults().drop_front(it->getCursor().size()),
+          resultMapping);
+    }
+    return success();
+  }
+};
+
 } // namespace
 
 mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
   addConversion([](Type type) { return type; });
   addConversion(convertIterSpaceType);
+  addConversion(convertIteratorType);
 }
 
 void mlir::populateLowerSparseIterationToSCFPatterns(
     TypeConverter &converter, RewritePatternSet &patterns) {
-  patterns.add<ExtractIterSpaceConverter>(converter, patterns.getContext());
+  patterns.add<ExtractIterSpaceConverter, SparseIterateOpConverter>(
+      converter, patterns.getContext());
 }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index b117c1694e45b8..a88f66e47d2f38 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -579,7 +579,7 @@ struct GenSemiRingReduction : public OpRewritePattern<GenericOp> {
     rewriter.setInsertionPointToStart(&semiring.getAbsentRegion().front());
     auto zero =
         rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(rtp));
-    rewriter.create<sparse_tensor::YieldOp>(loc, zero);
+    rewriter.create<sparse_tensor::YieldOp>(loc, zero->getResults());
     rewriter.setInsertionPointAfter(semiring);
     // CustomReduce {
     //    x = x REDUC y, identity
@@ -821,7 +821,7 @@ struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
 
           auto t =
               builder.create<tensor::InsertOp>(loc, v, reduc.front(), dstDcvs);
-          builder.create<sparse_tensor::YieldOp>(loc, t);
+          builder.create<sparse_tensor::YieldOp>(loc, t->getResults());
         });
 
     Value t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
@@ -906,7 +906,7 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
                      srcDcvs, dstSizes, dstDcvs);
           auto t =
               builder.create<tensor::InsertOp>(loc, v, reduc.front(), dstDcvs);
-          builder.create<sparse_tensor::YieldOp>(loc, t);
+          builder.create<sparse_tensor::YieldOp>(loc, t->getResults());
         });
 
     Value t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
index 036e750165c2b1..ba252b5d32c025 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
@@ -295,8 +295,9 @@ std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(LevelType lt, Value sz,
 
 /// Helper function to create a simple SparseIterator object that iterate over
 /// the SparseTensorLevel.
-std::unique_ptr<SparseIterator> makeSimpleIterator(const SparseTensorLevel &stl,
-                                                   SparseEmitStrategy strategy);
+std::unique_ptr<SparseIterator> makeSimpleIterator(
+    const SparseTensorLevel &stl,
+    SparseEmitStrategy strategy = SparseEmitStrategy::kFunctional);
 
 /// Helper function to create a synthetic SparseIterator object that iterate
 /// over a dense space specified by [0,`sz`).



More information about the Mlir-commits mailing list