[Mlir-commits] [mlir] [WIP][mlir][sparse] Setting up sparse_tensor.iterator-related Ops. (PR #85958)

Peiming Liu llvmlistbot at llvm.org
Fri Mar 22 11:41:48 PDT 2024


https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/85958

>From 1e76a2a5417212bb2e5891a00c0069c4118ce778 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 18 Mar 2024 19:40:46 +0000
Subject: [PATCH 1/3] test parse iterate operation

---
 .../SparseTensor/IR/SparseTensorOps.td        | 17 +++++
 .../SparseTensor/IR/SparseTensorTypes.td      | 76 +++++++++++++++++++
 2 files changed, 93 insertions(+)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 29cf8c32447ecf..6efeb6007d649e 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1418,6 +1418,23 @@ def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Iteration Operations.
+//===----------------------------------------------------------------------===//
+
+def IterateOp : SparseTensor_Op<"iterate",
+    [RecursiveMemoryEffects]> {
+
+  let arguments = (ins AnySparseIterSpace:$iterSpace,
+                       Variadic<AnyType>:$initArgs);
+  let results = (outs Variadic<AnyType>:$results);
+  let regions = (region SizedRegion<1>:$region);
+
+  let extraClassDeclaration = [{}];
+
+  let hasCustomAssemblyFormat = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // Sparse Tensor Debugging and Test-Only Operations.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
index 185cff46ae25d5..54a8e4d7ecd398 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
@@ -72,4 +72,80 @@ def SparseTensorStorageSpecifier
     : Type<CPred<"::llvm::isa<::mlir::sparse_tensor::StorageSpecifierType>($_self)">, "metadata",
           "::mlir::sparse_tensor::StorageSpecifierType">;
 
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Iteration Types.
+//===----------------------------------------------------------------------===//
+
+def LevelTypeArrayParameter : ArrayRefParameter<"::mlir::sparse_tensor::LevelType", "level-types"> {
+  let printer = [{
+    auto lvlStrings = llvm::map_range($_self, [](auto lt){ return lt.toMLIRString(); });
+    $_printer << "[" << llvm::join(lvlStrings, ",") << "]";
+  }];
+
+  let parser = [{ [&]() -> FailureOr<SmallVector<::mlir::sparse_tensor::LevelType>> {
+    SmallVector<::mlir::sparse_tensor::LevelType> ret;
+
+    const auto res = $_parser.parseCommaSeparatedList(
+      mlir::OpAsmParser::Delimiter::Square,
+      [&]() -> ParseResult {
+        ::mlir::sparse_tensor::ir_detail::LvlTypeParser lParser;
+        auto lvlTpOrFail = lParser.parseLvlType($_parser);
+        if (failed(lvlTpOrFail))
+          return failure();
+        ret.emplace_back(*lvlTpOrFail);
+        return success();
+      }, " in level-type list");
+
+    if (failed(res))
+      return failure();
+    return ret;
+  }() }];
+}
+
+def SparseTensor_Iterator : SparseTensor_Type<"Iterator"> {
+  let mnemonic = "iterator";
+
+  let parameters = (ins
+     LevelTypeArrayParameter: $lvlTypes
+  );
+
+  // let skipDefaultBuilders = 1;
+  // let hasCustomAssemblyFormat = 1;
+  let assemblyFormat="`<` $lvlTypes `>`";
+}
+
+
+def SparseTensor_IterSpace : SparseTensor_Type<"IterSpace"> {
+  let mnemonic = "iter_space";
+
+  let parameters = (ins
+     LevelTypeArrayParameter: $lvlTypes
+  );
+
+  let extraClassDeclaration = [{
+     ::mlir::sparse_tensor::IteratorType getIteratorType() const {
+        return IteratorType::get(getContext(), getLvlTypes());
+     }
+  }];
+
+  // let skipDefaultBuilders = 1;
+  // let hasCustomAssemblyFormat = 1;
+  let assemblyFormat="`<` $lvlTypes `>`";
+}
+
+def IsSparseSparseIterSpaceTypePred
+    : CPred<"::llvm::isa<::mlir::sparse_tensor::IterSpaceType>($_self)">;
+
+def IsSparseSparseIteratorTypePred
+    : CPred<"::llvm::isa<::mlir::sparse_tensor::IteratorType>($_self)">;
+
+def AnySparseIterSpace
+    : Type<IsSparseSparseIterSpaceTypePred, "sparse iteration space",
+          "::mlir::sparse_tensor::IterSpaceType">;
+
+def AnySparseIterator
+    : Type<IsSparseSparseIteratorTypePred, "sparse iterator",
+          "::mlir::sparse_tensor::IteratorType">;
+
+
 #endif // SPARSETENSOR_TYPES

>From e64944645a4301b393ec4e8c77ee58aaf4cfc19d Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 20 Mar 2024 16:14:12 +0000
Subject: [PATCH 2/3] test sparse space collapse

---
 .../Dialect/SparseTensor/IR/SparseTensor.h    |   2 +
 .../SparseTensor/IR/SparseTensorOps.td        |  41 ++-
 .../SparseTensor/IR/SparseTensorTypes.td      |  15 +-
 .../Dialect/SparseTensor/Transforms/Passes.h  |   6 +
 .../Dialect/SparseTensor/Transforms/Passes.td |  16 ++
 .../SparseTensor/IR/SparseTensorDialect.cpp   | 253 +++++++++++++++++-
 .../SparseTensor/Transforms/CMakeLists.txt    |   1 +
 .../Transforms/SparseSpaceCollapse.cpp        | 152 +++++++++++
 8 files changed, 479 insertions(+), 7 deletions(-)
 create mode 100644 mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 5e523ec428aefb..78692307820bc5 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -17,7 +17,9 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/TensorEncoding.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 6efeb6007d649e..f0eaf2191fdbd1 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -15,6 +15,8 @@ include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td"
 include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/ControlFlowInterfaces.td"
+include "mlir/Interfaces/LoopLikeInterface.td"
 
 //===----------------------------------------------------------------------===//
 // Base class.
@@ -1422,16 +1424,51 @@ def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
 // Sparse Tensor Iteration Operations.
 //===----------------------------------------------------------------------===//
 
+def ExtractIterSpaceOp : SparseTensor_Op<"iteration.extract_space",
+    [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+
+  let arguments = (ins AnySparseTensor:$tensor,
+                       Optional<AnySparseIterator>:$parentIter,
+                       LevelAttr:$loLvl, LevelAttr:$hiLvl);
+
+  let results = (outs AnySparseIterSpace:$resultSpace);
+
+  let extraClassDeclaration = [{
+    unsigned getSpaceDim() {
+      return getHiLvl() - getLoLvl();
+    }
+  }];
+
+  let hasVerifier = 1;
+  let assemblyFormat = "$tensor (`at`$parentIter^)? `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) "
+                       " attr-dict `:` type($tensor) (`,` type($parentIter)^)?";
+}
+
 def IterateOp : SparseTensor_Op<"iterate",
-    [RecursiveMemoryEffects]> {
+    [RecursiveMemoryEffects, RecursivelySpeculatable,
+     DeclareOpInterfaceMethods<LoopLikeOpInterface,
+      ["getInitsMutable", "getLoopResults", "getRegionIterArgs",
+      "getSingleInductionVar", "getYieldedValuesMutable"]>,
+     DeclareOpInterfaceMethods<RegionBranchOpInterface,
+      ["getEntrySuccessorOperands"]>,
+     SingleBlockImplicitTerminator<"sparse_tensor::YieldOp">]> {
 
   let arguments = (ins AnySparseIterSpace:$iterSpace,
                        Variadic<AnyType>:$initArgs);
   let results = (outs Variadic<AnyType>:$results);
   let regions = (region SizedRegion<1>:$region);
 
-  let extraClassDeclaration = [{}];
+  let extraClassDeclaration = [{
+    BlockArgument getIterator() {
+      return getRegion().getArguments().front();
+    }
+    unsigned getNumRegionIterArgs() {
+      return getRegion().getArguments().size() - 1;
+    }
+  }];
 
+  let hasVerifier = 1;
+  let hasRegionVerifier = 1;
   let hasCustomAssemblyFormat = 1;
 }
 
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
index 54a8e4d7ecd398..aa674b613e71db 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
@@ -109,8 +109,13 @@ def SparseTensor_Iterator : SparseTensor_Type<"Iterator"> {
      LevelTypeArrayParameter: $lvlTypes
   );
 
-  // let skipDefaultBuilders = 1;
-  // let hasCustomAssemblyFormat = 1;
+  let extraClassDeclaration = [{
+     unsigned getSpaceDim() {
+       return getLvlTypes().size();
+     }
+  }];
+
+
   let assemblyFormat="`<` $lvlTypes `>`";
 }
 
@@ -123,13 +128,15 @@ def SparseTensor_IterSpace : SparseTensor_Type<"IterSpace"> {
   );
 
   let extraClassDeclaration = [{
+     unsigned getSpaceDim() {
+       return getLvlTypes().size();
+     }
+
      ::mlir::sparse_tensor::IteratorType getIteratorType() const {
         return IteratorType::get(getContext(), getLvlTypes());
      }
   }];
 
-  // let skipDefaultBuilders = 1;
-  // let hasCustomAssemblyFormat = 1;
   let assemblyFormat="`<` $lvlTypes `>`";
 }
 
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 61b07d222d156b..0e9f5120f7b3dc 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -247,6 +247,12 @@ std::unique_ptr<Pass> createSparsificationAndBufferizationPass(
     bool enableBufferInitialization, unsigned vectorLength,
     bool enableVLAVectorization, bool enableSIMDIndex32, bool enableGPULibgen);
 
+//===----------------------------------------------------------------------===//
+// Sparse Iteration Transform Passes
+//===----------------------------------------------------------------------===//
+
+std::unique_ptr<Pass> createSparseSpaceCollapsePass();
+
 //===----------------------------------------------------------------------===//
 // Registration.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 58e2d6f32386c3..3ab75c23dbefa0 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -454,4 +454,20 @@ def SparsificationAndBufferization : Pass<"sparsification-and-bufferization", "M
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// Sparse Iteration Transform Passes
+//===----------------------------------------------------------------------===//
+
+def SparseSpaceCollapse : Pass<"sparse-space-collapse", "func::FuncOp"> {
+  let summary = "(experimental) sparse space collpasing pass";
+  let description = [{
+     This pass collapse consecutive sparse spaces (extracted from the same tensor)
+     into one multi-dimensional space.
+  }];
+  let constructor = "mlir::createSparseSpaceCollapsePass()";
+  let dependentDialects = [
+    "sparse_tensor::SparseTensorDialect",
+  ];
+}
+
 #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 6da51bb6b9cacf..6afa3e6309dc65 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1912,13 +1912,264 @@ LogicalResult YieldOp::verify() {
   auto *parentOp = (*this)->getParentOp();
   if (isa<BinaryOp>(parentOp) || isa<UnaryOp>(parentOp) ||
       isa<ReduceOp>(parentOp) || isa<SelectOp>(parentOp) ||
-      isa<ForeachOp>(parentOp))
+      isa<ForeachOp>(parentOp) || isa<IterateOp>(parentOp))
     return success();
 
   return emitOpError("expected parent op to be sparse_tensor unary, binary, "
                      "reduce, select or foreach");
 }
 
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Iteration Operations.
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr,
+                                   IntegerAttr &lvlHiAttr) {
+  Level lvlLo, lvlHi;
+
+  if (parser.parseInteger(lvlLo))
+    return failure();
+
+  if (succeeded(parser.parseOptionalKeyword("to"))) {
+    if (parser.parseInteger(lvlHi))
+      return failure();
+  } else {
+    lvlHi = lvlLo + 1;
+  }
+  lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo);
+  lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi);
+  return success();
+}
+
+static void printLevelRange(OpAsmPrinter &p, ExtractIterSpaceOp op,
+                            IntegerAttr lvlLo, IntegerAttr lvlHi) {
+  if (op.getLoLvl() + 1 == op.getHiLvl())
+    p << op.getLoLvl();
+  else
+    p << op.getLoLvl() << " to " << op.getHiLvl();
+}
+
+LogicalResult ExtractIterSpaceOp::inferReturnTypes(
+    MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
+    DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
+    SmallVectorImpl<mlir::Type> &ret) {
+
+  ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region);
+  SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
+  ArrayRef<LevelType> lts = stt.getLvlTypes().slice(
+      adaptor.getLoLvl(), adaptor.getHiLvl() - adaptor.getLoLvl());
+  ret.push_back(IterSpaceType::get(ctx, lts));
+  return success();
+}
+
+LogicalResult ExtractIterSpaceOp::verify() {
+  SparseTensorType stt = getSparseTensorType(getTensor());
+  if (getLoLvl() >= getHiLvl())
+    return emitOpError("expected smaller level low than level high");
+
+  ArrayRef<LevelType> lts = stt.getLvlTypes().slice(getLoLvl(), getSpaceDim());
+  if (!getResultSpace().getType().getLvlTypes().equals(lts)) {
+    return emitOpError(
+        "mismatch in iteration space level types and tensor level types");
+  }
+
+  TypedValue<IteratorType> pIter = getParentIter();
+  if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) {
+    return emitOpError("parent iterator is only needed iff level low equals 0");
+  }
+
+  if (pIter) {
+    unsigned pDim = pIter.getType().getSpaceDim();
+    if (getLoLvl() < pDim || !stt.getLvlTypes()
+                                  .slice(getLoLvl() - pDim, pDim)
+                                  .equals(pIter.getType().getLvlTypes())) {
+      return emitOpError(
+          "mismatch in parent iterator level types and tensor level types");
+    }
+  }
+
+  return success();
+}
+
+ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
+  OpAsmParser::Argument iterator;
+  OpAsmParser::UnresolvedOperand iterSpace;
+
+  // Parses %iters in %spaces
+  if (parser.parseArgument(iterator) || parser.parseKeyword("in") ||
+      parser.parseOperand(iterSpace)) {
+    return failure();
+  }
+
+  // Parse the optional initial iteration arguments.
+  SmallVector<OpAsmParser::Argument> regionArgs;
+  SmallVector<OpAsmParser::UnresolvedOperand> operands;
+  // Region arguments starts with iterators and follows by optional
+  // user-provided iter_args.
+  regionArgs.push_back(iterator);
+  bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
+  if (hasIterArgs)
+    if (parser.parseAssignmentList(regionArgs, operands))
+      return failure();
+
+  // parse ": sparse_tensor.iter_space -> ret"
+  Type iterSpaceTps;
+  if (parser.parseColon() || parser.parseType(iterSpaceTps))
+    return failure();
+  if (hasIterArgs)
+    if (parser.parseArrowTypeList(result.types))
+      return failure();
+
+  if (regionArgs.size() != result.types.size() + 1) {
+    return parser.emitError(
+        parser.getNameLoc(),
+        "mismatch in number of loop-carried values and defined values");
+  }
+
+  // Resolves input operands.
+  if (parser.resolveOperand(iterSpace, iterSpaceTps, result.operands))
+    return failure();
+
+  if (hasIterArgs) {
+    for (auto argOperandType :
+         llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
+      Type type = std::get<2>(argOperandType);
+      std::get<0>(argOperandType).type = type;
+      if (parser.resolveOperand(std::get<1>(argOperandType), type,
+                                result.operands))
+        return failure();
+    }
+  }
+
+  Region *body = result.addRegion();
+  regionArgs.front().type =
+      iterSpaceTps.cast<IterSpaceType>().getIteratorType();
+  if (parser.parseRegion(*body, regionArgs))
+    return failure();
+
+  IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
+
+  // Parse the optional attribute list.
+  if (parser.parseOptionalAttrDict(result.attributes))
+    return failure();
+
+  return success();
+}
+
+/// Prints the initialization list in the form of
+///   <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
+/// where 'inner' values are assumed to be region arguments and 'outer' values
+/// are regular SSA values.
+static void printInitializationList(OpAsmPrinter &p,
+                                    Block::BlockArgListType blocksArgs,
+                                    ValueRange initializers,
+                                    StringRef prefix = "") {
+  assert(blocksArgs.size() == initializers.size() &&
+         "expected same length of arguments and initializers");
+  if (initializers.empty())
+    return;
+
+  p << prefix << '(';
+  llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
+    p << std::get<0>(it) << " = " << std::get<1>(it);
+  });
+  p << ")";
+}
+
+void IterateOp::print(OpAsmPrinter &p) {
+  p << " " << getIterator() << " in " << getIterSpace();
+
+  printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
+
+  p << " : " << getIterSpace().getType() << " ";
+  if (!getInitArgs().empty())
+    p << "-> (" << getInitArgs().getTypes() << ") ";
+
+  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
+                /*printBlockTerminators=*/!getInitArgs().empty());
+}
+
+LogicalResult IterateOp::verify() {
+  if (getInitArgs().size() != getNumResults()) {
+    return emitOpError(
+        "mismatch in number of loop-carried values and defined values");
+  }
+  return success();
+}
+
+LogicalResult IterateOp::verifyRegions() {
+  if (getIterator().getType() != getIterSpace().getType().getIteratorType())
+    return emitOpError("mismatch in iterator and iteration space type");
+  if (getNumRegionIterArgs() != getNumResults())
+    return emitOpError(
+        "mismatch in number of basic block args and defined values");
+
+  auto initArgs = getInitArgs();
+  auto iterArgs = getRegionIterArgs();
+  auto yieldVals = getYieldedValues();
+  auto opResults = getResults();
+  if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
+                        opResults.size()})) {
+    return emitOpError() << "number mismatch between iter args and results.";
+  }
+
+  unsigned i = 0;
+  for (auto e : llvm::zip_equal(initArgs, iterArgs, yieldVals, opResults)) {
+    if (std::get<0>(e).getType() != std::get<3>(e).getType())
+      return emitOpError() << "types mismatch between " << i
+                           << "th iter operand and defined value";
+    if (std::get<1>(e).getType() != std::get<3>(e).getType())
+      return emitOpError() << "types mismatch between " << i
+                           << "th iter region arg and defined value";
+    if (std::get<2>(e).getType() != std::get<3>(e).getType())
+      return emitOpError() << "types mismatch between " << i
+                           << "th yield value and defined value";
+
+    ++i;
+  }
+  return success();
+}
+
+/// IterateOp implemented interfaces' methods.
+SmallVector<Region *> IterateOp::getLoopRegions() { return {&getRegion()}; }
+
+std::optional<Value> IterateOp::getSingleInductionVar() {
+  return getIterator();
+}
+
+MutableArrayRef<OpOperand> IterateOp::getInitsMutable() {
+  return getInitArgsMutable();
+}
+
+Block::BlockArgListType IterateOp::getRegionIterArgs() {
+  return getRegion().getArguments().drop_front();
+}
+
+std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() {
+  return cast<sparse_tensor::YieldOp>(
+             getRegion().getBlocks().front().getTerminator())
+      .getResultMutable();
+}
+
+std::optional<ResultRange> IterateOp::getLoopResults() { return getResults(); }
+
+OperandRange IterateOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+  return getInitArgs();
+}
+
+void IterateOp::getSuccessorRegions(RegionBranchPoint point,
+                                    SmallVectorImpl<RegionSuccessor> &regions) {
+  // Both the operation itself and the region may be branching into the body or
+  // back into the operation itself.
+  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
+  // It is possible for loop not to enter the body.
+  regions.push_back(RegionSuccessor(getResults()));
+}
+
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Dialect Setups.
+//===----------------------------------------------------------------------===//
+
 /// Materialize a single constant operation from a given attribute value with
 /// the desired resultant type.
 Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index 3c0f82fc00bb9d..8840da9aa56ef7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
   SparseGPUCodegen.cpp
   SparseReinterpretMap.cpp
   SparseStorageSpecifierToLLVM.cpp
