[Mlir-commits] [mlir] [mlir] Retain original identifier names for debugging (PR #79704)

Perry Gibson llvmlistbot at llvm.org
Sat Jan 27 10:50:08 PST 2024


https://github.com/Wheest created https://github.com/llvm/llvm-project/pull/79704

This PR implements retention of MLIR identifier names (e.g., `%my_val`, `^bb_foo`) for debugging and development purposes.

A wider discussion of this feature is in [this discourse thread](https://discourse.llvm.org/t/retain-original-identifier-names-for-debugging/76417/1).

A motivating example is that right now, IR generation drops all meaningful identifier names, which could be useful for developers trying to understand their passes, or for other tooling (e.g., [MLIR code formatters](https://discourse.llvm.org/t/clang-format-or-some-other-auto-format-for-mlir-files/75258/15)).

```mlir
func.func @add_one(%my_input: f64) -> f64 {
    %my_constant = arith.constant 1.00000e+00 : f64
    %my_output = arith.addf %my_input, %my_constant : f64
    return %my_output : f64
}
```

:arrow_down: becomes

```mlir
func.func @add_one(%arg0: f64) -> f64 {
  %cst = arith.constant 1.000000e+00 : f64
  %0 = arith.addf %arg0, %cst : f64
  return %0 : f64
}
```

The solution this PR implements is to store this metadata inside the attribute dictionary of operations, under a special namespace (e.g., `mlir.resultNames`).  This means that this optional feature (turned on by the `mlir-opt` flag `--retain-identifier-names`) does not incur any additional overhead, except in text parsing and printing (`AsmParser/Parser.cpp` and `IR/AsmPrinter.cpp`).

Alternative solutions, such as adding a string field to the `Value` class, reparsing location information, and adapting `OpAsmInterface` are discussed in the [relevant discourse thread](https://discourse.llvm.org/t/retain-original-identifier-names-for-debugging/76417/1)).

I've implemented some initial test cases in `mlir/test/IR/print-retain-identifiers.mlir`.  

This covers things such as:

- retaining the result names of operation
- retaining the names of basic blocks (and their arguments)
- handling result groups

A case that I know not to work is when a `func.func` argument is not used.  This is because we recover the SSA name of these arguments from operations which use them.  You can see in the first test case that the 2nd argument is not used, so will always default back to `%arg0`.

I do not have test cases for how the system handles code transformations, and am open to suggestions for additional tests to include.  Also note that this is my most substantial contribution to the codebase to date, so I may require some shepherding with regards to coding style or use of core LLVM library constructs.



>From 7e4a6a7ce6018cf2ee8da7138bd84a3c66ee82a8 Mon Sep 17 00:00:00 2001
From: Perry Gibson <perry at gibsonic.org>
Date: Fri, 26 Jan 2024 14:34:14 +0000
Subject: [PATCH 1/8] Added single result SSA processing

---
 mlir/include/mlir/IR/AsmState.h                | 12 +++++++++---
 mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h | 10 ++++++++++
 mlir/lib/AsmParser/Parser.cpp                  | 17 +++++++++++++++++
 mlir/lib/IR/AsmPrinter.cpp                     | 17 +++++++++++------
 mlir/lib/Tools/mlir-opt/MlirOptMain.cpp        | 10 ++++++++--
 5 files changed, 55 insertions(+), 11 deletions(-)

diff --git a/mlir/include/mlir/IR/AsmState.h b/mlir/include/mlir/IR/AsmState.h
index 42cbedcf9f88374..9c4eadb04cdf2fe 100644
--- a/mlir/include/mlir/IR/AsmState.h
+++ b/mlir/include/mlir/IR/AsmState.h
@@ -144,8 +144,7 @@ class AsmResourceBlob {
   /// Return the underlying data as an array of the given type. This is an
   /// inherrently unsafe operation, and should only be used when the data is
   /// known to be of the correct type.
-  template <typename T>
-  ArrayRef<T> getDataAs() const {
+  template <typename T> ArrayRef<T> getDataAs() const {
     return llvm::ArrayRef<T>((const T *)data.data(), data.size() / sizeof(T));
   }
 
@@ -464,8 +463,10 @@ class ParserConfig {
   /// `fallbackResourceMap` is an optional fallback handler that can be used to
   /// parse external resources not explicitly handled by another parser.
   ParserConfig(MLIRContext *context, bool verifyAfterParse = true,
-               FallbackAsmResourceMap *fallbackResourceMap = nullptr)
+               FallbackAsmResourceMap *fallbackResourceMap = nullptr,
+               bool retainIdentifierNames = false)
       : context(context), verifyAfterParse(verifyAfterParse),
+        retainIdentifierNames(retainIdentifierNames),
         fallbackResourceMap(fallbackResourceMap) {
     assert(context && "expected valid MLIR context");
   }
@@ -476,6 +477,10 @@ class ParserConfig {
   /// Returns if the parser should verify the IR after parsing.
   bool shouldVerifyAfterParse() const { return verifyAfterParse; }
 
+  /// Returns if the parser should retain identifier names collected using
+  /// parsing.
+  bool shouldRetainIdentifierNames() const { return retainIdentifierNames; }
+
   /// Returns the parsing configurations associated to the bytecode read.
   BytecodeReaderConfig &getBytecodeReaderConfig() const {
     return const_cast<BytecodeReaderConfig &>(bytecodeReaderConfig);
@@ -513,6 +518,7 @@ class ParserConfig {
 private:
   MLIRContext *context;
   bool verifyAfterParse;
+  bool retainIdentifierNames;
   DenseMap<StringRef, std::unique_ptr<AsmResourceParser>> resourceParsers;
   FallbackAsmResourceMap *fallbackResourceMap;
   BytecodeReaderConfig bytecodeReaderConfig;
diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
index 6e90fad1618d21b..a85dca186a4f3c9 100644
--- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
+++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
@@ -176,6 +176,13 @@ class MlirOptMainConfig {
   /// Reproducer file generation (no crash required).
   StringRef getReproducerFilename() const { return generateReproducerFileFlag; }
 
+  /// Print the pass-pipeline as text before executing.
+  MlirOptMainConfig &retainIdentifierNames(bool retain) {
+    retainIdentifierNamesFlag = retain;
+    return *this;
+  }
+  bool shouldRetainIdentifierNames() const { return retainIdentifierNamesFlag; }
+
 protected:
   /// Allow operation with no registered dialects.
   /// This option is for convenience during testing only and discouraged in
@@ -226,6 +233,9 @@ class MlirOptMainConfig {
   /// the corresponding line. This is meant for implementing diagnostic tests.
   bool verifyDiagnosticsFlag = false;
 
+  /// Retain identifier names in the output (e.g., `%my_var` instead of `%0`).
+  bool retainIdentifierNamesFlag = false;
+
   /// Run the verifier after each transformation pass.
   bool verifyPassesFlag = true;
 
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 00f2b0c0c2f12f9..8d3861a2f018d00 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -1223,6 +1223,23 @@ ParseResult OperationParser::parseOperation() {
       }
     }
 
+    // If enabled, store the SSA name(s) for the operation
+    if (state.config.shouldRetainIdentifierNames()) {
+      if (opResI == 1) {
+        for (ResultRecord &resIt : resultIDs) {
+          for (unsigned subRes : llvm::seq<unsigned>(0, std::get<1>(resIt))) {
+            op->setDiscardableAttr(
+                "mlir.ssaName",
+                StringAttr::get(getContext(),
+                                std::get<0>(resIt).drop_front(1)));
+          }
+        }
+      } else if (opResI > 1) {
+        emitError(
+            "have not yet implemented support for multiple return values");
+      }
+    }
+
     // Add this operation to the assembly state if it was provided to populate.
   } else if (state.asmState) {
     state.asmState->finalizeOperationDefinition(
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 6b8b7473bf0f8a1..164b0f97fc1cd9e 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -72,13 +72,12 @@ OpAsmParser::~OpAsmParser() = default;
 MLIRContext *AsmParser::getContext() const { return getBuilder().getContext(); }
 
 /// Parse a type list.
-/// This is out-of-line to work-around https://github.com/llvm/llvm-project/issues/62918
+/// This is out-of-line to work-around
+/// https://github.com/llvm/llvm-project/issues/62918
 ParseResult AsmParser::parseTypeList(SmallVectorImpl<Type> &result) {
-    return parseCommaSeparatedList(
-        [&]() { return parseType(result.emplace_back()); });
-  }
-
-
+  return parseCommaSeparatedList(
+      [&]() { return parseType(result.emplace_back()); });
+}
 
 //===----------------------------------------------------------------------===//
 // DialectAsmPrinter
@@ -1579,6 +1578,12 @@ void SSANameState::numberValuesInOp(Operation &op) {
   }
   Value resultBegin = op.getResult(0);
 
+  // Get the original SSA for the result if available
+  if (StringAttr ssaNameAttr = op.getAttrOfType<StringAttr>("mlir.ssaName")) {
+    setValueName(resultBegin, ssaNameAttr.strref());
+    op.removeDiscardableAttr("mlir.ssaName");
+  }
+
   // If the first result wasn't numbered, give it a default number.
   if (valueIDs.try_emplace(resultBegin, nextValueID).second)
     ++nextValueID;
diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
index 5395aa2b502d788..c4482435861590d 100644
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -149,6 +149,11 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig {
         cl::desc("Round-trip the IR after parsing and ensure it succeeds"),
         cl::location(verifyRoundtripFlag), cl::init(false));
 
+    static cl::opt<bool, /*ExternalStorage=*/true> retainIdentifierNames(
+        "retain-identifier-names",
+        cl::desc("Retain the original names of identifiers when printing"),
+        cl::location(retainIdentifierNamesFlag), cl::init(false));
+
     static cl::list<std::string> passPlugins(
         "load-pass-plugin", cl::desc("Load passes from plugin library"));
 
@@ -359,8 +364,9 @@ performActions(raw_ostream &os,
   // untouched.
   PassReproducerOptions reproOptions;
   FallbackAsmResourceMap fallbackResourceMap;
-  ParserConfig parseConfig(context, /*verifyAfterParse=*/true,
-                           &fallbackResourceMap);
+  ParserConfig parseConfig(
+      context, /*verifyAfterParse=*/true, &fallbackResourceMap,
+      /*retainIdentifierName=*/config.shouldRetainIdentifierNames());
   if (config.shouldRunReproducer())
     reproOptions.attachResourceParser(parseConfig);
 

>From 5a629d2cd3e6f736af20c808d53206f086ea5b05 Mon Sep 17 00:00:00 2001
From: Perry Gibson <perry at gibsonic.org>
Date: Fri, 26 Jan 2024 14:52:10 +0000
Subject: [PATCH 2/8] Moved SSA name stored to function

---
 mlir/lib/AsmParser/Parser.cpp | 40 ++++++++++++++++++++++-------------
 1 file changed, 25 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 8d3861a2f018d00..282077865c665a9 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -611,6 +611,10 @@ class OperationParser : public Parser {
   /// an object of type 'OperationName'. Otherwise, failure is returned.
   FailureOr<OperationName> parseCustomOperationName();
 
+  /// Store the SSA names for the current operation as attrs for debug purposes.
+  ParseResult storeSSANames(Operation *&op,
+                            SmallVector<ResultRecord, 1> resultIDs);
+
   //===--------------------------------------------------------------------===//
   // Region Parsing
   //===--------------------------------------------------------------------===//
@@ -1224,21 +1228,8 @@ ParseResult OperationParser::parseOperation() {
     }
 
     // If enabled, store the SSA name(s) for the operation
-    if (state.config.shouldRetainIdentifierNames()) {
-      if (opResI == 1) {
-        for (ResultRecord &resIt : resultIDs) {
-          for (unsigned subRes : llvm::seq<unsigned>(0, std::get<1>(resIt))) {
-            op->setDiscardableAttr(
-                "mlir.ssaName",
-                StringAttr::get(getContext(),
-                                std::get<0>(resIt).drop_front(1)));
-          }
-        }
-      } else if (opResI > 1) {
-        emitError(
-            "have not yet implemented support for multiple return values");
-      }
-    }
+    if (state.config.shouldRetainIdentifierNames())
+      storeSSANames(op, resultIDs);
 
     // Add this operation to the assembly state if it was provided to populate.
   } else if (state.asmState) {
@@ -1285,6 +1276,25 @@ OperationParser::parseSuccessors(SmallVectorImpl<Block *> &destinations) {
                                       /*allowEmptyList=*/false);
 }
 
+/// Store the SSA names for the current operation as attrs for debug purposes.
+ParseResult
+OperationParser::storeSSANames(Operation *&op,
+                               SmallVector<ResultRecord, 1> resultIDs) {
+  if (op->getNumResults() == 0)
+    emitError("Operation has no results\n");
+  else if (op->getNumResults() > 1)
+    emitError("have not yet implemented support for multiple return values\n");
+
+  for (ResultRecord &resIt : resultIDs) {
+    for (unsigned subRes : llvm::seq<unsigned>(0, std::get<1>(resIt))) {
+      op->setDiscardableAttr(
+          "mlir.ssaName",
+          StringAttr::get(getContext(), std::get<0>(resIt).drop_front(1)));
+    }
+  }
+  return success();
+}
+
 namespace {
 // RAII-style guard for cleaning up the regions in the operation state before
 // deleting them.  Within the parser, regions may get deleted if parsing failed,

>From f2f8047bd80df0f6da6a8db4bfe976491157993e Mon Sep 17 00:00:00 2001
From: Perry Gibson <perry at gibsonic.org>
Date: Fri, 26 Jan 2024 15:39:18 +0000
Subject: [PATCH 3/8] Added initial block name handling

---
 mlir/lib/AsmParser/Parser.cpp | 13 +++++++++++++
 mlir/lib/IR/AsmPrinter.cpp    | 33 ++++++++++++++++++++++++++++-----
 2 files changed, 41 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 282077865c665a9..bb3fd5d8729406a 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -1292,6 +1292,19 @@ OperationParser::storeSSANames(Operation *&op,
           StringAttr::get(getContext(), std::get<0>(resIt).drop_front(1)));
     }
   }
+
+  // Find the name of the block that contains this operation.
+  Block *blockPtr = op->getBlock();
+  for (const auto &map : blocksByName) {
+    for (const auto &entry : map) {
+      if (entry.second.block == blockPtr) {
+        op->setDiscardableAttr("mlir.blockName",
+                               StringAttr::get(getContext(), entry.first));
+        llvm::outs() << "Block name: " << entry.first << "\n";
+      }
+    }
+  }
+
   return success();
 }
 
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 164b0f97fc1cd9e..70d38c9bd8b4f89 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1298,6 +1298,10 @@ class SSANameState {
   /// conflicts, it is automatically renamed.
   StringRef uniqueValueName(StringRef name);
 
+  /// Set the original identifier names if available. Used in debugging with
+  /// `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig
+  void setRetainedIdentifierNames(Operation &op);
+
   /// This is the value ID for each SSA value. If this returns NameSentinel,
   /// then the valueID has an entry in valueNames.
   DenseMap<Value, unsigned> valueIDs;
@@ -1578,11 +1582,9 @@ void SSANameState::numberValuesInOp(Operation &op) {
   }
   Value resultBegin = op.getResult(0);
 
-  // Get the original SSA for the result if available
-  if (StringAttr ssaNameAttr = op.getAttrOfType<StringAttr>("mlir.ssaName")) {
-    setValueName(resultBegin, ssaNameAttr.strref());
-    op.removeDiscardableAttr("mlir.ssaName");
-  }
+  // Set the original identifier names if available. Used in debugging with
+  // `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig
+  setRetainedIdentifierNames(op);
 
   // If the first result wasn't numbered, give it a default number.
   if (valueIDs.try_emplace(resultBegin, nextValueID).second)
@@ -1595,6 +1597,27 @@ void SSANameState::numberValuesInOp(Operation &op) {
   }
 }
 
+void SSANameState::setRetainedIdentifierNames(Operation &op) {
+  // Get the original SSA for the result(s) if available
+  Value resultBegin = op.getResult(0);
+  if (StringAttr ssaNameAttr = op.getAttrOfType<StringAttr>("mlir.ssaName")) {
+    setValueName(resultBegin, ssaNameAttr.strref());
+    op.removeDiscardableAttr("mlir.ssaName");
+  }
+  unsigned numResults = op.getNumResults();
+  if (numResults > 1)
+    llvm::outs()
+        << "have not yet implemented support for multiple return values\n";
+
+  // Get the original SSA name for the block if available
+  if (StringAttr blockNameAttr =
+          op.getAttrOfType<StringAttr>("mlir.blockName")) {
+    blockNames[op.getBlock()] = {-1, blockNameAttr.strref()};
+    op.removeDiscardableAttr("mlir.blockName");
+  }
+  return;
+}
+
 void SSANameState::getResultIDAndNumber(
     OpResult result, Value &lookupValue,
     std::optional<int> &lookupResultNo) const {

>From 1aa3814f6eb47ec8d51d2eb68e40e469194e2744 Mon Sep 17 00:00:00 2001
From: Perry Gibson <perry at gibsonic.org>
Date: Fri, 26 Jan 2024 16:18:27 +0000
Subject: [PATCH 4/8] Added block arg name handling

---
 mlir/lib/AsmParser/Parser.cpp | 25 ++++++++++++++++++++-----
 mlir/lib/IR/AsmPrinter.cpp    | 14 +++++++++++++-
 2 files changed, 33 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index bb3fd5d8729406a..95b1b2fe1f8f07e 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -614,6 +614,7 @@ class OperationParser : public Parser {
   /// Store the SSA names for the current operation as attrs for debug purposes.
   ParseResult storeSSANames(Operation *&op,
                             SmallVector<ResultRecord, 1> resultIDs);
+  DenseMap<BlockArgument, StringRef> blockArgNames;
 
   //===--------------------------------------------------------------------===//
   // Region Parsing
@@ -1203,6 +1204,11 @@ ParseResult OperationParser::parseOperation() {
              << op->getNumResults() << " results but was provided "
              << numExpectedResults << " to bind";
 
+    // If enabled, store the SSA name(s) for the operation
+    llvm::outs() << "parsing operation: " << op->getName() << "\n";
+    if (state.config.shouldRetainIdentifierNames())
+      storeSSANames(op, resultIDs);
+
     // Add this operation to the assembly state if it was provided to populate.
     if (state.asmState) {
       unsigned resultIt = 0;
@@ -1227,10 +1233,6 @@ ParseResult OperationParser::parseOperation() {
       }
     }
 
-    // If enabled, store the SSA name(s) for the operation
-    if (state.config.shouldRetainIdentifierNames())
-      storeSSANames(op, resultIDs);
-
     // Add this operation to the assembly state if it was provided to populate.
   } else if (state.asmState) {
     state.asmState->finalizeOperationDefinition(
@@ -1300,7 +1302,16 @@ OperationParser::storeSSANames(Operation *&op,
       if (entry.second.block == blockPtr) {
         op->setDiscardableAttr("mlir.blockName",
                                StringAttr::get(getContext(), entry.first));
-        llvm::outs() << "Block name: " << entry.first << "\n";
+
+        // Store block arguments, if present
+        llvm::SmallVector<llvm::StringRef, 1> argNames;
+
+        for (BlockArgument arg : blockPtr->getArguments()) {
+          auto it = blockArgNames.find(arg);
+          if (it != blockArgNames.end())
+            argNames.push_back(it->second.drop_front(1));
+        }
+        op->setAttr("mlir.blockArgNames", builder.getStrArrayAttr(argNames));
       }
     }
   }
@@ -2395,6 +2406,10 @@ ParseResult OperationParser::parseOptionalBlockArgList(Block *owner) {
           } else {
             auto loc = getEncodedSourceLocation(useInfo.location);
             arg = owner->addArgument(type, loc);
+
+            // Optionally store argument name for debug purposes
+            if (state.config.shouldRetainIdentifierNames())
+              blockArgNames.insert({arg, useInfo.name});
           }
 
           // If the argument has an explicit loc(...) specifier, parse and apply
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 70d38c9bd8b4f89..34e0bdfccda94e5 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1609,12 +1609,24 @@ void SSANameState::setRetainedIdentifierNames(Operation &op) {
     llvm::outs()
         << "have not yet implemented support for multiple return values\n";
 
-  // Get the original SSA name for the block if available
+  // Get the original name for the block if available
   if (StringAttr blockNameAttr =
           op.getAttrOfType<StringAttr>("mlir.blockName")) {
     blockNames[op.getBlock()] = {-1, blockNameAttr.strref()};
     op.removeDiscardableAttr("mlir.blockName");
   }
+
+  // Get the original name for the block args if available
+  if (ArrayAttr blockArgNamesAttr =
+          op.getAttrOfType<ArrayAttr>("mlir.blockArgNames")) {
+    auto blockArgNames = blockArgNamesAttr.getValue();
+    auto blockArgs = op.getBlock()->getArguments();
+    for (int i = 0; i < blockArgs.size() && i < blockArgNames.size(); ++i) {
+      auto blockArgName = blockArgNames[i].cast<StringAttr>();
+      setValueName(blockArgs[i], cast<StringAttr>(blockArgNames[i]).strref());
+    }
+    op.removeDiscardableAttr("mlir.blockArgNames");
+  }
   return;
 }
 

>From c0ec451d9ba155478ee5bded93cfaa912480e0ce Mon Sep 17 00:00:00 2001
From: Perry Gibson <perry at gibsonic.org>
Date: Sat, 27 Jan 2024 11:24:53 +0000
Subject: [PATCH 5/8] Added support for operations with no arguments

---
 mlir/lib/AsmParser/Parser.cpp | 26 ++++++++++----------------
 mlir/lib/IR/AsmPrinter.cpp    | 27 +++++++++++++++------------
 2 files changed, 25 insertions(+), 28 deletions(-)

diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 95b1b2fe1f8f07e..36266be5af033c0 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -612,8 +612,7 @@ class OperationParser : public Parser {
   FailureOr<OperationName> parseCustomOperationName();
 
   /// Store the SSA names for the current operation as attrs for debug purposes.
-  ParseResult storeSSANames(Operation *&op,
-                            SmallVector<ResultRecord, 1> resultIDs);
+  void storeSSANames(Operation *&op, ArrayRef<ResultRecord> resultIDs);
   DenseMap<BlockArgument, StringRef> blockArgNames;
 
   //===--------------------------------------------------------------------===//
@@ -1204,11 +1203,6 @@ ParseResult OperationParser::parseOperation() {
              << op->getNumResults() << " results but was provided "
              << numExpectedResults << " to bind";
 
-    // If enabled, store the SSA name(s) for the operation
-    llvm::outs() << "parsing operation: " << op->getName() << "\n";
-    if (state.config.shouldRetainIdentifierNames())
-      storeSSANames(op, resultIDs);
-
     // Add this operation to the assembly state if it was provided to populate.
     if (state.asmState) {
       unsigned resultIt = 0;
@@ -1279,15 +1273,12 @@ OperationParser::parseSuccessors(SmallVectorImpl<Block *> &destinations) {
 }
 
 /// Store the SSA names for the current operation as attrs for debug purposes.
-ParseResult
-OperationParser::storeSSANames(Operation *&op,
-                               SmallVector<ResultRecord, 1> resultIDs) {
-  if (op->getNumResults() == 0)
-    emitError("Operation has no results\n");
-  else if (op->getNumResults() > 1)
+void OperationParser::storeSSANames(Operation *&op,
+                                    ArrayRef<ResultRecord> resultIDs) {
+  if (op->getNumResults() > 1)
     emitError("have not yet implemented support for multiple return values\n");
 
-  for (ResultRecord &resIt : resultIDs) {
+  for (const ResultRecord &resIt : resultIDs) {
     for (unsigned subRes : llvm::seq<unsigned>(0, std::get<1>(resIt))) {
       op->setDiscardableAttr(
           "mlir.ssaName",
@@ -1315,8 +1306,6 @@ OperationParser::storeSSANames(Operation *&op,
       }
     }
   }
-
-  return success();
 }
 
 namespace {
@@ -2082,6 +2071,11 @@ OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) {
 
   // Otherwise, create the operation and try to parse a location for it.
   Operation *op = opBuilder.create(opState);
+
+  // If enabled, store the SSA name(s) for the operation
+  if (state.config.shouldRetainIdentifierNames())
+    storeSSANames(op, resultIDs);
+
   if (parseTrailingLocationSpecifier(op))
     return nullptr;
 
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 34e0bdfccda94e5..d4be3b7802e6880 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1571,6 +1571,10 @@ void SSANameState::numberValuesInOp(Operation &op) {
     }
   }
 
+  // Set the original identifier names if available. Used in debugging with
+  // `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig
+  setRetainedIdentifierNames(op);
+
   unsigned numResults = op.getNumResults();
   if (numResults == 0) {
     // If value users should be printed, operations with no result need an id.
@@ -1582,10 +1586,6 @@ void SSANameState::numberValuesInOp(Operation &op) {
   }
   Value resultBegin = op.getResult(0);
 
-  // Set the original identifier names if available. Used in debugging with
-  // `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig
-  setRetainedIdentifierNames(op);
-
   // If the first result wasn't numbered, give it a default number.
   if (valueIDs.try_emplace(resultBegin, nextValueID).second)
     ++nextValueID;
@@ -1599,15 +1599,17 @@ void SSANameState::numberValuesInOp(Operation &op) {
 
 void SSANameState::setRetainedIdentifierNames(Operation &op) {
   // Get the original SSA for the result(s) if available
-  Value resultBegin = op.getResult(0);
-  if (StringAttr ssaNameAttr = op.getAttrOfType<StringAttr>("mlir.ssaName")) {
-    setValueName(resultBegin, ssaNameAttr.strref());
-    op.removeDiscardableAttr("mlir.ssaName");
-  }
   unsigned numResults = op.getNumResults();
   if (numResults > 1)
     llvm::outs()
         << "have not yet implemented support for multiple return values\n";
+  else if (numResults == 1) {
+    Value resultBegin = op.getResult(0);
+    if (StringAttr ssaNameAttr = op.getAttrOfType<StringAttr>("mlir.ssaName")) {
+      setValueName(resultBegin, ssaNameAttr.strref());
+      op.removeDiscardableAttr("mlir.ssaName");
+    }
+  }
 
   // Get the original name for the block if available
   if (StringAttr blockNameAttr =
@@ -1621,9 +1623,10 @@ void SSANameState::setRetainedIdentifierNames(Operation &op) {
           op.getAttrOfType<ArrayAttr>("mlir.blockArgNames")) {
     auto blockArgNames = blockArgNamesAttr.getValue();
     auto blockArgs = op.getBlock()->getArguments();
-    for (int i = 0; i < blockArgs.size() && i < blockArgNames.size(); ++i) {
-      auto blockArgName = blockArgNames[i].cast<StringAttr>();
-      setValueName(blockArgs[i], cast<StringAttr>(blockArgNames[i]).strref());
+    for (size_t i = 0; i < blockArgs.size() && i < blockArgNames.size(); ++i) {
+      auto blockArgName = blockArgNames[i].cast<StringAttr>().strref();
+      if (!usedNames.count(blockArgName))
+        setValueName(blockArgs[i], blockArgName);
     }
     op.removeDiscardableAttr("mlir.blockArgNames");
   }

>From a9e3891f4ed0511de6797cfe986fe8d3dcbf8f08 Mon Sep 17 00:00:00 2001
From: Perry Gibson <perry at gibsonic.org>
Date: Sat, 27 Jan 2024 12:21:39 +0000
Subject: [PATCH 6/8] Added initial unit tests

---
 mlir/test/IR/print-retain-identifiers.mlir | 54 ++++++++++++++++++++++
 1 file changed, 54 insertions(+)
 create mode 100644 mlir/test/IR/print-retain-identifiers.mlir

diff --git a/mlir/test/IR/print-retain-identifiers.mlir b/mlir/test/IR/print-retain-identifiers.mlir
new file mode 100644
index 000000000000000..663f7301b817bc0
--- /dev/null
+++ b/mlir/test/IR/print-retain-identifiers.mlir
@@ -0,0 +1,54 @@
+// RUN: mlir-opt -retain-identifier-names %s | FileCheck %s
+
+
+//===----------------------------------------------------------------------===//
+// Test SSA results (with single return values)
+//===----------------------------------------------------------------------===//
+
+// CHECK: func.func @add_one(%arg0: f64, %arg1: f64) -> f64 {
+func.func @add_one(%arg0: f64, %arg1: f64) -> f64 {
+  // CHECK: %my_constant = arith.constant 1.000000e+00 : f64
+  %my_constant = arith.constant 1.000000e+00 : f64
+  // CHECK: %my_output = arith.addf %arg0, %my_constant : f64
+  %my_output = arith.addf %arg0, %my_constant : f64
+  // CHECK: return %my_output : f64
+  return %my_output : f64
+}
+
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Test basic blocks and their arguments
+//===----------------------------------------------------------------------===//
+
+func.func @simple(i64, i1) -> i64 {
+^bb_alpha(%a: i64, %cond: i1):
+  // CHECK: cf.cond_br %cond, ^bb_beta, ^bb_gamma
+  cf.cond_br %cond, ^bb_beta, ^bb_gamma
+
+// CHECK: ^bb_beta:  // pred: ^bb_alpha
+^bb_beta:
+  // CHECK: cf.br ^bb_delta(%a : i64)
+  cf.br ^bb_delta(%a: i64)
+
+// CHECK: ^bb_gamma:  // pred: ^bb_alpha
+^bb_gamma:
+  // CHECK: %b = arith.addi %a, %a : i64
+  %b = arith.addi %a, %a : i64
+  // CHECK: cf.br ^bb_delta(%b : i64)
+  cf.br ^bb_delta(%b: i64)
+
+// CHECK: ^bb_delta(%c: i64):  // 2 preds: ^bb_gamma, ^bb_beta
+^bb_delta(%c: i64):
+  // CHECK: cf.br ^bb_eps(%c, %a : i64, i64)
+  cf.br ^bb_eps(%c, %a : i64, i64)
+
+// CHECK: ^bb_eps(%d: i64, %e: i64):  // pred: ^bb_delta
+^bb_eps(%d : i64, %e : i64):
+  // CHECK: %f = arith.addi %d, %e : i64
+  %f = arith.addi %d, %e : i64
+  return %f : i64
+}
+
+// -----

>From 467c3a38abfbbc22c25b0874a3943ee7ad8bc78b Mon Sep 17 00:00:00 2001
From: Perry Gibson <perry at gibsonic.org>
Date: Sat, 27 Jan 2024 14:49:16 +0000
Subject: [PATCH 7/8] Added operation operand name preservation

---
 mlir/lib/AsmParser/Parser.cpp              | 34 +++++++++++++++++-----
 mlir/lib/IR/AsmPrinter.cpp                 | 13 +++++++++
 mlir/test/IR/print-retain-identifiers.mlir |  8 ++---
 3 files changed, 43 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 36266be5af033c0..7327776ae5a55d1 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -613,7 +613,7 @@ class OperationParser : public Parser {
 
   /// Store the SSA names for the current operation as attrs for debug purposes.
   void storeSSANames(Operation *&op, ArrayRef<ResultRecord> resultIDs);
-  DenseMap<BlockArgument, StringRef> blockArgNames;
+  DenseMap<Value, StringRef> argNames;
 
   //===--------------------------------------------------------------------===//
   // Region Parsing
@@ -1286,7 +1286,19 @@ void OperationParser::storeSSANames(Operation *&op,
     }
   }
 
-  // Find the name of the block that contains this operation.
+  // Store the name information of the arguments of this operation.
+  if (op->getNumOperands() > 0) {
+    llvm::SmallVector<llvm::StringRef, 1> opArgNames;
+    for (auto &operand : op->getOpOperands()) {
+      auto it = argNames.find(operand.get());
+      if (it != argNames.end())
+        opArgNames.push_back(it->second.drop_front(1));
+    }
+    op->setDiscardableAttr("mlir.opArgNames",
+                           builder.getStrArrayAttr(opArgNames));
+  }
+
+  // Store the name information of the block that contains this operation.
   Block *blockPtr = op->getBlock();
   for (const auto &map : blocksByName) {
     for (const auto &entry : map) {
@@ -1295,14 +1307,15 @@ void OperationParser::storeSSANames(Operation *&op,
                                StringAttr::get(getContext(), entry.first));
 
         // Store block arguments, if present
-        llvm::SmallVector<llvm::StringRef, 1> argNames;
+        llvm::SmallVector<llvm::StringRef, 1> blockArgNames;
 
         for (BlockArgument arg : blockPtr->getArguments()) {
-          auto it = blockArgNames.find(arg);
-          if (it != blockArgNames.end())
-            argNames.push_back(it->second.drop_front(1));
+          auto it = argNames.find(arg);
+          if (it != argNames.end())
+            blockArgNames.push_back(it->second.drop_front(1));
         }
-        op->setAttr("mlir.blockArgNames", builder.getStrArrayAttr(argNames));
+        op->setAttr("mlir.blockArgNames",
+                    builder.getStrArrayAttr(blockArgNames));
       }
     }
   }
@@ -1712,6 +1725,11 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
                              SmallVectorImpl<Value> &result) override {
     if (auto value = parser.resolveSSAUse(operand, type)) {
       result.push_back(value);
+
+      // Optionally store argument name for debug purposes
+      if (parser.getState().config.shouldRetainIdentifierNames())
+        parser.argNames.insert({value, operand.name});
+
       return success();
     }
     return failure();
@@ -2403,7 +2421,7 @@ ParseResult OperationParser::parseOptionalBlockArgList(Block *owner) {
 
             // Optionally store argument name for debug purposes
             if (state.config.shouldRetainIdentifierNames())
-              blockArgNames.insert({arg, useInfo.name});
+              argNames.insert({arg, useInfo.name});
           }
 
           // If the argument has an explicit loc(...) specifier, parse and apply
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index d4be3b7802e6880..5755b9021dbad7f 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1611,6 +1611,19 @@ void SSANameState::setRetainedIdentifierNames(Operation &op) {
     }
   }
 
+  // Get the original name for the op args if available
+  if (ArrayAttr opArgNamesAttr =
+          op.getAttrOfType<ArrayAttr>("mlir.opArgNames")) {
+    auto opArgNames = opArgNamesAttr.getValue();
+    auto opArgs = op.getOperands();
+    for (size_t i = 0; i < opArgs.size() && i < opArgNames.size(); ++i) {
+      auto opArgName = opArgNames[i].cast<StringAttr>().strref();
+      if (!usedNames.count(opArgName))
+        setValueName(opArgs[i], opArgName);
+    }
+    op.removeDiscardableAttr("mlir.opArgNames");
+  }
+
   // Get the original name for the block if available
   if (StringAttr blockNameAttr =
           op.getAttrOfType<StringAttr>("mlir.blockName")) {
diff --git a/mlir/test/IR/print-retain-identifiers.mlir b/mlir/test/IR/print-retain-identifiers.mlir
index 663f7301b817bc0..05a6b9b42d8e820 100644
--- a/mlir/test/IR/print-retain-identifiers.mlir
+++ b/mlir/test/IR/print-retain-identifiers.mlir
@@ -5,12 +5,12 @@
 // Test SSA results (with single return values)
 //===----------------------------------------------------------------------===//
 
-// CHECK: func.func @add_one(%arg0: f64, %arg1: f64) -> f64 {
-func.func @add_one(%arg0: f64, %arg1: f64) -> f64 {
+// CHECK: func.func @add_one(%my_input: f64, %arg1: f64) -> f64 {
+func.func @add_one(%my_input: f64, %arg1: f64) -> f64 {
   // CHECK: %my_constant = arith.constant 1.000000e+00 : f64
   %my_constant = arith.constant 1.000000e+00 : f64
-  // CHECK: %my_output = arith.addf %arg0, %my_constant : f64
-  %my_output = arith.addf %arg0, %my_constant : f64
+  // CHECK: %my_output = arith.addf %my_input, %my_constant : f64
+  %my_output = arith.addf %my_input, %my_constant : f64
   // CHECK: return %my_output : f64
   return %my_output : f64
 }

>From d062cb861c283594413c8edc7dca57f9494b1f3e Mon Sep 17 00:00:00 2001
From: Perry Gibson <perry at gibsonic.org>
Date: Sat, 27 Jan 2024 18:27:53 +0000
Subject: [PATCH 8/8] Added support for result groups

---
 mlir/lib/AsmParser/Parser.cpp              | 19 ++++++-----
 mlir/lib/IR/AsmPrinter.cpp                 | 36 ++++++++++++--------
 mlir/test/IR/print-retain-identifiers.mlir | 38 ++++++++++++++++++++++
 3 files changed, 72 insertions(+), 21 deletions(-)

diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 7327776ae5a55d1..247e99e61c2c01b 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -1275,15 +1275,18 @@ OperationParser::parseSuccessors(SmallVectorImpl<Block *> &destinations) {
 /// Store the SSA names for the current operation as attrs for debug purposes.
 void OperationParser::storeSSANames(Operation *&op,
                                     ArrayRef<ResultRecord> resultIDs) {
-  if (op->getNumResults() > 1)
-    emitError("have not yet implemented support for multiple return values\n");
-
-  for (const ResultRecord &resIt : resultIDs) {
-    for (unsigned subRes : llvm::seq<unsigned>(0, std::get<1>(resIt))) {
-      op->setDiscardableAttr(
-          "mlir.ssaName",
-          StringAttr::get(getContext(), std::get<0>(resIt).drop_front(1)));
+
+  // Store the name(s) of the result(s) of this operation.
+  if (op->getNumResults() > 0) {
+    llvm::SmallVector<llvm::StringRef, 1> resultNames;
+    for (const ResultRecord &resIt : resultIDs) {
+      resultNames.push_back(std::get<0>(resIt).drop_front(1));
+      // Insert empty string for sub-results/result groups
+      for (unsigned int i = 1; i < std::get<1>(resIt); ++i)
+        resultNames.push_back(llvm::StringRef());
     }
+    op->setDiscardableAttr("mlir.resultNames",
+                           builder.getStrArrayAttr(resultNames));
   }
 
   // Store the name information of the arguments of this operation.
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 5755b9021dbad7f..84603bb6ebfba3a 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1300,7 +1300,8 @@ class SSANameState {
 
   /// Set the original identifier names if available. Used in debugging with
   /// `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig
-  void setRetainedIdentifierNames(Operation &op);
+  void setRetainedIdentifierNames(Operation &op,
+                                  SmallVector<int, 2> &resultGroups);
 
   /// This is the value ID for each SSA value. If this returns NameSentinel,
   /// then the valueID has an entry in valueNames.
@@ -1573,7 +1574,7 @@ void SSANameState::numberValuesInOp(Operation &op) {
 
   // Set the original identifier names if available. Used in debugging with
   // `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig
-  setRetainedIdentifierNames(op);
+  setRetainedIdentifierNames(op, resultGroups);
 
   unsigned numResults = op.getNumResults();
   if (numResults == 0) {
@@ -1597,18 +1598,25 @@ void SSANameState::numberValuesInOp(Operation &op) {
   }
 }
 
-void SSANameState::setRetainedIdentifierNames(Operation &op) {
-  // Get the original SSA for the result(s) if available
-  unsigned numResults = op.getNumResults();
-  if (numResults > 1)
-    llvm::outs()
-        << "have not yet implemented support for multiple return values\n";
-  else if (numResults == 1) {
-    Value resultBegin = op.getResult(0);
-    if (StringAttr ssaNameAttr = op.getAttrOfType<StringAttr>("mlir.ssaName")) {
-      setValueName(resultBegin, ssaNameAttr.strref());
-      op.removeDiscardableAttr("mlir.ssaName");
+void SSANameState::setRetainedIdentifierNames(
+    Operation &op, SmallVector<int, 2> &resultGroups) {
+  // Get the original names for the results if available
+  if (ArrayAttr resultNamesAttr =
+          op.getAttrOfType<ArrayAttr>("mlir.resultNames")) {
+    auto resultNames = resultNamesAttr.getValue();
+    auto results = op.getResults();
+    // Conservative in the case that the #results has changed
+    for (size_t i = 0; i < results.size() && i < resultNames.size(); ++i) {
+      auto resultName = resultNames[i].cast<StringAttr>().strref();
+      if (!resultName.empty()) {
+        if (!usedNames.count(resultName))
+          setValueName(results[i], resultName);
+        // If a result has a name, it is the start of a result group.
+        if (i > 0)
+          resultGroups.push_back(i);
+      }
     }
+    op.removeDiscardableAttr("mlir.resultNames");
   }
 
   // Get the original name for the op args if available
@@ -1616,6 +1624,7 @@ void SSANameState::setRetainedIdentifierNames(Operation &op) {
           op.getAttrOfType<ArrayAttr>("mlir.opArgNames")) {
     auto opArgNames = opArgNamesAttr.getValue();
     auto opArgs = op.getOperands();
+    // Conservative in the case that the #operands has changed
     for (size_t i = 0; i < opArgs.size() && i < opArgNames.size(); ++i) {
       auto opArgName = opArgNames[i].cast<StringAttr>().strref();
       if (!usedNames.count(opArgName))
@@ -1636,6 +1645,7 @@ void SSANameState::setRetainedIdentifierNames(Operation &op) {
           op.getAttrOfType<ArrayAttr>("mlir.blockArgNames")) {
     auto blockArgNames = blockArgNamesAttr.getValue();
     auto blockArgs = op.getBlock()->getArguments();
+    // Conservative in the case that the #args has changed
     for (size_t i = 0; i < blockArgs.size() && i < blockArgNames.size(); ++i) {
       auto blockArgName = blockArgNames[i].cast<StringAttr>().strref();
       if (!usedNames.count(blockArgName))
diff --git a/mlir/test/IR/print-retain-identifiers.mlir b/mlir/test/IR/print-retain-identifiers.mlir
index 05a6b9b42d8e820..b3e4f075b3936a8 100644
--- a/mlir/test/IR/print-retain-identifiers.mlir
+++ b/mlir/test/IR/print-retain-identifiers.mlir
@@ -52,3 +52,41 @@ func.func @simple(i64, i1) -> i64 {
 }
 
 // -----
+
+//===----------------------------------------------------------------------===//
+// Test multiple return values
+//===----------------------------------------------------------------------===//
+
+func.func @select_min_max(%a: f64, %b: f64) -> (f64, f64) {
+  %gt = arith.cmpf "ogt", %a, %b : f64
+  // CHECK: %min, %max = scf.if %gt -> (f64, f64) {
+  %min, %max = scf.if %gt -> (f64, f64) {
+    scf.yield %b, %a : f64, f64
+  } else {
+    scf.yield %a, %b : f64, f64
+  }
+  // CHECK: return %min, %max : f64, f64
+  return %min, %max : f64, f64
+}
+
+// -----
+
+////===----------------------------------------------------------------------===//
+// Test multiple return values, with a grouped value tuple
+//===----------------------------------------------------------------------===//
+
+func.func @select_max(%a: f64, %b: f64, %c: f64, %d: f64) -> (f64, f64, f64, f64) {
+  // Find the max between %a and %b,
+  // with %c and %d being other values that are returned.
+  %gt = arith.cmpf "ogt", %a, %b : f64
+  // CHECK: %max, %others:2, %alt = scf.if %gt -> (f64, f64, f64, f64) {
+  %max, %others:2, %alt  = scf.if %gt -> (f64, f64, f64, f64) {
+    scf.yield %b, %a, %c, %d : f64, f64, f64, f64
+  } else {
+    scf.yield %a, %b, %d, %c : f64, f64, f64, f64
+  }
+  // CHECK: return %max, %others#0, %others#1, %alt : f64, f64, f64, f64
+  return %max, %others#0, %others#1, %alt : f64, f64, f64, f64
+}
+
+// -----



More information about the Mlir-commits mailing list