[Mlir-commits] [mlir] [WIP][mlir][sparse] Setting up sparse_tensor.iterator-related Ops. (PR #85958)

Peiming Liu llvmlistbot at llvm.org
Thu Mar 28 09:33:34 PDT 2024


https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/85958

>From 8e0f27f5adfc970387748056aa86fa04552d63e5 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 18 Mar 2024 19:40:46 +0000
Subject: [PATCH 1/5] test parse iterate operation

---
 .../SparseTensor/IR/SparseTensorOps.td        | 17 +++++
 .../SparseTensor/IR/SparseTensorTypes.td      | 76 +++++++++++++++++++
 2 files changed, 93 insertions(+)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 29cf8c32447ecf..6efeb6007d649e 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1418,6 +1418,23 @@ def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Iteration Operations.
+//===----------------------------------------------------------------------===//
+
+def IterateOp : SparseTensor_Op<"iterate",
+    [RecursiveMemoryEffects]> {
+
+  let arguments = (ins AnySparseIterSpace:$iterSpace,
+                       Variadic<AnyType>:$initArgs);
+  let results = (outs Variadic<AnyType>:$results);
+  let regions = (region SizedRegion<1>:$region);
+
+  let extraClassDeclaration = [{}];
+
+  let hasCustomAssemblyFormat = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // Sparse Tensor Debugging and Test-Only Operations.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
index 185cff46ae25d5..54a8e4d7ecd398 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
@@ -72,4 +72,80 @@ def SparseTensorStorageSpecifier
     : Type<CPred<"::llvm::isa<::mlir::sparse_tensor::StorageSpecifierType>($_self)">, "metadata",
           "::mlir::sparse_tensor::StorageSpecifierType">;
 
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Iteration Types.
+//===----------------------------------------------------------------------===//
+
+def LevelTypeArrayParameter : ArrayRefParameter<"::mlir::sparse_tensor::LevelType", "level-types"> {
+  let printer = [{
+    auto lvlStrings = llvm::map_range($_self, [](auto lt){ return lt.toMLIRString(); });
+    $_printer << "[" << llvm::join(lvlStrings, ",") << "]";
+  }];
+
+  let parser = [{ [&]() -> FailureOr<SmallVector<::mlir::sparse_tensor::LevelType>> {
+    SmallVector<::mlir::sparse_tensor::LevelType> ret;
+
+    const auto res = $_parser.parseCommaSeparatedList(
+      mlir::OpAsmParser::Delimiter::Square,
+      [&]() -> ParseResult {
+        ::mlir::sparse_tensor::ir_detail::LvlTypeParser lParser;
+        auto lvlTpOrFail = lParser.parseLvlType($_parser);
+        if (failed(lvlTpOrFail))
+          return failure();
+        ret.emplace_back(*lvlTpOrFail);
+        return success();
+      }, " in level-type list");
+
+    if (failed(res))
+      return failure();
+    return ret;
+  }() }];
+}
+
+def SparseTensor_Iterator : SparseTensor_Type<"Iterator"> {
+  let mnemonic = "iterator";
+
+  let parameters = (ins
+     LevelTypeArrayParameter: $lvlTypes
+  );
+
+  // let skipDefaultBuilders = 1;
+  // let hasCustomAssemblyFormat = 1;
+  let assemblyFormat="`<` $lvlTypes `>`";
+}
+
+
+def SparseTensor_IterSpace : SparseTensor_Type<"IterSpace"> {
+  let mnemonic = "iter_space";
+
+  let parameters = (ins
+     LevelTypeArrayParameter: $lvlTypes
+  );
+
+  let extraClassDeclaration = [{
+     ::mlir::sparse_tensor::IteratorType getIteratorType() const {
+        return IteratorType::get(getContext(), getLvlTypes());
+     }
+  }];
+
+  // let skipDefaultBuilders = 1;
+  // let hasCustomAssemblyFormat = 1;
+  let assemblyFormat="`<` $lvlTypes `>`";
+}
+
+def IsSparseSparseIterSpaceTypePred
+    : CPred<"::llvm::isa<::mlir::sparse_tensor::IterSpaceType>($_self)">;
+
+def IsSparseSparseIteratorTypePred
+    : CPred<"::llvm::isa<::mlir::sparse_tensor::IteratorType>($_self)">;
+
+def AnySparseIterSpace
+    : Type<IsSparseSparseIterSpaceTypePred, "sparse iteration space",
+          "::mlir::sparse_tensor::IterSpaceType">;
+
+def AnySparseIterator
+    : Type<IsSparseSparseIteratorTypePred, "sparse iterator",
+          "::mlir::sparse_tensor::IteratorType">;
+
+
 #endif // SPARSETENSOR_TYPES

>From 947f87e0a48e8c893a2b0025d8707d0ffcb5ccd7 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 20 Mar 2024 16:14:12 +0000
Subject: [PATCH 2/5] test sparse space collapse

---
 .../Dialect/SparseTensor/IR/SparseTensor.h    |   2 +
 .../SparseTensor/IR/SparseTensorOps.td        |  41 ++-
 .../SparseTensor/IR/SparseTensorTypes.td      |  15 +-
 .../Dialect/SparseTensor/Transforms/Passes.h  |   6 +
 .../Dialect/SparseTensor/Transforms/Passes.td |  16 ++
 .../SparseTensor/IR/SparseTensorDialect.cpp   | 253 +++++++++++++++++-
 .../SparseTensor/Transforms/CMakeLists.txt    |   1 +
 .../Transforms/SparseSpaceCollapse.cpp        | 152 +++++++++++
 8 files changed, 479 insertions(+), 7 deletions(-)
 create mode 100644 mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 5e523ec428aefb..78692307820bc5 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -17,7 +17,9 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/TensorEncoding.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 6efeb6007d649e..f0eaf2191fdbd1 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -15,6 +15,8 @@ include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td"
 include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/ControlFlowInterfaces.td"
+include "mlir/Interfaces/LoopLikeInterface.td"
 
 //===----------------------------------------------------------------------===//
 // Base class.
@@ -1422,16 +1424,51 @@ def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
 // Sparse Tensor Iteration Operations.
 //===----------------------------------------------------------------------===//
 
+def ExtractIterSpaceOp : SparseTensor_Op<"iteration.extract_space",
+    [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+
+  let arguments = (ins AnySparseTensor:$tensor,
+                       Optional<AnySparseIterator>:$parentIter,
+                       LevelAttr:$loLvl, LevelAttr:$hiLvl);
+
+  let results = (outs AnySparseIterSpace:$resultSpace);
+
+  let extraClassDeclaration = [{
+    unsigned getSpaceDim() {
+      return getHiLvl() - getLoLvl();
+    }
+  }];
+
+  let hasVerifier = 1;
+  let assemblyFormat = "$tensor (`at`$parentIter^)? `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) "
+                       " attr-dict `:` type($tensor) (`,` type($parentIter)^)?";
+}
+
 def IterateOp : SparseTensor_Op<"iterate",
-    [RecursiveMemoryEffects]> {
+    [RecursiveMemoryEffects, RecursivelySpeculatable,
+     DeclareOpInterfaceMethods<LoopLikeOpInterface,
+      ["getInitsMutable", "getLoopResults", "getRegionIterArgs",
+      "getSingleInductionVar", "getYieldedValuesMutable"]>,
+     DeclareOpInterfaceMethods<RegionBranchOpInterface,
+      ["getEntrySuccessorOperands"]>,
+     SingleBlockImplicitTerminator<"sparse_tensor::YieldOp">]> {
 
   let arguments = (ins AnySparseIterSpace:$iterSpace,
                        Variadic<AnyType>:$initArgs);
   let results = (outs Variadic<AnyType>:$results);
   let regions = (region SizedRegion<1>:$region);
 
-  let extraClassDeclaration = [{}];
+  let extraClassDeclaration = [{
+    BlockArgument getIterator() {
+      return getRegion().getArguments().front();
+    }
+    unsigned getNumRegionIterArgs() {
+      return getRegion().getArguments().size() - 1;
+    }
+  }];
 
+  let hasVerifier = 1;
+  let hasRegionVerifier = 1;
   let hasCustomAssemblyFormat = 1;
 }
 
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
index 54a8e4d7ecd398..aa674b613e71db 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
@@ -109,8 +109,13 @@ def SparseTensor_Iterator : SparseTensor_Type<"Iterator"> {
      LevelTypeArrayParameter: $lvlTypes
   );
 
-  // let skipDefaultBuilders = 1;
-  // let hasCustomAssemblyFormat = 1;
+  let extraClassDeclaration = [{
+     unsigned getSpaceDim() {
+       return getLvlTypes().size();
+     }
+  }];
+
+
   let assemblyFormat="`<` $lvlTypes `>`";
 }
 
@@ -123,13 +128,15 @@ def SparseTensor_IterSpace : SparseTensor_Type<"IterSpace"> {
   );
 
   let extraClassDeclaration = [{
+     unsigned getSpaceDim() {
+       return getLvlTypes().size();
+     }
+
      ::mlir::sparse_tensor::IteratorType getIteratorType() const {
         return IteratorType::get(getContext(), getLvlTypes());
      }
   }];
 
-  // let skipDefaultBuilders = 1;
-  // let hasCustomAssemblyFormat = 1;
   let assemblyFormat="`<` $lvlTypes `>`";
 }
 
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 61b07d222d156b..0e9f5120f7b3dc 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -247,6 +247,12 @@ std::unique_ptr<Pass> createSparsificationAndBufferizationPass(
     bool enableBufferInitialization, unsigned vectorLength,
     bool enableVLAVectorization, bool enableSIMDIndex32, bool enableGPULibgen);
 
+//===----------------------------------------------------------------------===//
+// Sparse Iteration Transform Passes
+//===----------------------------------------------------------------------===//
+
+std::unique_ptr<Pass> createSparseSpaceCollapsePass();
+
 //===----------------------------------------------------------------------===//
 // Registration.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 58e2d6f32386c3..3ab75c23dbefa0 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -454,4 +454,20 @@ def SparsificationAndBufferization : Pass<"sparsification-and-bufferization", "M
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// Sparse Iteration Transform Passes
+//===----------------------------------------------------------------------===//
+
+def SparseSpaceCollapse : Pass<"sparse-space-collapse", "func::FuncOp"> {
+  let summary = "(experimental) sparse space collpasing pass";
+  let description = [{
+     This pass collapse consecutive sparse spaces (extracted from the same tensor)
+     into one multi-dimensional space.
+  }];
+  let constructor = "mlir::createSparseSpaceCollapsePass()";
+  let dependentDialects = [
+    "sparse_tensor::SparseTensorDialect",
+  ];
+}
+
 #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 6da51bb6b9cacf..6afa3e6309dc65 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1912,13 +1912,264 @@ LogicalResult YieldOp::verify() {
   auto *parentOp = (*this)->getParentOp();
   if (isa<BinaryOp>(parentOp) || isa<UnaryOp>(parentOp) ||
       isa<ReduceOp>(parentOp) || isa<SelectOp>(parentOp) ||
-      isa<ForeachOp>(parentOp))
+      isa<ForeachOp>(parentOp) || isa<IterateOp>(parentOp))
     return success();
 
   return emitOpError("expected parent op to be sparse_tensor unary, binary, "
                      "reduce, select or foreach");
 }
 
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Iteration Operations.
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr,
+                                   IntegerAttr &lvlHiAttr) {
+  Level lvlLo, lvlHi;
+
+  if (parser.parseInteger(lvlLo))
+    return failure();
+
+  if (succeeded(parser.parseOptionalKeyword("to"))) {
+    if (parser.parseInteger(lvlHi))
+      return failure();
+  } else {
+    lvlHi = lvlLo + 1;
+  }
+  lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo);
+  lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi);
+  return success();
+}
+
+static void printLevelRange(OpAsmPrinter &p, ExtractIterSpaceOp op,
+                            IntegerAttr lvlLo, IntegerAttr lvlHi) {
+  if (op.getLoLvl() + 1 == op.getHiLvl())
+    p << op.getLoLvl();
+  else
+    p << op.getLoLvl() << " to " << op.getHiLvl();
+}
+
+LogicalResult ExtractIterSpaceOp::inferReturnTypes(
+    MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
+    DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
+    SmallVectorImpl<mlir::Type> &ret) {
+
+  ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region);
+  SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
+  ArrayRef<LevelType> lts = stt.getLvlTypes().slice(
+      adaptor.getLoLvl(), adaptor.getHiLvl() - adaptor.getLoLvl());
+  ret.push_back(IterSpaceType::get(ctx, lts));
+  return success();
+}
+
+LogicalResult ExtractIterSpaceOp::verify() {
+  SparseTensorType stt = getSparseTensorType(getTensor());
+  if (getLoLvl() >= getHiLvl())
+    return emitOpError("expected smaller level low than level high");
+
+  ArrayRef<LevelType> lts = stt.getLvlTypes().slice(getLoLvl(), getSpaceDim());
+  if (!getResultSpace().getType().getLvlTypes().equals(lts)) {
+    return emitOpError(
+        "mismatch in iteration space level types and tensor level types");
+  }
+
+  TypedValue<IteratorType> pIter = getParentIter();
+  if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) {
+    return emitOpError("parent iterator is only needed iff level low equals 0");
+  }
+
+  if (pIter) {
+    unsigned pDim = pIter.getType().getSpaceDim();
+    if (getLoLvl() < pDim || !stt.getLvlTypes()
+                                  .slice(getLoLvl() - pDim, pDim)
+                                  .equals(pIter.getType().getLvlTypes())) {
+      return emitOpError(
+          "mismatch in parent iterator level types and tensor level types");
+    }
+  }
+
+  return success();
+}
+
+ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
+  OpAsmParser::Argument iterator;
+  OpAsmParser::UnresolvedOperand iterSpace;
+
+  // Parses %iters in %spaces
+  if (parser.parseArgument(iterator) || parser.parseKeyword("in") ||
+      parser.parseOperand(iterSpace)) {
+    return failure();
+  }
+
+  // Parse the optional initial iteration arguments.
+  SmallVector<OpAsmParser::Argument> regionArgs;
+  SmallVector<OpAsmParser::UnresolvedOperand> operands;
+  // Region arguments starts with iterators and follows by optional
+  // user-provided iter_args.
+  regionArgs.push_back(iterator);
+  bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
+  if (hasIterArgs)
+    if (parser.parseAssignmentList(regionArgs, operands))
+      return failure();
+
+  // parse ": sparse_tensor.iter_space -> ret"
+  Type iterSpaceTps;
+  if (parser.parseColon() || parser.parseType(iterSpaceTps))
+    return failure();
+  if (hasIterArgs)
+    if (parser.parseArrowTypeList(result.types))
+      return failure();
+
+  if (regionArgs.size() != result.types.size() + 1) {
+    return parser.emitError(
+        parser.getNameLoc(),
+        "mismatch in number of loop-carried values and defined values");
+  }
+
+  // Resolves input operands.
+  if (parser.resolveOperand(iterSpace, iterSpaceTps, result.operands))
+    return failure();
+
+  if (hasIterArgs) {
+    for (auto argOperandType :
+         llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
+      Type type = std::get<2>(argOperandType);
+      std::get<0>(argOperandType).type = type;
+      if (parser.resolveOperand(std::get<1>(argOperandType), type,
+                                result.operands))
+        return failure();
+    }
+  }
+
+  Region *body = result.addRegion();
+  regionArgs.front().type =
+      iterSpaceTps.cast<IterSpaceType>().getIteratorType();
+  if (parser.parseRegion(*body, regionArgs))
+    return failure();
+
+  IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
+
+  // Parse the optional attribute list.
+  if (parser.parseOptionalAttrDict(result.attributes))
+    return failure();
+
+  return success();
+}
+
+/// Prints the initialization list in the form of
+///   <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
+/// where 'inner' values are assumed to be region arguments and 'outer' values
+/// are regular SSA values.
+static void printInitializationList(OpAsmPrinter &p,
+                                    Block::BlockArgListType blocksArgs,
+                                    ValueRange initializers,
+                                    StringRef prefix = "") {
+  assert(blocksArgs.size() == initializers.size() &&
+         "expected same length of arguments and initializers");
+  if (initializers.empty())
+    return;
+
+  p << prefix << '(';
+  llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
+    p << std::get<0>(it) << " = " << std::get<1>(it);
+  });
+  p << ")";
+}
+
+void IterateOp::print(OpAsmPrinter &p) {
+  p << " " << getIterator() << " in " << getIterSpace();
+
+  printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
+
+  p << " : " << getIterSpace().getType() << " ";
+  if (!getInitArgs().empty())
+    p << "-> (" << getInitArgs().getTypes() << ") ";
+
+  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
+                /*printBlockTerminators=*/!getInitArgs().empty());
+}
+
+LogicalResult IterateOp::verify() {
+  if (getInitArgs().size() != getNumResults()) {
+    return emitOpError(
+        "mismatch in number of loop-carried values and defined values");
+  }
+  return success();
+}
+
+LogicalResult IterateOp::verifyRegions() {
+  if (getIterator().getType() != getIterSpace().getType().getIteratorType())
+    return emitOpError("mismatch in iterator and iteration space type");
+  if (getNumRegionIterArgs() != getNumResults())
+    return emitOpError(
+        "mismatch in number of basic block args and defined values");
+
+  auto initArgs = getInitArgs();
+  auto iterArgs = getRegionIterArgs();
+  auto yieldVals = getYieldedValues();
+  auto opResults = getResults();
+  if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
+                        opResults.size()})) {
+    return emitOpError() << "number mismatch between iter args and results.";
+  }
+
+  unsigned i = 0;
+  for (auto e : llvm::zip_equal(initArgs, iterArgs, yieldVals, opResults)) {
+    if (std::get<0>(e).getType() != std::get<3>(e).getType())
+      return emitOpError() << "types mismatch between " << i
+                           << "th iter operand and defined value";
+    if (std::get<1>(e).getType() != std::get<3>(e).getType())
+      return emitOpError() << "types mismatch between " << i
+                           << "th iter region arg and defined value";
+    if (std::get<2>(e).getType() != std::get<3>(e).getType())
+      return emitOpError() << "types mismatch between " << i
+                           << "th yield value and defined value";
+
+    ++i;
+  }
+  return success();
+}
+
+/// IterateOp implemented interfaces' methods.
+SmallVector<Region *> IterateOp::getLoopRegions() { return {&getRegion()}; }
+
+std::optional<Value> IterateOp::getSingleInductionVar() {
+  return getIterator();
+}
+
+MutableArrayRef<OpOperand> IterateOp::getInitsMutable() {
+  return getInitArgsMutable();
+}
+
+Block::BlockArgListType IterateOp::getRegionIterArgs() {
+  return getRegion().getArguments().drop_front();
+}
+
+std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() {
+  return cast<sparse_tensor::YieldOp>(
+             getRegion().getBlocks().front().getTerminator())
+      .getResultMutable();
+}
+
+std::optional<ResultRange> IterateOp::getLoopResults() { return getResults(); }
+
+OperandRange IterateOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+  return getInitArgs();
+}
+
+void IterateOp::getSuccessorRegions(RegionBranchPoint point,
+                                    SmallVectorImpl<RegionSuccessor> &regions) {
+  // Both the operation itself and the region may be branching into the body or
+  // back into the operation itself.
+  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
+  // It is possible for loop not to enter the body.
+  regions.push_back(RegionSuccessor(getResults()));
+}
+
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Dialect Setups.
+//===----------------------------------------------------------------------===//
+
 /// Materialize a single constant operation from a given attribute value with
 /// the desired resultant type.
 Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index 3c0f82fc00bb9d..8840da9aa56ef7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
   SparseGPUCodegen.cpp
   SparseReinterpretMap.cpp
   SparseStorageSpecifierToLLVM.cpp