+  SparseSpaceCollapse.cpp
   SparseTensorCodegen.cpp
   SparseTensorConversion.cpp
   SparseTensorPasses.cpp
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
new file mode 100644
index 00000000000000..f3207ede9585b4
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
@@ -0,0 +1,152 @@
+//===--------- SparseSpaceCollapse.cpp - Collapse Sparse Space Pass -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/Transforms/Passes.h"
+
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+
+namespace mlir {
+
+#define GEN_PASS_DEF_SPARSESPACECOLLAPSE
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
+
+namespace sparse_tensor {
+
+bool isCollapsableIterations(LoopLikeOpInterface parent,
+                             LoopLikeOpInterface node) {
+  auto pIterArgs = parent.getRegionIterArgs();
+  auto nInitArgs = node.getInits();
+  if (pIterArgs.size() != nInitArgs.size())
+    return false;
+
+  auto pYields = parent.getYieldedValues();
+  auto nResult = node.getLoopResults().value();
+
+  bool yieldEq =
+      llvm::all_of(llvm::zip_equal(pYields, nResult), [](auto zipped) {
+        return std::get<0>(zipped) == std::get<1>(zipped);
+      });
+
+  // Parent iter_args should be passed directly to the node's init_args.
+  bool iterArgEq =
+      llvm::all_of(llvm::zip_equal(pIterArgs, nInitArgs), [](auto zipped) {
+        return std::get<0>(zipped) == std::get<1>(zipped);
+      });
+
+  return yieldEq && iterArgEq;
+}
+
+bool legalToCollapse(ExtractIterSpaceOp parent, ExtractIterSpaceOp node) {
+  auto pItOp = llvm::dyn_cast<IterateOp>(parent->getParentOp());
+  auto nItOp = llvm::dyn_cast<IterateOp>(node->getParentOp());
+
+  // Can only collapse spaces extracted from the same tensor.
+  if (parent.getTensor() != node.getTensor() || !parent->hasOneUse())
+    return false;
+
+  // Can only collapse consecutive simple iteration on one tensor (i.e., no
+  // coiteration).
+  if (!nItOp || nItOp.getIterSpace() != parent.getResult() ||
+      nItOp->getBlock() != parent->getBlock())
+    return false;
+
+  if (pItOp && !isCollapsableIterations(pItOp, nItOp))
+    return false;
+
+  // TODO: Make sure all other operations in the same basic block as `node` can
+  // be collapsed and sink into the collapsed iteration (through Interfaces
+  // defined in TD files).
+  return true;
+}
+
+void collapseSparseSpace(ArrayRef<ExtractIterSpaceOp> toCollapse) {
+  if (toCollapse.size() < 2)
+    return;
+
+  ExtractIterSpaceOp root = toCollapse.front();
+  ExtractIterSpaceOp leaf = toCollapse.back();
+  Location loc = root.getLoc();
+
+  if (!leaf->hasOneUse())
+    return;
+  assert(root->hasOneUse());
+
+  // Insert collapsed operation at the same scope as root operation.
+  OpBuilder builder(toCollapse.front());
+
+  // Construct the collapsed iteration space.
+  auto collapsedSpace = builder.create<ExtractIterSpaceOp>(
+      loc, root.getTensor(), root.getParentIter(), root.getLoLvl(),
+      leaf.getHiLvl());
+
+  auto rItOp = llvm::cast<IterateOp>(*root->getUsers().begin());
+  auto pItOp = llvm::cast<IterateOp>(leaf->getParentOp());
+
+  // This could either be IterateOp or (TODO: in the future) CoIterateOp.
+  auto loop = llvm::dyn_cast<IterateOp>(*leaf->getUsers().begin());
+  if (!loop || !isCollapsableIterations(pItOp, loop))
+    return;
+
+  IRMapping mapper;
+  mapper.map(leaf, collapsedSpace.getResultSpace());
+  for (auto z : llvm::zip_equal(loop.getInitArgs(), rItOp.getInitArgs()))
+    mapper.map(std::get<0>(z), std::get<1>(z));
+
+  auto cloned = llvm::cast<IterateOp>(builder.clone(*loop, mapper));
+  cloned.getIterator().setType(collapsedSpace.getType().getIteratorType());
+
+  rItOp.replaceAllUsesWith(cloned.getResults());
+  // Erase collapsed loops.
+  rItOp.erase();
+  root.erase();
+}
+
+struct SparseSpaceCollapsePass
+    : public impl::SparseSpaceCollapseBase<SparseSpaceCollapsePass> {
+  SparseSpaceCollapsePass() = default;
+
+  void runOnOperation() override {
+    func::FuncOp func = getOperation();
+
+    // A naive (experimental) implementation to collapse consecutive sparse
+    // spaces. It does NOT handle complex cases where multiple spaces are
+    // extracted in the same basic block. E.g.,
+    //
+    // %space1 = extract_space %t1 ...
+    // %space2 = extract_space %t2 ...
+    // sparse_tensor.iterate(%sp1) ...
+    //
+    SmallVector<ExtractIterSpaceOp> toCollapse;
+    func->walk([&](ExtractIterSpaceOp op) {
+      if (toCollapse.empty()) {
+        // Root space to collapse.
+        toCollapse.push_back(op);
+      } else {
+        if (legalToCollapse(toCollapse.back(), op)) {
+          toCollapse.push_back(op);
+        } else {
+          collapseSparseSpace(toCollapse);
+          toCollapse.clear();
+        }
+      }
+    });
+
+    collapseSparseSpace(toCollapse);
+  }
+};
+
+} // namespace sparse_tensor
+
+std::unique_ptr<Pass> createSparseSpaceCollapsePass() {
+  return std::make_unique<sparse_tensor::SparseSpaceCollapsePass>();
+}
+
+} // namespace mlir

