[Mlir-commits] [mlir] [mlir] Retain original identifier names for debugging v2 (PR #119944)

Maksim Levental llvmlistbot at llvm.org
Sat Dec 14 00:01:54 PST 2024


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/119944

>From 2e7563bf2cd9c41db4e40e54c020bbb38c2371bf Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Fri, 13 Dec 2024 23:39:19 -0500
Subject: [PATCH] [mlir] retain identifier names

---
 mlir/include/mlir/IR/AsmState.h               |   9 +-
 mlir/include/mlir/IR/Value.h                  |  18 +++-
 .../include/mlir/Tools/mlir-opt/MlirOptMain.h |   9 ++
 mlir/lib/AsmParser/Parser.cpp                 | 100 ++++++++++++++----
 mlir/lib/IR/AsmPrinter.cpp                    |  34 ++++--
 mlir/lib/IR/Operation.cpp                     |   5 +-
 6 files changed, 138 insertions(+), 37 deletions(-)

diff --git a/mlir/include/mlir/IR/AsmState.h b/mlir/include/mlir/IR/AsmState.h
index edbd3bb6fc15db..e13a9324b1f669 100644
--- a/mlir/include/mlir/IR/AsmState.h
+++ b/mlir/include/mlir/IR/AsmState.h
@@ -471,8 +471,10 @@ class ParserConfig {
   /// `fallbackResourceMap` is an optional fallback handler that can be used to
   /// parse external resources not explicitly handled by another parser.
   ParserConfig(MLIRContext *context, bool verifyAfterParse = true,
-               FallbackAsmResourceMap *fallbackResourceMap = nullptr)
+               FallbackAsmResourceMap *fallbackResourceMap = nullptr,
+               bool retainIdentifierNames = false)
       : context(context), verifyAfterParse(verifyAfterParse),
+        retainIdentifierNames(retainIdentifierNames),
         fallbackResourceMap(fallbackResourceMap) {
     assert(context && "expected valid MLIR context");
   }
@@ -483,6 +485,10 @@ class ParserConfig {
   /// Returns if the parser should verify the IR after parsing.
   bool shouldVerifyAfterParse() const { return verifyAfterParse; }
 
+  /// Returns if the parser should retain identifier names collected using
+  /// parsing.
+  bool shouldRetainIdentifierNames() const { return retainIdentifierNames; }
+
   /// Returns the parsing configurations associated to the bytecode read.
   BytecodeReaderConfig &getBytecodeReaderConfig() const {
     return const_cast<BytecodeReaderConfig &>(bytecodeReaderConfig);
@@ -520,6 +526,7 @@ class ParserConfig {
 private:
   MLIRContext *context;
   bool verifyAfterParse;
+  bool retainIdentifierNames;
   DenseMap<StringRef, std::unique_ptr<AsmResourceParser>> resourceParsers;
   FallbackAsmResourceMap *fallbackResourceMap;
   BytecodeReaderConfig bytecodeReaderConfig;
diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h
index a7344c64e6730d..d9335fbe7a5a7e 100644
--- a/mlir/include/mlir/IR/Value.h
+++ b/mlir/include/mlir/IR/Value.h
@@ -367,7 +367,8 @@ namespace detail {
 /// This class provides the implementation for an operation result.
 class alignas(8) OpResultImpl : public ValueImpl {
 public:
-  using ValueImpl::ValueImpl;
+  OpResultImpl(Type type, Kind kind, Location loc)
+      : ValueImpl(type, kind), loc(loc) {}
 
   static bool classof(const ValueImpl *value) {
     return value->getKind() != ValueImpl::Kind::BlockArgument;
@@ -390,14 +391,17 @@ class alignas(8) OpResultImpl : public ValueImpl {
   static unsigned getMaxInlineResults() {
     return static_cast<unsigned>(Kind::OutOfLineOpResult);
   }
+
+  /// The source location of this result.
+  Location loc;
 };
 
 /// This class provides the implementation for an operation result whose index
 /// can be represented "inline" in the underlying ValueImpl.
 struct InlineOpResult : public OpResultImpl {
 public:
-  InlineOpResult(Type type, unsigned resultNo)
-      : OpResultImpl(type, static_cast<ValueImpl::Kind>(resultNo)) {
+  InlineOpResult(Type type, unsigned resultNo, Location loc)
+      : OpResultImpl(type, static_cast<ValueImpl::Kind>(resultNo), loc) {
     assert(resultNo < getMaxInlineResults());
   }
 
@@ -413,8 +417,8 @@ struct InlineOpResult : public OpResultImpl {
 /// cannot be represented "inline", and thus requires an additional index field.
 class OutOfLineOpResult : public OpResultImpl {
 public:
-  OutOfLineOpResult(Type type, uint64_t outOfLineIndex)
-      : OpResultImpl(type, Kind::OutOfLineOpResult),
+  OutOfLineOpResult(Type type, uint64_t outOfLineIndex, Location loc)
+      : OpResultImpl(type, Kind::OutOfLineOpResult, loc),
         outOfLineIndex(outOfLineIndex) {}
 
   static bool classof(const OpResultImpl *value) {
@@ -468,6 +472,10 @@ class OpResult : public Value {
   /// Returns the number of this result.
   unsigned getResultNumber() const { return getImpl()->getResultNumber(); }
 
+  /// Return the location for this result.
+  Location getLoc() const { return getImpl()->loc; }
+  void setLoc(Location loc) { getImpl()->loc = loc; }
+
 private:
   /// Get a raw pointer to the internal implementation.
   detail::OpResultImpl *getImpl() const {
diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
index 160585e7da5486..d8cc11f13c815a 100644
--- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
+++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
@@ -198,6 +198,12 @@ class MlirOptMainConfig {
   }
   bool shouldVerifyPasses() const { return verifyPassesFlag; }
 
+  MlirOptMainConfig &retainIdentifierNames(bool retain) {
+    retainIdentifierNamesFlag = retain;
+    return *this;
+  }
+  bool shouldRetainIdentifierNames() const { return retainIdentifierNamesFlag; }
+
   /// Set whether to run the verifier on parsing.
   MlirOptMainConfig &verifyOnParsing(bool verify) {
     disableVerifierOnParsingFlag = !verify;
@@ -284,6 +290,9 @@ class MlirOptMainConfig {
   /// Run the verifier after each transformation pass.
   bool verifyPassesFlag = true;
 
+  /// Retain identifier names in the output (e.g., `%my_var` instead of `%0`).
+  bool retainIdentifierNamesFlag = false;
+
   /// Disable the verifier on parsing.
   bool disableVerifierOnParsingFlag = false;
 
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index e3db248164672c..61af172df154ae 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -543,6 +543,10 @@ Type Parser::codeCompleteDialectSymbol(const llvm::StringMap<Type> &aliases) {
 //===----------------------------------------------------------------------===//
 
 namespace {
+/// This is the structure of a result specifier in the assembly syntax,
+/// including the name, number of results, and location.
+using ResultRecord = std::tuple<StringRef, unsigned, SMLoc>;
+
 /// This class provides support for parsing operations and regions of
 /// operations.
 class OperationParser : public Parser {
@@ -618,7 +622,8 @@ class OperationParser : public Parser {
   ParseResult parseSuccessors(SmallVectorImpl<Block *> &destinations);
 
   /// Parse an operation instance that is in the generic form.
-  Operation *parseGenericOperation();
+  Operation *parseGenericOperation(
+      std::optional<ArrayRef<ResultRecord>> resultIDs = std::nullopt);
 
   /// Parse different components, viz., use-info of operand(s), successor(s),
   /// region(s), attribute(s) and function-type, of the generic form of an
@@ -659,10 +664,6 @@ class OperationParser : public Parser {
   /// token is actually an alias, which means it must not contain a dot.
   ParseResult parseLocationAlias(LocationAttr &loc);
 
-  /// This is the structure of a result specifier in the assembly syntax,
-  /// including the name, number of results, and location.
-  using ResultRecord = std::tuple<StringRef, unsigned, SMLoc>;
-
   /// Parse an operation instance that is in the op-defined custom form.
   /// resultInfo specifies information about the "%name =" specifiers.
   Operation *parseCustomOperation(ArrayRef<ResultRecord> resultIDs);
@@ -1238,7 +1239,7 @@ ParseResult OperationParser::parseOperation() {
   if (nameTok.is(Token::bare_identifier) || nameTok.isKeyword())
     op = parseCustomOperation(resultIDs);
   else if (nameTok.is(Token::string))
-    op = parseGenericOperation();
+    op = parseGenericOperation(resultIDs);
   else if (nameTok.isCodeCompletionFor(Token::string))
     return codeCompleteStringDialectOrOperationName(nameTok.getStringValue());
   else if (nameTok.isCodeCompletion())
@@ -1344,6 +1345,38 @@ struct CleanupOpStateRegions {
   }
   OperationState &state;
 };
+
+std::pair<StringRef, unsigned> getResultName(ArrayRef<ResultRecord> resultIDs,
+                                             unsigned resultNo) {
+  // Scan for the resultID that contains this result number.
+  for (const auto &entry : resultIDs) {
+    if (resultNo < std::get<1>(entry)) {
+      // Don't pass on the leading %.
+      StringRef name = std::get<0>(entry).drop_front();
+      return {name, resultNo};
+    }
+    resultNo -= std::get<1>(entry);
+  }
+
+  // Invalid result number.
+  return {"", ~0U};
+}
+
+std::pair<SMLoc, unsigned> getResultLoc(ArrayRef<ResultRecord> resultIDs,
+                                        unsigned resultNo) {
+  // Scan for the resultID that contains this result number.
+  for (const auto &entry : resultIDs) {
+    if (resultNo < std::get<1>(entry)) {
+      SMLoc loc = std::get<2>(entry);
+      return {loc, resultNo};
+    }
+    resultNo -= std::get<1>(entry);
+  }
+
+  // Invalid result number.
+  return {SMLoc{}, ~0U};
+}
+
 } // namespace
 
 ParseResult OperationParser::parseGenericOperationAfterOpName(
@@ -1457,7 +1490,8 @@ ParseResult OperationParser::parseGenericOperationAfterOpName(
   return success();
 }
 
-Operation *OperationParser::parseGenericOperation() {
+Operation *OperationParser::parseGenericOperation(
+    std::optional<ArrayRef<ResultRecord>> maybeResultIDs) {
   // Get location information for the operation.
   auto srcLocation = getEncodedSourceLocation(getToken().getLoc());
 
@@ -1531,6 +1565,17 @@ Operation *OperationParser::parseGenericOperation() {
 
   // Create the operation and try to parse a location for it.
   Operation *op = opBuilder.create(result);
+  if (state.config.shouldRetainIdentifierNames() && maybeResultIDs) {
+    for (OpResult opResult : op->getResults()) {
+      unsigned resultNum = opResult.getResultNumber();
+      Location resultLoc = getEncodedSourceLocation(
+          getResultLoc(*maybeResultIDs, resultNum).first);
+      opResult.setLoc(NameLoc::get(
+          StringAttr::get(state.config.getContext(),
+                          getResultName(*maybeResultIDs, resultNum).first),
+          resultLoc));
+    }
+  }
   if (parseTrailingLocationSpecifier(op))
     return nullptr;
 
@@ -1571,7 +1616,7 @@ namespace {
 class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
 public:
   CustomOpAsmParser(
-      SMLoc nameLoc, ArrayRef<OperationParser::ResultRecord> resultIDs,
+      SMLoc nameLoc, ArrayRef<ResultRecord> resultIDs,
       function_ref<ParseResult(OpAsmParser &, OperationState &)> parseAssembly,
       bool isIsolatedFromAbove, StringRef opName, OperationParser &parser)
       : AsmParserImpl<OpAsmParser>(nameLoc, parser), resultIDs(resultIDs),
@@ -1634,18 +1679,7 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
   ///    getResultName(3) == {"z", 0 }
   std::pair<StringRef, unsigned>
   getResultName(unsigned resultNo) const override {
-    // Scan for the resultID that contains this result number.
-    for (const auto &entry : resultIDs) {
-      if (resultNo < std::get<1>(entry)) {
-        // Don't pass on the leading %.
-        StringRef name = std::get<0>(entry).drop_front();
-        return {name, resultNo};
-      }
-      resultNo -= std::get<1>(entry);
-    }
-
-    // Invalid result number.
-    return {"", ~0U};
+    return ::getResultName(resultIDs, resultNo);
   }
 
   /// Return the number of declared SSA results.  This returns 4 for the foo.op
@@ -1962,7 +1996,7 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
 
 private:
   /// Information about the result name specifiers.
-  ArrayRef<OperationParser::ResultRecord> resultIDs;
+  ArrayRef<ResultRecord> resultIDs;
 
   /// The abstract information of the operation.
   function_ref<ParseResult(OpAsmParser &, OperationState &)> parseAssembly;
@@ -2093,6 +2127,18 @@ OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) {
 
   // Otherwise, create the operation and try to parse a location for it.
   Operation *op = opBuilder.create(opState);
+
+  if (state.config.shouldRetainIdentifierNames()) {
+    for (OpResult opResult : op->getResults()) {
+      unsigned resultNum = opResult.getResultNumber();
+      Location resultLoc =
+          getEncodedSourceLocation(getResultLoc(resultIDs, resultNum).first);
+      StringRef resName = opAsmParser.getResultName(resultNum).first;
+      opResult.setLoc(NameLoc::get(
+          StringAttr::get(state.config.getContext(), resName), resultLoc));
+    }
+  }
+
   if (parseTrailingLocationSpecifier(op))
     return nullptr;
 
@@ -2235,6 +2281,11 @@ ParseResult OperationParser::parseRegionBody(Region &region, SMLoc startLoc,
       Location loc = entryArg.sourceLoc.has_value()
                          ? *entryArg.sourceLoc
                          : getEncodedSourceLocation(argInfo.location);
+      if (state.config.shouldRetainIdentifierNames()) {
+        loc = NameLoc::get(StringAttr::get(state.config.getContext(),
+                                           entryArg.ssaName.name.drop_front(1)),
+                           loc);
+      }
       BlockArgument arg = block->addArgument(entryArg.type, loc);
 
       // Add a definition of this arg to the assembly state if provided.
@@ -2414,7 +2465,12 @@ ParseResult OperationParser::parseOptionalBlockArgList(Block *owner) {
             if (arg.getType() != type)
               return emitError("argument and block argument type mismatch");
           } else {
-            auto loc = getEncodedSourceLocation(useInfo.location);
+            Location loc = getEncodedSourceLocation(useInfo.location);
+            if (state.config.shouldRetainIdentifierNames()) {
+              loc = NameLoc::get(StringAttr::get(state.config.getContext(),
+                                                 useInfo.name.drop_front(1)),
+                                 loc);
+            }
             arg = owner->addArgument(type, loc);
           }
 
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 61b90bc9b0a7bb..479409f34f6316 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1511,7 +1511,10 @@ void SSANameState::numberValuesInRegion(Region &region) {
     assert(!valueIDs.count(arg) && "arg numbered multiple times");
     assert(llvm::cast<BlockArgument>(arg).getOwner()->getParent() == &region &&
            "arg not defined in current region");
-    setValueName(arg, name);
+    if (auto nameLoc = dyn_cast<NameLoc>(arg.getLoc()))
+      setValueName(arg, nameLoc.getName());
+    else
+      setValueName(arg, name);
   };
 
   if (!printerFlags.shouldPrintGenericOpForm()) {
@@ -1553,7 +1556,10 @@ void SSANameState::numberValuesInBlock(Block &block) {
       specialNameBuffer.resize(strlen("arg"));
       specialName << nextArgumentID++;
     }
-    setValueName(arg, specialName.str());
+    if (auto nameLoc = dyn_cast<NameLoc>(arg.getLoc()))
+      setValueName(arg, nameLoc.getName());
+    else
+      setValueName(arg, specialName.str());
   }
 
   // Number the operations in this block.
@@ -1567,7 +1573,10 @@ void SSANameState::numberValuesInOp(Operation &op) {
   auto setResultNameFn = [&](Value result, StringRef name) {
     assert(!valueIDs.count(result) && "result numbered multiple times");
     assert(result.getDefiningOp() == &op && "result not defined by 'op'");
-    setValueName(result, name);
+    if (auto nameLoc = dyn_cast<NameLoc>(result.getLoc()))
+      setValueName(result, nameLoc.getName());
+    else
+      setValueName(result, name);
 
     // Record the result number for groups not anchored at 0.
     if (int resultNo = llvm::cast<OpResult>(result).getResultNumber())
@@ -1605,11 +1614,22 @@ void SSANameState::numberValuesInOp(Operation &op) {
     }
     return;
   }
-  Value resultBegin = op.getResult(0);
+  OpResult resultBegin = op.getOpResult(0);
 
-  // If the first result wasn't numbered, give it a default number.
-  if (valueIDs.try_emplace(resultBegin, nextValueID).second)
-    ++nextValueID;
+  if (!isa<OpAsmOpInterface>(&op)) {
+    if (auto nameLoc = dyn_cast<NameLoc>(resultBegin.getLoc())) {
+      setResultNameFn(resultBegin, nameLoc.getName());
+    } else {
+      // If the first result wasn't numbered, give it a default number.
+      if (valueIDs.try_emplace(resultBegin, nextValueID).second)
+        ++nextValueID;
+    }
+    for (OpResult opResult : op.getOpResults().drop_front(1)) {
+      if (auto nameLoc = dyn_cast<NameLoc>(opResult.getLoc())) {
+        setResultNameFn(opResult, nameLoc.getName());
+      }
+    }
+  }
 
   // If this operation has multiple result groups, mark it.
   if (resultGroups.size() != 1) {
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index fe0fee0f8db2ce..8291faee59d240 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -125,10 +125,11 @@ Operation *Operation::create(Location location, OperationName name,
   // Initialize the results.
   auto resultTypeIt = resultTypes.begin();
   for (unsigned i = 0; i < numInlineResults; ++i, ++resultTypeIt)
-    new (op->getInlineOpResult(i)) detail::InlineOpResult(*resultTypeIt, i);
+    new (op->getInlineOpResult(i))
+        detail::InlineOpResult(*resultTypeIt, i, location);
   for (unsigned i = 0; i < numTrailingResults; ++i, ++resultTypeIt) {
     new (op->getOutOfLineOpResult(i))
-        detail::OutOfLineOpResult(*resultTypeIt, i);
+        detail::OutOfLineOpResult(*resultTypeIt, i, location);
   }
 
   // Initialize the regions.



More information about the Mlir-commits mailing list