[Mlir-commits] [mlir] a6d09d4 - Add a `-verify-roundtrip` option to `mlir-opt` intended to validate custom printer/parser completeness

Mehdi Amini llvmlistbot at llvm.org
Thu May 25 15:16:01 PDT 2023


Author: Mehdi Amini
Date: 2023-05-25T15:15:47-07:00
New Revision: a6d09d4b1ac224ed90ee8ff6c964a2bff39421c7

URL: https://github.com/llvm/llvm-project/commit/a6d09d4b1ac224ed90ee8ff6c964a2bff39421c7
DIFF: https://github.com/llvm/llvm-project/commit/a6d09d4b1ac224ed90ee8ff6c964a2bff39421c7.diff

LOG: Add a `-verify-roundtrip` option to `mlir-opt` intended to validate custom printer/parser completeness

Running:

  MLIR_OPT_CHECK_IR_ROUNDTRIP=1 ninja check-mlir

will now exercises all of our test with a round-trip to bytecode and a comparison for equality.

Reviewed By: rriddle, ftynse, jpienaar

Differential Revision: https://reviews.llvm.org/D90088

Added: 
    

Modified: 
    mlir/include/mlir/IR/OperationSupport.h
    mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
    mlir/lib/IR/AsmPrinter.cpp
    mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
    mlir/test/lit.cfg.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 3631e41c1234d..f2bae58c80b3a 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -1103,7 +1103,7 @@ class OpPrintingFlags {
   OpPrintingFlags &enableDebugInfo(bool enable = true, bool prettyForm = false);
 
   /// Always print operations in the generic form.
-  OpPrintingFlags &printGenericOpForm();
+  OpPrintingFlags &printGenericOpForm(bool enable = true);
 
   /// Skip printing regions.
   OpPrintingFlags &skipRegions(bool skip = true);

diff  --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
index 8f1969cd3ad36..222a51e8db77e 100644
--- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
+++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
@@ -163,6 +163,13 @@ class MlirOptMainConfig {
   }
   bool shouldVerifyPasses() const { return verifyPassesFlag; }
 
+  /// Set whether to run the verifier after each transformation pass.
+  MlirOptMainConfig &verifyRoundtrip(bool verify) {
+    verifyRoundtripFlag = verify;
+    return *this;
+  }
+  bool shouldVerifyRoundtrip() const { return verifyRoundtripFlag; }
+
 protected:
   /// Allow operation with no registered dialects.
   /// This option is for convenience during testing only and discouraged in
@@ -212,6 +219,9 @@ class MlirOptMainConfig {
 
   /// Run the verifier after each transformation pass.
   bool verifyPassesFlag = true;
+
+  /// Verify that the input IR round-trips perfectly.
+  bool verifyRoundtripFlag = false;
 };
 
 /// This defines the function type used to setup the pass manager. This can be

diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 206a097a02802..a75bc584ef3a5 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -219,8 +219,8 @@ OpPrintingFlags &OpPrintingFlags::enableDebugInfo(bool enable,
 }
 
 /// Always print operations in the generic form.
-OpPrintingFlags &OpPrintingFlags::printGenericOpForm() {
-  printGenericOpFormFlag = true;
+OpPrintingFlags &OpPrintingFlags::printGenericOpForm(bool enable) {
+  printGenericOpFormFlag = enable;
   return *this;
 }
 

diff  --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
index 2aad687e41957..644113058bdc1 100644
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -139,6 +139,11 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig {
         cl::desc("Run the verifier after each transformation pass"),
         cl::location(verifyPassesFlag), cl::init(true));
 
+    static cl::opt<bool, /*ExternalStorage=*/true> verifyRoundtrip(
+        "verify-roundtrip",
+        cl::desc("Round-trip the IR after parsing and ensure it succeeds"),
+        cl::location(verifyRoundtripFlag), cl::init(false));
+
     static cl::list<std::string> passPlugins(
         "load-pass-plugin", cl::desc("Load passes from plugin library"));
     /// Set the callback to load a pass plugin.
@@ -213,6 +218,104 @@ void MlirOptMainConfigCLOptions::setDialectPluginsCallback(
   });
 }
 