>From 69cf58f3d4375ab076ef16e98d9c513b9eea1205 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 22 Mar 2024 18:41:34 +0000
Subject: [PATCH 3/3] test collapsing coordinate extraction from iterator.

---
 .../SparseTensor/IR/SparseTensorInterfaces.h  |   2 +
 .../SparseTensor/IR/SparseTensorInterfaces.td |  15 ++
 .../SparseTensor/IR/SparseTensorOps.td        |  17 +-
 .../SparseTensor/IR/SparseTensorDialect.cpp   | 167 +++++++++++-------
 .../Transforms/SparseSpaceCollapse.cpp        | 119 +++++++++----
 5 files changed, 217 insertions(+), 103 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h
index c0f31762ee071f..115e08b2cf8b14 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h
@@ -10,6 +10,7 @@
 #define MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORINTERFACES_H_
 
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
 
 namespace mlir {
 class PatternRewriter;
@@ -20,6 +21,7 @@ class StageWithSortSparseOp;
 namespace detail {
 LogicalResult stageWithSortImpl(sparse_tensor::StageWithSortSparseOp op,
                                 PatternRewriter &rewriter, Value &tmpBufs);
+
 } // namespace detail
 } // namespace sparse_tensor
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td
index 05eed0483f2c8a..ee1c0b52b47e45 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td
@@ -42,5 +42,20 @@ def StageWithSortSparseOpInterface : OpInterface<"StageWithSortSparseOp"> {
   ];
 }
 
