[Mlir-commits] [mlir] [mlir][linalg] Enable fuse consumer (PR #89893)
Oleksandr Alex Zinenko
llvmlistbot at llvm.org
Thu Apr 25 01:19:39 PDT 2024
================
@@ -0,0 +1,294 @@
+//===- TestLinalgFuseConsumer.cpp - Test Linalg fuse consumer ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass for testing fuse consumer of Linalg ops.
+// This is a temporary pass used to verify the correctness of the tiling
+// interface in linalg op and the related interface of fuse consumer. It should
+// be replaced with that implementation when the corresponding fusion transform
+// op is completed.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/Block.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Region.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+#include "mlir/Interfaces/TilingInterface.h"
+#include "mlir/Pass/Pass.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+
+using namespace mlir;
+
+#define DEBUG_TYPE "fuse-consumer"
+
+namespace {
+struct TestLinalgFuseConsumer
+ : public PassWrapper<TestLinalgFuseConsumer, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgFuseConsumer)
+
+ TestLinalgFuseConsumer() = default;
+ TestLinalgFuseConsumer(const TestLinalgFuseConsumer &pass)
+ : PassWrapper(pass){};
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<scf::SCFDialect, linalg::LinalgDialect,
+ tensor::TensorDialect>();
+ }
+ StringRef getArgument() const final { return "test-linalg-fuse-consumer"; }
+ StringRef getDescription() const final {
+ return "Test Linalg fuse consumer interface";
+ }
+
+ void runOnOperation() override {
+ Operation *consumerOp = nullptr, *containingOp = nullptr;
+ auto walkRes = getOperation()->walk([&](Operation *op) {
+ if (op->hasAttr("consumer")) {
+ if (consumerOp) {
+ return WalkResult::interrupt();
+ }
+ consumerOp = op;
+ }
+ if (op->hasAttr("containing")) {
+ if (containingOp) {
+ return WalkResult::interrupt();
+ }
+ containingOp = op;
+ }
+ return WalkResult::advance();
+ });
+
+ if (!consumerOp || !containingOp || walkRes.wasInterrupted()) {
+ emitError(getOperation()->getLoc())
+ << "expect 1 consumer and 1 containing op.";
+ return;
+ }
+
+ // Check consumer has tiling interface.
+ auto tileableConsumer = dyn_cast<TilingInterface>(consumerOp);
+ if (!tileableConsumer) {
+ emitError(consumerOp->getLoc())
+ << "consumer is not a TileableInterface: " << *consumerOp;
+ return;
+ }
+
+ // Check containing op is "scf::ForallOp".
+ auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
+ if (!forallOp) {
+ emitError(containingOp->getLoc())
+ << "containing op is not a scf.forall: " << containingOp;
+ return;
+ }
+
+ // Check dominance.
+ DominanceInfo domInfo(getOperation());
+ if (llvm::any_of(consumerOp->getOperands(), [&](Value v) {
+ return v.getDefiningOp() != containingOp &&
+ !domInfo.properlyDominates(v, containingOp);
+ })) {
+ emitError(consumerOp->getLoc())
+ << "consumer's operand can't dominate containing op";
+ return;
+ }
+
+ // Check consumer don't use more than one result of containingOp.
+ Value bridge(nullptr);
+ SmallVector<unsigned> operandNums;
+ for (auto [idx, opd] : llvm::enumerate((consumerOp->getOperands()))) {
+ if (opd.getDefiningOp() == containingOp) {
+ operandNums.push_back(idx);
+ if (!bridge) {
+ bridge = opd;
+ } else if (bridge != opd) {
+ emitError(consumerOp->getLoc())
+ << "consumer's operand use more than one containingOp's result";
+ return;
+ }
+ }
+ }
+
+ // Check consumer has DestinationStyleOpInterface.
+ auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
+ if (!dstOp) {
+ emitError(consumerOp->getLoc())
+ << "consumer op should have destination style op interface";
+ return;
+ }
+
+ // Check consumer doon't use scf.forall's output as init.
+ SmallVector<Value> dpsInits = llvm::to_vector<4>(
+ llvm::map_range(dstOp.getDpsInits(), [](Value v) { return v; }));
+ if (llvm::is_contained(dpsInits, bridge)) {
+ emitError(consumerOp->getLoc())
+ << "consumer op take result of scf.forall as init";
+ return;
+ }
+
+ // Check result was inserted only once.
+ int64_t bridgeResultIdx = cast<OpResult>(bridge).getResultNumber();
+ auto bridgeBlockArg = forallOp.getRegionOutArgs()[bridgeResultIdx];
+ scf::InParallelOp terminatorOp = forallOp.getTerminator();
+
+ tensor::ParallelInsertSliceOp targetInsertOp(nullptr);
+ for (Operation &op : terminatorOp.getRegion().front().getOperations()) {
+ auto parallelInsertSliceOp = cast<tensor::ParallelInsertSliceOp>(op);
+ if (parallelInsertSliceOp.getDest() == bridgeBlockArg) {
+ if (!targetInsertOp) {
+ targetInsertOp = parallelInsertSliceOp;
+ } else {
+ emitError(containingOp->getLoc())
+ << "containingOp's result inserted multi time";
+ return;
+ }
+ }
+ }
+
+ if (!targetInsertOp) {
+ emitError(containingOp->getLoc())
+ << "containingOp's result was not inserted";
+ return;
+ }
+
+ SmallVector<OpFoldResult> offsets = targetInsertOp.getMixedOffsets();
+ SmallVector<OpFoldResult> sizes = targetInsertOp.getMixedSizes();
+ SmallVector<OpFoldResult> strides = targetInsertOp.getMixedStrides();
+
+ // Check all insert stride is 1.
+ if (llvm::any_of(strides, [](OpFoldResult foldRes) {
+ if (auto attr = foldRes.dyn_cast<Attribute>()) {
+ return cast<IntegerAttr>(attr).getInt() != 1;
+ }
+ return true;
+ })) {
+ emitError(containingOp->getLoc())
+ << "containingOp's result yield with stride";
+ return;
+ }
+
+ IRRewriter rewriter(terminatorOp);
+ Location loc = forallOp.getLoc();
+
+ SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+
+ // Try to get iter domain position from input position.
+ if (failed(tileableConsumer.getIterationDomainTileFromOperandTile(
+ rewriter, operandNums.front(), offsets, sizes, iterDomainOffsets,
+ iterDomainSizes))) {
+ emitError(consumerOp->getLoc())
+ << "can't get iter domain position from input position";
+ return;
+ }
+
+ // Try to get all containing op result's position from iter domain position.
+ llvm::SmallVector<std::pair<llvm::SmallVector<OpFoldResult>,
+ llvm::SmallVector<OpFoldResult>>>
----------------
ftynse wrote:
No need to prefix with `llvm::`.
https://github.com/llvm/llvm-project/pull/89893
More information about the Mlir-commits
mailing list