[Mlir-commits] [mlir] [mlir] Add filtering callback to GenerateRuntimeVerification pass (PR #150013)
Thomas Hashem
llvmlistbot at llvm.org
Tue Jul 22 05:56:00 PDT 2025
https://github.com/hashemthomas1 created https://github.com/llvm/llvm-project/pull/150013
Users would be able to create this pass and attach to it a custom callback function to filter out unwanted operations.
>From b3e02d110aca9dc47a57b7acd681c246c6d26f60 Mon Sep 17 00:00:00 2001
From: Thomas <hashemthomas1 at gmail.com>
Date: Tue, 22 Jul 2025 12:30:27 +0300
Subject: [PATCH] [mlir] Add filtering callback to GenerateRuntimeVerification
pass
Users would be able to create this pass and attach to it a custom
callback function to filter out unwanted operations.
---
mlir/include/mlir/Transforms/Passes.h | 8 ++++
.../GenerateRuntimeVerification.cpp | 47 +++++++++++++++++--
2 files changed, 52 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 9cd2ef34e15ea..4749a45e51c1f 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -26,6 +26,7 @@
namespace mlir {
class GreedyRewriteConfig;
+class RuntimeVerifiableOpInterface;
//===----------------------------------------------------------------------===//
// Passes
@@ -77,6 +78,13 @@ std::unique_ptr<Pass> createPrintIRPass(const PrintIRPassOptions & = {});
/// Creates a pass that generates IR to verify ops at runtime.
std::unique_ptr<Pass> createGenerateRuntimeVerificationPass();
+/// Create an instance of the generate runtime verification pass, and
+/// use the provided filter function to skip certain verifiable ops.
+/// The default implementation does not filter any ops.
+std::unique_ptr<Pass> createGenerateRuntimeVerificationPass(
+ std::function<bool(RuntimeVerifiableOpInterface)>
+ shouldHandleVerifiableOpFn);
+
/// Creates a loop invariant code motion pass that hoists loop invariant
/// instructions out of the loop.
std::unique_ptr<Pass> createLoopInvariantCodeMotionPass();
diff --git a/mlir/lib/Transforms/GenerateRuntimeVerification.cpp b/mlir/lib/Transforms/GenerateRuntimeVerification.cpp
index a40bc2b3272fc..214510ca8ccd4 100644
--- a/mlir/lib/Transforms/GenerateRuntimeVerification.cpp
+++ b/mlir/lib/Transforms/GenerateRuntimeVerification.cpp
@@ -17,16 +17,46 @@ namespace mlir {
#include "mlir/Transforms/Passes.h.inc"
} // namespace mlir
+#define DEBUG_TYPE "generate-runtime-verification"
+#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
using namespace mlir;
+static bool defaultShouldHandleVerifiableOpFn(RuntimeVerifiableOpInterface op) {
+ // By default, all verifiable ops are considered
+ return true;
+}
+
namespace {
struct GenerateRuntimeVerificationPass
: public impl::GenerateRuntimeVerificationBase<
GenerateRuntimeVerificationPass> {
+
+ GenerateRuntimeVerificationPass();
+ GenerateRuntimeVerificationPass(const GenerateRuntimeVerificationPass &) =
+ default;
+ GenerateRuntimeVerificationPass(
+ std::function<bool(RuntimeVerifiableOpInterface)>
+ shouldHandleVerifiableOpFn);
+
void runOnOperation() override;
+
+private:
+ // A filter function to select verifiable ops to generate verification for.
+ // If empty, all verifiable ops are considered.
+ std::function<bool(RuntimeVerifiableOpInterface)> shouldHandleVerifiableOpFn;
};
} // namespace
+GenerateRuntimeVerificationPass::GenerateRuntimeVerificationPass()
+ : shouldHandleVerifiableOpFn(defaultShouldHandleVerifiableOpFn) {}
+
+GenerateRuntimeVerificationPass::GenerateRuntimeVerificationPass(
+ std::function<bool(RuntimeVerifiableOpInterface)>
+ shouldHandleVerifiableOpFn)
+ : shouldHandleVerifiableOpFn(std::move(shouldHandleVerifiableOpFn)) {}
+
void GenerateRuntimeVerificationPass::runOnOperation() {
// The implementation of the RuntimeVerifiableOpInterface may create ops that
// can be verified. We don't want to generate verification for IR that
@@ -38,11 +68,22 @@ void GenerateRuntimeVerificationPass::runOnOperation() {
OpBuilder builder(getOperation()->getContext());
for (RuntimeVerifiableOpInterface verifiableOp : ops) {
- builder.setInsertionPoint(verifiableOp);
- verifiableOp.generateRuntimeVerification(builder, verifiableOp.getLoc());
- };
+ if (shouldHandleVerifiableOpFn(verifiableOp)) {
+ builder.setInsertionPoint(verifiableOp);
+ verifiableOp.generateRuntimeVerification(builder, verifiableOp.getLoc());
+ } else {
+ LDBG("Skipping operation: " << verifiableOp.getOperation());
+ }
+ }
}
std::unique_ptr<Pass> mlir::createGenerateRuntimeVerificationPass() {
return std::make_unique<GenerateRuntimeVerificationPass>();
}
+
+std::unique_ptr<Pass> mlir::createGenerateRuntimeVerificationPass(
+ std::function<bool(RuntimeVerifiableOpInterface)>
+ shouldHandleVerifiableOpFn) {
+ return std::make_unique<GenerateRuntimeVerificationPass>(
+ std::move(shouldHandleVerifiableOpFn));
+}
More information about the Mlir-commits
mailing list