[Mlir-commits] [mlir] c6d85ba - [mlir][sparse] implement sparse space collapse pass. (#89003)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jun 11 12:10:58 PDT 2024
Author: Peiming Liu
Date: 2024-06-11T12:10:54-07:00
New Revision: c6d85baf9f12f69915559aff5ed6c48b63daafdd
URL: https://github.com/llvm/llvm-project/commit/c6d85baf9f12f69915559aff5ed6c48b63daafdd
DIFF: https://github.com/llvm/llvm-project/commit/c6d85baf9f12f69915559aff5ed6c48b63daafdd.diff
LOG: [mlir][sparse] implement sparse space collapse pass. (#89003)
Added:
mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir
Modified:
mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index d6d038ef65bdf..3043a0c4dc410 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -248,6 +248,12 @@ std::unique_ptr<Pass> createSparsificationAndBufferizationPass(
bool enableBufferInitialization, unsigned vectorLength,
bool enableVLAVectorization, bool enableSIMDIndex32, bool enableGPULibgen);
+//===----------------------------------------------------------------------===//
+// Sparse Iteration Transform Passes
+//===----------------------------------------------------------------------===//
+
+std::unique_ptr<Pass> createSparseSpaceCollapsePass();
+
//===----------------------------------------------------------------------===//
// Registration.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 2f844cee5ff52..c6554e1c94a4a 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -464,4 +464,20 @@ def SparsificationAndBufferization : Pass<"sparsification-and-bufferization", "M
];
}
+//===----------------------------------------------------------------------===//
+// Sparse Iteration Transform Passes
+//===----------------------------------------------------------------------===//
+
+def SparseSpaceCollapse : Pass<"sparse-space-collapse", "func::FuncOp"> {
+ let summary = "sparse space collapsing pass";
+ let description = [{
+ This pass collapses consecutive sparse spaces (extracted from the same tensor)
+ into one multi-dimensional space. The pass is not yet stablized.
+ }];
+ let constructor = "mlir::createSparseSpaceCollapsePass()";
+ let dependentDialects = [
+ "sparse_tensor::SparseTensorDialect",
+ ];
+}
+
#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index af3a1b48f45af..2a29ee8a7a87c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
SparseGPUCodegen.cpp
SparseReinterpretMap.cpp
SparseStorageSpecifierToLLVM.cpp
+ SparseSpaceCollapse.cpp
SparseTensorCodegen.cpp
SparseTensorConversion.cpp
SparseTensorPasses.cpp
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
new file mode 100644
index 0000000000000..924046fcd9961
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
@@ -0,0 +1,199 @@
+//===--------- SparseSpaceCollapse.cpp - Collapse Sparse Space Pass -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/Transforms/Passes.h"
+
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_SPARSESPACECOLLAPSE
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
+} // namespace mlir
+
+#define DEBUG_TYPE "sparse-space-collapse"
+
+using namespace mlir;
+using namespace sparse_tensor;
+
+namespace {
+
+struct CollapseSpaceInfo {
+ ExtractIterSpaceOp space;
+ IterateOp loop;
+};
+
+bool isCollapsableLoops(LoopLikeOpInterface parent, LoopLikeOpInterface node) {
+ auto pIterArgs = parent.getRegionIterArgs();
+ auto nInitArgs = node.getInits();
+ if (pIterArgs.size() != nInitArgs.size())
+ return false;
+
+ // Two loops are collapsable if they are perfectly nested.
+ auto pYields = parent.getYieldedValues();
+ auto nResult = node.getLoopResults().value();
+
+ bool yieldEq =
+ llvm::all_of(llvm::zip_equal(pYields, nResult), [](auto zipped) {
+ return std::get<0>(zipped) == std::get<1>(zipped);
+ });
+
+ // Parent iter_args should be passed directly to the node's init_args.
+ bool iterArgEq =
+ llvm::all_of(llvm::zip_equal(pIterArgs, nInitArgs), [](auto zipped) {
+ return std::get<0>(zipped) == std::get<1>(zipped);
+ });
+
+ return yieldEq && iterArgEq;
+}
+
+bool legalToCollapse(SmallVectorImpl<CollapseSpaceInfo> &toCollapse,
+ ExtractIterSpaceOp curSpace) {
+
+ auto getIterateOpOverSpace = [](ExtractIterSpaceOp space) -> IterateOp {
+ Value spaceVal = space.getExtractedSpace();
+ if (spaceVal.hasOneUse())
+ return llvm::dyn_cast<IterateOp>(*spaceVal.getUsers().begin());
+ return nullptr;
+ };
+
+ if (toCollapse.empty()) {
+ // Collapse root.
+ if (auto itOp = getIterateOpOverSpace(curSpace)) {
+ CollapseSpaceInfo &info = toCollapse.emplace_back();
+ info.space = curSpace;
+ info.loop = itOp;
+ return true;
+ }
+ return false;
+ }
+
+ auto parent = toCollapse.back().space;
+ auto pItOp = toCollapse.back().loop;
+ auto nItOp = getIterateOpOverSpace(curSpace);
+
+ // Can only collapse spaces extracted from the same tensor.
+ if (parent.getTensor() != curSpace.getTensor()) {
+ LLVM_DEBUG({
+ llvm::dbgs()
+ << "failed to collpase spaces extracted from
diff erent tensors.";
+ });
+ return false;
+ }
+
+ // Can only collapse consecutive simple iteration on one tensor (i.e., no
+ // coiteration).
+ if (!nItOp || nItOp->getBlock() != curSpace->getBlock() ||
+ pItOp.getIterator() != curSpace.getParentIter() ||
+ curSpace->getParentOp() != pItOp.getOperation()) {
+ LLVM_DEBUG(
+ { llvm::dbgs() << "failed to collapse non-consecutive IterateOps."; });
+ return false;
+ }
+
+ if (pItOp && !isCollapsableLoops(pItOp, nItOp)) {
+ LLVM_DEBUG({
+ llvm::dbgs()
+ << "failed to collapse IterateOps that are not perfectly nested.";
+ });
+ return false;
+ }
+
+ CollapseSpaceInfo &info = toCollapse.emplace_back();
+ info.space = curSpace;
+ info.loop = nItOp;
+ return true;
+}
+
+void collapseSparseSpace(MutableArrayRef<CollapseSpaceInfo> toCollapse) {
+ if (toCollapse.size() < 2)
+ return;
+
+ ExtractIterSpaceOp root = toCollapse.front().space;
+ ExtractIterSpaceOp leaf = toCollapse.back().space;
+ Location loc = root.getLoc();
+
+ assert(root->hasOneUse() && leaf->hasOneUse());
+
+ // Insert collapsed operation at the same scope as root operation.
+ OpBuilder builder(root);
+
+ // Construct the collapsed iteration space.
+ auto collapsedSpace = builder.create<ExtractIterSpaceOp>(
+ loc, root.getTensor(), root.getParentIter(), root.getLoLvl(),
+ leaf.getHiLvl());
+
+ auto rItOp = llvm::cast<IterateOp>(*root->getUsers().begin());
+ auto innermost = toCollapse.back().loop;
+
+ IRMapping mapper;
+ mapper.map(leaf, collapsedSpace.getExtractedSpace());
+ for (auto z : llvm::zip_equal(innermost.getInitArgs(), rItOp.getInitArgs()))
+ mapper.map(std::get<0>(z), std::get<1>(z));
+
+ auto cloned = llvm::cast<IterateOp>(builder.clone(*innermost, mapper));
+ builder.setInsertionPointToStart(cloned.getBody());
+
+ LevelSet crdUsedLvls;
+ unsigned shift = 0, argIdx = 1;
+ for (auto info : toCollapse.drop_back()) {
+ LevelSet set = info.loop.getCrdUsedLvls();
+ crdUsedLvls |= set.lshift(shift);
+ shift += info.loop.getSpaceDim();
+ for (BlockArgument crd : info.loop.getCrds()) {
+ BlockArgument collapsedCrd = cloned.getBody()->insertArgument(
+ argIdx++, builder.getIndexType(), crd.getLoc());
+ crd.replaceAllUsesWith(collapsedCrd);
+ }
+ }
+ crdUsedLvls |= innermost.getCrdUsedLvls().lshift(shift);
+ cloned.getIterator().setType(collapsedSpace.getType().getIteratorType());
+ cloned.setCrdUsedLvls(crdUsedLvls);
+
+ rItOp.replaceAllUsesWith(cloned.getResults());
+ // Erase collapsed loops.
+ rItOp.erase();
+ root.erase();
+}
+
+struct SparseSpaceCollapsePass
+ : public impl::SparseSpaceCollapseBase<SparseSpaceCollapsePass> {
+ SparseSpaceCollapsePass() = default;
+
+ void runOnOperation() override {
+ func::FuncOp func = getOperation();
+
+ // A naive (experimental) implementation to collapse consecutive sparse
+ // spaces. It does NOT handle complex cases where multiple spaces are
+ // extracted in the same basic block. E.g.,
+ //
+ // %space1 = extract_space %t1 ...
+ // %space2 = extract_space %t2 ...
+ // sparse_tensor.iterate(%sp1) ...
+ //
+ SmallVector<CollapseSpaceInfo> toCollapse;
+ func->walk([&](ExtractIterSpaceOp op) {
+ if (!legalToCollapse(toCollapse, op)) {
+ // if not legal to collapse one more space, collapse the existing ones
+ // and clear.
+ collapseSparseSpace(toCollapse);
+ toCollapse.clear();
+ }
+ });
+
+ collapseSparseSpace(toCollapse);
+ }
+};
+
+} // namespace
+
+std::unique_ptr<Pass> mlir::createSparseSpaceCollapsePass() {
+ return std::make_unique<SparseSpaceCollapsePass>();
+}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir b/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir
new file mode 100644
index 0000000000000..baa6199f12bc3
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir
@@ -0,0 +1,37 @@
+// RUN: mlir-opt %s --sparse-space-collapse | FileCheck %s
+
+#COO = #sparse_tensor.encoding<{
+ map = (i, j) -> (
+ i : compressed(nonunique),
+ j : singleton(soa)
+ )
+}>
+
+// CHECK-LABEL: func.func @sparse_sparse_collapse(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x8xf32, #sparse>,
+// CHECK-SAME: %[[VAL_1:.*]]: index) {
+// CHECK: %[[VAL_3:.*]] = sparse_tensor.extract_iteration_space %[[VAL_0]] lvls = 0 to 2 : tensor<4x8xf32, #sparse>
+// CHECK: %[[VAL_4:.*]] = sparse_tensor.iterate %[[VAL_5:.*]] in %[[VAL_3]] at(%[[VAL_6:.*]], _) iter_args(%[[VAL_7:.*]] = %[[VAL_1]])
+// CHECK: %[[VAL_8:.*]] = "test.op"(%[[VAL_7]]) : (index) -> index
+// CHECK: sparse_tensor.yield %[[VAL_8]] : index
+// CHECK: }
+// CHECK: "test.sink"(%[[VAL_4]]) : (index) -> ()
+// CHECK: return
+// CHECK: }
+func.func @sparse_sparse_collapse(%sp : tensor<4x8xf32, #COO>, %i : index) {
+ %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0
+ : tensor<4x8xf32, #COO>
+ -> !sparse_tensor.iter_space<#COO, lvls = 0>
+ %r1 = sparse_tensor.iterate %it1 in %l1 at(%crd0) iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0 to 1> -> index {
+ %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1
+ : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
+ -> !sparse_tensor.iter_space<#COO, lvls = 1>
+ %r2 = sparse_tensor.iterate %it2 in %l2 iter_args(%inner = %outer): !sparse_tensor.iter_space<#COO, lvls = 1 to 2> -> index {
+ %k ="test.op"(%inner) : (index) -> index
+ sparse_tensor.yield %k : index
+ }
+ sparse_tensor.yield %r2 : index
+ }
+ "test.sink"(%r1) : (index) -> ()
+ return
+}
More information about the Mlir-commits
mailing list