[Mlir-commits] [mlir] [MLIR] Add options to generate-runtime-verification to enable faster pass running (PR #160331)

Hanchenng Wu llvmlistbot at llvm.org
Mon Oct 6 14:10:44 PDT 2025


https://github.com/HanchengWu updated https://github.com/llvm/llvm-project/pull/160331

>From 1bcb096dc0ed34c368237f515a503c89c5146d86 Mon Sep 17 00:00:00 2001
From: Henry Wu <henrywu at mathworks.com>
Date: Tue, 23 Sep 2025 10:51:30 -0400
Subject: [PATCH] Reuse AsmState to enable fast op->print in
 generate-runtime-verification pass, and add location only pass option.

---
 .../RuntimeVerifiableOpInterface.td           |  8 +-
 mlir/include/mlir/Transforms/Passes.h         |  1 +
 mlir/include/mlir/Transforms/Passes.td        |  7 ++
 .../Transforms/RuntimeOpVerification.cpp      | 10 +-
 .../Transforms/RuntimeOpVerification.cpp      | 94 ++++++++++---------
 .../Transforms/RuntimeOpVerification.cpp      | 47 +++++-----
 .../RuntimeVerifiableOpInterface.cpp          | 26 -----
 .../GenerateRuntimeVerification.cpp           | 57 ++++++++++-
 .../Dialect/Linalg/runtime-verification.mlir  |  7 ++
 9 files changed, 157 insertions(+), 100 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td b/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td
index 6fd0df59d9d2e..21104834a9dc3 100644
--- a/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td
+++ b/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td
@@ -32,15 +32,11 @@ def RuntimeVerifiableOpInterface : OpInterface<"RuntimeVerifiableOpInterface"> {
       /*retTy=*/"void",
       /*methodName=*/"generateRuntimeVerification",
       /*args=*/(ins "::mlir::OpBuilder &":$builder,
-                    "::mlir::Location":$loc)
+                    "::mlir::Location":$loc,
+                    "function_ref<std::string(Operation *, StringRef)>":$generateErrorMessage)
     >,
   ];
 
-  let extraClassDeclaration = [{
-    /// Generate the error message that will be printed to the user when 
-    /// verification fails.
-    static std::string generateErrorMessage(Operation *op, const std::string &msg);
-  }];
 }
 
 #endif // MLIR_INTERFACES_RUNTIMEVERIFIABLEOPINTERFACE
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 1c035f2a843ff..17c323a042ec2 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -47,6 +47,7 @@ class GreedyRewriteConfig;
 #define GEN_PASS_DECL_TOPOLOGICALSORT
 #define GEN_PASS_DECL_COMPOSITEFIXEDPOINTPASS
 #define GEN_PASS_DECL_BUBBLEDOWNMEMORYSPACECASTS
+#define GEN_PASS_DECL_GENERATERUNTIMEVERIFICATION
 #include "mlir/Transforms/Passes.h.inc"
 
 /// Creates an instance of the Canonicalizer pass, configured with default
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index b2b7f20a497e3..28b4a01cf0ecd 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -270,8 +270,15 @@ def GenerateRuntimeVerification : Pass<"generate-runtime-verification"> {
     passes that are suspected to introduce faulty IR.
   }];
   let constructor = "mlir::createGenerateRuntimeVerificationPass()";
+  let options = [
+    Option<"verboseLevel", "verbose-level", "unsigned", /*default=*/"1",
+           "Verbosity level for runtime verification messages: "
+           "0 = Minimum (only source location), "
+           "1 = Detailed (include full operation details, names, types, shapes, etc.)">
+  ];
 }
 
