[Mlir-commits] [mlir] [mlir] implement `-verify-diagnostics=only-expected` (PR #135131)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 10 10:50:26 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Maksim Levental (makslevental)

<details>
<summary>Changes</summary>

This PR implements `verify-diagnostics=only-expected` which is a "best effort" verification - i.e., `unexpected`s and `near-misses` will not be considered failures. The purpose is to enable narrowly scoped checking of verification remarks (just as we have for lit where only a subset of lines get `CHECK`ed).

---
Full diff: https://github.com/llvm/llvm-project/pull/135131.diff


9 Files Affected:

- (modified) mlir/examples/transform-opt/mlir-transform-opt.cpp (+31-13) 
- (modified) mlir/include/mlir/IR/Diagnostics.h (+7-2) 
- (modified) mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h (+13-3) 
- (modified) mlir/lib/IR/Diagnostics.cpp (+7-4) 
- (modified) mlir/lib/Tools/mlir-opt/MlirOptMain.cpp (+22-6) 
- (modified) mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp (+21-10) 
- (modified) mlir/test/Pass/full_diagnostics.mlir (+1) 
- (added) mlir/test/Pass/full_diagnostics_only_expected.mlir (+17) 
- (added) mlir/test/mlir-translate/verify-only-expected.mlir (+49) 


``````````diff
diff --git a/mlir/examples/transform-opt/mlir-transform-opt.cpp b/mlir/examples/transform-opt/mlir-transform-opt.cpp
index 10e16096211ad..ce23e2d67101c 100644
--- a/mlir/examples/transform-opt/mlir-transform-opt.cpp
+++ b/mlir/examples/transform-opt/mlir-transform-opt.cpp
@@ -38,11 +38,20 @@ struct MlirTransformOptCLOptions {
       cl::desc("Allow operations coming from an unregistered dialect"),
       cl::init(false)};
 
-  cl::opt<bool> verifyDiagnostics{
-      "verify-diagnostics",
-      cl::desc("Check that emitted diagnostics match expected-* lines "
-               "on the corresponding line"),
-      cl::init(false)};
+  cl::opt<mlir::SourceMgrDiagnosticVerifierHandler::Level> verifyDiagnostics{
+      "verify-diagnostics", llvm::cl::ValueOptional,
+      cl::desc("Check that emitted diagnostics match "
+               "expected-* lines on the corresponding line"),
+      cl::values(
+          clEnumValN(
+              mlir::SourceMgrDiagnosticVerifierHandler::Level::All, "all",
+              "Check all diagnostics (expected, unexpected, near-misses)"),
+          clEnumValN(
+              mlir::SourceMgrDiagnosticVerifierHandler::Level::All, "",
+              "Check all diagnostics (expected, unexpected, near-misses)"),
+          clEnumValN(
+              mlir::SourceMgrDiagnosticVerifierHandler::Level::OnlyExpected,
+              "only-expected", "Check only expected diagnostics"))};
 
   cl::opt<std::string> payloadFilename{cl::Positional, cl::desc("<input file>"),
                                        cl::init("-")};
@@ -102,12 +111,17 @@ class DiagnosticHandlerWrapper {
 
   /// Constructs the diagnostic handler of the specified kind of the given
   /// source manager and context.
-  DiagnosticHandlerWrapper(Kind kind, llvm::SourceMgr &mgr,
-                           mlir::MLIRContext *context) {
-    if (kind == Kind::EmitDiagnostics)
+  DiagnosticHandlerWrapper(
+      Kind kind, llvm::SourceMgr &mgr, mlir::MLIRContext *context,
+      std::optional<mlir::SourceMgrDiagnosticVerifierHandler::Level> level =
+          {}) {
+    if (kind == Kind::EmitDiagnostics) {
       handler = new mlir::SourceMgrDiagnosticHandler(mgr, context);
-    else
-      handler = new mlir::SourceMgrDiagnosticVerifierHandler(mgr, context);
+    } else {
+      assert(level.has_value() && "expected level");
+      handler =
+          new mlir::SourceMgrDiagnosticVerifierHandler(mgr, context, *level);
+    }
   }
 
   /// This object is non-copyable but movable.
@@ -150,7 +164,9 @@ class TransformSourceMgr {
 public:
   /// Constructs the source manager indicating whether diagnostic messages will
   /// be verified later on.
-  explicit TransformSourceMgr(bool verifyDiagnostics)
+  explicit TransformSourceMgr(
+      std::optional<mlir::SourceMgrDiagnosticVerifierHandler::Level>
+          verifyDiagnostics)
       : verifyDiagnostics(verifyDiagnostics) {}
 
   /// Deconstructs the source manager. Note that `checkResults` must have been
@@ -179,7 +195,8 @@ class TransformSourceMgr {
     // verification needs to happen and store it.
     if (verifyDiagnostics) {
       diagHandlers.emplace_back(
-          DiagnosticHandlerWrapper::Kind::VerifyDiagnostics, mgr, &context);
+          DiagnosticHandlerWrapper::Kind::VerifyDiagnostics, mgr, &context,
+          verifyDiagnostics);
     } else {
       diagHandlers.emplace_back(DiagnosticHandlerWrapper::Kind::EmitDiagnostics,
                                 mgr, &context);
@@ -204,7 +221,8 @@ class TransformSourceMgr {
 
 private:
   /// Indicates whether diagnostic message verification is requested.
-  const bool verifyDiagnostics;
+  const std::optional<mlir::SourceMgrDiagnosticVerifierHandler::Level>
+      verifyDiagnostics;
 
   /// Indicates that diagnostic message verification has taken place, and the
   /// deconstruction is therefore safe.
diff --git a/mlir/include/mlir/IR/Diagnostics.h b/mlir/include/mlir/IR/Diagnostics.h
index 36c433c63b26d..77fa022146759 100644
--- a/mlir/include/mlir/IR/Diagnostics.h
+++ b/mlir/include/mlir/IR/Diagnostics.h
@@ -626,9 +626,12 @@ struct SourceMgrDiagnosticVerifierHandlerImpl;
 /// corresponding line of the source file.
 class SourceMgrDiagnosticVerifierHandler : public SourceMgrDiagnosticHandler {
 public:
+  enum class Level { None = 0, All, OnlyExpected };
   SourceMgrDiagnosticVerifierHandler(llvm::SourceMgr &srcMgr, MLIRContext *ctx,
-                                     raw_ostream &out);
-  SourceMgrDiagnosticVerifierHandler(llvm::SourceMgr &srcMgr, MLIRContext *ctx);
+                                     raw_ostream &out,
+                                     Level level = Level::All);
+  SourceMgrDiagnosticVerifierHandler(llvm::SourceMgr &srcMgr, MLIRContext *ctx,
+                                     Level level = Level::All);
   ~SourceMgrDiagnosticVerifierHandler();
 
   /// Returns the status of the handler and verifies that all expected
@@ -644,6 +647,8 @@ class SourceMgrDiagnosticVerifierHandler : public SourceMgrDiagnosticHandler {
   void process(FileLineColLoc loc, StringRef msg, DiagnosticSeverity kind);
 
   std::unique_ptr<detail::SourceMgrDiagnosticVerifierHandlerImpl> impl;
+
+  Level level = Level::All;
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
index 09bd86b9581df..ab9395791f179 100644
--- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
+++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
@@ -185,11 +185,20 @@ class MlirOptMainConfig {
 
   /// Set whether to check that emitted diagnostics match `expected-*` lines on
   /// the corresponding line. This is meant for implementing diagnostic tests.
-  MlirOptMainConfig &verifyDiagnostics(bool verify) {
+  MlirOptMainConfig &
+  verifyDiagnostics(SourceMgrDiagnosticVerifierHandler::Level verify) {
     verifyDiagnosticsFlag = verify;
     return *this;
   }
-  bool shouldVerifyDiagnostics() const { return verifyDiagnosticsFlag; }
+
+  bool shouldVerifyDiagnostics() const {
+    return verifyDiagnosticsFlag !=
+           SourceMgrDiagnosticVerifierHandler::Level::None;
+  }
+
+  SourceMgrDiagnosticVerifierHandler::Level verifyDiagnosticsLevel() const {
+    return verifyDiagnosticsFlag;
+  }
 
   /// Set whether to run the verifier after each transformation pass.
   MlirOptMainConfig &verifyPasses(bool verify) {
@@ -279,7 +288,8 @@ class MlirOptMainConfig {
 
   /// Set whether to check that emitted diagnostics match `expected-*` lines on
   /// the corresponding line. This is meant for implementing diagnostic tests.
-  bool verifyDiagnosticsFlag = false;
+  SourceMgrDiagnosticVerifierHandler::Level verifyDiagnosticsFlag =
+      SourceMgrDiagnosticVerifierHandler::Level::None;
 
   /// Run the verifier after each transformation pass.
   bool verifyPassesFlag = true;
diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp
index 19b32120f5890..4ac4bf4a1b37c 100644
--- a/mlir/lib/IR/Diagnostics.cpp
+++ b/mlir/lib/IR/Diagnostics.cpp
@@ -795,9 +795,9 @@ SourceMgrDiagnosticVerifierHandlerImpl::computeExpectedDiags(
 }
 
 SourceMgrDiagnosticVerifierHandler::SourceMgrDiagnosticVerifierHandler(
-    llvm::SourceMgr &srcMgr, MLIRContext *ctx, raw_ostream &out)
+    llvm::SourceMgr &srcMgr, MLIRContext *ctx, raw_ostream &out, Level level)
     : SourceMgrDiagnosticHandler(srcMgr, ctx, out),
-      impl(new SourceMgrDiagnosticVerifierHandlerImpl()) {
+      impl(new SourceMgrDiagnosticVerifierHandlerImpl()), level(level) {
   // Compute the expected diagnostics for each of the current files in the
   // source manager.
   for (unsigned i = 0, e = mgr.getNumBuffers(); i != e; ++i)
@@ -815,8 +815,8 @@ SourceMgrDiagnosticVerifierHandler::SourceMgrDiagnosticVerifierHandler(
 }
 
 SourceMgrDiagnosticVerifierHandler::SourceMgrDiagnosticVerifierHandler(
-    llvm::SourceMgr &srcMgr, MLIRContext *ctx)
-    : SourceMgrDiagnosticVerifierHandler(srcMgr, ctx, llvm::errs()) {}
+    llvm::SourceMgr &srcMgr, MLIRContext *ctx, Level level)
+    : SourceMgrDiagnosticVerifierHandler(srcMgr, ctx, llvm::errs(), level) {}
 
 SourceMgrDiagnosticVerifierHandler::~SourceMgrDiagnosticVerifierHandler() {
   // Ensure that all expected diagnostics were handled.
@@ -886,6 +886,9 @@ void SourceMgrDiagnosticVerifierHandler::process(FileLineColLoc loc,
     }
   }
 
+  if (level == Level::OnlyExpected)
+    return;
+
   // Otherwise, emit an error for the near miss.
   if (nearMiss)
     mgr.PrintMessage(os, nearMiss->fileLoc, llvm::SourceMgr::DK_Error,
diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
index 9bbf91de18305..2774dfe7e69a1 100644
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -168,11 +168,26 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig {
         cl::desc("Split marker to use for merging the ouput"),
         cl::location(outputSplitMarkerFlag), cl::init(kDefaultSplitMarker));
 
-    static cl::opt<bool, /*ExternalStorage=*/true> verifyDiagnostics(
-        "verify-diagnostics",
-        cl::desc("Check that emitted diagnostics match "
-                 "expected-* lines on the corresponding line"),
-        cl::location(verifyDiagnosticsFlag), cl::init(false));
+    static cl::opt<SourceMgrDiagnosticVerifierHandler::Level,
+                   /*ExternalStorage=*/true>
+        verifyDiagnostics{
+            "verify-diagnostics", llvm::cl::ValueOptional,
+            cl::desc("Check that emitted diagnostics match "
+                     "expected-* lines on the corresponding line"),
+            cl::location(verifyDiagnosticsFlag),
+            cl::values(
+                clEnumValN(SourceMgrDiagnosticVerifierHandler::Level::All,
+                           "all",
+                           "Check all diagnostics (expected, unexpected, "
+                           "near-misses)"),
+                // Implicit value: when passed with no arguments, e.g.
+                // `--verify-diagnostics` or `--verify-diagnostics=`.
+                clEnumValN(SourceMgrDiagnosticVerifierHandler::Level::All, "",
+                           "Check all diagnostics (expected, unexpected, "
+                           "near-misses)"),
+                clEnumValN(
+                    SourceMgrDiagnosticVerifierHandler::Level::OnlyExpected,
+                    "only-expected", "Check only expected diagnostics"))};
 
     static cl::opt<bool, /*ExternalStorage=*/true> verifyPasses(
         "verify-each",
@@ -542,7 +557,8 @@ static LogicalResult processBuffer(raw_ostream &os,
     return performActions(os, sourceMgr, &context, config);
   }
 
-  SourceMgrDiagnosticVerifierHandler sourceMgrHandler(*sourceMgr, &context);
+  SourceMgrDiagnosticVerifierHandler sourceMgrHandler(
+      *sourceMgr, &context, config.verifyDiagnosticsLevel());
 
   // Do any processing requested by command line flags.  We don't care whether
   // these actions succeed or fail, we only care what diagnostics they produce
diff --git a/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp b/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp
index 56773f599d5ce..ceccd1ea69e96 100644
--- a/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp
+++ b/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp
@@ -58,7 +58,8 @@ LogicalResult mlir::mlirTranslateMain(int argc, char **argv,
 
   static llvm::cl::opt<bool> allowUnregisteredDialects(
       "allow-unregistered-dialect",
-      llvm::cl::desc("Allow operation with no registered dialects (discouraged: testing only!)"),
+      llvm::cl::desc("Allow operation with no registered dialects "
+                     "(discouraged: testing only!)"),
       llvm::cl::init(false));
 
   static llvm::cl::opt<std::string> inputSplitMarker{
@@ -72,11 +73,21 @@ LogicalResult mlir::mlirTranslateMain(int argc, char **argv,
                      "default marker and process each chunk independently"),
       llvm::cl::init("")};
 
-  static llvm::cl::opt<bool> verifyDiagnostics(
-      "verify-diagnostics",
-      llvm::cl::desc("Check that emitted diagnostics match "
-                     "expected-* lines on the corresponding line"),
-      llvm::cl::init(false));
+  static llvm::cl::opt<SourceMgrDiagnosticVerifierHandler::Level>
+      verifyDiagnostics{
+          "verify-diagnostics", llvm::cl::ValueOptional,
+          llvm::cl::desc("Check that emitted diagnostics match "
+                         "expected-* lines on the corresponding line"),
+          llvm::cl::values(
+              clEnumValN(
+                  SourceMgrDiagnosticVerifierHandler::Level::All, "all",
+                  "Check all diagnostics (expected, unexpected, near-misses)"),
+              clEnumValN(
+                  SourceMgrDiagnosticVerifierHandler::Level::All, "",
+                  "Check all diagnostics (expected, unexpected, near-misses)"),
+              clEnumValN(
+                  SourceMgrDiagnosticVerifierHandler::Level::OnlyExpected,
+                  "only-expected", "Check only expected diagnostics"))};
 
   static llvm::cl::opt<bool> errorDiagnosticsOnly(
       "error-diagnostics-only",
@@ -149,17 +160,17 @@ LogicalResult mlir::mlirTranslateMain(int argc, char **argv,
 
       MLIRContext context;
       context.allowUnregisteredDialects(allowUnregisteredDialects);
-      context.printOpOnDiagnostic(!verifyDiagnostics);
+      context.printOpOnDiagnostic(!bool(verifyDiagnostics.getNumOccurrences()));
       auto sourceMgr = std::make_shared<llvm::SourceMgr>();
       sourceMgr->AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
 
-      if (verifyDiagnostics) {
+      if (verifyDiagnostics.getNumOccurrences()) {
         // In the diagnostic verification flow, we ignore whether the
         // translation failed (in most cases, it is expected to fail) and we do
         // not filter non-error diagnostics even if `errorDiagnosticsOnly` is
         // set. Instead, we check if the diagnostics were produced as expected.
-        SourceMgrDiagnosticVerifierHandler sourceMgrHandler(*sourceMgr,
-                                                            &context);
+        SourceMgrDiagnosticVerifierHandler sourceMgrHandler(
+            *sourceMgr, &context, verifyDiagnostics);
         (void)(*translationRequested)(sourceMgr, os, &context);
         result = sourceMgrHandler.verify();
       } else if (errorDiagnosticsOnly) {
diff --git a/mlir/test/Pass/full_diagnostics.mlir b/mlir/test/Pass/full_diagnostics.mlir
index 881260ce60d32..c355fd071bdb6 100644
--- a/mlir/test/Pass/full_diagnostics.mlir
+++ b/mlir/test/Pass/full_diagnostics.mlir
@@ -1,4 +1,5 @@
 // RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(test-pass-failure{gen-diagnostics}))' -verify-diagnostics
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(test-pass-failure{gen-diagnostics}))' -verify-diagnostics=all
 
 // Test that all errors are reported.
 // expected-error at below {{illegal operation}}
diff --git a/mlir/test/Pass/full_diagnostics_only_expected.mlir b/mlir/test/Pass/full_diagnostics_only_expected.mlir
new file mode 100644
index 0000000000000..de29693604b4f
--- /dev/null
+++ b/mlir/test/Pass/full_diagnostics_only_expected.mlir
@@ -0,0 +1,17 @@
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(test-pass-failure{gen-diagnostics}))' -verify-diagnostics=only-expected
+
+// Test that only expected errors are reported.
+// reports {{illegal operation}} but unchecked
+func.func @TestAlwaysIllegalOperationPass1() {
+  return
+}
+
+// expected-error at +1 {{illegal operation}}
+func.func @TestAlwaysIllegalOperationPass2() {
+  return
+}
+
+// reports {{illegal operation}} but unchecked
+func.func @TestAlwaysIllegalOperationPass3() {
+  return
+}
diff --git a/mlir/test/mlir-translate/verify-only-expected.mlir b/mlir/test/mlir-translate/verify-only-expected.mlir
new file mode 100644
index 0000000000000..543a6954c840f
--- /dev/null
+++ b/mlir/test/mlir-translate/verify-only-expected.mlir
@@ -0,0 +1,49 @@
+// Note: borrowed/copied from mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
+
+// Check that verify-diagnostics=only-expected passes with only one actual `expected-error`
+// RUN: mlir-translate %s -verify-diagnostics=only-expected -split-input-file -mlir-to-llvmir
+
+// Check that verify-diagnostics=all fails because we're missing three `expected-error`
+// RUN: not mlir-translate %s -verify-diagnostics=all -split-input-file -mlir-to-llvmir 2>&1 | FileCheck %s --check-prefix=CHECK-VERIFY-ALL
+// CHECK-VERIFY-ALL:      error: unexpected error: 'arm_sme.intr.write.horiz' op failed to verify that all of {predicate, vector} have same shape
+// CHECK-VERIFY-ALL-NEXT: "arm_sme.intr.write.horiz"
+// CHECK-VERIFY-ALL:      error: unexpected error: 'arm_sme.intr.read.horiz' op failed to verify that all of {vector, res} have same element type
+// CHECK-VERIFY-ALL-NEXT: %res = "arm_sme.intr.read.horiz"
+// CHECK-VERIFY-ALL:      error: unexpected error: 'arm_sme.intr.cntsb' op failed to verify that `res` is i64
+// CHECK-VERIFY-ALL-NEXT: %res = "arm_sme.intr.cntsb"
+
+llvm.func @arm_sme_vector_to_tile_invalid_types(%tileslice : i32,
+                                                %nxv4i1 : vector<[4]xi1>,
+                                                %nxv16i8 : vector<[16]xi8>) {
+  "arm_sme.intr.write.horiz"(%tileslice, %nxv4i1, %nxv16i8) <{tile_id = 0 : i32}> :
+      (i32, vector<[4]xi1>, vector<[16]xi8>) -> ()
+  llvm.return
+}
+
+// -----
+
+llvm.func @arm_sme_tile_slice_to_vector_invalid_shapes(
+  %tileslice : i32, %nxv4i1 : vector<[4]xi1>, %nxv16i8 : vector<[16]xi8>
+) -> vector<[3]xf32> {
+  // expected-error @+1 {{failed to verify that all of {vector, predicate, res} have same shape}}
+  %res = "arm_sme.intr.read.horiz"(%nxv16i8, %nxv4i1, %tileslice) <{tile_id = 0 : i32}> :
+      (vector<[16]xi8>, vector<[4]xi1>, i32) -> vector<[3]xf32>
+  llvm.return %res : vector<[3]xf32>
+}
+
+// -----
+
+llvm.func @arm_sme_tile_slice_to_vector_invalid_element_types(
+  %tileslice : i32, %nxv4i1 : vector<[4]xi1>, %nxv4f32 : vector<[4]xf32>
+) -> vector<[3]xi32> {
+  %res = "arm_sme.intr.read.horiz"(%nxv4f32, %nxv4i1, %tileslice) <{tile_id = 0 : i32}> :
+      (vector<[4]xf32>, vector<[4]xi1>, i32) -> vector<[4]xi32>
+  llvm.return %res : vector<[4]xi32>
+}
+
+// -----
+
+llvm.func @arm_sme_streaming_vl_invalid_return_type() -> i32 {
+  %res = "arm_sme.intr.cntsb"() : () -> i32
+  llvm.return %res : i32
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/135131


More information about the Mlir-commits mailing list