+def SparseCollapsableOpInterface : OpInterface<"SparseCollapsableOp"> {
+  let description = [{ TODO }];
+
+  let cppNamespace = "::mlir::sparse_tensor";
+
+  let methods = [
+    InterfaceMethod<
+    /*desc=*/"test",
+    /*retTy=*/"ValueRange",
+    /*methodName=*/"collaspeOpInto",
+    /*args=*/(ins "::mlir::OpBuilder &":$builder,
+                  "::mlir::ArrayRef<::mlir::Operation *>":$loops,
+                  "::mlir::Operation *":$collapsed)>,
+  ];
+}
 
 #endif // SPARSETENSOR_IR_SPARSETENSORINTERFACES
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index f0eaf2191fdbd1..467030a8d221af 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1280,7 +1280,9 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [Pure, SameOperandsAndResu
   let hasVerifier = 1;
 }
 
-def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator]>,
+def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator,
+    ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp", "ForeachOp",
+                 "IterateOp"]>]>,
     Arguments<(ins Optional<AnyType>:$result)> {
   let summary = "Yield from sparse_tensor set-like operations";
   let description = [{
@@ -1311,7 +1313,6 @@ def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator]>,
   let assemblyFormat = [{
         $result attr-dict `:` type($result)
   }];
-  let hasVerifier = 1;
 }
 
 def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
@@ -1440,10 +1441,20 @@ def ExtractIterSpaceOp : SparseTensor_Op<"iteration.extract_space",
   }];
 
   let hasVerifier = 1;