+
 def Inliner : Pass<"inline"> {
   let summary = "Inline function calls";
   let constructor = "mlir::createInlinerPass()";
diff --git a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
index eac0e47b18a7d..15eb51a6dcab2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
@@ -31,8 +31,10 @@ template <typename T>
 struct StructuredOpInterface
     : public RuntimeVerifiableOpInterface::ExternalModel<
           StructuredOpInterface<T>, T> {
-  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
-                                   Location loc) const {
+  void
+  generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
+                              function_ref<std::string(Operation *, StringRef)>
+                                  generateErrorMessage) const {
     auto linalgOp = llvm::cast<LinalgOp>(op);
 
     SmallVector<Range> loopRanges = linalgOp.createLoopRanges(builder, loc);
@@ -70,7 +72,7 @@ struct StructuredOpInterface
             builder.createOrFold<index::MinSOp>(loc, startIndex, endIndex);
         auto cmpOp = builder.createOrFold<index::CmpOp>(
             loc, index::IndexCmpPredicate::SGE, min, zero);
-        auto msg = RuntimeVerifiableOpInterface::generateErrorMessage(
+        auto msg = generateErrorMessage(
             linalgOp, "unexpected negative result on dimension #" +
                           std::to_string(dim) + " of input/output operand #" +
                           std::to_string(opOperand.getOperandNumber()));
@@ -100,7 +102,7 @@ struct StructuredOpInterface
 
         cmpOp = builder.createOrFold<index::CmpOp>(
             loc, predicate, inferredDimSize, actualDimSize);
-        msg = RuntimeVerifiableOpInterface::generateErrorMessage(
+        msg = generateErrorMessage(
             linalgOp, "dimension #" + std::to_string(dim) +
                           " of input/output operand #" +
                           std::to_string(opOperand.getOperandNumber()) +
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index d3a77c026379e..291da1f76ca9b 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -37,8 +37,10 @@ Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value,
 struct AssumeAlignmentOpInterface
     : public RuntimeVerifiableOpInterface::ExternalModel<
           AssumeAlignmentOpInterface, AssumeAlignmentOp> {
-  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
-                                   Location loc) const {
+  void
+  generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
+                              function_ref<std::string(Operation *, StringRef)>
+                                  generateErrorMessage) const {
     auto assumeOp = cast<AssumeAlignmentOp>(op);
     Value ptr = ExtractAlignedPointerAsIndexOp::create(builder, loc,
                                                        assumeOp.getMemref());
@@ -48,9 +50,9 @@ struct AssumeAlignmentOpInterface
     Value isAligned =
         arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq, rest,
                               arith::ConstantIndexOp::create(builder, loc, 0));
-    cf::AssertOp::create(builder, loc, isAligned,
-                         RuntimeVerifiableOpInterface::generateErrorMessage(
-                             op, "memref is not aligned to " +
+    cf::AssertOp::create(
+        builder, loc, isAligned,
+        generateErrorMessage(op, "memref is not aligned to " +
                                      std::to_string(assumeOp.getAlignment())));
   }
 };
@@ -58,8 +60,10 @@ struct AssumeAlignmentOpInterface
 struct CastOpInterface
     : public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
                                                          CastOp> {
-  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
-                                   Location loc) const {
+  void
+  generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
+                              function_ref<std::string(Operation *, StringRef)>
+                                  generateErrorMessage) const {
     auto castOp = cast<CastOp>(op);
     auto srcType = cast<BaseMemRefType>(castOp.getSource().getType());
 
@@ -76,8 +80,7 @@ struct CastOpInterface
       Value isSameRank = arith::CmpIOp::create(
           builder, loc, arith::CmpIPredicate::eq, srcRank, resultRank);
       cf::AssertOp::create(builder, loc, isSameRank,
-                           RuntimeVerifiableOpInterface::generateErrorMessage(
-                               op, "rank mismatch"));
+                           generateErrorMessage(op, "rank mismatch"));
     }
 
     // Get source offset and strides. We do not have an op to get offsets and
@@ -116,8 +119,8 @@ struct CastOpInterface
           builder, loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
       cf::AssertOp::create(
           builder, loc, isSameSz,
-          RuntimeVerifiableOpInterface::generateErrorMessage(
-              op, "size mismatch of dim " + std::to_string(it.index())));
+          generateErrorMessage(op, "size mismatch of dim " +
+                                       std::to_string(it.index())));
     }
 
     // Get result offset and strides.
@@ -135,8 +138,7 @@ struct CastOpInterface
       Value isSameOffset = arith::CmpIOp::create(
           builder, loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal);
       cf::AssertOp::create(builder, loc, isSameOffset,
-                           RuntimeVerifiableOpInterface::generateErrorMessage(
-                               op, "offset mismatch"));
+                           generateErrorMessage(op, "offset mismatch"));
     }
 
     // Check strides.
