[Mlir-commits] [mlir] [MLIR][LLVM] Recursion importer handle repeated self-references (PR #87295)

Billy Zhu llvmlistbot at llvm.org
Fri Apr 5 11:31:27 PDT 2024


https://github.com/zyx-billy updated https://github.com/llvm/llvm-project/pull/87295

>From 01fd161b6e26f36261ee1076ec8b625e233aecbe Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Mon, 1 Apr 2024 14:56:46 -0700
Subject: [PATCH 01/12] create translation helper

---
 mlir/lib/Target/LLVMIR/DebugImporter.cpp     | 331 +++++++++++++++----
 mlir/lib/Target/LLVMIR/DebugImporter.h       |  20 +-
 mlir/test/Target/LLVMIR/Import/debug-info.ll |  41 +++
 3 files changed, 320 insertions(+), 72 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/DebugImporter.cpp b/mlir/lib/Target/LLVMIR/DebugImporter.cpp
index 779ad26fc847e6..3de62d960ffc44 100644
--- a/mlir/lib/Target/LLVMIR/DebugImporter.cpp
+++ b/mlir/lib/Target/LLVMIR/DebugImporter.cpp
@@ -25,6 +25,269 @@ using namespace mlir;
 using namespace mlir::LLVM;
 using namespace mlir::LLVM::detail;
 
+/// Get the `getRecSelf` constructor for the translated type of `node` if its
+/// translated DITypeAttr supports recursion. Otherwise, returns nullptr.
+static function_ref<DIRecursiveTypeAttrInterface(DistinctAttr)>
+getRecSelfConstructor(llvm::DINode *node) {
+  using CtorType = function_ref<DIRecursiveTypeAttrInterface(DistinctAttr)>;
+  return TypeSwitch<llvm::DINode *, CtorType>(node)
+      .Case([&](llvm::DICompositeType *) {
+        return CtorType(DICompositeTypeAttr::getRecSelf);
+      })
+      .Default(CtorType());
+}
+
+/// Translation helper for recursive DINodes.
+/// Works alongside a stack-based DINode translator (the "main translator") for
+/// gracefully handling DINodes that are recursive.
+///
+/// Usage:
+/// - Before translating a node, call `tryPrune` to see if the pruner can
+///   preempt this translation. If this is a node that the pruner already knows
+///   how to handle, it will return the translated DINodeAttr.
+/// - After a node is successfully translated by the main translator, call
+///   `finalizeTranslation` to save the translated result with the pruner, and
+///   give it a chance to further modify the result.
+/// - Regardless of success or failure by the main translator, always call
+///   `finally` at the end of translating a node. This is necessary to keep the
+///   internal book-keeping in sync.
+///
+/// This helper maintains an internal cache so that no recursive type will
+/// be translated more than once by the main translator.
+/// This internal cache is different from the cache maintained by the main
+/// translator because it may store nodes that are not self-contained (i.e.
+/// contain unbounded recursive self-references).
+class mlir::LLVM::detail::RecursionPruner {
+public:
+  RecursionPruner(MLIRContext *context) : context(context) {}
+
+  /// If this node was previously cached, returns the cached result.
+  /// If this node is a recursive instance that was previously seen, returns a
+  /// self-reference.
+  /// Otherwise, returns null attr.
+  DINodeAttr tryPrune(llvm::DINode *node) {
+    // Lookup the cache first.
+    auto [result, unboundSelfRefs] = lookup(node);
+    if (result) {
+      // Need to inject unbound self-refs into the previous layer.
+      if (!unboundSelfRefs.empty())
+        translationStack.back().second.unboundSelfRefs.insert(
+            unboundSelfRefs.begin(), unboundSelfRefs.end());
+      return result;
+    }
+
+    // If the node type is capable of being recursive, check if it's seen
+    // before.
+    auto recSelfCtor = getRecSelfConstructor(node);
+    if (recSelfCtor) {
+      // If a cyclic dependency is detected since the same node is being
+      // traversed twice, emit a recursive self type, and mark the duplicate
+      // node on the translationStack so it can emit a recursive decl type.
+      auto [iter, inserted] = translationStack.try_emplace(node);
+      if (!inserted) {
+        // The original node may have already been assigned a recursive ID from
+        // a different self-reference. Use that if possible.
+        DIRecursiveTypeAttrInterface recSelf = iter->second.recSelf;
+        if (!recSelf) {
+          DistinctAttr recId = DistinctAttr::create(UnitAttr::get(context));
+          recSelf = recSelfCtor(recId);
+          iter->second.recSelf = recSelf;
+        }
+        // Inject the self-ref into the previous layer.
+        translationStack.back().second.unboundSelfRefs.insert(recSelf);
+        return cast<DINodeAttr>(recSelf);
+      }
+    }
+    return nullptr;
+  }
+
+  /// Register the translated result of `node`. Returns the finalized result
+  /// (with recId if recursive) and whether the result is self-contained
+  /// (i.e. contains no unbound self-refs).
+  std::pair<DINodeAttr, bool> finalizeTranslation(llvm::DINode *node,
+                                                  DINodeAttr result) {
+    // If `node` is not a potentially recursive type, it will not be on the
+    // translation stack. Nothing to set in this case.
+    if (translationStack.empty())
+      return {result, true};
+    if (translationStack.back().first != node)
+      return {result, translationStack.back().second.unboundSelfRefs.empty()};
+
+    TranslationState &state = translationStack.back().second;
+
+    // If this node is actually recursive, set the recId onto `result`.
+    if (DIRecursiveTypeAttrInterface recSelf = state.recSelf) {
+      auto recType = cast<DIRecursiveTypeAttrInterface>(result);
+      result = cast<DINodeAttr>(recType.withRecId(recSelf.getRecId()));
+
+      // Remove this recSelf from the set of unbound selfRefs.
+      state.unboundSelfRefs.erase(recSelf);
+
+      // Insert the newly resolved recursive type into the cache entries that
+      // rely on it.
+      // Only need to look at the caches at this level.
+      uint64_t numRemaining = state.cacheSize;
+      for (auto &cacheEntry : llvm::make_range(cache.rbegin(), cache.rend())) {
+        if (numRemaining == 0)
+          break;
+
+        if (auto refIter = cacheEntry.second.pendingReplacements.find(recSelf);
+            refIter != cacheEntry.second.pendingReplacements.end()) {
+          refIter->second = result;
+        }
+
+        --numRemaining;
+      }
+    }
+
+    // Insert the current result into the cache.
+    state.cacheSize++;
+    auto [iter, inserted] = cache.try_emplace(node);
+    assert(inserted && "invalid state: caching the same DINode twice");
+    iter->second.attr = result;
+
+    // If this node had any unbound self-refs free when it is registered into
+    // the cache, set up replacement placeholders: This result will need these
+    // unbound self-refs to be replaced before being used.
+    for (DIRecursiveTypeAttrInterface selfRef : state.unboundSelfRefs)
+      iter->second.pendingReplacements.try_emplace(selfRef, nullptr);
+
+    return {result, state.unboundSelfRefs.empty()};
+  }
+
+  /// Pop off a frame from the translation stack after a node is done being
+  /// translated.
+  void finally(llvm::DINode *node) {
+    // If `node` is not a potentially recursive type, it will not be on the
+    // translation stack. Nothing to handle in this case.
+    if (translationStack.empty() || translationStack.back().first != node)
+      return;
+
+    // At the end of the stack, all unbound self-refs must be resolved already,
+    // and the entire cache should be accounted for.
+    TranslationState &currLayerState = translationStack.back().second;
+    if (translationStack.size() == 1) {
+      assert(currLayerState.unboundSelfRefs.empty() &&
+             "internal error: unbound recursive self reference at top level.");
+      assert(currLayerState.cacheSize == cache.size() &&
+             "internal error: inconsistent cache size");
+      translationStack.pop_back();
+      cache.clear();
+      return;
+    }
+
+    // Copy unboundSelfRefs down to the previous level.
+    TranslationState &nextLayerState = (++translationStack.rbegin())->second;
+    nextLayerState.unboundSelfRefs.insert(
+        currLayerState.unboundSelfRefs.begin(),
+        currLayerState.unboundSelfRefs.end());
+
+    // The current layer cache is now considered part of the lower layer cache.
+    nextLayerState.cacheSize += currLayerState.cacheSize;
+
+    // Finally pop off this layer when all bookkeeping is done.
+    translationStack.pop_back();
+  }
+
+private:
+  /// Returns the cached result (if exists) with all known replacements applied.
+  /// Also returns the set of unbound self-refs that are unresolved in this
+  /// cached result.
+  /// The cache entry will also be updated with the replaced result and with the
+  /// applied replacements removed from the pendingReplacements map.
+  std::pair<DINodeAttr, llvm::DenseSet<DIRecursiveTypeAttrInterface>>
+  lookup(llvm::DINode *node) {
+    auto cacheIter = cache.find(node);
+    if (cacheIter == cache.end())
+      return {};
+
+    CachedTranslation &entry = cacheIter->second;
+
+    if (entry.pendingReplacements.empty())
+      return std::make_pair(entry.attr,
+                            llvm::DenseSet<DIRecursiveTypeAttrInterface>{});
+
+    mlir::AttrTypeReplacer replacer;
+    replacer.addReplacement(
+        [&entry](DIRecursiveTypeAttrInterface attr)
+            -> std::optional<std::pair<Attribute, WalkResult>> {
+          if (auto replacement = entry.pendingReplacements.lookup(attr)) {
+            // A replacement may contain additional unbound self-refs.
+            return std::make_pair(replacement, mlir::WalkResult::advance());
+          }
+          return std::make_pair(attr, mlir::WalkResult::skip());
+        });
+
+    Attribute replacedAttr = replacer.replace(entry.attr);
+    DINodeAttr result = cast<DINodeAttr>(replacedAttr);
+
+    // Update cache entry to save replaced version and remove already-applied
+    // replacements.
+    entry.attr = result;
+    DenseSet<DIRecursiveTypeAttrInterface> unboundRefs;
+    DenseSet<DIRecursiveTypeAttrInterface> boundRefs;
+    for (auto [refSelf, replacement] : entry.pendingReplacements) {
+      if (replacement)
+        boundRefs.insert(refSelf);
+      else
+        unboundRefs.insert(refSelf);
+    }
+
+    for (DIRecursiveTypeAttrInterface ref : boundRefs)
+      entry.pendingReplacements.erase(ref);
+
+    return std::make_pair(result, unboundRefs);
+  }
+
+  MLIRContext *context;
+
+  /// A lazy cached translation that contains the translated attribute as well
+  /// as any unbound self-references that need to be replaced at lookup.
+  struct CachedTranslation {
+    /// The translated attr. May contain unbound self-references for other
+    /// recursive attrs.
+    DINodeAttr attr;
+    /// The pending replacements that need to be run on this `attr` before it
+    /// can be used.
+    /// Each key is a recursive self-ref, and each value is a recursive decl
+    /// that may contain additional unbound self-refs that need to be replaced.
+    /// Each replacement will be applied at most once. Once a replacement is
+    /// applied, the cached `attr` will be updated, and the replacement will be
+    /// removed from this map.
+    DenseMap<DIRecursiveTypeAttrInterface, DINodeAttr> pendingReplacements;
+  };
+  /// A mapping between LLVM debug metadata and the corresponding attribute.
+  llvm::MapVector<llvm::DINode *, CachedTranslation> cache;
+
+  /// Each potentially recursive node will have a TranslationState pushed onto
+  /// the `translationStack` to keep track of whether this node is actually
+  /// recursive (i.e. has self-references inside), and other book-keeping.
+  struct TranslationState {
+    /// The rec-self if this node is indeed a recursive node (i.e. another
+    /// instance of itself is seen while translating it). Null if this node
+    /// has not been seen again deeper in the translation stack.
+    DIRecursiveTypeAttrInterface recSelf;
+    /// The number of cache entries belonging to this layer of the translation
+    /// stack. This corresponds to the last `cacheSize` entries of `cache`.
+    uint64_t cacheSize = 0;
+    /// All the unbound recursive self references in this layer of the
+    /// translation stack.
+    DenseSet<DIRecursiveTypeAttrInterface> unboundSelfRefs;
+  };
+  /// A stack that stores the metadata nodes that are being traversed. The stack
+  /// is used to handle cyclic dependencies during metadata translation.
+  /// Each node is pushed with an empty TranslationState. If it is ever seen
+  /// later when the stack is deeper, the node is recursive, and its
+  /// TranslationState is assigned a recSelf.
+  llvm::MapVector<llvm::DINode *, TranslationState> translationStack;
+};
+
+DebugImporter::DebugImporter(ModuleOp mlirModule)
+    : recursionPruner(new RecursionPruner(mlirModule.getContext())),
+      context(mlirModule.getContext()), mlirModule(mlirModule) {}
+
+DebugImporter::~DebugImporter() {}
+
 Location DebugImporter::translateFuncLocation(llvm::Function *func) {
   llvm::DISubprogram *subprogram = func->getSubprogram();
   if (!subprogram)
@@ -246,42 +509,12 @@ DINodeAttr DebugImporter::translate(llvm::DINode *node) {
   if (DINodeAttr attr = nodeToAttr.lookup(node))
     return attr;
 
-  // If the node type is capable of being recursive, check if it's seen before.
-  auto recSelfCtor = getRecSelfConstructor(node);
-  if (recSelfCtor) {
-    // If a cyclic dependency is detected since the same node is being traversed
-    // twice, emit a recursive self type, and mark the duplicate node on the
-    // translationStack so it can emit a recursive decl type.
-    auto [iter, inserted] = translationStack.try_emplace(node, nullptr);
-    if (!inserted) {
-      // The original node may have already been assigned a recursive ID from
-      // a different self-reference. Use that if possible.
-      DistinctAttr recId = iter->second;
-      if (!recId) {
-        recId = DistinctAttr::create(UnitAttr::get(context));
-        iter->second = recId;
-      }
-      unboundRecursiveSelfRefs.back().insert(recId);
-      return cast<DINodeAttr>(recSelfCtor(recId));
-    }
-  }
-
-  unboundRecursiveSelfRefs.emplace_back();
-
-  auto guard = llvm::make_scope_exit([&]() {
-    if (recSelfCtor)
-      translationStack.pop_back();
+  // Register with the recursive translator. If it is seen before, return the
+  // result immediately.
+  if (DINodeAttr attr = recursionPruner->tryPrune(node))
+    return attr;
 
-    // Copy unboundRecursiveSelfRefs down to the previous level.
-    if (unboundRecursiveSelfRefs.size() == 1)
-      assert(unboundRecursiveSelfRefs.back().empty() &&
-             "internal error: unbound recursive self reference at top level.");
-    else
-      unboundRecursiveSelfRefs[unboundRecursiveSelfRefs.size() - 2].insert(
-          unboundRecursiveSelfRefs.back().begin(),
-          unboundRecursiveSelfRefs.back().end());
-    unboundRecursiveSelfRefs.pop_back();
-  });
+  auto guard = llvm::make_scope_exit([&]() { recursionPruner->finally(node); });
 
   // Convert the debug metadata if possible.
   auto translateNode = [this](llvm::DINode *node) -> DINodeAttr {
@@ -318,20 +551,12 @@ DINodeAttr DebugImporter::translate(llvm::DINode *node) {
     return nullptr;
   };
   if (DINodeAttr attr = translateNode(node)) {
-    // If this node was marked as recursive, set its recId.
-    if (auto recType = dyn_cast<DIRecursiveTypeAttrInterface>(attr)) {
-      if (DistinctAttr recId = translationStack.lookup(node)) {
-        attr = cast<DINodeAttr>(recType.withRecId(recId));
-        // Remove the unbound recursive ID from the set of unbound self
-        // references in the translation stack.
-        unboundRecursiveSelfRefs.back().erase(recId);
-      }
-    }
-
+    auto [result, isSelfContained] =
+        recursionPruner->finalizeTranslation(node, attr);
     // Only cache fully self-contained nodes.
-    if (unboundRecursiveSelfRefs.back().empty())
-      nodeToAttr.try_emplace(node, attr);
-    return attr;
+    if (isSelfContained)
+      nodeToAttr.try_emplace(node, result);
+    return result;
   }
   return nullptr;
 }
@@ -394,13 +619,3 @@ DistinctAttr DebugImporter::getOrCreateDistinctID(llvm::DINode *node) {
     id = DistinctAttr::create(UnitAttr::get(context));
   return id;
 }
-
-function_ref<DIRecursiveTypeAttrInterface(DistinctAttr)>
-DebugImporter::getRecSelfConstructor(llvm::DINode *node) {
-  using CtorType = function_ref<DIRecursiveTypeAttrInterface(DistinctAttr)>;
-  return TypeSwitch<llvm::DINode *, CtorType>(node)
-      .Case([&](llvm::DICompositeType *concreteNode) {
-        return CtorType(DICompositeTypeAttr::getRecSelf);
-      })
-      .Default(CtorType());
-}
diff --git a/mlir/lib/Target/LLVMIR/DebugImporter.h b/mlir/lib/Target/LLVMIR/DebugImporter.h
index bcf628fc4234fb..1344362ff8d749 100644
--- a/mlir/lib/Target/LLVMIR/DebugImporter.h
+++ b/mlir/lib/Target/LLVMIR/DebugImporter.h
@@ -27,10 +27,12 @@ class LLVMFuncOp;
 
 namespace detail {
 
+class RecursionPruner;
+
 class DebugImporter {
 public:
-  DebugImporter(ModuleOp mlirModule)
-      : context(mlirModule.getContext()), mlirModule(mlirModule) {}
+  DebugImporter(ModuleOp mlirModule);
+  ~DebugImporter();
 
   /// Translates the given LLVM debug location to an MLIR location.
   Location translateLoc(llvm::DILocation *loc);
@@ -86,24 +88,14 @@ class DebugImporter {
   /// for it, or create a new one if not.
   DistinctAttr getOrCreateDistinctID(llvm::DINode *node);
 
-  /// Get the `getRecSelf` constructor for the translated type of `node` if its
-  /// translated DITypeAttr supports recursion. Otherwise, returns nullptr.
-  function_ref<DIRecursiveTypeAttrInterface(DistinctAttr)>
-  getRecSelfConstructor(llvm::DINode *node);
-
   /// A mapping between LLVM debug metadata and the corresponding attribute.
   DenseMap<llvm::DINode *, DINodeAttr> nodeToAttr;
   /// A mapping between distinct LLVM debug metadata nodes and the corresponding
   /// distinct id attribute.
   DenseMap<llvm::DINode *, DistinctAttr> nodeToDistinctAttr;
 
-  /// A stack that stores the metadata nodes that are being traversed. The stack
-  /// is used to detect cyclic dependencies during the metadata translation.
-  /// A node is pushed with a null value. If it is ever seen twice, it is given
-  /// a recursive id attribute, indicating that it is a recursive node.
-  llvm::MapVector<llvm::DINode *, DistinctAttr> translationStack;
-  /// All the unbound recursive self references in the translation stack.
-  SmallVector<DenseSet<DistinctAttr>> unboundRecursiveSelfRefs;
+  // Translation copilot for recursive types.
+  std::unique_ptr<RecursionPruner> recursionPruner;
 
   MLIRContext *context;
   ModuleOp mlirModule;
diff --git a/mlir/test/Target/LLVMIR/Import/debug-info.ll b/mlir/test/Target/LLVMIR/Import/debug-info.ll
index 959a5a1cd97176..b322840d852611 100644
--- a/mlir/test/Target/LLVMIR/Import/debug-info.ll
+++ b/mlir/test/Target/LLVMIR/Import/debug-info.ll
@@ -607,3 +607,44 @@ declare !dbg !1 void @declaration()
 !0 = !{i32 2, !"Debug Info Version", i32 3}
 !1 = !DISubprogram(name: "declaration", scope: !2, file: !2, flags: DIFlagPrototyped, spFlags: 0)
 !2 = !DIFile(filename: "debug-info.ll", directory: "/")
+
+; // -----
+
+; Ensure that repeated occurence of recursive subtree does not result in duplicate MLIR entries.
+
+; CHECK-DAG: #[[B1_INNER:.+]] = #llvm.di_derived_type<{{.*}}name = "B:B1", baseType = #[[B_SELF:.+]]>
+; CHECK-DAG: #[[B2_INNER:.+]] = #llvm.di_derived_type<{{.*}}name = "B:B2", baseType = #[[B_SELF]]>
+; CHECK-DAG: #[[B_INNER:.+]] = #llvm.di_composite_type<{{.*}}recId = [[B_RECID:.+]], {{.*}}name = "B", {{.*}}elements = #[[B1_INNER]], #[[B2_INNER]]
+
+; CHECK-DAG: #[[B1_OUTER:.+]] = #llvm.di_derived_type<{{.*}}name = "B:B1", baseType = #[[B_INNER]]>
+; CHECK-DAG: #[[B2_OUTER:.+]] = #llvm.di_derived_type<{{.*}}name = "B:B2", baseType = #[[B_INNER]]>
+; CHECK-DAG: #[[A_OUTER:.+]] = #llvm.di_composite_type<{{.*}}recId = [[A_RECID:.+]], {{.*}}name = "A", {{.*}}elements = #[[B1_OUTER]], #[[B2_OUTER]]
+
+; CHECK-DAG: #[[A_SELF:.+]] = #llvm.di_composite_type<{{.*}}recId = [[A_RECID]]
+; CHECK-DAG: #[[B_SELF:.+]] = #llvm.di_composite_type<{{.*}}recId = [[B_RECID]]
+
+; CHECK: #llvm.di_subprogram<{{.*}}scope = #[[A_OUTER]]
+
+
+define void @class_field(ptr %arg1) !dbg !18 {
+  ret void
+}
+
+declare void @llvm.dbg.value(metadata, metadata, metadata)
+
+!llvm.dbg.cu = !{!1}
+!llvm.module.flags = !{!0}
+!0 = !{i32 2, !"Debug Info Version", i32 3}
+!1 = distinct !DICompileUnit(language: DW_LANG_C, file: !2)
+!2 = !DIFile(filename: "debug-info.ll", directory: "/")
+
+
+!3 = !DICompositeType(tag: DW_TAG_class_type, name: "A", file: !2, line: 42, flags: DIFlagTypePassByReference | DIFlagNonTrivial, elements: !4)
+!4 = !{!7, !8}
+
+!5 = !DICompositeType(tag: DW_TAG_class_type, name: "B", scope: !3, file: !2, line: 42, flags: DIFlagTypePassByReference | DIFlagNonTrivial, elements: !9)
+!7 = !DIDerivedType(tag: DW_TAG_member, name: "B:B1", file: !2, baseType: !5)
+!8 = !DIDerivedType(tag: DW_TAG_member, name: "B:B2", file: !2, baseType: !5)
+!9 = !{!7, !8}
+
+!18 = distinct !DISubprogram(name: "A", scope: !3, file: !2, spFlags: DISPFlagDefinition, unit: !1)

>From ee93b95ad5f3aee6d7deba6e1a5bd623502d1947 Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Tue, 2 Apr 2024 13:30:43 -0700
Subject: [PATCH 02/12] nit suggestions

---
 mlir/lib/Target/LLVMIR/DebugImporter.cpp     | 22 ++++++++++----------
 mlir/lib/Target/LLVMIR/DebugImporter.h       |  2 +-
 mlir/test/Target/LLVMIR/Import/debug-info.ll |  2 --
 3 files changed, 12 insertions(+), 14 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/DebugImporter.cpp b/mlir/lib/Target/LLVMIR/DebugImporter.cpp
index 3de62d960ffc44..60c469a03d53a1 100644
--- a/mlir/lib/Target/LLVMIR/DebugImporter.cpp
+++ b/mlir/lib/Target/LLVMIR/DebugImporter.cpp
@@ -127,16 +127,15 @@ class mlir::LLVM::detail::RecursionPruner {
       // rely on it.
       // Only need to look at the caches at this level.
       uint64_t numRemaining = state.cacheSize;
-      for (auto &cacheEntry : llvm::make_range(cache.rbegin(), cache.rend())) {
+      for (CachedTranslation &cacheEntry :
+           llvm::make_second_range(llvm::reverse(cache))) {
         if (numRemaining == 0)
           break;
+        --numRemaining;
 
-        if (auto refIter = cacheEntry.second.pendingReplacements.find(recSelf);
-            refIter != cacheEntry.second.pendingReplacements.end()) {
+        if (auto refIter = cacheEntry.pendingReplacements.find(recSelf);
+            refIter != cacheEntry.pendingReplacements.end())
           refIter->second = result;
-        }
-
-        --numRemaining;
       }
     }
 
@@ -157,7 +156,7 @@ class mlir::LLVM::detail::RecursionPruner {
 
   /// Pop off a frame from the translation stack after a node is done being
   /// translated.
-  void finally(llvm::DINode *node) {
+  void popTranslationStack(llvm::DINode *node) {
     // If `node` is not a potentially recursive type, it will not be on the
     // translation stack. Nothing to handle in this case.
     if (translationStack.empty() || translationStack.back().first != node)
@@ -195,7 +194,7 @@ class mlir::LLVM::detail::RecursionPruner {
   /// cached result.
   /// The cache entry will also be updated with the replaced result and with the
   /// applied replacements removed from the pendingReplacements map.
-  std::pair<DINodeAttr, llvm::DenseSet<DIRecursiveTypeAttrInterface>>
+  std::pair<DINodeAttr, DenseSet<DIRecursiveTypeAttrInterface>>
   lookup(llvm::DINode *node) {
     auto cacheIter = cache.find(node);
     if (cacheIter == cache.end())
@@ -205,9 +204,9 @@ class mlir::LLVM::detail::RecursionPruner {
 
     if (entry.pendingReplacements.empty())
       return std::make_pair(entry.attr,
-                            llvm::DenseSet<DIRecursiveTypeAttrInterface>{});
+                            DenseSet<DIRecursiveTypeAttrInterface>{});
 
-    mlir::AttrTypeReplacer replacer;
+    AttrTypeReplacer replacer;
     replacer.addReplacement(
         [&entry](DIRecursiveTypeAttrInterface attr)
             -> std::optional<std::pair<Attribute, WalkResult>> {
@@ -514,7 +513,8 @@ DINodeAttr DebugImporter::translate(llvm::DINode *node) {
   if (DINodeAttr attr = recursionPruner->tryPrune(node))
     return attr;
 
-  auto guard = llvm::make_scope_exit([&]() { recursionPruner->finally(node); });
+  auto guard = llvm::make_scope_exit(
+      [&]() { recursionPruner->popTranslationStack(node); });
 
   // Convert the debug metadata if possible.
   auto translateNode = [this](llvm::DINode *node) -> DINodeAttr {
diff --git a/mlir/lib/Target/LLVMIR/DebugImporter.h b/mlir/lib/Target/LLVMIR/DebugImporter.h
index 1344362ff8d749..a95346cd3cd9c4 100644
--- a/mlir/lib/Target/LLVMIR/DebugImporter.h
+++ b/mlir/lib/Target/LLVMIR/DebugImporter.h
@@ -94,7 +94,7 @@ class DebugImporter {
   /// distinct id attribute.
   DenseMap<llvm::DINode *, DistinctAttr> nodeToDistinctAttr;
 
-  // Translation copilot for recursive types.
+  /// A translation helper for recursive types.
   std::unique_ptr<RecursionPruner> recursionPruner;
 
   MLIRContext *context;
diff --git a/mlir/test/Target/LLVMIR/Import/debug-info.ll b/mlir/test/Target/LLVMIR/Import/debug-info.ll
index b322840d852611..931f914c995bb6 100644
--- a/mlir/test/Target/LLVMIR/Import/debug-info.ll
+++ b/mlir/test/Target/LLVMIR/Import/debug-info.ll
@@ -625,7 +625,6 @@ declare !dbg !1 void @declaration()
 
 ; CHECK: #llvm.di_subprogram<{{.*}}scope = #[[A_OUTER]]
 
-
 define void @class_field(ptr %arg1) !dbg !18 {
   ret void
 }
@@ -638,7 +637,6 @@ declare void @llvm.dbg.value(metadata, metadata, metadata)
 !1 = distinct !DICompileUnit(language: DW_LANG_C, file: !2)
 !2 = !DIFile(filename: "debug-info.ll", directory: "/")
 
-
 !3 = !DICompositeType(tag: DW_TAG_class_type, name: "A", file: !2, line: 42, flags: DIFlagTypePassByReference | DIFlagNonTrivial, elements: !4)
 !4 = !{!7, !8}
 

>From 5f01209001046325101bb2e57b07d5a258a9c3f9 Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Tue, 2 Apr 2024 13:37:27 -0700
Subject: [PATCH 03/12] move decl into header and make direct member

---
 mlir/lib/Target/LLVMIR/DebugImporter.cpp | 450 ++++++++++-------------
 mlir/lib/Target/LLVMIR/DebugImporter.h   |  96 ++++-
 2 files changed, 282 insertions(+), 264 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/DebugImporter.cpp b/mlir/lib/Target/LLVMIR/DebugImporter.cpp
index 60c469a03d53a1..f57411dc7d97c0 100644
--- a/mlir/lib/Target/LLVMIR/DebugImporter.cpp
+++ b/mlir/lib/Target/LLVMIR/DebugImporter.cpp
@@ -25,264 +25,8 @@ using namespace mlir;
 using namespace mlir::LLVM;
 using namespace mlir::LLVM::detail;
 
-/// Get the `getRecSelf` constructor for the translated type of `node` if its
-/// translated DITypeAttr supports recursion. Otherwise, returns nullptr.
-static function_ref<DIRecursiveTypeAttrInterface(DistinctAttr)>
-getRecSelfConstructor(llvm::DINode *node) {
-  using CtorType = function_ref<DIRecursiveTypeAttrInterface(DistinctAttr)>;
-  return TypeSwitch<llvm::DINode *, CtorType>(node)
-      .Case([&](llvm::DICompositeType *) {
-        return CtorType(DICompositeTypeAttr::getRecSelf);
-      })
-      .Default(CtorType());
-}
-
-/// Translation helper for recursive DINodes.
-/// Works alongside a stack-based DINode translator (the "main translator") for
-/// gracefully handling DINodes that are recursive.
-///
-/// Usage:
-/// - Before translating a node, call `tryPrune` to see if the pruner can
-///   preempt this translation. If this is a node that the pruner already knows
-///   how to handle, it will return the translated DINodeAttr.
-/// - After a node is successfully translated by the main translator, call
-///   `finalizeTranslation` to save the translated result with the pruner, and
-///   give it a chance to further modify the result.
-/// - Regardless of success or failure by the main translator, always call
-///   `finally` at the end of translating a node. This is necessary to keep the
-///   internal book-keeping in sync.
-///
-/// This helper maintains an internal cache so that no recursive type will
-/// be translated more than once by the main translator.
-/// This internal cache is different from the cache maintained by the main
-/// translator because it may store nodes that are not self-contained (i.e.
-/// contain unbounded recursive self-references).
-class mlir::LLVM::detail::RecursionPruner {
-public:
-  RecursionPruner(MLIRContext *context) : context(context) {}
-
-  /// If this node was previously cached, returns the cached result.
-  /// If this node is a recursive instance that was previously seen, returns a
-  /// self-reference.
-  /// Otherwise, returns null attr.
-  DINodeAttr tryPrune(llvm::DINode *node) {
-    // Lookup the cache first.
-    auto [result, unboundSelfRefs] = lookup(node);
-    if (result) {
-      // Need to inject unbound self-refs into the previous layer.
-      if (!unboundSelfRefs.empty())
-        translationStack.back().second.unboundSelfRefs.insert(
-            unboundSelfRefs.begin(), unboundSelfRefs.end());
-      return result;
-    }
-
-    // If the node type is capable of being recursive, check if it's seen
-    // before.
-    auto recSelfCtor = getRecSelfConstructor(node);
-    if (recSelfCtor) {
-      // If a cyclic dependency is detected since the same node is being
-      // traversed twice, emit a recursive self type, and mark the duplicate
-      // node on the translationStack so it can emit a recursive decl type.
-      auto [iter, inserted] = translationStack.try_emplace(node);
-      if (!inserted) {
-        // The original node may have already been assigned a recursive ID from
-        // a different self-reference. Use that if possible.
-        DIRecursiveTypeAttrInterface recSelf = iter->second.recSelf;
-        if (!recSelf) {
-          DistinctAttr recId = DistinctAttr::create(UnitAttr::get(context));
-          recSelf = recSelfCtor(recId);
-          iter->second.recSelf = recSelf;
-        }
-        // Inject the self-ref into the previous layer.
-        translationStack.back().second.unboundSelfRefs.insert(recSelf);
-        return cast<DINodeAttr>(recSelf);
-      }
-    }
-    return nullptr;
-  }
-
-  /// Register the translated result of `node`. Returns the finalized result
-  /// (with recId if recursive) and whether the result is self-contained
-  /// (i.e. contains no unbound self-refs).
-  std::pair<DINodeAttr, bool> finalizeTranslation(llvm::DINode *node,
-                                                  DINodeAttr result) {
-    // If `node` is not a potentially recursive type, it will not be on the
-    // translation stack. Nothing to set in this case.
-    if (translationStack.empty())
-      return {result, true};
-    if (translationStack.back().first != node)
-      return {result, translationStack.back().second.unboundSelfRefs.empty()};
-
-    TranslationState &state = translationStack.back().second;
-
-    // If this node is actually recursive, set the recId onto `result`.
-    if (DIRecursiveTypeAttrInterface recSelf = state.recSelf) {
-      auto recType = cast<DIRecursiveTypeAttrInterface>(result);
-      result = cast<DINodeAttr>(recType.withRecId(recSelf.getRecId()));
-
-      // Remove this recSelf from the set of unbound selfRefs.
-      state.unboundSelfRefs.erase(recSelf);
-
-      // Insert the newly resolved recursive type into the cache entries that
-      // rely on it.
-      // Only need to look at the caches at this level.
-      uint64_t numRemaining = state.cacheSize;
-      for (CachedTranslation &cacheEntry :
-           llvm::make_second_range(llvm::reverse(cache))) {
-        if (numRemaining == 0)
-          break;
-        --numRemaining;
-
-        if (auto refIter = cacheEntry.pendingReplacements.find(recSelf);
-            refIter != cacheEntry.pendingReplacements.end())
-          refIter->second = result;
-      }
-    }
-
-    // Insert the current result into the cache.
-    state.cacheSize++;
-    auto [iter, inserted] = cache.try_emplace(node);
-    assert(inserted && "invalid state: caching the same DINode twice");
-    iter->second.attr = result;
-
-    // If this node had any unbound self-refs free when it is registered into
-    // the cache, set up replacement placeholders: This result will need these
-    // unbound self-refs to be replaced before being used.
-    for (DIRecursiveTypeAttrInterface selfRef : state.unboundSelfRefs)
-      iter->second.pendingReplacements.try_emplace(selfRef, nullptr);
-
-    return {result, state.unboundSelfRefs.empty()};
-  }
-
-  /// Pop off a frame from the translation stack after a node is done being
-  /// translated.
-  void popTranslationStack(llvm::DINode *node) {
-    // If `node` is not a potentially recursive type, it will not be on the
-    // translation stack. Nothing to handle in this case.
-    if (translationStack.empty() || translationStack.back().first != node)
-      return;
-
-    // At the end of the stack, all unbound self-refs must be resolved already,
-    // and the entire cache should be accounted for.
-    TranslationState &currLayerState = translationStack.back().second;
-    if (translationStack.size() == 1) {
-      assert(currLayerState.unboundSelfRefs.empty() &&
-             "internal error: unbound recursive self reference at top level.");
-      assert(currLayerState.cacheSize == cache.size() &&
-             "internal error: inconsistent cache size");
-      translationStack.pop_back();
-      cache.clear();
-      return;
-    }
-
-    // Copy unboundSelfRefs down to the previous level.
-    TranslationState &nextLayerState = (++translationStack.rbegin())->second;
-    nextLayerState.unboundSelfRefs.insert(
-        currLayerState.unboundSelfRefs.begin(),
-        currLayerState.unboundSelfRefs.end());
-
-    // The current layer cache is now considered part of the lower layer cache.
-    nextLayerState.cacheSize += currLayerState.cacheSize;
-
-    // Finally pop off this layer when all bookkeeping is done.
-    translationStack.pop_back();
-  }
-
-private:
-  /// Returns the cached result (if exists) with all known replacements applied.
-  /// Also returns the set of unbound self-refs that are unresolved in this
-  /// cached result.
-  /// The cache entry will also be updated with the replaced result and with the
-  /// applied replacements removed from the pendingReplacements map.
-  std::pair<DINodeAttr, DenseSet<DIRecursiveTypeAttrInterface>>
-  lookup(llvm::DINode *node) {
-    auto cacheIter = cache.find(node);
-    if (cacheIter == cache.end())
-      return {};
-
-    CachedTranslation &entry = cacheIter->second;
-
-    if (entry.pendingReplacements.empty())
-      return std::make_pair(entry.attr,
-                            DenseSet<DIRecursiveTypeAttrInterface>{});
-
-    AttrTypeReplacer replacer;
-    replacer.addReplacement(
-        [&entry](DIRecursiveTypeAttrInterface attr)
-            -> std::optional<std::pair<Attribute, WalkResult>> {
-          if (auto replacement = entry.pendingReplacements.lookup(attr)) {
-            // A replacement may contain additional unbound self-refs.
-            return std::make_pair(replacement, mlir::WalkResult::advance());
-          }
-          return std::make_pair(attr, mlir::WalkResult::skip());
-        });
-
-    Attribute replacedAttr = replacer.replace(entry.attr);
-    DINodeAttr result = cast<DINodeAttr>(replacedAttr);
-
-    // Update cache entry to save replaced version and remove already-applied
-    // replacements.
-    entry.attr = result;
-    DenseSet<DIRecursiveTypeAttrInterface> unboundRefs;
-    DenseSet<DIRecursiveTypeAttrInterface> boundRefs;
-    for (auto [refSelf, replacement] : entry.pendingReplacements) {
-      if (replacement)
-        boundRefs.insert(refSelf);
-      else
-        unboundRefs.insert(refSelf);
-    }
-
-    for (DIRecursiveTypeAttrInterface ref : boundRefs)
-      entry.pendingReplacements.erase(ref);
-
-    return std::make_pair(result, unboundRefs);
-  }
-
-  MLIRContext *context;
-
-  /// A lazy cached translation that contains the translated attribute as well
-  /// as any unbound self-references that need to be replaced at lookup.
-  struct CachedTranslation {
-    /// The translated attr. May contain unbound self-references for other
-    /// recursive attrs.
-    DINodeAttr attr;
-    /// The pending replacements that need to be run on this `attr` before it
-    /// can be used.
-    /// Each key is a recursive self-ref, and each value is a recursive decl
-    /// that may contain additional unbound self-refs that need to be replaced.
-    /// Each replacement will be applied at most once. Once a replacement is
-    /// applied, the cached `attr` will be updated, and the replacement will be
-    /// removed from this map.
-    DenseMap<DIRecursiveTypeAttrInterface, DINodeAttr> pendingReplacements;
-  };
-  /// A mapping between LLVM debug metadata and the corresponding attribute.
-  llvm::MapVector<llvm::DINode *, CachedTranslation> cache;
-
-  /// Each potentially recursive node will have a TranslationState pushed onto
-  /// the `translationStack` to keep track of whether this node is actually
-  /// recursive (i.e. has self-references inside), and other book-keeping.
-  struct TranslationState {
-    /// The rec-self if this node is indeed a recursive node (i.e. another
-    /// instance of itself is seen while translating it). Null if this node
-    /// has not been seen again deeper in the translation stack.
-    DIRecursiveTypeAttrInterface recSelf;
-    /// The number of cache entries belonging to this layer of the translation
-    /// stack. This corresponds to the last `cacheSize` entries of `cache`.
-    uint64_t cacheSize = 0;
-    /// All the unbound recursive self references in this layer of the
-    /// translation stack.
-    DenseSet<DIRecursiveTypeAttrInterface> unboundSelfRefs;
-  };
-  /// A stack that stores the metadata nodes that are being traversed. The stack
-  /// is used to handle cyclic dependencies during metadata translation.
-  /// Each node is pushed with an empty TranslationState. If it is ever seen
-  /// later when the stack is deeper, the node is recursive, and its
-  /// TranslationState is assigned a recSelf.
-  llvm::MapVector<llvm::DINode *, TranslationState> translationStack;
-};
-
 DebugImporter::DebugImporter(ModuleOp mlirModule)
-    : recursionPruner(new RecursionPruner(mlirModule.getContext())),
+    : recursionPruner(mlirModule.getContext()),
       context(mlirModule.getContext()), mlirModule(mlirModule) {}
 
 DebugImporter::~DebugImporter() {}
@@ -510,11 +254,11 @@ DINodeAttr DebugImporter::translate(llvm::DINode *node) {
 
   // Register with the recursive translator. If it is seen before, return the
   // result immediately.
-  if (DINodeAttr attr = recursionPruner->tryPrune(node))
+  if (DINodeAttr attr = recursionPruner.tryPrune(node))
     return attr;
 
   auto guard = llvm::make_scope_exit(
-      [&]() { recursionPruner->popTranslationStack(node); });
+      [&]() { recursionPruner.popTranslationStack(node); });
 
   // Convert the debug metadata if possible.
   auto translateNode = [this](llvm::DINode *node) -> DINodeAttr {
@@ -552,7 +296,7 @@ DINodeAttr DebugImporter::translate(llvm::DINode *node) {
   };
   if (DINodeAttr attr = translateNode(node)) {
     auto [result, isSelfContained] =
-        recursionPruner->finalizeTranslation(node, attr);
+        recursionPruner.finalizeTranslation(node, attr);
     // Only cache fully self-contained nodes.
     if (isSelfContained)
       nodeToAttr.try_emplace(node, result);
@@ -561,6 +305,192 @@ DINodeAttr DebugImporter::translate(llvm::DINode *node) {
   return nullptr;
 }
 
+//===----------------------------------------------------------------------===//
+// RecursionPruner
+//===----------------------------------------------------------------------===//
+/// Get the `getRecSelf` constructor for the translated type of `node` if its
+/// translated DITypeAttr supports recursion. Otherwise, returns nullptr.
+static function_ref<DIRecursiveTypeAttrInterface(DistinctAttr)>
+getRecSelfConstructor(llvm::DINode *node) {
+  using CtorType = function_ref<DIRecursiveTypeAttrInterface(DistinctAttr)>;
+  return TypeSwitch<llvm::DINode *, CtorType>(node)
+      .Case([&](llvm::DICompositeType *) {
+        return CtorType(DICompositeTypeAttr::getRecSelf);
+      })
+      .Default(CtorType());
+}
+
+DINodeAttr DebugImporter::RecursionPruner::tryPrune(llvm::DINode *node) {
+  // Lookup the cache first.
+  auto [result, unboundSelfRefs] = lookup(node);
+  if (result) {
+    // Need to inject unbound self-refs into the previous layer.
+    if (!unboundSelfRefs.empty())
+      translationStack.back().second.unboundSelfRefs.insert(
+          unboundSelfRefs.begin(), unboundSelfRefs.end());
+    return result;
+  }
+
+  // If the node type is capable of being recursive, check if it's seen
+  // before.
+  auto recSelfCtor = getRecSelfConstructor(node);
+  if (recSelfCtor) {
+    // If a cyclic dependency is detected since the same node is being
+    // traversed twice, emit a recursive self type, and mark the duplicate
+    // node on the translationStack so it can emit a recursive decl type.
+    auto [iter, inserted] = translationStack.try_emplace(node);
+    if (!inserted) {
+      // The original node may have already been assigned a recursive ID from
+      // a different self-reference. Use that if possible.
+      DIRecursiveTypeAttrInterface recSelf = iter->second.recSelf;
+      if (!recSelf) {
+        DistinctAttr recId = DistinctAttr::create(UnitAttr::get(context));
+        recSelf = recSelfCtor(recId);
+        iter->second.recSelf = recSelf;
+      }
+      // Inject the self-ref into the previous layer.
+      translationStack.back().second.unboundSelfRefs.insert(recSelf);
+      return cast<DINodeAttr>(recSelf);
+    }
+  }
+  return nullptr;
+}
+
+/// Register the translated result of `node`. Returns the finalized result
+/// (with recId if recursive) and whether the result is self-contained
+/// (i.e. contains no unbound self-refs).
+std::pair<DINodeAttr, bool>
+DebugImporter::RecursionPruner::finalizeTranslation(llvm::DINode *node,
+                                                    DINodeAttr result) {
+  // If `node` is not a potentially recursive type, it will not be on the
+  // translation stack. Nothing to set in this case.
+  if (translationStack.empty())
+    return {result, true};
+  if (translationStack.back().first != node)
+    return {result, translationStack.back().second.unboundSelfRefs.empty()};
+
+  TranslationState &state = translationStack.back().second;
+
+  // If this node is actually recursive, set the recId onto `result`.
+  if (DIRecursiveTypeAttrInterface recSelf = state.recSelf) {
+    auto recType = cast<DIRecursiveTypeAttrInterface>(result);
+    result = cast<DINodeAttr>(recType.withRecId(recSelf.getRecId()));
+
+    // Remove this recSelf from the set of unbound selfRefs.
+    state.unboundSelfRefs.erase(recSelf);
+
+    // Insert the newly resolved recursive type into the cache entries that
+    // rely on it.
+    // Only need to look at the caches at this level.
+    uint64_t numRemaining = state.cacheSize;
+    for (CachedTranslation &cacheEntry :
+         llvm::make_second_range(llvm::reverse(cache))) {
+      if (numRemaining == 0)
+        break;
+      --numRemaining;
+
+      if (auto refIter = cacheEntry.pendingReplacements.find(recSelf);
+          refIter != cacheEntry.pendingReplacements.end())
+        refIter->second = result;
+    }
+  }
+
+  // Insert the current result into the cache.
+  state.cacheSize++;
+  auto [iter, inserted] = cache.try_emplace(node);
+  assert(inserted && "invalid state: caching the same DINode twice");
+  iter->second.attr = result;
+
+  // If this node had any unbound self-refs free when it is registered into
+  // the cache, set up replacement placeholders: This result will need these
+  // unbound self-refs to be replaced before being used.
+  for (DIRecursiveTypeAttrInterface selfRef : state.unboundSelfRefs)
+    iter->second.pendingReplacements.try_emplace(selfRef, nullptr);
+
+  return {result, state.unboundSelfRefs.empty()};
+}
+
+/// Pop off a frame from the translation stack after a node is done being
+/// translated.
+void DebugImporter::RecursionPruner::popTranslationStack(llvm::DINode *node) {
+  // If `node` is not a potentially recursive type, it will not be on the
+  // translation stack. Nothing to handle in this case.
+  if (translationStack.empty() || translationStack.back().first != node)
+    return;
+
+  // At the end of the stack, all unbound self-refs must be resolved already,
+  // and the entire cache should be accounted for.
+  TranslationState &currLayerState = translationStack.back().second;
+  if (translationStack.size() == 1) {
+    assert(currLayerState.unboundSelfRefs.empty() &&
+           "internal error: unbound recursive self reference at top level.");
+    assert(currLayerState.cacheSize == cache.size() &&
+           "internal error: inconsistent cache size");
+    translationStack.pop_back();
+    cache.clear();
+    return;
+  }
+
+  // Copy unboundSelfRefs down to the previous level.
+  TranslationState &nextLayerState = (++translationStack.rbegin())->second;
+  nextLayerState.unboundSelfRefs.insert(currLayerState.unboundSelfRefs.begin(),
+                                        currLayerState.unboundSelfRefs.end());
+
+  // The current layer cache is now considered part of the lower layer cache.
+  nextLayerState.cacheSize += currLayerState.cacheSize;
+
+  // Finally pop off this layer when all bookkeeping is done.
+  translationStack.pop_back();
+}
+
+/// Returns the cached result (if exists) with all known replacements applied.
+/// Also returns the set of unbound self-refs that are unresolved in this
+/// cached result.
+/// The cache entry will also be updated with the replaced result and with the
+/// applied replacements removed from the pendingReplacements map.
+std::pair<DINodeAttr, DenseSet<DIRecursiveTypeAttrInterface>>
+DebugImporter::RecursionPruner::lookup(llvm::DINode *node) {
+  auto cacheIter = cache.find(node);
+  if (cacheIter == cache.end())
+    return {};
+
+  CachedTranslation &entry = cacheIter->second;
+
+  if (entry.pendingReplacements.empty())
+    return std::make_pair(entry.attr, DenseSet<DIRecursiveTypeAttrInterface>{});
+
+  AttrTypeReplacer replacer;
+  replacer.addReplacement(
+      [&entry](DIRecursiveTypeAttrInterface attr)
+          -> std::optional<std::pair<Attribute, WalkResult>> {
+        if (auto replacement = entry.pendingReplacements.lookup(attr)) {
+          // A replacement may contain additional unbound self-refs.
+          return std::make_pair(replacement, mlir::WalkResult::advance());
+        }
+        return std::make_pair(attr, mlir::WalkResult::skip());
+      });
+
+  Attribute replacedAttr = replacer.replace(entry.attr);
+  DINodeAttr result = cast<DINodeAttr>(replacedAttr);
+
+  // Update cache entry to save replaced version and remove already-applied
+  // replacements.
+  entry.attr = result;
+  DenseSet<DIRecursiveTypeAttrInterface> unboundRefs;
+  DenseSet<DIRecursiveTypeAttrInterface> boundRefs;
+  for (auto [refSelf, replacement] : entry.pendingReplacements) {
+    if (replacement)
+      boundRefs.insert(refSelf);
+    else
+      unboundRefs.insert(refSelf);
+  }
+
+  for (DIRecursiveTypeAttrInterface ref : boundRefs)
+    entry.pendingReplacements.erase(ref);
+
+  return std::make_pair(result, unboundRefs);
+}
+
 //===----------------------------------------------------------------------===//
 // Locations
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/DebugImporter.h b/mlir/lib/Target/LLVMIR/DebugImporter.h
index a95346cd3cd9c4..dd5aba9169dbd1 100644
--- a/mlir/lib/Target/LLVMIR/DebugImporter.h
+++ b/mlir/lib/Target/LLVMIR/DebugImporter.h
@@ -27,8 +27,6 @@ class LLVMFuncOp;
 
 namespace detail {
 
-class RecursionPruner;
-
 class DebugImporter {
 public:
   DebugImporter(ModuleOp mlirModule);
@@ -94,8 +92,98 @@ class DebugImporter {
   /// distinct id attribute.
   DenseMap<llvm::DINode *, DistinctAttr> nodeToDistinctAttr;
 
-  /// A translation helper for recursive types.
-  std::unique_ptr<RecursionPruner> recursionPruner;
+  /// Translation helper for recursive DINodes.
+  /// Works alongside a stack-based DINode translator (the "main translator")
+  /// for gracefully handling DINodes that are recursive.
+  ///
+  /// Usage:
+  /// - Before translating a node, call `tryPrune` to see if the pruner can
+  ///   preempt this translation. If this is a node that the pruner already
+  ///   knows how to handle, it will return the translated DINodeAttr.
+  /// - After a node is successfully translated by the main translator, call
+  ///   `finalizeTranslation` to save the translated result with the pruner, and
+  ///   give it a chance to further modify the result.
+  /// - Regardless of success or failure by the main translator, always call
+  ///   `finally` at the end of translating a node. This is necessary to keep
+  ///   the internal book-keeping in sync.
+  ///
+  /// This helper maintains an internal cache so that no recursive type will
+  /// be translated more than once by the main translator.
+  /// This internal cache is different from the cache maintained by the main
+  /// translator because it may store nodes that are not self-contained (i.e.
+  /// contain unbounded recursive self-references).
+  class RecursionPruner {
+  public:
+    RecursionPruner(MLIRContext *context) : context(context) {}
+
+    /// If this node was previously cached, returns the cached result.
+    /// If this node is a recursive instance that was previously seen, returns a
+    /// self-reference.
+    /// Otherwise, returns null attr.
+    DINodeAttr tryPrune(llvm::DINode *node);
+
+    /// Register the translated result of `node`. Returns the finalized result
+    /// (with recId if recursive) and whether the result is self-contained
+    /// (i.e. contains no unbound self-refs).
+    std::pair<DINodeAttr, bool> finalizeTranslation(llvm::DINode *node,
+                                                    DINodeAttr result);
+
+    /// Pop off a frame from the translation stack after a node is done being
+    /// translated.
+    void popTranslationStack(llvm::DINode *node);
+
+  private:
+    /// Returns the cached result (if exists) with all known replacements
+    /// applied. Also returns the set of unbound self-refs that are unresolved
+    /// in this cached result. The cache entry will also be updated with the
+    /// replaced result and with the applied replacements removed from the
+    /// pendingReplacements map.
+    std::pair<DINodeAttr, DenseSet<DIRecursiveTypeAttrInterface>>
+    lookup(llvm::DINode *node);
+
+    MLIRContext *context;
+
+    /// A lazy cached translation that contains the translated attribute as well
+    /// as any unbound self-references that need to be replaced at lookup.
+    struct CachedTranslation {
+      /// The translated attr. May contain unbound self-references for other
+      /// recursive attrs.
+      DINodeAttr attr;
+      /// The pending replacements that need to be run on this `attr` before it
+      /// can be used.
+      /// Each key is a recursive self-ref, and each value is a recursive decl
+      /// that may contain additional unbound self-refs that need to be
+      /// replaced. Each replacement will be applied at most once. Once a
+      /// replacement is applied, the cached `attr` will be updated, and the
+      /// replacement will be removed from this map.
+      DenseMap<DIRecursiveTypeAttrInterface, DINodeAttr> pendingReplacements;
+    };
+    /// A mapping between LLVM debug metadata and the corresponding attribute.
+    llvm::MapVector<llvm::DINode *, CachedTranslation> cache;
+
+    /// Each potentially recursive node will have a TranslationState pushed onto
+    /// the `translationStack` to keep track of whether this node is actually
+    /// recursive (i.e. has self-references inside), and other book-keeping.
+    struct TranslationState {
+      /// The rec-self if this node is indeed a recursive node (i.e. another
+      /// instance of itself is seen while translating it). Null if this node
+      /// has not been seen again deeper in the translation stack.
+      DIRecursiveTypeAttrInterface recSelf;
+      /// The number of cache entries belonging to this layer of the translation
+      /// stack. This corresponds to the last `cacheSize` entries of `cache`.
+      uint64_t cacheSize = 0;
+      /// All the unbound recursive self references in this layer of the
+      /// translation stack.
+      DenseSet<DIRecursiveTypeAttrInterface> unboundSelfRefs;
+    };
+    /// A stack that stores the metadata nodes that are being traversed. The
+    /// stack is used to handle cyclic dependencies during metadata translation.
+    /// Each node is pushed with an empty TranslationState. If it is ever seen
+    /// later when the stack is deeper, the node is recursive, and its
+    /// TranslationState is assigned a recSelf.
+    llvm::MapVector<llvm::DINode *, TranslationState> translationStack;
+  };
+  RecursionPruner recursionPruner;
 
   MLIRContext *context;
   ModuleOp mlirModule;

>From ffe56df7537a33f6c6c69938369a174c7e33c8a4 Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Tue, 2 Apr 2024 13:39:50 -0700
Subject: [PATCH 04/12] rename methods for clarity

---
 mlir/lib/Target/LLVMIR/DebugImporter.cpp |  9 +++++----
 mlir/lib/Target/LLVMIR/DebugImporter.h   | 17 ++++++++++-------
 2 files changed, 15 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/DebugImporter.cpp b/mlir/lib/Target/LLVMIR/DebugImporter.cpp
index f57411dc7d97c0..f4329dda2c4388 100644
--- a/mlir/lib/Target/LLVMIR/DebugImporter.cpp
+++ b/mlir/lib/Target/LLVMIR/DebugImporter.cpp
@@ -252,9 +252,9 @@ DINodeAttr DebugImporter::translate(llvm::DINode *node) {
   if (DINodeAttr attr = nodeToAttr.lookup(node))
     return attr;
 
-  // Register with the recursive translator. If it is seen before, return the
-  // result immediately.
-  if (DINodeAttr attr = recursionPruner.tryPrune(node))
+  // Register with the recursive translator. If it can be handled without
+  // recursing into it, return the result immediately.
+  if (DINodeAttr attr = recursionPruner.pruneOrPushTranslationStack(node))
     return attr;
 
   auto guard = llvm::make_scope_exit(
@@ -320,7 +320,8 @@ getRecSelfConstructor(llvm::DINode *node) {
       .Default(CtorType());
 }
 
-DINodeAttr DebugImporter::RecursionPruner::tryPrune(llvm::DINode *node) {
+DINodeAttr DebugImporter::RecursionPruner::pruneOrPushTranslationStack(
+    llvm::DINode *node) {
   // Lookup the cache first.
   auto [result, unboundSelfRefs] = lookup(node);
   if (result) {
diff --git a/mlir/lib/Target/LLVMIR/DebugImporter.h b/mlir/lib/Target/LLVMIR/DebugImporter.h
index dd5aba9169dbd1..7f21b836ac9f38 100644
--- a/mlir/lib/Target/LLVMIR/DebugImporter.h
+++ b/mlir/lib/Target/LLVMIR/DebugImporter.h
@@ -97,15 +97,16 @@ class DebugImporter {
   /// for gracefully handling DINodes that are recursive.
   ///
   /// Usage:
-  /// - Before translating a node, call `tryPrune` to see if the pruner can
-  ///   preempt this translation. If this is a node that the pruner already
-  ///   knows how to handle, it will return the translated DINodeAttr.
+  /// - Before translating a node, call `pruneOrPushTranslationStack` to see if
+  ///   the pruner can preempt this translation. If this is a node that the
+  ///   pruner already knows how to handle, it will return the translated
+  ///   DINodeAttr.
   /// - After a node is successfully translated by the main translator, call
   ///   `finalizeTranslation` to save the translated result with the pruner, and
   ///   give it a chance to further modify the result.
   /// - Regardless of success or failure by the main translator, always call
-  ///   `finally` at the end of translating a node. This is necessary to keep
-  ///   the internal book-keeping in sync.
+  ///   `popTranslationStack` at the end of translating a node. This is
+  ///   necessary to keep the internal book-keeping in sync.
   ///
   /// This helper maintains an internal cache so that no recursive type will
   /// be translated more than once by the main translator.
@@ -119,8 +120,10 @@ class DebugImporter {
     /// If this node was previously cached, returns the cached result.
     /// If this node is a recursive instance that was previously seen, returns a
     /// self-reference.
-    /// Otherwise, returns null attr.
-    DINodeAttr tryPrune(llvm::DINode *node);
+    /// Otherwise, returns null attr, and a translation stack frame is created
+    /// for this node. Expects `finalizeTranslation` & `popTranslationStack`
+    /// to be called on this node later.
+    DINodeAttr pruneOrPushTranslationStack(llvm::DINode *node);
 
     /// Register the translated result of `node`. Returns the finalized result
     /// (with recId if recursive) and whether the result is self-contained

>From 3e369fdd51acc41689e0c6f1577206418af6e999 Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Tue, 2 Apr 2024 15:41:32 -0700
Subject: [PATCH 05/12] recurse into decl and add test for it

---
 mlir/lib/Target/LLVMIR/DebugImporter.cpp     |  2 +-
 mlir/test/Target/LLVMIR/Import/debug-info.ll | 58 +++++++++++++++++++-
 2 files changed, 56 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/DebugImporter.cpp b/mlir/lib/Target/LLVMIR/DebugImporter.cpp
index f4329dda2c4388..db499a02553fdf 100644
--- a/mlir/lib/Target/LLVMIR/DebugImporter.cpp
+++ b/mlir/lib/Target/LLVMIR/DebugImporter.cpp
@@ -468,7 +468,7 @@ DebugImporter::RecursionPruner::lookup(llvm::DINode *node) {
           // A replacement may contain additional unbound self-refs.
           return std::make_pair(replacement, mlir::WalkResult::advance());
         }
-        return std::make_pair(attr, mlir::WalkResult::skip());
+        return std::make_pair(attr, mlir::WalkResult::advance());
       });
 
   Attribute replacedAttr = replacer.replace(entry.attr);
diff --git a/mlir/test/Target/LLVMIR/Import/debug-info.ll b/mlir/test/Target/LLVMIR/Import/debug-info.ll
index 931f914c995bb6..ac1a6b79675c1d 100644
--- a/mlir/test/Target/LLVMIR/Import/debug-info.ll
+++ b/mlir/test/Target/LLVMIR/Import/debug-info.ll
@@ -610,7 +610,15 @@ declare !dbg !1 void @declaration()
 
 ; // -----
 
-; Ensure that repeated occurence of recursive subtree does not result in duplicate MLIR entries.
+; Ensure that repeated occurence of recursive subtree does not result in
+; duplicate MLIR entries.
+;
+; +--> B:B1 ----+
+; |     ^       v
+; A <---+------ B
+; |     v       ^
+; +--> B:B2 ----+
+; This should result in only one B instance.
 
 ; CHECK-DAG: #[[B1_INNER:.+]] = #llvm.di_derived_type<{{.*}}name = "B:B1", baseType = #[[B_SELF:.+]]>
 ; CHECK-DAG: #[[B2_INNER:.+]] = #llvm.di_derived_type<{{.*}}name = "B:B2", baseType = #[[B_SELF]]>
@@ -629,8 +637,6 @@ define void @class_field(ptr %arg1) !dbg !18 {
   ret void
 }
 
-declare void @llvm.dbg.value(metadata, metadata, metadata)
-
 !llvm.dbg.cu = !{!1}
 !llvm.module.flags = !{!0}
 !0 = !{i32 2, !"Debug Info Version", i32 3}
@@ -646,3 +652,49 @@ declare void @llvm.dbg.value(metadata, metadata, metadata)
 !9 = !{!7, !8}
 
 !18 = distinct !DISubprogram(name: "A", scope: !3, file: !2, spFlags: DISPFlagDefinition, unit: !1)
+
+; // -----
+
+; Ensure that recursive cycles with multiple entry points are cached correctly.
+;
+; +---- A ----+
+; v           v
+; B <-------> C
+; This should result in a cached instance of B --> C --> B_SELF to be reused
+; when visiting B from C (after visiting B from A).
+
+; CHECK-DAG: #[[C_INNER:.+]] = #llvm.di_composite_type<{{.*}}name = "C", {{.*}}scope = #[[A_SELF:[^ ]+]], {{.*}}elements = #[[C_INNER_TO_B_SELF:.+]]
+; CHECK-DAG: #[[C_INNER_TO_B_SELF:.+]] = #llvm.di_derived_type<{{.*}}name = "->B", {{.*}}baseType = #[[B_SELF:[^ ]+]]>
+; CHECK-DAG: #[[B_TO_C_INNER:.+]] = #llvm.di_derived_type<{{.*}}name = "->C", {{.*}}baseType = #[[C_INNER]]
+; CHECK-DAG: #[[B:.+]] = #llvm.di_composite_type<{{.*}}recId = [[B_RECID:.+]], {{.*}}name = "B", {{.*}}scope = #[[A_SELF]], {{.*}}elements = #[[B_TO_C_INNER]]
+; CHECK-DAG: #[[A_SELF]] = #llvm.di_composite_type<{{.*}}recId = [[A_RECID:.+]]>
+; CHECK-DAG: #[[B_SELF]] = #llvm.di_composite_type<{{.*}}recId = [[B_RECID]]>
+
+; CHECK-DAG: #[[TO_B:.+]] = #llvm.di_derived_type<{{.*}}name = "->B", {{.*}}baseType = #[[B]]
+; CHECK-DAG: #[[C_OUTER:.+]] = #llvm.di_composite_type<{{.*}}name = "C", {{.*}}scope = #[[A_SELF]], {{.*}}elements = #[[TO_B]]
+; CHECK-DAG: #[[TO_C:.+]] = #llvm.di_derived_type<{{.*}}name = "->C", {{.*}}baseType = #[[C_OUTER]]
+; CHECK-DAG: #[[A:.+]] = #llvm.di_composite_type<{{.*}}recId = [[A_RECID]], {{.*}}name = "A", {{.*}}elements = #[[TO_B]], #[[TO_C]]
+
+; CHECK-DAG: #llvm.di_subprogram<{{.*}}scope = #[[A]],
+
+define void @class_field(ptr %arg1) !dbg !18 {
+  ret void
+}
+
+!llvm.dbg.cu = !{!1}
+!llvm.module.flags = !{!0}
+!0 = !{i32 2, !"Debug Info Version", i32 3}
+!1 = distinct !DICompileUnit(language: DW_LANG_C, file: !2)
+!2 = !DIFile(filename: "debug-info.ll", directory: "/")
+
+!3 = !DICompositeType(tag: DW_TAG_class_type, name: "A", file: !2, line: 42, flags: DIFlagTypePassByReference | DIFlagNonTrivial, elements: !4)
+!5 = !DICompositeType(tag: DW_TAG_class_type, name: "B", scope: !3, file: !2, line: 42, flags: DIFlagTypePassByReference | DIFlagNonTrivial, elements: !10)
+!6 = !DICompositeType(tag: DW_TAG_class_type, name: "C", scope: !3, file: !2, line: 42, flags: DIFlagTypePassByReference | DIFlagNonTrivial, elements: !9)
+
+!7 = !DIDerivedType(tag: DW_TAG_member, name: "->B", file: !2, baseType: !5)
+!8 = !DIDerivedType(tag: DW_TAG_member, name: "->C", file: !2, baseType: !6)
+!4 = !{!7, !8}
+!9 = !{!7}
+!10 = !{!8}
+
+!18 = distinct !DISubprogram(name: "SP", scope: !3, file: !2, spFlags: DISPFlagDefinition, unit: !1)

>From a8e999cd8abb0fdbaf318f64ca77d304edf58244 Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Tue, 2 Apr 2024 15:50:11 -0700
Subject: [PATCH 06/12] cleanup

---
 mlir/lib/Target/LLVMIR/DebugImporter.cpp | 12 ------------
 mlir/lib/Target/LLVMIR/DebugImporter.h   |  1 -
 2 files changed, 13 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/DebugImporter.cpp b/mlir/lib/Target/LLVMIR/DebugImporter.cpp
index db499a02553fdf..b5b709f00a07fd 100644
--- a/mlir/lib/Target/LLVMIR/DebugImporter.cpp
+++ b/mlir/lib/Target/LLVMIR/DebugImporter.cpp
@@ -29,8 +29,6 @@ DebugImporter::DebugImporter(ModuleOp mlirModule)
     : recursionPruner(mlirModule.getContext()),
       context(mlirModule.getContext()), mlirModule(mlirModule) {}
 
-DebugImporter::~DebugImporter() {}
-
 Location DebugImporter::translateFuncLocation(llvm::Function *func) {
   llvm::DISubprogram *subprogram = func->getSubprogram();
   if (!subprogram)
@@ -357,9 +355,6 @@ DINodeAttr DebugImporter::RecursionPruner::pruneOrPushTranslationStack(
   return nullptr;
 }
 
-/// Register the translated result of `node`. Returns the finalized result
-/// (with recId if recursive) and whether the result is self-contained
-/// (i.e. contains no unbound self-refs).
 std::pair<DINodeAttr, bool>
 DebugImporter::RecursionPruner::finalizeTranslation(llvm::DINode *node,
                                                     DINodeAttr result) {
@@ -411,8 +406,6 @@ DebugImporter::RecursionPruner::finalizeTranslation(llvm::DINode *node,
   return {result, state.unboundSelfRefs.empty()};
 }
 
-/// Pop off a frame from the translation stack after a node is done being
-/// translated.
 void DebugImporter::RecursionPruner::popTranslationStack(llvm::DINode *node) {
   // If `node` is not a potentially recursive type, it will not be on the
   // translation stack. Nothing to handle in this case.
@@ -444,11 +437,6 @@ void DebugImporter::RecursionPruner::popTranslationStack(llvm::DINode *node) {
   translationStack.pop_back();
 }
 
-/// Returns the cached result (if exists) with all known replacements applied.
-/// Also returns the set of unbound self-refs that are unresolved in this
-/// cached result.
-/// The cache entry will also be updated with the replaced result and with the
-/// applied replacements removed from the pendingReplacements map.
 std::pair<DINodeAttr, DenseSet<DIRecursiveTypeAttrInterface>>
 DebugImporter::RecursionPruner::lookup(llvm::DINode *node) {
   auto cacheIter = cache.find(node);
diff --git a/mlir/lib/Target/LLVMIR/DebugImporter.h b/mlir/lib/Target/LLVMIR/DebugImporter.h
index 7f21b836ac9f38..20f9b90154a358 100644
--- a/mlir/lib/Target/LLVMIR/DebugImporter.h
+++ b/mlir/lib/Target/LLVMIR/DebugImporter.h
@@ -30,7 +30,6 @@ namespace detail {
 class DebugImporter {
 public:
   DebugImporter(ModuleOp mlirModule);
-  ~DebugImporter();
 
   /// Translates the given LLVM debug location to an MLIR location.
   Location translateLoc(llvm::DILocation *loc);

>From 3a0842b1dd3f78c7c5db3720beb69b3e1f2576e6 Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Wed, 3 Apr 2024 15:58:59 -0700
Subject: [PATCH 07/12] handle mutual recursion replacement

---
 mlir/lib/Target/LLVMIR/DebugImporter.cpp     | 83 +++++++++++++++++---
 mlir/test/Target/LLVMIR/Import/debug-info.ll | 53 +++++++++++++
 2 files changed, 124 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/DebugImporter.cpp b/mlir/lib/Target/LLVMIR/DebugImporter.cpp
index b5b709f00a07fd..eaa546a8b8f834 100644
--- a/mlir/lib/Target/LLVMIR/DebugImporter.cpp
+++ b/mlir/lib/Target/LLVMIR/DebugImporter.cpp
@@ -318,6 +318,75 @@ getRecSelfConstructor(llvm::DINode *node) {
       .Default(CtorType());
 }
 
+/// An attribute replacer that replaces nested recursive decls with recursive
+/// self-references instead.
+///
+/// - Recurses down the attribute tree while replacing attributes based on the
+///   provided replacement map.
+/// - Keeps track of the currently open recursive declarations, and upon
+///   encountering a duplicate declaration, replace with a self-ref instead.
+static Attribute replaceAndPruneRecursiveTypesImpl(
+    Attribute node,
+    const DenseMap<DIRecursiveTypeAttrInterface, DINodeAttr> &mapping,
+    DenseSet<DistinctAttr> &openDecls) {
+  DistinctAttr recId;
+  if (auto recType = dyn_cast<DIRecursiveTypeAttrInterface>(node)) {
+    recId = recType.getRecId();
+
+    // Configure context.
+    if (recId) {
+      if (recType.isRecSelf()) {
+        // Replace selfRef based on the provided mapping.
+        if (DINodeAttr replacement = mapping.lookup(recType))
+          return replaceAndPruneRecursiveTypesImpl(replacement, mapping,
+                                                   openDecls);
+        return node;
+      }
+
+      auto [_, inserted] = openDecls.insert(recId);
+      if (!inserted) {
+        // This is a nested decl. Replace with recSelf.
+        return recType.getRecSelf(recId);
+      }
+    }
+  }
+
+  // Collect sub attrs.
+  SmallVector<Attribute> attrs;
+  SmallVector<Type> types;
+  node.walkImmediateSubElements(
+      [&attrs](Attribute attr) { attrs.push_back(attr); },
+      [&types](Type type) { types.push_back(type); });
+
+  // Recurse into attributes.
+  bool changed = false;
+  for (auto it = attrs.begin(); it != attrs.end(); it++) {
+    Attribute replaced =
+        replaceAndPruneRecursiveTypesImpl(*it, mapping, openDecls);
+    if (replaced != *it) {
+      *it = replaced;
+      changed = true;
+    }
+  }
+
+  Attribute result = node;
+  if (changed)
+    result = result.replaceImmediateSubElements(attrs, types);
+
+  // Reset context.
+  if (recId)
+    openDecls.erase(recId);
+
+  return result;
+}
+
+static Attribute replaceAndPruneRecursiveTypes(
+    DINodeAttr node,
+    const DenseMap<DIRecursiveTypeAttrInterface, DINodeAttr> &mapping) {
+  DenseSet<DistinctAttr> openDecls;
+  return replaceAndPruneRecursiveTypesImpl(node, mapping, openDecls);
+}
+
 DINodeAttr DebugImporter::RecursionPruner::pruneOrPushTranslationStack(
     llvm::DINode *node) {
   // Lookup the cache first.
@@ -448,18 +517,8 @@ DebugImporter::RecursionPruner::lookup(llvm::DINode *node) {
   if (entry.pendingReplacements.empty())
     return std::make_pair(entry.attr, DenseSet<DIRecursiveTypeAttrInterface>{});
 
-  AttrTypeReplacer replacer;
-  replacer.addReplacement(
-      [&entry](DIRecursiveTypeAttrInterface attr)
-          -> std::optional<std::pair<Attribute, WalkResult>> {
-        if (auto replacement = entry.pendingReplacements.lookup(attr)) {
-          // A replacement may contain additional unbound self-refs.
-          return std::make_pair(replacement, mlir::WalkResult::advance());
-        }
-        return std::make_pair(attr, mlir::WalkResult::advance());
-      });
-
-  Attribute replacedAttr = replacer.replace(entry.attr);
+  Attribute replacedAttr =
+      replaceAndPruneRecursiveTypes(entry.attr, entry.pendingReplacements);
   DINodeAttr result = cast<DINodeAttr>(replacedAttr);
 
   // Update cache entry to save replaced version and remove already-applied
diff --git a/mlir/test/Target/LLVMIR/Import/debug-info.ll b/mlir/test/Target/LLVMIR/Import/debug-info.ll
index ac1a6b79675c1d..2218e6cd5b7bbd 100644
--- a/mlir/test/Target/LLVMIR/Import/debug-info.ll
+++ b/mlir/test/Target/LLVMIR/Import/debug-info.ll
@@ -698,3 +698,56 @@ define void @class_field(ptr %arg1) !dbg !18 {
 !10 = !{!8}
 
 !18 = distinct !DISubprogram(name: "SP", scope: !3, file: !2, spFlags: DISPFlagDefinition, unit: !1)
+
+; // -----
+
+; Ensures that replacing a nested mutually recursive decl does not result in
+; nested duplicate recursive decls.
+;
+; A ---> B <--> C
+; ^             ^
+; +-------------+
+
+; CHECK-DAG: #[[A:.+]] = #llvm.di_composite_type<{{.*}}recId = [[A_RECID:.+]], {{.*}}name = "A", {{.*}}elements = #[[A_TO_B:.+]], #[[A_TO_C:.+]]>
+; CHECK-DAG: #llvm.di_subprogram<{{.*}}scope = #[[A]],
+; CHECK-DAG: #[[A_TO_B]] = #llvm.di_derived_type<{{.*}}name = "->B", {{.*}}baseType = #[[B_FROM_A:.+]]>
+; CHECK-DAG: #[[A_TO_C]] = #llvm.di_derived_type<{{.*}}name = "->C", {{.*}}baseType = #[[C_FROM_A:.+]]>
+
+; CHECK-DAG: #[[B_FROM_A]] = #llvm.di_composite_type<{{.*}}recId = [[B_RECID:.+]], {{.*}}name = "B", {{.*}}elements = #[[B_TO_C:.+]]>
+; CHECK-DAG: #[[B_TO_C]] = #llvm.di_derived_type<{{.*}}name = "->C", {{.*}}baseType = #[[C_FROM_B:.+]]>
+; CHECK-DAG: #[[C_FROM_B]] = #llvm.di_composite_type<{{.*}}recId = [[C_RECID:.+]], {{.*}}name = "C", {{.*}}elements = #[[TO_A_SELF:.+]], #[[TO_B_SELF:.+]], #[[TO_C_SELF:.+]]>
+
+; CHECK-DAG: #[[C_FROM_A]] = #llvm.di_composite_type<{{.*}}recId = [[C_RECID]], {{.*}}name = "C", {{.*}}elements = #[[TO_A_SELF]], #[[TO_B_INNER:.+]], #[[TO_C_SELF]]
+; CHECK-DAG: #[[TO_B_INNER]] = #llvm.di_derived_type<{{.*}}name = "->B", {{.*}}baseType = #[[B_INNER:.+]]>
+; CHECK-DAG: #[[B_INNER]] = #llvm.di_composite_type<{{.*}}name = "B", {{.*}}elements = #[[TO_C_SELF]]>
+
+; CHECK-DAG: #[[TO_A_SELF]] = #llvm.di_derived_type<{{.*}}name = "->A", {{.*}}baseType = #[[A_SELF:.+]]>
+; CHECK-DAG: #[[TO_B_SELF]] = #llvm.di_derived_type<{{.*}}name = "->B", {{.*}}baseType = #[[B_SELF:.+]]>
+; CHECK-DAG: #[[TO_C_SELF]] = #llvm.di_derived_type<{{.*}}name = "->C", {{.*}}baseType = #[[C_SELF:.+]]>
+; CHECK-DAG: #[[A_SELF]] = #llvm.di_composite_type<{{.*}}recId = [[A_RECID]]>
+; CHECK-DAG: #[[B_SELF]] = #llvm.di_composite_type<{{.*}}recId = [[B_RECID]]>
+; CHECK-DAG: #[[C_SELF]] = #llvm.di_composite_type<{{.*}}recId = [[C_RECID]]>
+
+define void @class_field(ptr %arg1) !dbg !18 {
+  ret void
+}
+
+!llvm.dbg.cu = !{!1}
+!llvm.module.flags = !{!0}
+!0 = !{i32 2, !"Debug Info Version", i32 3}
+!1 = distinct !DICompileUnit(language: DW_LANG_C, file: !2)
+!2 = !DIFile(filename: "debug-info.ll", directory: "/")
+
+!3 = !DICompositeType(tag: DW_TAG_class_type, name: "A", file: !2, line: 42, flags: DIFlagTypePassByReference | DIFlagNonTrivial, elements: !9)
+!4 = !DICompositeType(tag: DW_TAG_class_type, name: "B", file: !2, line: 42, flags: DIFlagTypePassByReference | DIFlagNonTrivial, elements: !10)
+!5 = !DICompositeType(tag: DW_TAG_class_type, name: "C", file: !2, line: 42, flags: DIFlagTypePassByReference | DIFlagNonTrivial, elements: !11)
+
+!6 = !DIDerivedType(tag: DW_TAG_member, name: "->A", file: !2, baseType: !3)
+!7 = !DIDerivedType(tag: DW_TAG_member, name: "->B", file: !2, baseType: !4)
+!8 = !DIDerivedType(tag: DW_TAG_member, name: "->C", file: !2, baseType: !5)
+
+!9 = !{!7, !8} ; A -> B, C
+!10 = !{!8} ; B -> C
+!11 = !{!6, !7, !8} ; C -> A, B, C
+
+!18 = distinct !DISubprogram(name: "SP", scope: !3, file: !2, spFlags: DISPFlagDefinition, unit: !1)

>From 6f57bc1c1697cb10cb2f61485c476a0ede73bf77 Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Wed, 3 Apr 2024 16:50:07 -0700
Subject: [PATCH 08/12] replace with iterative version instead

---
 mlir/lib/Target/LLVMIR/DebugImporter.cpp | 140 ++++++++++++++---------
 1 file changed, 84 insertions(+), 56 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/DebugImporter.cpp b/mlir/lib/Target/LLVMIR/DebugImporter.cpp
index eaa546a8b8f834..ecc179eb68136c 100644
--- a/mlir/lib/Target/LLVMIR/DebugImporter.cpp
+++ b/mlir/lib/Target/LLVMIR/DebugImporter.cpp
@@ -318,73 +318,101 @@ getRecSelfConstructor(llvm::DINode *node) {
       .Default(CtorType());
 }
 
-/// An attribute replacer that replaces nested recursive decls with recursive
-/// self-references instead.
-///
+/// An iterative attribute replacer that also handles pruning recursive types.
 /// - Recurses down the attribute tree while replacing attributes based on the
 ///   provided replacement map.
 /// - Keeps track of the currently open recursive declarations, and upon
-///   encountering a duplicate declaration, replace with a self-ref instead.
-static Attribute replaceAndPruneRecursiveTypesImpl(
-    Attribute node,
-    const DenseMap<DIRecursiveTypeAttrInterface, DINodeAttr> &mapping,
-    DenseSet<DistinctAttr> &openDecls) {
-  DistinctAttr recId;
-  if (auto recType = dyn_cast<DIRecursiveTypeAttrInterface>(node)) {
-    recId = recType.getRecId();
-
-    // Configure context.
-    if (recId) {
-      if (recType.isRecSelf()) {
-        // Replace selfRef based on the provided mapping.
-        if (DINodeAttr replacement = mapping.lookup(recType))
-          return replaceAndPruneRecursiveTypesImpl(replacement, mapping,
-                                                   openDecls);
-        return node;
-      }
+///   encountering a duplicate declaration, replace with a self-reference.
+static Attribute replaceAndPruneRecursiveTypes(
+    Attribute baseNode,
+    const DenseMap<DIRecursiveTypeAttrInterface, DINodeAttr> &mapping) {
 
-      auto [_, inserted] = openDecls.insert(recId);
-      if (!inserted) {
-        // This is a nested decl. Replace with recSelf.
-        return recType.getRecSelf(recId);
+  struct ReplacementState {
+    // The attribute being replaced.
+    Attribute node;
+    // The nested elements of `node`.
+    SmallVector<Attribute> attrs;
+    SmallVector<Type> types;
+    // The current attr being walked on (index into `attrs`).
+    size_t attrIndex;
+    // Whether or not any attr was replaced.
+    bool changed;
+
+    void replaceCurrAttr(Attribute attr) {
+      Attribute &target = attrs[attrIndex];
+      if (attr != target) {
+        changed = true;
+        target = attr;
       }
     }
-  }
+  };
 
-  // Collect sub attrs.
-  SmallVector<Attribute> attrs;
-  SmallVector<Type> types;
-  node.walkImmediateSubElements(
-      [&attrs](Attribute attr) { attrs.push_back(attr); },
-      [&types](Type type) { types.push_back(type); });
-
-  // Recurse into attributes.
-  bool changed = false;
-  for (auto it = attrs.begin(); it != attrs.end(); it++) {
-    Attribute replaced =
-        replaceAndPruneRecursiveTypesImpl(*it, mapping, openDecls);
-    if (replaced != *it) {
-      *it = replaced;
-      changed = true;
-    }
-  }
+  // Every iteration, perform replacement on the attribute at `attrIndex` at the
+  // top of `workStack`.
+  // If `attrIndex` reaches past the size of `attrs`, replace `node` with
+  // `attrs` & `types`, and pop off a stack frame and assign the replacement to
+  // the attr being replaced on the previous stack frame.
+  SmallVector<ReplacementState> workStack;
+  workStack.push_back({nullptr, {baseNode}, {}, 0, false});
 
-  Attribute result = node;
-  if (changed)
-    result = result.replaceImmediateSubElements(attrs, types);
+  DenseSet<DistinctAttr> openDecls;
+  while (workStack.size() > 1 || workStack.back().attrIndex == 0) {
+    ReplacementState &state = workStack.back();
+
+    // Check for popping condition.
+    if (state.attrIndex == state.attrs.size()) {
+      Attribute result = state.node;
+      if (state.changed)
+        result = result.replaceImmediateSubElements(state.attrs, state.types);
+
+      // Reset context.
+      if (auto recType = dyn_cast<DIRecursiveTypeAttrInterface>(state.node))
+        if (DistinctAttr recId = recType.getRecId())
+          openDecls.erase(recId);
+
+      workStack.pop_back();
+      ReplacementState &prevState = workStack.back();
+      prevState.replaceCurrAttr(result);
+      ++prevState.attrIndex;
+      continue;
+    }
 
-  // Reset context.
-  if (recId)
-    openDecls.erase(recId);
+    Attribute node = state.attrs[state.attrIndex];
+    if (auto recType = dyn_cast<DIRecursiveTypeAttrInterface>(node)) {
+      if (DistinctAttr recId = recType.getRecId()) {
+        if (recType.isRecSelf()) {
+          // Replace selfRef based on the provided mapping and re-walk from the
+          // replacement node (do not increment attrIndex).
+          if (DINodeAttr replacement = mapping.lookup(recType)) {
+            state.replaceCurrAttr(replacement);
+            continue;
+          }
+
+          // Otherwise, nothing to do. Advance to next attr.
+          ++state.attrIndex;
+          continue;
+        }
+
+        // Configure context.
+        auto [_, inserted] = openDecls.insert(recId);
+        if (!inserted) {
+          // This is a nested decl. Replace with recSelf. Nothing more to do.
+          state.replaceCurrAttr(recType.getRecSelf(recId));
+          ++state.attrIndex;
+          continue;
+        }
+      }
+    }
 
-  return result;
-}
+    // Recurse into this node.
+    workStack.push_back({node, {}, {}, 0, false});
+    ReplacementState &newState = workStack.back();
+    node.walkImmediateSubElements(
+        [&newState](Attribute attr) { newState.attrs.push_back(attr); },
+        [&newState](Type type) { newState.types.push_back(type); });
+  };
 
-static Attribute replaceAndPruneRecursiveTypes(
-    DINodeAttr node,
-    const DenseMap<DIRecursiveTypeAttrInterface, DINodeAttr> &mapping) {
-  DenseSet<DistinctAttr> openDecls;
-  return replaceAndPruneRecursiveTypesImpl(node, mapping, openDecls);
+  return workStack.back().attrs.front();
 }
 
 DINodeAttr DebugImporter::RecursionPruner::pruneOrPushTranslationStack(

>From 628f5201a47facc433e5195cacaa69e9b861a67c Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Thu, 4 Apr 2024 11:12:58 -0700
Subject: [PATCH 09/12] add cache to replacer

---
 mlir/lib/Target/LLVMIR/DebugImporter.cpp | 25 ++++++++++++++++++++++++
 1 file changed, 25 insertions(+)

diff --git a/mlir/lib/Target/LLVMIR/DebugImporter.cpp b/mlir/lib/Target/LLVMIR/DebugImporter.cpp
index ecc179eb68136c..2113af1aa44299 100644
--- a/mlir/lib/Target/LLVMIR/DebugImporter.cpp
+++ b/mlir/lib/Target/LLVMIR/DebugImporter.cpp
@@ -13,6 +13,7 @@
 #include "mlir/IR/Location.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/ScopeExit.h"
+#include "llvm/ADT/SetOperations.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/BinaryFormat/Dwarf.h"
 #include "llvm/IR/Constants.h"
@@ -355,6 +356,14 @@ static Attribute replaceAndPruneRecursiveTypes(
   SmallVector<ReplacementState> workStack;
   workStack.push_back({nullptr, {baseNode}, {}, 0, false});
 
+  // Replacement cache that remembers the context in which the cache is valid.
+  // All unboundRecIds must be in openDecls at time of lookup.
+  struct CacheWithContext {
+    DenseSet<DistinctAttr> unboundRecIds;
+    Attribute entry;
+  };
+  DenseMap<Attribute, CacheWithContext> replacementCache;
+
   DenseSet<DistinctAttr> openDecls;
   while (workStack.size() > 1 || workStack.back().attrIndex == 0) {
     ReplacementState &state = workStack.back();
@@ -370,6 +379,8 @@ static Attribute replaceAndPruneRecursiveTypes(
         if (DistinctAttr recId = recType.getRecId())
           openDecls.erase(recId);
 
+      replacementCache[state.node] = CacheWithContext{openDecls, result};
+
       workStack.pop_back();
       ReplacementState &prevState = workStack.back();
       prevState.replaceCurrAttr(result);
@@ -378,6 +389,20 @@ static Attribute replaceAndPruneRecursiveTypes(
     }
 
     Attribute node = state.attrs[state.attrIndex];
+
+    // Lookup in cache first.
+    if (auto it = replacementCache.find(node); it != replacementCache.end()) {
+      // If all the requried recIds are open decls, use cache.
+      if (llvm::set_is_subset(it->second.unboundRecIds, openDecls)) {
+        state.replaceCurrAttr(it->second.entry);
+        ++state.attrIndex;
+        continue;
+      }
+
+      // Otherwise, the cache entry is stale and can be removed now.
+      replacementCache.erase(it);
+    }
+
     if (auto recType = dyn_cast<DIRecursiveTypeAttrInterface>(node)) {
       if (DistinctAttr recId = recType.getRecId()) {
         if (recType.isRecSelf()) {

>From 9e2cb3a8c948f2d881e9ec91700e09d94eb3f86d Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Thu, 4 Apr 2024 13:31:44 -0700
Subject: [PATCH 10/12] simplify away cache with holes

---
 mlir/lib/Target/LLVMIR/DebugImporter.cpp     | 223 +++----------------
 mlir/lib/Target/LLVMIR/DebugImporter.h       |  49 ++--
 mlir/test/Target/LLVMIR/Import/debug-info.ll |  25 +--
 3 files changed, 59 insertions(+), 238 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/DebugImporter.cpp b/mlir/lib/Target/LLVMIR/DebugImporter.cpp
index 2113af1aa44299..9719b614f60584 100644
--- a/mlir/lib/Target/LLVMIR/DebugImporter.cpp
+++ b/mlir/lib/Target/LLVMIR/DebugImporter.cpp
@@ -319,139 +319,8 @@ getRecSelfConstructor(llvm::DINode *node) {
       .Default(CtorType());
 }
 
-/// An iterative attribute replacer that also handles pruning recursive types.
-/// - Recurses down the attribute tree while replacing attributes based on the
-///   provided replacement map.
-/// - Keeps track of the currently open recursive declarations, and upon
-///   encountering a duplicate declaration, replace with a self-reference.
-static Attribute replaceAndPruneRecursiveTypes(
-    Attribute baseNode,
-    const DenseMap<DIRecursiveTypeAttrInterface, DINodeAttr> &mapping) {
-
-  struct ReplacementState {
-    // The attribute being replaced.
-    Attribute node;
-    // The nested elements of `node`.
-    SmallVector<Attribute> attrs;
-    SmallVector<Type> types;
-    // The current attr being walked on (index into `attrs`).
-    size_t attrIndex;
-    // Whether or not any attr was replaced.
-    bool changed;
-
-    void replaceCurrAttr(Attribute attr) {
-      Attribute &target = attrs[attrIndex];
-      if (attr != target) {
-        changed = true;
-        target = attr;
-      }
-    }
-  };
-
-  // Every iteration, perform replacement on the attribute at `attrIndex` at the
-  // top of `workStack`.
-  // If `attrIndex` reaches past the size of `attrs`, replace `node` with
-  // `attrs` & `types`, and pop off a stack frame and assign the replacement to
-  // the attr being replaced on the previous stack frame.
-  SmallVector<ReplacementState> workStack;
-  workStack.push_back({nullptr, {baseNode}, {}, 0, false});
-
-  // Replacement cache that remembers the context in which the cache is valid.
-  // All unboundRecIds must be in openDecls at time of lookup.
-  struct CacheWithContext {
-    DenseSet<DistinctAttr> unboundRecIds;
-    Attribute entry;
-  };
-  DenseMap<Attribute, CacheWithContext> replacementCache;
-
-  DenseSet<DistinctAttr> openDecls;
-  while (workStack.size() > 1 || workStack.back().attrIndex == 0) {
-    ReplacementState &state = workStack.back();
-
-    // Check for popping condition.
-    if (state.attrIndex == state.attrs.size()) {
-      Attribute result = state.node;
-      if (state.changed)
-        result = result.replaceImmediateSubElements(state.attrs, state.types);
-
-      // Reset context.
-      if (auto recType = dyn_cast<DIRecursiveTypeAttrInterface>(state.node))
-        if (DistinctAttr recId = recType.getRecId())
-          openDecls.erase(recId);
-
-      replacementCache[state.node] = CacheWithContext{openDecls, result};
-
-      workStack.pop_back();
-      ReplacementState &prevState = workStack.back();
-      prevState.replaceCurrAttr(result);
-      ++prevState.attrIndex;
-      continue;
-    }
-
-    Attribute node = state.attrs[state.attrIndex];
-
-    // Lookup in cache first.
-    if (auto it = replacementCache.find(node); it != replacementCache.end()) {
-      // If all the requried recIds are open decls, use cache.
-      if (llvm::set_is_subset(it->second.unboundRecIds, openDecls)) {
-        state.replaceCurrAttr(it->second.entry);
-        ++state.attrIndex;
-        continue;
-      }
-
-      // Otherwise, the cache entry is stale and can be removed now.
-      replacementCache.erase(it);
-    }
-
-    if (auto recType = dyn_cast<DIRecursiveTypeAttrInterface>(node)) {
-      if (DistinctAttr recId = recType.getRecId()) {
-        if (recType.isRecSelf()) {
-          // Replace selfRef based on the provided mapping and re-walk from the
-          // replacement node (do not increment attrIndex).
-          if (DINodeAttr replacement = mapping.lookup(recType)) {
-            state.replaceCurrAttr(replacement);
-            continue;
-          }
-
-          // Otherwise, nothing to do. Advance to next attr.
-          ++state.attrIndex;
-          continue;
-        }
-
-        // Configure context.
-        auto [_, inserted] = openDecls.insert(recId);
-        if (!inserted) {
-          // This is a nested decl. Replace with recSelf. Nothing more to do.
-          state.replaceCurrAttr(recType.getRecSelf(recId));
-          ++state.attrIndex;
-          continue;
-        }
-      }
-    }
-
-    // Recurse into this node.
-    workStack.push_back({node, {}, {}, 0, false});
-    ReplacementState &newState = workStack.back();
-    node.walkImmediateSubElements(
-        [&newState](Attribute attr) { newState.attrs.push_back(attr); },
-        [&newState](Type type) { newState.types.push_back(type); });
-  };
-
-  return workStack.back().attrs.front();
-}
-
 DINodeAttr DebugImporter::RecursionPruner::pruneOrPushTranslationStack(
     llvm::DINode *node) {
-  // Lookup the cache first.
-  auto [result, unboundSelfRefs] = lookup(node);
-  if (result) {
-    // Need to inject unbound self-refs into the previous layer.
-    if (!unboundSelfRefs.empty())
-      translationStack.back().second.unboundSelfRefs.insert(
-          unboundSelfRefs.begin(), unboundSelfRefs.end());
-    return result;
-  }
-
   // If the node type is capable of being recursive, check if it's seen
   // before.
   auto recSelfCtor = getRecSelfConstructor(node);
@@ -465,7 +334,11 @@ DINodeAttr DebugImporter::RecursionPruner::pruneOrPushTranslationStack(
       // a different self-reference. Use that if possible.
       DIRecursiveTypeAttrInterface recSelf = iter->second.recSelf;
       if (!recSelf) {
-        DistinctAttr recId = DistinctAttr::create(UnitAttr::get(context));
+        DistinctAttr recId = nodeToRecId.lookup(node);
+        if (!recId) {
+          recId = DistinctAttr::create(UnitAttr::get(context));
+          nodeToRecId[node] = recId;
+        }
         recSelf = recSelfCtor(recId);
         iter->second.recSelf = recSelf;
       }
@@ -474,7 +347,8 @@ DINodeAttr DebugImporter::RecursionPruner::pruneOrPushTranslationStack(
       return cast<DINodeAttr>(recSelf);
     }
   }
-  return nullptr;
+
+  return lookup(node);
 }
 
 std::pair<DINodeAttr, bool>
@@ -493,39 +367,18 @@ DebugImporter::RecursionPruner::finalizeTranslation(llvm::DINode *node,
   if (DIRecursiveTypeAttrInterface recSelf = state.recSelf) {
     auto recType = cast<DIRecursiveTypeAttrInterface>(result);
     result = cast<DINodeAttr>(recType.withRecId(recSelf.getRecId()));
-
     // Remove this recSelf from the set of unbound selfRefs.
     state.unboundSelfRefs.erase(recSelf);
-
-    // Insert the newly resolved recursive type into the cache entries that
-    // rely on it.
-    // Only need to look at the caches at this level.
-    uint64_t numRemaining = state.cacheSize;
-    for (CachedTranslation &cacheEntry :
-         llvm::make_second_range(llvm::reverse(cache))) {
-      if (numRemaining == 0)
-        break;
-      --numRemaining;
-
-      if (auto refIter = cacheEntry.pendingReplacements.find(recSelf);
-          refIter != cacheEntry.pendingReplacements.end())
-        refIter->second = result;
-    }
   }
 
-  // Insert the current result into the cache.
-  state.cacheSize++;
-  auto [iter, inserted] = cache.try_emplace(node);
-  assert(inserted && "invalid state: caching the same DINode twice");
-  iter->second.attr = result;
-
-  // If this node had any unbound self-refs free when it is registered into
-  // the cache, set up replacement placeholders: This result will need these
-  // unbound self-refs to be replaced before being used.
-  for (DIRecursiveTypeAttrInterface selfRef : state.unboundSelfRefs)
-    iter->second.pendingReplacements.try_emplace(selfRef, nullptr);
-
-  return {result, state.unboundSelfRefs.empty()};
+  // Insert the result into our internal cache if it's not self-contained.
+  if (!state.unboundSelfRefs.empty()) {
+    auto [_, inserted] = dependentCache.try_emplace(
+        node, DependentTranslation{result, state.unboundSelfRefs});
+    assert(inserted && "invalid state: caching the same DINode twice");
+    return {result, false};
+  }
+  return {result, true};
 }
 
 void DebugImporter::RecursionPruner::popTranslationStack(llvm::DINode *node) {
@@ -540,10 +393,7 @@ void DebugImporter::RecursionPruner::popTranslationStack(llvm::DINode *node) {
   if (translationStack.size() == 1) {
     assert(currLayerState.unboundSelfRefs.empty() &&
            "internal error: unbound recursive self reference at top level.");
-    assert(currLayerState.cacheSize == cache.size() &&
-           "internal error: inconsistent cache size");
     translationStack.pop_back();
-    cache.clear();
     return;
   }
 
@@ -551,45 +401,22 @@ void DebugImporter::RecursionPruner::popTranslationStack(llvm::DINode *node) {
   TranslationState &nextLayerState = (++translationStack.rbegin())->second;
   nextLayerState.unboundSelfRefs.insert(currLayerState.unboundSelfRefs.begin(),
                                         currLayerState.unboundSelfRefs.end());
-
-  // The current layer cache is now considered part of the lower layer cache.
-  nextLayerState.cacheSize += currLayerState.cacheSize;
-
-  // Finally pop off this layer when all bookkeeping is done.
   translationStack.pop_back();
 }
 
-std::pair<DINodeAttr, DenseSet<DIRecursiveTypeAttrInterface>>
-DebugImporter::RecursionPruner::lookup(llvm::DINode *node) {
-  auto cacheIter = cache.find(node);
-  if (cacheIter == cache.end())
+DINodeAttr DebugImporter::RecursionPruner::lookup(llvm::DINode *node) {
+  auto cacheIter = dependentCache.find(node);
+  if (cacheIter == dependentCache.end())
     return {};
 
-  CachedTranslation &entry = cacheIter->second;
-
-  if (entry.pendingReplacements.empty())
-    return std::make_pair(entry.attr, DenseSet<DIRecursiveTypeAttrInterface>{});
-
-  Attribute replacedAttr =
-      replaceAndPruneRecursiveTypes(entry.attr, entry.pendingReplacements);
-  DINodeAttr result = cast<DINodeAttr>(replacedAttr);
-
-  // Update cache entry to save replaced version and remove already-applied
-  // replacements.
-  entry.attr = result;
-  DenseSet<DIRecursiveTypeAttrInterface> unboundRefs;
-  DenseSet<DIRecursiveTypeAttrInterface> boundRefs;
-  for (auto [refSelf, replacement] : entry.pendingReplacements) {
-    if (replacement)
-      boundRefs.insert(refSelf);
-    else
-      unboundRefs.insert(refSelf);
-  }
-
-  for (DIRecursiveTypeAttrInterface ref : boundRefs)
-    entry.pendingReplacements.erase(ref);
+  DependentTranslation &entry = cacheIter->second;
+  if (llvm::set_is_subset(entry.unboundSelfRefs,
+                          translationStack.back().second.unboundSelfRefs))
+    return entry.attr;
 
-  return std::make_pair(result, unboundRefs);
+  // Stale cache entry.
+  dependentCache.erase(cacheIter);
+  return {};
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/DebugImporter.h b/mlir/lib/Target/LLVMIR/DebugImporter.h
index 20f9b90154a358..8b22dc63456775 100644
--- a/mlir/lib/Target/LLVMIR/DebugImporter.h
+++ b/mlir/lib/Target/LLVMIR/DebugImporter.h
@@ -116,12 +116,11 @@ class DebugImporter {
   public:
     RecursionPruner(MLIRContext *context) : context(context) {}
 
-    /// If this node was previously cached, returns the cached result.
     /// If this node is a recursive instance that was previously seen, returns a
-    /// self-reference.
-    /// Otherwise, returns null attr, and a translation stack frame is created
-    /// for this node. Expects `finalizeTranslation` & `popTranslationStack`
-    /// to be called on this node later.
+    /// self-reference. If this node was previously cached, returns the cached
+    /// result. Otherwise, returns null attr, and a translation stack frame is
+    /// created for this node. Expects `finalizeTranslation` &
+    /// `popTranslationStack` to be called on this node later.
     DINodeAttr pruneOrPushTranslationStack(llvm::DINode *node);
 
     /// Register the translated result of `node`. Returns the finalized result
@@ -135,33 +134,27 @@ class DebugImporter {
     void popTranslationStack(llvm::DINode *node);
 
   private:
-    /// Returns the cached result (if exists) with all known replacements
-    /// applied. Also returns the set of unbound self-refs that are unresolved
-    /// in this cached result. The cache entry will also be updated with the
-    /// replaced result and with the applied replacements removed from the
-    /// pendingReplacements map.
-    std::pair<DINodeAttr, DenseSet<DIRecursiveTypeAttrInterface>>
-    lookup(llvm::DINode *node);
+    /// Returns the cached result (if exists) or null.
+    /// The cache entry will be removed if not all of its dependent self-refs
+    /// exists.
+    DINodeAttr lookup(llvm::DINode *node);
 
     MLIRContext *context;
 
-    /// A lazy cached translation that contains the translated attribute as well
-    /// as any unbound self-references that need to be replaced at lookup.
-    struct CachedTranslation {
+    /// A cached translation that contains the translated attribute as well
+    /// as any unbound self-references that it depends on.
+    struct DependentTranslation {
       /// The translated attr. May contain unbound self-references for other
       /// recursive attrs.
       DINodeAttr attr;
-      /// The pending replacements that need to be run on this `attr` before it
-      /// can be used.
-      /// Each key is a recursive self-ref, and each value is a recursive decl
-      /// that may contain additional unbound self-refs that need to be
-      /// replaced. Each replacement will be applied at most once. Once a
-      /// replacement is applied, the cached `attr` will be updated, and the
-      /// replacement will be removed from this map.
-      DenseMap<DIRecursiveTypeAttrInterface, DINodeAttr> pendingReplacements;
+      /// The set of unbound self-refs that this cached entry refers to. All
+      /// these self-refs must exist for the cached entry to be valid.
+      DenseSet<DIRecursiveTypeAttrInterface> unboundSelfRefs;
     };
     /// A mapping between LLVM debug metadata and the corresponding attribute.
-    llvm::MapVector<llvm::DINode *, CachedTranslation> cache;
+    /// Only contains those with unboundSelfRefs. Fully self-contained attrs
+    /// will be cached by the outer main translator.
+    DenseMap<llvm::DINode *, DependentTranslation> dependentCache;
 
     /// Each potentially recursive node will have a TranslationState pushed onto
     /// the `translationStack` to keep track of whether this node is actually
@@ -171,9 +164,6 @@ class DebugImporter {
       /// instance of itself is seen while translating it). Null if this node
       /// has not been seen again deeper in the translation stack.
       DIRecursiveTypeAttrInterface recSelf;
-      /// The number of cache entries belonging to this layer of the translation
-      /// stack. This corresponds to the last `cacheSize` entries of `cache`.
-      uint64_t cacheSize = 0;
       /// All the unbound recursive self references in this layer of the
       /// translation stack.
       DenseSet<DIRecursiveTypeAttrInterface> unboundSelfRefs;
@@ -184,6 +174,11 @@ class DebugImporter {
     /// later when the stack is deeper, the node is recursive, and its
     /// TranslationState is assigned a recSelf.
     llvm::MapVector<llvm::DINode *, TranslationState> translationStack;
+
+    /// A mapping between DINodes that are recursive, and their assigned recId.
+    /// This is kept so that repeated occurrences of the same node can reuse the
+    /// same ID and be deduplicated.
+    DenseMap<llvm::DINode *, DistinctAttr> nodeToRecId;
   };
   RecursionPruner recursionPruner;
 
diff --git a/mlir/test/Target/LLVMIR/Import/debug-info.ll b/mlir/test/Target/LLVMIR/Import/debug-info.ll
index 2218e6cd5b7bbd..e2ef94617dd7d3 100644
--- a/mlir/test/Target/LLVMIR/Import/debug-info.ll
+++ b/mlir/test/Target/LLVMIR/Import/debug-info.ll
@@ -663,19 +663,18 @@ define void @class_field(ptr %arg1) !dbg !18 {
 ; This should result in a cached instance of B --> C --> B_SELF to be reused
 ; when visiting B from C (after visiting B from A).
 
-; CHECK-DAG: #[[C_INNER:.+]] = #llvm.di_composite_type<{{.*}}name = "C", {{.*}}scope = #[[A_SELF:[^ ]+]], {{.*}}elements = #[[C_INNER_TO_B_SELF:.+]]
-; CHECK-DAG: #[[C_INNER_TO_B_SELF:.+]] = #llvm.di_derived_type<{{.*}}name = "->B", {{.*}}baseType = #[[B_SELF:[^ ]+]]>
-; CHECK-DAG: #[[B_TO_C_INNER:.+]] = #llvm.di_derived_type<{{.*}}name = "->C", {{.*}}baseType = #[[C_INNER]]
-; CHECK-DAG: #[[B:.+]] = #llvm.di_composite_type<{{.*}}recId = [[B_RECID:.+]], {{.*}}name = "B", {{.*}}scope = #[[A_SELF]], {{.*}}elements = #[[B_TO_C_INNER]]
-; CHECK-DAG: #[[A_SELF]] = #llvm.di_composite_type<{{.*}}recId = [[A_RECID:.+]]>
-; CHECK-DAG: #[[B_SELF]] = #llvm.di_composite_type<{{.*}}recId = [[B_RECID]]>
+; CHECK-DAG: #[[A:.+]] = #llvm.di_composite_type<{{.*}}name = "A", {{.*}}elements = #[[TO_B_OUTER:.+]], #[[TO_C_OUTER:.+]]>
+; CHECK-DAG: #llvm.di_subprogram<{{.*}}scope = #[[A]],
 
-; CHECK-DAG: #[[TO_B:.+]] = #llvm.di_derived_type<{{.*}}name = "->B", {{.*}}baseType = #[[B]]
-; CHECK-DAG: #[[C_OUTER:.+]] = #llvm.di_composite_type<{{.*}}name = "C", {{.*}}scope = #[[A_SELF]], {{.*}}elements = #[[TO_B]]
-; CHECK-DAG: #[[TO_C:.+]] = #llvm.di_derived_type<{{.*}}name = "->C", {{.*}}baseType = #[[C_OUTER]]
-; CHECK-DAG: #[[A:.+]] = #llvm.di_composite_type<{{.*}}recId = [[A_RECID]], {{.*}}name = "A", {{.*}}elements = #[[TO_B]], #[[TO_C]]
+; CHECK-DAG: #[[TO_B_OUTER]] = #llvm.di_derived_type<{{.*}}name = "->B", {{.*}}baseType = #[[B_OUTER:.+]]>
+; CHECK-DAG: #[[B_OUTER]] = #llvm.di_composite_type<{{.*}}recId = [[B_RECID:.+]], {{.*}}name = "B", {{.*}}elements = #[[TO_C_INNER:.+]]>
+; CHECK-DAG: #[[TO_C_INNER]] = #llvm.di_derived_type<{{.*}}name = "->C", {{.*}}baseType = #[[C_INNER:.+]]>
+; CHECK-DAG: #[[C_INNER]] = #llvm.di_composite_type<{{.*}}name = "C", {{.*}}elements = #[[TO_B_SELF:.+]]>
+; CHECK-DAG: #[[TO_B_SELF]] = #llvm.di_derived_type<{{.*}}name = "->B", {{.*}}baseType = #[[B_SELF:.+]]>
+; CHECK-DAG: #[[B_SELF]] = #llvm.di_composite_type<{{.*}}recId = [[B_RECID]]>
 
-; CHECK-DAG: #llvm.di_subprogram<{{.*}}scope = #[[A]],
+; CHECK-DAG: #[[TO_C_OUTER]] = #llvm.di_derived_type<{{.*}}name = "->C", {{.*}}baseType = #[[C_OUTER:.+]]>
+; CHECK-DAG: #[[C_OUTER]] = #llvm.di_composite_type<{{.*}}name = "C", {{.*}}elements = #[[TO_B_OUTER]]>
 
 define void @class_field(ptr %arg1) !dbg !18 {
   ret void
@@ -688,8 +687,8 @@ define void @class_field(ptr %arg1) !dbg !18 {
 !2 = !DIFile(filename: "debug-info.ll", directory: "/")
 
 !3 = !DICompositeType(tag: DW_TAG_class_type, name: "A", file: !2, line: 42, flags: DIFlagTypePassByReference | DIFlagNonTrivial, elements: !4)
-!5 = !DICompositeType(tag: DW_TAG_class_type, name: "B", scope: !3, file: !2, line: 42, flags: DIFlagTypePassByReference | DIFlagNonTrivial, elements: !10)
-!6 = !DICompositeType(tag: DW_TAG_class_type, name: "C", scope: !3, file: !2, line: 42, flags: DIFlagTypePassByReference | DIFlagNonTrivial, elements: !9)
+!5 = !DICompositeType(tag: DW_TAG_class_type, name: "B", file: !2, line: 42, flags: DIFlagTypePassByReference | DIFlagNonTrivial, elements: !10)
+!6 = !DICompositeType(tag: DW_TAG_class_type, name: "C", file: !2, line: 42, flags: DIFlagTypePassByReference | DIFlagNonTrivial, elements: !9)
 
 !7 = !DIDerivedType(tag: DW_TAG_member, name: "->B", file: !2, baseType: !5)
 !8 = !DIDerivedType(tag: DW_TAG_member, name: "->C", file: !2, baseType: !6)

>From 11ba5d4a4bd67b812e39000c8c59bced94f05274 Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Thu, 4 Apr 2024 13:50:59 -0700
Subject: [PATCH 11/12] nit space

---
 mlir/lib/Target/LLVMIR/DebugImporter.cpp | 1 +
 1 file changed, 1 insertion(+)

diff --git a/mlir/lib/Target/LLVMIR/DebugImporter.cpp b/mlir/lib/Target/LLVMIR/DebugImporter.cpp
index 9719b614f60584..3dc2d4e3a7509f 100644
--- a/mlir/lib/Target/LLVMIR/DebugImporter.cpp
+++ b/mlir/lib/Target/LLVMIR/DebugImporter.cpp
@@ -307,6 +307,7 @@ DINodeAttr DebugImporter::translate(llvm::DINode *node) {
 //===----------------------------------------------------------------------===//
 // RecursionPruner
 //===----------------------------------------------------------------------===//
+
 /// Get the `getRecSelf` constructor for the translated type of `node` if its
 /// translated DITypeAttr supports recursion. Otherwise, returns nullptr.
 static function_ref<DIRecursiveTypeAttrInterface(DistinctAttr)>

>From dffb213884d0bb460d3968f519d7b68f8a7cfe5a Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Fri, 5 Apr 2024 11:23:47 -0700
Subject: [PATCH 12/12] allow nested rec-decl

---
 mlir/lib/Target/LLVMIR/DebugTranslation.cpp | 13 ++++------
 mlir/test/Target/LLVMIR/llvmir-debug.mlir   | 28 +++++++++++++++++++++
 2 files changed, 33 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp
index 642359a23756af..f6e05e25ace6ae 100644
--- a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp
@@ -216,18 +216,15 @@ DebugTranslation::translateImpl(DIGlobalVariableAttr attr) {
 llvm::DIType *
 DebugTranslation::translateRecursive(DIRecursiveTypeAttrInterface attr) {
   DistinctAttr recursiveId = attr.getRecId();
-  if (attr.isRecSelf()) {
-    auto *iter = recursiveTypeMap.find(recursiveId);
-    assert(iter != recursiveTypeMap.end() && "unbound DI recursive self type");
+  if (auto *iter = recursiveTypeMap.find(recursiveId);
+      iter != recursiveTypeMap.end()) {
     return iter->second;
+  } else {
+    assert(!attr.isRecSelf() && "unbound DI recursive self type");
   }
 
   auto setRecursivePlaceholder = [&](llvm::DIType *placeholder) {
-    [[maybe_unused]] auto [iter, inserted] =
-        recursiveTypeMap.try_emplace(recursiveId, placeholder);
-    (void)iter;
-    (void)inserted;
-    assert(inserted && "illegal reuse of recursive id");
+    recursiveTypeMap.try_emplace(recursiveId, placeholder);
   };
 
   llvm::DIType *result =
diff --git a/mlir/test/Target/LLVMIR/llvmir-debug.mlir b/mlir/test/Target/LLVMIR/llvmir-debug.mlir
index 785a525caab8c2..c4ca0e83f81ee3 100644
--- a/mlir/test/Target/LLVMIR/llvmir-debug.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-debug.mlir
@@ -423,3 +423,31 @@ llvm.mlir.global @global_variable() {dbg_expr = #di_global_variable_expression}
 // CHECK: ![[SCOPE]] = !DISubprogram({{.*}}type: ![[SUBROUTINE:[0-9]+]],
 // CHECK: ![[SUBROUTINE]] = !DISubroutineType(types: ![[SR_TYPES:[0-9]+]])
 // CHECK: ![[SR_TYPES]] = !{![[COMP]]}
+
+// -----
+
+// Ensures nested recursive decls work.
+// The output should be identical to if the inner composite type decl was
+// replaced with the recursive self reference.
+
+#di_file = #llvm.di_file<"test.mlir" in "/">
+#di_composite_type_self = #llvm.di_composite_type<tag = DW_TAG_null, recId = distinct[0]<>>
+
+#di_subroutine_type_inner = #llvm.di_subroutine_type<types = #di_composite_type_self>
+#di_subprogram_inner = #llvm.di_subprogram<scope = #di_file, file = #di_file, subprogramFlags = Optimized, type = #di_subroutine_type_inner>
+#di_composite_type_inner = #llvm.di_composite_type<tag = DW_TAG_class_type, recId = distinct[0]<>, scope = #di_subprogram_inner>
+
+#di_subroutine_type = #llvm.di_subroutine_type<types = #di_composite_type_inner>
+#di_subprogram = #llvm.di_subprogram<scope = #di_file, file = #di_file, subprogramFlags = Optimized, type = #di_subroutine_type>
+#di_composite_type = #llvm.di_composite_type<tag = DW_TAG_class_type, recId = distinct[0]<>, scope = #di_subprogram>
+
+#di_global_variable = #llvm.di_global_variable<file = #di_file, line = 1, type = #di_composite_type>
+#di_global_variable_expression = #llvm.di_global_variable_expression<var = #di_global_variable>
+
+llvm.mlir.global @global_variable() {dbg_expr = #di_global_variable_expression} : !llvm.struct<()>
+
+// CHECK: distinct !DIGlobalVariable({{.*}}type: ![[COMP:[0-9]+]],
+// CHECK: ![[COMP]] = distinct !DICompositeType({{.*}}scope: ![[SCOPE:[0-9]+]],
+// CHECK: ![[SCOPE]] = !DISubprogram({{.*}}type: ![[SUBROUTINE:[0-9]+]],
+// CHECK: ![[SUBROUTINE]] = !DISubroutineType(types: ![[SR_TYPES:[0-9]+]])
+// CHECK: ![[SR_TYPES]] = !{![[COMP]]}



More information about the Mlir-commits mailing list