[Mlir-commits] [mlir] [mlir] Retain original identifier names for debugging v2 (PR #119944)
Maksim Levental
llvmlistbot at llvm.org
Sat Dec 14 12:54:17 PST 2024
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/119944
>From f77018548f0924c4561664966cc9bd644a7a125b Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Fri, 13 Dec 2024 23:39:19 -0500
Subject: [PATCH] [mlir] retain identifier names
---
mlir/include/mlir/IR/AsmState.h | 9 +-
mlir/include/mlir/IR/OperationSupport.h | 7 ++
mlir/include/mlir/IR/Value.h | 18 ++-
.../include/mlir/Tools/mlir-opt/MlirOptMain.h | 9 ++
mlir/lib/AsmParser/Parser.cpp | 103 ++++++++++++++----
mlir/lib/IR/AsmPrinter.cpp | 60 ++++++++--
mlir/lib/IR/Operation.cpp | 9 +-
mlir/lib/Tools/mlir-opt/MlirOptMain.cpp | 8 +-
8 files changed, 183 insertions(+), 40 deletions(-)
diff --git a/mlir/include/mlir/IR/AsmState.h b/mlir/include/mlir/IR/AsmState.h
index edbd3bb6fc15db..e13a9324b1f669 100644
--- a/mlir/include/mlir/IR/AsmState.h
+++ b/mlir/include/mlir/IR/AsmState.h
@@ -471,8 +471,10 @@ class ParserConfig {
/// `fallbackResourceMap` is an optional fallback handler that can be used to
/// parse external resources not explicitly handled by another parser.
ParserConfig(MLIRContext *context, bool verifyAfterParse = true,
- FallbackAsmResourceMap *fallbackResourceMap = nullptr)
+ FallbackAsmResourceMap *fallbackResourceMap = nullptr,
+ bool retainIdentifierNames = false)
: context(context), verifyAfterParse(verifyAfterParse),
+ retainIdentifierNames(retainIdentifierNames),
fallbackResourceMap(fallbackResourceMap) {
assert(context && "expected valid MLIR context");
}
@@ -483,6 +485,10 @@ class ParserConfig {
/// Returns if the parser should verify the IR after parsing.
bool shouldVerifyAfterParse() const { return verifyAfterParse; }
+ /// Returns if the parser should retain identifier names collected using
+ /// parsing.
+ bool shouldRetainIdentifierNames() const { return retainIdentifierNames; }
+
/// Returns the parsing configurations associated to the bytecode read.
BytecodeReaderConfig &getBytecodeReaderConfig() const {
return const_cast<BytecodeReaderConfig &>(bytecodeReaderConfig);
@@ -520,6 +526,7 @@ class ParserConfig {
private:
MLIRContext *context;
bool verifyAfterParse;
+ bool retainIdentifierNames;
DenseMap<StringRef, std::unique_ptr<AsmResourceParser>> resourceParsers;
FallbackAsmResourceMap *fallbackResourceMap;
BytecodeReaderConfig bytecodeReaderConfig;
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 1b93f3d3d04fe8..c82743acd78094 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -1221,6 +1221,10 @@ class OpPrintingFlags {
/// Return if printer should use unique SSA IDs.
bool shouldPrintUniqueSSAIDs() const;
+ /// Returns if the printer should retain identifier names collected using
+ /// parsing.
+ bool shouldPrintRetainedIdentifierNames() const;
+
private:
/// Elide large elements attributes if the number of elements is larger than
/// the upper limit.
@@ -1254,6 +1258,9 @@ class OpPrintingFlags {
/// Print unique SSA IDs for values, block arguments and naming conflicts
bool printUniqueSSAIDsFlag : 1;
+
+ /// Print the retained original names of identifiers
+ bool printRetainedIdentifierNamesFlag : 1;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h
index a7344c64e6730d..d9335fbe7a5a7e 100644
--- a/mlir/include/mlir/IR/Value.h
+++ b/mlir/include/mlir/IR/Value.h
@@ -367,7 +367,8 @@ namespace detail {
/// This class provides the implementation for an operation result.
class alignas(8) OpResultImpl : public ValueImpl {
public:
- using ValueImpl::ValueImpl;
+ OpResultImpl(Type type, Kind kind, Location loc)
+ : ValueImpl(type, kind), loc(loc) {}
static bool classof(const ValueImpl *value) {
return value->getKind() != ValueImpl::Kind::BlockArgument;
@@ -390,14 +391,17 @@ class alignas(8) OpResultImpl : public ValueImpl {
static unsigned getMaxInlineResults() {
return static_cast<unsigned>(Kind::OutOfLineOpResult);
}
+
+ /// The source location of this result.
+ Location loc;
};
/// This class provides the implementation for an operation result whose index
/// can be represented "inline" in the underlying ValueImpl.
struct InlineOpResult : public OpResultImpl {
public:
- InlineOpResult(Type type, unsigned resultNo)
- : OpResultImpl(type, static_cast<ValueImpl::Kind>(resultNo)) {
+ InlineOpResult(Type type, unsigned resultNo, Location loc)
+ : OpResultImpl(type, static_cast<ValueImpl::Kind>(resultNo), loc) {
assert(resultNo < getMaxInlineResults());
}
@@ -413,8 +417,8 @@ struct InlineOpResult : public OpResultImpl {
/// cannot be represented "inline", and thus requires an additional index field.
class OutOfLineOpResult : public OpResultImpl {
public:
- OutOfLineOpResult(Type type, uint64_t outOfLineIndex)
- : OpResultImpl(type, Kind::OutOfLineOpResult),
+ OutOfLineOpResult(Type type, uint64_t outOfLineIndex, Location loc)
+ : OpResultImpl(type, Kind::OutOfLineOpResult, loc),
outOfLineIndex(outOfLineIndex) {}
static bool classof(const OpResultImpl *value) {
@@ -468,6 +472,10 @@ class OpResult : public Value {
/// Returns the number of this result.
unsigned getResultNumber() const { return getImpl()->getResultNumber(); }
+ /// Return the location for this result.
+ Location getLoc() const { return getImpl()->loc; }
+ void setLoc(Location loc) { getImpl()->loc = loc; }
+
private:
/// Get a raw pointer to the internal implementation.
detail::OpResultImpl *getImpl() const {
diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
index 160585e7da5486..d8cc11f13c815a 100644
--- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
+++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
@@ -198,6 +198,12 @@ class MlirOptMainConfig {
}
bool shouldVerifyPasses() const { return verifyPassesFlag; }
+ MlirOptMainConfig &retainIdentifierNames(bool retain) {
+ retainIdentifierNamesFlag = retain;
+ return *this;
+ }
+ bool shouldRetainIdentifierNames() const { return retainIdentifierNamesFlag; }
+
/// Set whether to run the verifier on parsing.
MlirOptMainConfig &verifyOnParsing(bool verify) {
disableVerifierOnParsingFlag = !verify;
@@ -284,6 +290,9 @@ class MlirOptMainConfig {
/// Run the verifier after each transformation pass.
bool verifyPassesFlag = true;
+ /// Retain identifier names in the output (e.g., `%my_var` instead of `%0`).
+ bool retainIdentifierNamesFlag = false;
+
/// Disable the verifier on parsing.
bool disableVerifierOnParsingFlag = false;
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index e3db248164672c..d487b0bbdb31aa 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -543,6 +543,10 @@ Type Parser::codeCompleteDialectSymbol(const llvm::StringMap<Type> &aliases) {
//===----------------------------------------------------------------------===//
namespace {
+/// This is the structure of a result specifier in the assembly syntax,
+/// including the name, number of results, and location.
+using ResultRecord = std::tuple<StringRef, unsigned, SMLoc>;
+
/// This class provides support for parsing operations and regions of
/// operations.
class OperationParser : public Parser {
@@ -618,7 +622,8 @@ class OperationParser : public Parser {
ParseResult parseSuccessors(SmallVectorImpl<Block *> &destinations);
/// Parse an operation instance that is in the generic form.
- Operation *parseGenericOperation();
+ Operation *parseGenericOperation(
+ std::optional<ArrayRef<ResultRecord>> resultIDs = std::nullopt);
/// Parse different components, viz., use-info of operand(s), successor(s),
/// region(s), attribute(s) and function-type, of the generic form of an
@@ -659,10 +664,6 @@ class OperationParser : public Parser {
/// token is actually an alias, which means it must not contain a dot.
ParseResult parseLocationAlias(LocationAttr &loc);
- /// This is the structure of a result specifier in the assembly syntax,
- /// including the name, number of results, and location.
- using ResultRecord = std::tuple<StringRef, unsigned, SMLoc>;
-
/// Parse an operation instance that is in the op-defined custom form.
/// resultInfo specifies information about the "%name =" specifiers.
Operation *parseCustomOperation(ArrayRef<ResultRecord> resultIDs);
@@ -1238,7 +1239,7 @@ ParseResult OperationParser::parseOperation() {
if (nameTok.is(Token::bare_identifier) || nameTok.isKeyword())
op = parseCustomOperation(resultIDs);
else if (nameTok.is(Token::string))
- op = parseGenericOperation();
+ op = parseGenericOperation(resultIDs);
else if (nameTok.isCodeCompletionFor(Token::string))
return codeCompleteStringDialectOrOperationName(nameTok.getStringValue());
else if (nameTok.isCodeCompletion())
@@ -1344,6 +1345,38 @@ struct CleanupOpStateRegions {
}
OperationState &state;
};
+
+std::pair<StringRef, unsigned> getResultName(ArrayRef<ResultRecord> resultIDs,
+ unsigned resultNo) {
+ // Scan for the resultID that contains this result number.
+ for (const auto &entry : resultIDs) {
+ if (resultNo < std::get<1>(entry)) {
+ // Don't pass on the leading %.
+ StringRef name = std::get<0>(entry).drop_front();
+ return {name, resultNo};
+ }
+ resultNo -= std::get<1>(entry);
+ }
+
+ // Invalid result number.
+ return {"", ~0U};
+}
+
+std::pair<SMLoc, unsigned> getResultLoc(ArrayRef<ResultRecord> resultIDs,
+ unsigned resultNo) {
+ // Scan for the resultID that contains this result number.
+ for (const auto &entry : resultIDs) {
+ if (resultNo < std::get<1>(entry)) {
+ SMLoc loc = std::get<2>(entry);
+ return {loc, resultNo};
+ }
+ resultNo -= std::get<1>(entry);
+ }
+
+ // Invalid result number.
+ return {SMLoc{}, ~0U};
+}
+
} // namespace
ParseResult OperationParser::parseGenericOperationAfterOpName(
@@ -1457,7 +1490,8 @@ ParseResult OperationParser::parseGenericOperationAfterOpName(
return success();
}
-Operation *OperationParser::parseGenericOperation() {
+Operation *OperationParser::parseGenericOperation(
+ std::optional<ArrayRef<ResultRecord>> maybeResultIDs) {
// Get location information for the operation.
auto srcLocation = getEncodedSourceLocation(getToken().getLoc());
@@ -1531,6 +1565,17 @@ Operation *OperationParser::parseGenericOperation() {
// Create the operation and try to parse a location for it.
Operation *op = opBuilder.create(result);
+ if (state.config.shouldRetainIdentifierNames() && maybeResultIDs) {
+ for (OpResult opResult : op->getResults()) {
+ unsigned resultNum = opResult.getResultNumber();
+ Location resultLoc = getEncodedSourceLocation(
+ getResultLoc(*maybeResultIDs, resultNum).first);
+ opResult.setLoc(NameLoc::get(
+ StringAttr::get(state.config.getContext(),
+ getResultName(*maybeResultIDs, resultNum).first),
+ resultLoc));
+ }
+ }
if (parseTrailingLocationSpecifier(op))
return nullptr;
@@ -1571,7 +1616,7 @@ namespace {
class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
public:
CustomOpAsmParser(
- SMLoc nameLoc, ArrayRef<OperationParser::ResultRecord> resultIDs,
+ SMLoc nameLoc, ArrayRef<ResultRecord> resultIDs,
function_ref<ParseResult(OpAsmParser &, OperationState &)> parseAssembly,
bool isIsolatedFromAbove, StringRef opName, OperationParser &parser)
: AsmParserImpl<OpAsmParser>(nameLoc, parser), resultIDs(resultIDs),
@@ -1634,18 +1679,7 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
/// getResultName(3) == {"z", 0 }
std::pair<StringRef, unsigned>
getResultName(unsigned resultNo) const override {
- // Scan for the resultID that contains this result number.
- for (const auto &entry : resultIDs) {
- if (resultNo < std::get<1>(entry)) {
- // Don't pass on the leading %.
- StringRef name = std::get<0>(entry).drop_front();
- return {name, resultNo};
- }
- resultNo -= std::get<1>(entry);
- }
-
- // Invalid result number.
- return {"", ~0U};
+ return ::getResultName(resultIDs, resultNo);
}
/// Return the number of declared SSA results. This returns 4 for the foo.op
@@ -1962,7 +1996,7 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
private:
/// Information about the result name specifiers.
- ArrayRef<OperationParser::ResultRecord> resultIDs;
+ ArrayRef<ResultRecord> resultIDs;
/// The abstract information of the operation.
function_ref<ParseResult(OpAsmParser &, OperationState &)> parseAssembly;
@@ -2093,6 +2127,18 @@ OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) {
// Otherwise, create the operation and try to parse a location for it.
Operation *op = opBuilder.create(opState);
+
+ if (state.config.shouldRetainIdentifierNames()) {
+ for (OpResult opResult : op->getResults()) {
+ unsigned resultNum = opResult.getResultNumber();
+ Location resultLoc =
+ getEncodedSourceLocation(getResultLoc(resultIDs, resultNum).first);
+ StringRef resName = opAsmParser.getResultName(resultNum).first;
+ opResult.setLoc(NameLoc::get(
+ StringAttr::get(state.config.getContext(), resName), resultLoc));
+ }
+ }
+
if (parseTrailingLocationSpecifier(op))
return nullptr;
@@ -2159,8 +2205,11 @@ OperationParser::parseTrailingLocationSpecifier(OpOrArgument opOrArgument) {
if (parseToken(Token::r_paren, "expected ')' in location"))
return failure();
- if (auto *op = llvm::dyn_cast_if_present<Operation *>(opOrArgument))
+ if (auto *op = llvm::dyn_cast_if_present<Operation *>(opOrArgument)) {
op->setLoc(directLoc);
+ for (auto result : op->getResults())
+ result.setLoc(directLoc);
+ }
else
opOrArgument.get<BlockArgument>().setLoc(directLoc);
return success();
@@ -2235,6 +2284,11 @@ ParseResult OperationParser::parseRegionBody(Region ®ion, SMLoc startLoc,
Location loc = entryArg.sourceLoc.has_value()
? *entryArg.sourceLoc
: getEncodedSourceLocation(argInfo.location);
+ if (state.config.shouldRetainIdentifierNames()) {
+ loc = NameLoc::get(StringAttr::get(state.config.getContext(),
+ entryArg.ssaName.name.drop_front(1)),
+ loc);
+ }
BlockArgument arg = block->addArgument(entryArg.type, loc);
// Add a definition of this arg to the assembly state if provided.
@@ -2415,6 +2469,11 @@ ParseResult OperationParser::parseOptionalBlockArgList(Block *owner) {
return emitError("argument and block argument type mismatch");
} else {
auto loc = getEncodedSourceLocation(useInfo.location);
+ if (state.config.shouldRetainIdentifierNames()) {
+ loc = NameLoc::get(StringAttr::get(state.config.getContext(),
+ useInfo.name.drop_front(1)),
+ loc);
+ }
arg = owner->addArgument(type, loc);
}
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 61b90bc9b0a7bb..29b205828703cd 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -73,7 +73,8 @@ OpAsmParser::~OpAsmParser() = default;
MLIRContext *AsmParser::getContext() const { return getBuilder().getContext(); }
/// Parse a type list.
-/// This is out-of-line to work-around https://github.com/llvm/llvm-project/issues/62918
+/// This is out-of-line to work-around
+/// https://github.com/llvm/llvm-project/issues/62918
ParseResult AsmParser::parseTypeList(SmallVectorImpl<Type> &result) {
return parseCommaSeparatedList(
[&]() { return parseType(result.emplace_back()); });
@@ -195,6 +196,10 @@ struct AsmPrinterOptions {
"mlir-print-unique-ssa-ids", llvm::cl::init(false),
llvm::cl::desc("Print unique SSA ID numbers for values, block arguments "
"and naming conflicts across all regions")};
+
+ llvm::cl::opt<bool> printRetainedIdentifierNamesFlag{
+ "mlir-print-retained-identifier-names", llvm::cl::init(false),
+ llvm::cl::desc("Print the retained original names of identifiers")};
};
} // namespace
@@ -212,7 +217,8 @@ OpPrintingFlags::OpPrintingFlags()
: printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false),
printGenericOpFormFlag(false), skipRegionsFlag(false),
assumeVerifiedFlag(false), printLocalScope(false),
- printValueUsersFlag(false), printUniqueSSAIDsFlag(false) {
+ printValueUsersFlag(false), printUniqueSSAIDsFlag(false),
+ printRetainedIdentifierNamesFlag(false) {
// Initialize based upon command line options, if they are available.
if (!clOptions.isConstructed())
return;
@@ -231,6 +237,8 @@ OpPrintingFlags::OpPrintingFlags()
skipRegionsFlag = clOptions->skipRegionsOpt;
printValueUsersFlag = clOptions->printValueUsers;
printUniqueSSAIDsFlag = clOptions->printUniqueSSAIDs;
+ printRetainedIdentifierNamesFlag =
+ clOptions->printRetainedIdentifierNamesFlag;
}
/// Enable the elision of large elements attributes, by printing a '...'
@@ -362,6 +370,11 @@ bool OpPrintingFlags::shouldPrintUniqueSSAIDs() const {
return printUniqueSSAIDsFlag || shouldPrintGenericOpForm();
}
+/// Return if the printer should use unique IDs.
+bool OpPrintingFlags::shouldPrintRetainedIdentifierNames() const {
+ return printRetainedIdentifierNamesFlag;
+}
+
//===----------------------------------------------------------------------===//
// NewLineCounter
//===----------------------------------------------------------------------===//
@@ -1511,7 +1524,13 @@ void SSANameState::numberValuesInRegion(Region ®ion) {
assert(!valueIDs.count(arg) && "arg numbered multiple times");
assert(llvm::cast<BlockArgument>(arg).getOwner()->getParent() == ®ion &&
"arg not defined in current region");
- setValueName(arg, name);
+ if (printerFlags.shouldPrintRetainedIdentifierNames() &&
+ isa<NameLoc>(arg.getLoc())) {
+ auto nameLoc = cast<NameLoc>(arg.getLoc());
+ setValueName(arg, nameLoc.getName());
+ } else {
+ setValueName(arg, name);
+ }
};
if (!printerFlags.shouldPrintGenericOpForm()) {
@@ -1553,7 +1572,13 @@ void SSANameState::numberValuesInBlock(Block &block) {
specialNameBuffer.resize(strlen("arg"));
specialName << nextArgumentID++;
}
- setValueName(arg, specialName.str());
+ if (printerFlags.shouldPrintRetainedIdentifierNames() &&
+ isa<NameLoc>(arg.getLoc())) {
+ auto nameLoc = cast<NameLoc>(arg.getLoc());
+ setValueName(arg, nameLoc.getName());
+ } else {
+ setValueName(arg, specialName.str());
+ }
}
// Number the operations in this block.
@@ -1567,7 +1592,13 @@ void SSANameState::numberValuesInOp(Operation &op) {
auto setResultNameFn = [&](Value result, StringRef name) {
assert(!valueIDs.count(result) && "result numbered multiple times");
assert(result.getDefiningOp() == &op && "result not defined by 'op'");
- setValueName(result, name);
+ if (printerFlags.shouldPrintRetainedIdentifierNames() &&
+ isa<NameLoc>(result.getLoc())) {
+ auto nameLoc = cast<NameLoc>(result.getLoc());
+ setValueName(result, nameLoc.getName());
+ } else {
+ setValueName(result, name);
+ }
// Record the result number for groups not anchored at 0.
if (int resultNo = llvm::cast<OpResult>(result).getResultNumber())
@@ -1608,8 +1639,23 @@ void SSANameState::numberValuesInOp(Operation &op) {
Value resultBegin = op.getResult(0);
// If the first result wasn't numbered, give it a default number.
- if (valueIDs.try_emplace(resultBegin, nextValueID).second)
- ++nextValueID;
+ if (!printerFlags.shouldPrintRetainedIdentifierNames()) {
+ if (valueIDs.try_emplace(resultBegin, nextValueID).second)
+ ++nextValueID;
+ } else {
+ for (OpResult opResult : op.getOpResults()) {
+ if (!isa<OpAsmOpInterface>(&op) &&
+ printerFlags.shouldPrintRetainedIdentifierNames() &&
+ isa<NameLoc>(opResult.getLoc())) {
+ auto nameLoc = cast<NameLoc>(opResult.getLoc());
+ setResultNameFn(opResult, nameLoc.getName());
+ } else {
+ // If the first result wasn't numbered, give it a default number.
+ if (valueIDs.try_emplace(opResult, nextValueID).second)
+ ++nextValueID;
+ }
+ }
+ }
// If this operation has multiple result groups, mark it.
if (resultGroups.size() != 1) {
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index fe0fee0f8db2ce..bfd87ef790f994 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -125,10 +125,11 @@ Operation *Operation::create(Location location, OperationName name,
// Initialize the results.
auto resultTypeIt = resultTypes.begin();
for (unsigned i = 0; i < numInlineResults; ++i, ++resultTypeIt)
- new (op->getInlineOpResult(i)) detail::InlineOpResult(*resultTypeIt, i);
+ new (op->getInlineOpResult(i))
+ detail::InlineOpResult(*resultTypeIt, i, op->getLoc());
for (unsigned i = 0; i < numTrailingResults; ++i, ++resultTypeIt) {
new (op->getOutOfLineOpResult(i))
- detail::OutOfLineOpResult(*resultTypeIt, i);
+ detail::OutOfLineOpResult(*resultTypeIt, i, op->getLoc());
}
// Initialize the regions.
@@ -322,8 +323,8 @@ void Operation::setAttrs(DictionaryAttr newAttrs) {
}
void Operation::setAttrs(ArrayRef<NamedAttribute> newAttrs) {
if (getPropertiesStorageSize()) {
- // We're spliting the providing array of attributes by removing the inherentAttr
- // which will be stored in the properties.
+ // We're spliting the providing array of attributes by removing the
+ // inherentAttr which will be stored in the properties.
SmallVector<NamedAttribute> discardableAttrs;
discardableAttrs.reserve(newAttrs.size());
for (NamedAttribute attr : newAttrs) {
diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
index 9bbf91de183051..b90900bdd6c0b2 100644
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -184,6 +184,11 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig {
cl::desc("Disable the verifier on parsing (very unsafe)"),
cl::location(disableVerifierOnParsingFlag), cl::init(false));
+ static cl::opt<bool, /*ExternalStorage=*/true> retainIdentifierNames(
+ "mlir-retain-identifier-names",
+ cl::desc("Retain the original names of identifiers when printing"),
+ cl::location(retainIdentifierNamesFlag), cl::init(false));
+
static cl::opt<bool, /*ExternalStorage=*/true> verifyRoundtrip(
"verify-roundtrip",
cl::desc("Round-trip the IR after parsing and ensure it succeeds"),
@@ -373,7 +378,8 @@ static LogicalResult doVerifyRoundTrip(Operation *op,
}
FallbackAsmResourceMap fallbackResourceMap;
ParserConfig parseConfig(&roundtripContext, config.shouldVerifyOnParsing(),
- &fallbackResourceMap);
+ &fallbackResourceMap,
+ config.shouldRetainIdentifierNames());
roundtripModule = parseSourceString<Operation *>(buffer, parseConfig);
if (!roundtripModule) {
op->emitOpError() << "failed to parse " << testType
More information about the Mlir-commits
mailing list