[Mlir-commits] [mlir] c6d85ba - [mlir][sparse] implement sparse space collapse pass. (#89003)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jun 11 12:10:58 PDT 2024


Author: Peiming Liu
Date: 2024-06-11T12:10:54-07:00
New Revision: c6d85baf9f12f69915559aff5ed6c48b63daafdd

URL: https://github.com/llvm/llvm-project/commit/c6d85baf9f12f69915559aff5ed6c48b63daafdd
DIFF: https://github.com/llvm/llvm-project/commit/c6d85baf9f12f69915559aff5ed6c48b63daafdd.diff

LOG: [mlir][sparse] implement sparse space collapse pass. (#89003)

Added: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
    mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
    mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
    mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index d6d038ef65bdf..3043a0c4dc410 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -248,6 +248,12 @@ std::unique_ptr<Pass> createSparsificationAndBufferizationPass(
     bool enableBufferInitialization, unsigned vectorLength,
     bool enableVLAVectorization, bool enableSIMDIndex32, bool enableGPULibgen);
 
+//===----------------------------------------------------------------------===//
+// Sparse Iteration Transform Passes
+//===----------------------------------------------------------------------===//
+
+std::unique_ptr<Pass> createSparseSpaceCollapsePass();
+
 //===----------------------------------------------------------------------===//
 // Registration.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 2f844cee5ff52..c6554e1c94a4a 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -464,4 +464,20 @@ def SparsificationAndBufferization : Pass<"sparsification-and-bufferization", "M
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// Sparse Iteration Transform Passes
+//===----------------------------------------------------------------------===//
+
+def SparseSpaceCollapse : Pass<"sparse-space-collapse", "func::FuncOp"> {
+  let summary = "sparse space collapsing pass";
+  let description = [{
+     This pass collapses consecutive sparse spaces (extracted from the same tensor)
+     into one multi-dimensional space. The pass is not yet stablized.
+  }];
+  let constructor = "mlir::createSparseSpaceCollapsePass()";
+  let dependentDialects = [
+    "sparse_tensor::SparseTensorDialect",
+  ];
+}
+
 #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index af3a1b48f45af..2a29ee8a7a87c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
   SparseGPUCodegen.cpp
   SparseReinterpretMap.cpp
   SparseStorageSpecifierToLLVM.cpp
+  SparseSpaceCollapse.cpp
   SparseTensorCodegen.cpp
   SparseTensorConversion.cpp
   SparseTensorPasses.cpp

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
new file mode 100644
index 0000000000000..924046fcd9961
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseSpaceCollapse.cpp
@@ -0,0 +1,199 @@
+//===--------- SparseSpaceCollapse.cpp - Collapse Sparse Space Pass -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/Transforms/Passes.h"
+
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_SPARSESPACECOLLAPSE
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
+} // namespace mlir
+
+#define DEBUG_TYPE "sparse-space-collapse"
+
+using namespace mlir;
+using namespace sparse_tensor;
+
+namespace {
+
+struct CollapseSpaceInfo {
+  ExtractIterSpaceOp space;
+  IterateOp loop;
+};
+
+bool isCollapsableLoops(LoopLikeOpInterface parent, LoopLikeOpInterface node) {
+  auto pIterArgs = parent.getRegionIterArgs();
+  auto nInitArgs = node.getInits();
+  if (pIterArgs.size() != nInitArgs.size())
+    return false;
+
+  // Two loops are collapsable if they are perfectly nested.
+  auto pYields = parent.getYieldedValues();
+  auto nResult = node.getLoopResults().value();
+
+  bool yieldEq =
+      llvm::all_of(llvm::zip_equal(pYields, nResult), [](auto zipped) {
+        return std::get<0>(zipped) == std::get<1>(zipped);
+      });
+
+  // Parent iter_args should be passed directly to the node's init_args.
+  bool iterArgEq =
+      llvm::all_of(llvm::zip_equal(pIterArgs, nInitArgs), [](auto zipped) {
+        return std::get<0>(zipped) == std::get<1>(zipped);
+      });
+
+  return yieldEq && iterArgEq;
+}
+
+bool legalToCollapse(SmallVectorImpl<CollapseSpaceInfo> &toCollapse,
+                     ExtractIterSpaceOp curSpace) {
+
+  auto getIterateOpOverSpace = [](ExtractIterSpaceOp space) -> IterateOp {
+    Value spaceVal = space.getExtractedSpace();
+    if (spaceVal.hasOneUse())
+      return llvm::dyn_cast<IterateOp>(*spaceVal.getUsers().begin());
+    return nullptr;
+  };
+
+  if (toCollapse.empty()) {
+    // Collapse root.
+    if (auto itOp = getIterateOpOverSpace(curSpace)) {
+      CollapseSpaceInfo &info = toCollapse.emplace_back();
+      info.space = curSpace;
+      info.loop = itOp;
+      return true;
+    }
+    return false;
+  }
+
+  auto parent = toCollapse.back().space;
+  auto pItOp = toCollapse.back().loop;
+  auto nItOp = getIterateOpOverSpace(curSpace);
+
+  // Can only collapse spaces extracted from the same tensor.
+  if (parent.getTensor() != curSpace.getTensor()) {
+    LLVM_DEBUG({
+      llvm::dbgs()
+          << "failed to collpase spaces extracted from 
diff erent tensors.";
+    });
+    return false;
+  }
+
+  // Can only collapse consecutive simple iteration on one tensor (i.e., no
+  // coiteration).
+  if (!nItOp || nItOp->getBlock() != curSpace->getBlock() ||
+      pItOp.getIterator() != curSpace.getParentIter() ||
+      curSpace->getParentOp() != pItOp.getOperation()) {
+    LLVM_DEBUG(
+        { llvm::dbgs() << "failed to collapse non-consecutive IterateOps."; });
+    return false;
+  }
+
+  if (pItOp && !isCollapsableLoops(pItOp, nItOp)) {
+    LLVM_DEBUG({
+      llvm::dbgs()
+          << "failed to collapse IterateOps that are not perfectly nested.";
+    });
+    return false;
+  }
+
+  CollapseSpaceInfo &info = toCollapse.emplace_back();
+  info.space = curSpace;
+  info.loop = nItOp;
+  return true;
+}
+
+void collapseSparseSpace(MutableArrayRef<CollapseSpaceInfo> toCollapse) {
+  if (toCollapse.size() < 2)
+    return;
+
+  ExtractIterSpaceOp root = toCollapse.front().space;
+  ExtractIterSpaceOp leaf = toCollapse.back().space;
+  Location loc = root.getLoc();
+
+  assert(root->hasOneUse() && leaf->hasOneUse());
+
+  // Insert collapsed operation at the same scope as root operation.
+  OpBuilder builder(root);
+
+  // Construct the collapsed iteration space.
+  auto collapsedSpace = builder.create<ExtractIterSpaceOp>(
+      loc, root.getTensor(), root.getParentIter(), root.getLoLvl(),
+      leaf.getHiLvl());
+
+  auto rItOp = llvm::cast<IterateOp>(*root->getUsers().begin());
+  auto innermost = toCollapse.back().loop;
+
+  IRMapping mapper;
+  mapper.map(leaf, collapsedSpace.getExtractedSpace());
+  for (auto z : llvm::zip_equal(innermost.getInitArgs(), rItOp.getInitArgs()))
+    mapper.map(std::get<0>(z), std::get<1>(z));
+
+  auto cloned = llvm::cast<IterateOp>(builder.clone(*innermost, mapper));
+  builder.setInsertionPointToStart(cloned.getBody());
+
+  LevelSet crdUsedLvls;
+  unsigned shift = 0, argIdx = 1;
+  for (auto info : toCollapse.drop_back()) {
+    LevelSet set = info.loop.getCrdUsedLvls();
+    crdUsedLvls |= set.lshift(shift);
+    shift += info.loop.getSpaceDim();
+    for (BlockArgument crd : info.loop.getCrds()) {
+      BlockArgument collapsedCrd = cloned.getBody()->insertArgument(
+          argIdx++, builder.getIndexType(), crd.getLoc());
+      crd.replaceAllUsesWith(collapsedCrd);
+    }
+  }
+  crdUsedLvls |= innermost.getCrdUsedLvls().lshift(shift);
+  cloned.getIterator().setType(collapsedSpace.getType().getIteratorType());
+  cloned.setCrdUsedLvls(crdUsedLvls);
+
+  rItOp.replaceAllUsesWith(cloned.getResults());
+  // Erase collapsed loops.
+  rItOp.erase();
+  root.erase();
+}
+
+struct SparseSpaceCollapsePass
+    : public impl::SparseSpaceCollapseBase<SparseSpaceCollapsePass> {
+  SparseSpaceCollapsePass() = default;
+
+  void runOnOperation() override {
+    func::FuncOp func = getOperation();
+
+    // A naive (experimental) implementation to collapse consecutive sparse
+    // spaces. It does NOT handle complex cases where multiple spaces are
+    // extracted in the same basic block. E.g.,
+    //
+    // %space1 = extract_space %t1 ...
+    // %space2 = extract_space %t2 ...
+    // sparse_tensor.iterate(%sp1) ...
+    //
+    SmallVector<CollapseSpaceInfo> toCollapse;
+    func->walk([&](ExtractIterSpaceOp op) {
+      if (!legalToCollapse(toCollapse, op)) {
+        // if not legal to collapse one more space, collapse the existing ones
+        // and clear.
+        collapseSparseSpace(toCollapse);
+        toCollapse.clear();
+      }
+    });
+
+    collapseSparseSpace(toCollapse);
+  }
+};
+
+} // namespace
+
+std::unique_ptr<Pass> mlir::createSparseSpaceCollapsePass() {
+  return std::make_unique<SparseSpaceCollapsePass>();
+}

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir b/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir
new file mode 100644
index 0000000000000..baa6199f12bc3
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir
@@ -0,0 +1,37 @@
+// RUN: mlir-opt %s --sparse-space-collapse | FileCheck %s
+
+#COO = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : compressed(nonunique),
+    j : singleton(soa)
+  )
+}>
+
+// CHECK-LABEL:   func.func @sparse_sparse_collapse(
+// CHECK-SAME:         %[[VAL_0:.*]]: tensor<4x8xf32, #sparse>,
+// CHECK-SAME:         %[[VAL_1:.*]]: index) {
+// CHECK:           %[[VAL_3:.*]] = sparse_tensor.extract_iteration_space %[[VAL_0]] lvls = 0 to 2 : tensor<4x8xf32, #sparse>
+// CHECK:           %[[VAL_4:.*]] = sparse_tensor.iterate %[[VAL_5:.*]] in %[[VAL_3]] at(%[[VAL_6:.*]], _) iter_args(%[[VAL_7:.*]] = %[[VAL_1]])
+// CHECK:             %[[VAL_8:.*]] = "test.op"(%[[VAL_7]]) : (index) -> index
+// CHECK:             sparse_tensor.yield %[[VAL_8]] : index
+// CHECK:           }
+// CHECK:           "test.sink"(%[[VAL_4]]) : (index) -> ()
+// CHECK:           return
+// CHECK:         }
+func.func @sparse_sparse_collapse(%sp : tensor<4x8xf32, #COO>, %i : index) {
+  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0
+      : tensor<4x8xf32, #COO>
+     -> !sparse_tensor.iter_space<#COO, lvls = 0>
+  %r1 = sparse_tensor.iterate %it1 in %l1 at(%crd0) iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0 to 1> -> index {
+    %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1
+        : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
+       -> !sparse_tensor.iter_space<#COO, lvls = 1>
+    %r2 = sparse_tensor.iterate %it2 in %l2 iter_args(%inner = %outer): !sparse_tensor.iter_space<#COO, lvls = 1 to 2> -> index {
+      %k ="test.op"(%inner) : (index) -> index
+      sparse_tensor.yield %k : index
+    }
+    sparse_tensor.yield %r2 : index
+  }
+  "test.sink"(%r1) : (index) -> ()
+  return
+}


        


More information about the Mlir-commits mailing list