[Mlir-commits] [mlir] [mlir][sparse] implement sparse space collapse pass. (PR #89003)

Peiming Liu llvmlistbot at llvm.org
Mon Jun 10 14:14:13 PDT 2024


https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/89003

>From cbe0ef1fb34e59e7fc4562e2b3b25584f22410c3 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 16 Apr 2024 17:00:45 +0000
Subject: [PATCH 1/4] [mlir][sparse] implement sparse space collapse pass.

---
 .../Dialect/SparseTensor/Transforms/Passes.h  |   6 +
 .../Dialect/SparseTensor/Transforms/Passes.td |  16 ++
 .../SparseTensor/Transforms/CMakeLists.txt    |   1 +
 .../Transforms/SparseSpaceCollapse.cpp        | 183 ++++++++++++++++++
 .../SparseTensor/sparse_space_collapse.mlir   |  33 ++++
 5 files changed, 239 insertions(+)
 create mode 100644 mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
 create mode 100644 mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir

diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index d6d038ef65bdf..3043a0c4dc410 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -248,6 +248,12 @@ std::unique_ptr<Pass> createSparsificationAndBufferizationPass(
     bool enableBufferInitialization, unsigned vectorLength,
     bool enableVLAVectorization, bool enableSIMDIndex32, bool enableGPULibgen);
 
+//===----------------------------------------------------------------------===//
+// Sparse Iteration Transform Passes
+//===----------------------------------------------------------------------===//
+
+std::unique_ptr<Pass> createSparseSpaceCollapsePass();
+
 //===----------------------------------------------------------------------===//
 // Registration.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 2f844cee5ff52..42badd1bd9756 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -464,4 +464,20 @@ def SparsificationAndBufferization : Pass<"sparsification-and-bufferization", "M
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// Sparse Iteration Transform Passes
+//===----------------------------------------------------------------------===//
+
+def SparseSpaceCollapse : Pass<"sparse-space-collapse", "func::FuncOp"> {
+  let summary = "(experimental) sparse space collpasing pass";
+  let description = [{
+     This pass collapse consecutive sparse spaces (extracted from the same tensor)
+     into one multi-dimensional space. The pass is not yet stablized.
+  }];
+  let constructor = "mlir::createSparseSpaceCollapsePass()";
+  let dependentDialects = [
+    "sparse_tensor::SparseTensorDialect",
+  ];
+}
+
 #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index af3a1b48f45af..2a29ee8a7a87c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
   SparseGPUCodegen.cpp
   SparseReinterpretMap.cpp
   SparseStorageSpecifierToLLVM.cpp
+  SparseSpaceCollapse.cpp
   SparseTensorCodegen.cpp
   SparseTensorConversion.cpp
   SparseTensorPasses.cpp
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
new file mode 100644
index 0000000000000..bc469992d9710
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
@@ -0,0 +1,183 @@
+//===--------- SparseSpaceCollapse.cpp - Collapse Sparse Space Pass -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/Transforms/Passes.h"
+
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+
+namespace mlir {
+
+#define GEN_PASS_DEF_SPARSESPACECOLLAPSE
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
+
+namespace sparse_tensor {
+
+struct CollapseSpaceInfo {
+  ExtractIterSpaceOp space;
+  IterateOp loop;
+};
+
+bool isCollapsableLoops(LoopLikeOpInterface parent, LoopLikeOpInterface node) {
+  auto pIterArgs = parent.getRegionIterArgs();
+  auto nInitArgs = node.getInits();
+  if (pIterArgs.size() != nInitArgs.size())
+    return false;
+
+  // Two loops are collapsable if they are perfectly nested.
+  auto pYields = parent.getYieldedValues();
+  auto nResult = node.getLoopResults().value();
+
+  bool yieldEq =
+      llvm::all_of(llvm::zip_equal(pYields, nResult), [](auto zipped) {
+        return std::get<0>(zipped) == std::get<1>(zipped);
+      });
+
+  // Parent iter_args should be passed directly to the node's init_args.
+  bool iterArgEq =
+      llvm::all_of(llvm::zip_equal(pIterArgs, nInitArgs), [](auto zipped) {
+        return std::get<0>(zipped) == std::get<1>(zipped);
+      });
+
+  return yieldEq && iterArgEq;
+}
+
+bool legalToCollapse(SmallVectorImpl<CollapseSpaceInfo> &toCollapse,
+                     ExtractIterSpaceOp curSpace) {
+
+  auto getIterateOpOverSpace = [](ExtractIterSpaceOp space) -> IterateOp {
+    Value spaceVal = space.getResultSpace();
+    if (spaceVal.hasOneUse())
+      return llvm::dyn_cast<IterateOp>(*spaceVal.getUsers().begin());
+    return nullptr;
+  };
+
+  if (toCollapse.empty()) {
+    // Collapse root.
+    if (auto itOp = getIterateOpOverSpace(curSpace)) {
+      CollapseSpaceInfo &info = toCollapse.emplace_back();
+      info.space = curSpace;
+      info.loop = itOp;
+      return true;
+    }
+    return false;
+  }
+
+  auto parent = toCollapse.back().space;
+  auto pItOp = toCollapse.back().loop;
+  auto nItOp = getIterateOpOverSpace(curSpace);
+
+  // Can only collapse spaces extracted from the same tensor.
+  if (parent.getTensor() != curSpace.getTensor())
+    return false;
+
+  // Can only collapse consecutive simple iteration on one tensor (i.e., no
+  // coiteration).
+  if (!nItOp || nItOp->getBlock() != curSpace->getBlock() ||
+      pItOp.getIterator() != curSpace.getParentIter() ||
+      curSpace->getParentOp() != pItOp.getOperation())
+    return false;
+
+  if (pItOp && !isCollapsableLoops(pItOp, nItOp))
+    return false;
+
+  CollapseSpaceInfo &info = toCollapse.emplace_back();
+  info.space = curSpace;
+  info.loop = nItOp;
+  return true;
+}
+
+void collapseSparseSpace(MutableArrayRef<CollapseSpaceInfo> toCollapse) {
+  if (toCollapse.size() < 2)
+    return;
+
+  ExtractIterSpaceOp root = toCollapse.front().space;
+  ExtractIterSpaceOp leaf = toCollapse.back().space;
+  Location loc = root.getLoc();
+
+  assert(root->hasOneUse() && leaf->hasOneUse());
+
+  // Insert collapsed operation at the same scope as root operation.
+  OpBuilder builder(root);
+
+  // Construct the collapsed iteration space.
+  auto collapsedSpace = builder.create<ExtractIterSpaceOp>(
+      loc, root.getTensor(), root.getParentIter(), root.getLoLvl(),
+      leaf.getHiLvl());
+
+  auto rItOp = llvm::cast<IterateOp>(*root->getUsers().begin());
+  auto innermost = toCollapse.back().loop;
+
+  IRMapping mapper;
+  mapper.map(leaf, collapsedSpace.getResultSpace());
+  for (auto z : llvm::zip_equal(innermost.getInitArgs(), rItOp.getInitArgs()))
+    mapper.map(std::get<0>(z), std::get<1>(z));
+
+  auto cloned = llvm::cast<IterateOp>(builder.clone(*innermost, mapper));
+  builder.setInsertionPointToStart(cloned.getBody());
+
+  LevelSet crdUsedLvls;
+  unsigned shift = 0, argIdx = 1;
+  for (auto info : toCollapse.drop_back()) {
+    LevelSet set = info.loop.getCrdUsedLvls();
+    crdUsedLvls |= set.lshift(shift);
+    shift += info.loop.getSpaceDim();
+    for (BlockArgument crd : info.loop.getCrds()) {
+      BlockArgument collapsedCrd = cloned.getBody()->insertArgument(
+          argIdx++, builder.getIndexType(), crd.getLoc());
+      crd.replaceAllUsesWith(collapsedCrd);
+    }
+  }
+  crdUsedLvls |= innermost.getCrdUsedLvls().lshift(shift);
+  cloned.getIterator().setType(collapsedSpace.getType().getIteratorType());
+  cloned.setCrdUsedLvls(crdUsedLvls);
+
+  rItOp.replaceAllUsesWith(cloned.getResults());
+  // Erase collapsed loops.
+  rItOp.erase();
+  root.erase();
+}
+
+struct SparseSpaceCollapsePass
+    : public impl::SparseSpaceCollapseBase<SparseSpaceCollapsePass> {
+  SparseSpaceCollapsePass() = default;
+
+  void runOnOperation() override {
+    func::FuncOp func = getOperation();
+
+    // A naive (experimental) implementation to collapse consecutive sparse
+    // spaces. It does NOT handle complex cases where multiple spaces are
+    // extracted in the same basic block. E.g.,
+    //
+    // %space1 = extract_space %t1 ...
+    // %space2 = extract_space %t2 ...
+    // sparse_tensor.iterate(%sp1) ...
+    //
+    SmallVector<CollapseSpaceInfo> toCollapse;
+    func->walk([&](ExtractIterSpaceOp op) {
+      if (!legalToCollapse(toCollapse, op)) {
+        // if not legal to collapse one more space, collapse the existing ones
+        // and clear.
+        collapseSparseSpace(toCollapse);
+        toCollapse.clear();
+      }
+    });
+
+    collapseSparseSpace(toCollapse);
+  }
+};
+
+} // namespace sparse_tensor
+
+std::unique_ptr<Pass> createSparseSpaceCollapsePass() {
+  return std::make_unique<sparse_tensor::SparseSpaceCollapsePass>();
+}
+
+} // namespace mlir
diff --git a/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir b/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir
new file mode 100644
index 0000000000000..392dfe01884ba
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-opt %s --sparse-space-collapse | FileCheck %s
+
+#COO = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : compressed(nonunique),
+    j : singleton(soa)
+  )
+}>
+
+// CHECK-LABEL:   func.func @sparse_sparse_collapse(
+// CHECK-SAME:         %[[VAL_0:.*]]: tensor<4x8xf32, #sparse>,
+// CHECK-SAME:         %[[VAL_1:.*]]: index) {
+// CHECK:           %[[VAL_3:.*]] = sparse_tensor.extract_iteration_space %[[VAL_0]] lvls = 0 to 2 : tensor<4x8xf32, #sparse>
+// CHECK:           %[[VAL_4:.*]] = sparse_tensor.iterate %[[VAL_5:.*]] in %[[VAL_3]] at(%[[VAL_6:.*]], _) iter_args(%[[VAL_7:.*]] = %[[VAL_1]])
+// CHECK:             %[[VAL_8:.*]] = "test.op"(%[[VAL_7]]) : (index) -> index
+// CHECK:             sparse_tensor.yield %[[VAL_8]] : index
+// CHECK:           }
+// CHECK:           "test.sink"(%[[VAL_4]]) : (index) -> ()
+// CHECK:           return
+// CHECK:         }
+func.func @sparse_sparse_collapse(%sp : tensor<4x8xf32, #COO>, %i : index) {
+  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+  %r1 = sparse_tensor.iterate %it1 in %l1 at(%crd0) iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0 to 1> -> index {
+    %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
+    %r2 = sparse_tensor.iterate %it2 in %l2 iter_args(%inner = %outer): !sparse_tensor.iter_space<#COO, lvls = 1 to 2> -> index {
+      %k ="test.op"(%inner) : (index) -> index
+      sparse_tensor.yield %k : index
+    }
+    sparse_tensor.yield %r2 : index
+  }
+  "test.sink"(%r1) : (index) -> ()
+  return
+}

