[Mlir-commits] [mlir] [mlir][Pass] Move PassExecutionAction to Pass.h, NFC. (PR #74850)
Aman LaChapelle
llvmlistbot at llvm.org
Fri Dec 8 07:28:41 PST 2023
https://github.com/bzcheeseman created https://github.com/llvm/llvm-project/pull/74850
This patch moves PassExecutionAction to Pass.h so that it can be used by the action framework to introspect and intercede in pass managers that might be set up opaquely. This provides for a very particular use case, which essentially involves being able to intercede in a PassManager and skip or apply individual passes. Because of this, this patch also adds a test for this use case to verify that it could in fact work.
>From 541adead8413010fe3276aade3182990ff165fe0 Mon Sep 17 00:00:00 2001
From: bzcheeseman <aman.lachapelle at gmail.com>
Date: Fri, 8 Dec 2023 07:18:04 -0800
Subject: [PATCH] [mlir][Pass] Move PassExecutionAction to Pass.h, NFC.
This patch moves PassExecutionAction to Pass.h so that it can be used by the action framework to introspect and intercede in pass managers that might be set up opaquely. This provides for a very particular use case, which essentially involves being able to intercede in a PassManager and skip or apply individual passes. Because of this, this patch also adds a test for this use case to verify that it could in fact work.
---
mlir/include/mlir/Pass/Pass.h | 47 +++++++++++
mlir/lib/Pass/Pass.cpp | 10 +++
mlir/lib/Pass/PassDetail.h | 20 -----
mlir/unittests/Pass/CMakeLists.txt | 1 +
mlir/unittests/Pass/PassManagerTest.cpp | 100 ++++++++++++++++++++++++
5 files changed, 158 insertions(+), 20 deletions(-)
diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h
index 5a4df4324ecd1..121b253eb83fe 100644
--- a/mlir/include/mlir/Pass/Pass.h
+++ b/mlir/include/mlir/Pass/Pass.h
@@ -9,6 +9,7 @@
#ifndef MLIR_PASS_PASS_H
#define MLIR_PASS_PASS_H
+#include "mlir/IR/Action.h"
#include "mlir/Pass/AnalysisManager.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Support/LogicalResult.h"
@@ -457,6 +458,52 @@ class PassWrapper : public BaseT {
}
};
+/// This class encapsulates the "action" of executing a single pass. This allows
+/// a user of the Action infrastructure to query information about an action in
+/// (for example) a breakpoint context. You could use it like this:
+///
+/// auto onBreakpoint = [&](const ActionActiveStack *backtrace) {
+/// if (auto passExec = dyn_cast<PassExecutionAction>(anAction))
+/// record(passExec.getPass());
+/// return ExecutionContext::Apply;
+/// };
+/// ExecutionContext exeCtx(onBreakpoint);
+///
+class PassExecutionAction : public tracing::ActionImpl<PassExecutionAction> {
+ using Base = tracing::ActionImpl<PassExecutionAction>;
+
+public:
+ /// Define a TypeID for this PassExecutionAction.
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PassExecutionAction)
+ /// Construct a PassExecutionAction. This is called by the OpToOpPassAdaptor
+ /// when it calls `executeAction`.
+ PassExecutionAction(ArrayRef<IRUnit> irUnits, const Pass &pass);
+
+ /// The tag required by ActionImpl to identify this action.
+ static constexpr StringLiteral tag = "pass-execution";
+
+ /// Print a textual version of this action to `os`.
+ void print(raw_ostream &os) const override;
+
+ /// Get the pass that will be executed by this action. This is not a class of
+ /// passes, or all instances of a pass kind, this is a single pass.
+ const Pass &getPass() const { return pass; }
+
+ /// Get the operation that is the base of this pass. For example, an
+ /// OperationPass<ModuleOp> would return a ModuleOp.
+ Operation *getOp() const;
+
+public:
+ /// Reference to the pass being run. Notice that this will *not* extend the
+ /// lifetime of the pass, and so this class is therefore unsafe to keep past
+ /// the lifetime of the `executeAction` call.
+ const Pass &pass;
+
+ /// The base op for this pass. For an OperationPass<ModuleOp>, we would have a
+ /// ModuleOp here.
+ Operation *op;
+};
+
} // namespace mlir
#endif // MLIR_PASS_PASS_H
diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index 658f8844b428d..810d6a357d52c 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -36,11 +36,21 @@ using namespace mlir::detail;
// PassExecutionAction
//===----------------------------------------------------------------------===//
+PassExecutionAction::PassExecutionAction(ArrayRef<IRUnit> irUnits,
+ const Pass &pass)
+ : Base(irUnits), pass(pass) {}
+
void PassExecutionAction::print(raw_ostream &os) const {
os << llvm::formatv("`{0}` running `{1}` on Operation `{2}`", tag,
pass.getName(), getOp()->getName());
}
+Operation *PassExecutionAction::getOp() const {
+ ArrayRef<IRUnit> irUnits = getContextIRUnits();
+ return irUnits.empty() ? nullptr
+ : llvm::dyn_cast_if_present<Operation *>(irUnits[0]);
+}
+
//===----------------------------------------------------------------------===//
// Pass
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Pass/PassDetail.h b/mlir/lib/Pass/PassDetail.h
index 727607146a68c..0e964b6d6d36b 100644
--- a/mlir/lib/Pass/PassDetail.h
+++ b/mlir/lib/Pass/PassDetail.h
@@ -15,26 +15,6 @@
#include "llvm/Support/FormatVariadic.h"
namespace mlir {
-/// Encapsulate the "action" of executing a single pass, used for the MLIR
-/// tracing infrastructure.
-struct PassExecutionAction : public tracing::ActionImpl<PassExecutionAction> {
- using Base = tracing::ActionImpl<PassExecutionAction>;
- PassExecutionAction(ArrayRef<IRUnit> irUnits, const Pass &pass)
- : Base(irUnits), pass(pass) {}
- static constexpr StringLiteral tag = "pass-execution";
- void print(raw_ostream &os) const override;
- const Pass &getPass() const { return pass; }
- Operation *getOp() const {
- ArrayRef<IRUnit> irUnits = getContextIRUnits();
- return irUnits.empty() ? nullptr
- : llvm::dyn_cast_if_present<Operation *>(irUnits[0]);
- }
-
-public:
- const Pass &pass;
- Operation *op;
-};
-
namespace detail {
//===----------------------------------------------------------------------===//
diff --git a/mlir/unittests/Pass/CMakeLists.txt b/mlir/unittests/Pass/CMakeLists.txt
index 65f0774123865..802b3bbc6c635 100644
--- a/mlir/unittests/Pass/CMakeLists.txt
+++ b/mlir/unittests/Pass/CMakeLists.txt
@@ -5,5 +5,6 @@ add_mlir_unittest(MLIRPassTests
)
target_link_libraries(MLIRPassTests
PRIVATE
+ MLIRDebug
MLIRFuncDialect
MLIRPass)
diff --git a/mlir/unittests/Pass/PassManagerTest.cpp b/mlir/unittests/Pass/PassManagerTest.cpp
index 9a30f64eaabc2..7ceed3bb3bc3b 100644
--- a/mlir/unittests/Pass/PassManagerTest.cpp
+++ b/mlir/unittests/Pass/PassManagerTest.cpp
@@ -7,6 +7,8 @@
//===----------------------------------------------------------------------===//
#include "mlir/Pass/PassManager.h"
+#include "mlir/Debug/BreakpointManagers/TagBreakpointManager.h"
+#include "mlir/Debug/ExecutionContext.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
@@ -86,6 +88,104 @@ TEST(PassManagerTest, OpSpecificAnalysis) {
}
}
+/// Simple pass to annotate a func::FuncOp with a single attribute `didProcess`.
+struct AddAttrFunctionPass
+ : public PassWrapper<AddAttrFunctionPass, OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AddAttrFunctionPass)
+
+ void runOnOperation() override {
+ func::FuncOp op = getOperation();
+ Builder builder(op->getParentOfType<ModuleOp>());
+ if (op->hasAttr("didProcess"))
+ op->setAttr("didProcessAgain", builder.getUnitAttr());
+
+ // We always want to set this one.
+ op->setAttr("didProcess", builder.getUnitAttr());
+ }
+};
+
+/// Simple pass to annotate a func::FuncOp with a single attribute
+/// `didProcess2`.
+struct AddSecondAttrFunctionPass
+ : public PassWrapper<AddSecondAttrFunctionPass,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AddSecondAttrFunctionPass)
+
+ void runOnOperation() override {
+ func::FuncOp op = getOperation();
+ Builder builder(op->getParentOfType<ModuleOp>());
+ op->setAttr("didProcess2", builder.getUnitAttr());
+ }
+};
+
+TEST(PassManagerTest, ExecutionAction) {
+ MLIRContext context;
+ context.loadDialect<func::FuncDialect>();
+ Builder builder(&context);
+
+ // Create a module with 2 functions.
+ OwningOpRef<ModuleOp> module(ModuleOp::create(UnknownLoc::get(&context)));
+ auto f =
+ func::FuncOp::create(builder.getUnknownLoc(), "process_me_once",
+ builder.getFunctionType(std::nullopt, std::nullopt));
+ f.setPrivate();
+ module->push_back(f);
+
+ // Instantiate our passes.
+ auto pm = PassManager::on<ModuleOp>(&context);
+ auto pass = std::make_unique<AddAttrFunctionPass>();
+ auto *passPtr = pass.get();
+ pm.addNestedPass<func::FuncOp>(std::move(pass));
+ pm.addNestedPass<func::FuncOp>(std::make_unique<AddSecondAttrFunctionPass>());
+ // Duplicate the first pass to ensure that we *only* run the *first* pass, not
+ // all instances of this pass kind. Notice that this pass (and the test as a
+ // whole) are built to ensure that we can run just a single pass out of a
+ // pipeline that may contain duplicates.
+ pm.addNestedPass<func::FuncOp>(std::make_unique<AddAttrFunctionPass>());
+
+ // Use the action manager to only hit the first pass, not the second one.
+ auto onBreakpoint = [&](const tracing::ActionActiveStack *backtrace)
+ -> tracing::ExecutionContext::Control {
+ // Not a PassExecutionAction, apply the action.
+ auto *passExec = dyn_cast<PassExecutionAction>(&backtrace->getAction());
+ if (!passExec)
+ return tracing::ExecutionContext::Next;
+
+ // If this isn't a function, apply the action.
+ if (!isa<func::FuncOp>(passExec->getOp()))
+ return tracing::ExecutionContext::Next;
+
+ // Only apply the first function pass. Not all instances of the first pass,
+ // only the first pass.
+ if (passExec->getPass().getThreadingSiblingOrThis() == passPtr)
+ return tracing::ExecutionContext::Next;
+
+ // Do not apply any other passes in the pass manager.
+ return tracing::ExecutionContext::Skip;
+ };
+
+ // Set up our breakpoint manager.
+ tracing::TagBreakpointManager simpleManager;
+ tracing::ExecutionContext executionCtx(onBreakpoint);
+ executionCtx.addBreakpointManager(&simpleManager);
+ simpleManager.addBreakpoint(PassExecutionAction::tag);
+
+ // Register the execution context in the MLIRContext.
+ context.registerActionHandler(executionCtx);
+
+ // Run the pass manager, expecting our handler to be called.
+ LogicalResult result = pm.run(module.get());
+ EXPECT_TRUE(succeeded(result));
+
+ // Verify that each function got annotated with `didProcess` and *not*
+ // `didProcess2`.
+ for (func::FuncOp func : module->getOps<func::FuncOp>()) {
+ ASSERT_TRUE(func->getDiscardableAttr("didProcess"));
+ ASSERT_FALSE(func->getDiscardableAttr("didProcess2"));
+ ASSERT_FALSE(func->getDiscardableAttr("didProcessAgain"));
+ }
+}
+
namespace {
struct InvalidPass : Pass {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InvalidPass)
More information about the Mlir-commits
mailing list