-  let assemblyFormat = "$tensor (`at`$parentIter^)? `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) "
+  let assemblyFormat = "$tensor (`at` $parentIter^)? `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) "
                        " attr-dict `:` type($tensor) (`,` type($parentIter)^)?";
 }
 
+def CoordinateOp : SparseTensor_Op<"iteration.coordinate",
+    [Pure, DeclareOpInterfaceMethods<SparseCollapsableOpInterface>]> {
+  let arguments = (ins AnySparseIterator:$iterator);
+  let results = (outs Variadic<Index>:$crds);
+
+  let extraClassDeclaration = [{ }];
+  // let hasVerifier = 1;
+  let assemblyFormat = " $iterator attr-dict `:` type($iterator) `->` type($crds)";
+}
+
 def IterateOp : SparseTensor_Op<"iterate",
     [RecursiveMemoryEffects, RecursivelySpeculatable,
      DeclareOpInterfaceMethods<LoopLikeOpInterface,
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 6afa3e6309dc65..54be3c3b4c3e5f 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1907,18 +1907,6 @@ LogicalResult SortOp::verify() {
   return success();
 }
 
-LogicalResult YieldOp::verify() {
-  // Check for compatible parent.
-  auto *parentOp = (*this)->getParentOp();
-  if (isa<BinaryOp>(parentOp) || isa<UnaryOp>(parentOp) ||
-      isa<ReduceOp>(parentOp) || isa<SelectOp>(parentOp) ||
-      isa<ForeachOp>(parentOp) || isa<IterateOp>(parentOp))
-    return success();
-
-  return emitOpError("expected parent op to be sparse_tensor unary, binary, "
-                     "reduce, select or foreach");
-}
-
 //===----------------------------------------------------------------------===//
 // Sparse Tensor Iteration Operations.
 //===----------------------------------------------------------------------===//
@@ -1936,17 +1924,82 @@ static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr,
   } else {
     lvlHi = lvlLo + 1;
   }