+  SparseSpaceCollapse.cpp
   SparseTensorCodegen.cpp
   SparseTensorConversion.cpp
   SparseTensorPasses.cpp
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
new file mode 100644
index 00000000000000..f3207ede9585b4
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
@@ -0,0 +1,152 @@
+//===--------- SparseSpaceCollapse.cpp - Collapse Sparse Space Pass -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/Transforms/Passes.h"
+
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+
+namespace mlir {
+
+#define GEN_PASS_DEF_SPARSESPACECOLLAPSE
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
+
+namespace sparse_tensor {
+
+bool isCollapsableIterations(LoopLikeOpInterface parent,
+                             LoopLikeOpInterface node) {
+  auto pIterArgs = parent.getRegionIterArgs();
+  auto nInitArgs = node.getInits();
+  if (pIterArgs.size() != nInitArgs.size())
+    return false;
+
+  auto pYields = parent.getYieldedValues();
+  auto nResult = node.getLoopResults().value();
+
+  bool yieldEq =
+      llvm::all_of(llvm::zip_equal(pYields, nResult), [](auto zipped) {
+        return std::get<0>(zipped) == std::get<1>(zipped);
+      });
+
+  // Parent iter_args should be passed directly to the node's init_args.
+  bool iterArgEq =
+      llvm::all_of(llvm::zip_equal(pIterArgs, nInitArgs), [](auto zipped) {
+        return std::get<0>(zipped) == std::get<1>(zipped);
+      });
+
+  return yieldEq && iterArgEq;
+}
+
+bool legalToCollapse(ExtractIterSpaceOp parent, ExtractIterSpaceOp node) {
+  auto pItOp = llvm::dyn_cast<IterateOp>(parent->getParentOp());
+  auto nItOp = llvm::dyn_cast<IterateOp>(node->getParentOp());
+
+  // Can only collapse spaces extracted from the same tensor.
+  if (parent.getTensor() != node.getTensor() || !parent->hasOneUse())
+    return false;
+
+  // Can only collapse consecutive simple iteration on one tensor (i.e., no
+  // coiteration).
+  if (!nItOp || nItOp.getIterSpace() != parent.getResult() ||
+      nItOp->getBlock() != parent->getBlock())
+    return false;
+
+  if (pItOp && !isCollapsableIterations(pItOp, nItOp))
+    return false;
+
+  // TODO: Make sure all other operations in the same basic block as `node` can
+  // be collapsed and sink into the collapsed iteration (through Interfaces
+  // defined in TD files).
+  return true;
+}
+
+void collapseSparseSpace(ArrayRef<ExtractIterSpaceOp> toCollapse) {
+  if (toCollapse.size() < 2)
+    return;
+
+  ExtractIterSpaceOp root = toCollapse.front();
+  ExtractIterSpaceOp leaf = toCollapse.back();
+  Location loc = root.getLoc();
+
+  if (!leaf->hasOneUse())
+    return;
+  assert(root->hasOneUse());
+
+  // Insert collapsed operation at the same scope as root operation.
+  OpBuilder builder(toCollapse.front());
+
+  // Construct the collapsed iteration space.
+  auto collapsedSpace = builder.create<ExtractIterSpaceOp>(
+      loc, root.getTensor(), root.getParentIter(), root.getLoLvl(),
+      leaf.getHiLvl());
+
+  auto rItOp = llvm::cast<IterateOp>(*root->getUsers().begin());
+  auto pItOp = llvm::cast<IterateOp>(leaf->getParentOp());
+
+  // This could either be IterateOp or (TODO: in the future) CoIterateOp.
+  auto loop = llvm::dyn_cast<IterateOp>(*leaf->getUsers().begin());
+  if (!loop || !isCollapsableIterations(pItOp, loop))
+    return;
+
+  IRMapping mapper;
+  mapper.map(leaf, collapsedSpace.getResultSpace());
+  for (auto z : llvm::zip_equal(loop.getInitArgs(), rItOp.getInitArgs()))
+    mapper.map(std::get<0>(z), std::get<1>(z));
+
+  auto cloned = llvm::cast<IterateOp>(builder.clone(*loop, mapper));
+  cloned.getIterator().setType(collapsedSpace.getType().getIteratorType());
+
+  rItOp.replaceAllUsesWith(cloned.getResults());
+  // Erase collapsed loops.
+  rItOp.erase();
+  root.erase();
+}
+
+struct SparseSpaceCollapsePass
+    : public impl::SparseSpaceCollapseBase<SparseSpaceCollapsePass> {
+  SparseSpaceCollapsePass() = default;
+
+  void runOnOperation() override {
+    func::FuncOp func = getOperation();
+
+    // A naive (experimental) implementation to collapse consecutive sparse
+    // spaces. It does NOT handle complex cases where multiple spaces are
+    // extracted in the same basic block. E.g.,
+    //
+    // %space1 = extract_space %t1 ...
+    // %space2 = extract_space %t2 ...
+    // sparse_tensor.iterate(%sp1) ...
+    //
+    SmallVector<ExtractIterSpaceOp> toCollapse;
+    func->walk([&](ExtractIterSpaceOp op) {
+      if (toCollapse.empty()) {
+        // Root space to collapse.
+        toCollapse.push_back(op);
+      } else {
+        if (legalToCollapse(toCollapse.back(), op)) {
+          toCollapse.push_back(op);
+        } else {
+          collapseSparseSpace(toCollapse);
+          toCollapse.clear();
+        }
+      }
+    });
+
+    collapseSparseSpace(toCollapse);
+  }
+};
+
+} // namespace sparse_tensor
+
+std::unique_ptr<Pass> createSparseSpaceCollapsePass() {
+  return std::make_unique<sparse_tensor::SparseSpaceCollapsePass>();
+}
+
+} // namespace mlir

