[Mlir-commits] [mlir] Enable pass instrumentation to signal failures. (PR #163126)
Jacques Pienaar
llvmlistbot at llvm.org
Wed Dec 10 20:51:36 PST 2025
https://github.com/jpienaar updated https://github.com/llvm/llvm-project/pull/163126
>From 987b607e9e416cac817f01f24a8cf1ef522cc5a7 Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jacques+gh at japienaar.info>
Date: Mon, 13 Oct 2025 04:58:42 +0000
Subject: [PATCH 1/2] Enable pass instrumentation to signal failures.
---
mlir/include/mlir/Pass/Pass.h | 4 +
mlir/include/mlir/Pass/PassInstrumentation.h | 2 +
mlir/lib/Pass/Pass.cpp | 33 ++++---
mlir/unittests/Pass/PassManagerTest.cpp | 98 ++++++++++++++++++++
4 files changed, 124 insertions(+), 13 deletions(-)
diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h
index 448a688243491..4e005e1ed5bf6 100644
--- a/mlir/include/mlir/Pass/Pass.h
+++ b/mlir/include/mlir/Pass/Pass.h
@@ -17,6 +17,7 @@
#include <optional>
namespace mlir {
+class PassInstrumentation;
namespace detail {
class OpToOpPassAdaptor;
struct OpPassManagerImpl;
@@ -341,6 +342,9 @@ class Pass {
/// Allow access to 'passOptions'.
friend class PassInfo;
+
+ /// Allow access to 'signalPassFailure'.
+ friend class PassInstrumentation;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Pass/PassInstrumentation.h b/mlir/include/mlir/Pass/PassInstrumentation.h
index 917bac4b22288..25a8e77be75ee 100644
--- a/mlir/include/mlir/Pass/PassInstrumentation.h
+++ b/mlir/include/mlir/Pass/PassInstrumentation.h
@@ -80,6 +80,8 @@ class PassInstrumentation {
/// name of the analysis that was computed, its TypeID, as well as the
/// current operation being analyzed.
virtual void runAfterAnalysis(StringRef name, TypeID id, Operation *op) {}
+
+ static void signalPassFailure(Pass *pass);
};
/// This class holds a collection of PassInstrumentation objects, and invokes
diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index 75f882606e0ab..f947fbc36f8c1 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -599,17 +599,20 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
if (pi)
pi->runBeforePass(pass, op);
- bool passFailed = false;
- op->getContext()->executeAction<PassExecutionAction>(
- [&]() {
- // Invoke the virtual runOnOperation method.
- if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(pass))
- adaptor->runOnOperation(verifyPasses);
- else
- pass->runOnOperation();
- passFailed = pass->passState->irAndPassFailed.getInt();
- },
- {op}, *pass);
+ bool passFailed = pass->passState->irAndPassFailed.getInt();
+ if (!passFailed) {
+ op->getContext()->executeAction<PassExecutionAction>(
+ [&]() {
+ // Invoke the virtual runOnOperation method.
+ if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(pass))
+ adaptor->runOnOperation(verifyPasses);
+ else
+ pass->runOnOperation();
+ passFailed = pass->passState->irAndPassFailed.getInt();
+ },
+ {op}, *pass);
+ }
+
// Invalidate any non preserved analyses.
am.invalidate(pass->passState->preservedAnalyses);
@@ -640,10 +643,12 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
// Instrument after the pass has run.
if (pi) {
- if (passFailed)
+ if (passFailed) {
pi->runAfterPassFailed(pass, op);
- else
+ } else {
pi->runAfterPass(pass, op);
+ passFailed = passFailed || pass->passState->irAndPassFailed.getInt();
+ }
}
// Return if the pass signaled a failure.
@@ -1198,6 +1203,8 @@ void PassInstrumentation::runBeforePipeline(
void PassInstrumentation::runAfterPipeline(
std::optional<OperationName> name, const PipelineParentInfo &parentInfo) {}
+void PassInstrumentation::signalPassFailure(Pass *pass) { pass->signalPassFailure(); }
+
//===----------------------------------------------------------------------===//
// PassInstrumentor
//===----------------------------------------------------------------------===//
diff --git a/mlir/unittests/Pass/PassManagerTest.cpp b/mlir/unittests/Pass/PassManagerTest.cpp
index 7e618811eabf4..86c793384db11 100644
--- a/mlir/unittests/Pass/PassManagerTest.cpp
+++ b/mlir/unittests/Pass/PassManagerTest.cpp
@@ -14,6 +14,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassInstrumentation.h"
#include "gtest/gtest.h"
#include <memory>
@@ -117,6 +118,103 @@ struct AddSecondAttrFunctionPass
}
};
+/// PassInstrumentation to count pass callbacks and signal pass failures.
+struct TestPassInstrumentation : public PassInstrumentation {
+ int beforePassCallbackCount = 0;
+ int afterPassCallbackCount = 0;
+ int afterPassFailedCallbackCount = 0;
+
+ bool failBeforePass = false;
+ bool failAfterPass = false;
+
+ void runBeforePass(Pass *pass, Operation *op) override {
+ if (pass->getTypeID() != TypeID::get<AddAttrFunctionPass>()) return;
+
+ ++beforePassCallbackCount;
+ if (failBeforePass)
+ signalPassFailure(pass);
+ }
+ void runAfterPass(Pass *pass, Operation *op) override {
+ if (pass->getTypeID() != TypeID::get<AddAttrFunctionPass>()) return;
+
+ ++afterPassCallbackCount;
+ if (failAfterPass)
+ signalPassFailure(pass);
+ }
+ void runAfterPassFailed(Pass *pass, Operation *op) override {
+ if (pass->getTypeID() != TypeID::get<AddAttrFunctionPass>()) return;
+
+ ++afterPassFailedCallbackCount;
+ }
+};
+
+TEST(PassManagerTest, PassInstrumentation) {
+ MLIRContext context;
+ context.loadDialect<func::FuncDialect>();
+ Builder b(&context);
+
+ // Create a module with 1 function.
+ OwningOpRef<ModuleOp> module(ModuleOp::create(UnknownLoc::get(&context)));
+ auto func = func::FuncOp::create(b.getUnknownLoc(), "test_func",
+ b.getFunctionType({}, {}));
+ func.setPrivate();
+ module->push_back(func);
+
+ struct InstrumentationCounts {
+ int beforePass;
+ int afterPass;
+ int afterPassFailed;
+ };
+
+ auto runInstrumentation =
+ [&](bool failBefore,
+ bool failAfter) -> std::pair<LogicalResult, InstrumentationCounts> {
+ // Instantiate and run our pass.
+ auto pm = PassManager::on<ModuleOp>(&context);
+ auto instrumentation = std::make_unique<TestPassInstrumentation>();
+ auto *instrumentationPtr = instrumentation.get();
+ instrumentation->failBeforePass = failBefore;
+ instrumentation->failAfterPass = failAfter;
+ pm.addInstrumentation(std::move(instrumentation));
+ pm.addNestedPass<func::FuncOp>(std::make_unique<AddAttrFunctionPass>());
+ LogicalResult result = pm.run(module.get());
+
+ InstrumentationCounts counts = {
+ .beforePass = instrumentationPtr->beforePassCallbackCount,
+ .afterPass = instrumentationPtr->afterPassCallbackCount,
+ .afterPassFailed = instrumentationPtr->afterPassFailedCallbackCount};
+ return {result, counts};
+ };
+
+ for (bool failBefore : {false, true}) {
+ for (bool failAfter : {false, true}) {
+ auto [result, counts] = runInstrumentation(failBefore, failAfter);
+
+ InstrumentationCounts expected;
+ if (failBefore) {
+ EXPECT_TRUE(failed(result))
+ << "failBefore=" << failBefore << ", failAfter=" << failAfter;
+ expected = {.beforePass = 1, .afterPass = 0, .afterPassFailed = 1};
+ } else if (failAfter) {
+ EXPECT_TRUE(failed(result))
+ << "failBefore=" << failBefore << ", failAfter=" << failAfter;
+ expected = {.beforePass = 1, .afterPass = 1, .afterPassFailed = 0};
+ } else {
+ EXPECT_TRUE(succeeded(result))
+ << "failBefore=" << failBefore << ", failAfter=" << failAfter;
+ expected = {.beforePass = 1, .afterPass = 1, .afterPassFailed = 0};
+ }
+
+ EXPECT_EQ(counts.beforePass, expected.beforePass)
+ << "failBefore=" << failBefore << ", failAfter=" << failAfter;
+ EXPECT_EQ(counts.afterPass, expected.afterPass)
+ << "failBefore=" << failBefore << ", failAfter=" << failAfter;
+ EXPECT_EQ(counts.afterPassFailed, expected.afterPassFailed)
+ << "failBefore=" << failBefore << ", failAfter=" << failAfter;
+ }
+ }
+}
+
TEST(PassManagerTest, ExecutionAction) {
MLIRContext context;
context.loadDialect<func::FuncDialect>();
>From b754842be85055bf88c09061be152d68f2896cf7 Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jacques+gh at japienaar.info>
Date: Thu, 11 Dec 2025 04:51:00 +0000
Subject: [PATCH 2/2] Address comments
---
mlir/include/mlir/Pass/PassInstrumentation.h | 4 +++-
mlir/lib/Pass/Pass.cpp | 2 ++
2 files changed, 5 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Pass/PassInstrumentation.h b/mlir/include/mlir/Pass/PassInstrumentation.h
index 25a8e77be75ee..4ceff9d657aa4 100644
--- a/mlir/include/mlir/Pass/PassInstrumentation.h
+++ b/mlir/include/mlir/Pass/PassInstrumentation.h
@@ -81,7 +81,9 @@ class PassInstrumentation {
/// current operation being analyzed.
virtual void runAfterAnalysis(StringRef name, TypeID id, Operation *op) {}
- static void signalPassFailure(Pass *pass);
+ /// Helper method to enable analysis to signal pass failure. Used, for
+ /// example, when pre- or post-conditions fail.
+ void signalPassFailure(Pass *pass);
};
/// This class holds a collection of PassInstrumentation objects, and invokes
diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index f947fbc36f8c1..e28036ff1c1f0 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -599,6 +599,8 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
if (pi)
pi->runBeforePass(pass, op);
+ // Pass instrumentation can use pass failure to flag unmet invariants
+ // (preconditions) of the pass. Skip running pass if in failure state.
bool passFailed = pass->passState->irAndPassFailed.getInt();
if (!passFailed) {
op->getContext()->executeAction<PassExecutionAction>(
More information about the Mlir-commits
mailing list