[Mlir-commits] [mlir] [mlir] Add filtering callback to GenerateRuntimeVerification pass (PR #150013)

Thomas Hashem llvmlistbot at llvm.org
Tue Jul 22 05:56:00 PDT 2025


https://github.com/hashemthomas1 created https://github.com/llvm/llvm-project/pull/150013

Users would be able to create this pass and attach to it a custom callback function to filter out unwanted operations.

>From b3e02d110aca9dc47a57b7acd681c246c6d26f60 Mon Sep 17 00:00:00 2001
From: Thomas <hashemthomas1 at gmail.com>
Date: Tue, 22 Jul 2025 12:30:27 +0300
Subject: [PATCH] [mlir] Add filtering callback to GenerateRuntimeVerification
 pass

Users would be able to create this pass and attach to it a custom
callback function to filter out unwanted operations.
---
 mlir/include/mlir/Transforms/Passes.h         |  8 ++++
 .../GenerateRuntimeVerification.cpp           | 47 +++++++++++++++++--
 2 files changed, 52 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 9cd2ef34e15ea..4749a45e51c1f 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -26,6 +26,7 @@
 namespace mlir {
 
 class GreedyRewriteConfig;
+class RuntimeVerifiableOpInterface;
 
 //===----------------------------------------------------------------------===//
 // Passes
@@ -77,6 +78,13 @@ std::unique_ptr<Pass> createPrintIRPass(const PrintIRPassOptions & = {});
 /// Creates a pass that generates IR to verify ops at runtime.
 std::unique_ptr<Pass> createGenerateRuntimeVerificationPass();
 
+/// Create an instance of the generate runtime verification pass, and
+/// use the provided filter function to skip certain verifiable ops.
+/// The default implementation does not filter any ops.
+std::unique_ptr<Pass> createGenerateRuntimeVerificationPass(
+    std::function<bool(RuntimeVerifiableOpInterface)>
+        shouldHandleVerifiableOpFn);
+
 /// Creates a loop invariant code motion pass that hoists loop invariant
 /// instructions out of the loop.
 std::unique_ptr<Pass> createLoopInvariantCodeMotionPass();
diff --git a/mlir/lib/Transforms/GenerateRuntimeVerification.cpp b/mlir/lib/Transforms/GenerateRuntimeVerification.cpp
index a40bc2b3272fc..214510ca8ccd4 100644
--- a/mlir/lib/Transforms/GenerateRuntimeVerification.cpp
+++ b/mlir/lib/Transforms/GenerateRuntimeVerification.cpp
@@ -17,16 +17,46 @@ namespace mlir {
 #include "mlir/Transforms/Passes.h.inc"
 } // namespace mlir
 
+#define DEBUG_TYPE "generate-runtime-verification"
+#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
 using namespace mlir;
 
+static bool defaultShouldHandleVerifiableOpFn(RuntimeVerifiableOpInterface op) {
+  // By default, all verifiable ops are considered
+  return true;
+}
+
 namespace {
 struct GenerateRuntimeVerificationPass
     : public impl::GenerateRuntimeVerificationBase<
           GenerateRuntimeVerificationPass> {
+
+  GenerateRuntimeVerificationPass();
+  GenerateRuntimeVerificationPass(const GenerateRuntimeVerificationPass &) =
+      default;
+  GenerateRuntimeVerificationPass(
+      std::function<bool(RuntimeVerifiableOpInterface)>
+          shouldHandleVerifiableOpFn);
+
   void runOnOperation() override;
+
+private:
+  // A filter function to select verifiable ops to generate verification for.
+  // If empty, all verifiable ops are considered.
+  std::function<bool(RuntimeVerifiableOpInterface)> shouldHandleVerifiableOpFn;
 };
 } // namespace
 
+GenerateRuntimeVerificationPass::GenerateRuntimeVerificationPass()
+    : shouldHandleVerifiableOpFn(defaultShouldHandleVerifiableOpFn) {}
+
+GenerateRuntimeVerificationPass::GenerateRuntimeVerificationPass(
+    std::function<bool(RuntimeVerifiableOpInterface)>
+        shouldHandleVerifiableOpFn)
+    : shouldHandleVerifiableOpFn(std::move(shouldHandleVerifiableOpFn)) {}
+
 void GenerateRuntimeVerificationPass::runOnOperation() {
   // The implementation of the RuntimeVerifiableOpInterface may create ops that
   // can be verified. We don't want to generate verification for IR that
@@ -38,11 +68,22 @@ void GenerateRuntimeVerificationPass::runOnOperation() {
 
   OpBuilder builder(getOperation()->getContext());
   for (RuntimeVerifiableOpInterface verifiableOp : ops) {
-    builder.setInsertionPoint(verifiableOp);
-    verifiableOp.generateRuntimeVerification(builder, verifiableOp.getLoc());
-  };
+    if (shouldHandleVerifiableOpFn(verifiableOp)) {
+      builder.setInsertionPoint(verifiableOp);
+      verifiableOp.generateRuntimeVerification(builder, verifiableOp.getLoc());
+    } else {
+      LDBG("Skipping operation: " << verifiableOp.getOperation());
+    }
+  }
 }
 
 std::unique_ptr<Pass> mlir::createGenerateRuntimeVerificationPass() {
   return std::make_unique<GenerateRuntimeVerificationPass>();
 }
+
+std::unique_ptr<Pass> mlir::createGenerateRuntimeVerificationPass(
+    std::function<bool(RuntimeVerifiableOpInterface)>
+        shouldHandleVerifiableOpFn) {
+  return std::make_unique<GenerateRuntimeVerificationPass>(
+      std::move(shouldHandleVerifiableOpFn));
+}



More information about the Mlir-commits mailing list