>From c5840b322cfa05a557b58bdca0201bfbcb31dc3b Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 22 Mar 2024 18:41:34 +0000
Subject: [PATCH 3/5] test collapsing coordinate extraction from iterator.

---
 .../SparseTensor/IR/SparseTensorInterfaces.h  |   2 +
 .../SparseTensor/IR/SparseTensorInterfaces.td |  15 ++
 .../SparseTensor/IR/SparseTensorOps.td        |  17 +-
 .../SparseTensor/IR/SparseTensorDialect.cpp   | 167 +++++++++++-------
 .../Transforms/SparseSpaceCollapse.cpp        | 119 +++++++++----
 5 files changed, 217 insertions(+), 103 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h
index c0f31762ee071f..115e08b2cf8b14 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h
@@ -10,6 +10,7 @@
 #define MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORINTERFACES_H_
 
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
 
 namespace mlir {
 class PatternRewriter;
@@ -20,6 +21,7 @@ class StageWithSortSparseOp;
 namespace detail {
 LogicalResult stageWithSortImpl(sparse_tensor::StageWithSortSparseOp op,
                                 PatternRewriter &rewriter, Value &tmpBufs);
+
 } // namespace detail
 } // namespace sparse_tensor
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td
index 05eed0483f2c8a..ee1c0b52b47e45 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td
@@ -42,5 +42,20 @@ def StageWithSortSparseOpInterface : OpInterface<"StageWithSortSparseOp"> {
   ];
 }
 
+def SparseCollapsableOpInterface : OpInterface<"SparseCollapsableOp"> {
+  let description = [{ TODO }];
+
+  let cppNamespace = "::mlir::sparse_tensor";
+
+  let methods = [
+    InterfaceMethod<
+    /*desc=*/"test",
+    /*retTy=*/"ValueRange",
+    /*methodName=*/"collaspeOpInto",
+    /*args=*/(ins "::mlir::OpBuilder &":$builder,
+                  "::mlir::ArrayRef<::mlir::Operation *>":$loops,
+                  "::mlir::Operation *":$collapsed)>,
+  ];
+}
 
 #endif // SPARSETENSOR_IR_SPARSETENSORINTERFACES
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index f0eaf2191fdbd1..467030a8d221af 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1280,7 +1280,9 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [Pure, SameOperandsAndResu
   let hasVerifier = 1;
 }
 
-def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator]>,
+def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator,
+    ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp", "ForeachOp",
+                 "IterateOp"]>]>,
     Arguments<(ins Optional<AnyType>:$result)> {
   let summary = "Yield from sparse_tensor set-like operations";
   let description = [{
@@ -1311,7 +1313,6 @@ def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator]>,
   let assemblyFormat = [{
         $result attr-dict `:` type($result)
   }];
-  let hasVerifier = 1;
 }
 
 def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
@@ -1440,10 +1441,20 @@ def ExtractIterSpaceOp : SparseTensor_Op<"iteration.extract_space",
   }];
 
   let hasVerifier = 1;
-  let assemblyFormat = "$tensor (`at`$parentIter^)? `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) "
+  let assemblyFormat = "$tensor (`at` $parentIter^)? `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) "
                        " attr-dict `:` type($tensor) (`,` type($parentIter)^)?";
 }
 
+def CoordinateOp : SparseTensor_Op<"iteration.coordinate",
+    [Pure, DeclareOpInterfaceMethods<SparseCollapsableOpInterface>]> {
+  let arguments = (ins AnySparseIterator:$iterator);
+  let results = (outs Variadic<Index>:$crds);
+
+  let extraClassDeclaration = [{ }];
+  // let hasVerifier = 1;
+  let assemblyFormat = " $iterator attr-dict `:` type($iterator) `->` type($crds)";
+}
+
 def IterateOp : SparseTensor_Op<"iterate",
     [RecursiveMemoryEffects, RecursivelySpeculatable,
      DeclareOpInterfaceMethods<LoopLikeOpInterface,
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 6afa3e6309dc65..54be3c3b4c3e5f 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1907,18 +1907,6 @@ LogicalResult SortOp::verify() {
   return success();
 }
 
-LogicalResult YieldOp::verify() {
-  // Check for compatible parent.
-  auto *parentOp = (*this)->getParentOp();
-  if (isa<BinaryOp>(parentOp) || isa<UnaryOp>(parentOp) ||
-      isa<ReduceOp>(parentOp) || isa<SelectOp>(parentOp) ||
-      isa<ForeachOp>(parentOp) || isa<IterateOp>(parentOp))
-    return success();
-
-  return emitOpError("expected parent op to be sparse_tensor unary, binary, "
-                     "reduce, select or foreach");
-}
-
 //===----------------------------------------------------------------------===//
 // Sparse Tensor Iteration Operations.
 //===----------------------------------------------------------------------===//
@@ -1936,17 +1924,82 @@ static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr,
   } else {
     lvlHi = lvlLo + 1;
   }
+
+  if (lvlHi <= lvlLo)
+    parser.emitError(parser.getNameLoc(),
+                     "expect larger level upper bound than lower bound");
+
   lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo);
   lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi);
   return success();
 }
 
