[Mlir-commits] [mlir] [WIP][mlir][sparse] Setting up sparse_tensor.iterator-related Ops. (PR #85958)

Peiming Liu llvmlistbot at llvm.org
Wed Mar 20 09:26:48 PDT 2024


https://github.com/PeimingLiu created https://github.com/llvm/llvm-project/pull/85958

This is an WIP PR that prototypes new sparse loop structures that retains sparse semantics. The ultimate goal of the project is to make better use of MLIR's progressive lowering strategy inside MLIR sparsifier such that high level "sparse-aware" optimization can be performed.

The draft PR currently provide a naive implementation that showcases how consecutive "sparse spaces" can be collapsed similarly as dense loops by executing
```
mlir-opt test.mlir --sparse-space-collapse
```
such that nested loops before the transformation in
```
// test.mlir
#COO = #sparse_tensor.encoding<{ map = (i, j) -> (i : compressed(nonunique), j : singleton)}>

func.func @sparse_slice_stride(%sp : tensor<4x8xf32, #COO>, %i : index, %j : index) {
  %l1 = sparse_tensor.iteration.extract_space %sp lvls = 0 : tensor<4x8xf32, #COO>
  %r1 = sparse_tensor.iterate %it1 in %l1 iter_args(%outer = %i): !sparse_tensor.iter_space<[compressed(nonunique)]> -> index {
    %l2 = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<[compressed(nonunique)]>
    %r2 = sparse_tensor.iterate %it2 in %l2 iter_args(%inner = %outer): !sparse_tensor.iter_space<[singleton]> -> index {
      %k = arith.addi %inner, %j : index
      sparse_tensor.yield %k : index
    }
    sparse_tensor.yield %r2 : index
  }
  "test.op"(%r1) : (index) -> ()
  return
}
```
while be transformed into a collapsed single loops as
```
#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton) }>
module {
  func.func @sparse_slice_stride(%arg0: tensor<4x8xf32, #sparse>, %arg1: index, %arg2: index) {
    %0 = sparse_tensor.iteration.extract_space %arg0 lvls = 0 to 2 : tensor<4x8xf32, #sparse>
    %1 = sparse_tensor.iterate %arg3 in %0 iter_args(%arg4 = %arg1) : !sparse_tensor.iter_space<[compressed(nonunique),singleton]> -> (index) {
      %2 = arith.addi %arg4, %arg2 : index
      sparse_tensor.yield %2 : index
    }
    "test.op"(%1) : (index) -> ()
    return
  }
}

```




>From 1e76a2a5417212bb2e5891a00c0069c4118ce778 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 18 Mar 2024 19:40:46 +0000
Subject: [PATCH 1/2] test parse iterate operation

---
 .../SparseTensor/IR/SparseTensorOps.td        | 17 +++++
 .../SparseTensor/IR/SparseTensorTypes.td      | 76 +++++++++++++++++++
 2 files changed, 93 insertions(+)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 29cf8c32447ecf..6efeb6007d649e 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1418,6 +1418,23 @@ def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Iteration Operations.
+//===----------------------------------------------------------------------===//
+
+def IterateOp : SparseTensor_Op<"iterate",
+    [RecursiveMemoryEffects]> {
+
+  let arguments = (ins AnySparseIterSpace:$iterSpace,
+                       Variadic<AnyType>:$initArgs);
+  let results = (outs Variadic<AnyType>:$results);
+  let regions = (region SizedRegion<1>:$region);
+
+  let extraClassDeclaration = [{}];
+
+  let hasCustomAssemblyFormat = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // Sparse Tensor Debugging and Test-Only Operations.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
index 185cff46ae25d5..54a8e4d7ecd398 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
@@ -72,4 +72,80 @@ def SparseTensorStorageSpecifier
     : Type<CPred<"::llvm::isa<::mlir::sparse_tensor::StorageSpecifierType>($_self)">, "metadata",
           "::mlir::sparse_tensor::StorageSpecifierType">;
 
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Iteration Types.
+//===----------------------------------------------------------------------===//
+
+def LevelTypeArrayParameter : ArrayRefParameter<"::mlir::sparse_tensor::LevelType", "level-types"> {
+  let printer = [{
+    auto lvlStrings = llvm::map_range($_self, [](auto lt){ return lt.toMLIRString(); });
+    $_printer << "[" << llvm::join(lvlStrings, ",") << "]";
+  }];
+
+  let parser = [{ [&]() -> FailureOr<SmallVector<::mlir::sparse_tensor::LevelType>> {
+    SmallVector<::mlir::sparse_tensor::LevelType> ret;
+
+    const auto res = $_parser.parseCommaSeparatedList(
+      mlir::OpAsmParser::Delimiter::Square,
+      [&]() -> ParseResult {
+        ::mlir::sparse_tensor::ir_detail::LvlTypeParser lParser;
+        auto lvlTpOrFail = lParser.parseLvlType($_parser);
+        if (failed(lvlTpOrFail))
+          return failure();
+        ret.emplace_back(*lvlTpOrFail);
+        return success();
+      }, " in level-type list");
+
+    if (failed(res))
+      return failure();
+    return ret;
+  }() }];
+}
+
+def SparseTensor_Iterator : SparseTensor_Type<"Iterator"> {
+  let mnemonic = "iterator";
+
+  let parameters = (ins
+     LevelTypeArrayParameter: $lvlTypes
+  );
+
+  // let skipDefaultBuilders = 1;
+  // let hasCustomAssemblyFormat = 1;
+  let assemblyFormat="`<` $lvlTypes `>`";
+}
+
+
+def SparseTensor_IterSpace : SparseTensor_Type<"IterSpace"> {
+  let mnemonic = "iter_space";
+
+  let parameters = (ins
+     LevelTypeArrayParameter: $lvlTypes
+  );
+
+  let extraClassDeclaration = [{
+     ::mlir::sparse_tensor::IteratorType getIteratorType() const {
+        return IteratorType::get(getContext(), getLvlTypes());
+     }
+  }];
+
+  // let skipDefaultBuilders = 1;
+  // let hasCustomAssemblyFormat = 1;
+  let assemblyFormat="`<` $lvlTypes `>`";
+}
+
+def IsSparseSparseIterSpaceTypePred
+    : CPred<"::llvm::isa<::mlir::sparse_tensor::IterSpaceType>($_self)">;
+
+def IsSparseSparseIteratorTypePred
+    : CPred<"::llvm::isa<::mlir::sparse_tensor::IteratorType>($_self)">;
+
+def AnySparseIterSpace
+    : Type<IsSparseSparseIterSpaceTypePred, "sparse iteration space",
+          "::mlir::sparse_tensor::IterSpaceType">;
+
+def AnySparseIterator
+    : Type<IsSparseSparseIteratorTypePred, "sparse iterator",
+          "::mlir::sparse_tensor::IteratorType">;
+
+
 #endif // SPARSETENSOR_TYPES

>From e64944645a4301b393ec4e8c77ee58aaf4cfc19d Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 20 Mar 2024 16:14:12 +0000
Subject: [PATCH 2/2] test sparse space collapse

---
 .../Dialect/SparseTensor/IR/SparseTensor.h    |   2 +
 .../SparseTensor/IR/SparseTensorOps.td        |  41 ++-
 .../SparseTensor/IR/SparseTensorTypes.td      |  15 +-
 .../Dialect/SparseTensor/Transforms/Passes.h  |   6 +
 .../Dialect/SparseTensor/Transforms/Passes.td |  16 ++
 .../SparseTensor/IR/SparseTensorDialect.cpp   | 253 +++++++++++++++++-
 .../SparseTensor/Transforms/CMakeLists.txt    |   1 +
 .../Transforms/SparseSpaceCollapse.cpp        | 152 +++++++++++
 8 files changed, 479 insertions(+), 7 deletions(-)
 create mode 100644 mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 5e523ec428aefb..78692307820bc5 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -17,7 +17,9 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/TensorEncoding.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 6efeb6007d649e..f0eaf2191fdbd1 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -15,6 +15,8 @@ include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td"
 include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/ControlFlowInterfaces.td"
+include "mlir/Interfaces/LoopLikeInterface.td"
 
 //===----------------------------------------------------------------------===//
 // Base class.
@@ -1422,16 +1424,51 @@ def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
 // Sparse Tensor Iteration Operations.
 //===----------------------------------------------------------------------===//
 
+def ExtractIterSpaceOp : SparseTensor_Op<"iteration.extract_space",
+    [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+
+  let arguments = (ins AnySparseTensor:$tensor,
+                       Optional<AnySparseIterator>:$parentIter,
+                       LevelAttr:$loLvl, LevelAttr:$hiLvl);
+
+  let results = (outs AnySparseIterSpace:$resultSpace);
+
+  let extraClassDeclaration = [{
+    unsigned getSpaceDim() {
+      return getHiLvl() - getLoLvl();
+    }
+  }];
+
+  let hasVerifier = 1;
+  let assemblyFormat = "$tensor (`at`$parentIter^)? `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) "
+                       " attr-dict `:` type($tensor) (`,` type($parentIter)^)?";
+}
+
 def IterateOp : SparseTensor_Op<"iterate",
-    [RecursiveMemoryEffects]> {
+    [RecursiveMemoryEffects, RecursivelySpeculatable,
+     DeclareOpInterfaceMethods<LoopLikeOpInterface,
+      ["getInitsMutable", "getLoopResults", "getRegionIterArgs",
+      "getSingleInductionVar", "getYieldedValuesMutable"]>,
+     DeclareOpInterfaceMethods<RegionBranchOpInterface,
+      ["getEntrySuccessorOperands"]>,
+     SingleBlockImplicitTerminator<"sparse_tensor::YieldOp">]> {
 
   let arguments = (ins AnySparseIterSpace:$iterSpace,
                        Variadic<AnyType>:$initArgs);
   let results = (outs Variadic<AnyType>:$results);
   let regions = (region SizedRegion<1>:$region);
 
-  let extraClassDeclaration = [{}];
+  let extraClassDeclaration = [{
+    BlockArgument getIterator() {
+      return getRegion().getArguments().front();
+    }
+    unsigned getNumRegionIterArgs() {
+      return getRegion().getArguments().size() - 1;
+    }
+  }];
 
+  let hasVerifier = 1;
+  let hasRegionVerifier = 1;
   let hasCustomAssemblyFormat = 1;
 }
 
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
index 54a8e4d7ecd398..aa674b613e71db 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
@@ -109,8 +109,13 @@ def SparseTensor_Iterator : SparseTensor_Type<"Iterator"> {
      LevelTypeArrayParameter: $lvlTypes
   );
 
-  // let skipDefaultBuilders = 1;
-  // let hasCustomAssemblyFormat = 1;
+  let extraClassDeclaration = [{
+     unsigned getSpaceDim() {
+       return getLvlTypes().size();
+     }
+  }];
+
+
   let assemblyFormat="`<` $lvlTypes `>`";
 }
 
@@ -123,13 +128,15 @@ def SparseTensor_IterSpace : SparseTensor_Type<"IterSpace"> {
   );
 
   let extraClassDeclaration = [{
+     unsigned getSpaceDim() {
+       return getLvlTypes().size();
+     }
+
      ::mlir::sparse_tensor::IteratorType getIteratorType() const {
         return IteratorType::get(getContext(), getLvlTypes());
      }
   }];
 
-  // let skipDefaultBuilders = 1;
-  // let hasCustomAssemblyFormat = 1;
   let assemblyFormat="`<` $lvlTypes `>`";
 }
 
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 61b07d222d156b..0e9f5120f7b3dc 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -247,6 +247,12 @@ std::unique_ptr<Pass> createSparsificationAndBufferizationPass(
     bool enableBufferInitialization, unsigned vectorLength,
     bool enableVLAVectorization, bool enableSIMDIndex32, bool enableGPULibgen);
 
+//===----------------------------------------------------------------------===//
+// Sparse Iteration Transform Passes
+//===----------------------------------------------------------------------===//
+
+std::unique_ptr<Pass> createSparseSpaceCollapsePass();
+
 //===----------------------------------------------------------------------===//
 // Registration.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 58e2d6f32386c3..3ab75c23dbefa0 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -454,4 +454,20 @@ def SparsificationAndBufferization : Pass<"sparsification-and-bufferization", "M
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// Sparse Iteration Transform Passes
+//===----------------------------------------------------------------------===//
+
+def SparseSpaceCollapse : Pass<"sparse-space-collapse", "func::FuncOp"> {
+  let summary = "(experimental) sparse space collpasing pass";
+  let description = [{
+     This pass collapse consecutive sparse spaces (extracted from the same tensor)
+     into one multi-dimensional space.
+  }];
+  let constructor = "mlir::createSparseSpaceCollapsePass()";
+  let dependentDialects = [
+    "sparse_tensor::SparseTensorDialect",
+  ];
+}
+
 #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 6da51bb6b9cacf..6afa3e6309dc65 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1912,13 +1912,264 @@ LogicalResult YieldOp::verify() {
   auto *parentOp = (*this)->getParentOp();
   if (isa<BinaryOp>(parentOp) || isa<UnaryOp>(parentOp) ||
       isa<ReduceOp>(parentOp) || isa<SelectOp>(parentOp) ||
-      isa<ForeachOp>(parentOp))
+      isa<ForeachOp>(parentOp) || isa<IterateOp>(parentOp))
     return success();
 
   return emitOpError("expected parent op to be sparse_tensor unary, binary, "
                      "reduce, select or foreach");
 }
 
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Iteration Operations.
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr,
+                                   IntegerAttr &lvlHiAttr) {
+  Level lvlLo, lvlHi;
+
+  if (parser.parseInteger(lvlLo))
+    return failure();
+
+  if (succeeded(parser.parseOptionalKeyword("to"))) {
+    if (parser.parseInteger(lvlHi))
+      return failure();
+  } else {
+    lvlHi = lvlLo + 1;
+  }
+  lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo);
+  lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi);
+  return success();
+}
+
+static void printLevelRange(OpAsmPrinter &p, ExtractIterSpaceOp op,
+                            IntegerAttr lvlLo, IntegerAttr lvlHi) {
+  if (op.getLoLvl() + 1 == op.getHiLvl())
+    p << op.getLoLvl();
+  else
+    p << op.getLoLvl() << " to " << op.getHiLvl();
+}
+
+LogicalResult ExtractIterSpaceOp::inferReturnTypes(
+    MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
+    DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
+    SmallVectorImpl<mlir::Type> &ret) {
+
+  ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region);
+  SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
+  ArrayRef<LevelType> lts = stt.getLvlTypes().slice(
+      adaptor.getLoLvl(), adaptor.getHiLvl() - adaptor.getLoLvl());
+  ret.push_back(IterSpaceType::get(ctx, lts));
+  return success();
+}
+
+LogicalResult ExtractIterSpaceOp::verify() {
+  SparseTensorType stt = getSparseTensorType(getTensor());
+  if (getLoLvl() >= getHiLvl())
+    return emitOpError("expected smaller level low than level high");
+
+  ArrayRef<LevelType> lts = stt.getLvlTypes().slice(getLoLvl(), getSpaceDim());
+  if (!getResultSpace().getType().getLvlTypes().equals(lts)) {
+    return emitOpError(
+        "mismatch in iteration space level types and tensor level types");
+  }
+
+  TypedValue<IteratorType> pIter = getParentIter();
+  if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) {
+    return emitOpError("parent iterator is only needed iff level low equals 0");
+  }
+
+  if (pIter) {
+    unsigned pDim = pIter.getType().getSpaceDim();
+    if (getLoLvl() < pDim || !stt.getLvlTypes()
+                                  .slice(getLoLvl() - pDim, pDim)
+                                  .equals(pIter.getType().getLvlTypes())) {
+      return emitOpError(
+          "mismatch in parent iterator level types and tensor level types");
+    }
+  }
+
+  return success();
+}
+
+ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
+  OpAsmParser::Argument iterator;
+  OpAsmParser::UnresolvedOperand iterSpace;
+
+  // Parses %iters in %spaces
+  if (parser.parseArgument(iterator) || parser.parseKeyword("in") ||
+      parser.parseOperand(iterSpace)) {
+    return failure();
+  }
+
+  // Parse the optional initial iteration arguments.
+  SmallVector<OpAsmParser::Argument> regionArgs;
+  SmallVector<OpAsmParser::UnresolvedOperand> operands;
+  // Region arguments starts with iterators and follows by optional
+  // user-provided iter_args.
+  regionArgs.push_back(iterator);
+  bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
+  if (hasIterArgs)
+    if (parser.parseAssignmentList(regionArgs, operands))
+      return failure();
+
+  // parse ": sparse_tensor.iter_space -> ret"
+  Type iterSpaceTps;
+  if (parser.parseColon() || parser.parseType(iterSpaceTps))
+    return failure();
+  if (hasIterArgs)
+    if (parser.parseArrowTypeList(result.types))
+      return failure();
+
+  if (regionArgs.size() != result.types.size() + 1) {
+    return parser.emitError(
+        parser.getNameLoc(),
+        "mismatch in number of loop-carried values and defined values");
+  }
+
+  // Resolves input operands.
+  if (parser.resolveOperand(iterSpace, iterSpaceTps, result.operands))
+    return failure();
+
+  if (hasIterArgs) {
+    for (auto argOperandType :
+         llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
+      Type type = std::get<2>(argOperandType);
+      std::get<0>(argOperandType).type = type;
+      if (parser.resolveOperand(std::get<1>(argOperandType), type,
+                                result.operands))
+        return failure();
+    }
+  }
+
+  Region *body = result.addRegion();
+  regionArgs.front().type =
+      iterSpaceTps.cast<IterSpaceType>().getIteratorType();
+  if (parser.parseRegion(*body, regionArgs))
+    return failure();
+
+  IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
+
+  // Parse the optional attribute list.
+  if (parser.parseOptionalAttrDict(result.attributes))
+    return failure();
+
+  return success();
+}
+
+/// Prints the initialization list in the form of
+///   <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
+/// where 'inner' values are assumed to be region arguments and 'outer' values
+/// are regular SSA values.
+static void printInitializationList(OpAsmPrinter &p,
+                                    Block::BlockArgListType blocksArgs,
+                                    ValueRange initializers,
+                                    StringRef prefix = "") {
+  assert(blocksArgs.size() == initializers.size() &&
+         "expected same length of arguments and initializers");
+  if (initializers.empty())
+    return;
+
+  p << prefix << '(';
+  llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
+    p << std::get<0>(it) << " = " << std::get<1>(it);
+  });
+  p << ")";
+}
+
+void IterateOp::print(OpAsmPrinter &p) {
+  p << " " << getIterator() << " in " << getIterSpace();
+
+  printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
+
+  p << " : " << getIterSpace().getType() << " ";
+  if (!getInitArgs().empty())
+    p << "-> (" << getInitArgs().getTypes() << ") ";
+
+  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
+                /*printBlockTerminators=*/!getInitArgs().empty());
+}
+
+LogicalResult IterateOp::verify() {
+  if (getInitArgs().size() != getNumResults()) {
+    return emitOpError(
+        "mismatch in number of loop-carried values and defined values");
+  }
+  return success();
+}
+
+LogicalResult IterateOp::verifyRegions() {
+  if (getIterator().getType() != getIterSpace().getType().getIteratorType())
+    return emitOpError("mismatch in iterator and iteration space type");
+  if (getNumRegionIterArgs() != getNumResults())
+    return emitOpError(
+        "mismatch in number of basic block args and defined values");
+
+  auto initArgs = getInitArgs();
+  auto iterArgs = getRegionIterArgs();
+  auto yieldVals = getYieldedValues();
+  auto opResults = getResults();
+  if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
+                        opResults.size()})) {
+    return emitOpError() << "number mismatch between iter args and results.";
+  }
+
+  unsigned i = 0;
+  for (auto e : llvm::zip_equal(initArgs, iterArgs, yieldVals, opResults)) {
+    if (std::get<0>(e).getType() != std::get<3>(e).getType())
+      return emitOpError() << "types mismatch between " << i
+                           << "th iter operand and defined value";
+    if (std::get<1>(e).getType() != std::get<3>(e).getType())
+      return emitOpError() << "types mismatch between " << i
+                           << "th iter region arg and defined value";
+    if (std::get<2>(e).getType() != std::get<3>(e).getType())
+      return emitOpError() << "types mismatch between " << i
+                           << "th yield value and defined value";
+
+    ++i;
+  }
+  return success();
+}
+
+/// IterateOp implemented interfaces' methods.
+SmallVector<Region *> IterateOp::getLoopRegions() { return {&getRegion()}; }
+
+std::optional<Value> IterateOp::getSingleInductionVar() {
+  return getIterator();
+}
+
+MutableArrayRef<OpOperand> IterateOp::getInitsMutable() {
+  return getInitArgsMutable();
+}
+
+Block::BlockArgListType IterateOp::getRegionIterArgs() {
+  return getRegion().getArguments().drop_front();
+}
+
+std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() {
+  return cast<sparse_tensor::YieldOp>(
+             getRegion().getBlocks().front().getTerminator())
+      .getResultMutable();
+}
+
+std::optional<ResultRange> IterateOp::getLoopResults() { return getResults(); }
+
+OperandRange IterateOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+  return getInitArgs();
+}
+
+void IterateOp::getSuccessorRegions(RegionBranchPoint point,
+                                    SmallVectorImpl<RegionSuccessor> &regions) {
+  // Both the operation itself and the region may be branching into the body or
+  // back into the operation itself.
+  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
+  // It is possible for loop not to enter the body.
+  regions.push_back(RegionSuccessor(getResults()));
+}
+
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Dialect Setups.
+//===----------------------------------------------------------------------===//
+
 /// Materialize a single constant operation from a given attribute value with
 /// the desired resultant type.
 Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index 3c0f82fc00bb9d..8840da9aa56ef7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
   SparseGPUCodegen.cpp
   SparseReinterpretMap.cpp
   SparseStorageSpecifierToLLVM.cpp
