[Mlir-commits] [mlir] [MLIR][LLVM] Recursion importer handle repeated self-references (PR #87295)

Billy Zhu llvmlistbot at llvm.org
Mon Apr 1 16:43:10 PDT 2024


https://github.com/zyx-billy created https://github.com/llvm/llvm-project/pull/87295

Followup to this discussion: https://github.com/llvm/llvm-project/pull/80251#discussion_r1535599920.

The previous debug importer was correct but inefficient. For cases with mutual recursion that contain more than one back-edge, each back-edge would result in a new translated instance. This is because the previous implementation never caches any translated result with unbounded self-references. This means all translation inside a recursive context is performed from scratch, which will incur repeated run-time cost as well as repeated attribute sub-trees in the translated IR (differing only in their `recId`s).

This PR refactors the importer to handle caching inside a recursive context.
- In the presence of unbound self-refs, the translation result is cached with the assumption that any future usage of this cached result must replace these unbound self-refs.
- Any time a recursive type is resolved (translated into a self-contained DINodeAttr), we update the cache so that unbound self-refs that refer to this recursive type is replaced with the self-contained version.
- To improve efficiency, this replacement is done lazily until the particular cached entry is retrieved. On retrieval, we apply all pending updates to the cache entry before returning it. This eliminates the need to perform attribute replacing unless the cache is actually needed.

>From 01fd161b6e26f36261ee1076ec8b625e233aecbe Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Mon, 1 Apr 2024 14:56:46 -0700
Subject: [PATCH] create translation helper

---
 mlir/lib/Target/LLVMIR/DebugImporter.cpp     | 331 +++++++++++++++----
 mlir/lib/Target/LLVMIR/DebugImporter.h       |  20 +-
 mlir/test/Target/LLVMIR/Import/debug-info.ll |  41 +++
 3 files changed, 320 insertions(+), 72 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/DebugImporter.cpp b/mlir/lib/Target/LLVMIR/DebugImporter.cpp
index 779ad26fc847e6..3de62d960ffc44 100644
--- a/mlir/lib/Target/LLVMIR/DebugImporter.cpp
+++ b/mlir/lib/Target/LLVMIR/DebugImporter.cpp
@@ -25,6 +25,269 @@ using namespace mlir;
 using namespace mlir::LLVM;
 using namespace mlir::LLVM::detail;
 
+/// Get the `getRecSelf` constructor for the translated type of `node` if its
+/// translated DITypeAttr supports recursion. Otherwise, returns nullptr.
+static function_ref<DIRecursiveTypeAttrInterface(DistinctAttr)>
+getRecSelfConstructor(llvm::DINode *node) {
+  using CtorType = function_ref<DIRecursiveTypeAttrInterface(DistinctAttr)>;
+  return TypeSwitch<llvm::DINode *, CtorType>(node)
+      .Case([&](llvm::DICompositeType *) {
+        return CtorType(DICompositeTypeAttr::getRecSelf);
+      })
+      .Default(CtorType());
+}
+
+/// Translation helper for recursive DINodes.
+/// Works alongside a stack-based DINode translator (the "main translator") for
+/// gracefully handling DINodes that are recursive.
+///
+/// Usage:
+/// - Before translating a node, call `tryPrune` to see if the pruner can
+///   preempt this translation. If this is a node that the pruner already knows
+///   how to handle, it will return the translated DINodeAttr.
+/// - After a node is successfully translated by the main translator, call
+///   `finalizeTranslation` to save the translated result with the pruner, and
+///   give it a chance to further modify the result.
+/// - Regardless of success or failure by the main translator, always call
+///   `finally` at the end of translating a node. This is necessary to keep the
+///   internal book-keeping in sync.
+///
+/// This helper maintains an internal cache so that no recursive type will
+/// be translated more than once by the main translator.
+/// This internal cache is different from the cache maintained by the main
+/// translator because it may store nodes that are not self-contained (i.e.
+/// contain unbounded recursive self-references).
+class mlir::LLVM::detail::RecursionPruner {
+public:
+  RecursionPruner(MLIRContext *context) : context(context) {}
+
+  /// If this node was previously cached, returns the cached result.
+  /// If this node is a recursive instance that was previously seen, returns a
+  /// self-reference.
+  /// Otherwise, returns null attr.
+  DINodeAttr tryPrune(llvm::DINode *node) {
+    // Lookup the cache first.
+    auto [result, unboundSelfRefs] = lookup(node);
+    if (result) {
+      // Need to inject unbound self-refs into the previous layer.
+      if (!unboundSelfRefs.empty())
+        translationStack.back().second.unboundSelfRefs.insert(
+            unboundSelfRefs.begin(), unboundSelfRefs.end());
+      return result;
+    }
+
+    // If the node type is capable of being recursive, check if it's seen
+    // before.
+    auto recSelfCtor = getRecSelfConstructor(node);
+    if (recSelfCtor) {
+      // If a cyclic dependency is detected since the same node is being
+      // traversed twice, emit a recursive self type, and mark the duplicate
+      // node on the translationStack so it can emit a recursive decl type.
+      auto [iter, inserted] = translationStack.try_emplace(node);
+      if (!inserted) {
+        // The original node may have already been assigned a recursive ID from
+        // a different self-reference. Use that if possible.
+        DIRecursiveTypeAttrInterface recSelf = iter->second.recSelf;
+        if (!recSelf) {
+          DistinctAttr recId = DistinctAttr::create(UnitAttr::get(context));
+          recSelf = recSelfCtor(recId);
+          iter->second.recSelf = recSelf;
+        }
+        // Inject the self-ref into the previous layer.
+        translationStack.back().second.unboundSelfRefs.insert(recSelf);
+        return cast<DINodeAttr>(recSelf);
+      }
+    }
+    return nullptr;
+  }
+
+  /// Register the translated result of `node`. Returns the finalized result
+  /// (with recId if recursive) and whether the result is self-contained
+  /// (i.e. contains no unbound self-refs).
+  std::pair<DINodeAttr, bool> finalizeTranslation(llvm::DINode *node,
+                                                  DINodeAttr result) {
+    // If `node` is not a potentially recursive type, it will not be on the
+    // translation stack. Nothing to set in this case.
+    if (translationStack.empty())
+      return {result, true};
+    if (translationStack.back().first != node)
+      return {result, translationStack.back().second.unboundSelfRefs.empty()};
+
+    TranslationState &state = translationStack.back().second;
+
+    // If this node is actually recursive, set the recId onto `result`.
+    if (DIRecursiveTypeAttrInterface recSelf = state.recSelf) {
+      auto recType = cast<DIRecursiveTypeAttrInterface>(result);
+      result = cast<DINodeAttr>(recType.withRecId(recSelf.getRecId()));
+
+      // Remove this recSelf from the set of unbound selfRefs.
+      state.unboundSelfRefs.erase(recSelf);
+
+      // Insert the newly resolved recursive type into the cache entries that
+      // rely on it.
+      // Only need to look at the caches at this level.
+      uint64_t numRemaining = state.cacheSize;
+      for (auto &cacheEntry : llvm::make_range(cache.rbegin(), cache.rend())) {
+        if (numRemaining == 0)
+          break;
+
+        if (auto refIter = cacheEntry.second.pendingReplacements.find(recSelf);
+            refIter != cacheEntry.second.pendingReplacements.end()) {
+          refIter->second = result;
+        }
+
+        --numRemaining;
+      }
+    }
+
+    // Insert the current result into the cache.
+    state.cacheSize++;
+    auto [iter, inserted] = cache.try_emplace(node);
+    assert(inserted && "invalid state: caching the same DINode twice");
+    iter->second.attr = result;
+
+    // If this node had any unbound self-refs free when it is registered into
+    // the cache, set up replacement placeholders: This result will need these
+    // unbound self-refs to be replaced before being used.
+    for (DIRecursiveTypeAttrInterface selfRef : state.unboundSelfRefs)
+      iter->second.pendingReplacements.try_emplace(selfRef, nullptr);
+
+    return {result, state.unboundSelfRefs.empty()};
+  }
+
+  /// Pop off a frame from the translation stack after a node is done being
+  /// translated.
+  void finally(llvm::DINode *node) {
+    // If `node` is not a potentially recursive type, it will not be on the
+    // translation stack. Nothing to handle in this case.
+    if (translationStack.empty() || translationStack.back().first != node)
+      return;
+
+    // At the end of the stack, all unbound self-refs must be resolved already,
+    // and the entire cache should be accounted for.
+    TranslationState &currLayerState = translationStack.back().second;
+    if (translationStack.size() == 1) {
+      assert(currLayerState.unboundSelfRefs.empty() &&
+             "internal error: unbound recursive self reference at top level.");
+      assert(currLayerState.cacheSize == cache.size() &&
+             "internal error: inconsistent cache size");
+      translationStack.pop_back();
+      cache.clear();
+      return;
+    }
+
+    // Copy unboundSelfRefs down to the previous level.
+    TranslationState &nextLayerState = (++translationStack.rbegin())->second;
+    nextLayerState.unboundSelfRefs.insert(
+        currLayerState.unboundSelfRefs.begin(),
+        currLayerState.unboundSelfRefs.end());
+
+    // The current layer cache is now considered part of the lower layer cache.
+    nextLayerState.cacheSize += currLayerState.cacheSize;
+
+    // Finally pop off this layer when all bookkeeping is done.
+    translationStack.pop_back();
+  }
+
+private:
+  /// Returns the cached result (if exists) with all known replacements applied.
+  /// Also returns the set of unbound self-refs that are unresolved in this
+  /// cached result.
+  /// The cache entry will also be updated with the replaced result and with the
+  /// applied replacements removed from the pendingReplacements map.
+  std::pair<DINodeAttr, llvm::DenseSet<DIRecursiveTypeAttrInterface>>
+  lookup(llvm::DINode *node) {
+    auto cacheIter = cache.find(node);
+    if (cacheIter == cache.end())
+      return {};
+
+    CachedTranslation &entry = cacheIter->second;
+
+    if (entry.pendingReplacements.empty())
+      return std::make_pair(entry.attr,
+                            llvm::DenseSet<DIRecursiveTypeAttrInterface>{});
+
+    mlir::AttrTypeReplacer replacer;
+    replacer.addReplacement(
+        [&entry](DIRecursiveTypeAttrInterface attr)
+            -> std::optional<std::pair<Attribute, WalkResult>> {
+          if (auto replacement = entry.pendingReplacements.lookup(attr)) {
+            // A replacement may contain additional unbound self-refs.
+            return std::make_pair(replacement, mlir::WalkResult::advance());
+          }
+          return std::make_pair(attr, mlir::WalkResult::skip());
+        });
+
+    Attribute replacedAttr = replacer.replace(entry.attr);
+    DINodeAttr result = cast<DINodeAttr>(replacedAttr);
+
+    // Update cache entry to save replaced version and remove already-applied
+    // replacements.
+    entry.attr = result;
+    DenseSet<DIRecursiveTypeAttrInterface> unboundRefs;
+    DenseSet<DIRecursiveTypeAttrInterface> boundRefs;
+    for (auto [refSelf, replacement] : entry.pendingReplacements) {
+      if (replacement)
+        boundRefs.insert(refSelf);
+      else
+        unboundRefs.insert(refSelf);
+    }
+
+    for (DIRecursiveTypeAttrInterface ref : boundRefs)
+      entry.pendingReplacements.erase(ref);
+
+    return std::make_pair(result, unboundRefs);
+  }
+
+  MLIRContext *context;
+
+  /// A lazy cached translation that contains the translated attribute as well
+  /// as any unbound self-references that need to be replaced at lookup.
+  struct CachedTranslation {
+    /// The translated attr. May contain unbound self-references for other
+    /// recursive attrs.
+    DINodeAttr attr;
+    /// The pending replacements that need to be run on this `attr` before it
+    /// can be used.
+    /// Each key is a recursive self-ref, and each value is a recursive decl
+    /// that may contain additional unbound self-refs that need to be replaced.
+    /// Each replacement will be applied at most once. Once a replacement is
+    /// applied, the cached `attr` will be updated, and the replacement will be
+    /// removed from this map.
+    DenseMap<DIRecursiveTypeAttrInterface, DINodeAttr> pendingReplacements;
+  };
+  /// A mapping between LLVM debug metadata and the corresponding attribute.
+  llvm::MapVector<llvm::DINode *, CachedTranslation> cache;
+
+  /// Each potentially recursive node will have a TranslationState pushed onto
+  /// the `translationStack` to keep track of whether this node is actually
+  /// recursive (i.e. has self-references inside), and other book-keeping.
+  struct TranslationState {
+    /// The rec-self if this node is indeed a recursive node (i.e. another
+    /// instance of itself is seen while translating it). Null if this node
+    /// has not been seen again deeper in the translation stack.
+    DIRecursiveTypeAttrInterface recSelf;
+    /// The number of cache entries belonging to this layer of the translation
+    /// stack. This corresponds to the last `cacheSize` entries of `cache`.
+    uint64_t cacheSize = 0;
+    /// All the unbound recursive self references in this layer of the
+    /// translation stack.
+    DenseSet<DIRecursiveTypeAttrInterface> unboundSelfRefs;
+  };
+  /// A stack that stores the metadata nodes that are being traversed. The stack
+  /// is used to handle cyclic dependencies during metadata translation.
+  /// Each node is pushed with an empty TranslationState. If it is ever seen
+  /// later when the stack is deeper, the node is recursive, and its
+  /// TranslationState is assigned a recSelf.
+  llvm::MapVector<llvm::DINode *, TranslationState> translationStack;
+};
+
+DebugImporter::DebugImporter(ModuleOp mlirModule)
+    : recursionPruner(new RecursionPruner(mlirModule.getContext())),
+      context(mlirModule.getContext()), mlirModule(mlirModule) {}
+
+DebugImporter::~DebugImporter() {}
+
 Location DebugImporter::translateFuncLocation(llvm::Function *func) {
   llvm::DISubprogram *subprogram = func->getSubprogram();
   if (!subprogram)
@@ -246,42 +509,12 @@ DINodeAttr DebugImporter::translate(llvm::DINode *node) {
   if (DINodeAttr attr = nodeToAttr.lookup(node))
     return attr;
 
-  // If the node type is capable of being recursive, check if it's seen before.
-  auto recSelfCtor = getRecSelfConstructor(node);
-  if (recSelfCtor) {
-    // If a cyclic dependency is detected since the same node is being traversed
-    // twice, emit a recursive self type, and mark the duplicate node on the
-    // translationStack so it can emit a recursive decl type.
-    auto [iter, inserted] = translationStack.try_emplace(node, nullptr);
-    if (!inserted) {
-      // The original node may have already been assigned a recursive ID from
-      // a different self-reference. Use that if possible.
-      DistinctAttr recId = iter->second;
-      if (!recId) {
-        recId = DistinctAttr::create(UnitAttr::get(context));
-        iter->second = recId;
-      }
-      unboundRecursiveSelfRefs.back().insert(recId);
-      return cast<DINodeAttr>(recSelfCtor(recId));
-    }
-  }
-
-  unboundRecursiveSelfRefs.emplace_back();
-
-  auto guard = llvm::make_scope_exit([&]() {
-    if (recSelfCtor)
-      translationStack.pop_back();
+  // Register with the recursive translator. If it is seen before, return the
+  // result immediately.
+  if (DINodeAttr attr = recursionPruner->tryPrune(node))
+    return attr;
 
-    // Copy unboundRecursiveSelfRefs down to the previous level.
-    if (unboundRecursiveSelfRefs.size() == 1)
-      assert(unboundRecursiveSelfRefs.back().empty() &&
-             "internal error: unbound recursive self reference at top level.");
-    else
-      unboundRecursiveSelfRefs[unboundRecursiveSelfRefs.size() - 2].insert(
-          unboundRecursiveSelfRefs.back().begin(),
-          unboundRecursiveSelfRefs.back().end());
-    unboundRecursiveSelfRefs.pop_back();
-  });
+  auto guard = llvm::make_scope_exit([&]() { recursionPruner->finally(node); });
 
   // Convert the debug metadata if possible.
   auto translateNode = [this](llvm::DINode *node) -> DINodeAttr {
@@ -318,20 +551,12 @@ DINodeAttr DebugImporter::translate(llvm::DINode *node) {
     return nullptr;
   };
   if (DINodeAttr attr = translateNode(node)) {
-    // If this node was marked as recursive, set its recId.
-    if (auto recType = dyn_cast<DIRecursiveTypeAttrInterface>(attr)) {
-      if (DistinctAttr recId = translationStack.lookup(node)) {
-        attr = cast<DINodeAttr>(recType.withRecId(recId));
-        // Remove the unbound recursive ID from the set of unbound self
-        // references in the translation stack.
-        unboundRecursiveSelfRefs.back().erase(recId);
-      }
-    }
-
+    auto [result, isSelfContained] =
+        recursionPruner->finalizeTranslation(node, attr);
     // Only cache fully self-contained nodes.
-    if (unboundRecursiveSelfRefs.back().empty())
-      nodeToAttr.try_emplace(node, attr);
-    return attr;
+    if (isSelfContained)
+      nodeToAttr.try_emplace(node, result);
+    return result;
   }
   return nullptr;
 }
@@ -394,13 +619,3 @@ DistinctAttr DebugImporter::getOrCreateDistinctID(llvm::DINode *node) {
     id = DistinctAttr::create(UnitAttr::get(context));
   return id;
 }
-
-function_ref<DIRecursiveTypeAttrInterface(DistinctAttr)>
-DebugImporter::getRecSelfConstructor(llvm::DINode *node) {
-  using CtorType = function_ref<DIRecursiveTypeAttrInterface(DistinctAttr)>;
-  return TypeSwitch<llvm::DINode *, CtorType>(node)
-      .Case([&](llvm::DICompositeType *concreteNode) {
-        return CtorType(DICompositeTypeAttr::getRecSelf);
-      })
-      .Default(CtorType());
-}
diff --git a/mlir/lib/Target/LLVMIR/DebugImporter.h b/mlir/lib/Target/LLVMIR/DebugImporter.h
index bcf628fc4234fb..1344362ff8d749 100644
--- a/mlir/lib/Target/LLVMIR/DebugImporter.h
+++ b/mlir/lib/Target/LLVMIR/DebugImporter.h
@@ -27,10 +27,12 @@ class LLVMFuncOp;
 
 namespace detail {
 
+class RecursionPruner;
+
 class DebugImporter {
 public:
-  DebugImporter(ModuleOp mlirModule)
-      : context(mlirModule.getContext()), mlirModule(mlirModule) {}
+  DebugImporter(ModuleOp mlirModule);
+  ~DebugImporter();
 
   /// Translates the given LLVM debug location to an MLIR location.
   Location translateLoc(llvm::DILocation *loc);
@@ -86,24 +88,14 @@ class DebugImporter {
   /// for it, or create a new one if not.
   DistinctAttr getOrCreateDistinctID(llvm::DINode *node);
 
-  /// Get the `getRecSelf` constructor for the translated type of `node` if its
-  /// translated DITypeAttr supports recursion. Otherwise, returns nullptr.
-  function_ref<DIRecursiveTypeAttrInterface(DistinctAttr)>
-  getRecSelfConstructor(llvm::DINode *node);
-
   /// A mapping between LLVM debug metadata and the corresponding attribute.
   DenseMap<llvm::DINode *, DINodeAttr> nodeToAttr;
   /// A mapping between distinct LLVM debug metadata nodes and the corresponding
   /// distinct id attribute.
   DenseMap<llvm::DINode *, DistinctAttr> nodeToDistinctAttr;
 
-  /// A stack that stores the metadata nodes that are being traversed. The stack
-  /// is used to detect cyclic dependencies during the metadata translation.
-  /// A node is pushed with a null value. If it is ever seen twice, it is given
-  /// a recursive id attribute, indicating that it is a recursive node.
-  llvm::MapVector<llvm::DINode *, DistinctAttr> translationStack;
-  /// All the unbound recursive self references in the translation stack.
-  SmallVector<DenseSet<DistinctAttr>> unboundRecursiveSelfRefs;
+  // Translation copilot for recursive types.
+  std::unique_ptr<RecursionPruner> recursionPruner;
 
   MLIRContext *context;
   ModuleOp mlirModule;
diff --git a/mlir/test/Target/LLVMIR/Import/debug-info.ll b/mlir/test/Target/LLVMIR/Import/debug-info.ll
index 959a5a1cd97176..b322840d852611 100644
--- a/mlir/test/Target/LLVMIR/Import/debug-info.ll
+++ b/mlir/test/Target/LLVMIR/Import/debug-info.ll
@@ -607,3 +607,44 @@ declare !dbg !1 void @declaration()
 !0 = !{i32 2, !"Debug Info Version", i32 3}
 !1 = !DISubprogram(name: "declaration", scope: !2, file: !2, flags: DIFlagPrototyped, spFlags: 0)
 !2 = !DIFile(filename: "debug-info.ll", directory: "/")
+
+; // -----
+
+; Ensure that repeated occurence of recursive subtree does not result in duplicate MLIR entries.
+
+; CHECK-DAG: #[[B1_INNER:.+]] = #llvm.di_derived_type<{{.*}}name = "B:B1", baseType = #[[B_SELF:.+]]>
+; CHECK-DAG: #[[B2_INNER:.+]] = #llvm.di_derived_type<{{.*}}name = "B:B2", baseType = #[[B_SELF]]>
+; CHECK-DAG: #[[B_INNER:.+]] = #llvm.di_composite_type<{{.*}}recId = [[B_RECID:.+]], {{.*}}name = "B", {{.*}}elements = #[[B1_INNER]], #[[B2_INNER]]
+
+; CHECK-DAG: #[[B1_OUTER:.+]] = #llvm.di_derived_type<{{.*}}name = "B:B1", baseType = #[[B_INNER]]>
+; CHECK-DAG: #[[B2_OUTER:.+]] = #llvm.di_derived_type<{{.*}}name = "B:B2", baseType = #[[B_INNER]]>
+; CHECK-DAG: #[[A_OUTER:.+]] = #llvm.di_composite_type<{{.*}}recId = [[A_RECID:.+]], {{.*}}name = "A", {{.*}}elements = #[[B1_OUTER]], #[[B2_OUTER]]
+
+; CHECK-DAG: #[[A_SELF:.+]] = #llvm.di_composite_type<{{.*}}recId = [[A_RECID]]
+; CHECK-DAG: #[[B_SELF:.+]] = #llvm.di_composite_type<{{.*}}recId = [[B_RECID]]
+
+; CHECK: #llvm.di_subprogram<{{.*}}scope = #[[A_OUTER]]
+
+
+define void @class_field(ptr %arg1) !dbg !18 {
+  ret void
+}
+
+declare void @llvm.dbg.value(metadata, metadata, metadata)
+
+!llvm.dbg.cu = !{!1}
+!llvm.module.flags = !{!0}
+!0 = !{i32 2, !"Debug Info Version", i32 3}
+!1 = distinct !DICompileUnit(language: DW_LANG_C, file: !2)
+!2 = !DIFile(filename: "debug-info.ll", directory: "/")
+
+
+!3 = !DICompositeType(tag: DW_TAG_class_type, name: "A", file: !2, line: 42, flags: DIFlagTypePassByReference | DIFlagNonTrivial, elements: !4)
+!4 = !{!7, !8}
+
+!5 = !DICompositeType(tag: DW_TAG_class_type, name: "B", scope: !3, file: !2, line: 42, flags: DIFlagTypePassByReference | DIFlagNonTrivial, elements: !9)
+!7 = !DIDerivedType(tag: DW_TAG_member, name: "B:B1", file: !2, baseType: !5)
+!8 = !DIDerivedType(tag: DW_TAG_member, name: "B:B2", file: !2, baseType: !5)
+!9 = !{!7, !8}
+
+!18 = distinct !DISubprogram(name: "A", scope: !3, file: !2, spFlags: DISPFlagDefinition, unit: !1)



More information about the Mlir-commits mailing list