[Mlir-commits] [mlir] [mlir][linalg] Enable fuse consumer (PR #89893)
donald chen
llvmlistbot at llvm.org
Thu Apr 25 07:01:42 PDT 2024
================
@@ -0,0 +1,294 @@
+//===- TestLinalgFuseConsumer.cpp - Test Linalg fuse consumer ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass for testing fuse consumer of Linalg ops.
+// This is a temporary pass used to verify the correctness of the tiling
+// interface in linalg op and the related interface of fuse consumer. It should
+// be replaced with that implementation when the corresponding fusion transform
+// op is completed.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/Block.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Region.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+#include "mlir/Interfaces/TilingInterface.h"
+#include "mlir/Pass/Pass.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+
+using namespace mlir;
+
+#define DEBUG_TYPE "fuse-consumer"
+
+namespace {
+struct TestLinalgFuseConsumer
+ : public PassWrapper<TestLinalgFuseConsumer, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgFuseConsumer)
+
+ TestLinalgFuseConsumer() = default;
+ TestLinalgFuseConsumer(const TestLinalgFuseConsumer &pass)
+ : PassWrapper(pass){};
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<scf::SCFDialect, linalg::LinalgDialect,
+ tensor::TensorDialect>();
+ }
+ StringRef getArgument() const final { return "test-linalg-fuse-consumer"; }
+ StringRef getDescription() const final {
+ return "Test Linalg fuse consumer interface";
+ }
+
+ void runOnOperation() override {
+ Operation *consumerOp = nullptr, *containingOp = nullptr;
+ auto walkRes = getOperation()->walk([&](Operation *op) {
+ if (op->hasAttr("consumer")) {
+ if (consumerOp) {
+ return WalkResult::interrupt();
+ }
+ consumerOp = op;
+ }
+ if (op->hasAttr("containing")) {
+ if (containingOp) {
+ return WalkResult::interrupt();
+ }
+ containingOp = op;
+ }
+ return WalkResult::advance();
+ });
+
+ if (!consumerOp || !containingOp || walkRes.wasInterrupted()) {
+ emitError(getOperation()->getLoc())
----------------
cxy-1993 wrote:
Good idea, done.
https://github.com/llvm/llvm-project/pull/89893
More information about the Mlir-commits
mailing list