+
+  if (lvlHi <= lvlLo)
+    parser.emitError(parser.getNameLoc(),
+                     "expect larger level upper bound than lower bound");
+
   lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo);
   lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi);
   return success();
 }
 
-static void printLevelRange(OpAsmPrinter &p, ExtractIterSpaceOp op,
-                            IntegerAttr lvlLo, IntegerAttr lvlHi) {
-  if (op.getLoLvl() + 1 == op.getHiLvl())
-    p << op.getLoLvl();
+static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
+                            IntegerAttr lvlHi) {
+  unsigned lo = lvlLo.getValue().getZExtValue();
+  unsigned hi = lvlHi.getValue().getZExtValue();
+  if (lo + 1 == hi)
+    p << lo;
   else
-    p << op.getLoLvl() << " to " << op.getHiLvl();
+    p << lo << " to " << hi;
+}
+
+ParseResult
+parseSparseSpaceLoop(OpAsmParser &parser, OperationState &state,
+                     SmallVectorImpl<OpAsmParser::Argument> &iterators,
+                     SmallVectorImpl<OpAsmParser::Argument> &iterArgs) {
+  SmallVector<OpAsmParser::UnresolvedOperand> spaces;
+  SmallVector<OpAsmParser::UnresolvedOperand> initArgs;
+  // Parses "%iters, ... in %spaces, ..."
+  if (parser.parseArgumentList(iterators) || parser.parseKeyword("in") ||
+      parser.parseOperandList(spaces))
+    return failure();
+
+  if (iterators.size() != spaces.size())
+    return parser.emitError(
+        parser.getNameLoc(),
+        "mismatch in number of sparse iterators and sparse spaces");
+
+  bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
+  if (hasIterArgs)
+    if (parser.parseAssignmentList(iterArgs, initArgs))
+      return failure();
+
+  SmallVector<Type> iterSpaceTps;
+  // parse ": sparse_tensor.iter_space -> ret"
+  if (parser.parseColon() || parser.parseTypeList(iterSpaceTps))
+    return failure();
+  if (iterSpaceTps.size() != spaces.size())
+    return parser.emitError(parser.getNameLoc(),
+                            "mismatch in number of iteration space operands "
+                            "and iteration space types");
+
+  for (auto [it, tp] : llvm::zip_equal(iterators, iterSpaceTps)) {
+    IterSpaceType spaceTp = llvm::dyn_cast<IterSpaceType>(tp);
+    if (!spaceTp)
+      return parser.emitError(parser.getNameLoc(),
+                              "expected sparse_tensor.iter_space type for "
+                              "iteration space operands");
+    it.type = spaceTp.getIteratorType();
+  }
+
+  if (hasIterArgs)
+    if (parser.parseArrowTypeList(state.types))
+      return failure();
+
+  // Resolves input operands.
+  if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(),
+                             state.operands))
+    return failure();
+
+  if (hasIterArgs) {
+    for (auto [it, init, tp] : llvm::zip(iterArgs, initArgs, state.types)) {
+      it.type = tp;
+      if (parser.resolveOperand(init, tp, state.operands))
+        return failure();
+    }
+  }
+  return success();
 }
 
 LogicalResult ExtractIterSpaceOp::inferReturnTypes(
@@ -1991,60 +2044,45 @@ LogicalResult ExtractIterSpaceOp::verify() {
   return success();
 }
 
+ValueRange CoordinateOp::collaspeOpInto(OpBuilder &builder,
+                                        ArrayRef<Operation *> loops,
+                                        Operation *collapsed) {
+  assert(llvm::all_of(loops,
+                      [](Operation *l) { return llvm::isa<IterateOp>(l); }));
+  auto finalLoop = llvm::cast<IterateOp>(collapsed);
+  SmallVector<Type> retTps(finalLoop.getIterSpace().getType().getSpaceDim(),
+                           builder.getIndexType());
+  auto collapsedCoords =
+      builder.create<CoordinateOp>(getLoc(), retTps, finalLoop.getIterator());
+
+  for (Operation *l : loops) {
+    if (getIterator().getParentBlock()->getParentOp() == l) {
+      auto space = llvm::cast<IterateOp>(l)
+                       .getIterSpace()
+                       .getDefiningOp<ExtractIterSpaceOp>();
+
+      return collapsedCoords.getResults().slice(space.getLoLvl(),
+                                                space.getSpaceDim());
+    }
+  }
+  llvm_unreachable(
+      "Can not find the corresponding iterate space for the collapsable op.");
+}
+
 ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
   OpAsmParser::Argument iterator;
   OpAsmParser::UnresolvedOperand iterSpace;
 
-  // Parses %iters in %spaces
-  if (parser.parseArgument(iterator) || parser.parseKeyword("in") ||
-      parser.parseOperand(iterSpace)) {
+  SmallVector<OpAsmParser::Argument> iters, iterArgs;
+  if (parseSparseSpaceLoop(parser, result, iters, iterArgs))
     return failure();
-  }
-
-  // Parse the optional initial iteration arguments.
-  SmallVector<OpAsmParser::Argument> regionArgs;
-  SmallVector<OpAsmParser::UnresolvedOperand> operands;
-  // Region arguments starts with iterators and follows by optional
-  // user-provided iter_args.
-  regionArgs.push_back(iterator);
-  bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
-  if (hasIterArgs)
-    if (parser.parseAssignmentList(regionArgs, operands))
-      return failure();
-
-  // parse ": sparse_tensor.iter_space -> ret"
-  Type iterSpaceTps;
-  if (parser.parseColon() || parser.parseType(iterSpaceTps))
-    return failure();
-  if (hasIterArgs)
-    if (parser.parseArrowTypeList(result.types))
-      return failure();
-
-  if (regionArgs.size() != result.types.size() + 1) {
-    return parser.emitError(
-        parser.getNameLoc(),
-        "mismatch in number of loop-carried values and defined values");
-  }
-
-  // Resolves input operands.
-  if (parser.resolveOperand(iterSpace, iterSpaceTps, result.operands))
-    return failure();
-
-  if (hasIterArgs) {
-    for (auto argOperandType :
-         llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
-      Type type = std::get<2>(argOperandType);
-      std::get<0>(argOperandType).type = type;
-      if (parser.resolveOperand(std::get<1>(argOperandType), type,
-                                result.operands))
-        return failure();
-    }
-  }
+  if (iters.size() != 1)
+    return parser.emitError(parser.getNameLoc(),
+                            "expected only one iterator/iteration space");
 
+  iters.append(iterArgs);
   Region *body = result.addRegion();
-  regionArgs.front().type =
-      iterSpaceTps.cast<IterSpaceType>().getIteratorType();
-  if (parser.parseRegion(*body, regionArgs))
+  if (parser.parseRegion(*body, iters))
     return failure();
 
   IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
@@ -2078,7 +2116,6 @@ static void printInitializationList(OpAsmPrinter &p,
 
 void IterateOp::print(OpAsmPrinter &p) {
   p << " " << getIterator() << " in " << getIterSpace();
-
   printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
 
   p << " : " << getIterSpace().getType() << " ";
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
index f3207ede9585b4..752b2dfc2a0070 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
@@ -20,8 +20,14 @@ namespace mlir {
 
 namespace sparse_tensor {
 
-bool isCollapsableIterations(LoopLikeOpInterface parent,
-                             LoopLikeOpInterface node) {
+struct CollapseSpaceInfo {
+  ExtractIterSpaceOp space;
+  // Coiteration as well (if make sense)?
+  IterateOp loop;
+  SmallVector<SparseCollapsableOp> collapseOps;
+};
+
+bool isCollapsableLoops(LoopLikeOpInterface parent, LoopLikeOpInterface node) {
   auto pIterArgs = parent.getRegionIterArgs();
   auto nInitArgs = node.getInits();
   if (pIterArgs.size() != nInitArgs.size())
@@ -44,43 +50,81 @@ bool isCollapsableIterations(LoopLikeOpInterface parent,
   return yieldEq && iterArgEq;
 }
 
-bool legalToCollapse(ExtractIterSpaceOp parent, ExtractIterSpaceOp node) {
-  auto pItOp = llvm::dyn_cast<IterateOp>(parent->getParentOp());
-  auto nItOp = llvm::dyn_cast<IterateOp>(node->getParentOp());
+bool legalToCollapse(SmallVectorImpl<CollapseSpaceInfo> &toCollapse,
+                     ExtractIterSpaceOp curSpace) {
+
+  auto getIterateOpOverSpace = [](ExtractIterSpaceOp space) -> IterateOp {
+    Value spaceVal = space.getResultSpace();
+    if (spaceVal.hasOneUse())
+      return llvm::dyn_cast<IterateOp>(*spaceVal.getUsers().begin());
+    return nullptr;
+  };
+
+  if (toCollapse.empty()) {
+    // Collapse root.
+    if (auto itOp = getIterateOpOverSpace(curSpace)) {
+      CollapseSpaceInfo &info = toCollapse.emplace_back();
+      info.space = curSpace;
+      info.loop = itOp;
+      // No operations need to be collapsed at the root level;
+      info.collapseOps = {};
+      return true;
+    }
+    return false;
+  }
+
+  auto parent = toCollapse.back().space;
+  auto pItOp = toCollapse.back().loop;
+  auto nItOp = getIterateOpOverSpace(curSpace);
 
   // Can only collapse spaces extracted from the same tensor.
-  if (parent.getTensor() != node.getTensor() || !parent->hasOneUse())
+  if (parent.getTensor() != curSpace.getTensor())
     return false;
 
   // Can only collapse consecutive simple iteration on one tensor (i.e., no
   // coiteration).
-  if (!nItOp || nItOp.getIterSpace() != parent.getResult() ||
-      nItOp->getBlock() != parent->getBlock())
+  if (!nItOp || nItOp->getBlock() != curSpace->getBlock() ||
+      pItOp.getIterator() != curSpace.getParentIter() ||
+      curSpace->getParentOp() != pItOp.getOperation())
     return false;
 
-  if (pItOp && !isCollapsableIterations(pItOp, nItOp))
+  if (pItOp && !isCollapsableLoops(pItOp, nItOp))
     return false;
 
   // TODO: Make sure all other operations in the same basic block as `node` can
   // be collapsed and sink into the collapsed iteration (through Interfaces
   // defined in TD files).
+  SmallVector<SparseCollapsableOp> collapsableOps;
+  for (Operation &op : *pItOp.getBody()) {
+    if (&op == curSpace.getOperation() || &op == nItOp.getOperation() ||
+        &op == pItOp.getBody()->getTerminator())
+      continue;
+    // All other ops in parent loop need to be collapsable.
+    auto collapsableOp = llvm::dyn_cast<SparseCollapsableOp>(&op);
+    if (!collapsableOp)
+      return false;
+    collapsableOps.push_back(collapsableOp);
+  }
+
+  CollapseSpaceInfo &info = toCollapse.emplace_back();
+  info.space = curSpace;
+  info.loop = nItOp;
+  info.collapseOps = std::move(collapsableOps);
   return true;
 }
 
-void collapseSparseSpace(ArrayRef<ExtractIterSpaceOp> toCollapse) {
+void collapseSparseSpace(SmallVectorImpl<CollapseSpaceInfo> &toCollapse) {
   if (toCollapse.size() < 2)
     return;
 
-  ExtractIterSpaceOp root = toCollapse.front();
-  ExtractIterSpaceOp leaf = toCollapse.back();
+  ExtractIterSpaceOp root = toCollapse.front().space;
+  ExtractIterSpaceOp leaf = toCollapse.back().space;
   Location loc = root.getLoc();
 
-  if (!leaf->hasOneUse())
-    return;
-  assert(root->hasOneUse());
+  assert(root->hasOneUse() && leaf->hasOneUse());
 
   // Insert collapsed operation at the same scope as root operation.
-  OpBuilder builder(toCollapse.front());
+  OpBuilder builder(root);
 
   // Construct the collapsed iteration space.
   auto collapsedSpace = builder.create<ExtractIterSpaceOp>(
@@ -88,19 +132,29 @@ void collapseSparseSpace(ArrayRef<ExtractIterSpaceOp> toCollapse) {
       leaf.getHiLvl());
 
   auto rItOp = llvm::cast<IterateOp>(*root->getUsers().begin());
-  auto pItOp = llvm::cast<IterateOp>(leaf->getParentOp());
-
-  // This could either be IterateOp or (TODO: in the future) CoIterateOp.
-  auto loop = llvm::dyn_cast<IterateOp>(*leaf->getUsers().begin());
-  if (!loop || !isCollapsableIterations(pItOp, loop))
-    return;
+  auto innermost = toCollapse.back().loop;
 
   IRMapping mapper;
   mapper.map(leaf, collapsedSpace.getResultSpace());
-  for (auto z : llvm::zip_equal(loop.getInitArgs(), rItOp.getInitArgs()))
+  for (auto z : llvm::zip_equal(innermost.getInitArgs(), rItOp.getInitArgs()))
     mapper.map(std::get<0>(z), std::get<1>(z));
 
-  auto cloned = llvm::cast<IterateOp>(builder.clone(*loop, mapper));
+  auto cloned = llvm::cast<IterateOp>(builder.clone(*innermost, mapper));
+  builder.setInsertionPointToStart(cloned.getBody());
+  SmallVector<Operation *> loops =
+      llvm::map_to_vector(toCollapse, [](CollapseSpaceInfo &info) {
+        return info.loop.getOperation();
+      });
+
+  for (const CollapseSpaceInfo &info : toCollapse) {
+    for (SparseCollapsableOp op : info.collapseOps) {
+      ValueRange colVals = op.collaspeOpInto(builder, loops, cloned);
+      for (auto [o, r] : llvm::zip(op->getResults(), colVals))
+        o.replaceAllUsesWith(r);
+      op.erase();
+    }
+  }
+
   cloned.getIterator().setType(collapsedSpace.getType().getIteratorType());
 
   rItOp.replaceAllUsesWith(cloned.getResults());
@@ -124,18 +178,13 @@ struct SparseSpaceCollapsePass
     // %space2 = extract_space %t2 ...
     // sparse_tensor.iterate(%sp1) ...
     //
-    SmallVector<ExtractIterSpaceOp> toCollapse;
+    SmallVector<CollapseSpaceInfo> toCollapse;
     func->walk([&](ExtractIterSpaceOp op) {
-      if (toCollapse.empty()) {
-        // Root space to collapse.
-        toCollapse.push_back(op);
-      } else {
-        if (legalToCollapse(toCollapse.back(), op)) {
-          toCollapse.push_back(op);
-        } else {
-          collapseSparseSpace(toCollapse);
-          toCollapse.clear();
-        }
+      if (!legalToCollapse(toCollapse, op)) {
+        // if not legal to collapse one more space, collapse the existing ones
+        // and clear.
+        collapseSparseSpace(toCollapse);
+        toCollapse.clear();
       }
     });
 



More information about the Mlir-commits mailing list