[Mlir-commits] [mlir] [mlir] Retain original identifier names for debugging v2 (PR #119944)
Maksim Levental
llvmlistbot at llvm.org
Fri Dec 13 18:55:30 PST 2024
https://github.com/makslevental created https://github.com/llvm/llvm-project/pull/119944
Based on https://github.com/llvm/llvm-project/pull/79704
>From b8505eb41d6cb595683daf69fbda299164c7b700 Mon Sep 17 00:00:00 2001
From: Perry Gibson <perry at gibsonic.org>
Date: Fri, 26 Jan 2024 14:34:14 +0000
Subject: [PATCH 01/11] Added single result SSA processing
---
mlir/include/mlir/IR/AsmState.h | 12 +++++++++---
mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h | 10 ++++++++++
mlir/lib/AsmParser/Parser.cpp | 17 +++++++++++++++++
mlir/lib/IR/AsmPrinter.cpp | 9 ++++++++-
mlir/lib/Tools/mlir-opt/MlirOptMain.cpp | 10 ++++++++--
5 files changed, 52 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/IR/AsmState.h b/mlir/include/mlir/IR/AsmState.h
index edbd3bb6fc15db..c030c243bac4d9 100644
--- a/mlir/include/mlir/IR/AsmState.h
+++ b/mlir/include/mlir/IR/AsmState.h
@@ -146,8 +146,7 @@ class AsmResourceBlob {
/// Return the underlying data as an array of the given type. This is an
/// inherrently unsafe operation, and should only be used when the data is
/// known to be of the correct type.
- template <typename T>
- ArrayRef<T> getDataAs() const {
+ template <typename T> ArrayRef<T> getDataAs() const {
return llvm::ArrayRef<T>((const T *)data.data(), data.size() / sizeof(T));
}
@@ -471,8 +470,10 @@ class ParserConfig {
/// `fallbackResourceMap` is an optional fallback handler that can be used to
/// parse external resources not explicitly handled by another parser.
ParserConfig(MLIRContext *context, bool verifyAfterParse = true,
- FallbackAsmResourceMap *fallbackResourceMap = nullptr)
+ FallbackAsmResourceMap *fallbackResourceMap = nullptr,
+ bool retainIdentifierNames = false)
: context(context), verifyAfterParse(verifyAfterParse),
+ retainIdentifierNames(retainIdentifierNames),
fallbackResourceMap(fallbackResourceMap) {
assert(context && "expected valid MLIR context");
}
@@ -483,6 +484,10 @@ class ParserConfig {
/// Returns if the parser should verify the IR after parsing.
bool shouldVerifyAfterParse() const { return verifyAfterParse; }
+ /// Returns if the parser should retain identifier names collected using
+ /// parsing.
+ bool shouldRetainIdentifierNames() const { return retainIdentifierNames; }
+
/// Returns the parsing configurations associated to the bytecode read.
BytecodeReaderConfig &getBytecodeReaderConfig() const {
return const_cast<BytecodeReaderConfig &>(bytecodeReaderConfig);
@@ -520,6 +525,7 @@ class ParserConfig {
private:
MLIRContext *context;
bool verifyAfterParse;
+ bool retainIdentifierNames;
DenseMap<StringRef, std::unique_ptr<AsmResourceParser>> resourceParsers;
FallbackAsmResourceMap *fallbackResourceMap;
BytecodeReaderConfig bytecodeReaderConfig;
diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
index 160585e7da5486..4d28c180250546 100644
--- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
+++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
@@ -215,6 +215,13 @@ class MlirOptMainConfig {
/// Reproducer file generation (no crash required).
StringRef getReproducerFilename() const { return generateReproducerFileFlag; }
+ /// Print the pass-pipeline as text before executing.
+ MlirOptMainConfig &retainIdentifierNames(bool retain) {
+ retainIdentifierNamesFlag = retain;
+ return *this;
+ }
+ bool shouldRetainIdentifierNames() const { return retainIdentifierNamesFlag; }
+
protected:
/// Allow operation with no registered dialects.
/// This option is for convenience during testing only and discouraged in
@@ -281,6 +288,9 @@ class MlirOptMainConfig {
/// the corresponding line. This is meant for implementing diagnostic tests.
bool verifyDiagnosticsFlag = false;
+ /// Retain identifier names in the output (e.g., `%my_var` instead of `%0`).
+ bool retainIdentifierNamesFlag = false;
+
/// Run the verifier after each transformation pass.
bool verifyPassesFlag = true;
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index e3db248164672c..043e696aebe4e9 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -1283,6 +1283,23 @@ ParseResult OperationParser::parseOperation() {
}
}
+ // If enabled, store the SSA name(s) for the operation
+ if (state.config.shouldRetainIdentifierNames()) {
+ if (opResI == 1) {
+ for (ResultRecord &resIt : resultIDs) {
+ for (unsigned subRes : llvm::seq<unsigned>(0, std::get<1>(resIt))) {
+ op->setDiscardableAttr(
+ "mlir.ssaName",
+ StringAttr::get(getContext(),
+ std::get<0>(resIt).drop_front(1)));
+ }
+ }
+ } else if (opResI > 1) {
+ emitError(
+ "have not yet implemented support for multiple return values");
+ }
+ }
+
// Add this operation to the assembly state if it was provided to populate.
} else if (state.asmState) {
state.asmState->finalizeOperationDefinition(
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 61b90bc9b0a7bb..37dab681e57000 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -73,7 +73,8 @@ OpAsmParser::~OpAsmParser() = default;
MLIRContext *AsmParser::getContext() const { return getBuilder().getContext(); }
/// Parse a type list.
-/// This is out-of-line to work-around https://github.com/llvm/llvm-project/issues/62918
+/// This is out-of-line to work-around
+/// https://github.com/llvm/llvm-project/issues/62918
ParseResult AsmParser::parseTypeList(SmallVectorImpl<Type> &result) {
return parseCommaSeparatedList(
[&]() { return parseType(result.emplace_back()); });
@@ -1607,6 +1608,12 @@ void SSANameState::numberValuesInOp(Operation &op) {
}
Value resultBegin = op.getResult(0);
+ // Get the original SSA for the result if available
+ if (StringAttr ssaNameAttr = op.getAttrOfType<StringAttr>("mlir.ssaName")) {
+ setValueName(resultBegin, ssaNameAttr.strref());
+ op.removeDiscardableAttr("mlir.ssaName");
+ }
+
// If the first result wasn't numbered, give it a default number.
if (valueIDs.try_emplace(resultBegin, nextValueID).second)
++nextValueID;
diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
index 9bbf91de183051..f9f60832a3d798 100644
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -189,6 +189,11 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig {
cl::desc("Round-trip the IR after parsing and ensure it succeeds"),
cl::location(verifyRoundtripFlag), cl::init(false));
+ static cl::opt<bool, /*ExternalStorage=*/true> retainIdentifierNames(
+ "retain-identifier-names",
+ cl::desc("Retain the original names of identifiers when printing"),
+ cl::location(retainIdentifierNamesFlag), cl::init(false));
+
static cl::list<std::string> passPlugins(
"load-pass-plugin", cl::desc("Load passes from plugin library"));
@@ -439,8 +444,9 @@ performActions(raw_ostream &os,
// untouched.
PassReproducerOptions reproOptions;
FallbackAsmResourceMap fallbackResourceMap;
- ParserConfig parseConfig(context, config.shouldVerifyOnParsing(),
- &fallbackResourceMap);
+ ParserConfig parseConfig(
+ context, config.shouldVerifyOnParsing(), &fallbackResourceMap,
+ config.shouldRetainIdentifierNames());
if (config.shouldRunReproducer())
reproOptions.attachResourceParser(parseConfig);
>From 93b950dffcaf44e230295b347be34b48ff80de34 Mon Sep 17 00:00:00 2001
From: Perry Gibson <perry at gibsonic.org>
Date: Fri, 26 Jan 2024 14:52:10 +0000
Subject: [PATCH 02/11] Moved SSA name stored to function
---
mlir/lib/AsmParser/Parser.cpp | 40 ++++++++++++++++++++++-------------
1 file changed, 25 insertions(+), 15 deletions(-)
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 043e696aebe4e9..97c9c8a742289f 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -671,6 +671,10 @@ class OperationParser : public Parser {
/// an object of type 'OperationName'. Otherwise, failure is returned.
FailureOr<OperationName> parseCustomOperationName();
+ /// Store the SSA names for the current operation as attrs for debug purposes.
+ ParseResult storeSSANames(Operation *&op,
+ SmallVector<ResultRecord, 1> resultIDs);
+
//===--------------------------------------------------------------------===//
// Region Parsing
//===--------------------------------------------------------------------===//
@@ -1284,21 +1288,8 @@ ParseResult OperationParser::parseOperation() {
}
// If enabled, store the SSA name(s) for the operation
- if (state.config.shouldRetainIdentifierNames()) {
- if (opResI == 1) {
- for (ResultRecord &resIt : resultIDs) {
- for (unsigned subRes : llvm::seq<unsigned>(0, std::get<1>(resIt))) {
- op->setDiscardableAttr(
- "mlir.ssaName",
- StringAttr::get(getContext(),
- std::get<0>(resIt).drop_front(1)));
- }
- }
- } else if (opResI > 1) {
- emitError(
- "have not yet implemented support for multiple return values");
- }
- }
+ if (state.config.shouldRetainIdentifierNames())
+ storeSSANames(op, resultIDs);
// Add this operation to the assembly state if it was provided to populate.
} else if (state.asmState) {
@@ -1345,6 +1336,25 @@ OperationParser::parseSuccessors(SmallVectorImpl<Block *> &destinations) {
/*allowEmptyList=*/false);
}
+/// Store the SSA names for the current operation as attrs for debug purposes.
+ParseResult
+OperationParser::storeSSANames(Operation *&op,
+ SmallVector<ResultRecord, 1> resultIDs) {
+ if (op->getNumResults() == 0)
+ emitError("Operation has no results\n");
+ else if (op->getNumResults() > 1)
+ emitError("have not yet implemented support for multiple return values\n");
+
+ for (ResultRecord &resIt : resultIDs) {
+ for (unsigned subRes : llvm::seq<unsigned>(0, std::get<1>(resIt))) {
+ op->setDiscardableAttr(
+ "mlir.ssaName",
+ StringAttr::get(getContext(), std::get<0>(resIt).drop_front(1)));
+ }
+ }
+ return success();
+}
+
namespace {
// RAII-style guard for cleaning up the regions in the operation state before
// deleting them. Within the parser, regions may get deleted if parsing failed,
>From 4ce92ebc24f868b94ab100cff61a298301d186fd Mon Sep 17 00:00:00 2001
From: Perry Gibson <perry at gibsonic.org>
Date: Fri, 26 Jan 2024 15:39:18 +0000
Subject: [PATCH 03/11] Added initial block name handling
---
mlir/lib/AsmParser/Parser.cpp | 13 +++++++++++++
mlir/lib/IR/AsmPrinter.cpp | 33 ++++++++++++++++++++++++++++-----
2 files changed, 41 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 97c9c8a742289f..5042e460c4de41 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -1352,6 +1352,19 @@ OperationParser::storeSSANames(Operation *&op,
StringAttr::get(getContext(), std::get<0>(resIt).drop_front(1)));
}
}
+
+ // Find the name of the block that contains this operation.
+ Block *blockPtr = op->getBlock();
+ for (const auto &map : blocksByName) {
+ for (const auto &entry : map) {
+ if (entry.second.block == blockPtr) {
+ op->setDiscardableAttr("mlir.blockName",
+ StringAttr::get(getContext(), entry.first));
+ llvm::outs() << "Block name: " << entry.first << "\n";
+ }
+ }
+ }
+
return success();
}
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 37dab681e57000..a316b4a70e74e3 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1322,6 +1322,10 @@ class SSANameState {
/// conflicts, it is automatically renamed.
StringRef uniqueValueName(StringRef name);
+ /// Set the original identifier names if available. Used in debugging with
+ /// `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig
+ void setRetainedIdentifierNames(Operation &op);
+
/// This is the value ID for each SSA value. If this returns NameSentinel,
/// then the valueID has an entry in valueNames.
DenseMap<Value, unsigned> valueIDs;
@@ -1608,11 +1612,9 @@ void SSANameState::numberValuesInOp(Operation &op) {
}
Value resultBegin = op.getResult(0);
- // Get the original SSA for the result if available
- if (StringAttr ssaNameAttr = op.getAttrOfType<StringAttr>("mlir.ssaName")) {
- setValueName(resultBegin, ssaNameAttr.strref());
- op.removeDiscardableAttr("mlir.ssaName");
- }
+ // Set the original identifier names if available. Used in debugging with
+ // `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig
+ setRetainedIdentifierNames(op);
// If the first result wasn't numbered, give it a default number.
if (valueIDs.try_emplace(resultBegin, nextValueID).second)
@@ -1625,6 +1627,27 @@ void SSANameState::numberValuesInOp(Operation &op) {
}
}
+void SSANameState::setRetainedIdentifierNames(Operation &op) {
+ // Get the original SSA for the result(s) if available
+ Value resultBegin = op.getResult(0);
+ if (StringAttr ssaNameAttr = op.getAttrOfType<StringAttr>("mlir.ssaName")) {
+ setValueName(resultBegin, ssaNameAttr.strref());
+ op.removeDiscardableAttr("mlir.ssaName");
+ }
+ unsigned numResults = op.getNumResults();
+ if (numResults > 1)
+ llvm::outs()
+ << "have not yet implemented support for multiple return values\n";
+
+ // Get the original SSA name for the block if available
+ if (StringAttr blockNameAttr =
+ op.getAttrOfType<StringAttr>("mlir.blockName")) {
+ blockNames[op.getBlock()] = {-1, blockNameAttr.strref()};
+ op.removeDiscardableAttr("mlir.blockName");
+ }
+ return;
+}
+
void SSANameState::getResultIDAndNumber(
OpResult result, Value &lookupValue,
std::optional<int> &lookupResultNo) const {
>From dae9882bd7b470509bc293150f62730e6cbd4260 Mon Sep 17 00:00:00 2001
From: Perry Gibson <perry at gibsonic.org>
Date: Fri, 26 Jan 2024 16:18:27 +0000
Subject: [PATCH 04/11] Added block arg name handling
---
mlir/lib/AsmParser/Parser.cpp | 25 ++++++++++++++++++++-----
mlir/lib/IR/AsmPrinter.cpp | 14 +++++++++++++-
2 files changed, 33 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 5042e460c4de41..1d9b1c5a52711a 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -674,6 +674,7 @@ class OperationParser : public Parser {
/// Store the SSA names for the current operation as attrs for debug purposes.
ParseResult storeSSANames(Operation *&op,
SmallVector<ResultRecord, 1> resultIDs);
+ DenseMap<BlockArgument, StringRef> blockArgNames;
//===--------------------------------------------------------------------===//
// Region Parsing
@@ -1263,6 +1264,11 @@ ParseResult OperationParser::parseOperation() {
<< op->getNumResults() << " results but was provided "
<< numExpectedResults << " to bind";
+ // If enabled, store the SSA name(s) for the operation
+ llvm::outs() << "parsing operation: " << op->getName() << "\n";
+ if (state.config.shouldRetainIdentifierNames())
+ storeSSANames(op, resultIDs);
+
// Add this operation to the assembly state if it was provided to populate.
if (state.asmState) {
unsigned resultIt = 0;
@@ -1287,10 +1293,6 @@ ParseResult OperationParser::parseOperation() {
}
}
- // If enabled, store the SSA name(s) for the operation
- if (state.config.shouldRetainIdentifierNames())
- storeSSANames(op, resultIDs);
-
// Add this operation to the assembly state if it was provided to populate.
} else if (state.asmState) {
state.asmState->finalizeOperationDefinition(
@@ -1360,7 +1362,16 @@ OperationParser::storeSSANames(Operation *&op,
if (entry.second.block == blockPtr) {
op->setDiscardableAttr("mlir.blockName",
StringAttr::get(getContext(), entry.first));
- llvm::outs() << "Block name: " << entry.first << "\n";
+
+ // Store block arguments, if present
+ llvm::SmallVector<llvm::StringRef, 1> argNames;
+
+ for (BlockArgument arg : blockPtr->getArguments()) {
+ auto it = blockArgNames.find(arg);
+ if (it != blockArgNames.end())
+ argNames.push_back(it->second.drop_front(1));
+ }
+ op->setAttr("mlir.blockArgNames", builder.getStrArrayAttr(argNames));
}
}
}
@@ -2456,6 +2467,10 @@ ParseResult OperationParser::parseOptionalBlockArgList(Block *owner) {
} else {
auto loc = getEncodedSourceLocation(useInfo.location);
arg = owner->addArgument(type, loc);
+
+ // Optionally store argument name for debug purposes
+ if (state.config.shouldRetainIdentifierNames())
+ blockArgNames.insert({arg, useInfo.name});
}
// If the argument has an explicit loc(...) specifier, parse and apply
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index a316b4a70e74e3..89fd312a8af4fb 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1639,12 +1639,24 @@ void SSANameState::setRetainedIdentifierNames(Operation &op) {
llvm::outs()
<< "have not yet implemented support for multiple return values\n";
- // Get the original SSA name for the block if available
+ // Get the original name for the block if available
if (StringAttr blockNameAttr =
op.getAttrOfType<StringAttr>("mlir.blockName")) {
blockNames[op.getBlock()] = {-1, blockNameAttr.strref()};
op.removeDiscardableAttr("mlir.blockName");
}
+
+ // Get the original name for the block args if available
+ if (ArrayAttr blockArgNamesAttr =
+ op.getAttrOfType<ArrayAttr>("mlir.blockArgNames")) {
+ auto blockArgNames = blockArgNamesAttr.getValue();
+ auto blockArgs = op.getBlock()->getArguments();
+ for (int i = 0; i < blockArgs.size() && i < blockArgNames.size(); ++i) {
+ auto blockArgName = blockArgNames[i].cast<StringAttr>();
+ setValueName(blockArgs[i], cast<StringAttr>(blockArgNames[i]).strref());
+ }
+ op.removeDiscardableAttr("mlir.blockArgNames");
+ }
return;
}
>From fcc4a52dbfb8e88e0263426bbef4c5bae3e8f815 Mon Sep 17 00:00:00 2001
From: Perry Gibson <perry at gibsonic.org>
Date: Sat, 27 Jan 2024 11:24:53 +0000
Subject: [PATCH 05/11] Added support for operations with no arguments
---
mlir/lib/AsmParser/Parser.cpp | 26 ++++++++++----------------
mlir/lib/IR/AsmPrinter.cpp | 27 +++++++++++++++------------
2 files changed, 25 insertions(+), 28 deletions(-)
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 1d9b1c5a52711a..08c1942b50ec34 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -672,8 +672,7 @@ class OperationParser : public Parser {
FailureOr<OperationName> parseCustomOperationName();
/// Store the SSA names for the current operation as attrs for debug purposes.
- ParseResult storeSSANames(Operation *&op,
- SmallVector<ResultRecord, 1> resultIDs);
+ void storeSSANames(Operation *&op, ArrayRef<ResultRecord> resultIDs);
DenseMap<BlockArgument, StringRef> blockArgNames;
//===--------------------------------------------------------------------===//
@@ -1264,11 +1263,6 @@ ParseResult OperationParser::parseOperation() {
<< op->getNumResults() << " results but was provided "
<< numExpectedResults << " to bind";
- // If enabled, store the SSA name(s) for the operation
- llvm::outs() << "parsing operation: " << op->getName() << "\n";
- if (state.config.shouldRetainIdentifierNames())
- storeSSANames(op, resultIDs);
-
// Add this operation to the assembly state if it was provided to populate.
if (state.asmState) {
unsigned resultIt = 0;
@@ -1339,15 +1333,12 @@ OperationParser::parseSuccessors(SmallVectorImpl<Block *> &destinations) {
}
/// Store the SSA names for the current operation as attrs for debug purposes.
-ParseResult
-OperationParser::storeSSANames(Operation *&op,
- SmallVector<ResultRecord, 1> resultIDs) {
- if (op->getNumResults() == 0)
- emitError("Operation has no results\n");
- else if (op->getNumResults() > 1)
+void OperationParser::storeSSANames(Operation *&op,
+ ArrayRef<ResultRecord> resultIDs) {
+ if (op->getNumResults() > 1)
emitError("have not yet implemented support for multiple return values\n");
- for (ResultRecord &resIt : resultIDs) {
+ for (const ResultRecord &resIt : resultIDs) {
for (unsigned subRes : llvm::seq<unsigned>(0, std::get<1>(resIt))) {
op->setDiscardableAttr(
"mlir.ssaName",
@@ -1375,8 +1366,6 @@ OperationParser::storeSSANames(Operation *&op,
}
}
}
-
- return success();
}
namespace {
@@ -2144,6 +2133,11 @@ OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) {
// Otherwise, create the operation and try to parse a location for it.
Operation *op = opBuilder.create(opState);
+
+ // If enabled, store the SSA name(s) for the operation
+ if (state.config.shouldRetainIdentifierNames())
+ storeSSANames(op, resultIDs);
+
if (parseTrailingLocationSpecifier(op))
return nullptr;
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 89fd312a8af4fb..65036d9237a8ce 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1601,6 +1601,10 @@ void SSANameState::numberValuesInOp(Operation &op) {
}
}
+ // Set the original identifier names if available. Used in debugging with
+ // `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig
+ setRetainedIdentifierNames(op);
+
unsigned numResults = op.getNumResults();
if (numResults == 0) {
// If value users should be printed, operations with no result need an id.
@@ -1612,10 +1616,6 @@ void SSANameState::numberValuesInOp(Operation &op) {
}
Value resultBegin = op.getResult(0);
- // Set the original identifier names if available. Used in debugging with
- // `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig
- setRetainedIdentifierNames(op);
-
// If the first result wasn't numbered, give it a default number.
if (valueIDs.try_emplace(resultBegin, nextValueID).second)
++nextValueID;
@@ -1629,15 +1629,17 @@ void SSANameState::numberValuesInOp(Operation &op) {
void SSANameState::setRetainedIdentifierNames(Operation &op) {
// Get the original SSA for the result(s) if available
- Value resultBegin = op.getResult(0);
- if (StringAttr ssaNameAttr = op.getAttrOfType<StringAttr>("mlir.ssaName")) {
- setValueName(resultBegin, ssaNameAttr.strref());
- op.removeDiscardableAttr("mlir.ssaName");
- }
unsigned numResults = op.getNumResults();
if (numResults > 1)
llvm::outs()
<< "have not yet implemented support for multiple return values\n";
+ else if (numResults == 1) {
+ Value resultBegin = op.getResult(0);
+ if (StringAttr ssaNameAttr = op.getAttrOfType<StringAttr>("mlir.ssaName")) {
+ setValueName(resultBegin, ssaNameAttr.strref());
+ op.removeDiscardableAttr("mlir.ssaName");
+ }
+ }
// Get the original name for the block if available
if (StringAttr blockNameAttr =
@@ -1651,9 +1653,10 @@ void SSANameState::setRetainedIdentifierNames(Operation &op) {
op.getAttrOfType<ArrayAttr>("mlir.blockArgNames")) {
auto blockArgNames = blockArgNamesAttr.getValue();
auto blockArgs = op.getBlock()->getArguments();
- for (int i = 0; i < blockArgs.size() && i < blockArgNames.size(); ++i) {
- auto blockArgName = blockArgNames[i].cast<StringAttr>();
- setValueName(blockArgs[i], cast<StringAttr>(blockArgNames[i]).strref());
+ for (size_t i = 0; i < blockArgs.size() && i < blockArgNames.size(); ++i) {
+ auto blockArgName = blockArgNames[i].cast<StringAttr>().strref();
+ if (!usedNames.count(blockArgName))
+ setValueName(blockArgs[i], blockArgName);
}
op.removeDiscardableAttr("mlir.blockArgNames");
}
>From bbb2dbd703016e07c87ab25b8eed99b86542351a Mon Sep 17 00:00:00 2001
From: Perry Gibson <perry at gibsonic.org>
Date: Sat, 27 Jan 2024 12:21:39 +0000
Subject: [PATCH 06/11] Added initial unit tests
---
mlir/test/IR/print-retain-identifiers.mlir | 54 ++++++++++++++++++++++
1 file changed, 54 insertions(+)
create mode 100644 mlir/test/IR/print-retain-identifiers.mlir
diff --git a/mlir/test/IR/print-retain-identifiers.mlir b/mlir/test/IR/print-retain-identifiers.mlir
new file mode 100644
index 00000000000000..663f7301b817bc
--- /dev/null
+++ b/mlir/test/IR/print-retain-identifiers.mlir
@@ -0,0 +1,54 @@
+// RUN: mlir-opt -retain-identifier-names %s | FileCheck %s
+
+
+//===----------------------------------------------------------------------===//
+// Test SSA results (with single return values)
+//===----------------------------------------------------------------------===//
+
+// CHECK: func.func @add_one(%arg0: f64, %arg1: f64) -> f64 {
+func.func @add_one(%arg0: f64, %arg1: f64) -> f64 {
+ // CHECK: %my_constant = arith.constant 1.000000e+00 : f64
+ %my_constant = arith.constant 1.000000e+00 : f64
+ // CHECK: %my_output = arith.addf %arg0, %my_constant : f64
+ %my_output = arith.addf %arg0, %my_constant : f64
+ // CHECK: return %my_output : f64
+ return %my_output : f64
+}
+
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Test basic blocks and their arguments
+//===----------------------------------------------------------------------===//
+
+func.func @simple(i64, i1) -> i64 {
+^bb_alpha(%a: i64, %cond: i1):
+ // CHECK: cf.cond_br %cond, ^bb_beta, ^bb_gamma
+ cf.cond_br %cond, ^bb_beta, ^bb_gamma
+
+// CHECK: ^bb_beta: // pred: ^bb_alpha
+^bb_beta:
+ // CHECK: cf.br ^bb_delta(%a : i64)
+ cf.br ^bb_delta(%a: i64)
+
+// CHECK: ^bb_gamma: // pred: ^bb_alpha
+^bb_gamma:
+ // CHECK: %b = arith.addi %a, %a : i64
+ %b = arith.addi %a, %a : i64
+ // CHECK: cf.br ^bb_delta(%b : i64)
+ cf.br ^bb_delta(%b: i64)
+
+// CHECK: ^bb_delta(%c: i64): // 2 preds: ^bb_gamma, ^bb_beta
+^bb_delta(%c: i64):
+ // CHECK: cf.br ^bb_eps(%c, %a : i64, i64)
+ cf.br ^bb_eps(%c, %a : i64, i64)
+
+// CHECK: ^bb_eps(%d: i64, %e: i64): // pred: ^bb_delta
+^bb_eps(%d : i64, %e : i64):
+ // CHECK: %f = arith.addi %d, %e : i64
+ %f = arith.addi %d, %e : i64
+ return %f : i64
+}
+
+// -----
>From 0b4485f2b790ecf415fd2d341ca4fc2e8c6cafa1 Mon Sep 17 00:00:00 2001
From: Perry Gibson <perry at gibsonic.org>
Date: Sat, 27 Jan 2024 14:49:16 +0000
Subject: [PATCH 07/11] Added operation operand name preservation
---
mlir/lib/AsmParser/Parser.cpp | 34 +++++++++++++++++-----
mlir/lib/IR/AsmPrinter.cpp | 13 +++++++++
mlir/test/IR/print-retain-identifiers.mlir | 8 ++---
3 files changed, 43 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 08c1942b50ec34..75c5a1a9f1c27e 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -673,7 +673,7 @@ class OperationParser : public Parser {
/// Store the SSA names for the current operation as attrs for debug purposes.
void storeSSANames(Operation *&op, ArrayRef<ResultRecord> resultIDs);
- DenseMap<BlockArgument, StringRef> blockArgNames;
+ DenseMap<Value, StringRef> argNames;
//===--------------------------------------------------------------------===//
// Region Parsing
@@ -1346,7 +1346,19 @@ void OperationParser::storeSSANames(Operation *&op,
}
}
- // Find the name of the block that contains this operation.
+ // Store the name information of the arguments of this operation.
+ if (op->getNumOperands() > 0) {
+ llvm::SmallVector<llvm::StringRef, 1> opArgNames;
+ for (auto &operand : op->getOpOperands()) {
+ auto it = argNames.find(operand.get());
+ if (it != argNames.end())
+ opArgNames.push_back(it->second.drop_front(1));
+ }
+ op->setDiscardableAttr("mlir.opArgNames",
+ builder.getStrArrayAttr(opArgNames));
+ }
+
+ // Store the name information of the block that contains this operation.
Block *blockPtr = op->getBlock();
for (const auto &map : blocksByName) {
for (const auto &entry : map) {
@@ -1355,14 +1367,15 @@ void OperationParser::storeSSANames(Operation *&op,
StringAttr::get(getContext(), entry.first));
// Store block arguments, if present
- llvm::SmallVector<llvm::StringRef, 1> argNames;
+ llvm::SmallVector<llvm::StringRef, 1> blockArgNames;
for (BlockArgument arg : blockPtr->getArguments()) {
- auto it = blockArgNames.find(arg);
- if (it != blockArgNames.end())
- argNames.push_back(it->second.drop_front(1));
+ auto it = argNames.find(arg);
+ if (it != argNames.end())
+ blockArgNames.push_back(it->second.drop_front(1));
}
- op->setAttr("mlir.blockArgNames", builder.getStrArrayAttr(argNames));
+ op->setAttr("mlir.blockArgNames",
+ builder.getStrArrayAttr(blockArgNames));
}
}
}
@@ -1772,6 +1785,11 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
SmallVectorImpl<Value> &result) override {
if (auto value = parser.resolveSSAUse(operand, type)) {
result.push_back(value);
+
+ // Optionally store argument name for debug purposes
+ if (parser.getState().config.shouldRetainIdentifierNames())
+ parser.argNames.insert({value, operand.name});
+
return success();
}
return failure();
@@ -2464,7 +2482,7 @@ ParseResult OperationParser::parseOptionalBlockArgList(Block *owner) {
// Optionally store argument name for debug purposes
if (state.config.shouldRetainIdentifierNames())
- blockArgNames.insert({arg, useInfo.name});
+ argNames.insert({arg, useInfo.name});
}
// If the argument has an explicit loc(...) specifier, parse and apply
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 65036d9237a8ce..6b5183e5fb5913 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1641,6 +1641,19 @@ void SSANameState::setRetainedIdentifierNames(Operation &op) {
}
}
+ // Get the original name for the op args if available
+ if (ArrayAttr opArgNamesAttr =
+ op.getAttrOfType<ArrayAttr>("mlir.opArgNames")) {
+ auto opArgNames = opArgNamesAttr.getValue();
+ auto opArgs = op.getOperands();
+ for (size_t i = 0; i < opArgs.size() && i < opArgNames.size(); ++i) {
+ auto opArgName = opArgNames[i].cast<StringAttr>().strref();
+ if (!usedNames.count(opArgName))
+ setValueName(opArgs[i], opArgName);
+ }
+ op.removeDiscardableAttr("mlir.opArgNames");
+ }
+
// Get the original name for the block if available
if (StringAttr blockNameAttr =
op.getAttrOfType<StringAttr>("mlir.blockName")) {
diff --git a/mlir/test/IR/print-retain-identifiers.mlir b/mlir/test/IR/print-retain-identifiers.mlir
index 663f7301b817bc..05a6b9b42d8e82 100644
--- a/mlir/test/IR/print-retain-identifiers.mlir
+++ b/mlir/test/IR/print-retain-identifiers.mlir
@@ -5,12 +5,12 @@
// Test SSA results (with single return values)
//===----------------------------------------------------------------------===//
-// CHECK: func.func @add_one(%arg0: f64, %arg1: f64) -> f64 {
-func.func @add_one(%arg0: f64, %arg1: f64) -> f64 {
+// CHECK: func.func @add_one(%my_input: f64, %arg1: f64) -> f64 {
+func.func @add_one(%my_input: f64, %arg1: f64) -> f64 {
// CHECK: %my_constant = arith.constant 1.000000e+00 : f64
%my_constant = arith.constant 1.000000e+00 : f64
- // CHECK: %my_output = arith.addf %arg0, %my_constant : f64
- %my_output = arith.addf %arg0, %my_constant : f64
+ // CHECK: %my_output = arith.addf %my_input, %my_constant : f64
+ %my_output = arith.addf %my_input, %my_constant : f64
// CHECK: return %my_output : f64
return %my_output : f64
}
>From 7016d22fe7a9681f0181c57294b3a2e14855513f Mon Sep 17 00:00:00 2001
From: Perry Gibson <perry at gibsonic.org>
Date: Sat, 27 Jan 2024 18:27:53 +0000
Subject: [PATCH 08/11] Added support for result groups
---
mlir/lib/AsmParser/Parser.cpp | 19 ++++++-----
mlir/lib/IR/AsmPrinter.cpp | 36 ++++++++++++--------
mlir/test/IR/print-retain-identifiers.mlir | 38 ++++++++++++++++++++++
3 files changed, 72 insertions(+), 21 deletions(-)
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 75c5a1a9f1c27e..90217b287e7e20 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -1335,15 +1335,18 @@ OperationParser::parseSuccessors(SmallVectorImpl<Block *> &destinations) {
/// Store the SSA names for the current operation as attrs for debug purposes.
void OperationParser::storeSSANames(Operation *&op,
ArrayRef<ResultRecord> resultIDs) {
- if (op->getNumResults() > 1)
- emitError("have not yet implemented support for multiple return values\n");
-
- for (const ResultRecord &resIt : resultIDs) {
- for (unsigned subRes : llvm::seq<unsigned>(0, std::get<1>(resIt))) {
- op->setDiscardableAttr(
- "mlir.ssaName",
- StringAttr::get(getContext(), std::get<0>(resIt).drop_front(1)));
+
+ // Store the name(s) of the result(s) of this operation.
+ if (op->getNumResults() > 0) {
+ llvm::SmallVector<llvm::StringRef, 1> resultNames;
+ for (const ResultRecord &resIt : resultIDs) {
+ resultNames.push_back(std::get<0>(resIt).drop_front(1));
+ // Insert empty string for sub-results/result groups
+ for (unsigned int i = 1; i < std::get<1>(resIt); ++i)
+ resultNames.push_back(llvm::StringRef());
}
+ op->setDiscardableAttr("mlir.resultNames",
+ builder.getStrArrayAttr(resultNames));
}
// Store the name information of the arguments of this operation.
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 6b5183e5fb5913..29449ca0a3e56d 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1324,7 +1324,8 @@ class SSANameState {
/// Set the original identifier names if available. Used in debugging with
/// `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig
- void setRetainedIdentifierNames(Operation &op);
+ void setRetainedIdentifierNames(Operation &op,
+ SmallVector<int, 2> &resultGroups);
/// This is the value ID for each SSA value. If this returns NameSentinel,
/// then the valueID has an entry in valueNames.
@@ -1603,7 +1604,7 @@ void SSANameState::numberValuesInOp(Operation &op) {
// Set the original identifier names if available. Used in debugging with
// `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig
- setRetainedIdentifierNames(op);
+ setRetainedIdentifierNames(op, resultGroups);
unsigned numResults = op.getNumResults();
if (numResults == 0) {
@@ -1627,18 +1628,25 @@ void SSANameState::numberValuesInOp(Operation &op) {
}
}
-void SSANameState::setRetainedIdentifierNames(Operation &op) {
- // Get the original SSA for the result(s) if available
- unsigned numResults = op.getNumResults();
- if (numResults > 1)
- llvm::outs()
- << "have not yet implemented support for multiple return values\n";
- else if (numResults == 1) {
- Value resultBegin = op.getResult(0);
- if (StringAttr ssaNameAttr = op.getAttrOfType<StringAttr>("mlir.ssaName")) {
- setValueName(resultBegin, ssaNameAttr.strref());
- op.removeDiscardableAttr("mlir.ssaName");
+void SSANameState::setRetainedIdentifierNames(
+ Operation &op, SmallVector<int, 2> &resultGroups) {
+ // Get the original names for the results if available
+ if (ArrayAttr resultNamesAttr =
+ op.getAttrOfType<ArrayAttr>("mlir.resultNames")) {
+ auto resultNames = resultNamesAttr.getValue();
+ auto results = op.getResults();
+ // Conservative in the case that the #results has changed
+ for (size_t i = 0; i < results.size() && i < resultNames.size(); ++i) {
+ auto resultName = resultNames[i].cast<StringAttr>().strref();
+ if (!resultName.empty()) {
+ if (!usedNames.count(resultName))
+ setValueName(results[i], resultName);
+ // If a result has a name, it is the start of a result group.
+ if (i > 0)
+ resultGroups.push_back(i);
+ }
}
+ op.removeDiscardableAttr("mlir.resultNames");
}
// Get the original name for the op args if available
@@ -1646,6 +1654,7 @@ void SSANameState::setRetainedIdentifierNames(Operation &op) {
op.getAttrOfType<ArrayAttr>("mlir.opArgNames")) {
auto opArgNames = opArgNamesAttr.getValue();
auto opArgs = op.getOperands();
+ // Conservative in the case that the #operands has changed
for (size_t i = 0; i < opArgs.size() && i < opArgNames.size(); ++i) {
auto opArgName = opArgNames[i].cast<StringAttr>().strref();
if (!usedNames.count(opArgName))
@@ -1666,6 +1675,7 @@ void SSANameState::setRetainedIdentifierNames(Operation &op) {
op.getAttrOfType<ArrayAttr>("mlir.blockArgNames")) {
auto blockArgNames = blockArgNamesAttr.getValue();
auto blockArgs = op.getBlock()->getArguments();
+ // Conservative in the case that the #args has changed
for (size_t i = 0; i < blockArgs.size() && i < blockArgNames.size(); ++i) {
auto blockArgName = blockArgNames[i].cast<StringAttr>().strref();
if (!usedNames.count(blockArgName))
diff --git a/mlir/test/IR/print-retain-identifiers.mlir b/mlir/test/IR/print-retain-identifiers.mlir
index 05a6b9b42d8e82..b3e4f075b3936a 100644
--- a/mlir/test/IR/print-retain-identifiers.mlir
+++ b/mlir/test/IR/print-retain-identifiers.mlir
@@ -52,3 +52,41 @@ func.func @simple(i64, i1) -> i64 {
}
// -----
+
+//===----------------------------------------------------------------------===//
+// Test multiple return values
+//===----------------------------------------------------------------------===//
+
+func.func @select_min_max(%a: f64, %b: f64) -> (f64, f64) {
+ %gt = arith.cmpf "ogt", %a, %b : f64
+ // CHECK: %min, %max = scf.if %gt -> (f64, f64) {
+ %min, %max = scf.if %gt -> (f64, f64) {
+ scf.yield %b, %a : f64, f64
+ } else {
+ scf.yield %a, %b : f64, f64
+ }
+ // CHECK: return %min, %max : f64, f64
+ return %min, %max : f64, f64
+}
+
+// -----
+
+////===----------------------------------------------------------------------===//
+// Test multiple return values, with a grouped value tuple
+//===----------------------------------------------------------------------===//
+
+func.func @select_max(%a: f64, %b: f64, %c: f64, %d: f64) -> (f64, f64, f64, f64) {
+ // Find the max between %a and %b,
+ // with %c and %d being other values that are returned.
+ %gt = arith.cmpf "ogt", %a, %b : f64
+ // CHECK: %max, %others:2, %alt = scf.if %gt -> (f64, f64, f64, f64) {
+ %max, %others:2, %alt = scf.if %gt -> (f64, f64, f64, f64) {
+ scf.yield %b, %a, %c, %d : f64, f64, f64, f64
+ } else {
+ scf.yield %a, %b, %d, %c : f64, f64, f64, f64
+ }
+ // CHECK: return %max, %others#0, %others#1, %alt : f64, f64, f64, f64
+ return %max, %others#0, %others#1, %alt : f64, f64, f64, f64
+}
+
+// -----
>From afa5f2cdc338135ae7877112cca2bcac8f624044 Mon Sep 17 00:00:00 2001
From: Perry Gibson <perry at gibsonic.org>
Date: Sat, 27 Jan 2024 18:58:35 +0000
Subject: [PATCH 09/11] Fix clang-format issue in AsmState.h
---
mlir/include/mlir/IR/AsmState.h | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/IR/AsmState.h b/mlir/include/mlir/IR/AsmState.h
index c030c243bac4d9..e13a9324b1f669 100644
--- a/mlir/include/mlir/IR/AsmState.h
+++ b/mlir/include/mlir/IR/AsmState.h
@@ -146,7 +146,8 @@ class AsmResourceBlob {
/// Return the underlying data as an array of the given type. This is an
/// inherrently unsafe operation, and should only be used when the data is
/// known to be of the correct type.
- template <typename T> ArrayRef<T> getDataAs() const {
+ template <typename T>
+ ArrayRef<T> getDataAs() const {
return llvm::ArrayRef<T>((const T *)data.data(), data.size() / sizeof(T));
}
>From a61ae25d11b70101d203e9c416e4f61977e86fa6 Mon Sep 17 00:00:00 2001
From: Perry Gibson <perry at gibsonic.org>
Date: Sun, 28 Jan 2024 16:22:52 +0000
Subject: [PATCH 10/11] Added system to handle when we use default names
---
mlir/lib/IR/AsmPrinter.cpp | 45 ++++++++++++++++++++++----------------
1 file changed, 26 insertions(+), 19 deletions(-)
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 29449ca0a3e56d..77d60680a0b556 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1001,7 +1001,8 @@ class DummyAliasDialectAsmPrinter : public DialectAsmPrinter {
/// store the new copy,
static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
StringRef allowedPunctChars = "$._-",
- bool allowTrailingDigit = true) {
+ bool allowTrailingDigit = true,
+ bool allowNumeric = false) {
assert(!name.empty() && "Shouldn't have an empty name here");
auto validChar = [&](char ch) {
@@ -1021,16 +1022,17 @@ static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
// Check to see if this name is valid. If it starts with a digit, then it
// could conflict with the autogenerated numeric ID's, so add an underscore
- // prefix to avoid problems.
- if (isdigit(name[0]) || (!validChar(name[0]) && name[0] != ' ')) {
+ // prefix to avoid problems. This can be overridden by setting allowNumeric.
+ if ((isdigit(name[0]) || (!validChar(name[0]) && name[0] != ' ')) && !allowNumeric) {
buffer.push_back('_');
copyNameToBuffer();
return buffer;
}
// If the name ends with a trailing digit, add a '_' to avoid potential
- // conflicts with autogenerated ID's.
- if (!allowTrailingDigit && isdigit(name.back())) {
+ // conflicts with autogenerated ID's. This can be overridden by setting
+ // allowNumeric.
+ if (!allowTrailingDigit && isdigit(name.back()) && !allowNumeric) {
copyNameToBuffer();
buffer.push_back('_');
return buffer;
@@ -1316,11 +1318,11 @@ class SSANameState {
std::optional<int> &lookupResultNo) const;
/// Set a special value name for the given value.
- void setValueName(Value value, StringRef name);
+ void setValueName(Value value, StringRef name, bool allowNumeric = false);
/// Uniques the given value name within the printer. If the given name
/// conflicts, it is automatically renamed.
- StringRef uniqueValueName(StringRef name);
+ StringRef uniqueValueName(StringRef name, bool allowNumeric = false);
/// Set the original identifier names if available. Used in debugging with
/// `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig
@@ -1571,7 +1573,10 @@ void SSANameState::numberValuesInOp(Operation &op) {
// Function used to set the special result names for the operation.
SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0);
auto setResultNameFn = [&](Value result, StringRef name) {
- assert(!valueIDs.count(result) && "result numbered multiple times");
+ // Case where the result has already been named
+ if (valueIDs.count(result))
+ return;
+ // assert(!valueIDs.count(result) && "result numbered multiple times");
assert(result.getDefiningOp() == &op && "result not defined by 'op'");
setValueName(result, name);
@@ -1595,6 +1600,10 @@ void SSANameState::numberValuesInOp(Operation &op) {
blockNames[block] = {-1, name};
};
+ // Set the original identifier names if available. Used in debugging with
+ // `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig
+ setRetainedIdentifierNames(op, resultGroups);
+
if (!printerFlags.shouldPrintGenericOpForm()) {
if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op)) {
asmInterface.getAsmBlockNames(setBlockNameFn);
@@ -1602,10 +1611,6 @@ void SSANameState::numberValuesInOp(Operation &op) {
}
}
- // Set the original identifier names if available. Used in debugging with
- // `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig
- setRetainedIdentifierNames(op, resultGroups);
-
unsigned numResults = op.getNumResults();
if (numResults == 0) {
// If value users should be printed, operations with no result need an id.
@@ -1640,7 +1645,7 @@ void SSANameState::setRetainedIdentifierNames(
auto resultName = resultNames[i].cast<StringAttr>().strref();
if (!resultName.empty()) {
if (!usedNames.count(resultName))
- setValueName(results[i], resultName);
+ setValueName(results[i], resultName, /*allowNumeric=*/true);
// If a result has a name, it is the start of a result group.
if (i > 0)
resultGroups.push_back(i);
@@ -1658,7 +1663,7 @@ void SSANameState::setRetainedIdentifierNames(
for (size_t i = 0; i < opArgs.size() && i < opArgNames.size(); ++i) {
auto opArgName = opArgNames[i].cast<StringAttr>().strref();
if (!usedNames.count(opArgName))
- setValueName(opArgs[i], opArgName);
+ setValueName(opArgs[i], opArgName, /*allowNumeric=*/true);
}
op.removeDiscardableAttr("mlir.opArgNames");
}
@@ -1679,7 +1684,7 @@ void SSANameState::setRetainedIdentifierNames(
for (size_t i = 0; i < blockArgs.size() && i < blockArgNames.size(); ++i) {
auto blockArgName = blockArgNames[i].cast<StringAttr>().strref();
if (!usedNames.count(blockArgName))
- setValueName(blockArgs[i], blockArgName);
+ setValueName(blockArgs[i], blockArgName, /*allowNumeric=*/true);
}
op.removeDiscardableAttr("mlir.blockArgNames");
}
@@ -1725,7 +1730,8 @@ void SSANameState::getResultIDAndNumber(
lookupValue = owner->getResult(groupResultNo);
}
-void SSANameState::setValueName(Value value, StringRef name) {
+void SSANameState::setValueName(Value value, StringRef name,
+ bool allowNumeric) {
// If the name is empty, the value uses the default numbering.
if (name.empty()) {
valueIDs[value] = nextValueID++;
@@ -1733,12 +1739,13 @@ void SSANameState::setValueName(Value value, StringRef name) {
}
valueIDs[value] = NameSentinel;
- valueNames[value] = uniqueValueName(name);
+ valueNames[value] = uniqueValueName(name, allowNumeric);
}
-StringRef SSANameState::uniqueValueName(StringRef name) {
+StringRef SSANameState::uniqueValueName(StringRef name, bool allowNumeric) {
SmallString<16> tmpBuffer;
- name = sanitizeIdentifier(name, tmpBuffer);
+ name = sanitizeIdentifier(name, tmpBuffer, /*allowedPunctChars=*/"$._-",
+ /*allowTrailingDigit=*/true, allowNumeric);
// Check to see if this name is already unique.
if (!usedNames.count(name)) {
>From 28099524777fd7e90ed7073b17baea2708ec356e Mon Sep 17 00:00:00 2001
From: Perry Gibson <perry at gibsonic.org>
Date: Mon, 29 Jan 2024 13:03:04 +0000
Subject: [PATCH 11/11] Added region arg support & ambiguous name test
---
mlir/lib/AsmParser/Parser.cpp | 28 +++--
mlir/lib/IR/AsmPrinter.cpp | 118 ++++++++++++---------
mlir/test/IR/print-retain-identifiers.mlir | 27 ++++-
3 files changed, 113 insertions(+), 60 deletions(-)
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 90217b287e7e20..baa52fcf2da65c 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -671,8 +671,9 @@ class OperationParser : public Parser {
/// an object of type 'OperationName'. Otherwise, failure is returned.
FailureOr<OperationName> parseCustomOperationName();
- /// Store the SSA names for the current operation as attrs for debug purposes.
- void storeSSANames(Operation *&op, ArrayRef<ResultRecord> resultIDs);
+ /// Store the identifier names for the current operation as attrs for debug
+ /// purposes.
+ void storeIdentifierNames(Operation *&op, ArrayRef<ResultRecord> resultIDs);
DenseMap<Value, StringRef> argNames;
//===--------------------------------------------------------------------===//
@@ -1333,8 +1334,8 @@ OperationParser::parseSuccessors(SmallVectorImpl<Block *> &destinations) {
}
/// Store the SSA names for the current operation as attrs for debug purposes.
-void OperationParser::storeSSANames(Operation *&op,
- ArrayRef<ResultRecord> resultIDs) {
+void OperationParser::storeIdentifierNames(Operation *&op,
+ ArrayRef<ResultRecord> resultIDs) {
// Store the name(s) of the result(s) of this operation.
if (op->getNumResults() > 0) {
@@ -1382,6 +1383,18 @@ void OperationParser::storeSSANames(Operation *&op,
}
}
}
+
+ // Store names of region arguments (e.g., for FuncOps)
+ if (op->getNumRegions() > 0 && op->getRegion(0).getNumArguments() > 0) {
+ llvm::SmallVector<llvm::StringRef, 1> regionArgNames;
+ for (BlockArgument arg : op->getRegion(0).getArguments()) {
+ auto it = argNames.find(arg);
+ if (it != argNames.end()) {
+ regionArgNames.push_back(it->second.drop_front(1));
+ }
+ }
+ op->setAttr("mlir.regionArgNames", builder.getStrArrayAttr(regionArgNames));
+ }
}
namespace {
@@ -2155,9 +2168,9 @@ OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) {
// Otherwise, create the operation and try to parse a location for it.
Operation *op = opBuilder.create(opState);
- // If enabled, store the SSA name(s) for the operation
+ // If enabled, store the original identifier name(s) for the operation
if (state.config.shouldRetainIdentifierNames())
- storeSSANames(op, resultIDs);
+ storeIdentifierNames(op, resultIDs);
if (parseTrailingLocationSpecifier(op))
return nullptr;
@@ -2307,6 +2320,9 @@ ParseResult OperationParser::parseRegionBody(Region ®ion, SMLoc startLoc,
if (state.asmState)
state.asmState->addDefinition(arg, argInfo.location);
+ if (state.config.shouldRetainIdentifierNames())
+ argNames.insert({arg, argInfo.name});
+
// Record the definition for this argument.
if (addDefinition(argInfo, arg))
return failure();
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 77d60680a0b556..06b819b4a8675c 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1327,7 +1327,9 @@ class SSANameState {
/// Set the original identifier names if available. Used in debugging with
/// `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig
void setRetainedIdentifierNames(Operation &op,
- SmallVector<int, 2> &resultGroups);
+ SmallVector<int, 2> &resultGroups,
+ bool hasRegion = false);
+ void setRetainedIdentifierNames(Region ®ion);
/// This is the value ID for each SSA value. If this returns NameSentinel,
/// then the valueID has an entry in valueNames.
@@ -1522,6 +1524,9 @@ void SSANameState::numberValuesInRegion(Region ®ion) {
setValueName(arg, name);
};
+ // Use manually specified region arg names if available
+ setRetainedIdentifierNames(region);
+
if (!printerFlags.shouldPrintGenericOpForm()) {
if (Operation *op = region.getParentOp()) {
if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op))
@@ -1633,62 +1638,73 @@ void SSANameState::numberValuesInOp(Operation &op) {
}
}
-void SSANameState::setRetainedIdentifierNames(
- Operation &op, SmallVector<int, 2> &resultGroups) {
- // Get the original names for the results if available
- if (ArrayAttr resultNamesAttr =
- op.getAttrOfType<ArrayAttr>("mlir.resultNames")) {
- auto resultNames = resultNamesAttr.getValue();
- auto results = op.getResults();
- // Conservative in the case that the #results has changed
- for (size_t i = 0; i < results.size() && i < resultNames.size(); ++i) {
- auto resultName = resultNames[i].cast<StringAttr>().strref();
- if (!resultName.empty()) {
- if (!usedNames.count(resultName))
- setValueName(results[i], resultName, /*allowNumeric=*/true);
- // If a result has a name, it is the start of a result group.
- if (i > 0)
- resultGroups.push_back(i);
- }
- }
- op.removeDiscardableAttr("mlir.resultNames");
- }
+void SSANameState::setRetainedIdentifierNames(Operation &op,
+ SmallVector<int, 2> &resultGroups,
+ bool hasRegion) {
+
+ // Lambda which fetches the list of relevant attributes (e.g.,
+ // mlir.resultNames) and associates them with the relevant values
+ auto handleNamedAttributes =
+ [this](Operation &op, const Twine &attrName, auto getValuesFunc,
+ std::optional<std::function<void(int)>> customAction =
+ std::nullopt) {
+ if (ArrayAttr namesAttr = op.getAttrOfType<ArrayAttr>(attrName.str())) {
+ auto names = namesAttr.getValue();
+ auto values = getValuesFunc();
+ // Conservative in case the number of values has changed
+ for (size_t i = 0; i < values.size() && i < names.size(); ++i) {
+ auto name = names[i].cast<StringAttr>().strref();
+ if (!name.empty()) {
+ if (!this->usedNames.count(name))
+ this->setValueName(values[i], name, true);
+ if (customAction.has_value())
+ customAction.value()(i);
+ }
+ }
+ op.removeDiscardableAttr(attrName.str());
+ }
+ };
- // Get the original name for the op args if available
- if (ArrayAttr opArgNamesAttr =
- op.getAttrOfType<ArrayAttr>("mlir.opArgNames")) {
- auto opArgNames = opArgNamesAttr.getValue();
- auto opArgs = op.getOperands();
- // Conservative in the case that the #operands has changed
- for (size_t i = 0; i < opArgs.size() && i < opArgNames.size(); ++i) {
- auto opArgName = opArgNames[i].cast<StringAttr>().strref();
- if (!usedNames.count(opArgName))
- setValueName(opArgs[i], opArgName, /*allowNumeric=*/true);
+ if (hasRegion) {
+ // Get the original name(s) for the region arg(s) if available (e.g., for
+ // FuncOp args). Requires hasRegion flag to ensure scoping is correct
+ if (hasRegion && op.getNumRegions() > 0 &&
+ op.getRegion(0).getNumArguments() > 0) {
+ handleNamedAttributes(op, "mlir.regionArgNames",
+ [&]() { return op.getRegion(0).getArguments(); });
+ }
+ } else {
+ // Get the original names for the results if available
+ handleNamedAttributes(
+ op, "mlir.resultNames", [&]() { return op.getResults(); },
+ [&resultGroups](int i) { /*handles result groups*/
+ if (i > 0)
+ resultGroups.push_back(i);
+ });
+
+ // Get the original name for the op args if available
+ handleNamedAttributes(op, "mlir.opArgNames",
+ [&]() { return op.getOperands(); });
+
+ // Get the original name for the block if available
+ if (StringAttr blockNameAttr =
+ op.getAttrOfType<StringAttr>("mlir.blockName")) {
+ blockNames[op.getBlock()] = {-1, blockNameAttr.strref()};
+ op.removeDiscardableAttr("mlir.blockName");
}
- op.removeDiscardableAttr("mlir.opArgNames");
- }
- // Get the original name for the block if available
- if (StringAttr blockNameAttr =
- op.getAttrOfType<StringAttr>("mlir.blockName")) {
- blockNames[op.getBlock()] = {-1, blockNameAttr.strref()};
- op.removeDiscardableAttr("mlir.blockName");
+ // Get the original name(s) for the block arg(s) if available
+ handleNamedAttributes(op, "mlir.blockArgNames",
+ [&]() { return op.getBlock()->getArguments(); });
}
+ return;
+}
- // Get the original name for the block args if available
- if (ArrayAttr blockArgNamesAttr =
- op.getAttrOfType<ArrayAttr>("mlir.blockArgNames")) {
- auto blockArgNames = blockArgNamesAttr.getValue();
- auto blockArgs = op.getBlock()->getArguments();
- // Conservative in the case that the #args has changed
- for (size_t i = 0; i < blockArgs.size() && i < blockArgNames.size(); ++i) {
- auto blockArgName = blockArgNames[i].cast<StringAttr>().strref();
- if (!usedNames.count(blockArgName))
- setValueName(blockArgs[i], blockArgName, /*allowNumeric=*/true);
- }
- op.removeDiscardableAttr("mlir.blockArgNames");
+void SSANameState::setRetainedIdentifierNames(Region ®ion) {
+ if (Operation *op = region.getParentOp()) {
+ SmallVector<int, 2> resultGroups;
+ setRetainedIdentifierNames(*op, resultGroups, true);
}
- return;
}
void SSANameState::getResultIDAndNumber(
diff --git a/mlir/test/IR/print-retain-identifiers.mlir b/mlir/test/IR/print-retain-identifiers.mlir
index b3e4f075b3936a..65aa0507d205fa 100644
--- a/mlir/test/IR/print-retain-identifiers.mlir
+++ b/mlir/test/IR/print-retain-identifiers.mlir
@@ -5,8 +5,8 @@
// Test SSA results (with single return values)
//===----------------------------------------------------------------------===//
-// CHECK: func.func @add_one(%my_input: f64, %arg1: f64) -> f64 {
-func.func @add_one(%my_input: f64, %arg1: f64) -> f64 {
+// CHECK: func.func @add_one(%my_input: f64) -> f64 {
+func.func @add_one(%my_input: f64) -> f64 {
// CHECK: %my_constant = arith.constant 1.000000e+00 : f64
%my_constant = arith.constant 1.000000e+00 : f64
// CHECK: %my_output = arith.addf %my_input, %my_constant : f64
@@ -71,7 +71,7 @@ func.func @select_min_max(%a: f64, %b: f64) -> (f64, f64) {
// -----
-////===----------------------------------------------------------------------===//
+//===----------------------------------------------------------------------===//
// Test multiple return values, with a grouped value tuple
//===----------------------------------------------------------------------===//
@@ -90,3 +90,24 @@ func.func @select_max(%a: f64, %b: f64, %c: f64, %d: f64) -> (f64, f64, f64, f64
}
// -----
+
+//===----------------------------------------------------------------------===//
+// Test identifiers which may clash with OpAsmOpInterface names (e.g., cst, %1, etc)
+//===----------------------------------------------------------------------===//
+
+// CHECK: func.func @clash(%arg1: f64, %arg0: f64, %arg2: f64) -> f64 {
+func.func @clash(%arg1: f64, %arg0: f64, %arg2: f64) -> f64 {
+ %my_constant = arith.constant 1.000000e+00 : f64
+ // CHECK: %cst = arith.constant 2.000000e+00 : f64
+ %cst = arith.constant 2.000000e+00 : f64
+ // CHECK: %cst_1 = arith.constant 3.000000e+00 : f64
+ %cst_1 = arith.constant 3.000000e+00 : f64
+ // CHECK: %1 = arith.addf %arg1, %cst : f64
+ %1 = arith.addf %arg1, %cst : f64
+ // CHECK: %0 = arith.addf %arg1, %cst_1 : f64
+ %0 = arith.addf %arg1, %cst_1 : f64
+ // CHECK: return %1 : f64
+ return %1 : f64
+}
+
+// -----
More information about the Mlir-commits
mailing list