[Mlir-commits] [mlir] Strenghten bytecode roundtrip test (PR #82946)
Matteo Franciolini
llvmlistbot at llvm.org
Sun Feb 25 16:48:40 PST 2024
https://github.com/mfrancio created https://github.com/llvm/llvm-project/pull/82946
This patch implements a mechanism to fail a bytecode roundtrip test only
if the corresponding textual roundtrip is successful. Right now the test
assumes that printers/parsers are correctly implemented for the textual
format. This is not the case for some of the upstream dialects, for
example SPIRV.
In addition to that, it fixes roundtripping to bytecode the
`IR/affine-map.mlir` test, which relies on the parser to simplify some
of the affine maps. However, at the end of the parsing this
simplification is not complete: some expressions are further unique'd
after an additional roundtrip to file, either in text or bytecode
format.
>From 39f3a099d458b3341f2961ddd0c0ef0a102ed3b0 Mon Sep 17 00:00:00 2001
From: Matteo Franciolini <mfranciolini at tesla.com>
Date: Sun, 25 Feb 2024 16:19:47 -0800
Subject: [PATCH] Strenghten bytecode roundtrip test
This patch implements a mechanism to fail a bytecode roundtrip test only
if the corresponding textual roundtrip is successful. Right now the test
assumes that printers/parsers are correctly implemented for the textual
format. This is not the case for some of the upstream dialects, for
example SPIRV.
In addition to that, it fixes roundtripping to bytecode the
`IR/affine-map.mlir` test, which relies on the parser to simplify some
of the affine maps. However, at the end of the parsing this
simplification is not complete: some expressions are further unique'd
after an additional roundtrip to file, either in text or bytecode
format.
---
mlir/lib/Tools/mlir-opt/MlirOptMain.cpp | 83 +++++++++++++++++--------
1 file changed, 57 insertions(+), 26 deletions(-)
diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
index f01c7631decb77..36fe09d20b6e58 100644
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -27,6 +27,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/OwningOpRef.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
@@ -266,18 +267,62 @@ static LogicalResult doVerifyRoundTrip(Operation *op,
const MlirOptMainConfig &config,
bool useBytecode) {
// We use a new context to avoid resource handle renaming issue in the diff.
+ auto initializeNewContext = [&](MLIRContext &newContext) -> LogicalResult {
+ OwningOpRef<Operation *> roundtripModule;
+ newContext.appendDialectRegistry(op->getContext()->getDialectRegistry());
+ if (op->getContext()->allowsUnregisteredDialects())
+ newContext.allowUnregisteredDialects();
+ StringRef irdlFile = config.getIrdlFile();
+ if (!irdlFile.empty() && failed(loadIRDLDialects(irdlFile, newContext)))
+ return failure();
+ return success();
+ };
MLIRContext roundtripContext;
- OwningOpRef<Operation *> roundtripModule;
- roundtripContext.appendDialectRegistry(
- op->getContext()->getDialectRegistry());
- if (op->getContext()->allowsUnregisteredDialects())
- roundtripContext.allowUnregisteredDialects();
- StringRef irdlFile = config.getIrdlFile();
- if (!irdlFile.empty() && failed(loadIRDLDialects(irdlFile, roundtripContext)))
+ if (failed(initializeNewContext(roundtripContext)))
return failure();
+ auto parseMLIRString =
+ [&](const std::string &stringBuffer, MLIRContext &context) -> OwningOpRef<Operation *> {
+ FallbackAsmResourceMap fallbackResourceMap;
+ ParserConfig parseConfig(&context, /*verifyAfterParse=*/true,
+ &fallbackResourceMap);
+ OwningOpRef<Operation *> roundtripModule =
+ parseSourceString<Operation *>(stringBuffer, parseConfig);
+ return roundtripModule;
+ };
+
+ // Print the operation to string. If we are going to verify the roundtrip to
+ // bytecode, make sure first that a roundtrip to text of the same IR is
+ // possible.
+ std::string reference;
+ {
+ llvm::raw_string_ostream ostream(reference);
+ op->print(ostream,
+ OpPrintingFlags().printGenericOpForm().enableDebugInfo());
+ // When testing a bytecode roundtrip, we don't want to report failure if the
+ // textual roundtrip also fails.
+ if (useBytecode) {
+ MLIRContext textualContext;
+ if (failed(initializeNewContext(textualContext)))
+ return failure();
+ OwningOpRef<Operation *> textualModule =
+ parseMLIRString(ostream.str(), textualContext);
+ // If we can't parse the string back, we can't guarantee that bytecode
+ // will be parsed correctly.
+ if (!textualModule)
+ return success();
+
+ // Clear the reference, and print the textual roundtrip.
+ reference.clear();
+ llvm::raw_string_ostream ostreamref(reference);
+ textualModule->print(
+ ostreamref, OpPrintingFlags().printGenericOpForm().enableDebugInfo());
+ }
+ }
+
// Print a first time with custom format (or bytecode) and parse it back to
// the roundtripModule.
+ std::string roundtrip;
{
std::string buffer;
llvm::raw_string_ostream ostream(buffer);
@@ -291,25 +336,11 @@ static LogicalResult doVerifyRoundTrip(Operation *op,
op->print(ostream,
OpPrintingFlags().printGenericOpForm(false).enableDebugInfo());
}
- FallbackAsmResourceMap fallbackResourceMap;
- ParserConfig parseConfig(&roundtripContext, /*verifyAfterParse=*/true,
- &fallbackResourceMap);
- roundtripModule =
- parseSourceString<Operation *>(ostream.str(), parseConfig);
- if (!roundtripModule) {
- op->emitOpError()
- << "failed to parse bytecode back, cannot verify round-trip.\n";
- return failure();
- }
- }
-
- // Print in the generic form for the reference module and the round-tripped
- // one and compare the outputs.
- std::string reference, roundtrip;
- {
- llvm::raw_string_ostream ostreamref(reference);
- op->print(ostreamref,
- OpPrintingFlags().printGenericOpForm().enableDebugInfo());
+ OwningOpRef<Operation *> roundtripModule =
+ parseMLIRString(ostream.str(), roundtripContext);
+ if (!roundtripModule)
+ return op->emitOpError()
+ << "failed to parse bytecode back, cannot verify round-trip.\n";
llvm::raw_string_ostream ostreamrndtrip(roundtrip);
roundtripModule.get()->print(
ostreamrndtrip,
More information about the Mlir-commits
mailing list