+LogicalResult loadIRDLDialects(StringRef irdlFile, MLIRContext &ctx) {
+  DialectRegistry registry;
+  registry.insert<irdl::IRDLDialect>();
+  ctx.appendDialectRegistry(registry);
+
+  // Set up the input file.
+  std::string errorMessage;
+  std::unique_ptr<MemoryBuffer> file = openInputFile(irdlFile, &errorMessage);
+  if (!file) {
+    emitError(UnknownLoc::get(&ctx)) << errorMessage;
+    return failure();
+  }
+
+  // Give the buffer to the source manager.
+  // This will be picked up by the parser.
+  SourceMgr sourceMgr;
+  sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc());
+
+  SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &ctx);
+
+  // Parse the input file.
+  OwningOpRef<ModuleOp> module(parseSourceFile<ModuleOp>(sourceMgr, &ctx));
+
+  // Load IRDL dialects.
+  return irdl::loadDialects(module.get());
+}
+
+// Return success if the module can correctly round-trip. This intended to test
+// that the custom printers/parsers are complete.
+static LogicalResult doVerifyRoundTrip(Operation *op,
+                                       const MlirOptMainConfig &config,
+                                       bool useBytecode) {
+  // We use a new context to avoid resource handle renaming issue in the 
diff .
+  MLIRContext roundtripContext;
+  OwningOpRef<Operation *> roundtripModule;
+  roundtripContext.appendDialectRegistry(
+      op->getContext()->getDialectRegistry());
+  if (op->getContext()->allowsUnregisteredDialects())
+    roundtripContext.allowUnregisteredDialects();
+  StringRef irdlFile = config.getIrdlFile();
+  if (!irdlFile.empty() && failed(loadIRDLDialects(irdlFile, roundtripContext)))
+    return failure();
+
+  // Print a first time with custom format (or bytecode) and parse it back to
+  // the roundtripModule.
+  {
+    std::string buffer;
+    llvm::raw_string_ostream ostream(buffer);
+    if (useBytecode) {
+      if (failed(writeBytecodeToFile(op, ostream))) {
+        op->emitOpError() << "failed to write bytecode, cannot verify round-trip.\n";
+        return failure();
+      }
+    } else {
+      op->print(ostream,
+                OpPrintingFlags().printGenericOpForm(false).enableDebugInfo());
+    }
+    FallbackAsmResourceMap fallbackResourceMap;
+    ParserConfig parseConfig(&roundtripContext, /*verifyAfterParse=*/true,
+                             &fallbackResourceMap);
+    roundtripModule =
+        parseSourceString<Operation *>(ostream.str(), parseConfig);
+    if (!roundtripModule) {
+      op->emitOpError() << "failed to parse bytecode back, cannot verify round-trip.\n";
+      return failure();
+    }
+  }
+
+  // Print in the generic form for the reference module and the round-tripped
+  // one and compare the outputs.
+  std::string reference, roundtrip;
+  {
+    llvm::raw_string_ostream ostreamref(reference);
+    op->print(ostreamref,
+              OpPrintingFlags().printGenericOpForm().enableDebugInfo());
+    llvm::raw_string_ostream ostreamrndtrip(roundtrip);
+    roundtripModule.get()->print(
+        ostreamrndtrip,
+        OpPrintingFlags().printGenericOpForm().enableDebugInfo());
+  }
+  if (reference != roundtrip) {
+    // TODO implement a 
diff .
+    return op->emitOpError() << "roundTrip testing roundtripped module 
diff ers from reference:\n<<<<<<Reference\n"
+                             << reference << "\n=====\n"
+                             << roundtrip << "\n>>>>>roundtripped\n";
+  }
+
+  return success();
+}
+
+static LogicalResult doVerifyRoundTrip(Operation *op,
+                                       const MlirOptMainConfig &config) {
+  // Textual round-trip isn't fully robust at the moment (for example implicit
+  // terminator are losing location informations).
+
+  return doVerifyRoundTrip(op, config, /*useBytecode=*/true);
+}
+
 /// Perform the actions on the input file indicated by the command line flags
 /// within the specified context.
 ///
@@ -247,10 +350,16 @@ performActions(raw_ostream &os,
   TimingScope parserTiming = timing.nest("Parser");
   OwningOpRef<Operation *> op = parseSourceFileForTool(
       sourceMgr, parseConfig, !config.shouldUseExplicitModule());
-  context->enableMultithreading(wasThreadingEnabled);
+  parserTiming.stop();
   if (!op)
     return failure();
-  parserTiming.stop();
+
+  // Perform round-trip verification if requested
+  if (config.shouldVerifyRoundtrip() &&
+      failed(doVerifyRoundTrip(op.get(), config)))
+    return failure();
+
+  context->enableMultithreading(wasThreadingEnabled);
 
   // Prepare the pass manager, applying command-line and reproducer options.
   PassManager pm(op.get()->getName(), PassManager::Nesting::Implicit);
@@ -286,33 +395,6 @@ performActions(raw_ostream &os,
   return success();
 }
 
-LogicalResult loadIRDLDialects(StringRef irdlFile, MLIRContext &ctx) {
-  DialectRegistry registry;
-  registry.insert<irdl::IRDLDialect>();
-  ctx.appendDialectRegistry(registry);
-
-  // Set up the input file.
-  std::string errorMessage;
-  std::unique_ptr<MemoryBuffer> file = openInputFile(irdlFile, &errorMessage);
-  if (!file) {
-    emitError(UnknownLoc::get(&ctx)) << errorMessage;
-    return failure();
-  }
-
-  // Give the buffer to the source manager.
-  // This will be picked up by the parser.
-  SourceMgr sourceMgr;
-  sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc());
-
-  SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &ctx);
-
-  // Parse the input file.
-  OwningOpRef<ModuleOp> module(parseSourceFile<ModuleOp>(sourceMgr, &ctx));
-
-  // Load IRDL dialects.
-  return irdl::loadDialects(module.get());
-}
-
 /// Parses the memory buffer.  If successfully, run a series of passes against
 /// it and print the result.
 static LogicalResult processBuffer(raw_ostream &os,

diff  --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py
index 7c8bd6aea8aba..1fc2e319d19fd 100644
--- a/mlir/test/lit.cfg.py
+++ b/mlir/test/lit.cfg.py
@@ -65,7 +65,6 @@ def add_runtime(name):
 
 tool_dirs = [config.mlir_tools_dir, config.llvm_tools_dir]
 tools = [
-    'mlir-opt',
     'mlir-tblgen',
     'mlir-translate',
     'mlir-lsp-server',
@@ -125,6 +124,11 @@ def add_runtime(name):
   ToolSubst('%PYTHON', python_executable, unresolved='ignore'),
 ])
 
+if "MLIR_OPT_CHECK_IR_ROUNDTRIP" in os.environ:
+  tools.extend([
+    ToolSubst('mlir-opt', 'mlir-opt --verify-roundtrip', unresolved='fatal'),
+  ])
+
 llvm_config.add_tool_substitutions(tools, tool_dirs)
 
 


        


More information about the Mlir-commits mailing list