[Mlir-commits] [mlir] Enable pass instrumentation to signal failures. (PR #163126)

Jacques Pienaar llvmlistbot at llvm.org
Wed Dec 10 22:03:34 PST 2025


https://github.com/jpienaar updated https://github.com/llvm/llvm-project/pull/163126

>From 987b607e9e416cac817f01f24a8cf1ef522cc5a7 Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jacques+gh at japienaar.info>
Date: Mon, 13 Oct 2025 04:58:42 +0000
Subject: [PATCH 1/3] Enable pass instrumentation to signal failures.

---
 mlir/include/mlir/Pass/Pass.h                |  4 +
 mlir/include/mlir/Pass/PassInstrumentation.h |  2 +
 mlir/lib/Pass/Pass.cpp                       | 33 ++++---
 mlir/unittests/Pass/PassManagerTest.cpp      | 98 ++++++++++++++++++++
 4 files changed, 124 insertions(+), 13 deletions(-)

diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h
index 448a688243491..4e005e1ed5bf6 100644
--- a/mlir/include/mlir/Pass/Pass.h
+++ b/mlir/include/mlir/Pass/Pass.h
@@ -17,6 +17,7 @@
 #include <optional>
 
 namespace mlir {
+class PassInstrumentation;
 namespace detail {
 class OpToOpPassAdaptor;
 struct OpPassManagerImpl;
@@ -341,6 +342,9 @@ class Pass {
 
   /// Allow access to 'passOptions'.
   friend class PassInfo;
+
+  /// Allow access to 'signalPassFailure'.
+  friend class PassInstrumentation;
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Pass/PassInstrumentation.h b/mlir/include/mlir/Pass/PassInstrumentation.h
index 917bac4b22288..25a8e77be75ee 100644
--- a/mlir/include/mlir/Pass/PassInstrumentation.h
+++ b/mlir/include/mlir/Pass/PassInstrumentation.h
@@ -80,6 +80,8 @@ class PassInstrumentation {
   /// name of the analysis that was computed, its TypeID, as well as the
   /// current operation being analyzed.
   virtual void runAfterAnalysis(StringRef name, TypeID id, Operation *op) {}
+
+  static void signalPassFailure(Pass *pass);
 };
 
 /// This class holds a collection of PassInstrumentation objects, and invokes
diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index 75f882606e0ab..f947fbc36f8c1 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -599,17 +599,20 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
   if (pi)
     pi->runBeforePass(pass, op);
 
-  bool passFailed = false;
-  op->getContext()->executeAction<PassExecutionAction>(
-      [&]() {
-        // Invoke the virtual runOnOperation method.
-        if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(pass))
-          adaptor->runOnOperation(verifyPasses);
-        else
-          pass->runOnOperation();
-        passFailed = pass->passState->irAndPassFailed.getInt();
-      },
-      {op}, *pass);
+  bool passFailed = pass->passState->irAndPassFailed.getInt();
+  if (!passFailed) {
+    op->getContext()->executeAction<PassExecutionAction>(
+        [&]() {
+          // Invoke the virtual runOnOperation method.
+          if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(pass))
+            adaptor->runOnOperation(verifyPasses);
+          else
+            pass->runOnOperation();
+          passFailed = pass->passState->irAndPassFailed.getInt();
+        },
+        {op}, *pass);
+  }
+
 
   // Invalidate any non preserved analyses.
   am.invalidate(pass->passState->preservedAnalyses);
@@ -640,10 +643,12 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
 
   // Instrument after the pass has run.
   if (pi) {
-    if (passFailed)
+    if (passFailed) {
       pi->runAfterPassFailed(pass, op);
-    else
+    } else {
       pi->runAfterPass(pass, op);
+      passFailed = passFailed || pass->passState->irAndPassFailed.getInt();
+    }
   }
 
   // Return if the pass signaled a failure.
@@ -1198,6 +1203,8 @@ void PassInstrumentation::runBeforePipeline(
 void PassInstrumentation::runAfterPipeline(
     std::optional<OperationName> name, const PipelineParentInfo &parentInfo) {}
 
+void PassInstrumentation::signalPassFailure(Pass *pass) { pass->signalPassFailure(); }
+
 //===----------------------------------------------------------------------===//
 // PassInstrumentor
 //===----------------------------------------------------------------------===//
diff --git a/mlir/unittests/Pass/PassManagerTest.cpp b/mlir/unittests/Pass/PassManagerTest.cpp
index 7e618811eabf4..86c793384db11 100644
--- a/mlir/unittests/Pass/PassManagerTest.cpp
+++ b/mlir/unittests/Pass/PassManagerTest.cpp
@@ -14,6 +14,7 @@
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassInstrumentation.h"
 #include "gtest/gtest.h"
 
 #include <memory>
@@ -117,6 +118,103 @@ struct AddSecondAttrFunctionPass
   }
 };
 
+/// PassInstrumentation to count pass callbacks and signal pass failures.
+struct TestPassInstrumentation : public PassInstrumentation {
+  int beforePassCallbackCount = 0;
+  int afterPassCallbackCount = 0;
+  int afterPassFailedCallbackCount = 0;
+
+  bool failBeforePass = false;
+  bool failAfterPass = false;
+
+  void runBeforePass(Pass *pass, Operation *op) override {
+    if (pass->getTypeID() != TypeID::get<AddAttrFunctionPass>()) return;
+
+    ++beforePassCallbackCount;
+    if (failBeforePass)
+      signalPassFailure(pass);
+  }
+  void runAfterPass(Pass *pass, Operation *op) override {
+    if (pass->getTypeID() != TypeID::get<AddAttrFunctionPass>()) return;
+
+    ++afterPassCallbackCount;
+    if (failAfterPass)
+      signalPassFailure(pass);
+  }
+  void runAfterPassFailed(Pass *pass, Operation *op) override {
+    if (pass->getTypeID() != TypeID::get<AddAttrFunctionPass>()) return;
+
+    ++afterPassFailedCallbackCount;
+  }
+};
+
+TEST(PassManagerTest, PassInstrumentation) {
+  MLIRContext context;
+  context.loadDialect<func::FuncDialect>();
+  Builder b(&context);
+
+  // Create a module with 1 function.
+  OwningOpRef<ModuleOp> module(ModuleOp::create(UnknownLoc::get(&context)));
+  auto func = func::FuncOp::create(b.getUnknownLoc(), "test_func",
+                                   b.getFunctionType({}, {}));
+  func.setPrivate();
+  module->push_back(func);
+
+  struct InstrumentationCounts {
+    int beforePass;
+    int afterPass;
+    int afterPassFailed;
+  };
+
+  auto runInstrumentation =
+      [&](bool failBefore,
+          bool failAfter) -> std::pair<LogicalResult, InstrumentationCounts> {
+    // Instantiate and run our pass.
+    auto pm = PassManager::on<ModuleOp>(&context);
+    auto instrumentation = std::make_unique<TestPassInstrumentation>();
+    auto *instrumentationPtr = instrumentation.get();
+    instrumentation->failBeforePass = failBefore;
+    instrumentation->failAfterPass = failAfter;
+    pm.addInstrumentation(std::move(instrumentation));
+    pm.addNestedPass<func::FuncOp>(std::make_unique<AddAttrFunctionPass>());
+    LogicalResult result = pm.run(module.get());
+
+    InstrumentationCounts counts = {
+        .beforePass = instrumentationPtr->beforePassCallbackCount,
+        .afterPass = instrumentationPtr->afterPassCallbackCount,
+        .afterPassFailed = instrumentationPtr->afterPassFailedCallbackCount};
+    return {result, counts};
+  };
+
+  for (bool failBefore : {false, true}) {
+    for (bool failAfter : {false, true}) {
+      auto [result, counts] = runInstrumentation(failBefore, failAfter);
+
+      InstrumentationCounts expected;
+      if (failBefore) {
+        EXPECT_TRUE(failed(result))
+            << "failBefore=" << failBefore << ", failAfter=" << failAfter;
+        expected = {.beforePass = 1, .afterPass = 0, .afterPassFailed = 1};
+      } else if (failAfter) {
+        EXPECT_TRUE(failed(result))
+            << "failBefore=" << failBefore << ", failAfter=" << failAfter;
+        expected = {.beforePass = 1, .afterPass = 1, .afterPassFailed = 0};
+      } else {
+        EXPECT_TRUE(succeeded(result))
+            << "failBefore=" << failBefore << ", failAfter=" << failAfter;
+        expected = {.beforePass = 1, .afterPass = 1, .afterPassFailed = 0};
+      }
+
+      EXPECT_EQ(counts.beforePass, expected.beforePass)
+          << "failBefore=" << failBefore << ", failAfter=" << failAfter;
+      EXPECT_EQ(counts.afterPass, expected.afterPass)
+          << "failBefore=" << failBefore << ", failAfter=" << failAfter;
+      EXPECT_EQ(counts.afterPassFailed, expected.afterPassFailed)
+          << "failBefore=" << failBefore << ", failAfter=" << failAfter;
+    }
+  }
+}
+
 TEST(PassManagerTest, ExecutionAction) {
   MLIRContext context;
   context.loadDialect<func::FuncDialect>();

>From b754842be85055bf88c09061be152d68f2896cf7 Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jacques+gh at japienaar.info>
Date: Thu, 11 Dec 2025 04:51:00 +0000
Subject: [PATCH 2/3] Address comments

---
 mlir/include/mlir/Pass/PassInstrumentation.h | 4 +++-
 mlir/lib/Pass/Pass.cpp                       | 2 ++
 2 files changed, 5 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Pass/PassInstrumentation.h b/mlir/include/mlir/Pass/PassInstrumentation.h
index 25a8e77be75ee..4ceff9d657aa4 100644
--- a/mlir/include/mlir/Pass/PassInstrumentation.h
+++ b/mlir/include/mlir/Pass/PassInstrumentation.h
@@ -81,7 +81,9 @@ class PassInstrumentation {
   /// current operation being analyzed.
   virtual void runAfterAnalysis(StringRef name, TypeID id, Operation *op) {}
 
-  static void signalPassFailure(Pass *pass);
+  /// Helper method to enable analysis to signal pass failure. Used, for
+  /// example, when pre- or post-conditions fail.
+  void signalPassFailure(Pass *pass);
 };
 
 /// This class holds a collection of PassInstrumentation objects, and invokes
diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index f947fbc36f8c1..e28036ff1c1f0 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -599,6 +599,8 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
   if (pi)
     pi->runBeforePass(pass, op);
 
+  // Pass instrumentation can use pass failure to flag unmet invariants
+  // (preconditions) of the pass. Skip running pass if in failure state.
   bool passFailed = pass->passState->irAndPassFailed.getInt();
   if (!passFailed) {
     op->getContext()->executeAction<PassExecutionAction>(

>From 4bec816cd37e149538520c95249807753d1af08c Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jacques+gh at japienaar.info>
Date: Thu, 11 Dec 2025 06:03:19 +0000
Subject: [PATCH 3/3] Avoid C++20 designated initializers

---
 mlir/unittests/Pass/PassManagerTest.cpp | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/mlir/unittests/Pass/PassManagerTest.cpp b/mlir/unittests/Pass/PassManagerTest.cpp
index 86c793384db11..9a6973d3a2411 100644
--- a/mlir/unittests/Pass/PassManagerTest.cpp
+++ b/mlir/unittests/Pass/PassManagerTest.cpp
@@ -180,9 +180,9 @@ TEST(PassManagerTest, PassInstrumentation) {
     LogicalResult result = pm.run(module.get());
 
     InstrumentationCounts counts = {
-        .beforePass = instrumentationPtr->beforePassCallbackCount,
-        .afterPass = instrumentationPtr->afterPassCallbackCount,
-        .afterPassFailed = instrumentationPtr->afterPassFailedCallbackCount};
+        instrumentationPtr->beforePassCallbackCount,
+        instrumentationPtr->afterPassCallbackCount,
+        instrumentationPtr->afterPassFailedCallbackCount};
     return {result, counts};
   };
 
@@ -194,15 +194,15 @@ TEST(PassManagerTest, PassInstrumentation) {
       if (failBefore) {
         EXPECT_TRUE(failed(result))
             << "failBefore=" << failBefore << ", failAfter=" << failAfter;
-        expected = {.beforePass = 1, .afterPass = 0, .afterPassFailed = 1};
+        expected = {/*beforePass=*/1, /*afterPass=*/0, /*afterPassFailed=*/1};
       } else if (failAfter) {
         EXPECT_TRUE(failed(result))
             << "failBefore=" << failBefore << ", failAfter=" << failAfter;
-        expected = {.beforePass = 1, .afterPass = 1, .afterPassFailed = 0};
+        expected = {/*beforePass=*/1, /*afterPass=*/1, /*afterPassFailed=*/0};
       } else {
         EXPECT_TRUE(succeeded(result))
             << "failBefore=" << failBefore << ", failAfter=" << failAfter;
-        expected = {.beforePass = 1, .afterPass = 1, .afterPassFailed = 0};
+        expected = {/*beforePass=*/1, /*afterPass=*/1, /*afterPassFailed=*/0};
       }
 
       EXPECT_EQ(counts.beforePass, expected.beforePass)



More information about the Mlir-commits mailing list