@@ -153,8 +155,8 @@ struct CastOpInterface
           builder, loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal);
       cf::AssertOp::create(
           builder, loc, isSameStride,
-          RuntimeVerifiableOpInterface::generateErrorMessage(
-              op, "stride mismatch of dim " + std::to_string(it.index())));
+          generateErrorMessage(op, "stride mismatch of dim " +
+                                       std::to_string(it.index())));
     }
   }
 };
@@ -162,8 +164,10 @@ struct CastOpInterface
 struct CopyOpInterface
     : public RuntimeVerifiableOpInterface::ExternalModel<CopyOpInterface,
                                                          CopyOp> {
-  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
-                                   Location loc) const {
+  void
+  generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
+                              function_ref<std::string(Operation *, StringRef)>
+                                  generateErrorMessage) const {
     auto copyOp = cast<CopyOp>(op);
     BaseMemRefType sourceType = copyOp.getSource().getType();
     BaseMemRefType targetType = copyOp.getTarget().getType();
@@ -193,9 +197,9 @@ struct CopyOpInterface
       Value targetDim = getDimSize(copyOp.getTarget(), rankedTargetType, i);
       Value sameDimSize = arith::CmpIOp::create(
           builder, loc, arith::CmpIPredicate::eq, sourceDim, targetDim);
-      cf::AssertOp::create(builder, loc, sameDimSize,
-                           RuntimeVerifiableOpInterface::generateErrorMessage(
-                               op, "size of " + std::to_string(i) +
+      cf::AssertOp::create(
+          builder, loc, sameDimSize,
+          generateErrorMessage(op, "size of " + std::to_string(i) +
                                        "-th source/target dim does not match"));
     }
   }
@@ -204,16 +208,17 @@ struct CopyOpInterface
 struct DimOpInterface
     : public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
                                                          DimOp> {
-  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
-                                   Location loc) const {
+  void
+  generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
+                              function_ref<std::string(Operation *, StringRef)>
+                                  generateErrorMessage) const {
     auto dimOp = cast<DimOp>(op);
     Value rank = RankOp::create(builder, loc, dimOp.getSource());
     Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
     cf::AssertOp::create(
         builder, loc,
         generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank),
-        RuntimeVerifiableOpInterface::generateErrorMessage(
-            op, "index is out of bounds"));
+        generateErrorMessage(op, "index is out of bounds"));
   }
 };
 
@@ -223,8 +228,10 @@ template <typename LoadStoreOp>
 struct LoadStoreOpInterface
     : public RuntimeVerifiableOpInterface::ExternalModel<
           LoadStoreOpInterface<LoadStoreOp>, LoadStoreOp> {
-  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
-                                   Location loc) const {
+  void
+  generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
+                              function_ref<std::string(Operation *, StringRef)>
+                                  generateErrorMessage) const {
     auto loadStoreOp = cast<LoadStoreOp>(op);
 
     auto memref = loadStoreOp.getMemref();
@@ -245,16 +252,17 @@ struct LoadStoreOpInterface
                 : inBounds;
     }
     cf::AssertOp::create(builder, loc, assertCond,
-                         RuntimeVerifiableOpInterface::generateErrorMessage(
-                             op, "out-of-bounds access"));
+                         generateErrorMessage(op, "out-of-bounds access"));
   }
 };
 
 struct SubViewOpInterface
     : public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface,
                                                          SubViewOp> {
-  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
-                                   Location loc) const {
+  void
+  generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
+                              function_ref<std::string(Operation *, StringRef)>
+                                  generateErrorMessage) const {
     auto subView = cast<SubViewOp>(op);
     MemRefType sourceType = subView.getSource().getType();
 
@@ -277,10 +285,10 @@ struct SubViewOpInterface
       Value dimSize = metadataOp.getSizes()[i];
       Value offsetInBounds =
           generateInBoundsCheck(builder, loc, offset, zero, dimSize);
-      cf::AssertOp::create(
-          builder, loc, offsetInBounds,
-          RuntimeVerifiableOpInterface::generateErrorMessage(
-              op, "offset " + std::to_string(i) + " is out-of-bounds"));
+      cf::AssertOp::create(builder, loc, offsetInBounds,
+                           generateErrorMessage(op, "offset " +
+                                                        std::to_string(i) +
+                                                        " is out-of-bounds"));
 
       // Verify that slice does not run out-of-bounds.
       Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one);
