[Mlir-commits] [mlir] [mlir][sparse] implement sparse space collapse pass. (PR #89003)
Peiming Liu
llvmlistbot at llvm.org
Mon Jun 10 14:18:10 PDT 2024
================
@@ -0,0 +1,183 @@
+//===--------- SparseSpaceCollapse.cpp - Collapse Sparse Space Pass -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/Transforms/Passes.h"
+
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+
+namespace mlir {
+
+#define GEN_PASS_DEF_SPARSESPACECOLLAPSE
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
+
+namespace sparse_tensor {
+
+struct CollapseSpaceInfo {
+ ExtractIterSpaceOp space;
+ IterateOp loop;
+};
+
+bool isCollapsableLoops(LoopLikeOpInterface parent, LoopLikeOpInterface node) {
+ auto pIterArgs = parent.getRegionIterArgs();
+ auto nInitArgs = node.getInits();
+ if (pIterArgs.size() != nInitArgs.size())
+ return false;
+
+ // Two loops are collapsable if they are perfectly nested.
+ auto pYields = parent.getYieldedValues();
+ auto nResult = node.getLoopResults().value();
+
+ bool yieldEq =
+ llvm::all_of(llvm::zip_equal(pYields, nResult), [](auto zipped) {
+ return std::get<0>(zipped) == std::get<1>(zipped);
+ });
+
+ // Parent iter_args should be passed directly to the node's init_args.
+ bool iterArgEq =
+ llvm::all_of(llvm::zip_equal(pIterArgs, nInitArgs), [](auto zipped) {
+ return std::get<0>(zipped) == std::get<1>(zipped);
+ });
+
+ return yieldEq && iterArgEq;
+}
+
+bool legalToCollapse(SmallVectorImpl<CollapseSpaceInfo> &toCollapse,
+ ExtractIterSpaceOp curSpace) {
+
+ auto getIterateOpOverSpace = [](ExtractIterSpaceOp space) -> IterateOp {
+ Value spaceVal = space.getResultSpace();
+ if (spaceVal.hasOneUse())
+ return llvm::dyn_cast<IterateOp>(*spaceVal.getUsers().begin());
+ return nullptr;
+ };
+
+ if (toCollapse.empty()) {
+ // Collapse root.
+ if (auto itOp = getIterateOpOverSpace(curSpace)) {
+ CollapseSpaceInfo &info = toCollapse.emplace_back();
+ info.space = curSpace;
+ info.loop = itOp;
+ return true;
+ }
+ return false;
+ }
+
+ auto parent = toCollapse.back().space;
+ auto pItOp = toCollapse.back().loop;
+ auto nItOp = getIterateOpOverSpace(curSpace);
+
+ // Can only collapse spaces extracted from the same tensor.
+ if (parent.getTensor() != curSpace.getTensor())
+ return false;
+
+ // Can only collapse consecutive simple iteration on one tensor (i.e., no
+ // coiteration).
+ if (!nItOp || nItOp->getBlock() != curSpace->getBlock() ||
+ pItOp.getIterator() != curSpace.getParentIter() ||
+ curSpace->getParentOp() != pItOp.getOperation())
+ return false;
+
+ if (pItOp && !isCollapsableLoops(pItOp, nItOp))
+ return false;
+
+ CollapseSpaceInfo &info = toCollapse.emplace_back();
+ info.space = curSpace;
+ info.loop = nItOp;
+ return true;
+}
+
+void collapseSparseSpace(MutableArrayRef<CollapseSpaceInfo> toCollapse) {
+ if (toCollapse.size() < 2)
+ return;
+
+ ExtractIterSpaceOp root = toCollapse.front().space;
+ ExtractIterSpaceOp leaf = toCollapse.back().space;
+ Location loc = root.getLoc();
+
+ assert(root->hasOneUse() && leaf->hasOneUse());
+
+ // Insert collapsed operation at the same scope as root operation.
+ OpBuilder builder(root);
+
+ // Construct the collapsed iteration space.
+ auto collapsedSpace = builder.create<ExtractIterSpaceOp>(
+ loc, root.getTensor(), root.getParentIter(), root.getLoLvl(),
+ leaf.getHiLvl());
+
+ auto rItOp = llvm::cast<IterateOp>(*root->getUsers().begin());
+ auto innermost = toCollapse.back().loop;
+
+ IRMapping mapper;
+ mapper.map(leaf, collapsedSpace.getResultSpace());
+ for (auto z : llvm::zip_equal(innermost.getInitArgs(), rItOp.getInitArgs()))
+ mapper.map(std::get<0>(z), std::get<1>(z));
+
+ auto cloned = llvm::cast<IterateOp>(builder.clone(*innermost, mapper));
----------------
PeimingLiu wrote:
since `innermost` is nested in the outermost loop (`root`), I would assume that it will be erased together with the outermost loop. Am I wrong?
https://github.com/llvm/llvm-project/pull/89003
More information about the Mlir-commits
mailing list