[Mlir-commits] [mlir] a6d09d4 - Add a `-verify-roundtrip` option to `mlir-opt` intended to validate custom printer/parser completeness
Mehdi Amini
llvmlistbot at llvm.org
Thu May 25 15:16:01 PDT 2023
Author: Mehdi Amini
Date: 2023-05-25T15:15:47-07:00
New Revision: a6d09d4b1ac224ed90ee8ff6c964a2bff39421c7
URL: https://github.com/llvm/llvm-project/commit/a6d09d4b1ac224ed90ee8ff6c964a2bff39421c7
DIFF: https://github.com/llvm/llvm-project/commit/a6d09d4b1ac224ed90ee8ff6c964a2bff39421c7.diff
LOG: Add a `-verify-roundtrip` option to `mlir-opt` intended to validate custom printer/parser completeness
Running:
MLIR_OPT_CHECK_IR_ROUNDTRIP=1 ninja check-mlir
will now exercises all of our test with a round-trip to bytecode and a comparison for equality.
Reviewed By: rriddle, ftynse, jpienaar
Differential Revision: https://reviews.llvm.org/D90088
Added:
Modified:
mlir/include/mlir/IR/OperationSupport.h
mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
mlir/test/lit.cfg.py
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 3631e41c1234d..f2bae58c80b3a 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -1103,7 +1103,7 @@ class OpPrintingFlags {
OpPrintingFlags &enableDebugInfo(bool enable = true, bool prettyForm = false);
/// Always print operations in the generic form.
- OpPrintingFlags &printGenericOpForm();
+ OpPrintingFlags &printGenericOpForm(bool enable = true);
/// Skip printing regions.
OpPrintingFlags &skipRegions(bool skip = true);
diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
index 8f1969cd3ad36..222a51e8db77e 100644
--- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
+++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
@@ -163,6 +163,13 @@ class MlirOptMainConfig {
}
bool shouldVerifyPasses() const { return verifyPassesFlag; }
+ /// Set whether to run the verifier after each transformation pass.
+ MlirOptMainConfig &verifyRoundtrip(bool verify) {
+ verifyRoundtripFlag = verify;
+ return *this;
+ }
+ bool shouldVerifyRoundtrip() const { return verifyRoundtripFlag; }
+
protected:
/// Allow operation with no registered dialects.
/// This option is for convenience during testing only and discouraged in
@@ -212,6 +219,9 @@ class MlirOptMainConfig {
/// Run the verifier after each transformation pass.
bool verifyPassesFlag = true;
+
+ /// Verify that the input IR round-trips perfectly.
+ bool verifyRoundtripFlag = false;
};
/// This defines the function type used to setup the pass manager. This can be
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 206a097a02802..a75bc584ef3a5 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -219,8 +219,8 @@ OpPrintingFlags &OpPrintingFlags::enableDebugInfo(bool enable,
}
/// Always print operations in the generic form.
-OpPrintingFlags &OpPrintingFlags::printGenericOpForm() {
- printGenericOpFormFlag = true;
+OpPrintingFlags &OpPrintingFlags::printGenericOpForm(bool enable) {
+ printGenericOpFormFlag = enable;
return *this;
}
diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
index 2aad687e41957..644113058bdc1 100644
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -139,6 +139,11 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig {
cl::desc("Run the verifier after each transformation pass"),
cl::location(verifyPassesFlag), cl::init(true));
+ static cl::opt<bool, /*ExternalStorage=*/true> verifyRoundtrip(
+ "verify-roundtrip",
+ cl::desc("Round-trip the IR after parsing and ensure it succeeds"),
+ cl::location(verifyRoundtripFlag), cl::init(false));
+
static cl::list<std::string> passPlugins(
"load-pass-plugin", cl::desc("Load passes from plugin library"));
/// Set the callback to load a pass plugin.
@@ -213,6 +218,104 @@ void MlirOptMainConfigCLOptions::setDialectPluginsCallback(
});
}
+LogicalResult loadIRDLDialects(StringRef irdlFile, MLIRContext &ctx) {
+ DialectRegistry registry;
+ registry.insert<irdl::IRDLDialect>();
+ ctx.appendDialectRegistry(registry);
+
+ // Set up the input file.
+ std::string errorMessage;
+ std::unique_ptr<MemoryBuffer> file = openInputFile(irdlFile, &errorMessage);
+ if (!file) {
+ emitError(UnknownLoc::get(&ctx)) << errorMessage;
+ return failure();
+ }
+
+ // Give the buffer to the source manager.
+ // This will be picked up by the parser.
+ SourceMgr sourceMgr;
+ sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc());
+
+ SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &ctx);
+
+ // Parse the input file.
+ OwningOpRef<ModuleOp> module(parseSourceFile<ModuleOp>(sourceMgr, &ctx));
+
+ // Load IRDL dialects.
+ return irdl::loadDialects(module.get());
+}
+
+// Return success if the module can correctly round-trip. This intended to test
+// that the custom printers/parsers are complete.
+static LogicalResult doVerifyRoundTrip(Operation *op,
+ const MlirOptMainConfig &config,
+ bool useBytecode) {
+ // We use a new context to avoid resource handle renaming issue in the
diff .
+ MLIRContext roundtripContext;
+ OwningOpRef<Operation *> roundtripModule;
+ roundtripContext.appendDialectRegistry(
+ op->getContext()->getDialectRegistry());
+ if (op->getContext()->allowsUnregisteredDialects())
+ roundtripContext.allowUnregisteredDialects();
+ StringRef irdlFile = config.getIrdlFile();
+ if (!irdlFile.empty() && failed(loadIRDLDialects(irdlFile, roundtripContext)))
+ return failure();
+
+ // Print a first time with custom format (or bytecode) and parse it back to
+ // the roundtripModule.
+ {
+ std::string buffer;
+ llvm::raw_string_ostream ostream(buffer);
+ if (useBytecode) {
+ if (failed(writeBytecodeToFile(op, ostream))) {
+ op->emitOpError() << "failed to write bytecode, cannot verify round-trip.\n";
+ return failure();
+ }
+ } else {
+ op->print(ostream,
+ OpPrintingFlags().printGenericOpForm(false).enableDebugInfo());
+ }
+ FallbackAsmResourceMap fallbackResourceMap;
+ ParserConfig parseConfig(&roundtripContext, /*verifyAfterParse=*/true,
+ &fallbackResourceMap);
+ roundtripModule =
+ parseSourceString<Operation *>(ostream.str(), parseConfig);
+ if (!roundtripModule) {
+ op->emitOpError() << "failed to parse bytecode back, cannot verify round-trip.\n";
+ return failure();
+ }
+ }
+
+ // Print in the generic form for the reference module and the round-tripped
+ // one and compare the outputs.
+ std::string reference, roundtrip;
+ {
+ llvm::raw_string_ostream ostreamref(reference);
+ op->print(ostreamref,
+ OpPrintingFlags().printGenericOpForm().enableDebugInfo());
+ llvm::raw_string_ostream ostreamrndtrip(roundtrip);
+ roundtripModule.get()->print(
+ ostreamrndtrip,
+ OpPrintingFlags().printGenericOpForm().enableDebugInfo());
+ }
+ if (reference != roundtrip) {
+ // TODO implement a
diff .
+ return op->emitOpError() << "roundTrip testing roundtripped module
diff ers from reference:\n<<<<<<Reference\n"
+ << reference << "\n=====\n"
+ << roundtrip << "\n>>>>>roundtripped\n";
+ }
+
+ return success();
+}
+
+static LogicalResult doVerifyRoundTrip(Operation *op,
+ const MlirOptMainConfig &config) {
+ // Textual round-trip isn't fully robust at the moment (for example implicit
+ // terminator are losing location informations).
+
+ return doVerifyRoundTrip(op, config, /*useBytecode=*/true);
+}
+
/// Perform the actions on the input file indicated by the command line flags
/// within the specified context.
///
@@ -247,10 +350,16 @@ performActions(raw_ostream &os,
TimingScope parserTiming = timing.nest("Parser");
OwningOpRef<Operation *> op = parseSourceFileForTool(
sourceMgr, parseConfig, !config.shouldUseExplicitModule());
- context->enableMultithreading(wasThreadingEnabled);
+ parserTiming.stop();
if (!op)
return failure();
- parserTiming.stop();
+
+ // Perform round-trip verification if requested
+ if (config.shouldVerifyRoundtrip() &&
+ failed(doVerifyRoundTrip(op.get(), config)))
+ return failure();
+
+ context->enableMultithreading(wasThreadingEnabled);
// Prepare the pass manager, applying command-line and reproducer options.
PassManager pm(op.get()->getName(), PassManager::Nesting::Implicit);
@@ -286,33 +395,6 @@ performActions(raw_ostream &os,
return success();
}
-LogicalResult loadIRDLDialects(StringRef irdlFile, MLIRContext &ctx) {
- DialectRegistry registry;
- registry.insert<irdl::IRDLDialect>();
- ctx.appendDialectRegistry(registry);
-
- // Set up the input file.
- std::string errorMessage;
- std::unique_ptr<MemoryBuffer> file = openInputFile(irdlFile, &errorMessage);
- if (!file) {
- emitError(UnknownLoc::get(&ctx)) << errorMessage;
- return failure();
- }
-
- // Give the buffer to the source manager.
- // This will be picked up by the parser.
- SourceMgr sourceMgr;
- sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc());
-
- SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &ctx);
-
- // Parse the input file.
- OwningOpRef<ModuleOp> module(parseSourceFile<ModuleOp>(sourceMgr, &ctx));
-
- // Load IRDL dialects.
- return irdl::loadDialects(module.get());
-}
-
/// Parses the memory buffer. If successfully, run a series of passes against
/// it and print the result.
static LogicalResult processBuffer(raw_ostream &os,
diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py
index 7c8bd6aea8aba..1fc2e319d19fd 100644
--- a/mlir/test/lit.cfg.py
+++ b/mlir/test/lit.cfg.py
@@ -65,7 +65,6 @@ def add_runtime(name):
tool_dirs = [config.mlir_tools_dir, config.llvm_tools_dir]
tools = [
- 'mlir-opt',
'mlir-tblgen',
'mlir-translate',
'mlir-lsp-server',
@@ -125,6 +124,11 @@ def add_runtime(name):
ToolSubst('%PYTHON', python_executable, unresolved='ignore'),
])
+if "MLIR_OPT_CHECK_IR_ROUNDTRIP" in os.environ:
+ tools.extend([
+ ToolSubst('mlir-opt', 'mlir-opt --verify-roundtrip', unresolved='fatal'),
+ ])
+
llvm_config.add_tool_substitutions(tools, tool_dirs)
More information about the Mlir-commits
mailing list