-static void printLevelRange(OpAsmPrinter &p, ExtractIterSpaceOp op,
-                            IntegerAttr lvlLo, IntegerAttr lvlHi) {
-  if (op.getLoLvl() + 1 == op.getHiLvl())
-    p << op.getLoLvl();
+static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
+                            IntegerAttr lvlHi) {
+  unsigned lo = lvlLo.getValue().getZExtValue();
+  unsigned hi = lvlHi.getValue().getZExtValue();
+  if (lo + 1 == hi)
+    p << lo;
   else
-    p << op.getLoLvl() << " to " << op.getHiLvl();
+    p << lo << " to " << hi;
+}
+
+ParseResult
+parseSparseSpaceLoop(OpAsmParser &parser, OperationState &state,
+                     SmallVectorImpl<OpAsmParser::Argument> &iterators,
+                     SmallVectorImpl<OpAsmParser::Argument> &iterArgs) {
+  SmallVector<OpAsmParser::UnresolvedOperand> spaces;
+  SmallVector<OpAsmParser::UnresolvedOperand> initArgs;
+  // Parses "%iters, ... in %spaces, ..."
+  if (parser.parseArgumentList(iterators) || parser.parseKeyword("in") ||
+      parser.parseOperandList(spaces))
+    return failure();
+
+  if (iterators.size() != spaces.size())
+    return parser.emitError(
+        parser.getNameLoc(),
+        "mismatch in number of sparse iterators and sparse spaces");
+
+  bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
+  if (hasIterArgs)
+    if (parser.parseAssignmentList(iterArgs, initArgs))
+      return failure();
+
+  SmallVector<Type> iterSpaceTps;
+  // parse ": sparse_tensor.iter_space -> ret"
+  if (parser.parseColon() || parser.parseTypeList(iterSpaceTps))
+    return failure();
+  if (iterSpaceTps.size() != spaces.size())
+    return parser.emitError(parser.getNameLoc(),
+                            "mismatch in number of iteration space operands "
+                            "and iteration space types");
+
+  for (auto [it, tp] : llvm::zip_equal(iterators, iterSpaceTps)) {
+    IterSpaceType spaceTp = llvm::dyn_cast<IterSpaceType>(tp);
+    if (!spaceTp)
+      return parser.emitError(parser.getNameLoc(),
+                              "expected sparse_tensor.iter_space type for "
+                              "iteration space operands");
+    it.type = spaceTp.getIteratorType();
+  }
+
+  if (hasIterArgs)
+    if (parser.parseArrowTypeList(state.types))
+      return failure();
+
+  // Resolves input operands.
+  if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(),
+                             state.operands))
+    return failure();
+
+  if (hasIterArgs) {
+    for (auto [it, init, tp] : llvm::zip(iterArgs, initArgs, state.types)) {
+      it.type = tp;
+      if (parser.resolveOperand(init, tp, state.operands))
+        return failure();
+    }
+  }
+  return success();
 }
 
 LogicalResult ExtractIterSpaceOp::inferReturnTypes(
@@ -1991,60 +2044,45 @@ LogicalResult ExtractIterSpaceOp::verify() {
   return success();
 }
 
+ValueRange CoordinateOp::collaspeOpInto(OpBuilder &builder,
+                                        ArrayRef<Operation *> loops,
+                                        Operation *collapsed) {
+  assert(llvm::all_of(loops,
+                      [](Operation *l) { return llvm::isa<IterateOp>(l); }));
+  auto finalLoop = llvm::cast<IterateOp>(collapsed);
+  SmallVector<Type> retTps(finalLoop.getIterSpace().getType().getSpaceDim(),
+                           builder.getIndexType());
+  auto collapsedCoords =
+      builder.create<CoordinateOp>(getLoc(), retTps, finalLoop.getIterator());
+
+  for (Operation *l : loops) {
+    if (getIterator().getParentBlock()->getParentOp() == l) {
+      auto space = llvm::cast<IterateOp>(l)
+                       .getIterSpace()
+                       .getDefiningOp<ExtractIterSpaceOp>();
+
+      return collapsedCoords.getResults().slice(space.getLoLvl(),
+                                                space.getSpaceDim());
+    }
+  }
+  llvm_unreachable(
+      "Can not find the corresponding iterate space for the collapsable op.");
+}
+
 ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
   OpAsmParser::Argument iterator;
   OpAsmParser::UnresolvedOperand iterSpace;
 
-  // Parses %iters in %spaces
-  if (parser.parseArgument(iterator) || parser.parseKeyword("in") ||
-      parser.parseOperand(iterSpace)) {
+  SmallVector<OpAsmParser::Argument> iters, iterArgs;
+  if (parseSparseSpaceLoop(parser, result, iters, iterArgs))
     return failure();
-  }
-
-  // Parse the optional initial iteration arguments.
-  SmallVector<OpAsmParser::Argument> regionArgs;
-  SmallVector<OpAsmParser::UnresolvedOperand> operands;
-  // Region arguments starts with iterators and follows by optional
-  // user-provided iter_args.
-  regionArgs.push_back(iterator);
-  bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
-  if (hasIterArgs)
-    if (parser.parseAssignmentList(regionArgs, operands))
-      return failure();
-
-  // parse ": sparse_tensor.iter_space -> ret"
-  Type iterSpaceTps;
-  if (parser.parseColon() || parser.parseType(iterSpaceTps))
-    return failure();
-  if (hasIterArgs)
-    if (parser.parseArrowTypeList(result.types))
-      return failure();
-
-  if (regionArgs.size() != result.types.size() + 1) {
-    return parser.emitError(
-        parser.getNameLoc(),
-        "mismatch in number of loop-carried values and defined values");
-  }
-
-  // Resolves input operands.
-  if (parser.resolveOperand(iterSpace, iterSpaceTps, result.operands))
-    return failure();
-
-  if (hasIterArgs) {
-    for (auto argOperandType :
-         llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
-      Type type = std::get<2>(argOperandType);
-      std::get<0>(argOperandType).type = type;
-      if (parser.resolveOperand(std::get<1>(argOperandType), type,
-                                result.operands))
-        return failure();
-    }
-  }
+  if (iters.size() != 1)
+    return parser.emitError(parser.getNameLoc(),
+                            "expected only one iterator/iteration space");
 
+  iters.append(iterArgs);
   Region *body = result.addRegion();
-  regionArgs.front().type =
-      iterSpaceTps.cast<IterSpaceType>().getIteratorType();
-  if (parser.parseRegion(*body, regionArgs))
+  if (parser.parseRegion(*body, iters))
     return failure();
 
   IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
@@ -2078,7 +2116,6 @@ static void printInitializationList(OpAsmPrinter &p,
 
 void IterateOp::print(OpAsmPrinter &p) {
   p << " " << getIterator() << " in " << getIterSpace();
-
   printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
 
   p << " : " << getIterSpace().getType() << " ";
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
index f3207ede9585b4..752b2dfc2a0070 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
@@ -20,8 +20,14 @@ namespace mlir {
 
 namespace sparse_tensor {
 
-bool isCollapsableIterations(LoopLikeOpInterface parent,
-                             LoopLikeOpInterface node) {
+struct CollapseSpaceInfo {
+  ExtractIterSpaceOp space;
+  // Coiteration as well (if make sense)?
+  IterateOp loop;
+  SmallVector<SparseCollapsableOp> collapseOps;
+};
+
+bool isCollapsableLoops(LoopLikeOpInterface parent, LoopLikeOpInterface node) {
   auto pIterArgs = parent.getRegionIterArgs();
   auto nInitArgs = node.getInits();
   if (pIterArgs.size() != nInitArgs.size())
@@ -44,43 +50,81 @@ bool isCollapsableIterations(LoopLikeOpInterface parent,
   return yieldEq && iterArgEq;
 }
 
-bool legalToCollapse(ExtractIterSpaceOp parent, ExtractIterSpaceOp node) {
-  auto pItOp = llvm::dyn_cast<IterateOp>(parent->getParentOp());
-  auto nItOp = llvm::dyn_cast<IterateOp>(node->getParentOp());
+bool legalToCollapse(SmallVectorImpl<CollapseSpaceInfo> &toCollapse,
+                     ExtractIterSpaceOp curSpace) {
+
+  auto getIterateOpOverSpace = [](ExtractIterSpaceOp space) -> IterateOp {
+    Value spaceVal = space.getResultSpace();
+    if (spaceVal.hasOneUse())
+      return llvm::dyn_cast<IterateOp>(*spaceVal.getUsers().begin());
+    return nullptr;
+  };
+
+  if (toCollapse.empty()) {
+    // Collapse root.
+    if (auto itOp = getIterateOpOverSpace(curSpace)) {
+      CollapseSpaceInfo &info = toCollapse.emplace_back();
+      info.space = curSpace;
+      info.loop = itOp;
+      // No operations need to be collapsed at the root level;
+      info.collapseOps = {};
+      return true;
+    }
+    return false;
+  }
+
+  auto parent = toCollapse.back().space;
+  auto pItOp = toCollapse.back().loop;
+  auto nItOp = getIterateOpOverSpace(curSpace);
 
   // Can only collapse spaces extracted from the same tensor.
-  if (parent.getTensor() != node.getTensor() || !parent->hasOneUse())
+  if (parent.getTensor() != curSpace.getTensor())
     return false;
 
   // Can only collapse consecutive simple iteration on one tensor (i.e., no
   // coiteration).
-  if (!nItOp || nItOp.getIterSpace() != parent.getResult() ||
-      nItOp->getBlock() != parent->getBlock())
+  if (!nItOp || nItOp->getBlock() != curSpace->getBlock() ||
+      pItOp.getIterator() != curSpace.getParentIter() ||
+      curSpace->getParentOp() != pItOp.getOperation())
     return false;
 
-  if (pItOp && !isCollapsableIterations(pItOp, nItOp))
+  if (pItOp && !isCollapsableLoops(pItOp, nItOp))
     return false;
 
   // TODO: Make sure all other operations in the same basic block as `node` can
   // be collapsed and sink into the collapsed iteration (through Interfaces
   // defined in TD files).
+  SmallVector<SparseCollapsableOp> collapsableOps;
+  for (Operation &op : *pItOp.getBody()) {
+    if (&op == curSpace.getOperation() || &op == nItOp.getOperation() ||
+        &op == pItOp.getBody()->getTerminator())
+      continue;
+    // All other ops in parent loop need to be collapsable.
+    auto collapsableOp = llvm::dyn_cast<SparseCollapsableOp>(&op);
+    if (!collapsableOp)
+      return false;
+    collapsableOps.push_back(collapsableOp);
+  }
+
+  CollapseSpaceInfo &info = toCollapse.emplace_back();
+  info.space = curSpace;
+  info.loop = nItOp;
+  info.collapseOps = std::move(collapsableOps);
   return true;
 }
 
-void collapseSparseSpace(ArrayRef<ExtractIterSpaceOp> toCollapse) {
+void collapseSparseSpace(SmallVectorImpl<CollapseSpaceInfo> &toCollapse) {
   if (toCollapse.size() < 2)
     return;
 
-  ExtractIterSpaceOp root = toCollapse.front();
-  ExtractIterSpaceOp leaf = toCollapse.back();
+  ExtractIterSpaceOp root = toCollapse.front().space;
+  ExtractIterSpaceOp leaf = toCollapse.back().space;
   Location loc = root.getLoc();
 
-  if (!leaf->hasOneUse())
-    return;
-  assert(root->hasOneUse());
+  assert(root->hasOneUse() && leaf->hasOneUse());
 
   // Insert collapsed operation at the same scope as root operation.
-  OpBuilder builder(toCollapse.front());
+  OpBuilder builder(root);
 
   // Construct the collapsed iteration space.
   auto collapsedSpace = builder.create<ExtractIterSpaceOp>(
@@ -88,19 +132,29 @@ void collapseSparseSpace(ArrayRef<ExtractIterSpaceOp> toCollapse) {
       leaf.getHiLvl());
 
   auto rItOp = llvm::cast<IterateOp>(*root->getUsers().begin());
-  auto pItOp = llvm::cast<IterateOp>(leaf->getParentOp());
-
-  // This could either be IterateOp or (TODO: in the future) CoIterateOp.
-  auto loop = llvm::dyn_cast<IterateOp>(*leaf->getUsers().begin());
-  if (!loop || !isCollapsableIterations(pItOp, loop))
-    return;
+  auto innermost = toCollapse.back().loop;
 
   IRMapping mapper;
   mapper.map(leaf, collapsedSpace.getResultSpace());
-  for (auto z : llvm::zip_equal(loop.getInitArgs(), rItOp.getInitArgs()))
+  for (auto z : llvm::zip_equal(innermost.getInitArgs(), rItOp.getInitArgs()))
     mapper.map(std::get<0>(z), std::get<1>(z));
 
-  auto cloned = llvm::cast<IterateOp>(builder.clone(*loop, mapper));
+  auto cloned = llvm::cast<IterateOp>(builder.clone(*innermost, mapper));
+  builder.setInsertionPointToStart(cloned.getBody());
+  SmallVector<Operation *> loops =
+      llvm::map_to_vector(toCollapse, [](CollapseSpaceInfo &info) {
+        return info.loop.getOperation();
+      });
+
+  for (const CollapseSpaceInfo &info : toCollapse) {
+    for (SparseCollapsableOp op : info.collapseOps) {
+      ValueRange colVals = op.collaspeOpInto(builder, loops, cloned);
+      for (auto [o, r] : llvm::zip(op->getResults(), colVals))
+        o.replaceAllUsesWith(r);
+      op.erase();
+    }
+  }
+
   cloned.getIterator().setType(collapsedSpace.getType().getIteratorType());
 
   rItOp.replaceAllUsesWith(cloned.getResults());
@@ -124,18 +178,13 @@ struct SparseSpaceCollapsePass
     // %space2 = extract_space %t2 ...
     // sparse_tensor.iterate(%sp1) ...
     //
-    SmallVector<ExtractIterSpaceOp> toCollapse;
+    SmallVector<CollapseSpaceInfo> toCollapse;
     func->walk([&](ExtractIterSpaceOp op) {
-      if (toCollapse.empty()) {
-        // Root space to collapse.
-        toCollapse.push_back(op);
-      } else {
-        if (legalToCollapse(toCollapse.back(), op)) {
-          toCollapse.push_back(op);
-        } else {
-          collapseSparseSpace(toCollapse);
-          toCollapse.clear();
-        }
+      if (!legalToCollapse(toCollapse, op)) {
+        // if not legal to collapse one more space, collapse the existing ones
+        // and clear.
+        collapseSparseSpace(toCollapse);
+        toCollapse.clear();
       }
     });
 

>From f1d12b1ce28a94499b2c8cd4a0c4d79232955067 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 25 Mar 2024 22:41:52 +0000
Subject: [PATCH 4/5] fuse crds into iterate operation.

---
 .../Dialect/SparseTensor/IR/SparseTensor.h    |  36 ++++++
 .../SparseTensor/IR/SparseTensorAttrDefs.td   |  15 +++
 .../SparseTensor/IR/SparseTensorOps.td        |  30 +++--
 .../SparseTensor/IR/SparseTensorDialect.cpp   | 119 +++++++++++++-----
 .../Transforms/SparseSpaceCollapse.cpp        |  44 ++-----
 5 files changed, 170 insertions(+), 74 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 78692307820bc5..081a9b8cad8d62 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -22,6 +22,8 @@
 #include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
+#include "llvm/ADT/bit.h"
+
 //===----------------------------------------------------------------------===//
 //
 // Type aliases to help code be more self-documenting. Unfortunately
@@ -43,6 +45,40 @@ using Level = uint64_t;
 /// including the value `ShapedType::kDynamic` (for shapes).
 using Size = int64_t;
 
+/// A simple wrapper to encode a bitset of defined  (at most 64) levels.
+class LevelSet {
+  uint64_t bits = 0;
+
+public:
+  LevelSet() = default;
+  explicit LevelSet(uint64_t bits) : bits(bits) {}
+  operator uint64_t() const { return bits; }
+
+  LevelSet &set(unsigned i) {
+    assert(i < 64);
+    bits |= 1 << i;
+    return *this;
+  }
+
+  LevelSet &operator|=(LevelSet lhs) {
+    bits |= static_cast<uint64_t>(lhs);
+    return *this;
+  }
+
+  LevelSet &lshift(unsigned offset) {
+    bits = bits << offset;
+    return *this;
+  }
+
+  bool operator[](unsigned i) const {
+    assert(i < 64);
+    return (bits & (1 << i)) != 0;
+  }
+
+  unsigned count() const { return llvm::popcount(bits); }
+  bool empty() const { return bits == 0; }
+};
+
 } // namespace sparse_tensor
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index d3be8a3009ba1e..36c075f52f8e5b 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -19,6 +19,21 @@ class SparseTensor_Attr<string name,
                         list<Trait> traits = []>
     : AttrDef<SparseTensor_Dialect, name, traits>;
 
+//===----------------------------------------------------------------------===//
+// A simple bitset attribute wrapped over a single int64_t to encode a set of
+// sparse tensor levels.
+//===----------------------------------------------------------------------===//
+
+def LevelSetAttr :
+    TypedAttrBase<
+      I64, "IntegerAttr",
+      And<[CPred<"::llvm::isa<::mlir::IntegerAttr>($_self)">,
+           CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getType().isInteger(64)">]>,
+      "LevelSet attribute"> {
+  let returnType = [{::mlir::sparse_tensor::LevelSet}];
+  let convertFromStorage = [{::mlir::sparse_tensor::LevelSet($_self.getValue().getZExtValue())}];
+}
+
 //===----------------------------------------------------------------------===//
 // These attributes are just like `IndexAttr` except that they clarify whether
 // the index refers to a dimension (an axis of the semantic tensor) or a level
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 467030a8d221af..9a918760c3190d 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1445,36 +1445,42 @@ def ExtractIterSpaceOp : SparseTensor_Op<"iteration.extract_space",
                        " attr-dict `:` type($tensor) (`,` type($parentIter)^)?";
 }
 
-def CoordinateOp : SparseTensor_Op<"iteration.coordinate",
-    [Pure, DeclareOpInterfaceMethods<SparseCollapsableOpInterface>]> {
-  let arguments = (ins AnySparseIterator:$iterator);
-  let results = (outs Variadic<Index>:$crds);
-
-  let extraClassDeclaration = [{ }];
-  // let hasVerifier = 1;
-  let assemblyFormat = " $iterator attr-dict `:` type($iterator) `->` type($crds)";
-}
+// def CoordinateOp : SparseTensor_Op<"iteration.coordinate",
+//     [Pure, DeclareOpInterfaceMethods<SparseCollapsableOpInterface>]> {
+//   let arguments = (ins AnySparseIterator:$iterator);
+//   let results = (outs Variadic<Index>:$crds);
+//   let extraClassDeclaration = [{ }];
+//   // let hasVerifier = 1;
+//   let assemblyFormat = " $iterator attr-dict `:` type($iterator) `->` type($crds)";
+// }
 
 def IterateOp : SparseTensor_Op<"iterate",
     [RecursiveMemoryEffects, RecursivelySpeculatable,
      DeclareOpInterfaceMethods<LoopLikeOpInterface,
       ["getInitsMutable", "getLoopResults", "getRegionIterArgs",
-      "getSingleInductionVar", "getYieldedValuesMutable"]>,
+       "getYieldedValuesMutable"]>,
      DeclareOpInterfaceMethods<RegionBranchOpInterface,
       ["getEntrySuccessorOperands"]>,
      SingleBlockImplicitTerminator<"sparse_tensor::YieldOp">]> {
 
   let arguments = (ins AnySparseIterSpace:$iterSpace,
-                       Variadic<AnyType>:$initArgs);
+                       Variadic<AnyType>:$initArgs,
+                       LevelSetAttr:$crdUsedLvls);
   let results = (outs Variadic<AnyType>:$results);
   let regions = (region SizedRegion<1>:$region);
 
   let extraClassDeclaration = [{
+    unsigned getSpaceDim() {
+      return getIterSpace().getType().getSpaceDim();
+    }
     BlockArgument getIterator() {
       return getRegion().getArguments().front();
     }
+    Block::BlockArgListType getCrds() {
+      return getRegion().getArguments().slice(1, getCrdUsedLvls().count());
+    }
     unsigned getNumRegionIterArgs() {
-      return getRegion().getArguments().size() - 1;
+      return getRegion().getArguments().size() - 1 - getCrdUsedLvls().count();
     }
   }];
 
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 54be3c3b4c3e5f..2b54c2fda3d739 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -42,6 +42,7 @@ namespace mlir::sparse_tensor {
 llvm::hash_code hash_value(LevelType lt) {
   return llvm::hash_value(static_cast<uint64_t>(lt));
 }
+
 } // namespace mlir::sparse_tensor
 
 //===----------------------------------------------------------------------===//
@@ -1944,12 +1945,13 @@ static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
     p << lo << " to " << hi;
 }
 
-ParseResult
+static ParseResult
 parseSparseSpaceLoop(OpAsmParser &parser, OperationState &state,
                      SmallVectorImpl<OpAsmParser::Argument> &iterators,
                      SmallVectorImpl<OpAsmParser::Argument> &iterArgs) {
   SmallVector<OpAsmParser::UnresolvedOperand> spaces;
   SmallVector<OpAsmParser::UnresolvedOperand> initArgs;
+
   // Parses "%iters, ... in %spaces, ..."
   if (parser.parseArgumentList(iterators) || parser.parseKeyword("in") ||
       parser.parseOperandList(spaces))
@@ -1960,6 +1962,34 @@ parseSparseSpaceLoop(OpAsmParser &parser, OperationState &state,
         parser.getNameLoc(),
         "mismatch in number of sparse iterators and sparse spaces");
 
+  // Parse "at(%crd0, _, ...)"
+  LevelSet crdUsedLvlSet;
+  bool hasUsedCrds = succeeded(parser.parseOptionalKeyword("at"));
+  unsigned lvlCrdCnt = 0;
+  if (hasUsedCrds) {
+    ParseResult crdList = parser.parseCommaSeparatedList(
+        OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
+          if (parser.parseOptionalKeyword("_")) {
+            if (parser.parseArgument(iterArgs.emplace_back()))
+              return failure();
+            // Always use IndexType for the coordinate.
+            crdUsedLvlSet.set(lvlCrdCnt);
+            iterArgs.back().type = parser.getBuilder().getIndexType();
+          }
+          lvlCrdCnt += 1;
+          return success();
+        });
+    if (failed(crdList)) {
+      return parser.emitError(
+          parser.getNameLoc(),
+          "expecting SSA value or \"_\" for level coordinates");
+    }
+  }
+  // Set the CrdUsedLvl bitset.
+  state.addAttribute("crdUsedLvls",
+                     parser.getBuilder().getI64IntegerAttr(crdUsedLvlSet));
+
+  // Parse "iter_args(%arg = %init, ...)"
   bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
   if (hasIterArgs)
     if (parser.parseAssignmentList(iterArgs, initArgs))
@@ -1980,6 +2010,10 @@ parseSparseSpaceLoop(OpAsmParser &parser, OperationState &state,
       return parser.emitError(parser.getNameLoc(),
                               "expected sparse_tensor.iter_space type for "
                               "iteration space operands");
+    if (hasUsedCrds && spaceTp.getSpaceDim() != lvlCrdCnt)
+      return parser.emitError(parser.getNameLoc(),
+                              "mismatch in number of iteration space dimension "
+                              "and specified coordinates");
     it.type = spaceTp.getIteratorType();
   }
 
@@ -1993,7 +2027,10 @@ parseSparseSpaceLoop(OpAsmParser &parser, OperationState &state,
     return failure();
 
   if (hasIterArgs) {
-    for (auto [it, init, tp] : llvm::zip(iterArgs, initArgs, state.types)) {
+    unsigned numCrds = crdUsedLvlSet.count();
+    // Strip off leading args that used for coordinates.
+    MutableArrayRef args = MutableArrayRef(iterArgs).drop_front(numCrds);
+    for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) {
       it.type = tp;
       if (parser.resolveOperand(init, tp, state.operands))
         return failure();
@@ -2044,30 +2081,32 @@ LogicalResult ExtractIterSpaceOp::verify() {
   return success();
 }
 
-ValueRange CoordinateOp::collaspeOpInto(OpBuilder &builder,
-                                        ArrayRef<Operation *> loops,
-                                        Operation *collapsed) {
-  assert(llvm::all_of(loops,
-                      [](Operation *l) { return llvm::isa<IterateOp>(l); }));
-  auto finalLoop = llvm::cast<IterateOp>(collapsed);
-  SmallVector<Type> retTps(finalLoop.getIterSpace().getType().getSpaceDim(),
-                           builder.getIndexType());
-  auto collapsedCoords =
-      builder.create<CoordinateOp>(getLoc(), retTps, finalLoop.getIterator());
-
-  for (Operation *l : loops) {
-    if (getIterator().getParentBlock()->getParentOp() == l) {
-      auto space = llvm::cast<IterateOp>(l)
-                       .getIterSpace()
-                       .getDefiningOp<ExtractIterSpaceOp>();
-
-      return collapsedCoords.getResults().slice(space.getLoLvl(),
-                                                space.getSpaceDim());
-    }
-  }
-  llvm_unreachable(
-      "Can not find the corresponding iterate space for the collapsable op.");
-}
+// ValueRange CoordinateOp::collaspeOpInto(OpBuilder &builder,
+//                                         ArrayRef<Operation *> loops,
+//                                         Operation *collapsed) {
+//   assert(llvm::all_of(loops,
+//                       [](Operation *l) { return llvm::isa<IterateOp>(l); }));
+//   auto finalLoop = llvm::cast<IterateOp>(collapsed);
+//   SmallVector<Type> retTps(finalLoop.getIterSpace().getType().getSpaceDim(),
+//                            builder.getIndexType());
+//   auto collapsedCoords =
+//       builder.create<CoordinateOp>(getLoc(), retTps,
+//       finalLoop.getIterator());
+
+//   for (Operation *l : loops) {
+//     if (getIterator().getParentBlock()->getParentOp() == l) {
+//       auto space = llvm::cast<IterateOp>(l)
+//                        .getIterSpace()
+//                        .getDefiningOp<ExtractIterSpaceOp>();
+
+//       return collapsedCoords.getResults().slice(space.getLoLvl(),
+//                                                 space.getSpaceDim());
+//     }
+//   }
+//   llvm_unreachable(
+//       "Can not find the corresponding iterate space for the collapsable
+//       op.");
+// }
 
 ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
   OpAsmParser::Argument iterator;
@@ -2114,8 +2153,30 @@ static void printInitializationList(OpAsmPrinter &p,
   p << ")";
 }
 
+static void printUsedCrdsList(OpAsmPrinter &p, unsigned spaceDim,
+                              Block::BlockArgListType blocksArgs,
+                              LevelSet crdUsedLvls) {
+  if (crdUsedLvls.empty())
+    return;
+
+  p << " at(";
+  for (unsigned i = 0; i < spaceDim; i++) {
+    if (crdUsedLvls[i]) {
+      p << blocksArgs.front();
+      blocksArgs = blocksArgs.drop_front();
+    } else {
+      p << "_";
+    }
+    if (i != spaceDim - 1)
+      p << ", ";
+  }
+  assert(blocksArgs.empty());
+  p << ")";
+}
+
 void IterateOp::print(OpAsmPrinter &p) {
   p << " " << getIterator() << " in " << getIterSpace();
+  printUsedCrdsList(p, getSpaceDim(), getCrds(), getCrdUsedLvls());
   printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
 
   p << " : " << getIterSpace().getType() << " ";
@@ -2170,16 +2231,12 @@ LogicalResult IterateOp::verifyRegions() {
 /// IterateOp implemented interfaces' methods.
 SmallVector<Region *> IterateOp::getLoopRegions() { return {&getRegion()}; }
 
-std::optional<Value> IterateOp::getSingleInductionVar() {
-  return getIterator();
-}
-
 MutableArrayRef<OpOperand> IterateOp::getInitsMutable() {
   return getInitArgsMutable();
 }
 
 Block::BlockArgListType IterateOp::getRegionIterArgs() {
-  return getRegion().getArguments().drop_front();
+  return getRegion().getArguments().take_back(getNumRegionIterArgs());
 }
 
 std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
index 752b2dfc2a0070..39c9a9292c9be9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
@@ -24,7 +24,6 @@ struct CollapseSpaceInfo {
   ExtractIterSpaceOp space;
   // Coiteration as well (if make sense)?
   IterateOp loop;
-  SmallVector<SparseCollapsableOp> collapseOps;
 };
 
 bool isCollapsableLoops(LoopLikeOpInterface parent, LoopLikeOpInterface node) {
@@ -66,8 +65,6 @@ bool legalToCollapse(SmallVectorImpl<CollapseSpaceInfo> &toCollapse,
       CollapseSpaceInfo &info = toCollapse.emplace_back();
       info.space = curSpace;
       info.loop = itOp;
-      // No operations need to be collapsed at the root level;
-      info.collapseOps = {};
       return true;
     }
     return false;
@@ -91,29 +88,13 @@ bool legalToCollapse(SmallVectorImpl<CollapseSpaceInfo> &toCollapse,
   if (pItOp && !isCollapsableLoops(pItOp, nItOp))
     return false;
 
-  // TODO: Make sure all other operations in the same basic block as `node` can
-  // be collapsed and sink into the collapsed iteration (through Interfaces
-  // defined in TD files).
-  SmallVector<SparseCollapsableOp> collapsableOps;
-  for (Operation &op : *pItOp.getBody()) {
-    if (&op == curSpace.getOperation() || &op == nItOp.getOperation() ||
-        &op == pItOp.getBody()->getTerminator())
-      continue;
-    // All other ops in parent loop need to be collapsable.
-    auto collapsableOp = llvm::dyn_cast<SparseCollapsableOp>(&op);
-    if (!collapsableOp)
-      return false;
-    collapsableOps.push_back(collapsableOp);
-  }
-
   CollapseSpaceInfo &info = toCollapse.emplace_back();
   info.space = curSpace;
   info.loop = nItOp;
-  info.collapseOps = std::move(collapsableOps);
   return true;
 }
 
-void collapseSparseSpace(SmallVectorImpl<CollapseSpaceInfo> &toCollapse) {
+void collapseSparseSpace(MutableArrayRef<CollapseSpaceInfo> toCollapse) {
   if (toCollapse.size() < 2)
     return;
 
@@ -141,21 +122,22 @@ void collapseSparseSpace(SmallVectorImpl<CollapseSpaceInfo> &toCollapse) {
 
   auto cloned = llvm::cast<IterateOp>(builder.clone(*innermost, mapper));
   builder.setInsertionPointToStart(cloned.getBody());
-  SmallVector<Operation *> loops =
-      llvm::map_to_vector(toCollapse, [](CollapseSpaceInfo &info) {
-        return info.loop.getOperation();
-      });
 
-  for (const CollapseSpaceInfo &info : toCollapse) {
-    for (SparseCollapsableOp op : info.collapseOps) {
-      ValueRange colVals = op.collaspeOpInto(builder, loops, cloned);
-      for (auto [o, r] : llvm::zip(op->getResults(), colVals))
-        o.replaceAllUsesWith(r);
-      op.erase();
+  LevelSet crdUsedLvls;
+  unsigned shift = 0, argIdx = 1;
+  for (auto info : toCollapse.drop_back()) {
+    LevelSet set = info.loop.getCrdUsedLvls();
+    crdUsedLvls |= set.lshift(shift);
+    shift += info.loop.getSpaceDim();
+    for (BlockArgument crd : info.loop.getCrds()) {
+      BlockArgument collapsedCrd = cloned.getBody()->insertArgument(
+          argIdx++, builder.getIndexType(), crd.getLoc());
+      crd.replaceAllUsesWith(collapsedCrd);
     }
   }
-
+  crdUsedLvls |= innermost.getCrdUsedLvls().lshift(shift);
   cloned.getIterator().setType(collapsedSpace.getType().getIteratorType());
+  cloned.setCrdUsedLvls(crdUsedLvls);
 
   rItOp.replaceAllUsesWith(cloned.getResults());
   // Erase collapsed loops.

>From 0336e00d04d83d05940310c7e8bec892cb390bb1 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 28 Mar 2024 16:32:38 +0000
Subject: [PATCH 5/5] setup lowering passes

---
 .../mlir/Dialect/SparseTensor/IR/Enums.h      |   4 +
 .../SparseTensor/IR/SparseTensorOps.td        |  24 ++--
 .../Dialect/SparseTensor/Transforms/Passes.h  |  16 +++
 .../Dialect/SparseTensor/Transforms/Passes.td |  13 ++
 .../SparseTensor/Transforms/CMakeLists.txt    |   1 +
 .../Transforms/SparseIterationToScf.cpp       |  76 ++++++++++++
 .../Transforms/SparseTensorPasses.cpp         |  27 ++++
 .../Transforms/Utils/SparseTensorLevel.cpp    | 117 +++++++++++-------
 .../Transforms/Utils/SparseTensorLevel.h      |   6 +
 9 files changed, 227 insertions(+), 57 deletions(-)
 create mode 100644 mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index 33f613a46bad84..96ee7111fea2cf 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -357,6 +357,10 @@ struct LevelType {
     return hasSparseSemantic();
   }
 
+  constexpr unsigned getNumBuffer() const {
+    return hasDenseSemantic() ? 0 : (isWithPosLT() ? 2 : 1);
+  }
+
   std::string toMLIRString() const {
     std::string lvlStr = toFormatString(getLvlFmt());
     std::string propStr = "";
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 9a918760c3190d..540cfa880a13e2 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -263,7 +263,7 @@ def SparseTensor_ReinterpretMapOp : SparseTensor_Op<"reinterpret_map", [NoMemory
 }
 
 def SparseTensor_ToPositionsOp : SparseTensor_Op<"positions",
-      [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
+      [Pure, AlwaysSpeculatable, DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
     Arguments<(ins AnySparseTensor:$tensor, LevelAttr:$level)>,
     Results<(outs AnyNon0RankedMemRef:$result)> {
   let summary = "Extracts the `level`-th positions array of the `tensor`";
@@ -285,12 +285,12 @@ def SparseTensor_ToPositionsOp : SparseTensor_Op<"positions",
        : tensor<64x64xf64, #CSR> to memref<?xindex>
     ```
   }];
-  let assemblyFormat = "$tensor attr-dict `:` type($tensor) `to` type($result)";
+  let assemblyFormat = "$tensor attr-dict `:` type($tensor)";
   let hasVerifier = 1;
 }
 
 def SparseTensor_ToCoordinatesOp : SparseTensor_Op<"coordinates",
-      [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
+      [Pure, AlwaysSpeculatable, DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
     Arguments<(ins AnySparseTensor:$tensor, LevelAttr:$level)>,
     Results<(outs AnyNon0RankedMemRef:$result)> {
   let summary = "Extracts the `level`-th coordinates array of the `tensor`";
@@ -312,12 +312,12 @@ def SparseTensor_ToCoordinatesOp : SparseTensor_Op<"coordinates",
        : tensor<64x64xf64, #CSR> to memref<?xindex>
     ```
   }];
-  let assemblyFormat = "$tensor attr-dict `:` type($tensor) `to` type($result)";
+  let assemblyFormat = "$tensor attr-dict `:` type($tensor)";
   let hasVerifier = 1;
 }
 
 def SparseTensor_ToCoordinatesBufferOp : SparseTensor_Op<"coordinates_buffer",
-      [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
+      [Pure, AlwaysSpeculatable, DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
     Arguments<(ins AnySparseTensor:$tensor)>,
     Results<(outs AnyNon0RankedMemRef:$result)> {
   let summary = "Extracts the linear coordinates array from a tensor";
@@ -340,16 +340,15 @@ def SparseTensor_ToCoordinatesBufferOp : SparseTensor_Op<"coordinates_buffer",
     Example:
 
     ```mlir
-    %1 = sparse_tensor.coordinates_buffer %0
-       : tensor<64x64xf64, #COO> to memref<?xindex>
+    %1 = sparse_tensor.coordinates_buffer %0 : tensor<64x64xf64, #COO>
     ```
   }];
-  let assemblyFormat = "$tensor attr-dict `:` type($tensor) `to` type($result)";
+  let assemblyFormat = "$tensor attr-dict `:` type($tensor)";
   let hasVerifier = 1;
 }
 
 def SparseTensor_ToValuesOp : SparseTensor_Op<"values",
-      [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
+      [Pure, AlwaysSpeculatable, DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
     Arguments<(ins AnySparseTensor:$tensor)>,
     Results<(outs AnyNon0RankedMemRef:$result)> {
   let summary = "Extracts numerical values array from a tensor";
@@ -367,10 +366,10 @@ def SparseTensor_ToValuesOp : SparseTensor_Op<"values",
     Example:
 
     ```mlir
-    %1 = sparse_tensor.values %0 : tensor<64x64xf64, #CSR> to memref<?xf64>
+    %1 = sparse_tensor.values %0 : tensor<64x64xf64, #CSR>
     ```
   }];
-  let assemblyFormat = "$tensor attr-dict `:` type($tensor) `to` type($result)";
+  let assemblyFormat = "$tensor attr-dict `:` type($tensor)";
   let hasVerifier = 1;
 }
 
@@ -1438,6 +1437,9 @@ def ExtractIterSpaceOp : SparseTensor_Op<"iteration.extract_space",
     unsigned getSpaceDim() {
       return getHiLvl() - getLoLvl();
     }
+    ArrayRef<::mlir::sparse_tensor::LevelType> getSpaceLvlTypes() {
+      return getResultSpace().getType().getLvlTypes();
+    }
   }];
 
   let hasVerifier = 1;
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 0e9f5120f7b3dc..8d2b9fe571e20b 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -16,6 +16,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/OneToNTypeConversion.h"
 
 //===----------------------------------------------------------------------===//
 // Include the generated pass header (which needs some early definitions).
@@ -142,6 +143,21 @@ void populateLowerForeachToSCFPatterns(RewritePatternSet &patterns);
 
 std::unique_ptr<Pass> createLowerForeachToSCFPass();
 
+//===----------------------------------------------------------------------===//
+// The LowerSparseIterationToSCF pass.
+//===----------------------------------------------------------------------===//
+
+/// Type converter for iter_space and iterator.
+class SparseIterationTypeConverter : public OneToNTypeConverter {
+public:
+  SparseIterationTypeConverter();
+};
+
+void populateLowerSparseIterationToSCFPatterns(TypeConverter &converter,
+                                               RewritePatternSet &patterns);
+
+std::unique_ptr<Pass> createLowerSparseIterationToSCFPass();
+
 //===----------------------------------------------------------------------===//
 // The SparseTensorConversion pass.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 3ab75c23dbefa0..f27c64a7dee84e 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -210,6 +210,19 @@ def LowerForeachToSCF : Pass<"lower-sparse-foreach-to-scf", "func::FuncOp"> {
   ];
 }
 
+def LowerSparseIterationToSCF : Pass<"lower-sparse-iteration-to-scf", "func::FuncOp"> {
+  let summary = "lower sparse_tensor.iterate/coiterate into scf loops";
+  let description = [{
+     TODO:
+  }];
+  let constructor = "mlir::createLowerSparseIterationToSCFPass()";
+  let dependentDialects = [
+    "memref::MemRefDialect",
+    "scf::SCFDialect",
+    "sparse_tensor::SparseTensorDialect",
+  ];
+}
+
 def SparseTensorConversionPass : Pass<"sparse-tensor-conversion", "ModuleOp"> {
   let summary = "Convert sparse tensors and primitives to library calls";
   let description = [{
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index 8840da9aa56ef7..c615f9ad2370c7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
   SparseAssembler.cpp
   SparseBufferRewriting.cpp
   SparseGPUCodegen.cpp
+  SparseIterationToScf.cpp
   SparseReinterpretMap.cpp
   SparseStorageSpecifierToLLVM.cpp
   SparseSpaceCollapse.cpp
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
new file mode 100644
index 00000000000000..267eff724590e7
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -0,0 +1,76 @@
+
+#include "Utils/CodegenUtils.h"
+#include "Utils/SparseTensorLevel.h"
+
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+#include "mlir/Transforms/OneToNTypeConversion.h"
+
+using namespace mlir;
+using namespace mlir::sparse_tensor;
+
+static std::optional<LogicalResult>
+convertIterSpaceType(IterSpaceType itSp, SmallVectorImpl<Type> &fields) {
+  if (itSp.getSpaceDim() > 1)
+    llvm_unreachable("Not implemented.");
+
+  auto idxTp = IndexType::get(itSp.getContext());
+  // FIXME: this assumes that the Pos/CrdBitWidth in sparse tensor encoding is
+  // overriden to non-default values.
+  auto sparseMemRef = MemRefType::get({ShapedType::kDynamic}, idxTp);
+  for (LevelType lt : itSp.getLvlTypes()) {
+    // Position and coordinate buffer in the sparse structure.
+    if (lt.isWithPosLT())
+      fields.push_back(sparseMemRef);
+    if (lt.isWithCrdLT())
+      fields.push_back(sparseMemRef);
+  }
+  // Two indices for lower and upper bound.
+  fields.append({idxTp, idxTp});
+  return success();
+}
+
+namespace {
+
+/// Sparse codegen rule for number of entries operator.
+class ExtractIterSpaceConverter
+    : public OneToNOpConversionPattern<ExtractIterSpaceOp> {
+public:
+  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+  LogicalResult
+  matchAndRewrite(ExtractIterSpaceOp op, OpAdaptor adaptor,
+                  OneToNPatternRewriter &rewriter) const override {
+    if (op.getSpaceDim() > 1)
+      llvm_unreachable("Not implemented.");
+    Location loc = op.getLoc();
+
+    const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
+    std::unique_ptr<SparseTensorLevel> lvl =
+        makeSparseTensorLevel(rewriter, loc, op.getTensor(), 0, op.getLoLvl());
+
+    SmallVector<Value> result = llvm::to_vector(lvl->getLvlBuffers());
+    if (!op.getParentIter()) {
+      // TODO: handle batch.
+      std::pair<Value, Value> bounds = lvl->peekRangeAt(
+          rewriter, loc, /*batchPrefix*/ {}, constantIndex(rewriter, loc, 0));
+      result.append({bounds.first, bounds.second});
+    } else {
+      llvm_unreachable("Not implemented.");
+    }
+
+    rewriter.replaceOp(op, result, resultMapping);
+    return success();
+  }
+};
+
+} // namespace
+
+mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
+  addConversion([](Type type) { return type; });
+  addConversion(convertIterSpaceType);
+}
+
+void mlir::populateLowerSparseIterationToSCFPatterns(
+    TypeConverter &converter, RewritePatternSet &patterns) {
+  patterns.add<ExtractIterSpaceConverter>(converter, patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index d4c17928d4ca15..3d1a070330a476 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -26,6 +26,7 @@ namespace mlir {
 #define GEN_PASS_DEF_SPARSEREINTERPRETMAP
 #define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE
 #define GEN_PASS_DEF_SPARSIFICATIONPASS
+#define GEN_PASS_DEF_LOWERSPARSEITERATIONTOSCF
 #define GEN_PASS_DEF_LOWERSPARSEOPSTOFOREACH
 #define GEN_PASS_DEF_LOWERFOREACHTOSCF
 #define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
@@ -152,10 +153,32 @@ struct LowerForeachToSCFPass
     auto *ctx = &getContext();
     RewritePatternSet patterns(ctx);
     populateLowerForeachToSCFPatterns(patterns);
+
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
 };
 
+struct LowerSparseIterationToSCFPass
+    : public impl::LowerSparseIterationToSCFBase<
+          LowerSparseIterationToSCFPass> {
+  LowerSparseIterationToSCFPass() = default;
+  LowerSparseIterationToSCFPass(const LowerSparseIterationToSCFPass &) =
+      default;
+
+  void runOnOperation() override {
+    auto *ctx = &getContext();
+    RewritePatternSet patterns(ctx);
+    SparseIterationTypeConverter converter;
+    ConversionTarget target(*ctx);
+    target.addIllegalOp<ExtractIterSpaceOp, IterateOp>();
+
+    populateLowerSparseIterationToSCFPatterns(converter, patterns);
+    if (failed(applyPartialOneToNConversion(getOperation(), converter,
+                                            std::move(patterns))))
+      signalPassFailure();
+  }
+};
+
 struct SparseTensorConversionPass
     : public impl::SparseTensorConversionPassBase<SparseTensorConversionPass> {
   SparseTensorConversionPass() = default;
@@ -438,6 +461,10 @@ std::unique_ptr<Pass> mlir::createLowerForeachToSCFPass() {
   return std::make_unique<LowerForeachToSCFPass>();
 }
 
+std::unique_ptr<Pass> mlir::createLowerSparseIterationToSCFPass() {
+  return std::make_unique<LowerSparseIterationToSCFPass>();
+}
+
 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
   return std::make_unique<SparseTensorConversionPass>();
 }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
index bc27fae5d19480..3b501953ef0abe 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
@@ -46,21 +46,41 @@ using ValueTuple = std::tuple<Value, Value, Value>;
 
 namespace {
 
+template <bool hasPosBuffer>
 class SparseLevel : public SparseTensorLevel {
+  // It is either a array of size 2 or size 1 depending on whether the space
+  // level requires a position array.
+  using BufferT = std::conditional_t<hasPosBuffer, std::array<Value, 2>,
+                                     std::array<Value, 1>>;
+
 public:
   SparseLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
-              Value crdBuffer)
-      : SparseTensorLevel(tid, lvl, lt, lvlSize), crdBuffer(crdBuffer) {}
+              BufferT buffers)
+      : SparseTensorLevel(tid, lvl, lt, lvlSize), buffers(buffers) {}
+
+  ValueRange getLvlBuffers() const override { return buffers; }
 
   Value peekCrdAt(OpBuilder &b, Location l, ValueRange batchPrefix,
                   Value iv) const override {
     SmallVector<Value> memCrd(batchPrefix);
     memCrd.push_back(iv);
-    return genIndexLoad(b, l, crdBuffer, memCrd);
+    return genIndexLoad(b, l, getCrdBuf(), memCrd);
   }
 
 protected:
-  const Value crdBuffer;
+  template <typename T = void, typename = std::enable_if_t<hasPosBuffer, T>>
+  Value getPosBuf() const {
+    return buffers[0];
+  }
+
+  Value getCrdBuf() const {
+    if constexpr (hasPosBuffer)
+      return buffers[1];
+    else
+      return buffers[0];
+  }
+
+  const BufferT buffers;
 };
 
 class DenseLevel : public SparseTensorLevel {
@@ -72,6 +92,8 @@ class DenseLevel : public SparseTensorLevel {
     llvm_unreachable("locate random-accessible level instead");
   }
 
+  ValueRange getLvlBuffers() const override { return {}; }
+
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange, Value p,
                         Value max) const override {
     Value posLo = MULI(p, lvlSize);
@@ -88,6 +110,8 @@ class BatchLevel : public SparseTensorLevel {
     llvm_unreachable("locate random-accessible level instead");
   }
 
+  ValueRange getLvlBuffers() const override { return {}; }
+
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange, Value p,
                         Value max) const override {
     assert(max == nullptr && "Dense level can not be non-unique.");
@@ -96,11 +120,11 @@ class BatchLevel : public SparseTensorLevel {
   }
 };
 
-class CompressedLevel : public SparseLevel {
+class CompressedLevel : public SparseLevel<true> {
 public:
   CompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
                   Value posBuffer, Value crdBuffer)
-      : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
+      : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
                         Value p, Value max) const override {
@@ -109,21 +133,18 @@ class CompressedLevel : public SparseLevel {
 
     SmallVector<Value> memCrd(batchPrefix);
     memCrd.push_back(p);
-    Value pLo = genIndexLoad(b, l, posBuffer, memCrd);
+    Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
     memCrd.back() = ADDI(p, C_IDX(1));
-    Value pHi = genIndexLoad(b, l, posBuffer, memCrd);
+    Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
     return {pLo, pHi};
   }
-
-private:
-  const Value posBuffer;
 };
 
-class LooseCompressedLevel : public SparseLevel {
+class LooseCompressedLevel : public SparseLevel<true> {
 public:
   LooseCompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
                        Value posBuffer, Value crdBuffer)
-      : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
+      : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
                         Value p, Value max) const override {
@@ -133,21 +154,18 @@ class LooseCompressedLevel : public SparseLevel {
 
     p = MULI(p, C_IDX(2));
     memCrd.push_back(p);
-    Value pLo = genIndexLoad(b, l, posBuffer, memCrd);
+    Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
     memCrd.back() = ADDI(p, C_IDX(1));
-    Value pHi = genIndexLoad(b, l, posBuffer, memCrd);
+    Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
     return {pLo, pHi};
   }
-
-private:
-  const Value posBuffer;
 };
 
-class SingletonLevel : public SparseLevel {
+class SingletonLevel : public SparseLevel<false> {
 public:
   SingletonLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
                  Value crdBuffer)
-      : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer) {}
+      : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
                         Value p, Value segHi) const override {
@@ -159,11 +177,11 @@ class SingletonLevel : public SparseLevel {
   }
 };
 
-class NOutOfMLevel : public SparseLevel {
+class NOutOfMLevel : public SparseLevel<false> {
 public:
   NOutOfMLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
                Value crdBuffer)
-      : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer) {}
+      : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
                         Value p, Value max) const override {
@@ -1314,6 +1332,30 @@ ValueRange NonEmptySubSectIterator::forwardImpl(OpBuilder &b, Location l) {
 // SparseIterator factory functions.
 //===----------------------------------------------------------------------===//
 
+/// Helper function to create a TensorLevel object from given `tensor`.
+std::unique_ptr<SparseTensorLevel>
+sparse_tensor::makeSparseTensorLevel(LevelType lt, Value sz, ValueRange b,
+                                     unsigned t, Level l) {
+  assert(lt.getNumBuffer() == b.size());
+  switch (lt.getLvlFmt()) {
+  case LevelFormat::Dense:
+    return std::make_unique<DenseLevel>(t, l, sz);
+  case LevelFormat::Batch:
+    return std::make_unique<BatchLevel>(t, l, sz);
+  case LevelFormat::Compressed:
+    return std::make_unique<CompressedLevel>(t, l, lt, sz, b[0], b[1]);
+  case LevelFormat::LooseCompressed:
+    return std::make_unique<LooseCompressedLevel>(t, l, lt, sz, b[0], b[1]);
+  case LevelFormat::Singleton:
+    return std::make_unique<SingletonLevel>(t, l, lt, sz, b[0]);
+  case LevelFormat::NOutOfM:
+    return std::make_unique<NOutOfMLevel>(t, l, lt, sz, b[0]);
+  case LevelFormat::Undef:
+    llvm_unreachable("undefined level format");
+  }
+  llvm_unreachable("unrecognizable level format");
+}
+
 std::unique_ptr<SparseTensorLevel>
 sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
                                      unsigned tid, Level lvl) {
@@ -1323,33 +1365,16 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
   Value sz = stt.hasEncoding() ? b.create<LvlOp>(l, t, lvl).getResult()
                                : b.create<tensor::DimOp>(l, t, lvl).getResult();
 
-  switch (lt.getLvlFmt()) {
-  case LevelFormat::Dense:
-    return std::make_unique<DenseLevel>(tid, lvl, sz);
-  case LevelFormat::Batch:
-    return std::make_unique<BatchLevel>(tid, lvl, sz);
-  case LevelFormat::Compressed: {
-    Value pos = b.create<ToPositionsOp>(l, t, lvl);
-    Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
-    return std::make_unique<CompressedLevel>(tid, lvl, lt, sz, pos, crd);
-  }
-  case LevelFormat::LooseCompressed: {
+  SmallVector<Value, 2> buffers;
+  if (lt.isWithPosLT()) {
     Value pos = b.create<ToPositionsOp>(l, t, lvl);
-    Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
-    return std::make_unique<LooseCompressedLevel>(tid, lvl, lt, sz, pos, crd);
-  }
-  case LevelFormat::Singleton: {
-    Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
-    return std::make_unique<SingletonLevel>(tid, lvl, lt, sz, crd);
+    buffers.push_back(pos);
   }
-  case LevelFormat::NOutOfM: {
-    Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
-    return std::make_unique<NOutOfMLevel>(tid, lvl, lt, sz, crd);
+  if (lt.isWithCrdLT()) {
+    Value pos = b.create<ToCoordinatesOp>(l, t, lvl);
+    buffers.push_back(pos);
   }
-  case LevelFormat::Undef:
-    llvm_unreachable("undefined level format");
-  }
-  llvm_unreachable("unrecognizable level format");
+  return makeSparseTensorLevel(lt, sz, buffers, tid, lvl);
 }
 
 std::pair<std::unique_ptr<SparseTensorLevel>, std::unique_ptr<SparseIterator>>
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
index 9f92eecdf75cb6..46188fc112bd95 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
@@ -55,6 +55,7 @@ class SparseTensorLevel {
   Level getLevel() const { return lvl; }
   LevelType getLT() const { return lt; }
   Value getSize() const { return lvlSize; }
+  virtual ValueRange getLvlBuffers() const = 0;
 
   //
   // Level properties
@@ -287,6 +288,11 @@ std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(OpBuilder &builder,
                                                          Location loc, Value t,
                                                          unsigned tid, Level l);
 
+/// Helper function to create a TensorLevel object from given `tensor`.
+std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(LevelType lt, Value sz,
+                                                         ValueRange buffers,
+                                                         unsigned tid, Level l);
+
 /// Helper function to create a simple SparseIterator object that iterate over
 /// the SparseTensorLevel.
 std::unique_ptr<SparseIterator> makeSimpleIterator(const SparseTensorLevel &stl,



More information about the Mlir-commits mailing list