@@ -292,9 +300,9 @@ struct SubViewOpInterface
           generateInBoundsCheck(builder, loc, lastPos, zero, dimSize);
       cf::AssertOp::create(
           builder, loc, lastPosInBounds,
-          RuntimeVerifiableOpInterface::generateErrorMessage(
-              op, "subview runs out-of-bounds along dimension " +
-                      std::to_string(i)));
+          generateErrorMessage(op,
+                               "subview runs out-of-bounds along dimension " +
+                                   std::to_string(i)));
     }
   }
 };
@@ -302,8 +310,10 @@ struct SubViewOpInterface
 struct ExpandShapeOpInterface
     : public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface,
                                                          ExpandShapeOp> {
-  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
-                                   Location loc) const {
+  void
+  generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
+                              function_ref<std::string(Operation *, StringRef)>
+                                  generateErrorMessage) const {
     auto expandShapeOp = cast<ExpandShapeOp>(op);
 
     // Verify that the expanded dim sizes are a product of the collapsed dim
@@ -333,9 +343,9 @@ struct ExpandShapeOpInterface
       Value isModZero = arith::CmpIOp::create(
           builder, loc, arith::CmpIPredicate::eq, mod,
           arith::ConstantIndexOp::create(builder, loc, 0));
-      cf::AssertOp::create(builder, loc, isModZero,
-                           RuntimeVerifiableOpInterface::generateErrorMessage(
-                               op, "static result dims in reassoc group do not "
+      cf::AssertOp::create(
+          builder, loc, isModZero,
+          generateErrorMessage(op, "static result dims in reassoc group do not "
                                    "divide src dim evenly"));
     }
   }
diff --git a/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp
index 838ff1f987c63..c031118606823 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp
@@ -35,8 +35,10 @@ Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value,
 struct CastOpInterface
     : public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
                                                          CastOp> {
-  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
-                                   Location loc) const {
+  void
+  generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
+                              function_ref<std::string(Operation *, StringRef)>
+                                  generateErrorMessage) const {
     auto castOp = cast<CastOp>(op);
     auto srcType = cast<TensorType>(castOp.getSource().getType());
 
@@ -53,8 +55,7 @@ struct CastOpInterface
       Value isSameRank = arith::CmpIOp::create(
           builder, loc, arith::CmpIPredicate::eq, srcRank, resultRank);
       cf::AssertOp::create(builder, loc, isSameRank,
-                           RuntimeVerifiableOpInterface::generateErrorMessage(
-                               op, "rank mismatch"));
+                           generateErrorMessage(op, "rank mismatch"));
     }
 
     // Check dimension sizes.
@@ -76,8 +77,8 @@ struct CastOpInterface
           builder, loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
       cf::AssertOp::create(
           builder, loc, isSameSz,
-          RuntimeVerifiableOpInterface::generateErrorMessage(
-              op, "size mismatch of dim " + std::to_string(it.index())));
+          generateErrorMessage(op, "size mismatch of dim " +
+                                       std::to_string(it.index())));
     }
   }
 };
@@ -85,16 +86,17 @@ struct CastOpInterface
 struct DimOpInterface
     : public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
                                                          DimOp> {
-  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
-                                   Location loc) const {
+  void
+  generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
+                              function_ref<std::string(Operation *, StringRef)>
+                                  generateErrorMessage) const {
     auto dimOp = cast<DimOp>(op);
     Value rank = RankOp::create(builder, loc, dimOp.getSource());
     Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
     cf::AssertOp::create(
         builder, loc,
         generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank),
-        RuntimeVerifiableOpInterface::generateErrorMessage(
-            op, "index is out of bounds"));
+        generateErrorMessage(op, "index is out of bounds"));
   }
 };
 