>From 99247e986552609460a2a400de839471cc17f0ec Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 21 May 2024 18:19:54 +0000
Subject: [PATCH 2/4] rebase

---
 .../SparseTensor/Transforms/SparseSpaceCollapse.cpp       | 4 ++--
 mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir | 8 ++++++--
 2 files changed, 8 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
index bc469992d9710..4d06603a59862 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
@@ -53,7 +53,7 @@ bool legalToCollapse(SmallVectorImpl<CollapseSpaceInfo> &toCollapse,
                      ExtractIterSpaceOp curSpace) {
 
   auto getIterateOpOverSpace = [](ExtractIterSpaceOp space) -> IterateOp {
-    Value spaceVal = space.getResultSpace();
+    Value spaceVal = space.getExtractedSpace();
     if (spaceVal.hasOneUse())
       return llvm::dyn_cast<IterateOp>(*spaceVal.getUsers().begin());
     return nullptr;
@@ -116,7 +116,7 @@ void collapseSparseSpace(MutableArrayRef<CollapseSpaceInfo> toCollapse) {
   auto innermost = toCollapse.back().loop;
 
   IRMapping mapper;
-  mapper.map(leaf, collapsedSpace.getResultSpace());
+  mapper.map(leaf, collapsedSpace.getExtractedSpace());
   for (auto z : llvm::zip_equal(innermost.getInitArgs(), rItOp.getInitArgs()))
     mapper.map(std::get<0>(z), std::get<1>(z));
 
diff --git a/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir b/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir
index 392dfe01884ba..baa6199f12bc3 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir
@@ -19,9 +19,13 @@
 // CHECK:           return
 // CHECK:         }
 func.func @sparse_sparse_collapse(%sp : tensor<4x8xf32, #COO>, %i : index) {
-  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0
+      : tensor<4x8xf32, #COO>
+     -> !sparse_tensor.iter_space<#COO, lvls = 0>
   %r1 = sparse_tensor.iterate %it1 in %l1 at(%crd0) iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0 to 1> -> index {
-    %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
+    %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1
+        : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
+       -> !sparse_tensor.iter_space<#COO, lvls = 1>
     %r2 = sparse_tensor.iterate %it2 in %l2 iter_args(%inner = %outer): !sparse_tensor.iter_space<#COO, lvls = 1 to 2> -> index {
       %k ="test.op"(%inner) : (index) -> index
       sparse_tensor.yield %k : index

>From 6354d181864b75cc72289833b903e80bdfee852e Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 10 Jun 2024 17:47:38 +0000
Subject: [PATCH 3/4] address comments

---
 .../Dialect/SparseTensor/Transforms/Passes.td     |  2 +-
 .../Transforms/SparseSpaceCollapse.cpp            | 15 ++++++++-------
 2 files changed, 9 insertions(+), 8 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 42badd1bd9756..d7f9adcb75bb6 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -471,7 +471,7 @@ def SparsificationAndBufferization : Pass<"sparsification-and-bufferization", "M
 def SparseSpaceCollapse : Pass<"sparse-space-collapse", "func::FuncOp"> {
   let summary = "(experimental) sparse space collpasing pass";
   let description = [{
-     This pass collapse consecutive sparse spaces (extracted from the same tensor)
+     This pass collapses consecutive sparse spaces (extracted from the same tensor)
      into one multi-dimensional space. The pass is not yet stablized.
   }];
   let constructor = "mlir::createSparseSpaceCollapsePass()";
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
index 4d06603a59862..9d5a5b84e4101 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
@@ -14,11 +14,14 @@
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
 
 namespace mlir {
-
 #define GEN_PASS_DEF_SPARSESPACECOLLAPSE
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace sparse_tensor;
 
-namespace sparse_tensor {
+namespace {
 
 struct CollapseSpaceInfo {
   ExtractIterSpaceOp space;
@@ -174,10 +177,8 @@ struct SparseSpaceCollapsePass
   }
 };
 
-} // namespace sparse_tensor
+} // namespace
 
-std::unique_ptr<Pass> createSparseSpaceCollapsePass() {
-  return std::make_unique<sparse_tensor::SparseSpaceCollapsePass>();
+std::unique_ptr<Pass> mlir::createSparseSpaceCollapsePass() {
+  return std::make_unique<SparseSpaceCollapsePass>();
 }
-
-} // namespace mlir

>From 730c68562302096be6c1c2619e3be3711da9704e Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 10 Jun 2024 21:13:54 +0000
Subject: [PATCH 4/4] address comments

---
 .../Transforms/SparseSpaceCollapse.cpp        | 21 ++++++++++++++++---
 1 file changed, 18 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
index 9d5a5b84e4101..924046fcd9961 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
@@ -18,6 +18,8 @@ namespace mlir {
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
 } // namespace mlir
 
+#define DEBUG_TYPE "sparse-space-collapse"
+
 using namespace mlir;
 using namespace sparse_tensor;
 
@@ -78,18 +80,31 @@ bool legalToCollapse(SmallVectorImpl<CollapseSpaceInfo> &toCollapse,
   auto nItOp = getIterateOpOverSpace(curSpace);
 
   // Can only collapse spaces extracted from the same tensor.
-  if (parent.getTensor() != curSpace.getTensor())
+  if (parent.getTensor() != curSpace.getTensor()) {
+    LLVM_DEBUG({
+      llvm::dbgs()
+          << "failed to collpase spaces extracted from different tensors.";
+    });
     return false;
+  }
 
   // Can only collapse consecutive simple iteration on one tensor (i.e., no
   // coiteration).
   if (!nItOp || nItOp->getBlock() != curSpace->getBlock() ||
       pItOp.getIterator() != curSpace.getParentIter() ||
-      curSpace->getParentOp() != pItOp.getOperation())
+      curSpace->getParentOp() != pItOp.getOperation()) {
+    LLVM_DEBUG(
+        { llvm::dbgs() << "failed to collapse non-consecutive IterateOps."; });
     return false;
+  }
 
-  if (pItOp && !isCollapsableLoops(pItOp, nItOp))
+  if (pItOp && !isCollapsableLoops(pItOp, nItOp)) {
+    LLVM_DEBUG({
+      llvm::dbgs()
+          << "failed to collapse IterateOps that are not perfectly nested.";
+    });
     return false;
+  }
 
   CollapseSpaceInfo &info = toCollapse.emplace_back();
   info.space = curSpace;



More information about the Mlir-commits mailing list