[Mlir-commits] [mlir] [mlir] Retain original identifier names for debugging (PR #79704)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Jan 27 10:50:38 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-core
Author: Perry Gibson (Wheest)
<details>
<summary>Changes</summary>
This PR implements retention of MLIR identifier names (e.g., `%my_val`, `^bb_foo`) for debugging and development purposes.
A wider discussion of this feature is in [this discourse thread](https://discourse.llvm.org/t/retain-original-identifier-names-for-debugging/76417/1).
A motivating example is that right now, IR generation drops all meaningful identifier names, which could be useful for developers trying to understand their passes, or for other tooling (e.g., [MLIR code formatters](https://discourse.llvm.org/t/clang-format-or-some-other-auto-format-for-mlir-files/75258/15)).
```mlir
func.func @<!-- -->add_one(%my_input: f64) -> f64 {
%my_constant = arith.constant 1.00000e+00 : f64
%my_output = arith.addf %my_input, %my_constant : f64
return %my_output : f64
}
```
:arrow_down: becomes
```mlir
func.func @<!-- -->add_one(%arg0: f64) -> f64 {
%cst = arith.constant 1.000000e+00 : f64
%0 = arith.addf %arg0, %cst : f64
return %0 : f64
}
```
The solution this PR implements is to store this metadata inside the attribute dictionary of operations, under a special namespace (e.g., `mlir.resultNames`). This means that this optional feature (turned on by the `mlir-opt` flag `--retain-identifier-names`) does not incur any additional overhead, except in text parsing and printing (`AsmParser/Parser.cpp` and `IR/AsmPrinter.cpp`).
Alternative solutions, such as adding a string field to the `Value` class, reparsing location information, and adapting `OpAsmInterface` are discussed in the [relevant discourse thread](https://discourse.llvm.org/t/retain-original-identifier-names-for-debugging/76417/1)).
I've implemented some initial test cases in `mlir/test/IR/print-retain-identifiers.mlir`.
This covers things such as:
- retaining the result names of operation
- retaining the names of basic blocks (and their arguments)
- handling result groups
A case that I know not to work is when a `func.func` argument is not used. This is because we recover the SSA name of these arguments from operations which use them. You can see in the first test case that the 2nd argument is not used, so will always default back to `%arg0`.
I do not have test cases for how the system handles code transformations, and am open to suggestions for additional tests to include. Also note that this is my most substantial contribution to the codebase to date, so I may require some shepherding with regards to coding style or use of core LLVM library constructs.
---
Full diff: https://github.com/llvm/llvm-project/pull/79704.diff
6 Files Affected:
- (modified) mlir/include/mlir/IR/AsmState.h (+9-3)
- (modified) mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h (+10)
- (modified) mlir/lib/AsmParser/Parser.cpp (+70)
- (modified) mlir/lib/IR/AsmPrinter.cpp (+72-6)
- (modified) mlir/lib/Tools/mlir-opt/MlirOptMain.cpp (+8-2)
- (added) mlir/test/IR/print-retain-identifiers.mlir (+92)
``````````diff
diff --git a/mlir/include/mlir/IR/AsmState.h b/mlir/include/mlir/IR/AsmState.h
index 42cbedcf9f8837..9c4eadb04cdf2f 100644
--- a/mlir/include/mlir/IR/AsmState.h
+++ b/mlir/include/mlir/IR/AsmState.h
@@ -144,8 +144,7 @@ class AsmResourceBlob {
/// Return the underlying data as an array of the given type. This is an
/// inherrently unsafe operation, and should only be used when the data is
/// known to be of the correct type.
- template <typename T>
- ArrayRef<T> getDataAs() const {
+ template <typename T> ArrayRef<T> getDataAs() const {
return llvm::ArrayRef<T>((const T *)data.data(), data.size() / sizeof(T));
}
@@ -464,8 +463,10 @@ class ParserConfig {
/// `fallbackResourceMap` is an optional fallback handler that can be used to
/// parse external resources not explicitly handled by another parser.
ParserConfig(MLIRContext *context, bool verifyAfterParse = true,
- FallbackAsmResourceMap *fallbackResourceMap = nullptr)
+ FallbackAsmResourceMap *fallbackResourceMap = nullptr,
+ bool retainIdentifierNames = false)
: context(context), verifyAfterParse(verifyAfterParse),
+ retainIdentifierNames(retainIdentifierNames),
fallbackResourceMap(fallbackResourceMap) {
assert(context && "expected valid MLIR context");
}
@@ -476,6 +477,10 @@ class ParserConfig {
/// Returns if the parser should verify the IR after parsing.
bool shouldVerifyAfterParse() const { return verifyAfterParse; }
+ /// Returns if the parser should retain identifier names collected using
+ /// parsing.
+ bool shouldRetainIdentifierNames() const { return retainIdentifierNames; }
+
/// Returns the parsing configurations associated to the bytecode read.
BytecodeReaderConfig &getBytecodeReaderConfig() const {
return const_cast<BytecodeReaderConfig &>(bytecodeReaderConfig);
@@ -513,6 +518,7 @@ class ParserConfig {
private:
MLIRContext *context;
bool verifyAfterParse;
+ bool retainIdentifierNames;
DenseMap<StringRef, std::unique_ptr<AsmResourceParser>> resourceParsers;
FallbackAsmResourceMap *fallbackResourceMap;
BytecodeReaderConfig bytecodeReaderConfig;
diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
index 6e90fad1618d21..a85dca186a4f3c 100644
--- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
+++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
@@ -176,6 +176,13 @@ class MlirOptMainConfig {
/// Reproducer file generation (no crash required).
StringRef getReproducerFilename() const { return generateReproducerFileFlag; }
+ /// Print the pass-pipeline as text before executing.
+ MlirOptMainConfig &retainIdentifierNames(bool retain) {
+ retainIdentifierNamesFlag = retain;
+ return *this;
+ }
+ bool shouldRetainIdentifierNames() const { return retainIdentifierNamesFlag; }
+
protected:
/// Allow operation with no registered dialects.
/// This option is for convenience during testing only and discouraged in
@@ -226,6 +233,9 @@ class MlirOptMainConfig {
/// the corresponding line. This is meant for implementing diagnostic tests.
bool verifyDiagnosticsFlag = false;
+ /// Retain identifier names in the output (e.g., `%my_var` instead of `%0`).
+ bool retainIdentifierNamesFlag = false;
+
/// Run the verifier after each transformation pass.
bool verifyPassesFlag = true;
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 00f2b0c0c2f12f..247e99e61c2c01 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -611,6 +611,10 @@ class OperationParser : public Parser {
/// an object of type 'OperationName'. Otherwise, failure is returned.
FailureOr<OperationName> parseCustomOperationName();
+ /// Store the SSA names for the current operation as attrs for debug purposes.
+ void storeSSANames(Operation *&op, ArrayRef<ResultRecord> resultIDs);
+ DenseMap<Value, StringRef> argNames;
+
//===--------------------------------------------------------------------===//
// Region Parsing
//===--------------------------------------------------------------------===//
@@ -1268,6 +1272,58 @@ OperationParser::parseSuccessors(SmallVectorImpl<Block *> &destinations) {
/*allowEmptyList=*/false);
}
+/// Store the SSA names for the current operation as attrs for debug purposes.
+void OperationParser::storeSSANames(Operation *&op,
+ ArrayRef<ResultRecord> resultIDs) {
+
+ // Store the name(s) of the result(s) of this operation.
+ if (op->getNumResults() > 0) {
+ llvm::SmallVector<llvm::StringRef, 1> resultNames;
+ for (const ResultRecord &resIt : resultIDs) {
+ resultNames.push_back(std::get<0>(resIt).drop_front(1));
+ // Insert empty string for sub-results/result groups
+ for (unsigned int i = 1; i < std::get<1>(resIt); ++i)
+ resultNames.push_back(llvm::StringRef());
+ }
+ op->setDiscardableAttr("mlir.resultNames",
+ builder.getStrArrayAttr(resultNames));
+ }
+
+ // Store the name information of the arguments of this operation.
+ if (op->getNumOperands() > 0) {
+ llvm::SmallVector<llvm::StringRef, 1> opArgNames;
+ for (auto &operand : op->getOpOperands()) {
+ auto it = argNames.find(operand.get());
+ if (it != argNames.end())
+ opArgNames.push_back(it->second.drop_front(1));
+ }
+ op->setDiscardableAttr("mlir.opArgNames",
+ builder.getStrArrayAttr(opArgNames));
+ }
+
+ // Store the name information of the block that contains this operation.
+ Block *blockPtr = op->getBlock();
+ for (const auto &map : blocksByName) {
+ for (const auto &entry : map) {
+ if (entry.second.block == blockPtr) {
+ op->setDiscardableAttr("mlir.blockName",
+ StringAttr::get(getContext(), entry.first));
+
+ // Store block arguments, if present
+ llvm::SmallVector<llvm::StringRef, 1> blockArgNames;
+
+ for (BlockArgument arg : blockPtr->getArguments()) {
+ auto it = argNames.find(arg);
+ if (it != argNames.end())
+ blockArgNames.push_back(it->second.drop_front(1));
+ }
+ op->setAttr("mlir.blockArgNames",
+ builder.getStrArrayAttr(blockArgNames));
+ }
+ }
+ }
+}
+
namespace {
// RAII-style guard for cleaning up the regions in the operation state before
// deleting them. Within the parser, regions may get deleted if parsing failed,
@@ -1672,6 +1728,11 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
SmallVectorImpl<Value> &result) override {
if (auto value = parser.resolveSSAUse(operand, type)) {
result.push_back(value);
+
+ // Optionally store argument name for debug purposes
+ if (parser.getState().config.shouldRetainIdentifierNames())
+ parser.argNames.insert({value, operand.name});
+
return success();
}
return failure();
@@ -2031,6 +2092,11 @@ OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) {
// Otherwise, create the operation and try to parse a location for it.
Operation *op = opBuilder.create(opState);
+
+ // If enabled, store the SSA name(s) for the operation
+ if (state.config.shouldRetainIdentifierNames())
+ storeSSANames(op, resultIDs);
+
if (parseTrailingLocationSpecifier(op))
return nullptr;
@@ -2355,6 +2421,10 @@ ParseResult OperationParser::parseOptionalBlockArgList(Block *owner) {
} else {
auto loc = getEncodedSourceLocation(useInfo.location);
arg = owner->addArgument(type, loc);
+
+ // Optionally store argument name for debug purposes
+ if (state.config.shouldRetainIdentifierNames())
+ argNames.insert({arg, useInfo.name});
}
// If the argument has an explicit loc(...) specifier, parse and apply
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 6b8b7473bf0f8a..84603bb6ebfba3 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -72,13 +72,12 @@ OpAsmParser::~OpAsmParser() = default;
MLIRContext *AsmParser::getContext() const { return getBuilder().getContext(); }
/// Parse a type list.
-/// This is out-of-line to work-around https://github.com/llvm/llvm-project/issues/62918
+/// This is out-of-line to work-around
+/// https://github.com/llvm/llvm-project/issues/62918
ParseResult AsmParser::parseTypeList(SmallVectorImpl<Type> &result) {
- return parseCommaSeparatedList(
- [&]() { return parseType(result.emplace_back()); });
- }
-
-
+ return parseCommaSeparatedList(
+ [&]() { return parseType(result.emplace_back()); });
+}
//===----------------------------------------------------------------------===//
// DialectAsmPrinter
@@ -1299,6 +1298,11 @@ class SSANameState {
/// conflicts, it is automatically renamed.
StringRef uniqueValueName(StringRef name);
+ /// Set the original identifier names if available. Used in debugging with
+ /// `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig
+ void setRetainedIdentifierNames(Operation &op,
+ SmallVector<int, 2> &resultGroups);
+
/// This is the value ID for each SSA value. If this returns NameSentinel,
/// then the valueID has an entry in valueNames.
DenseMap<Value, unsigned> valueIDs;
@@ -1568,6 +1572,10 @@ void SSANameState::numberValuesInOp(Operation &op) {
}
}
+ // Set the original identifier names if available. Used in debugging with
+ // `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig
+ setRetainedIdentifierNames(op, resultGroups);
+
unsigned numResults = op.getNumResults();
if (numResults == 0) {
// If value users should be printed, operations with no result need an id.
@@ -1590,6 +1598,64 @@ void SSANameState::numberValuesInOp(Operation &op) {
}
}
+void SSANameState::setRetainedIdentifierNames(
+ Operation &op, SmallVector<int, 2> &resultGroups) {
+ // Get the original names for the results if available
+ if (ArrayAttr resultNamesAttr =
+ op.getAttrOfType<ArrayAttr>("mlir.resultNames")) {
+ auto resultNames = resultNamesAttr.getValue();
+ auto results = op.getResults();
+ // Conservative in the case that the #results has changed
+ for (size_t i = 0; i < results.size() && i < resultNames.size(); ++i) {
+ auto resultName = resultNames[i].cast<StringAttr>().strref();
+ if (!resultName.empty()) {
+ if (!usedNames.count(resultName))
+ setValueName(results[i], resultName);
+ // If a result has a name, it is the start of a result group.
+ if (i > 0)
+ resultGroups.push_back(i);
+ }
+ }
+ op.removeDiscardableAttr("mlir.resultNames");
+ }
+
+ // Get the original name for the op args if available
+ if (ArrayAttr opArgNamesAttr =
+ op.getAttrOfType<ArrayAttr>("mlir.opArgNames")) {
+ auto opArgNames = opArgNamesAttr.getValue();
+ auto opArgs = op.getOperands();
+ // Conservative in the case that the #operands has changed
+ for (size_t i = 0; i < opArgs.size() && i < opArgNames.size(); ++i) {
+ auto opArgName = opArgNames[i].cast<StringAttr>().strref();
+ if (!usedNames.count(opArgName))
+ setValueName(opArgs[i], opArgName);
+ }
+ op.removeDiscardableAttr("mlir.opArgNames");
+ }
+
+ // Get the original name for the block if available
+ if (StringAttr blockNameAttr =
+ op.getAttrOfType<StringAttr>("mlir.blockName")) {
+ blockNames[op.getBlock()] = {-1, blockNameAttr.strref()};
+ op.removeDiscardableAttr("mlir.blockName");
+ }
+
+ // Get the original name for the block args if available
+ if (ArrayAttr blockArgNamesAttr =
+ op.getAttrOfType<ArrayAttr>("mlir.blockArgNames")) {
+ auto blockArgNames = blockArgNamesAttr.getValue();
+ auto blockArgs = op.getBlock()->getArguments();
+ // Conservative in the case that the #args has changed
+ for (size_t i = 0; i < blockArgs.size() && i < blockArgNames.size(); ++i) {
+ auto blockArgName = blockArgNames[i].cast<StringAttr>().strref();
+ if (!usedNames.count(blockArgName))
+ setValueName(blockArgs[i], blockArgName);
+ }
+ op.removeDiscardableAttr("mlir.blockArgNames");
+ }
+ return;
+}
+
void SSANameState::getResultIDAndNumber(
OpResult result, Value &lookupValue,
std::optional<int> &lookupResultNo) const {
diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
index 5395aa2b502d78..c4482435861590 100644
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -149,6 +149,11 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig {
cl::desc("Round-trip the IR after parsing and ensure it succeeds"),
cl::location(verifyRoundtripFlag), cl::init(false));
+ static cl::opt<bool, /*ExternalStorage=*/true> retainIdentifierNames(
+ "retain-identifier-names",
+ cl::desc("Retain the original names of identifiers when printing"),
+ cl::location(retainIdentifierNamesFlag), cl::init(false));
+
static cl::list<std::string> passPlugins(
"load-pass-plugin", cl::desc("Load passes from plugin library"));
@@ -359,8 +364,9 @@ performActions(raw_ostream &os,
// untouched.
PassReproducerOptions reproOptions;
FallbackAsmResourceMap fallbackResourceMap;
- ParserConfig parseConfig(context, /*verifyAfterParse=*/true,
- &fallbackResourceMap);
+ ParserConfig parseConfig(
+ context, /*verifyAfterParse=*/true, &fallbackResourceMap,
+ /*retainIdentifierName=*/config.shouldRetainIdentifierNames());
if (config.shouldRunReproducer())
reproOptions.attachResourceParser(parseConfig);
diff --git a/mlir/test/IR/print-retain-identifiers.mlir b/mlir/test/IR/print-retain-identifiers.mlir
new file mode 100644
index 00000000000000..b3e4f075b3936a
--- /dev/null
+++ b/mlir/test/IR/print-retain-identifiers.mlir
@@ -0,0 +1,92 @@
+// RUN: mlir-opt -retain-identifier-names %s | FileCheck %s
+
+
+//===----------------------------------------------------------------------===//
+// Test SSA results (with single return values)
+//===----------------------------------------------------------------------===//
+
+// CHECK: func.func @add_one(%my_input: f64, %arg1: f64) -> f64 {
+func.func @add_one(%my_input: f64, %arg1: f64) -> f64 {
+ // CHECK: %my_constant = arith.constant 1.000000e+00 : f64
+ %my_constant = arith.constant 1.000000e+00 : f64
+ // CHECK: %my_output = arith.addf %my_input, %my_constant : f64
+ %my_output = arith.addf %my_input, %my_constant : f64
+ // CHECK: return %my_output : f64
+ return %my_output : f64
+}
+
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Test basic blocks and their arguments
+//===----------------------------------------------------------------------===//
+
+func.func @simple(i64, i1) -> i64 {
+^bb_alpha(%a: i64, %cond: i1):
+ // CHECK: cf.cond_br %cond, ^bb_beta, ^bb_gamma
+ cf.cond_br %cond, ^bb_beta, ^bb_gamma
+
+// CHECK: ^bb_beta: // pred: ^bb_alpha
+^bb_beta:
+ // CHECK: cf.br ^bb_delta(%a : i64)
+ cf.br ^bb_delta(%a: i64)
+
+// CHECK: ^bb_gamma: // pred: ^bb_alpha
+^bb_gamma:
+ // CHECK: %b = arith.addi %a, %a : i64
+ %b = arith.addi %a, %a : i64
+ // CHECK: cf.br ^bb_delta(%b : i64)
+ cf.br ^bb_delta(%b: i64)
+
+// CHECK: ^bb_delta(%c: i64): // 2 preds: ^bb_gamma, ^bb_beta
+^bb_delta(%c: i64):
+ // CHECK: cf.br ^bb_eps(%c, %a : i64, i64)
+ cf.br ^bb_eps(%c, %a : i64, i64)
+
+// CHECK: ^bb_eps(%d: i64, %e: i64): // pred: ^bb_delta
+^bb_eps(%d : i64, %e : i64):
+ // CHECK: %f = arith.addi %d, %e : i64
+ %f = arith.addi %d, %e : i64
+ return %f : i64
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Test multiple return values
+//===----------------------------------------------------------------------===//
+
+func.func @select_min_max(%a: f64, %b: f64) -> (f64, f64) {
+ %gt = arith.cmpf "ogt", %a, %b : f64
+ // CHECK: %min, %max = scf.if %gt -> (f64, f64) {
+ %min, %max = scf.if %gt -> (f64, f64) {
+ scf.yield %b, %a : f64, f64
+ } else {
+ scf.yield %a, %b : f64, f64
+ }
+ // CHECK: return %min, %max : f64, f64
+ return %min, %max : f64, f64
+}
+
+// -----
+
+////===----------------------------------------------------------------------===//
+// Test multiple return values, with a grouped value tuple
+//===----------------------------------------------------------------------===//
+
+func.func @select_max(%a: f64, %b: f64, %c: f64, %d: f64) -> (f64, f64, f64, f64) {
+ // Find the max between %a and %b,
+ // with %c and %d being other values that are returned.
+ %gt = arith.cmpf "ogt", %a, %b : f64
+ // CHECK: %max, %others:2, %alt = scf.if %gt -> (f64, f64, f64, f64) {
+ %max, %others:2, %alt = scf.if %gt -> (f64, f64, f64, f64) {
+ scf.yield %b, %a, %c, %d : f64, f64, f64, f64
+ } else {
+ scf.yield %a, %b, %d, %c : f64, f64, f64, f64
+ }
+ // CHECK: return %max, %others#0, %others#1, %alt : f64, f64, f64, f64
+ return %max, %others#0, %others#1, %alt : f64, f64, f64, f64
+}
+
+// -----
``````````
</details>
https://github.com/llvm/llvm-project/pull/79704
More information about the Mlir-commits
mailing list