@@ -104,8 +106,10 @@ template <typename OpTy>
 struct ExtractInsertOpInterface
     : public RuntimeVerifiableOpInterface::ExternalModel<
           ExtractInsertOpInterface<OpTy>, OpTy> {
-  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
-                                   Location loc) const {
+  void
+  generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
+                              function_ref<std::string(Operation *, StringRef)>
+                                  generateErrorMessage) const {
     auto extractInsertOp = cast<OpTy>(op);
 
     Value tensor;
@@ -135,16 +139,17 @@ struct ExtractInsertOpInterface
                 : inBounds;
     }
     cf::AssertOp::create(builder, loc, assertCond,
-                         RuntimeVerifiableOpInterface::generateErrorMessage(
-                             op, "out-of-bounds access"));
+                         generateErrorMessage(op, "out-of-bounds access"));
   }
 };
 
 struct ExtractSliceOpInterface
     : public RuntimeVerifiableOpInterface::ExternalModel<
           ExtractSliceOpInterface, ExtractSliceOp> {
-  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
-                                   Location loc) const {
+  void
+  generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
+                              function_ref<std::string(Operation *, StringRef)>
+                                  generateErrorMessage) const {
     auto extractSliceOp = cast<ExtractSliceOp>(op);
     RankedTensorType sourceType = extractSliceOp.getSource().getType();
 
@@ -166,10 +171,10 @@ struct ExtractSliceOpInterface
           loc, extractSliceOp.getSource(), i);
       Value offsetInBounds =
           generateInBoundsCheck(builder, loc, offset, zero, dimSize);
-      cf::AssertOp::create(
-          builder, loc, offsetInBounds,
-          RuntimeVerifiableOpInterface::generateErrorMessage(
-              op, "offset " + std::to_string(i) + " is out-of-bounds"));
+      cf::AssertOp::create(builder, loc, offsetInBounds,
+                           generateErrorMessage(op, "offset " +
+                                                        std::to_string(i) +
+                                                        " is out-of-bounds"));
 
       // Verify that slice does not run out-of-bounds.
       Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one);
@@ -181,7 +186,7 @@ struct ExtractSliceOpInterface
           generateInBoundsCheck(builder, loc, lastPos, zero, dimSize);
       cf::AssertOp::create(
           builder, loc, lastPosInBounds,
-          RuntimeVerifiableOpInterface::generateErrorMessage(
+          generateErrorMessage(
               op, "extract_slice runs out-of-bounds along dimension " +
                       std::to_string(i)));
     }
diff --git a/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp b/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp
index 8aa194befb420..f9a54f950d7ff 100644
--- a/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp
+++ b/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp
@@ -8,31 +8,5 @@
 
 #include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
 
-namespace mlir {
-class Location;
-class OpBuilder;
-
-/// Generate an error message string for the given op and the specified error.
-std::string
-RuntimeVerifiableOpInterface::generateErrorMessage(Operation *op,
-                                                   const std::string &msg) {
-  std::string buffer;
-  llvm::raw_string_ostream stream(buffer);
-  OpPrintingFlags flags;
-  // We may generate a lot of error messages and so we need to ensure the
-  // printing is fast.
-  flags.elideLargeElementsAttrs();
-  flags.printGenericOpForm();
-  flags.skipRegions();
-  flags.useLocalScope();
-  stream << "ERROR: Runtime op verification failed\n";
-  op->print(stream, flags);
-  stream << "\n^ " << msg;
-  stream << "\nLocation: ";
-  op->getLoc().print(stream);
-  return buffer;
-}
-} // namespace mlir
-
 /// Include the definitions of the interface.
 #include "mlir/Interfaces/RuntimeVerifiableOpInterface.cpp.inc"
diff --git a/mlir/lib/Transforms/GenerateRuntimeVerification.cpp b/mlir/lib/Transforms/GenerateRuntimeVerification.cpp
index a40bc2b3272fc..63c71cd6fb44d 100644
--- a/mlir/lib/Transforms/GenerateRuntimeVerification.cpp
+++ b/mlir/lib/Transforms/GenerateRuntimeVerification.cpp
@@ -6,6 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/IR/AsmState.h"
 #include "mlir/Transforms/Passes.h"
 
 #include "mlir/IR/Builders.h"
@@ -25,9 +26,51 @@ struct GenerateRuntimeVerificationPass
           GenerateRuntimeVerificationPass> {
   void runOnOperation() override;
 };