+  SparseSpaceCollapse.cpp
   SparseTensorCodegen.cpp
   SparseTensorConversion.cpp
   SparseTensorPasses.cpp
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
new file mode 100644
index 00000000000000..f3207ede9585b4
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
@@ -0,0 +1,152 @@
+//===--------- SparseSpaceCollapse.cpp - Collapse Sparse Space Pass -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/Transforms/Passes.h"
+
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+
+namespace mlir {
+
+#define GEN_PASS_DEF_SPARSESPACECOLLAPSE
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
+
+namespace sparse_tensor {
+
+bool isCollapsableIterations(LoopLikeOpInterface parent,
+                             LoopLikeOpInterface node) {
+  auto pIterArgs = parent.getRegionIterArgs();
+  auto nInitArgs = node.getInits();
+  if (pIterArgs.size() != nInitArgs.size())
+    return false;
+
+  auto pYields = parent.getYieldedValues();
+  auto nResult = node.getLoopResults().value();
+
+  bool yieldEq =
+      llvm::all_of(llvm::zip_equal(pYields, nResult), [](auto zipped) {
+        return std::get<0>(zipped) == std::get<1>(zipped);
+      });
+
+  // Parent iter_args should be passed directly to the node's init_args.
+  bool iterArgEq =
+      llvm::all_of(llvm::zip_equal(pIterArgs, nInitArgs), [](auto zipped) {
+        return std::get<0>(zipped) == std::get<1>(zipped);
+      });
+
+  return yieldEq && iterArgEq;
+}
+
+bool legalToCollapse(ExtractIterSpaceOp parent, ExtractIterSpaceOp node) {
+  auto pItOp = llvm::dyn_cast<IterateOp>(parent->getParentOp());
+  auto nItOp = llvm::dyn_cast<IterateOp>(node->getParentOp());
+
+  // Can only collapse spaces extracted from the same tensor.
+  if (parent.getTensor() != node.getTensor() || !parent->hasOneUse())
+    return false;
+
+  // Can only collapse consecutive simple iteration on one tensor (i.e., no
+  // coiteration).
+  if (!nItOp || nItOp.getIterSpace() != parent.getResult() ||
+      nItOp->getBlock() != parent->getBlock())
+    return false;
+
+  if (pItOp && !isCollapsableIterations(pItOp, nItOp))
+    return false;
+
+  // TODO: Make sure all other operations in the same basic block as `node` can
+  // be collapsed and sink into the collapsed iteration (through Interfaces
+  // defined in TD files).
+  return true;
+}
+
+void collapseSparseSpace(ArrayRef<ExtractIterSpaceOp> toCollapse) {
+  if (toCollapse.size() < 2)
+    return;
+
+  ExtractIterSpaceOp root = toCollapse.front();
+  ExtractIterSpaceOp leaf = toCollapse.back();
+  Location loc = root.getLoc();
+
+  if (!leaf->hasOneUse())
+    return;
+  assert(root->hasOneUse());
+
+  // Insert collapsed operation at the same scope as root operation.
+  OpBuilder builder(toCollapse.front());
+
+  // Construct the collapsed iteration space.
+  auto collapsedSpace = builder.create<ExtractIterSpaceOp>(
+      loc, root.getTensor(), root.getParentIter(), root.getLoLvl(),
+      leaf.getHiLvl());
+
+  auto rItOp = llvm::cast<IterateOp>(*root->getUsers().begin());
+  auto pItOp = llvm::cast<IterateOp>(leaf->getParentOp());
+
+  // This could either be IterateOp or (TODO: in the future) CoIterateOp.
+  auto loop = llvm::dyn_cast<IterateOp>(*leaf->getUsers().begin());
+  if (!loop || !isCollapsableIterations(pItOp, loop))
+    return;
+
+  IRMapping mapper;
+  mapper.map(leaf, collapsedSpace.getResultSpace());
+  for (auto z : llvm::zip_equal(loop.getInitArgs(), rItOp.getInitArgs()))
+    mapper.map(std::get<0>(z), std::get<1>(z));
+
+  auto cloned = llvm::cast<IterateOp>(builder.clone(*loop, mapper));
+  cloned.getIterator().setType(collapsedSpace.getType().getIteratorType());
+
+  rItOp.replaceAllUsesWith(cloned.getResults());
+  // Erase collapsed loops.
+  rItOp.erase();
+  root.erase();
+}
+
+struct SparseSpaceCollapsePass
+    : public impl::SparseSpaceCollapseBase<SparseSpaceCollapsePass> {
+  SparseSpaceCollapsePass() = default;
+
+  void runOnOperation() override {
+    func::FuncOp func = getOperation();
+
+    // A naive (experimental) implementation to collapse consecutive sparse
+    // spaces. It does NOT handle complex cases where multiple spaces are
+    // extracted in the same basic block. E.g.,
+    //
+    // %space1 = extract_space %t1 ...
+    // %space2 = extract_space %t2 ...
+    // sparse_tensor.iterate(%sp1) ...
+    //
+    SmallVector<ExtractIterSpaceOp> toCollapse;
+    func->walk([&](ExtractIterSpaceOp op) {
+      if (toCollapse.empty()) {
+        // Root space to collapse.
+        toCollapse.push_back(op);
+      } else {
+        if (legalToCollapse(toCollapse.back(), op)) {
+          toCollapse.push_back(op);
+        } else {
+          collapseSparseSpace(toCollapse);
+          toCollapse.clear();
+        }
+      }
+    });
+
+    collapseSparseSpace(toCollapse);
+  }
+};
+
+} // namespace sparse_tensor
+
+std::unique_ptr<Pass> createSparseSpaceCollapsePass() {
+  return std::make_unique<sparse_tensor::SparseSpaceCollapsePass>();
+}
+
+} // namespace mlir



More information about the Mlir-commits mailing list