[Mlir-commits] [mlir] [mlir][sparse] implement sparse space collapse pass. (PR #89003)
Peiming Liu
llvmlistbot at llvm.org
Mon Jun 10 14:14:13 PDT 2024
https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/89003
>From cbe0ef1fb34e59e7fc4562e2b3b25584f22410c3 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 16 Apr 2024 17:00:45 +0000
Subject: [PATCH 1/4] [mlir][sparse] implement sparse space collapse pass.
---
.../Dialect/SparseTensor/Transforms/Passes.h | 6 +
.../Dialect/SparseTensor/Transforms/Passes.td | 16 ++
.../SparseTensor/Transforms/CMakeLists.txt | 1 +
.../Transforms/SparseSpaceCollapse.cpp | 183 ++++++++++++++++++
.../SparseTensor/sparse_space_collapse.mlir | 33 ++++
5 files changed, 239 insertions(+)
create mode 100644 mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
create mode 100644 mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index d6d038ef65bdf..3043a0c4dc410 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -248,6 +248,12 @@ std::unique_ptr<Pass> createSparsificationAndBufferizationPass(
bool enableBufferInitialization, unsigned vectorLength,
bool enableVLAVectorization, bool enableSIMDIndex32, bool enableGPULibgen);
+//===----------------------------------------------------------------------===//
+// Sparse Iteration Transform Passes
+//===----------------------------------------------------------------------===//
+
+std::unique_ptr<Pass> createSparseSpaceCollapsePass();
+
//===----------------------------------------------------------------------===//
// Registration.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 2f844cee5ff52..42badd1bd9756 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -464,4 +464,20 @@ def SparsificationAndBufferization : Pass<"sparsification-and-bufferization", "M
];
}
+//===----------------------------------------------------------------------===//
+// Sparse Iteration Transform Passes
+//===----------------------------------------------------------------------===//
+
+def SparseSpaceCollapse : Pass<"sparse-space-collapse", "func::FuncOp"> {
+ let summary = "(experimental) sparse space collpasing pass";
+ let description = [{
+ This pass collapse consecutive sparse spaces (extracted from the same tensor)
+ into one multi-dimensional space. The pass is not yet stablized.
+ }];
+ let constructor = "mlir::createSparseSpaceCollapsePass()";
+ let dependentDialects = [
+ "sparse_tensor::SparseTensorDialect",
+ ];
+}
+
#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index af3a1b48f45af..2a29ee8a7a87c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
SparseGPUCodegen.cpp
SparseReinterpretMap.cpp
SparseStorageSpecifierToLLVM.cpp
+ SparseSpaceCollapse.cpp
SparseTensorCodegen.cpp
SparseTensorConversion.cpp
SparseTensorPasses.cpp
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
new file mode 100644
index 0000000000000..bc469992d9710
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
@@ -0,0 +1,183 @@
+//===--------- SparseSpaceCollapse.cpp - Collapse Sparse Space Pass -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/Transforms/Passes.h"
+
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+
+namespace mlir {
+
+#define GEN_PASS_DEF_SPARSESPACECOLLAPSE
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
+
+namespace sparse_tensor {
+
+struct CollapseSpaceInfo {
+ ExtractIterSpaceOp space;
+ IterateOp loop;
+};
+
+bool isCollapsableLoops(LoopLikeOpInterface parent, LoopLikeOpInterface node) {
+ auto pIterArgs = parent.getRegionIterArgs();
+ auto nInitArgs = node.getInits();
+ if (pIterArgs.size() != nInitArgs.size())
+ return false;
+
+ // Two loops are collapsable if they are perfectly nested.
+ auto pYields = parent.getYieldedValues();
+ auto nResult = node.getLoopResults().value();
+
+ bool yieldEq =
+ llvm::all_of(llvm::zip_equal(pYields, nResult), [](auto zipped) {
+ return std::get<0>(zipped) == std::get<1>(zipped);
+ });
+
+ // Parent iter_args should be passed directly to the node's init_args.
+ bool iterArgEq =
+ llvm::all_of(llvm::zip_equal(pIterArgs, nInitArgs), [](auto zipped) {
+ return std::get<0>(zipped) == std::get<1>(zipped);
+ });
+
+ return yieldEq && iterArgEq;
+}
+
+bool legalToCollapse(SmallVectorImpl<CollapseSpaceInfo> &toCollapse,
+ ExtractIterSpaceOp curSpace) {
+
+ auto getIterateOpOverSpace = [](ExtractIterSpaceOp space) -> IterateOp {
+ Value spaceVal = space.getResultSpace();
+ if (spaceVal.hasOneUse())
+ return llvm::dyn_cast<IterateOp>(*spaceVal.getUsers().begin());
+ return nullptr;
+ };
+
+ if (toCollapse.empty()) {
+ // Collapse root.
+ if (auto itOp = getIterateOpOverSpace(curSpace)) {
+ CollapseSpaceInfo &info = toCollapse.emplace_back();
+ info.space = curSpace;
+ info.loop = itOp;
+ return true;
+ }
+ return false;
+ }
+
+ auto parent = toCollapse.back().space;
+ auto pItOp = toCollapse.back().loop;
+ auto nItOp = getIterateOpOverSpace(curSpace);
+
+ // Can only collapse spaces extracted from the same tensor.
+ if (parent.getTensor() != curSpace.getTensor())
+ return false;
+
+ // Can only collapse consecutive simple iteration on one tensor (i.e., no
+ // coiteration).
+ if (!nItOp || nItOp->getBlock() != curSpace->getBlock() ||
+ pItOp.getIterator() != curSpace.getParentIter() ||
+ curSpace->getParentOp() != pItOp.getOperation())
+ return false;
+
+ if (pItOp && !isCollapsableLoops(pItOp, nItOp))
+ return false;
+
+ CollapseSpaceInfo &info = toCollapse.emplace_back();
+ info.space = curSpace;
+ info.loop = nItOp;
+ return true;
+}
+
+void collapseSparseSpace(MutableArrayRef<CollapseSpaceInfo> toCollapse) {
+ if (toCollapse.size() < 2)
+ return;
+
+ ExtractIterSpaceOp root = toCollapse.front().space;
+ ExtractIterSpaceOp leaf = toCollapse.back().space;
+ Location loc = root.getLoc();
+
+ assert(root->hasOneUse() && leaf->hasOneUse());
+
+ // Insert collapsed operation at the same scope as root operation.
+ OpBuilder builder(root);
+
+ // Construct the collapsed iteration space.
+ auto collapsedSpace = builder.create<ExtractIterSpaceOp>(
+ loc, root.getTensor(), root.getParentIter(), root.getLoLvl(),
+ leaf.getHiLvl());
+
+ auto rItOp = llvm::cast<IterateOp>(*root->getUsers().begin());
+ auto innermost = toCollapse.back().loop;
+
+ IRMapping mapper;
+ mapper.map(leaf, collapsedSpace.getResultSpace());
+ for (auto z : llvm::zip_equal(innermost.getInitArgs(), rItOp.getInitArgs()))
+ mapper.map(std::get<0>(z), std::get<1>(z));
+
+ auto cloned = llvm::cast<IterateOp>(builder.clone(*innermost, mapper));
+ builder.setInsertionPointToStart(cloned.getBody());
+
+ LevelSet crdUsedLvls;
+ unsigned shift = 0, argIdx = 1;
+ for (auto info : toCollapse.drop_back()) {
+ LevelSet set = info.loop.getCrdUsedLvls();
+ crdUsedLvls |= set.lshift(shift);
+ shift += info.loop.getSpaceDim();
+ for (BlockArgument crd : info.loop.getCrds()) {
+ BlockArgument collapsedCrd = cloned.getBody()->insertArgument(
+ argIdx++, builder.getIndexType(), crd.getLoc());
+ crd.replaceAllUsesWith(collapsedCrd);
+ }
+ }
+ crdUsedLvls |= innermost.getCrdUsedLvls().lshift(shift);
+ cloned.getIterator().setType(collapsedSpace.getType().getIteratorType());
+ cloned.setCrdUsedLvls(crdUsedLvls);
+
+ rItOp.replaceAllUsesWith(cloned.getResults());
+ // Erase collapsed loops.
+ rItOp.erase();
+ root.erase();
+}
+
+struct SparseSpaceCollapsePass
+ : public impl::SparseSpaceCollapseBase<SparseSpaceCollapsePass> {
+ SparseSpaceCollapsePass() = default;
+
+ void runOnOperation() override {
+ func::FuncOp func = getOperation();
+
+ // A naive (experimental) implementation to collapse consecutive sparse
+ // spaces. It does NOT handle complex cases where multiple spaces are
+ // extracted in the same basic block. E.g.,
+ //
+ // %space1 = extract_space %t1 ...
+ // %space2 = extract_space %t2 ...
+ // sparse_tensor.iterate(%sp1) ...
+ //
+ SmallVector<CollapseSpaceInfo> toCollapse;
+ func->walk([&](ExtractIterSpaceOp op) {
+ if (!legalToCollapse(toCollapse, op)) {
+ // if not legal to collapse one more space, collapse the existing ones
+ // and clear.
+ collapseSparseSpace(toCollapse);
+ toCollapse.clear();
+ }
+ });
+
+ collapseSparseSpace(toCollapse);
+ }
+};
+
+} // namespace sparse_tensor
+
+std::unique_ptr<Pass> createSparseSpaceCollapsePass() {
+ return std::make_unique<sparse_tensor::SparseSpaceCollapsePass>();
+}
+
+} // namespace mlir
diff --git a/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir b/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir
new file mode 100644
index 0000000000000..392dfe01884ba
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-opt %s --sparse-space-collapse | FileCheck %s
+
+#COO = #sparse_tensor.encoding<{
+ map = (i, j) -> (
+ i : compressed(nonunique),
+ j : singleton(soa)
+ )
+}>
+
+// CHECK-LABEL: func.func @sparse_sparse_collapse(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x8xf32, #sparse>,
+// CHECK-SAME: %[[VAL_1:.*]]: index) {
+// CHECK: %[[VAL_3:.*]] = sparse_tensor.extract_iteration_space %[[VAL_0]] lvls = 0 to 2 : tensor<4x8xf32, #sparse>
+// CHECK: %[[VAL_4:.*]] = sparse_tensor.iterate %[[VAL_5:.*]] in %[[VAL_3]] at(%[[VAL_6:.*]], _) iter_args(%[[VAL_7:.*]] = %[[VAL_1]])
+// CHECK: %[[VAL_8:.*]] = "test.op"(%[[VAL_7]]) : (index) -> index
+// CHECK: sparse_tensor.yield %[[VAL_8]] : index
+// CHECK: }
+// CHECK: "test.sink"(%[[VAL_4]]) : (index) -> ()
+// CHECK: return
+// CHECK: }
+func.func @sparse_sparse_collapse(%sp : tensor<4x8xf32, #COO>, %i : index) {
+ %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+ %r1 = sparse_tensor.iterate %it1 in %l1 at(%crd0) iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0 to 1> -> index {
+ %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
+ %r2 = sparse_tensor.iterate %it2 in %l2 iter_args(%inner = %outer): !sparse_tensor.iter_space<#COO, lvls = 1 to 2> -> index {
+ %k ="test.op"(%inner) : (index) -> index
+ sparse_tensor.yield %k : index
+ }
+ sparse_tensor.yield %r2 : index
+ }
+ "test.sink"(%r1) : (index) -> ()
+ return
+}
>From 99247e986552609460a2a400de839471cc17f0ec Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 21 May 2024 18:19:54 +0000
Subject: [PATCH 2/4] rebase
---
.../SparseTensor/Transforms/SparseSpaceCollapse.cpp | 4 ++--
mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir | 8 ++++++--
2 files changed, 8 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
index bc469992d9710..4d06603a59862 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
@@ -53,7 +53,7 @@ bool legalToCollapse(SmallVectorImpl<CollapseSpaceInfo> &toCollapse,
ExtractIterSpaceOp curSpace) {
auto getIterateOpOverSpace = [](ExtractIterSpaceOp space) -> IterateOp {
- Value spaceVal = space.getResultSpace();
+ Value spaceVal = space.getExtractedSpace();
if (spaceVal.hasOneUse())
return llvm::dyn_cast<IterateOp>(*spaceVal.getUsers().begin());
return nullptr;
@@ -116,7 +116,7 @@ void collapseSparseSpace(MutableArrayRef<CollapseSpaceInfo> toCollapse) {
auto innermost = toCollapse.back().loop;
IRMapping mapper;
- mapper.map(leaf, collapsedSpace.getResultSpace());
+ mapper.map(leaf, collapsedSpace.getExtractedSpace());
for (auto z : llvm::zip_equal(innermost.getInitArgs(), rItOp.getInitArgs()))
mapper.map(std::get<0>(z), std::get<1>(z));
diff --git a/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir b/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir
index 392dfe01884ba..baa6199f12bc3 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir
@@ -19,9 +19,13 @@
// CHECK: return
// CHECK: }
func.func @sparse_sparse_collapse(%sp : tensor<4x8xf32, #COO>, %i : index) {
- %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+ %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0
+ : tensor<4x8xf32, #COO>
+ -> !sparse_tensor.iter_space<#COO, lvls = 0>
%r1 = sparse_tensor.iterate %it1 in %l1 at(%crd0) iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0 to 1> -> index {
- %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
+ %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1
+ : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
+ -> !sparse_tensor.iter_space<#COO, lvls = 1>
%r2 = sparse_tensor.iterate %it2 in %l2 iter_args(%inner = %outer): !sparse_tensor.iter_space<#COO, lvls = 1 to 2> -> index {
%k ="test.op"(%inner) : (index) -> index
sparse_tensor.yield %k : index
>From 6354d181864b75cc72289833b903e80bdfee852e Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 10 Jun 2024 17:47:38 +0000
Subject: [PATCH 3/4] address comments
---
.../Dialect/SparseTensor/Transforms/Passes.td | 2 +-
.../Transforms/SparseSpaceCollapse.cpp | 15 ++++++++-------
2 files changed, 9 insertions(+), 8 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 42badd1bd9756..d7f9adcb75bb6 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -471,7 +471,7 @@ def SparsificationAndBufferization : Pass<"sparsification-and-bufferization", "M
def SparseSpaceCollapse : Pass<"sparse-space-collapse", "func::FuncOp"> {
let summary = "(experimental) sparse space collpasing pass";
let description = [{
- This pass collapse consecutive sparse spaces (extracted from the same tensor)
+ This pass collapses consecutive sparse spaces (extracted from the same tensor)
into one multi-dimensional space. The pass is not yet stablized.
}];
let constructor = "mlir::createSparseSpaceCollapsePass()";
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
index 4d06603a59862..9d5a5b84e4101 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
@@ -14,11 +14,14 @@
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
namespace mlir {
-
#define GEN_PASS_DEF_SPARSESPACECOLLAPSE
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace sparse_tensor;
-namespace sparse_tensor {
+namespace {
struct CollapseSpaceInfo {
ExtractIterSpaceOp space;
@@ -174,10 +177,8 @@ struct SparseSpaceCollapsePass
}
};
-} // namespace sparse_tensor
+} // namespace
-std::unique_ptr<Pass> createSparseSpaceCollapsePass() {
- return std::make_unique<sparse_tensor::SparseSpaceCollapsePass>();
+std::unique_ptr<Pass> mlir::createSparseSpaceCollapsePass() {
+ return std::make_unique<SparseSpaceCollapsePass>();
}
-
-} // namespace mlir
>From 730c68562302096be6c1c2619e3be3711da9704e Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 10 Jun 2024 21:13:54 +0000
Subject: [PATCH 4/4] address comments
---
.../Transforms/SparseSpaceCollapse.cpp | 21 ++++++++++++++++---
1 file changed, 18 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
index 9d5a5b84e4101..924046fcd9961 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
@@ -18,6 +18,8 @@ namespace mlir {
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
} // namespace mlir
+#define DEBUG_TYPE "sparse-space-collapse"
+
using namespace mlir;
using namespace sparse_tensor;
@@ -78,18 +80,31 @@ bool legalToCollapse(SmallVectorImpl<CollapseSpaceInfo> &toCollapse,
auto nItOp = getIterateOpOverSpace(curSpace);
// Can only collapse spaces extracted from the same tensor.
- if (parent.getTensor() != curSpace.getTensor())
+ if (parent.getTensor() != curSpace.getTensor()) {
+ LLVM_DEBUG({
+ llvm::dbgs()
+ << "failed to collpase spaces extracted from different tensors.";
+ });
return false;
+ }
// Can only collapse consecutive simple iteration on one tensor (i.e., no
// coiteration).
if (!nItOp || nItOp->getBlock() != curSpace->getBlock() ||
pItOp.getIterator() != curSpace.getParentIter() ||
- curSpace->getParentOp() != pItOp.getOperation())
+ curSpace->getParentOp() != pItOp.getOperation()) {
+ LLVM_DEBUG(
+ { llvm::dbgs() << "failed to collapse non-consecutive IterateOps."; });
return false;
+ }
- if (pItOp && !isCollapsableLoops(pItOp, nItOp))
+ if (pItOp && !isCollapsableLoops(pItOp, nItOp)) {
+ LLVM_DEBUG({
+ llvm::dbgs()
+ << "failed to collapse IterateOps that are not perfectly nested.";
+ });
return false;
+ }
CollapseSpaceInfo &info = toCollapse.emplace_back();
info.space = curSpace;
More information about the Mlir-commits
mailing list