[Mlir-commits] [mlir] [mlir] Retain original identifier names for	debugging (PR #79704)
    Perry Gibson 
    llvmlistbot at llvm.org
       
    Mon Jan 29 05:04:20 PST 2024
    
    
  
https://github.com/Wheest updated https://github.com/llvm/llvm-project/pull/79704
>From 7e4a6a7ce6018cf2ee8da7138bd84a3c66ee82a8 Mon Sep 17 00:00:00 2001
From: Perry Gibson <perry at gibsonic.org>
Date: Fri, 26 Jan 2024 14:34:14 +0000
Subject: [PATCH 01/11] Added single result SSA processing
---
 mlir/include/mlir/IR/AsmState.h                | 12 +++++++++---
 mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h | 10 ++++++++++
 mlir/lib/AsmParser/Parser.cpp                  | 17 +++++++++++++++++
 mlir/lib/IR/AsmPrinter.cpp                     | 17 +++++++++++------
 mlir/lib/Tools/mlir-opt/MlirOptMain.cpp        | 10 ++++++++--
 5 files changed, 55 insertions(+), 11 deletions(-)
diff --git a/mlir/include/mlir/IR/AsmState.h b/mlir/include/mlir/IR/AsmState.h
index 42cbedcf9f88374..9c4eadb04cdf2fe 100644
--- a/mlir/include/mlir/IR/AsmState.h
+++ b/mlir/include/mlir/IR/AsmState.h
@@ -144,8 +144,7 @@ class AsmResourceBlob {
   /// Return the underlying data as an array of the given type. This is an
   /// inherrently unsafe operation, and should only be used when the data is
   /// known to be of the correct type.
-  template <typename T>
-  ArrayRef<T> getDataAs() const {
+  template <typename T> ArrayRef<T> getDataAs() const {
     return llvm::ArrayRef<T>((const T *)data.data(), data.size() / sizeof(T));
   }
 
@@ -464,8 +463,10 @@ class ParserConfig {
   /// `fallbackResourceMap` is an optional fallback handler that can be used to
   /// parse external resources not explicitly handled by another parser.
   ParserConfig(MLIRContext *context, bool verifyAfterParse = true,
-               FallbackAsmResourceMap *fallbackResourceMap = nullptr)
+               FallbackAsmResourceMap *fallbackResourceMap = nullptr,
+               bool retainIdentifierNames = false)
       : context(context), verifyAfterParse(verifyAfterParse),
+        retainIdentifierNames(retainIdentifierNames),
         fallbackResourceMap(fallbackResourceMap) {
     assert(context && "expected valid MLIR context");
   }
@@ -476,6 +477,10 @@ class ParserConfig {
   /// Returns if the parser should verify the IR after parsing.
   bool shouldVerifyAfterParse() const { return verifyAfterParse; }
 
+  /// Returns if the parser should retain identifier names collected using
+  /// parsing.
+  bool shouldRetainIdentifierNames() const { return retainIdentifierNames; }
+
   /// Returns the parsing configurations associated to the bytecode read.
   BytecodeReaderConfig &getBytecodeReaderConfig() const {
     return const_cast<BytecodeReaderConfig &>(bytecodeReaderConfig);
@@ -513,6 +518,7 @@ class ParserConfig {
 private:
   MLIRContext *context;
   bool verifyAfterParse;
+  bool retainIdentifierNames;
   DenseMap<StringRef, std::unique_ptr<AsmResourceParser>> resourceParsers;
   FallbackAsmResourceMap *fallbackResourceMap;
   BytecodeReaderConfig bytecodeReaderConfig;
diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
index 6e90fad1618d21b..a85dca186a4f3c9 100644
--- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
+++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
@@ -176,6 +176,13 @@ class MlirOptMainConfig {
   /// Reproducer file generation (no crash required).
   StringRef getReproducerFilename() const { return generateReproducerFileFlag; }
 
+  /// Print the pass-pipeline as text before executing.
+  MlirOptMainConfig &retainIdentifierNames(bool retain) {
+    retainIdentifierNamesFlag = retain;
+    return *this;
+  }
+  bool shouldRetainIdentifierNames() const { return retainIdentifierNamesFlag; }
+
 protected:
   /// Allow operation with no registered dialects.
   /// This option is for convenience during testing only and discouraged in
@@ -226,6 +233,9 @@ class MlirOptMainConfig {
   /// the corresponding line. This is meant for implementing diagnostic tests.
   bool verifyDiagnosticsFlag = false;
 
+  /// Retain identifier names in the output (e.g., `%my_var` instead of `%0`).
+  bool retainIdentifierNamesFlag = false;
+
   /// Run the verifier after each transformation pass.
   bool verifyPassesFlag = true;
 
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 00f2b0c0c2f12f9..8d3861a2f018d00 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -1223,6 +1223,23 @@ ParseResult OperationParser::parseOperation() {
       }
     }
 
+    // If enabled, store the SSA name(s) for the operation
+    if (state.config.shouldRetainIdentifierNames()) {
+      if (opResI == 1) {
+        for (ResultRecord &resIt : resultIDs) {
+          for (unsigned subRes : llvm::seq<unsigned>(0, std::get<1>(resIt))) {
+            op->setDiscardableAttr(
+                "mlir.ssaName",
+                StringAttr::get(getContext(),
+                                std::get<0>(resIt).drop_front(1)));
+          }
+        }
+      } else if (opResI > 1) {
+        emitError(
+            "have not yet implemented support for multiple return values");
+      }
+    }
+
     // Add this operation to the assembly state if it was provided to populate.
   } else if (state.asmState) {
     state.asmState->finalizeOperationDefinition(
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 6b8b7473bf0f8a1..164b0f97fc1cd9e 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -72,13 +72,12 @@ OpAsmParser::~OpAsmParser() = default;
 MLIRContext *AsmParser::getContext() const { return getBuilder().getContext(); }
 
 /// Parse a type list.
-/// This is out-of-line to work-around https://github.com/llvm/llvm-project/issues/62918
+/// This is out-of-line to work-around
+/// https://github.com/llvm/llvm-project/issues/62918
 ParseResult AsmParser::parseTypeList(SmallVectorImpl<Type> &result) {
-    return parseCommaSeparatedList(
-        [&]() { return parseType(result.emplace_back()); });
-  }
-
-
+  return parseCommaSeparatedList(
+      [&]() { return parseType(result.emplace_back()); });
+}
 
 //===----------------------------------------------------------------------===//
 // DialectAsmPrinter
@@ -1579,6 +1578,12 @@ void SSANameState::numberValuesInOp(Operation &op) {
   }
   Value resultBegin = op.getResult(0);
 
+  // Get the original SSA for the result if available
+  if (StringAttr ssaNameAttr = op.getAttrOfType<StringAttr>("mlir.ssaName")) {
+    setValueName(resultBegin, ssaNameAttr.strref());
+    op.removeDiscardableAttr("mlir.ssaName");
+  }
+
   // If the first result wasn't numbered, give it a default number.
   if (valueIDs.try_emplace(resultBegin, nextValueID).second)
     ++nextValueID;
diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
index 5395aa2b502d788..c4482435861590d 100644
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -149,6 +149,11 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig {
         cl::desc("Round-trip the IR after parsing and ensure it succeeds"),
         cl::location(verifyRoundtripFlag), cl::init(false));
 
+    static cl::opt<bool, /*ExternalStorage=*/true> retainIdentifierNames(
+        "retain-identifier-names",
+        cl::desc("Retain the original names of identifiers when printing"),
+        cl::location(retainIdentifierNamesFlag), cl::init(false));
+
     static cl::list<std::string> passPlugins(
         "load-pass-plugin", cl::desc("Load passes from plugin library"));
 
@@ -359,8 +364,9 @@ performActions(raw_ostream &os,
   // untouched.
   PassReproducerOptions reproOptions;
   FallbackAsmResourceMap fallbackResourceMap;
-  ParserConfig parseConfig(context, /*verifyAfterParse=*/true,
-                           &fallbackResourceMap);
+  ParserConfig parseConfig(
+      context, /*verifyAfterParse=*/true, &fallbackResourceMap,
+      /*retainIdentifierName=*/config.shouldRetainIdentifierNames());
   if (config.shouldRunReproducer())
     reproOptions.attachResourceParser(parseConfig);
 
>From 5a629d2cd3e6f736af20c808d53206f086ea5b05 Mon Sep 17 00:00:00 2001
From: Perry Gibson <perry at gibsonic.org>
Date: Fri, 26 Jan 2024 14:52:10 +0000
Subject: [PATCH 02/11] Moved SSA name stored to function
---
 mlir/lib/AsmParser/Parser.cpp | 40 ++++++++++++++++++++++-------------
 1 file changed, 25 insertions(+), 15 deletions(-)
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 8d3861a2f018d00..282077865c665a9 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -611,6 +611,10 @@ class OperationParser : public Parser {
   /// an object of type 'OperationName'. Otherwise, failure is returned.
   FailureOr<OperationName> parseCustomOperationName();
 
+  /// Store the SSA names for the current operation as attrs for debug purposes.
+  ParseResult storeSSANames(Operation *&op,
+                            SmallVector<ResultRecord, 1> resultIDs);
+
   //===--------------------------------------------------------------------===//
   // Region Parsing
   //===--------------------------------------------------------------------===//
@@ -1224,21 +1228,8 @@ ParseResult OperationParser::parseOperation() {
     }
 
     // If enabled, store the SSA name(s) for the operation
-    if (state.config.shouldRetainIdentifierNames()) {
-      if (opResI == 1) {
-        for (ResultRecord &resIt : resultIDs) {
-          for (unsigned subRes : llvm::seq<unsigned>(0, std::get<1>(resIt))) {
-            op->setDiscardableAttr(
-                "mlir.ssaName",
-                StringAttr::get(getContext(),
-                                std::get<0>(resIt).drop_front(1)));
-          }
-        }
-      } else if (opResI > 1) {
-        emitError(
-            "have not yet implemented support for multiple return values");
-      }
-    }
+    if (state.config.shouldRetainIdentifierNames())
+      storeSSANames(op, resultIDs);
 
     // Add this operation to the assembly state if it was provided to populate.
   } else if (state.asmState) {
@@ -1285,6 +1276,25 @@ OperationParser::parseSuccessors(SmallVectorImpl<Block *> &destinations) {
                                       /*allowEmptyList=*/false);
 }
 
+/// Store the SSA names for the current operation as attrs for debug purposes.
+ParseResult
+OperationParser::storeSSANames(Operation *&op,
+                               SmallVector<ResultRecord, 1> resultIDs) {
+  if (op->getNumResults() == 0)
+    emitError("Operation has no results\n");
+  else if (op->getNumResults() > 1)
+    emitError("have not yet implemented support for multiple return values\n");
+
+  for (ResultRecord &resIt : resultIDs) {
+    for (unsigned subRes : llvm::seq<unsigned>(0, std::get<1>(resIt))) {
+      op->setDiscardableAttr(
+          "mlir.ssaName",
+          StringAttr::get(getContext(), std::get<0>(resIt).drop_front(1)));
+    }
+  }
+  return success();
+}
+
 namespace {
 // RAII-style guard for cleaning up the regions in the operation state before
 // deleting them.  Within the parser, regions may get deleted if parsing failed,
>From f2f8047bd80df0f6da6a8db4bfe976491157993e Mon Sep 17 00:00:00 2001
From: Perry Gibson <perry at gibsonic.org>
Date: Fri, 26 Jan 2024 15:39:18 +0000
Subject: [PATCH 03/11] Added initial block name handling
---
 mlir/lib/AsmParser/Parser.cpp | 13 +++++++++++++
 mlir/lib/IR/AsmPrinter.cpp    | 33 ++++++++++++++++++++++++++++-----
 2 files changed, 41 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 282077865c665a9..bb3fd5d8729406a 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -1292,6 +1292,19 @@ OperationParser::storeSSANames(Operation *&op,
           StringAttr::get(getContext(), std::get<0>(resIt).drop_front(1)));
     }
   }
+
+  // Find the name of the block that contains this operation.
+  Block *blockPtr = op->getBlock();
+  for (const auto &map : blocksByName) {
+    for (const auto &entry : map) {
+      if (entry.second.block == blockPtr) {
+        op->setDiscardableAttr("mlir.blockName",
+                               StringAttr::get(getContext(), entry.first));
+        llvm::outs() << "Block name: " << entry.first << "\n";
+      }
+    }
+  }
+
   return success();
 }
 
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 164b0f97fc1cd9e..70d38c9bd8b4f89 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1298,6 +1298,10 @@ class SSANameState {
   /// conflicts, it is automatically renamed.
   StringRef uniqueValueName(StringRef name);
 
+  /// Set the original identifier names if available. Used in debugging with
+  /// `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig
+  void setRetainedIdentifierNames(Operation &op);
+
   /// This is the value ID for each SSA value. If this returns NameSentinel,
   /// then the valueID has an entry in valueNames.
   DenseMap<Value, unsigned> valueIDs;
@@ -1578,11 +1582,9 @@ void SSANameState::numberValuesInOp(Operation &op) {
   }
   Value resultBegin = op.getResult(0);
 
-  // Get the original SSA for the result if available
-  if (StringAttr ssaNameAttr = op.getAttrOfType<StringAttr>("mlir.ssaName")) {
-    setValueName(resultBegin, ssaNameAttr.strref());
-    op.removeDiscardableAttr("mlir.ssaName");
-  }
+  // Set the original identifier names if available. Used in debugging with
+  // `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig
+  setRetainedIdentifierNames(op);
 
   // If the first result wasn't numbered, give it a default number.
   if (valueIDs.try_emplace(resultBegin, nextValueID).second)
@@ -1595,6 +1597,27 @@ void SSANameState::numberValuesInOp(Operation &op) {
   }
 }
 
+void SSANameState::setRetainedIdentifierNames(Operation &op) {
+  // Get the original SSA for the result(s) if available
+  Value resultBegin = op.getResult(0);
+  if (StringAttr ssaNameAttr = op.getAttrOfType<StringAttr>("mlir.ssaName")) {
+    setValueName(resultBegin, ssaNameAttr.strref());
+    op.removeDiscardableAttr("mlir.ssaName");
+  }
+  unsigned numResults = op.getNumResults();
+  if (numResults > 1)
+    llvm::outs()
+        << "have not yet implemented support for multiple return values\n";
+
+  // Get the original SSA name for the block if available
+  if (StringAttr blockNameAttr =
+          op.getAttrOfType<StringAttr>("mlir.blockName")) {
+    blockNames[op.getBlock()] = {-1, blockNameAttr.strref()};
+    op.removeDiscardableAttr("mlir.blockName");
+  }
+  return;
+}
+
 void SSANameState::getResultIDAndNumber(
     OpResult result, Value &lookupValue,
     std::optional<int> &lookupResultNo) const {
>From 1aa3814f6eb47ec8d51d2eb68e40e469194e2744 Mon Sep 17 00:00:00 2001
From: Perry Gibson <perry at gibsonic.org>
Date: Fri, 26 Jan 2024 16:18:27 +0000
Subject: [PATCH 04/11] Added block arg name handling
---
 mlir/lib/AsmParser/Parser.cpp | 25 ++++++++++++++++++++-----
 mlir/lib/IR/AsmPrinter.cpp    | 14 +++++++++++++-
 2 files changed, 33 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index bb3fd5d8729406a..95b1b2fe1f8f07e 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -614,6 +614,7 @@ class OperationParser : public Parser {
   /// Store the SSA names for the current operation as attrs for debug purposes.
   ParseResult storeSSANames(Operation *&op,
                             SmallVector<ResultRecord, 1> resultIDs);
+  DenseMap<BlockArgument, StringRef> blockArgNames;
 
   //===--------------------------------------------------------------------===//
   // Region Parsing
@@ -1203,6 +1204,11 @@ ParseResult OperationParser::parseOperation() {
              << op->getNumResults() << " results but was provided "
              << numExpectedResults << " to bind";
 
+    // If enabled, store the SSA name(s) for the operation
+    llvm::outs() << "parsing operation: " << op->getName() << "\n";
+    if (state.config.shouldRetainIdentifierNames())
+      storeSSANames(op, resultIDs);
+
     // Add this operation to the assembly state if it was provided to populate.
     if (state.asmState) {
       unsigned resultIt = 0;
@@ -1227,10 +1233,6 @@ ParseResult OperationParser::parseOperation() {
       }
     }
 
-    // If enabled, store the SSA name(s) for the operation
-    if (state.config.shouldRetainIdentifierNames())
-      storeSSANames(op, resultIDs);
-
     // Add this operation to the assembly state if it was provided to populate.
   } else if (state.asmState) {
     state.asmState->finalizeOperationDefinition(
@@ -1300,7 +1302,16 @@ OperationParser::storeSSANames(Operation *&op,
       if (entry.second.block == blockPtr) {
         op->setDiscardableAttr("mlir.blockName",
                                StringAttr::get(getContext(), entry.first));
-        llvm::outs() << "Block name: " << entry.first << "\n";
+
+        // Store block arguments, if present
+        llvm::SmallVector<llvm::StringRef, 1> argNames;
+
+        for (BlockArgument arg : blockPtr->getArguments()) {
+          auto it = blockArgNames.find(arg);
+          if (it != blockArgNames.end())
+            argNames.push_back(it->second.drop_front(1));
+        }
+        op->setAttr("mlir.blockArgNames", builder.getStrArrayAttr(argNames));
       }
     }
   }
@@ -2395,6 +2406,10 @@ ParseResult OperationParser::parseOptionalBlockArgList(Block *owner) {
           } else {
             auto loc = getEncodedSourceLocation(useInfo.location);
             arg = owner->addArgument(type, loc);
+
+            // Optionally store argument name for debug purposes
+            if (state.config.shouldRetainIdentifierNames())
+              blockArgNames.insert({arg, useInfo.name});
           }
 
           // If the argument has an explicit loc(...) specifier, parse and apply
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 70d38c9bd8b4f89..34e0bdfccda94e5 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1609,12 +1609,24 @@ void SSANameState::setRetainedIdentifierNames(Operation &op) {
     llvm::outs()
         << "have not yet implemented support for multiple return values\n";
 
-  // Get the original SSA name for the block if available
+  // Get the original name for the block if available
   if (StringAttr blockNameAttr =
           op.getAttrOfType<StringAttr>("mlir.blockName")) {
     blockNames[op.getBlock()] = {-1, blockNameAttr.strref()};
     op.removeDiscardableAttr("mlir.blockName");
   }
+
+  // Get the original name for the block args if available
+  if (ArrayAttr blockArgNamesAttr =
+          op.getAttrOfType<ArrayAttr>("mlir.blockArgNames")) {
+    auto blockArgNames = blockArgNamesAttr.getValue();
+    auto blockArgs = op.getBlock()->getArguments();
+    for (int i = 0; i < blockArgs.size() && i < blockArgNames.size(); ++i) {
+      auto blockArgName = blockArgNames[i].cast<StringAttr>();
+      setValueName(blockArgs[i], cast<StringAttr>(blockArgNames[i]).strref());
+    }
+    op.removeDiscardableAttr("mlir.blockArgNames");
+  }
   return;
 }
 
>From c0ec451d9ba155478ee5bded93cfaa912480e0ce Mon Sep 17 00:00:00 2001
From: Perry Gibson <perry at gibsonic.org>
Date: Sat, 27 Jan 2024 11:24:53 +0000
Subject: [PATCH 05/11] Added support for operations with no arguments
---
 mlir/lib/AsmParser/Parser.cpp | 26 ++++++++++----------------
 mlir/lib/IR/AsmPrinter.cpp    | 27 +++++++++++++++------------
 2 files changed, 25 insertions(+), 28 deletions(-)
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 95b1b2fe1f8f07e..36266be5af033c0 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -612,8 +612,7 @@ class OperationParser : public Parser {
   FailureOr<OperationName> parseCustomOperationName();
 
   /// Store the SSA names for the current operation as attrs for debug purposes.
-  ParseResult storeSSANames(Operation *&op,
-                            SmallVector<ResultRecord, 1> resultIDs);
+  void storeSSANames(Operation *&op, ArrayRef<ResultRecord> resultIDs);
   DenseMap<BlockArgument, StringRef> blockArgNames;
 
   //===--------------------------------------------------------------------===//
@@ -1204,11 +1203,6 @@ ParseResult OperationParser::parseOperation() {
              << op->getNumResults() << " results but was provided "
              << numExpectedResults << " to bind";
 
-    // If enabled, store the SSA name(s) for the operation
-    llvm::outs() << "parsing operation: " << op->getName() << "\n";
-    if (state.config.shouldRetainIdentifierNames())
-      storeSSANames(op, resultIDs);
-
     // Add this operation to the assembly state if it was provided to populate.
     if (state.asmState) {
       unsigned resultIt = 0;
@@ -1279,15 +1273,12 @@ OperationParser::parseSuccessors(SmallVectorImpl<Block *> &destinations) {
 }
 
 /// Store the SSA names for the current operation as attrs for debug purposes.
-ParseResult
-OperationParser::storeSSANames(Operation *&op,
-                               SmallVector<ResultRecord, 1> resultIDs) {
-  if (op->getNumResults() == 0)
-    emitError("Operation has no results\n");
-  else if (op->getNumResults() > 1)
+void OperationParser::storeSSANames(Operation *&op,
+                                    ArrayRef<ResultRecord> resultIDs) {
+  if (op->getNumResults() > 1)
     emitError("have not yet implemented support for multiple return values\n");
 
-  for (ResultRecord &resIt : resultIDs) {
+  for (const ResultRecord &resIt : resultIDs) {
     for (unsigned subRes : llvm::seq<unsigned>(0, std::get<1>(resIt))) {
       op->setDiscardableAttr(
           "mlir.ssaName",
@@ -1315,8 +1306,6 @@ OperationParser::storeSSANames(Operation *&op,
       }
     }
   }
-
-  return success();
 }
 
 namespace {
@@ -2082,6 +2071,11 @@ OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) {
 
   // Otherwise, create the operation and try to parse a location for it.
   Operation *op = opBuilder.create(opState);
+
+  // If enabled, store the SSA name(s) for the operation
+  if (state.config.shouldRetainIdentifierNames())
+    storeSSANames(op, resultIDs);
+
   if (parseTrailingLocationSpecifier(op))
     return nullptr;
 
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 34e0bdfccda94e5..d4be3b7802e6880 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1571,6 +1571,10 @@ void SSANameState::numberValuesInOp(Operation &op) {
     }
   }
 
+  // Set the original identifier names if available. Used in debugging with
+  // `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig
+  setRetainedIdentifierNames(op);
+
   unsigned numResults = op.getNumResults();
   if (numResults == 0) {
     // If value users should be printed, operations with no result need an id.
@@ -1582,10 +1586,6 @@ void SSANameState::numberValuesInOp(Operation &op) {
   }
   Value resultBegin = op.getResult(0);
 
-  // Set the original identifier names if available. Used in debugging with
-  // `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig
-  setRetainedIdentifierNames(op);
-
   // If the first result wasn't numbered, give it a default number.
   if (valueIDs.try_emplace(resultBegin, nextValueID).second)
     ++nextValueID;
@@ -1599,15 +1599,17 @@ void SSANameState::numberValuesInOp(Operation &op) {
 
 void SSANameState::setRetainedIdentifierNames(Operation &op) {
   // Get the original SSA for the result(s) if available
-  Value resultBegin = op.getResult(0);
-  if (StringAttr ssaNameAttr = op.getAttrOfType<StringAttr>("mlir.ssaName")) {
-    setValueName(resultBegin, ssaNameAttr.strref());
-    op.removeDiscardableAttr("mlir.ssaName");
-  }
   unsigned numResults = op.getNumResults();
   if (numResults > 1)
     llvm::outs()
         << "have not yet implemented support for multiple return values\n";
+  else if (numResults == 1) {
+    Value resultBegin = op.getResult(0);
+    if (StringAttr ssaNameAttr = op.getAttrOfType<StringAttr>("mlir.ssaName")) {
+      setValueName(resultBegin, ssaNameAttr.strref());
+      op.removeDiscardableAttr("mlir.ssaName");
+    }
+  }
 
   // Get the original name for the block if available
   if (StringAttr blockNameAttr =
@@ -1621,9 +1623,10 @@ void SSANameState::setRetainedIdentifierNames(Operation &op) {
           op.getAttrOfType<ArrayAttr>("mlir.blockArgNames")) {
     auto blockArgNames = blockArgNamesAttr.getValue();
     auto blockArgs = op.getBlock()->getArguments();
-    for (int i = 0; i < blockArgs.size() && i < blockArgNames.size(); ++i) {
-      auto blockArgName = blockArgNames[i].cast<StringAttr>();
-      setValueName(blockArgs[i], cast<StringAttr>(blockArgNames[i]).strref());
+    for (size_t i = 0; i < blockArgs.size() && i < blockArgNames.size(); ++i) {
+      auto blockArgName = blockArgNames[i].cast<StringAttr>().strref();
+      if (!usedNames.count(blockArgName))
+        setValueName(blockArgs[i], blockArgName);
     }
     op.removeDiscardableAttr("mlir.blockArgNames");
   }
>From a9e3891f4ed0511de6797cfe986fe8d3dcbf8f08 Mon Sep 17 00:00:00 2001
From: Perry Gibson <perry at gibsonic.org>
Date: Sat, 27 Jan 2024 12:21:39 +0000
Subject: [PATCH 06/11] Added initial unit tests
---
 mlir/test/IR/print-retain-identifiers.mlir | 54 ++++++++++++++++++++++
 1 file changed, 54 insertions(+)
 create mode 100644 mlir/test/IR/print-retain-identifiers.mlir
diff --git a/mlir/test/IR/print-retain-identifiers.mlir b/mlir/test/IR/print-retain-identifiers.mlir
new file mode 100644
index 000000000000000..663f7301b817bc0
--- /dev/null
+++ b/mlir/test/IR/print-retain-identifiers.mlir
@@ -0,0 +1,54 @@
+// RUN: mlir-opt -retain-identifier-names %s | FileCheck %s
+
+
+//===----------------------------------------------------------------------===//
+// Test SSA results (with single return values)
+//===----------------------------------------------------------------------===//
+
+// CHECK: func.func @add_one(%arg0: f64, %arg1: f64) -> f64 {
+func.func @add_one(%arg0: f64, %arg1: f64) -> f64 {
+  // CHECK: %my_constant = arith.constant 1.000000e+00 : f64
+  %my_constant = arith.constant 1.000000e+00 : f64
+  // CHECK: %my_output = arith.addf %arg0, %my_constant : f64
+  %my_output = arith.addf %arg0, %my_constant : f64
+  // CHECK: return %my_output : f64
+  return %my_output : f64
+}
+
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Test basic blocks and their arguments
+//===----------------------------------------------------------------------===//
+
+func.func @simple(i64, i1) -> i64 {
+^bb_alpha(%a: i64, %cond: i1):
+  // CHECK: cf.cond_br %cond, ^bb_beta, ^bb_gamma
+  cf.cond_br %cond, ^bb_beta, ^bb_gamma
+
+// CHECK: ^bb_beta:  // pred: ^bb_alpha
+^bb_beta:
+  // CHECK: cf.br ^bb_delta(%a : i64)
+  cf.br ^bb_delta(%a: i64)
+
+// CHECK: ^bb_gamma:  // pred: ^bb_alpha
+^bb_gamma:
+  // CHECK: %b = arith.addi %a, %a : i64
+  %b = arith.addi %a, %a : i64
+  // CHECK: cf.br ^bb_delta(%b : i64)
+  cf.br ^bb_delta(%b: i64)
+
+// CHECK: ^bb_delta(%c: i64):  // 2 preds: ^bb_gamma, ^bb_beta
+^bb_delta(%c: i64):
+  // CHECK: cf.br ^bb_eps(%c, %a : i64, i64)
+  cf.br ^bb_eps(%c, %a : i64, i64)
+
+// CHECK: ^bb_eps(%d: i64, %e: i64):  // pred: ^bb_delta
+^bb_eps(%d : i64, %e : i64):
+  // CHECK: %f = arith.addi %d, %e : i64
+  %f = arith.addi %d, %e : i64
+  return %f : i64
+}
+
+// -----
>From 467c3a38abfbbc22c25b0874a3943ee7ad8bc78b Mon Sep 17 00:00:00 2001
From: Perry Gibson <perry at gibsonic.org>
Date: Sat, 27 Jan 2024 14:49:16 +0000
Subject: [PATCH 07/11] Added operation operand name preservation
---
 mlir/lib/AsmParser/Parser.cpp              | 34 +++++++++++++++++-----
 mlir/lib/IR/AsmPrinter.cpp                 | 13 +++++++++
 mlir/test/IR/print-retain-identifiers.mlir |  8 ++---
 3 files changed, 43 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 36266be5af033c0..7327776ae5a55d1 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -613,7 +613,7 @@ class OperationParser : public Parser {
 
   /// Store the SSA names for the current operation as attrs for debug purposes.
   void storeSSANames(Operation *&op, ArrayRef<ResultRecord> resultIDs);
-  DenseMap<BlockArgument, StringRef> blockArgNames;
+  DenseMap<Value, StringRef> argNames;
 
   //===--------------------------------------------------------------------===//
   // Region Parsing
@@ -1286,7 +1286,19 @@ void OperationParser::storeSSANames(Operation *&op,
     }
   }
 
-  // Find the name of the block that contains this operation.
+  // Store the name information of the arguments of this operation.
+  if (op->getNumOperands() > 0) {
+    llvm::SmallVector<llvm::StringRef, 1> opArgNames;
+    for (auto &operand : op->getOpOperands()) {
+      auto it = argNames.find(operand.get());
+      if (it != argNames.end())
+        opArgNames.push_back(it->second.drop_front(1));
+    }
+    op->setDiscardableAttr("mlir.opArgNames",
+                           builder.getStrArrayAttr(opArgNames));
+  }
+
+  // Store the name information of the block that contains this operation.
   Block *blockPtr = op->getBlock();
   for (const auto &map : blocksByName) {
     for (const auto &entry : map) {
@@ -1295,14 +1307,15 @@ void OperationParser::storeSSANames(Operation *&op,
                                StringAttr::get(getContext(), entry.first));
 
         // Store block arguments, if present
-        llvm::SmallVector<llvm::StringRef, 1> argNames;
+        llvm::SmallVector<llvm::StringRef, 1> blockArgNames;
 
         for (BlockArgument arg : blockPtr->getArguments()) {
-          auto it = blockArgNames.find(arg);
-          if (it != blockArgNames.end())
-            argNames.push_back(it->second.drop_front(1));
+          auto it = argNames.find(arg);
+          if (it != argNames.end())
+            blockArgNames.push_back(it->second.drop_front(1));
         }
-        op->setAttr("mlir.blockArgNames", builder.getStrArrayAttr(argNames));
+        op->setAttr("mlir.blockArgNames",
+                    builder.getStrArrayAttr(blockArgNames));
       }
     }
   }
@@ -1712,6 +1725,11 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
                              SmallVectorImpl<Value> &result) override {
     if (auto value = parser.resolveSSAUse(operand, type)) {
       result.push_back(value);
+
+      // Optionally store argument name for debug purposes
+      if (parser.getState().config.shouldRetainIdentifierNames())
+        parser.argNames.insert({value, operand.name});
+
       return success();
     }
     return failure();
@@ -2403,7 +2421,7 @@ ParseResult OperationParser::parseOptionalBlockArgList(Block *owner) {
 
             // Optionally store argument name for debug purposes
             if (state.config.shouldRetainIdentifierNames())
-              blockArgNames.insert({arg, useInfo.name});
+              argNames.insert({arg, useInfo.name});
           }
 
           // If the argument has an explicit loc(...) specifier, parse and apply
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index d4be3b7802e6880..5755b9021dbad7f 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1611,6 +1611,19 @@ void SSANameState::setRetainedIdentifierNames(Operation &op) {
     }
   }
 
+  // Get the original name for the op args if available
+  if (ArrayAttr opArgNamesAttr =
+          op.getAttrOfType<ArrayAttr>("mlir.opArgNames")) {
+    auto opArgNames = opArgNamesAttr.getValue();
+    auto opArgs = op.getOperands();
+    for (size_t i = 0; i < opArgs.size() && i < opArgNames.size(); ++i) {
+      auto opArgName = opArgNames[i].cast<StringAttr>().strref();
+      if (!usedNames.count(opArgName))
+        setValueName(opArgs[i], opArgName);
+    }
+    op.removeDiscardableAttr("mlir.opArgNames");
+  }
+
   // Get the original name for the block if available
   if (StringAttr blockNameAttr =
           op.getAttrOfType<StringAttr>("mlir.blockName")) {
diff --git a/mlir/test/IR/print-retain-identifiers.mlir b/mlir/test/IR/print-retain-identifiers.mlir
index 663f7301b817bc0..05a6b9b42d8e820 100644
--- a/mlir/test/IR/print-retain-identifiers.mlir
+++ b/mlir/test/IR/print-retain-identifiers.mlir
@@ -5,12 +5,12 @@
 // Test SSA results (with single return values)
 //===----------------------------------------------------------------------===//
 
-// CHECK: func.func @add_one(%arg0: f64, %arg1: f64) -> f64 {
-func.func @add_one(%arg0: f64, %arg1: f64) -> f64 {
+// CHECK: func.func @add_one(%my_input: f64, %arg1: f64) -> f64 {
+func.func @add_one(%my_input: f64, %arg1: f64) -> f64 {
   // CHECK: %my_constant = arith.constant 1.000000e+00 : f64
   %my_constant = arith.constant 1.000000e+00 : f64
-  // CHECK: %my_output = arith.addf %arg0, %my_constant : f64
-  %my_output = arith.addf %arg0, %my_constant : f64
+  // CHECK: %my_output = arith.addf %my_input, %my_constant : f64
+  %my_output = arith.addf %my_input, %my_constant : f64
   // CHECK: return %my_output : f64
   return %my_output : f64
 }
>From d062cb861c283594413c8edc7dca57f9494b1f3e Mon Sep 17 00:00:00 2001
From: Perry Gibson <perry at gibsonic.org>
Date: Sat, 27 Jan 2024 18:27:53 +0000
Subject: [PATCH 08/11] Added support for result groups
---
 mlir/lib/AsmParser/Parser.cpp              | 19 ++++++-----
 mlir/lib/IR/AsmPrinter.cpp                 | 36 ++++++++++++--------
 mlir/test/IR/print-retain-identifiers.mlir | 38 ++++++++++++++++++++++
 3 files changed, 72 insertions(+), 21 deletions(-)
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 7327776ae5a55d1..247e99e61c2c01b 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -1275,15 +1275,18 @@ OperationParser::parseSuccessors(SmallVectorImpl<Block *> &destinations) {
 /// Store the SSA names for the current operation as attrs for debug purposes.
 void OperationParser::storeSSANames(Operation *&op,
                                     ArrayRef<ResultRecord> resultIDs) {
-  if (op->getNumResults() > 1)
-    emitError("have not yet implemented support for multiple return values\n");
-
-  for (const ResultRecord &resIt : resultIDs) {
-    for (unsigned subRes : llvm::seq<unsigned>(0, std::get<1>(resIt))) {
-      op->setDiscardableAttr(
-          "mlir.ssaName",
-          StringAttr::get(getContext(), std::get<0>(resIt).drop_front(1)));
+
+  // Store the name(s) of the result(s) of this operation.
+  if (op->getNumResults() > 0) {
+    llvm::SmallVector<llvm::StringRef, 1> resultNames;
+    for (const ResultRecord &resIt : resultIDs) {
+      resultNames.push_back(std::get<0>(resIt).drop_front(1));
+      // Insert empty string for sub-results/result groups
+      for (unsigned int i = 1; i < std::get<1>(resIt); ++i)
+        resultNames.push_back(llvm::StringRef());
     }
+    op->setDiscardableAttr("mlir.resultNames",
+                           builder.getStrArrayAttr(resultNames));
   }
 
   // Store the name information of the arguments of this operation.
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 5755b9021dbad7f..84603bb6ebfba3a 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1300,7 +1300,8 @@ class SSANameState {
 
   /// Set the original identifier names if available. Used in debugging with
   /// `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig
-  void setRetainedIdentifierNames(Operation &op);
+  void setRetainedIdentifierNames(Operation &op,
+                                  SmallVector<int, 2> &resultGroups);
 
   /// This is the value ID for each SSA value. If this returns NameSentinel,
   /// then the valueID has an entry in valueNames.
@@ -1573,7 +1574,7 @@ void SSANameState::numberValuesInOp(Operation &op) {
 
   // Set the original identifier names if available. Used in debugging with
   // `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig
-  setRetainedIdentifierNames(op);
+  setRetainedIdentifierNames(op, resultGroups);
 
   unsigned numResults = op.getNumResults();
   if (numResults == 0) {
@@ -1597,18 +1598,25 @@ void SSANameState::numberValuesInOp(Operation &op) {
   }
 }
 
-void SSANameState::setRetainedIdentifierNames(Operation &op) {
-  // Get the original SSA for the result(s) if available
-  unsigned numResults = op.getNumResults();
-  if (numResults > 1)
-    llvm::outs()
-        << "have not yet implemented support for multiple return values\n";
-  else if (numResults == 1) {
-    Value resultBegin = op.getResult(0);
-    if (StringAttr ssaNameAttr = op.getAttrOfType<StringAttr>("mlir.ssaName")) {
-      setValueName(resultBegin, ssaNameAttr.strref());
-      op.removeDiscardableAttr("mlir.ssaName");
+void SSANameState::setRetainedIdentifierNames(
+    Operation &op, SmallVector<int, 2> &resultGroups) {
+  // Get the original names for the results if available
+  if (ArrayAttr resultNamesAttr =
+          op.getAttrOfType<ArrayAttr>("mlir.resultNames")) {
+    auto resultNames = resultNamesAttr.getValue();
+    auto results = op.getResults();
+    // Conservative in the case that the #results has changed
+    for (size_t i = 0; i < results.size() && i < resultNames.size(); ++i) {
+      auto resultName = resultNames[i].cast<StringAttr>().strref();
+      if (!resultName.empty()) {
+        if (!usedNames.count(resultName))
+          setValueName(results[i], resultName);
+        // If a result has a name, it is the start of a result group.
+        if (i > 0)
+          resultGroups.push_back(i);
+      }
     }
+    op.removeDiscardableAttr("mlir.resultNames");
   }
 
   // Get the original name for the op args if available
@@ -1616,6 +1624,7 @@ void SSANameState::setRetainedIdentifierNames(Operation &op) {
           op.getAttrOfType<ArrayAttr>("mlir.opArgNames")) {
     auto opArgNames = opArgNamesAttr.getValue();
     auto opArgs = op.getOperands();
+    // Conservative in the case that the #operands has changed
     for (size_t i = 0; i < opArgs.size() && i < opArgNames.size(); ++i) {
       auto opArgName = opArgNames[i].cast<StringAttr>().strref();
       if (!usedNames.count(opArgName))
@@ -1636,6 +1645,7 @@ void SSANameState::setRetainedIdentifierNames(Operation &op) {
           op.getAttrOfType<ArrayAttr>("mlir.blockArgNames")) {
     auto blockArgNames = blockArgNamesAttr.getValue();
     auto blockArgs = op.getBlock()->getArguments();
+    // Conservative in the case that the #args has changed
     for (size_t i = 0; i < blockArgs.size() && i < blockArgNames.size(); ++i) {
       auto blockArgName = blockArgNames[i].cast<StringAttr>().strref();
       if (!usedNames.count(blockArgName))
diff --git a/mlir/test/IR/print-retain-identifiers.mlir b/mlir/test/IR/print-retain-identifiers.mlir
index 05a6b9b42d8e820..b3e4f075b3936a8 100644
--- a/mlir/test/IR/print-retain-identifiers.mlir
+++ b/mlir/test/IR/print-retain-identifiers.mlir
@@ -52,3 +52,41 @@ func.func @simple(i64, i1) -> i64 {
 }
 
 // -----
+
+//===----------------------------------------------------------------------===//
+// Test multiple return values
+//===----------------------------------------------------------------------===//
+
+func.func @select_min_max(%a: f64, %b: f64) -> (f64, f64) {
+  %gt = arith.cmpf "ogt", %a, %b : f64
+  // CHECK: %min, %max = scf.if %gt -> (f64, f64) {
+  %min, %max = scf.if %gt -> (f64, f64) {
+    scf.yield %b, %a : f64, f64
+  } else {
+    scf.yield %a, %b : f64, f64
+  }
+  // CHECK: return %min, %max : f64, f64
+  return %min, %max : f64, f64
+}
+
+// -----
+
+////===----------------------------------------------------------------------===//
+// Test multiple return values, with a grouped value tuple
+//===----------------------------------------------------------------------===//
+
+func.func @select_max(%a: f64, %b: f64, %c: f64, %d: f64) -> (f64, f64, f64, f64) {
+  // Find the max between %a and %b,
+  // with %c and %d being other values that are returned.
+  %gt = arith.cmpf "ogt", %a, %b : f64
+  // CHECK: %max, %others:2, %alt = scf.if %gt -> (f64, f64, f64, f64) {
+  %max, %others:2, %alt  = scf.if %gt -> (f64, f64, f64, f64) {
+    scf.yield %b, %a, %c, %d : f64, f64, f64, f64
+  } else {
+    scf.yield %a, %b, %d, %c : f64, f64, f64, f64
+  }
+  // CHECK: return %max, %others#0, %others#1, %alt : f64, f64, f64, f64
+  return %max, %others#0, %others#1, %alt : f64, f64, f64, f64
+}
+
+// -----
>From fb557df2c4afec5aaaf876428769b70d29306244 Mon Sep 17 00:00:00 2001
From: Perry Gibson <perry at gibsonic.org>
Date: Sat, 27 Jan 2024 18:58:35 +0000
Subject: [PATCH 09/11] Fix clang-format issue in AsmState.h
---
 mlir/include/mlir/IR/AsmState.h | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/IR/AsmState.h b/mlir/include/mlir/IR/AsmState.h
index 9c4eadb04cdf2fe..36f80712efb6ac5 100644
--- a/mlir/include/mlir/IR/AsmState.h
+++ b/mlir/include/mlir/IR/AsmState.h
@@ -144,7 +144,8 @@ class AsmResourceBlob {
   /// Return the underlying data as an array of the given type. This is an
   /// inherrently unsafe operation, and should only be used when the data is
   /// known to be of the correct type.
-  template <typename T> ArrayRef<T> getDataAs() const {
+  template <typename T>
+  ArrayRef<T> getDataAs() const {
     return llvm::ArrayRef<T>((const T *)data.data(), data.size() / sizeof(T));
   }
 
>From 233a987408742eb47c9d55bbe1013af4c0b0af10 Mon Sep 17 00:00:00 2001
From: Perry Gibson <perry at gibsonic.org>
Date: Sun, 28 Jan 2024 16:22:52 +0000
Subject: [PATCH 10/11] Added system to handle when we use default names
---
 mlir/lib/IR/AsmPrinter.cpp | 45 ++++++++++++++++++++++----------------
 1 file changed, 26 insertions(+), 19 deletions(-)
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 84603bb6ebfba3a..e3bbe75441d588a 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -981,7 +981,8 @@ class DummyAliasDialectAsmPrinter : public DialectAsmPrinter {
 /// store the new copy,
 static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
                                     StringRef allowedPunctChars = "$._-",
-                                    bool allowTrailingDigit = true) {
+                                    bool allowTrailingDigit = true,
+                                    bool allowNumeric = false) {
   assert(!name.empty() && "Shouldn't have an empty name here");
 
   auto copyNameToBuffer = [&] {
@@ -997,16 +998,17 @@ static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
 
   // Check to see if this name is valid. If it starts with a digit, then it
   // could conflict with the autogenerated numeric ID's, so add an underscore
-  // prefix to avoid problems.
-  if (isdigit(name[0])) {
+  // prefix to avoid problems. This can be overridden by setting allowNumeric.
+  if (isdigit(name[0]) && !allowNumeric) {
     buffer.push_back('_');
     copyNameToBuffer();
     return buffer;
   }
 
   // If the name ends with a trailing digit, add a '_' to avoid potential
-  // conflicts with autogenerated ID's.
-  if (!allowTrailingDigit && isdigit(name.back())) {
+  // conflicts with autogenerated ID's. This can be overridden by setting
+  // allowNumeric.
+  if (!allowTrailingDigit && isdigit(name.back()) && !allowNumeric) {
     copyNameToBuffer();
     buffer.push_back('_');
     return buffer;
@@ -1292,11 +1294,11 @@ class SSANameState {
                             std::optional<int> &lookupResultNo) const;
 
   /// Set a special value name for the given value.
-  void setValueName(Value value, StringRef name);
+  void setValueName(Value value, StringRef name, bool allowNumeric = false);
 
   /// Uniques the given value name within the printer. If the given name
   /// conflicts, it is automatically renamed.
-  StringRef uniqueValueName(StringRef name);
+  StringRef uniqueValueName(StringRef name, bool allowNumeric = false);
 
   /// Set the original identifier names if available. Used in debugging with
   /// `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig
@@ -1541,7 +1543,10 @@ void SSANameState::numberValuesInOp(Operation &op) {
   // Function used to set the special result names for the operation.
   SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0);
   auto setResultNameFn = [&](Value result, StringRef name) {
-    assert(!valueIDs.count(result) && "result numbered multiple times");
+    // Case where the result has already been named
+    if (valueIDs.count(result))
+      return;
+    // assert(!valueIDs.count(result) && "result numbered multiple times");
     assert(result.getDefiningOp() == &op && "result not defined by 'op'");
     setValueName(result, name);
 
@@ -1565,6 +1570,10 @@ void SSANameState::numberValuesInOp(Operation &op) {
     blockNames[block] = {-1, name};
   };
 
+  // Set the original identifier names if available. Used in debugging with
+  // `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig
+  setRetainedIdentifierNames(op, resultGroups);
+
   if (!printerFlags.shouldPrintGenericOpForm()) {
     if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op)) {
       asmInterface.getAsmBlockNames(setBlockNameFn);
@@ -1572,10 +1581,6 @@ void SSANameState::numberValuesInOp(Operation &op) {
     }
   }
 
-  // Set the original identifier names if available. Used in debugging with
-  // `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig
-  setRetainedIdentifierNames(op, resultGroups);
-
   unsigned numResults = op.getNumResults();
   if (numResults == 0) {
     // If value users should be printed, operations with no result need an id.
@@ -1610,7 +1615,7 @@ void SSANameState::setRetainedIdentifierNames(
       auto resultName = resultNames[i].cast<StringAttr>().strref();
       if (!resultName.empty()) {
         if (!usedNames.count(resultName))
-          setValueName(results[i], resultName);
+          setValueName(results[i], resultName, /*allowNumeric=*/true);
         // If a result has a name, it is the start of a result group.
         if (i > 0)
           resultGroups.push_back(i);
@@ -1628,7 +1633,7 @@ void SSANameState::setRetainedIdentifierNames(
     for (size_t i = 0; i < opArgs.size() && i < opArgNames.size(); ++i) {
       auto opArgName = opArgNames[i].cast<StringAttr>().strref();
       if (!usedNames.count(opArgName))
-        setValueName(opArgs[i], opArgName);
+        setValueName(opArgs[i], opArgName, /*allowNumeric=*/true);
     }
     op.removeDiscardableAttr("mlir.opArgNames");
   }
@@ -1649,7 +1654,7 @@ void SSANameState::setRetainedIdentifierNames(
     for (size_t i = 0; i < blockArgs.size() && i < blockArgNames.size(); ++i) {
       auto blockArgName = blockArgNames[i].cast<StringAttr>().strref();
       if (!usedNames.count(blockArgName))
-        setValueName(blockArgs[i], blockArgName);
+        setValueName(blockArgs[i], blockArgName, /*allowNumeric=*/true);
     }
     op.removeDiscardableAttr("mlir.blockArgNames");
   }
@@ -1695,7 +1700,8 @@ void SSANameState::getResultIDAndNumber(
   lookupValue = owner->getResult(groupResultNo);
 }
 
-void SSANameState::setValueName(Value value, StringRef name) {
+void SSANameState::setValueName(Value value, StringRef name,
+                                bool allowNumeric) {
   // If the name is empty, the value uses the default numbering.
   if (name.empty()) {
     valueIDs[value] = nextValueID++;
@@ -1703,12 +1709,13 @@ void SSANameState::setValueName(Value value, StringRef name) {
   }
 
   valueIDs[value] = NameSentinel;
-  valueNames[value] = uniqueValueName(name);
+  valueNames[value] = uniqueValueName(name, allowNumeric);
 }
 
-StringRef SSANameState::uniqueValueName(StringRef name) {
+StringRef SSANameState::uniqueValueName(StringRef name, bool allowNumeric) {
   SmallString<16> tmpBuffer;
-  name = sanitizeIdentifier(name, tmpBuffer);
+  name = sanitizeIdentifier(name, tmpBuffer, /*allowedPunctChars=*/"$._-",
+                            /*allowTrailingDigit=*/true, allowNumeric);
 
   // Check to see if this name is already unique.
   if (!usedNames.count(name)) {
>From b257ab53e508e5fb00e16a1c73e7fceeea4d0923 Mon Sep 17 00:00:00 2001
From: Perry Gibson <perry at gibsonic.org>
Date: Mon, 29 Jan 2024 13:03:04 +0000
Subject: [PATCH 11/11] Added region arg support & ambiguous name test
---
 mlir/lib/AsmParser/Parser.cpp              |  28 ++++-
 mlir/lib/IR/AsmPrinter.cpp                 | 122 ++++++++++++---------
 mlir/test/IR/print-retain-identifiers.mlir |  27 ++++-
 3 files changed, 115 insertions(+), 62 deletions(-)
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 247e99e61c2c01b..6f0c5fa30ffa560 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -611,8 +611,9 @@ class OperationParser : public Parser {
   /// an object of type 'OperationName'. Otherwise, failure is returned.
   FailureOr<OperationName> parseCustomOperationName();
 
-  /// Store the SSA names for the current operation as attrs for debug purposes.
-  void storeSSANames(Operation *&op, ArrayRef<ResultRecord> resultIDs);
+  /// Store the identifier names for the current operation as attrs for debug
+  /// purposes.
+  void storeIdentifierNames(Operation *&op, ArrayRef<ResultRecord> resultIDs);
   DenseMap<Value, StringRef> argNames;
 
   //===--------------------------------------------------------------------===//
@@ -1273,8 +1274,8 @@ OperationParser::parseSuccessors(SmallVectorImpl<Block *> &destinations) {
 }
 
 /// Store the SSA names for the current operation as attrs for debug purposes.
-void OperationParser::storeSSANames(Operation *&op,
-                                    ArrayRef<ResultRecord> resultIDs) {
+void OperationParser::storeIdentifierNames(Operation *&op,
+                                           ArrayRef<ResultRecord> resultIDs) {
 
   // Store the name(s) of the result(s) of this operation.
   if (op->getNumResults() > 0) {
@@ -1322,6 +1323,18 @@ void OperationParser::storeSSANames(Operation *&op,
       }
     }
   }
+
+  // Store names of region arguments (e.g., for FuncOps)
+  if (op->getNumRegions() > 0 && op->getRegion(0).getNumArguments() > 0) {
+    llvm::SmallVector<llvm::StringRef, 1> regionArgNames;
+    for (BlockArgument arg : op->getRegion(0).getArguments()) {
+      auto it = argNames.find(arg);
+      if (it != argNames.end()) {
+        regionArgNames.push_back(it->second.drop_front(1));
+      }
+    }
+    op->setAttr("mlir.regionArgNames", builder.getStrArrayAttr(regionArgNames));
+  }
 }
 
 namespace {
@@ -2093,9 +2106,9 @@ OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) {
   // Otherwise, create the operation and try to parse a location for it.
   Operation *op = opBuilder.create(opState);
 
-  // If enabled, store the SSA name(s) for the operation
+  // If enabled, store the original identifier name(s) for the operation
   if (state.config.shouldRetainIdentifierNames())
-    storeSSANames(op, resultIDs);
+    storeIdentifierNames(op, resultIDs);
 
   if (parseTrailingLocationSpecifier(op))
     return nullptr;
@@ -2246,6 +2259,9 @@ ParseResult OperationParser::parseRegionBody(Region ®ion, SMLoc startLoc,
       if (state.asmState)
         state.asmState->addDefinition(arg, argInfo.location);
 
+      if (state.config.shouldRetainIdentifierNames())
+        argNames.insert({arg, argInfo.name});
+
       // Record the definition for this argument.
       if (addDefinition(argInfo, arg))
         return failure();
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index e3bbe75441d588a..51f4bb66a8414ce 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1303,7 +1303,9 @@ class SSANameState {
   /// Set the original identifier names if available. Used in debugging with
   /// `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig
   void setRetainedIdentifierNames(Operation &op,
-                                  SmallVector<int, 2> &resultGroups);
+                                  SmallVector<int, 2> &resultGroups,
+                                  bool hasRegion = false);
+  void setRetainedIdentifierNames(Region ®ion);
 
   /// This is the value ID for each SSA value. If this returns NameSentinel,
   /// then the valueID has an entry in valueNames.
@@ -1492,6 +1494,9 @@ void SSANameState::numberValuesInRegion(Region ®ion) {
     setValueName(arg, name);
   };
 
+  // Use manually specified region arg names if available
+  setRetainedIdentifierNames(region);
+
   if (!printerFlags.shouldPrintGenericOpForm()) {
     if (Operation *op = region.getParentOp()) {
       if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op))
@@ -1603,64 +1608,75 @@ void SSANameState::numberValuesInOp(Operation &op) {
   }
 }
 
-void SSANameState::setRetainedIdentifierNames(
-    Operation &op, SmallVector<int, 2> &resultGroups) {
-  // Get the original names for the results if available
-  if (ArrayAttr resultNamesAttr =
-          op.getAttrOfType<ArrayAttr>("mlir.resultNames")) {
-    auto resultNames = resultNamesAttr.getValue();
-    auto results = op.getResults();
-    // Conservative in the case that the #results has changed
-    for (size_t i = 0; i < results.size() && i < resultNames.size(); ++i) {
-      auto resultName = resultNames[i].cast<StringAttr>().strref();
-      if (!resultName.empty()) {
-        if (!usedNames.count(resultName))
-          setValueName(results[i], resultName, /*allowNumeric=*/true);
-        // If a result has a name, it is the start of a result group.
-        if (i > 0)
-          resultGroups.push_back(i);
-      }
-    }
-    op.removeDiscardableAttr("mlir.resultNames");
-  }
-
-  // Get the original name for the op args if available
-  if (ArrayAttr opArgNamesAttr =
-          op.getAttrOfType<ArrayAttr>("mlir.opArgNames")) {
-    auto opArgNames = opArgNamesAttr.getValue();
-    auto opArgs = op.getOperands();
-    // Conservative in the case that the #operands has changed
-    for (size_t i = 0; i < opArgs.size() && i < opArgNames.size(); ++i) {
-      auto opArgName = opArgNames[i].cast<StringAttr>().strref();
-      if (!usedNames.count(opArgName))
-        setValueName(opArgs[i], opArgName, /*allowNumeric=*/true);
+void SSANameState::setRetainedIdentifierNames(Operation &op,
+                                              SmallVector<int, 2> &resultGroups,
+                                              bool hasRegion) {
+
+  // Lambda which fetches the list of relevant attributes (e.g.,
+  // mlir.resultNames) and associates them with the relevant values
+  auto handleNamedAttributes =
+      [this](Operation &op, const Twine &attrName, auto getValuesFunc,
+             std::optional<std::function<void(int)>> customAction =
+                 std::nullopt) {
+        if (ArrayAttr namesAttr = op.getAttrOfType<ArrayAttr>(attrName.str())) {
+          auto names = namesAttr.getValue();
+          auto values = getValuesFunc();
+          // Conservative in case the number of values has changed
+          for (size_t i = 0; i < values.size() && i < names.size(); ++i) {
+            auto name = names[i].cast<StringAttr>().strref();
+            if (!name.empty()) {
+              if (!this->usedNames.count(name))
+                this->setValueName(values[i], name, true);
+              if (customAction.has_value())
+                customAction.value()(i);
+            }
+          }
+          op.removeDiscardableAttr(attrName.str());
+        }
+      };
+
+  if (hasRegion) {
+    // Get the original name(s) for the region arg(s) if available (e.g., for
+    // FuncOp args).  Requires hasRegion flag to ensure scoping is correct
+    if (hasRegion && op.getNumRegions() > 0 &&
+        op.getRegion(0).getNumArguments() > 0) {
+      handleNamedAttributes(op, "mlir.regionArgNames",
+                            [&]() { return op.getRegion(0).getArguments(); });
     }
-    op.removeDiscardableAttr("mlir.opArgNames");
-  }
-
-  // Get the original name for the block if available
-  if (StringAttr blockNameAttr =
-          op.getAttrOfType<StringAttr>("mlir.blockName")) {
-    blockNames[op.getBlock()] = {-1, blockNameAttr.strref()};
-    op.removeDiscardableAttr("mlir.blockName");
-  }
-
-  // Get the original name for the block args if available
-  if (ArrayAttr blockArgNamesAttr =
-          op.getAttrOfType<ArrayAttr>("mlir.blockArgNames")) {
-    auto blockArgNames = blockArgNamesAttr.getValue();
-    auto blockArgs = op.getBlock()->getArguments();
-    // Conservative in the case that the #args has changed
-    for (size_t i = 0; i < blockArgs.size() && i < blockArgNames.size(); ++i) {
-      auto blockArgName = blockArgNames[i].cast<StringAttr>().strref();
-      if (!usedNames.count(blockArgName))
-        setValueName(blockArgs[i], blockArgName, /*allowNumeric=*/true);
+  } else {
+    // Get the original names for the results if available
+    handleNamedAttributes(
+        op, "mlir.resultNames", [&]() { return op.getResults(); },
+        [&resultGroups](int i) { /*handles result groups*/
+                                 if (i > 0)
+                                   resultGroups.push_back(i);
+        });
+
+    // Get the original name for the op args if available
+    handleNamedAttributes(op, "mlir.opArgNames",
+                          [&]() { return op.getOperands(); });
+
+    // Get the original name for the block if available
+    if (StringAttr blockNameAttr =
+            op.getAttrOfType<StringAttr>("mlir.blockName")) {
+      blockNames[op.getBlock()] = {-1, blockNameAttr.strref()};
+      op.removeDiscardableAttr("mlir.blockName");
     }
-    op.removeDiscardableAttr("mlir.blockArgNames");
+
+    // Get the original name(s) for the block arg(s) if available
+    handleNamedAttributes(op, "mlir.blockArgNames",
+                          [&]() { return op.getBlock()->getArguments(); });
   }
   return;
 }
 
+void SSANameState::setRetainedIdentifierNames(Region ®ion) {
+  if (Operation *op = region.getParentOp()) {
+    SmallVector<int, 2> resultGroups;
+    setRetainedIdentifierNames(*op, resultGroups, true);
+  }
+}
+
 void SSANameState::getResultIDAndNumber(
     OpResult result, Value &lookupValue,
     std::optional<int> &lookupResultNo) const {
diff --git a/mlir/test/IR/print-retain-identifiers.mlir b/mlir/test/IR/print-retain-identifiers.mlir
index b3e4f075b3936a8..65aa0507d205fa2 100644
--- a/mlir/test/IR/print-retain-identifiers.mlir
+++ b/mlir/test/IR/print-retain-identifiers.mlir
@@ -5,8 +5,8 @@
 // Test SSA results (with single return values)
 //===----------------------------------------------------------------------===//
 
-// CHECK: func.func @add_one(%my_input: f64, %arg1: f64) -> f64 {
-func.func @add_one(%my_input: f64, %arg1: f64) -> f64 {
+// CHECK: func.func @add_one(%my_input: f64) -> f64 {
+func.func @add_one(%my_input: f64) -> f64 {
   // CHECK: %my_constant = arith.constant 1.000000e+00 : f64
   %my_constant = arith.constant 1.000000e+00 : f64
   // CHECK: %my_output = arith.addf %my_input, %my_constant : f64
@@ -71,7 +71,7 @@ func.func @select_min_max(%a: f64, %b: f64) -> (f64, f64) {
 
 // -----
 
-////===----------------------------------------------------------------------===//
+//===----------------------------------------------------------------------===//
 // Test multiple return values, with a grouped value tuple
 //===----------------------------------------------------------------------===//
 
@@ -90,3 +90,24 @@ func.func @select_max(%a: f64, %b: f64, %c: f64, %d: f64) -> (f64, f64, f64, f64
 }
 
 // -----
+
+//===----------------------------------------------------------------------===//
+// Test identifiers which may clash with OpAsmOpInterface names (e.g., cst, %1, etc)
+//===----------------------------------------------------------------------===//
+
+// CHECK: func.func @clash(%arg1: f64, %arg0: f64, %arg2: f64) -> f64 {
+func.func @clash(%arg1: f64, %arg0: f64, %arg2: f64) -> f64 {
+  %my_constant = arith.constant 1.000000e+00 : f64
+  // CHECK: %cst = arith.constant 2.000000e+00 : f64
+  %cst = arith.constant 2.000000e+00 : f64
+  // CHECK: %cst_1 = arith.constant 3.000000e+00 : f64
+  %cst_1 = arith.constant 3.000000e+00 : f64
+  // CHECK: %1 = arith.addf %arg1, %cst : f64
+  %1 = arith.addf %arg1, %cst : f64
+  // CHECK: %0 = arith.addf %arg1, %cst_1 : f64
+  %0 = arith.addf %arg1, %cst_1 : f64
+  // CHECK: return %1 : f64
+  return %1 : f64
+}
+
+// -----
    
    
More information about the Mlir-commits
mailing list