[Mlir-commits] [mlir] [mlir] Retain original identifier names for debugging (PR #79704)
Perry Gibson
llvmlistbot at llvm.org
Sat Jan 27 11:00:38 PST 2024
https://github.com/Wheest updated https://github.com/llvm/llvm-project/pull/79704
>From 7e4a6a7ce6018cf2ee8da7138bd84a3c66ee82a8 Mon Sep 17 00:00:00 2001
From: Perry Gibson <perry at gibsonic.org>
Date: Fri, 26 Jan 2024 14:34:14 +0000
Subject: [PATCH 1/9] Added single result SSA processing
---
mlir/include/mlir/IR/AsmState.h | 12 +++++++++---
mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h | 10 ++++++++++
mlir/lib/AsmParser/Parser.cpp | 17 +++++++++++++++++
mlir/lib/IR/AsmPrinter.cpp | 17 +++++++++++------
mlir/lib/Tools/mlir-opt/MlirOptMain.cpp | 10 ++++++++--
5 files changed, 55 insertions(+), 11 deletions(-)
diff --git a/mlir/include/mlir/IR/AsmState.h b/mlir/include/mlir/IR/AsmState.h
index 42cbedcf9f88374..9c4eadb04cdf2fe 100644
--- a/mlir/include/mlir/IR/AsmState.h
+++ b/mlir/include/mlir/IR/AsmState.h
@@ -144,8 +144,7 @@ class AsmResourceBlob {
/// Return the underlying data as an array of the given type. This is an
/// inherrently unsafe operation, and should only be used when the data is
/// known to be of the correct type.
- template <typename T>
- ArrayRef<T> getDataAs() const {
+ template <typename T> ArrayRef<T> getDataAs() const {
return llvm::ArrayRef<T>((const T *)data.data(), data.size() / sizeof(T));
}
@@ -464,8 +463,10 @@ class ParserConfig {
/// `fallbackResourceMap` is an optional fallback handler that can be used to
/// parse external resources not explicitly handled by another parser.
ParserConfig(MLIRContext *context, bool verifyAfterParse = true,
- FallbackAsmResourceMap *fallbackResourceMap = nullptr)
+ FallbackAsmResourceMap *fallbackResourceMap = nullptr,
+ bool retainIdentifierNames = false)
: context(context), verifyAfterParse(verifyAfterParse),
+ retainIdentifierNames(retainIdentifierNames),
fallbackResourceMap(fallbackResourceMap) {
assert(context && "expected valid MLIR context");
}
@@ -476,6 +477,10 @@ class ParserConfig {
/// Returns if the parser should verify the IR after parsing.
bool shouldVerifyAfterParse() const { return verifyAfterParse; }
+ /// Returns if the parser should retain identifier names collected using
+ /// parsing.
+ bool shouldRetainIdentifierNames() const { return retainIdentifierNames; }
+
/// Returns the parsing configurations associated to the bytecode read.
BytecodeReaderConfig &getBytecodeReaderConfig() const {
return const_cast<BytecodeReaderConfig &>(bytecodeReaderConfig);
@@ -513,6 +518,7 @@ class ParserConfig {
private:
MLIRContext *context;
bool verifyAfterParse;
+ bool retainIdentifierNames;
DenseMap<StringRef, std::unique_ptr<AsmResourceParser>> resourceParsers;
FallbackAsmResourceMap *fallbackResourceMap;
BytecodeReaderConfig bytecodeReaderConfig;
diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
index 6e90fad1618d21b..a85dca186a4f3c9 100644
--- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
+++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
@@ -176,6 +176,13 @@ class MlirOptMainConfig {
/// Reproducer file generation (no crash required).
StringRef getReproducerFilename() const { return generateReproducerFileFlag; }
+ /// Print the pass-pipeline as text before executing.
+ MlirOptMainConfig &retainIdentifierNames(bool retain) {
+ retainIdentifierNamesFlag = retain;
+ return *this;
+ }
+ bool shouldRetainIdentifierNames() const { return retainIdentifierNamesFlag; }
+
protected:
/// Allow operation with no registered dialects.
/// This option is for convenience during testing only and discouraged in
@@ -226,6 +233,9 @@ class MlirOptMainConfig {
/// the corresponding line. This is meant for implementing diagnostic tests.
bool verifyDiagnosticsFlag = false;
+ /// Retain identifier names in the output (e.g., `%my_var` instead of `%0`).
+ bool retainIdentifierNamesFlag = false;
+
/// Run the verifier after each transformation pass.
bool verifyPassesFlag = true;
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 00f2b0c0c2f12f9..8d3861a2f018d00 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -1223,6 +1223,23 @@ ParseResult OperationParser::parseOperation() {
}
}
+ // If enabled, store the SSA name(s) for the operation
+ if (state.config.shouldRetainIdentifierNames()) {
+ if (opResI == 1) {
+ for (ResultRecord &resIt : resultIDs) {
+ for (unsigned subRes : llvm::seq<unsigned>(0, std::get<1>(resIt))) {
+ op->setDiscardableAttr(
+ "mlir.ssaName",
+ StringAttr::get(getContext(),
+ std::get<0>(resIt).drop_front(1)));
+ }
+ }
+ } else if (opResI > 1) {
+ emitError(
+ "have not yet implemented support for multiple return values");
+ }
+ }
+
// Add this operation to the assembly state if it was provided to populate.
} else if (state.asmState) {
state.asmState->finalizeOperationDefinition(
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 6b8b7473bf0f8a1..164b0f97fc1cd9e 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -72,13 +72,12 @@ OpAsmParser::~OpAsmParser() = default;
MLIRContext *AsmParser::getContext() const { return getBuilder().getContext(); }
/// Parse a type list.
-/// This is out-of-line to work-around https://github.com/llvm/llvm-project/issues/62918
+/// This is out-of-line to work-around
+/// https://github.com/llvm/llvm-project/issues/62918
ParseResult AsmParser::parseTypeList(SmallVectorImpl<Type> &result) {
- return parseCommaSeparatedList(
- [&]() { return parseType(result.emplace_back()); });
- }
-
-
+ return parseCommaSeparatedList(
+ [&]() { return parseType(result.emplace_back()); });
+}
//===----------------------------------------------------------------------===//
// DialectAsmPrinter
@@ -1579,6 +1578,12 @@ void SSANameState::numberValuesInOp(Operation &op) {
}
Value resultBegin = op.getResult(0);
+ // Get the original SSA for the result if available
+ if (StringAttr ssaNameAttr = op.getAttrOfType<StringAttr>("mlir.ssaName")) {
+ setValueName(resultBegin, ssaNameAttr.strref());
+ op.removeDiscardableAttr("mlir.ssaName");
+ }
+
// If the first result wasn't numbered, give it a default number.
if (valueIDs.try_emplace(resultBegin, nextValueID).second)
++nextValueID;
diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
index 5395aa2b502d788..c4482435861590d 100644
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -149,6 +149,11 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig {
cl::desc("Round-trip the IR after parsing and ensure it succeeds"),
cl::location(verifyRoundtripFlag), cl::init(false));
+ static cl::opt<bool, /*ExternalStorage=*/true> retainIdentifierNames(
+ "retain-identifier-names",
+ cl::desc("Retain the original names of identifiers when printing"),
+ cl::location(retainIdentifierNamesFlag), cl::init(false));
+
static cl::list<std::string> passPlugins(
"load-pass-plugin", cl::desc("Load passes from plugin library"));
@@ -359,8 +364,9 @@ performActions(raw_ostream &os,
// untouched.
PassReproducerOptions reproOptions;
FallbackAsmResourceMap fallbackResourceMap;
- ParserConfig parseConfig(context, /*verifyAfterParse=*/true,
- &fallbackResourceMap);
+ ParserConfig parseConfig(
+ context, /*verifyAfterParse=*/true, &fallbackResourceMap,
+ /*retainIdentifierName=*/config.shouldRetainIdentifierNames());
if (config.shouldRunReproducer())
reproOptions.attachResourceParser(parseConfig);
>From 5a629d2cd3e6f736af20c808d53206f086ea5b05 Mon Sep 17 00:00:00 2001
From: Perry Gibson <perry at gibsonic.org>
Date: Fri, 26 Jan 2024 14:52:10 +0000
Subject: [PATCH 2/9] Moved SSA name stored to function
---
mlir/lib/AsmParser/Parser.cpp | 40 ++++++++++++++++++++++-------------
1 file changed, 25 insertions(+), 15 deletions(-)
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 8d3861a2f018d00..282077865c665a9 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -611,6 +611,10 @@ class OperationParser : public Parser {
/// an object of type 'OperationName'. Otherwise, failure is returned.
FailureOr<OperationName> parseCustomOperationName();
+ /// Store the SSA names for the current operation as attrs for debug purposes.
+ ParseResult storeSSANames(Operation *&op,
+ SmallVector<ResultRecord, 1> resultIDs);
+
//===--------------------------------------------------------------------===//
// Region Parsing
//===--------------------------------------------------------------------===//
@@ -1224,21 +1228,8 @@ ParseResult OperationParser::parseOperation() {
}
// If enabled, store the SSA name(s) for the operation
- if (state.config.shouldRetainIdentifierNames()) {
- if (opResI == 1) {
- for (ResultRecord &resIt : resultIDs) {
- for (unsigned subRes : llvm::seq<unsigned>(0, std::get<1>(resIt))) {
- op->setDiscardableAttr(
- "mlir.ssaName",
- StringAttr::get(getContext(),
- std::get<0>(resIt).drop_front(1)));
- }
- }
- } else if (opResI > 1) {
- emitError(
- "have not yet implemented support for multiple return values");
- }
- }
+ if (state.config.shouldRetainIdentifierNames())
+ storeSSANames(op, resultIDs);
// Add this operation to the assembly state if it was provided to populate.
} else if (state.asmState) {
@@ -1285,6 +1276,25 @@ OperationParser::parseSuccessors(SmallVectorImpl<Block *> &destinations) {
/*allowEmptyList=*/false);
}
+/// Store the SSA names for the current operation as attrs for debug purposes.
+ParseResult
+OperationParser::storeSSANames(Operation *&op,
+ SmallVector<ResultRecord, 1> resultIDs) {
+ if (op->getNumResults() == 0)
+ emitError("Operation has no results\n");
+ else if (op->getNumResults() > 1)
+ emitError("have not yet implemented support for multiple return values\n");
+
+ for (ResultRecord &resIt : resultIDs) {
+ for (unsigned subRes : llvm::seq<unsigned>(0, std::get<1>(resIt))) {
+ op->setDiscardableAttr(
+ "mlir.ssaName",
+ StringAttr::get(getContext(), std::get<0>(resIt).drop_front(1)));
+ }
+ }
+ return success();
+}
+
namespace {
// RAII-style guard for cleaning up the regions in the operation state before
// deleting them. Within the parser, regions may get deleted if parsing failed,
>From f2f8047bd80df0f6da6a8db4bfe976491157993e Mon Sep 17 00:00:00 2001
From: Perry Gibson <perry at gibsonic.org>
Date: Fri, 26 Jan 2024 15:39:18 +0000
Subject: [PATCH 3/9] Added initial block name handling
---
mlir/lib/AsmParser/Parser.cpp | 13 +++++++++++++
mlir/lib/IR/AsmPrinter.cpp | 33 ++++++++++++++++++++++++++++-----
2 files changed, 41 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 282077865c665a9..bb3fd5d8729406a 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -1292,6 +1292,19 @@ OperationParser::storeSSANames(Operation *&op,
StringAttr::get(getContext(), std::get<0>(resIt).drop_front(1)));
}
}
+
+ // Find the name of the block that contains this operation.
+ Block *blockPtr = op->getBlock();
+ for (const auto &map : blocksByName) {
+ for (const auto &entry : map) {
+ if (entry.second.block == blockPtr) {
+ op->setDiscardableAttr("mlir.blockName",
+ StringAttr::get(getContext(), entry.first));
+ llvm::outs() << "Block name: " << entry.first << "\n";
+ }
+ }
+ }
+
return success();
}
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 164b0f97fc1cd9e..70d38c9bd8b4f89 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1298,6 +1298,10 @@ class SSANameState {
/// conflicts, it is automatically renamed.
StringRef uniqueValueName(StringRef name);
+ /// Set the original identifier names if available. Used in debugging with
+ /// `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig
+ void setRetainedIdentifierNames(Operation &op);
+
/// This is the value ID for each SSA value. If this returns NameSentinel,
/// then the valueID has an entry in valueNames.
DenseMap<Value, unsigned> valueIDs;
@@ -1578,11 +1582,9 @@ void SSANameState::numberValuesInOp(Operation &op) {
}
Value resultBegin = op.getResult(0);
- // Get the original SSA for the result if available
- if (StringAttr ssaNameAttr = op.getAttrOfType<StringAttr>("mlir.ssaName")) {
- setValueName(resultBegin, ssaNameAttr.strref());
- op.removeDiscardableAttr("mlir.ssaName");
- }
+ // Set the original identifier names if available. Used in debugging with
+ // `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig
+ setRetainedIdentifierNames(op);
// If the first result wasn't numbered, give it a default number.
if (valueIDs.try_emplace(resultBegin, nextValueID).second)
@@ -1595,6 +1597,27 @@ void SSANameState::numberValuesInOp(Operation &op) {
}
}
+void SSANameState::setRetainedIdentifierNames(Operation &op) {
+ // Get the original SSA for the result(s) if available
+ Value resultBegin = op.getResult(0);
+ if (StringAttr ssaNameAttr = op.getAttrOfType<StringAttr>("mlir.ssaName")) {
+ setValueName(resultBegin, ssaNameAttr.strref());
+ op.removeDiscardableAttr("mlir.ssaName");
+ }
+ unsigned numResults = op.getNumResults();
+ if (numResults > 1)
+ llvm::outs()
+ << "have not yet implemented support for multiple return values\n";
+
+ // Get the original SSA name for the block if available
+ if (StringAttr blockNameAttr =
+ op.getAttrOfType<StringAttr>("mlir.blockName")) {
+ blockNames[op.getBlock()] = {-1, blockNameAttr.strref()};
+ op.removeDiscardableAttr("mlir.blockName");
+ }
+ return;
+}
+
void SSANameState::getResultIDAndNumber(
OpResult result, Value &lookupValue,
std::optional<int> &lookupResultNo) const {
>From 1aa3814f6eb47ec8d51d2eb68e40e469194e2744 Mon Sep 17 00:00:00 2001
From: Perry Gibson <perry at gibsonic.org>
Date: Fri, 26 Jan 2024 16:18:27 +0000
Subject: [PATCH 4/9] Added block arg name handling
---
mlir/lib/AsmParser/Parser.cpp | 25 ++++++++++++++++++++-----
mlir/lib/IR/AsmPrinter.cpp | 14 +++++++++++++-
2 files changed, 33 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index bb3fd5d8729406a..95b1b2fe1f8f07e 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -614,6 +614,7 @@ class OperationParser : public Parser {
/// Store the SSA names for the current operation as attrs for debug purposes.
ParseResult storeSSANames(Operation *&op,
SmallVector<ResultRecord, 1> resultIDs);
+ DenseMap<BlockArgument, StringRef> blockArgNames;
//===--------------------------------------------------------------------===//
// Region Parsing
@@ -1203,6 +1204,11 @@ ParseResult OperationParser::parseOperation() {
<< op->getNumResults() << " results but was provided "
<< numExpectedResults << " to bind";
+ // If enabled, store the SSA name(s) for the operation
+ llvm::outs() << "parsing operation: " << op->getName() << "\n";
+ if (state.config.shouldRetainIdentifierNames())
+ storeSSANames(op, resultIDs);
+
// Add this operation to the assembly state if it was provided to populate.
if (state.asmState) {
unsigned resultIt = 0;
@@ -1227,10 +1233,6 @@ ParseResult OperationParser::parseOperation() {
}
}
- // If enabled, store the SSA name(s) for the operation
- if (state.config.shouldRetainIdentifierNames())
- storeSSANames(op, resultIDs);
-
// Add this operation to the assembly state if it was provided to populate.
} else if (state.asmState) {
state.asmState->finalizeOperationDefinition(
@@ -1300,7 +1302,16 @@ OperationParser::storeSSANames(Operation *&op,
if (entry.second.block == blockPtr) {
op->setDiscardableAttr("mlir.blockName",
StringAttr::get(getContext(), entry.first));
- llvm::outs() << "Block name: " << entry.first << "\n";
+
+ // Store block arguments, if present
+ llvm::SmallVector<llvm::StringRef, 1> argNames;
+
+ for (BlockArgument arg : blockPtr->getArguments()) {
+ auto it = blockArgNames.find(arg);
+ if (it != blockArgNames.end())
+ argNames.push_back(it->second.drop_front(1));
+ }
+ op->setAttr("mlir.blockArgNames", builder.getStrArrayAttr(argNames));
}
}
}
@@ -2395,6 +2406,10 @@ ParseResult OperationParser::parseOptionalBlockArgList(Block *owner) {
} else {
auto loc = getEncodedSourceLocation(useInfo.location);
arg = owner->addArgument(type, loc);
+
+ // Optionally store argument name for debug purposes
+ if (state.config.shouldRetainIdentifierNames())
+ blockArgNames.insert({arg, useInfo.name});
}
// If the argument has an explicit loc(...) specifier, parse and apply
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 70d38c9bd8b4f89..34e0bdfccda94e5 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1609,12 +1609,24 @@ void SSANameState::setRetainedIdentifierNames(Operation &op) {
llvm::outs()
<< "have not yet implemented support for multiple return values\n";
- // Get the original SSA name for the block if available
+ // Get the original name for the block if available
if (StringAttr blockNameAttr =
op.getAttrOfType<StringAttr>("mlir.blockName")) {
blockNames[op.getBlock()] = {-1, blockNameAttr.strref()};
op.removeDiscardableAttr("mlir.blockName");
}
+
+ // Get the original name for the block args if available
+ if (ArrayAttr blockArgNamesAttr =
+ op.getAttrOfType<ArrayAttr>("mlir.blockArgNames")) {
+ auto blockArgNames = blockArgNamesAttr.getValue();
+ auto blockArgs = op.getBlock()->getArguments();
+ for (int i = 0; i < blockArgs.size() && i < blockArgNames.size(); ++i) {
+ auto blockArgName = blockArgNames[i].cast<StringAttr>();
+ setValueName(blockArgs[i], cast<StringAttr>(blockArgNames[i]).strref());
+ }
+ op.removeDiscardableAttr("mlir.blockArgNames");
+ }
return;
}
>From c0ec451d9ba155478ee5bded93cfaa912480e0ce Mon Sep 17 00:00:00 2001
From: Perry Gibson <perry at gibsonic.org>
Date: Sat, 27 Jan 2024 11:24:53 +0000
Subject: [PATCH 5/9] Added support for operations with no arguments
---
mlir/lib/AsmParser/Parser.cpp | 26 ++++++++++----------------
mlir/lib/IR/AsmPrinter.cpp | 27 +++++++++++++++------------
2 files changed, 25 insertions(+), 28 deletions(-)
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 95b1b2fe1f8f07e..36266be5af033c0 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -612,8 +612,7 @@ class OperationParser : public Parser {
FailureOr<OperationName> parseCustomOperationName();
/// Store the SSA names for the current operation as attrs for debug purposes.
- ParseResult storeSSANames(Operation *&op,
- SmallVector<ResultRecord, 1> resultIDs);
+ void storeSSANames(Operation *&op, ArrayRef<ResultRecord> resultIDs);
DenseMap<BlockArgument, StringRef> blockArgNames;
//===--------------------------------------------------------------------===//
@@ -1204,11 +1203,6 @@ ParseResult OperationParser::parseOperation() {
<< op->getNumResults() << " results but was provided "
<< numExpectedResults << " to bind";
- // If enabled, store the SSA name(s) for the operation
- llvm::outs() << "parsing operation: " << op->getName() << "\n";
- if (state.config.shouldRetainIdentifierNames())
- storeSSANames(op, resultIDs);
-
// Add this operation to the assembly state if it was provided to populate.
if (state.asmState) {
unsigned resultIt = 0;
@@ -1279,15 +1273,12 @@ OperationParser::parseSuccessors(SmallVectorImpl<Block *> &destinations) {
}
/// Store the SSA names for the current operation as attrs for debug purposes.
-ParseResult
-OperationParser::storeSSANames(Operation *&op,
- SmallVector<ResultRecord, 1> resultIDs) {
- if (op->getNumResults() == 0)
- emitError("Operation has no results\n");
- else if (op->getNumResults() > 1)
+void OperationParser::storeSSANames(Operation *&op,
+ ArrayRef<ResultRecord> resultIDs) {
+ if (op->getNumResults() > 1)
emitError("have not yet implemented support for multiple return values\n");
- for (ResultRecord &resIt : resultIDs) {
+ for (const ResultRecord &resIt : resultIDs) {
for (unsigned subRes : llvm::seq<unsigned>(0, std::get<1>(resIt))) {
op->setDiscardableAttr(
"mlir.ssaName",
@@ -1315,8 +1306,6 @@ OperationParser::storeSSANames(Operation *&op,
}
}
}
-
- return success();
}
namespace {
@@ -2082,6 +2071,11 @@ OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) {
// Otherwise, create the operation and try to parse a location for it.
Operation *op = opBuilder.create(opState);
+
+ // If enabled, store the SSA name(s) for the operation
+ if (state.config.shouldRetainIdentifierNames())
+ storeSSANames(op, resultIDs);
+
if (parseTrailingLocationSpecifier(op))
return nullptr;
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 34e0bdfccda94e5..d4be3b7802e6880 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1571,6 +1571,10 @@ void SSANameState::numberValuesInOp(Operation &op) {
}
}
+ // Set the original identifier names if available. Used in debugging with
+ // `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig
+ setRetainedIdentifierNames(op);
+
unsigned numResults = op.getNumResults();
if (numResults == 0) {
// If value users should be printed, operations with no result need an id.
@@ -1582,10 +1586,6 @@ void SSANameState::numberValuesInOp(Operation &op) {
}
Value resultBegin = op.getResult(0);
- // Set the original identifier names if available. Used in debugging with
- // `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig
- setRetainedIdentifierNames(op);
-
// If the first result wasn't numbered, give it a default number.
if (valueIDs.try_emplace(resultBegin, nextValueID).second)
++nextValueID;
@@ -1599,15 +1599,17 @@ void SSANameState::numberValuesInOp(Operation &op) {
void SSANameState::setRetainedIdentifierNames(Operation &op) {
// Get the original SSA for the result(s) if available
- Value resultBegin = op.getResult(0);
- if (StringAttr ssaNameAttr = op.getAttrOfType<StringAttr>("mlir.ssaName")) {
- setValueName(resultBegin, ssaNameAttr.strref());
- op.removeDiscardableAttr("mlir.ssaName");
- }
unsigned numResults = op.getNumResults();
if (numResults > 1)
llvm::outs()
<< "have not yet implemented support for multiple return values\n";
+ else if (numResults == 1) {
+ Value resultBegin = op.getResult(0);
+ if (StringAttr ssaNameAttr = op.getAttrOfType<StringAttr>("mlir.ssaName")) {
+ setValueName(resultBegin, ssaNameAttr.strref());
+ op.removeDiscardableAttr("mlir.ssaName");
+ }
+ }
// Get the original name for the block if available
if (StringAttr blockNameAttr =
@@ -1621,9 +1623,10 @@ void SSANameState::setRetainedIdentifierNames(Operation &op) {
op.getAttrOfType<ArrayAttr>("mlir.blockArgNames")) {
auto blockArgNames = blockArgNamesAttr.getValue();
auto blockArgs = op.getBlock()->getArguments();
- for (int i = 0; i < blockArgs.size() && i < blockArgNames.size(); ++i) {
- auto blockArgName = blockArgNames[i].cast<StringAttr>();
- setValueName(blockArgs[i], cast<StringAttr>(blockArgNames[i]).strref());
+ for (size_t i = 0; i < blockArgs.size() && i < blockArgNames.size(); ++i) {
+ auto blockArgName = blockArgNames[i].cast<StringAttr>().strref();
+ if (!usedNames.count(blockArgName))
+ setValueName(blockArgs[i], blockArgName);
}
op.removeDiscardableAttr("mlir.blockArgNames");
}
>From a9e3891f4ed0511de6797cfe986fe8d3dcbf8f08 Mon Sep 17 00:00:00 2001
From: Perry Gibson <perry at gibsonic.org>
Date: Sat, 27 Jan 2024 12:21:39 +0000
Subject: [PATCH 6/9] Added initial unit tests
---
mlir/test/IR/print-retain-identifiers.mlir | 54 ++++++++++++++++++++++
1 file changed, 54 insertions(+)
create mode 100644 mlir/test/IR/print-retain-identifiers.mlir
diff --git a/mlir/test/IR/print-retain-identifiers.mlir b/mlir/test/IR/print-retain-identifiers.mlir
new file mode 100644
index 000000000000000..663f7301b817bc0
--- /dev/null
+++ b/mlir/test/IR/print-retain-identifiers.mlir
@@ -0,0 +1,54 @@
+// RUN: mlir-opt -retain-identifier-names %s | FileCheck %s
+
+
+//===----------------------------------------------------------------------===//
+// Test SSA results (with single return values)
+//===----------------------------------------------------------------------===//
+
+// CHECK: func.func @add_one(%arg0: f64, %arg1: f64) -> f64 {
+func.func @add_one(%arg0: f64, %arg1: f64) -> f64 {
+ // CHECK: %my_constant = arith.constant 1.000000e+00 : f64
+ %my_constant = arith.constant 1.000000e+00 : f64
+ // CHECK: %my_output = arith.addf %arg0, %my_constant : f64
+ %my_output = arith.addf %arg0, %my_constant : f64
+ // CHECK: return %my_output : f64
+ return %my_output : f64
+}
+
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Test basic blocks and their arguments
+//===----------------------------------------------------------------------===//
+
+func.func @simple(i64, i1) -> i64 {
+^bb_alpha(%a: i64, %cond: i1):
+ // CHECK: cf.cond_br %cond, ^bb_beta, ^bb_gamma
+ cf.cond_br %cond, ^bb_beta, ^bb_gamma
+
+// CHECK: ^bb_beta: // pred: ^bb_alpha
+^bb_beta:
+ // CHECK: cf.br ^bb_delta(%a : i64)
+ cf.br ^bb_delta(%a: i64)
+
+// CHECK: ^bb_gamma: // pred: ^bb_alpha
+^bb_gamma:
+ // CHECK: %b = arith.addi %a, %a : i64
+ %b = arith.addi %a, %a : i64
+ // CHECK: cf.br ^bb_delta(%b : i64)
+ cf.br ^bb_delta(%b: i64)
+
+// CHECK: ^bb_delta(%c: i64): // 2 preds: ^bb_gamma, ^bb_beta
+^bb_delta(%c: i64):
+ // CHECK: cf.br ^bb_eps(%c, %a : i64, i64)
+ cf.br ^bb_eps(%c, %a : i64, i64)
+
+// CHECK: ^bb_eps(%d: i64, %e: i64): // pred: ^bb_delta
+^bb_eps(%d : i64, %e : i64):
+ // CHECK: %f = arith.addi %d, %e : i64
+ %f = arith.addi %d, %e : i64
+ return %f : i64
+}
+
+// -----
>From 467c3a38abfbbc22c25b0874a3943ee7ad8bc78b Mon Sep 17 00:00:00 2001
From: Perry Gibson <perry at gibsonic.org>
Date: Sat, 27 Jan 2024 14:49:16 +0000
Subject: [PATCH 7/9] Added operation operand name preservation
---
mlir/lib/AsmParser/Parser.cpp | 34 +++++++++++++++++-----
mlir/lib/IR/AsmPrinter.cpp | 13 +++++++++
mlir/test/IR/print-retain-identifiers.mlir | 8 ++---
3 files changed, 43 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 36266be5af033c0..7327776ae5a55d1 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -613,7 +613,7 @@ class OperationParser : public Parser {
/// Store the SSA names for the current operation as attrs for debug purposes.
void storeSSANames(Operation *&op, ArrayRef<ResultRecord> resultIDs);
- DenseMap<BlockArgument, StringRef> blockArgNames;
+ DenseMap<Value, StringRef> argNames;
//===--------------------------------------------------------------------===//
// Region Parsing
@@ -1286,7 +1286,19 @@ void OperationParser::storeSSANames(Operation *&op,
}
}
- // Find the name of the block that contains this operation.
+ // Store the name information of the arguments of this operation.
+ if (op->getNumOperands() > 0) {
+ llvm::SmallVector<llvm::StringRef, 1> opArgNames;
+ for (auto &operand : op->getOpOperands()) {
+ auto it = argNames.find(operand.get());
+ if (it != argNames.end())
+ opArgNames.push_back(it->second.drop_front(1));
+ }
+ op->setDiscardableAttr("mlir.opArgNames",
+ builder.getStrArrayAttr(opArgNames));
+ }
+
+ // Store the name information of the block that contains this operation.
Block *blockPtr = op->getBlock();
for (const auto &map : blocksByName) {
for (const auto &entry : map) {
@@ -1295,14 +1307,15 @@ void OperationParser::storeSSANames(Operation *&op,
StringAttr::get(getContext(), entry.first));
// Store block arguments, if present
- llvm::SmallVector<llvm::StringRef, 1> argNames;
+ llvm::SmallVector<llvm::StringRef, 1> blockArgNames;
for (BlockArgument arg : blockPtr->getArguments()) {
- auto it = blockArgNames.find(arg);
- if (it != blockArgNames.end())
- argNames.push_back(it->second.drop_front(1));
+ auto it = argNames.find(arg);
+ if (it != argNames.end())
+ blockArgNames.push_back(it->second.drop_front(1));
}
- op->setAttr("mlir.blockArgNames", builder.getStrArrayAttr(argNames));
+ op->setAttr("mlir.blockArgNames",
+ builder.getStrArrayAttr(blockArgNames));
}
}
}
@@ -1712,6 +1725,11 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
SmallVectorImpl<Value> &result) override {
if (auto value = parser.resolveSSAUse(operand, type)) {
result.push_back(value);
+
+ // Optionally store argument name for debug purposes
+ if (parser.getState().config.shouldRetainIdentifierNames())
+ parser.argNames.insert({value, operand.name});
+
return success();
}
return failure();
@@ -2403,7 +2421,7 @@ ParseResult OperationParser::parseOptionalBlockArgList(Block *owner) {
// Optionally store argument name for debug purposes
if (state.config.shouldRetainIdentifierNames())
- blockArgNames.insert({arg, useInfo.name});
+ argNames.insert({arg, useInfo.name});
}
// If the argument has an explicit loc(...) specifier, parse and apply
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index d4be3b7802e6880..5755b9021dbad7f 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1611,6 +1611,19 @@ void SSANameState::setRetainedIdentifierNames(Operation &op) {
}
}
+ // Get the original name for the op args if available
+ if (ArrayAttr opArgNamesAttr =
+ op.getAttrOfType<ArrayAttr>("mlir.opArgNames")) {
+ auto opArgNames = opArgNamesAttr.getValue();
+ auto opArgs = op.getOperands();
+ for (size_t i = 0; i < opArgs.size() && i < opArgNames.size(); ++i) {
+ auto opArgName = opArgNames[i].cast<StringAttr>().strref();
+ if (!usedNames.count(opArgName))
+ setValueName(opArgs[i], opArgName);
+ }
+ op.removeDiscardableAttr("mlir.opArgNames");
+ }
+
// Get the original name for the block if available
if (StringAttr blockNameAttr =
op.getAttrOfType<StringAttr>("mlir.blockName")) {
diff --git a/mlir/test/IR/print-retain-identifiers.mlir b/mlir/test/IR/print-retain-identifiers.mlir
index 663f7301b817bc0..05a6b9b42d8e820 100644
--- a/mlir/test/IR/print-retain-identifiers.mlir
+++ b/mlir/test/IR/print-retain-identifiers.mlir
@@ -5,12 +5,12 @@
// Test SSA results (with single return values)
//===----------------------------------------------------------------------===//
-// CHECK: func.func @add_one(%arg0: f64, %arg1: f64) -> f64 {
-func.func @add_one(%arg0: f64, %arg1: f64) -> f64 {
+// CHECK: func.func @add_one(%my_input: f64, %arg1: f64) -> f64 {
+func.func @add_one(%my_input: f64, %arg1: f64) -> f64 {
// CHECK: %my_constant = arith.constant 1.000000e+00 : f64
%my_constant = arith.constant 1.000000e+00 : f64
- // CHECK: %my_output = arith.addf %arg0, %my_constant : f64
- %my_output = arith.addf %arg0, %my_constant : f64
+ // CHECK: %my_output = arith.addf %my_input, %my_constant : f64
+ %my_output = arith.addf %my_input, %my_constant : f64
// CHECK: return %my_output : f64
return %my_output : f64
}
>From d062cb861c283594413c8edc7dca57f9494b1f3e Mon Sep 17 00:00:00 2001
From: Perry Gibson <perry at gibsonic.org>
Date: Sat, 27 Jan 2024 18:27:53 +0000
Subject: [PATCH 8/9] Added support for result groups
---
mlir/lib/AsmParser/Parser.cpp | 19 ++++++-----
mlir/lib/IR/AsmPrinter.cpp | 36 ++++++++++++--------
mlir/test/IR/print-retain-identifiers.mlir | 38 ++++++++++++++++++++++
3 files changed, 72 insertions(+), 21 deletions(-)
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 7327776ae5a55d1..247e99e61c2c01b 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -1275,15 +1275,18 @@ OperationParser::parseSuccessors(SmallVectorImpl<Block *> &destinations) {
/// Store the SSA names for the current operation as attrs for debug purposes.
void OperationParser::storeSSANames(Operation *&op,
ArrayRef<ResultRecord> resultIDs) {
- if (op->getNumResults() > 1)
- emitError("have not yet implemented support for multiple return values\n");
-
- for (const ResultRecord &resIt : resultIDs) {
- for (unsigned subRes : llvm::seq<unsigned>(0, std::get<1>(resIt))) {
- op->setDiscardableAttr(
- "mlir.ssaName",
- StringAttr::get(getContext(), std::get<0>(resIt).drop_front(1)));
+
+ // Store the name(s) of the result(s) of this operation.
+ if (op->getNumResults() > 0) {
+ llvm::SmallVector<llvm::StringRef, 1> resultNames;
+ for (const ResultRecord &resIt : resultIDs) {
+ resultNames.push_back(std::get<0>(resIt).drop_front(1));
+ // Insert empty string for sub-results/result groups
+ for (unsigned int i = 1; i < std::get<1>(resIt); ++i)
+ resultNames.push_back(llvm::StringRef());
}
+ op->setDiscardableAttr("mlir.resultNames",
+ builder.getStrArrayAttr(resultNames));
}
// Store the name information of the arguments of this operation.
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 5755b9021dbad7f..84603bb6ebfba3a 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1300,7 +1300,8 @@ class SSANameState {
/// Set the original identifier names if available. Used in debugging with
/// `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig
- void setRetainedIdentifierNames(Operation &op);
+ void setRetainedIdentifierNames(Operation &op,
+ SmallVector<int, 2> &resultGroups);
/// This is the value ID for each SSA value. If this returns NameSentinel,
/// then the valueID has an entry in valueNames.
@@ -1573,7 +1574,7 @@ void SSANameState::numberValuesInOp(Operation &op) {
// Set the original identifier names if available. Used in debugging with
// `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig
- setRetainedIdentifierNames(op);
+ setRetainedIdentifierNames(op, resultGroups);
unsigned numResults = op.getNumResults();
if (numResults == 0) {
@@ -1597,18 +1598,25 @@ void SSANameState::numberValuesInOp(Operation &op) {
}
}
-void SSANameState::setRetainedIdentifierNames(Operation &op) {
- // Get the original SSA for the result(s) if available
- unsigned numResults = op.getNumResults();
- if (numResults > 1)
- llvm::outs()
- << "have not yet implemented support for multiple return values\n";
- else if (numResults == 1) {
- Value resultBegin = op.getResult(0);
- if (StringAttr ssaNameAttr = op.getAttrOfType<StringAttr>("mlir.ssaName")) {
- setValueName(resultBegin, ssaNameAttr.strref());
- op.removeDiscardableAttr("mlir.ssaName");
+void SSANameState::setRetainedIdentifierNames(
+ Operation &op, SmallVector<int, 2> &resultGroups) {
+ // Get the original names for the results if available
+ if (ArrayAttr resultNamesAttr =
+ op.getAttrOfType<ArrayAttr>("mlir.resultNames")) {
+ auto resultNames = resultNamesAttr.getValue();
+ auto results = op.getResults();
+ // Conservative in the case that the #results has changed
+ for (size_t i = 0; i < results.size() && i < resultNames.size(); ++i) {
+ auto resultName = resultNames[i].cast<StringAttr>().strref();
+ if (!resultName.empty()) {
+ if (!usedNames.count(resultName))
+ setValueName(results[i], resultName);
+ // If a result has a name, it is the start of a result group.
+ if (i > 0)
+ resultGroups.push_back(i);
+ }
}
+ op.removeDiscardableAttr("mlir.resultNames");
}
// Get the original name for the op args if available
@@ -1616,6 +1624,7 @@ void SSANameState::setRetainedIdentifierNames(Operation &op) {
op.getAttrOfType<ArrayAttr>("mlir.opArgNames")) {
auto opArgNames = opArgNamesAttr.getValue();
auto opArgs = op.getOperands();
+ // Conservative in the case that the #operands has changed
for (size_t i = 0; i < opArgs.size() && i < opArgNames.size(); ++i) {
auto opArgName = opArgNames[i].cast<StringAttr>().strref();
if (!usedNames.count(opArgName))
@@ -1636,6 +1645,7 @@ void SSANameState::setRetainedIdentifierNames(Operation &op) {
op.getAttrOfType<ArrayAttr>("mlir.blockArgNames")) {
auto blockArgNames = blockArgNamesAttr.getValue();
auto blockArgs = op.getBlock()->getArguments();
+ // Conservative in the case that the #args has changed
for (size_t i = 0; i < blockArgs.size() && i < blockArgNames.size(); ++i) {
auto blockArgName = blockArgNames[i].cast<StringAttr>().strref();
if (!usedNames.count(blockArgName))
diff --git a/mlir/test/IR/print-retain-identifiers.mlir b/mlir/test/IR/print-retain-identifiers.mlir
index 05a6b9b42d8e820..b3e4f075b3936a8 100644
--- a/mlir/test/IR/print-retain-identifiers.mlir
+++ b/mlir/test/IR/print-retain-identifiers.mlir
@@ -52,3 +52,41 @@ func.func @simple(i64, i1) -> i64 {
}
// -----
+
+//===----------------------------------------------------------------------===//
+// Test multiple return values
+//===----------------------------------------------------------------------===//
+
+func.func @select_min_max(%a: f64, %b: f64) -> (f64, f64) {
+ %gt = arith.cmpf "ogt", %a, %b : f64
+ // CHECK: %min, %max = scf.if %gt -> (f64, f64) {
+ %min, %max = scf.if %gt -> (f64, f64) {
+ scf.yield %b, %a : f64, f64
+ } else {
+ scf.yield %a, %b : f64, f64
+ }
+ // CHECK: return %min, %max : f64, f64
+ return %min, %max : f64, f64
+}
+
+// -----
+
+////===----------------------------------------------------------------------===//
+// Test multiple return values, with a grouped value tuple
+//===----------------------------------------------------------------------===//
+
+func.func @select_max(%a: f64, %b: f64, %c: f64, %d: f64) -> (f64, f64, f64, f64) {
+ // Find the max between %a and %b,
+ // with %c and %d being other values that are returned.
+ %gt = arith.cmpf "ogt", %a, %b : f64
+ // CHECK: %max, %others:2, %alt = scf.if %gt -> (f64, f64, f64, f64) {
+ %max, %others:2, %alt = scf.if %gt -> (f64, f64, f64, f64) {
+ scf.yield %b, %a, %c, %d : f64, f64, f64, f64
+ } else {
+ scf.yield %a, %b, %d, %c : f64, f64, f64, f64
+ }
+ // CHECK: return %max, %others#0, %others#1, %alt : f64, f64, f64, f64
+ return %max, %others#0, %others#1, %alt : f64, f64, f64, f64
+}
+
+// -----
>From fb557df2c4afec5aaaf876428769b70d29306244 Mon Sep 17 00:00:00 2001
From: Perry Gibson <perry at gibsonic.org>
Date: Sat, 27 Jan 2024 18:58:35 +0000
Subject: [PATCH 9/9] Fix clang-format issue in AsmState.h
---
mlir/include/mlir/IR/AsmState.h | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/IR/AsmState.h b/mlir/include/mlir/IR/AsmState.h
index 9c4eadb04cdf2fe..36f80712efb6ac5 100644
--- a/mlir/include/mlir/IR/AsmState.h
+++ b/mlir/include/mlir/IR/AsmState.h
@@ -144,7 +144,8 @@ class AsmResourceBlob {
/// Return the underlying data as an array of the given type. This is an
/// inherrently unsafe operation, and should only be used when the data is
/// known to be of the correct type.
- template <typename T> ArrayRef<T> getDataAs() const {
+ template <typename T>
+ ArrayRef<T> getDataAs() const {
return llvm::ArrayRef<T>((const T *)data.data(), data.size() / sizeof(T));
}
More information about the Mlir-commits
mailing list