+
+/// Default error message generator for runtime verification failures.
+///
+/// This class generates error messages with different levels of verbosity:
+/// - Level 0: Shows only the error message and operation location
+/// - Level 1: Shows the full operation string, error message, and location
+///
+/// Clients can call getVerboseLevel() to retrieve the current verbose level
+/// and use it to customize their own error message generators with similar
+/// behavior patterns.
+class DefaultErrMsgGenerator {
+private:
+  unsigned vLevel;
+  AsmState &state;
+
+public:
+  DefaultErrMsgGenerator(unsigned verboseLevel, AsmState &asmState)
+      : vLevel(verboseLevel), state(asmState) {}
+
+  std::string operator()(Operation *op, StringRef msg) {
+    std::string buffer;
+    llvm::raw_string_ostream stream(buffer);
+    stream << "ERROR: Runtime op verification failed\n";
+    if (vLevel == 1) {
+      op->print(stream, state);
+      stream << "\n";
+    }
+    stream << "^\nLocation: ";
+    op->getLoc().print(stream);
+    return buffer;
+  }
+
+  unsigned getVerboseLevel() const { return vLevel; }
+};
 } // namespace
 
 void GenerateRuntimeVerificationPass::runOnOperation() {
+  // Check verboseLevel is in range [0, 1].
+  if (verboseLevel > 1) {
+    getOperation()->emitError(
+        "generate-runtime-verification pass: set verboseLevel to 0 or 1");
+    signalPassFailure();
+    return;
+  }
+
   // The implementation of the RuntimeVerifiableOpInterface may create ops that
   // can be verified. We don't want to generate verification for IR that
   // performs verification, so gather all runtime-verifiable ops first.
@@ -36,10 +79,22 @@ void GenerateRuntimeVerificationPass::runOnOperation() {
     ops.push_back(verifiableOp);
   });
 
+  // We may generate a lot of error messages and so we need to ensure the
+  // printing is fast.
+  OpPrintingFlags flags;
+  flags.elideLargeElementsAttrs();
+  flags.skipRegions();
+  flags.useLocalScope();
+  AsmState state(getOperation(), flags);
+
+  // Client can call getVerboseLevel() to fetch verbose level.
+  DefaultErrMsgGenerator defaultErrMsgGenerator(verboseLevel.getValue(), state);
+
   OpBuilder builder(getOperation()->getContext());
   for (RuntimeVerifiableOpInterface verifiableOp : ops) {
     builder.setInsertionPoint(verifiableOp);
-    verifiableOp.generateRuntimeVerification(builder, verifiableOp.getLoc());
+    verifiableOp.generateRuntimeVerification(builder, verifiableOp.getLoc(),
+                                             defaultErrMsgGenerator);
   };
 }
 
diff --git a/mlir/test/Dialect/Linalg/runtime-verification.mlir b/mlir/test/Dialect/Linalg/runtime-verification.mlir
index a4f29d8457e58..07e96c823c889 100644
--- a/mlir/test/Dialect/Linalg/runtime-verification.mlir
+++ b/mlir/test/Dialect/Linalg/runtime-verification.mlir
@@ -1,13 +1,18 @@
 // RUN: mlir-opt %s -generate-runtime-verification | FileCheck %s
+// RUN: mlir-opt %s --generate-runtime-verification="verbose-level=0" | FileCheck %s --check-prefix=VERBOSE0
 
 // Most of the tests for linalg runtime-verification are implemented as integration tests.
 
 #identity = affine_map<(d0) -> (d0)>
 
 // CHECK-LABEL: @static_dims
+// VERBOSE0-LABEL: @static_dims
 func.func @static_dims(%arg0: tensor<5xf32>, %arg1: tensor<5xf32>) -> (tensor<5xf32>) {
     // CHECK: %[[TRUE:.*]] = index.bool.constant true
     // CHECK: cf.assert %[[TRUE]]
+    // VERBOSE0: %[[TRUE:.*]] = index.bool.constant true
+    // VERBOSE0: cf.assert %[[TRUE]]
+    // VERBOSE0-SAME: ERROR: Runtime op verification failed\0A^\0ALocation: loc(
     %result = tensor.empty() : tensor<5xf32> 
     %0 = linalg.generic {
       indexing_maps = [#identity, #identity, #identity],
@@ -26,9 +31,11 @@ func.func @static_dims(%arg0: tensor<5xf32>, %arg1: tensor<5xf32>) -> (tensor<5x
 #map = affine_map<() -> ()>
 
 // CHECK-LABEL: @scalars
+// VERBOSE1-LABEL: @scalars
 func.func @scalars(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>) {
     // No runtime checks are required if the operands are all scalars
     // CHECK-NOT: cf.assert
+    // VERBOSE1-NOT: cf.assert
     %result = tensor.empty() : tensor<f32> 
     %0 = linalg.generic {
       indexing_maps = [#map, #map, #map],



More information about the Mlir-commits mailing list