[Mlir-commits] [mlir] [MLIR][LLVM] Support Recursive DITypes (PR #80251)

Billy Zhu llvmlistbot at llvm.org
Tue Mar 5 17:46:01 PST 2024


https://github.com/zyx-billy updated https://github.com/llvm/llvm-project/pull/80251

>From b8fd81a71674d7ac7aea53f5c709eb519ef6bb18 Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Thu, 25 Jan 2024 15:39:29 -0800
Subject: [PATCH 1/9] generic recursive type with importer impl

---
 .../mlir/Dialect/LLVMIR/LLVMAttrDefs.td       | 69 +++++++++++++-
 mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h  | 35 ++++++-
 mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp      | 22 ++++-
 mlir/lib/Target/LLVMIR/DebugImporter.cpp      | 95 ++++++++++++++-----
 mlir/lib/Target/LLVMIR/DebugImporter.h        | 34 +++++--
 mlir/test/Target/LLVMIR/Import/debug-info.ll  | 56 ++++-------
 6 files changed, 235 insertions(+), 76 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index a831b076fb864f..a9a471800bda3c 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -301,6 +301,70 @@ def LLVM_DIExpressionAttr : LLVM_Attr<"DIExpression", "di_expression"> {
   let assemblyFormat = "`<` ( `[` $operations^ `]` ) : (``)? `>`";
 }
 
+//===----------------------------------------------------------------------===//
+// DIRecursiveTypeAttr
+//===----------------------------------------------------------------------===//
+
+def LLVM_DIRecursiveTypeAttr : LLVM_Attr<"DIRecursiveType", "di_recursive_type",
+                                    /*traits=*/[], "DITypeAttr"> {
+  let description = [{
+    This attribute enables recursive DITypes. There are two modes for this
+    attribute.
+
+    1. If `baseType` is present:
+      - This type is considered a recursive declaration (rec-decl).
+      - The `baseType` is a self-recursive type identified by the `id` field.
+
+    2. If `baseType` is not present:
+      - This type is considered a recursive self reference (rec-self).
+      - This DIRecursiveType itself is a placeholder type that should be
+        conceptually replaced with the closet parent DIRecursiveType with the
+        same `id` field.
+
+    e.g. To represent a linked list struct:
+
+      #rec_self = di_recursive_type<self_id = 0>
+      #ptr = di_derived_type<baseType: #rec_self, ...>
+      #field = di_derived_type<name = "next", baseType: #ptr, ...>
+      #struct = di_composite_type<name = "Node", elements: #field, ...>
+      #rec = di_recursive_type<self_id = 0, baseType: #struct>
+
+      #var = di_local_variable<type = #struct_type, ...>
+
+    Note that the a rec-self without an outer rec-decl with the same id is
+    conceptually the same as an "unbound" variable. The context needs to provide
+    meaning to the rec-self.
+
+    This can be avoided by calling the `getUnfoldedBaseType()` method on a
+    rec-decl, which returns the `baseType` with all matching rec-self instances
+    replaced with this rec-decl again. This is useful, for example, for fetching
+    a field out of a recursive struct and maintaining the legality of the field
+    type.
+  }];
+
+  let parameters = (ins
+    "DistinctAttr":$id,
+    OptionalParameter<"DITypeAttr">:$baseType
+  );
+
+  let builders = [
+    AttrBuilderWithInferredContext<(ins "DistinctAttr":$id), [{
+      return $_get(id.getContext(), id, nullptr);
+    }]>
+  ];
+
+  let extraClassDeclaration = [{
+    /// Whether this node represents a self-reference.
+    bool isRecSelf() { return !getBaseType(); }
+
+    /// Get the `baseType` with all instances of the corresponding rec-self
+    /// replaced with this attribute. This can only be called if `!isRecSelf()`.
+    DITypeAttr getUnfoldedBaseType();
+  }];
+
+  let assemblyFormat = "`<` struct(params) `>`";
+}
+
 //===----------------------------------------------------------------------===//
 // DINullTypeAttr
 //===----------------------------------------------------------------------===//
@@ -526,14 +590,15 @@ def LLVM_DISubprogramAttr : LLVM_Attr<"DISubprogram", "di_subprogram",
     OptionalParameter<"unsigned">:$line,
     OptionalParameter<"unsigned">:$scopeLine,
     "DISubprogramFlags":$subprogramFlags,
-    OptionalParameter<"DISubroutineTypeAttr">:$type
+    OptionalParameter<"DIRecursiveTypeAttrOf<DISubroutineTypeAttr>">:$type
   );
   let builders = [
     AttrBuilderWithInferredContext<(ins
       "DistinctAttr":$id, "DICompileUnitAttr":$compileUnit,
       "DIScopeAttr":$scope, "StringRef":$name, "StringRef":$linkageName,
       "DIFileAttr":$file, "unsigned":$line, "unsigned":$scopeLine,
-      "DISubprogramFlags":$subprogramFlags, "DISubroutineTypeAttr":$type
+      "DISubprogramFlags":$subprogramFlags,
+      "DIRecursiveTypeAttrOf<DISubroutineTypeAttr>":$type
     ), [{
       MLIRContext *ctx = file.getContext();
       return $_get(ctx, id, compileUnit, scope, StringAttr::get(ctx, name),
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h
index c370bfa2b733d6..c9c0ba635b8549 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h
@@ -52,9 +52,9 @@ class DILocalScopeAttr : public DIScopeAttr {
 };
 
 /// This class represents a LLVM attribute that describes a debug info type.
-class DITypeAttr : public DINodeAttr {
+class DITypeAttr : public DIScopeAttr {
 public:
-  using DINodeAttr::DINodeAttr;
+  using DIScopeAttr::DIScopeAttr;
 
   /// Support LLVM type casting.
   static bool classof(Attribute attr);
@@ -74,6 +74,10 @@ class TBAANodeAttr : public Attribute {
   }
 };
 
+// Forward declare.
+template <typename BaseType>
+class DIRecursiveTypeAttrOf;
+
 // Inline the LLVM generated Linkage enum and utility.
 // This is only necessary to isolate the "enum generated code" from the
 // attribute definition itself.
@@ -87,4 +91,31 @@ using linkage::Linkage;
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.h.inc"
 
+namespace mlir {
+namespace LLVM {
+/// This class represents either a concrete attr, or a DIRecursiveTypeAttr
+/// containing such a concrete attr.
+template <typename BaseType>
+class DIRecursiveTypeAttrOf : public DITypeAttr {
+public:
+  static_assert(std::is_base_of_v<DITypeAttr, BaseType>);
+  using DITypeAttr::DITypeAttr;
+  /// Support LLVM type casting.
+  static bool classof(Attribute attr) {
+    if (auto rec = llvm::dyn_cast<DIRecursiveTypeAttr>(attr))
+      return llvm::isa<BaseType>(rec.getBaseType());
+    return llvm::isa<BaseType>(attr);
+  }
+
+  DIRecursiveTypeAttrOf(BaseType baseType) : DITypeAttr(baseType) {}
+
+  BaseType getUnfoldedBaseType() {
+    if (auto rec = llvm::dyn_cast<DIRecursiveTypeAttr>(this))
+      return llvm::cast<BaseType>(rec.getUnfoldedBaseType());
+    return llvm::cast<BaseType>(this);
+  }
+};
+} // namespace LLVM
+} // namespace mlir
+
 #endif // MLIR_DIALECT_LLVMIR_LLVMATTRS_H_
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
index 645a45dd96befb..ed4001170bfe55 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
@@ -68,8 +68,8 @@ bool DINodeAttr::classof(Attribute attr) {
 //===----------------------------------------------------------------------===//
 
 bool DIScopeAttr::classof(Attribute attr) {
-  return llvm::isa<DICompileUnitAttr, DICompositeTypeAttr, DIFileAttr,
-                   DILocalScopeAttr, DIModuleAttr, DINamespaceAttr>(attr);
+  return llvm::isa<DICompileUnitAttr, DIFileAttr, DILocalScopeAttr,
+                   DIModuleAttr, DINamespaceAttr, DITypeAttr>(attr);
 }
 
 //===----------------------------------------------------------------------===//
@@ -86,8 +86,9 @@ bool DILocalScopeAttr::classof(Attribute attr) {
 //===----------------------------------------------------------------------===//
 
 bool DITypeAttr::classof(Attribute attr) {
-  return llvm::isa<DINullTypeAttr, DIBasicTypeAttr, DICompositeTypeAttr,
-                   DIDerivedTypeAttr, DISubroutineTypeAttr>(attr);
+  return llvm::isa<DIRecursiveTypeAttr, DINullTypeAttr, DIBasicTypeAttr,
+                   DICompositeTypeAttr, DIDerivedTypeAttr,
+                   DISubroutineTypeAttr>(attr);
 }
 
 //===----------------------------------------------------------------------===//
@@ -185,6 +186,19 @@ void printExpressionArg(AsmPrinter &printer, uint64_t opcode,
   });
 }
 
+//===----------------------------------------------------------------------===//
+// DIRecursiveTypeAttr
+//===----------------------------------------------------------------------===//
+DITypeAttr DIRecursiveTypeAttr::getUnfoldedBaseType() {
+  assert(!isRecSelf() && "cannot get baseType from a rec-self type");
+  return llvm::cast<DITypeAttr>(getBaseType().replace(
+      [&](DIRecursiveTypeAttr rec) -> std::optional<DIRecursiveTypeAttr> {
+        if (rec.getId() == getId())
+          return *this;
+        return std::nullopt;
+      }));
+}
+
 //===----------------------------------------------------------------------===//
 // TargetFeaturesAttr
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/DebugImporter.cpp b/mlir/lib/Target/LLVMIR/DebugImporter.cpp
index c631617f973544..506da06906280a 100644
--- a/mlir/lib/Target/LLVMIR/DebugImporter.cpp
+++ b/mlir/lib/Target/LLVMIR/DebugImporter.cpp
@@ -51,10 +51,9 @@ DICompileUnitAttr DebugImporter::translateImpl(llvm::DICompileUnit *node) {
   std::optional<DIEmissionKind> emissionKind =
       symbolizeDIEmissionKind(node->getEmissionKind());
   return DICompileUnitAttr::get(
-      context, DistinctAttr::create(UnitAttr::get(context)),
-      node->getSourceLanguage(), translate(node->getFile()),
-      getStringAttrOrNull(node->getRawProducer()), node->isOptimized(),
-      emissionKind.value());
+      context, getOrCreateDistinctID(node), node->getSourceLanguage(),
+      translate(node->getFile()), getStringAttrOrNull(node->getRawProducer()),
+      node->isOptimized(), emissionKind.value());
 }
 
 DICompositeTypeAttr DebugImporter::translateImpl(llvm::DICompositeType *node) {
@@ -64,11 +63,7 @@ DICompositeTypeAttr DebugImporter::translateImpl(llvm::DICompositeType *node) {
     assert(element && "expected a non-null element type");
     elements.push_back(translate(element));
   }
-  // Drop the elements parameter if a cyclic dependency is detected. We
-  // currently cannot model these cycles and thus drop the parameter if
-  // required. A cyclic dependency is detected if one of the element nodes
-  // translates to a nullptr since the node is already on the translation stack.
-  // TODO: Support debug metadata with cyclic dependencies.
+  // Drop the elements parameter if any of the elements are invalid.
   if (llvm::is_contained(elements, nullptr))
     elements.clear();
   DITypeAttr baseType = translate(node->getBaseType());
@@ -84,7 +79,7 @@ DICompositeTypeAttr DebugImporter::translateImpl(llvm::DICompositeType *node) {
 }
 
 DIDerivedTypeAttr DebugImporter::translateImpl(llvm::DIDerivedType *node) {
-  // Return nullptr if the base type is a cyclic dependency.
+  // Return nullptr if the base type invalid.
   DITypeAttr baseType = translate(node->getBaseType());
   if (node->getBaseType() && !baseType)
     return nullptr;
@@ -179,14 +174,14 @@ DISubprogramAttr DebugImporter::translateImpl(llvm::DISubprogram *node) {
   // Only definitions require a distinct identifier.
   mlir::DistinctAttr id;
   if (node->isDistinct())
-    id = DistinctAttr::create(UnitAttr::get(context));
+    id = getOrCreateDistinctID(node);
   std::optional<DISubprogramFlags> subprogramFlags =
       symbolizeDISubprogramFlags(node->getSubprogram()->getSPFlags());
-  // Return nullptr if the scope or type is a cyclic dependency.
-  DIScopeAttr scope = translate(node->getScope());
+  // Return nullptr if the scope or type is invalid.
+  DIScopeAttr scope = cast<DIScopeAttr>(translate(node->getScope()));
   if (node->getScope() && !scope)
     return nullptr;
-  DISubroutineTypeAttr type = translate(node->getType());
+  DIRecursiveTypeAttrOf<DISubroutineTypeAttr> type = translate(node->getType());
   if (node->getType() && !type)
     return nullptr;
   return DISubprogramAttr::get(context, id, translate(node->getUnit()), scope,
@@ -229,7 +224,7 @@ DebugImporter::translateImpl(llvm::DISubroutineType *node) {
     }
     types.push_back(translate(type));
   }
-  // Return nullptr if any of the types is a cyclic dependency.
+  // Return nullptr if any of the types is invalid.
   if (llvm::is_contained(types, nullptr))
     return nullptr;
   return DISubroutineTypeAttr::get(context, node->getCC(), types);
@@ -247,12 +242,47 @@ DINodeAttr DebugImporter::translate(llvm::DINode *node) {
   if (DINodeAttr attr = nodeToAttr.lookup(node))
     return attr;
 
-  // Return nullptr if a cyclic dependency is detected since the same node is
-  // being traversed twice. This check avoids infinite recursion if the debug
-  // metadata contains cycles.
-  if (!translationStack.insert(node))
-    return nullptr;
-  auto guard = llvm::make_scope_exit([&]() { translationStack.pop_back(); });
+  // If a cyclic dependency is detected since the same node is being traversed
+  // twice, emit a recursive self type, and mark the duplicate node on the
+  // translationStack so it can emit a recursive decl type.
+  auto *typeNode = dyn_cast<llvm::DIType>(node);
+  if (typeNode) {
+    auto [iter, inserted] = typeTranslationStack.try_emplace(typeNode, nullptr);
+    if (!inserted) {
+      // The original node may have already been assigned a recursive ID from
+      // a different self-reference. Use that if possible.
+      DistinctAttr recId = iter->second;
+      if (!recId) {
+        recId = DistinctAttr::create(UnitAttr::get(context));
+        iter->second = recId;
+      }
+      unboundRecursiveSelfRefs.back().insert(recId);
+      return DIRecursiveTypeAttr::get(recId);
+    }
+  } else {
+    bool inserted =
+        nonTypeTranslationStack.insert({node, typeTranslationStack.size()});
+    assert(inserted && "recursion is only supported via DITypes");
+  }
+
+  unboundRecursiveSelfRefs.emplace_back();
+
+  auto guard = llvm::make_scope_exit([&]() {
+    if (typeNode)
+      typeTranslationStack.pop_back();
+    else
+      nonTypeTranslationStack.pop_back();
+
+    // Copy unboundRecursiveSelfRefs down to the previous level.
+    if (unboundRecursiveSelfRefs.size() == 1)
+      assert(unboundRecursiveSelfRefs.back().empty() &&
+             "internal error: unbound recursive self reference at top level.");
+    else
+      unboundRecursiveSelfRefs[unboundRecursiveSelfRefs.size() - 2].insert(
+          unboundRecursiveSelfRefs.back().begin(),
+          unboundRecursiveSelfRefs.back().end());
+    unboundRecursiveSelfRefs.pop_back();
+  });
 
   // Convert the debug metadata if possible.
   auto translateNode = [this](llvm::DINode *node) -> DINodeAttr {
@@ -289,7 +319,21 @@ DINodeAttr DebugImporter::translate(llvm::DINode *node) {
     return nullptr;
   };
   if (DINodeAttr attr = translateNode(node)) {
-    nodeToAttr.insert({node, attr});
+    // If this node was marked as recursive, wrap with a recursive type.
+    if (typeNode) {
+      if (DistinctAttr id = typeTranslationStack.lookup(typeNode)) {
+        DITypeAttr typeAttr = cast<DITypeAttr>(attr);
+        attr = DIRecursiveTypeAttr::get(context, id, typeAttr);
+
+        // Remove the unbound recursive attr.
+        AttrTypeReplacer replacer;
+        unboundRecursiveSelfRefs.back().erase(id);
+      }
+    }
+
+    // Only cache fully self-contained nodes.
+    if (unboundRecursiveSelfRefs.back().empty())
+      nodeToAttr.try_emplace(node, attr);
     return attr;
   }
   return nullptr;
@@ -346,3 +390,10 @@ StringAttr DebugImporter::getStringAttrOrNull(llvm::MDString *stringNode) {
     return StringAttr();
   return StringAttr::get(context, stringNode->getString());
 }
+
+DistinctAttr DebugImporter::getOrCreateDistinctID(llvm::DINode *node) {
+  DistinctAttr &id = nodeToDistinctAttr[node];
+  if (!id)
+    id = DistinctAttr::create(UnitAttr::get(context));
+  return id;
+}
diff --git a/mlir/lib/Target/LLVMIR/DebugImporter.h b/mlir/lib/Target/LLVMIR/DebugImporter.h
index 7d4a371284b68b..7e9e28b274b983 100644
--- a/mlir/lib/Target/LLVMIR/DebugImporter.h
+++ b/mlir/lib/Target/LLVMIR/DebugImporter.h
@@ -52,8 +52,15 @@ class DebugImporter {
   /// Infers the metadata type and translates it to MLIR.
   template <typename DINodeT>
   auto translate(DINodeT *node) {
-    // Infer the MLIR type from the LLVM metadata type.
-    using MLIRTypeT = decltype(translateImpl(node));
+    // Infer the result MLIR type from the LLVM metadata type.
+    // If the result is a DIType, it can also be wrapped in a recursive type,
+    // so the result is wrapped into a DIRecursiveTypeAttrOf.
+    // Otherwise, the exact result type is used.
+    constexpr bool isDIType = std::is_base_of_v<llvm::DIType, DINodeT>;
+    using RawMLIRTypeT = decltype(translateImpl(node));
+    using MLIRTypeT =
+        std::conditional_t<isDIType, DIRecursiveTypeAttrOf<RawMLIRTypeT>,
+                           RawMLIRTypeT>;
     return cast_or_null<MLIRTypeT>(
         translate(static_cast<llvm::DINode *>(node)));
   }
@@ -82,12 +89,27 @@ class DebugImporter {
   /// null attribute otherwise.
   StringAttr getStringAttrOrNull(llvm::MDString *stringNode);
 
+  DistinctAttr getOrCreateDistinctID(llvm::DINode *node);
+
   /// A mapping between LLVM debug metadata and the corresponding attribute.
   DenseMap<llvm::DINode *, DINodeAttr> nodeToAttr;
-
-  /// A stack that stores the metadata nodes that are being traversed. The stack
-  /// is used to detect cyclic dependencies during the metadata translation.
-  SetVector<llvm::DINode *> translationStack;
+  /// A mapping between LLVM debug metadata and the distinct ID attr for DI
+  /// nodes that require distinction.
+  DenseMap<llvm::DINode *, DistinctAttr> nodeToDistinctAttr;
+
+  /// A stack that stores the metadata type nodes that are being traversed. The
+  /// stack is used to detect cyclic dependencies during the metadata
+  /// translation. Nodes are pushed with a null value. If it is ever seen twice,
+  /// it is given a DistinctAttr, indicating that it is a recursive node and
+  /// should take on that DistinctAttr as ID.
+  llvm::MapVector<llvm::DIType *, DistinctAttr> typeTranslationStack;
+  /// All the unbound recursive self references in the translation stack.
+  SmallVector<DenseSet<DistinctAttr>> unboundRecursiveSelfRefs;
+  /// A stack that stores the non-type metadata nodes that are being traversed.
+  /// Each node is associated with the size of the `typeTranslationStack` at the
+  /// time of push. This is used to identify a recursion purely in the non-type
+  /// metadata nodes, which is not supported yet.
+  SetVector<std::pair<llvm::DINode *, unsigned>> nonTypeTranslationStack;
 
   MLIRContext *context;
   ModuleOp mlirModule;
diff --git a/mlir/test/Target/LLVMIR/Import/debug-info.ll b/mlir/test/Target/LLVMIR/Import/debug-info.ll
index 9ef6580bcf2408..032acc8f8c8115 100644
--- a/mlir/test/Target/LLVMIR/Import/debug-info.ll
+++ b/mlir/test/Target/LLVMIR/Import/debug-info.ll
@@ -296,12 +296,16 @@ define void @class_method() {
   ret void, !dbg !9
 }
 
-; Verify the elements parameter is dropped due to the cyclic dependencies.
-; CHECK: #[[COMP:.+]] = #llvm.di_composite_type<tag = DW_TAG_class_type, name = "class_name", file = #{{.*}}, line = 42, flags = "TypePassByReference|NonTrivial">
-; CHECK: #[[COMP_PTR:.+]] = #llvm.di_derived_type<tag = DW_TAG_pointer_type, baseType = #[[COMP]], sizeInBits = 64>
+; Verify the cyclic composite type is identified, even though conversion begins from the subprogram type.
+; CHECK: #[[COMP_PTR:.+]] = #llvm.di_derived_type<tag = DW_TAG_pointer_type, baseType = #llvm.di_recursive_type<id = [[COMP_ID:.+]]>, sizeInBits = 64>
 ; CHECK: #[[SP_TYPE:.+]] = #llvm.di_subroutine_type<types = #{{.*}}, #[[COMP_PTR]]>
-; CHECK: #[[SP:.+]] = #llvm.di_subprogram<id = distinct[{{.*}}]<>, compileUnit = #{{.*}}, scope = #[[COMP]], name = "class_method", file = #{{.*}}, subprogramFlags = Definition, type = #[[SP_TYPE]]>
-; CHECK: #[[LOC]] = loc(fused<#[[SP]]>
+; CHECK: #[[SP_INNER:.+]] = #llvm.di_subprogram<id = [[SP_ID:.+]], compileUnit = #{{.*}}, scope = #llvm.di_recursive_type<id = [[COMP_ID]]>, name = "class_method", file = #{{.*}}, subprogramFlags = Definition, type = #[[SP_TYPE]]>
+; CHECK: #[[COMP:.+]] = #llvm.di_composite_type<tag = DW_TAG_class_type, name = "class_name", file = #{{.*}}, line = 42, flags = "TypePassByReference|NonTrivial", elements = #[[SP_INNER]]>
+
+; CHECK: #[[COMP_PTR_OUTER:.+]] = #llvm.di_derived_type<tag = DW_TAG_pointer_type, baseType = #llvm.di_recursive_type<id = [[COMP_ID]], baseType = #[[COMP]]>, sizeInBits = 64>
+; CHECK: #[[SP_TYPE_OUTER:.+]] = #llvm.di_subroutine_type<types = #{{.*}}, #[[COMP_PTR_OUTER]]>
+; CHECK: #[[SP_OUTER:.+]] = #llvm.di_subprogram<id = [[SP_ID]], compileUnit = #{{.*}}, scope = #llvm.di_recursive_type<id = [[COMP_ID]], baseType = #[[COMP]]>, name = "class_method", file = #{{.*}}, subprogramFlags = Definition, type = #[[SP_TYPE_OUTER]]>
+; CHECK: #[[LOC]] = loc(fused<#[[SP_OUTER]]>
 
 !llvm.dbg.cu = !{!1}
 !llvm.module.flags = !{!0}
@@ -318,15 +322,16 @@ define void @class_method() {
 
 ; // -----
 
-; Verify the elements parameter is dropped due to the cyclic dependencies.
-; CHECK: #[[$COMP:.+]] = #llvm.di_composite_type<tag = DW_TAG_class_type, name = "class_field", file = #{{.*}}, line = 42, flags = "TypePassByReference|NonTrivial">
-; CHECK: #[[$COMP_PTR:.+]] = #llvm.di_derived_type<tag = DW_TAG_pointer_type, baseType = #[[$COMP]]>
-; CHECK: #[[$VAR0:.+]] = #llvm.di_local_variable<scope = #{{.*}}, name = "class_field", file = #{{.*}}, type = #[[$COMP_PTR]]>
+; Verify the cyclic composite type is handled correctly.
+; CHECK: #[[FIELD:.+]] = #llvm.di_derived_type<tag = DW_TAG_member, name = "call_field", baseType = #llvm.di_recursive_type<id = [[COMP_ID:.+]]>>
+; CHECK: #[[COMP:.+]] = #llvm.di_composite_type<tag = DW_TAG_class_type, name = "class_field", file = #{{.*}}, line = 42, flags = "TypePassByReference|NonTrivial", elements = #[[FIELD]]>
+; CHECK: #[[COMP_PTR:.+]] = #llvm.di_derived_type<tag = DW_TAG_pointer_type, baseType = #[[COMP]]>
+; CHECK: #[[VAR0:.+]] = #llvm.di_local_variable<scope = #{{.*}}, name = "class_field", file = #{{.*}}, type = #llvm.di_recursive_type<id = [[COMP_ID]], baseType = #[[COMP_PTR]]>>
 
-; CHECK-LABEL: @class_field
+; CHECK: @class_field
 ; CHECK-SAME:  %[[ARG0:[a-zA-Z0-9]+]]
 define void @class_field(ptr %arg1) {
-  ; CHECK: llvm.intr.dbg.value #[[$VAR0]] = %[[ARG0]] : !llvm.ptr
+  ; CHECK: llvm.intr.dbg.value #[[VAR0]] = %[[ARG0]] : !llvm.ptr
   call void @llvm.dbg.value(metadata ptr %arg1, metadata !7, metadata !DIExpression()), !dbg !9
   ret void
 }
@@ -563,35 +568,6 @@ define void @func_in_module(ptr %arg) !dbg !8 {
 
 ; // -----
 
-; Verifies that array types that have an unimportable base type are removed to
-; avoid producing invalid IR.
-; CHECK: #[[DI_LOCAL_VAR:.+]] = #llvm.di_local_variable<
-; CHECK-NOT: type =
-
-; CHECK-LABEL: @array_with_cyclic_base_type
-define i32 @array_with_cyclic_base_type(ptr %0) !dbg !3 {
-  call void @llvm.dbg.value(metadata ptr %0, metadata !4, metadata !DIExpression()), !dbg !7
-  ret i32 0
-}
-
-; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
-declare void @llvm.dbg.value(metadata, metadata, metadata)
-
-
-!llvm.module.flags = !{!0}
-!llvm.dbg.cu = !{!1}
-
-!0 = !{i32 2, !"Debug Info Version", i32 3}
-!1 = distinct !DICompileUnit(language: DW_LANG_C, file: !2)
-!2 = !DIFile(filename: "debug-info.ll", directory: "/")
-!3 = distinct !DISubprogram(name: "func", scope: !2, file: !2, line: 46, scopeLine: 48, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !1)
-!4 = !DILocalVariable(name: "op", arg: 5, scope: !3, file: !2, line: 47, type: !5)
-!5 = !DICompositeType(tag: DW_TAG_array_type, size: 42, baseType: !6)
-!6 = !DIDerivedType(tag: DW_TAG_pointer_type, baseType: !5)
-!7 = !DILocation(line: 0, scope: !3)
-
-; // -----
-
 ; Verifies that import compile units respect the distinctness of the input.
 ; CHECK-LABEL: @distinct_cu_func0
 define void @distinct_cu_func0() !dbg !4 {

>From fe4a78050d560464f89f7a9d8888ea99ecc62984 Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Mon, 5 Feb 2024 12:22:20 -0800
Subject: [PATCH 2/9] add exporter support for composite type

---
 mlir/lib/Target/LLVMIR/DebugTranslation.cpp | 96 ++++++++++++++++-----
 mlir/lib/Target/LLVMIR/DebugTranslation.h   | 23 ++++-
 mlir/test/Target/LLVMIR/llvmir-debug.mlir   | 60 +++++++++++++
 3 files changed, 156 insertions(+), 23 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp
index 420bb8d8274ecb..fec9f9d77318ca 100644
--- a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp
@@ -117,11 +117,8 @@ static DINodeT *getDistinctOrUnique(bool isDistinct, Ts &&...args) {
 }
 
 llvm::DICompositeType *
-DebugTranslation::translateImpl(DICompositeTypeAttr attr) {
-  SmallVector<llvm::Metadata *> elements;
-  for (auto member : attr.getElements())
-    elements.push_back(translate(member));
-
+DebugTranslation::translateImpl(DICompositeTypeAttr attr,
+                                SetRecursivePlaceholderFn setRec) {
   // TODO: Use distinct attributes to model this, once they have landed.
   // Depending on the tag, composite types must be distinct.
   bool isDistinct = false;
@@ -133,15 +130,26 @@ DebugTranslation::translateImpl(DICompositeTypeAttr attr) {
     isDistinct = true;
   }
 
-  return getDistinctOrUnique<llvm::DICompositeType>(
-      isDistinct, llvmCtx, attr.getTag(), getMDStringOrNull(attr.getName()),
-      translate(attr.getFile()), attr.getLine(), translate(attr.getScope()),
-      translate(attr.getBaseType()), attr.getSizeInBits(),
-      attr.getAlignInBits(),
-      /*OffsetInBits=*/0,
-      /*Flags=*/static_cast<llvm::DINode::DIFlags>(attr.getFlags()),
-      llvm::MDNode::get(llvmCtx, elements),
-      /*RuntimeLang=*/0, /*VTableHolder=*/nullptr);
+  llvm::DICompositeType *placeholder =
+      getDistinctOrUnique<llvm::DICompositeType>(
+          isDistinct, llvmCtx, attr.getTag(), getMDStringOrNull(attr.getName()),
+          translate(attr.getFile()), attr.getLine(), translate(attr.getScope()),
+          translate(attr.getBaseType()), attr.getSizeInBits(),
+          attr.getAlignInBits(),
+          /*OffsetInBits=*/0,
+          /*Flags=*/static_cast<llvm::DINode::DIFlags>(attr.getFlags()),
+          /*Elements=*/nullptr, /*RuntimeLang=*/0, /*VTableHolder=*/nullptr);
+
+  if (setRec)
+    setRec(placeholder);
+
+  SmallVector<llvm::Metadata *> elements;
+  for (auto member : attr.getElements())
+    elements.push_back(translate(member));
+
+  placeholder->replaceElements(llvm::MDNode::get(llvmCtx, elements));
+
+  return placeholder;
 }
 
 llvm::DIDerivedType *DebugTranslation::translateImpl(DIDerivedTypeAttr attr) {
@@ -201,22 +209,66 @@ DebugTranslation::translateImpl(DIGlobalVariableAttr attr) {
       attr.getIsDefined(), nullptr, nullptr, attr.getAlignInBits(), nullptr);
 }
 
+llvm::DIType *DebugTranslation::translateImpl(DIRecursiveTypeAttr attr) {
+  if (attr.isRecSelf()) {
+    auto *iter = recursiveTypeMap.find(attr.getId());
+    assert(iter != recursiveTypeMap.end() && "unbound DI recursive self type");
+    return iter->second;
+  }
+
+  size_t recursiveStackSize = recursiveTypeMap.size();
+  auto setRecursivePlaceholderFn = [&](llvm::DIType *node) {
+    auto [iter, inserted] = recursiveTypeMap.try_emplace(attr.getId(), node);
+    assert(inserted && "illegal reuse of recursive id");
+  };
+
+  llvm::DIType *node =
+      TypeSwitch<DITypeAttr, llvm::DIType *>(attr.getBaseType())
+          .Case<DICompositeTypeAttr>([&](auto attr) {
+            return translateImpl(attr, setRecursivePlaceholderFn);
+          });
+
+  assert((recursiveStackSize + 1 == recursiveTypeMap.size()) &&
+         "internal inconsistency: unexpected recursive translation stack");
+  recursiveTypeMap.pop_back();
+
+  return node;
+}
+
 llvm::DIScope *DebugTranslation::translateImpl(DIScopeAttr attr) {
   return cast<llvm::DIScope>(translate(DINodeAttr(attr)));
 }
 
 llvm::DISubprogram *DebugTranslation::translateImpl(DISubprogramAttr attr) {
+  if (auto iter = distinctAttrToNode.find(attr.getId());
+      iter != distinctAttrToNode.end())
+    return cast<llvm::DISubprogram>(iter->second);
+
+  llvm::DIScope *scope = translate(attr.getScope());
+  llvm::DIFile *file = translate(attr.getFile());
+  llvm::DIType *type = translate(attr.getType());
+  llvm::DICompileUnit *compileUnit = translate(attr.getCompileUnit());
+
+  // Check again after recursive calls in case this distinct node recurses back
+  // to itself.
+  if (auto iter = distinctAttrToNode.find(attr.getId());
+      iter != distinctAttrToNode.end())
+    return cast<llvm::DISubprogram>(iter->second);
+
   bool isDefinition = static_cast<bool>(attr.getSubprogramFlags() &
                                         LLVM::DISubprogramFlags::Definition);
-  return getDistinctOrUnique<llvm::DISubprogram>(
-      isDefinition, llvmCtx, translate(attr.getScope()),
-      getMDStringOrNull(attr.getName()),
-      getMDStringOrNull(attr.getLinkageName()), translate(attr.getFile()),
-      attr.getLine(), translate(attr.getType()), attr.getScopeLine(),
+  llvm::DISubprogram *node = getDistinctOrUnique<llvm::DISubprogram>(
+      isDefinition, llvmCtx, scope, getMDStringOrNull(attr.getName()),
+      getMDStringOrNull(attr.getLinkageName()), file, attr.getLine(), type,
+      attr.getScopeLine(),
       /*ContainingType=*/nullptr, /*VirtualIndex=*/0,
       /*ThisAdjustment=*/0, llvm::DINode::FlagZero,
       static_cast<llvm::DISubprogram::DISPFlags>(attr.getSubprogramFlags()),
-      translate(attr.getCompileUnit()));
+      compileUnit);
+
+  if (attr.getId())
+    distinctAttrToNode.try_emplace(attr.getId(), node);
+  return node;
 }
 
 llvm::DIModule *DebugTranslation::translateImpl(DIModuleAttr attr) {
@@ -275,8 +327,8 @@ llvm::DINode *DebugTranslation::translate(DINodeAttr attr) {
                 DIDerivedTypeAttr, DIFileAttr, DIGlobalVariableAttr,
                 DILabelAttr, DILexicalBlockAttr, DILexicalBlockFileAttr,
                 DILocalVariableAttr, DIModuleAttr, DINamespaceAttr,
-                DINullTypeAttr, DISubprogramAttr, DISubrangeAttr,
-                DISubroutineTypeAttr>(
+                DINullTypeAttr, DIRecursiveTypeAttr, DISubprogramAttr,
+                DISubrangeAttr, DISubroutineTypeAttr>(
               [&](auto attr) { return translateImpl(attr); });
   attrToNode.insert({attr, node});
   return node;
diff --git a/mlir/lib/Target/LLVMIR/DebugTranslation.h b/mlir/lib/Target/LLVMIR/DebugTranslation.h
index 627c6846844983..f4ba687ad564e6 100644
--- a/mlir/lib/Target/LLVMIR/DebugTranslation.h
+++ b/mlir/lib/Target/LLVMIR/DebugTranslation.h
@@ -68,10 +68,22 @@ class DebugTranslation {
   llvm::DIFile *translateFile(StringRef fileName);
 
   /// Translate the given attribute to the corresponding llvm debug metadata.
+  ///
+  /// For attributes corresponding to DITypes that can be recursive (i.e.
+  /// supports replacing subelements), an additional optional argument with type
+  /// `SetRecursivePlaceholderFn` should be supported.
+  /// The translation impl for recursive support must follow these three steps:
+  /// 1. Produce a placeholder version of the translated node without calling
+  /// `translate` on any subelements of the MLIR attr.
+  /// 2. Call the SetRecursivePlaceholderFn with the placeholder node.
+  /// 3. Translate subelements recursively using `translate` and fill the
+  /// original placeholder.
+  using SetRecursivePlaceholderFn = llvm::function_ref<void(llvm::DIType *)>;
   llvm::DIType *translateImpl(DINullTypeAttr attr);
   llvm::DIBasicType *translateImpl(DIBasicTypeAttr attr);
   llvm::DICompileUnit *translateImpl(DICompileUnitAttr attr);
-  llvm::DICompositeType *translateImpl(DICompositeTypeAttr attr);
+  llvm::DICompositeType *translateImpl(DICompositeTypeAttr attr,
+                                       SetRecursivePlaceholderFn setRec = {});
   llvm::DIDerivedType *translateImpl(DIDerivedTypeAttr attr);
   llvm::DIFile *translateImpl(DIFileAttr attr);
   llvm::DILabel *translateImpl(DILabelAttr attr);
@@ -82,6 +94,7 @@ class DebugTranslation {
   llvm::DIGlobalVariable *translateImpl(DIGlobalVariableAttr attr);
   llvm::DIModule *translateImpl(DIModuleAttr attr);
   llvm::DINamespace *translateImpl(DINamespaceAttr attr);
+  llvm::DIType *translateImpl(DIRecursiveTypeAttr attr);
   llvm::DIScope *translateImpl(DIScopeAttr attr);
   llvm::DISubprogram *translateImpl(DISubprogramAttr attr);
   llvm::DISubrange *translateImpl(DISubrangeAttr attr);
@@ -102,6 +115,14 @@ class DebugTranslation {
   /// metadata.
   DenseMap<Attribute, llvm::DINode *> attrToNode;
 
+  /// A mapping from DIRecursiveTypeAttr id to the translated DIType.
+  llvm::MapVector<DistinctAttr, llvm::DIType *> recursiveTypeMap;
+
+  /// A mapping between distinct ID attr for DI nodes that require distinction
+  /// and the translate LLVM metadata node. This helps identify attrs that
+  /// should translate into the same LLVM debug node.
+  DenseMap<DistinctAttr, llvm::DINode *> distinctAttrToNode;
+
   /// A mapping between filename and llvm debug file.
   /// TODO: Change this to DenseMap<Identifier, ...> when we can
   /// access the Identifier filename in FileLineColLoc.
diff --git a/mlir/test/Target/LLVMIR/llvmir-debug.mlir b/mlir/test/Target/LLVMIR/llvmir-debug.mlir
index cfd5239515c9c0..446ed0afc39e1a 100644
--- a/mlir/test/Target/LLVMIR/llvmir-debug.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-debug.mlir
@@ -342,3 +342,63 @@ llvm.func @func_line_tables() {
 llvm.func @func_debug_directives() {
   llvm.return
 } loc(fused<#di_subprogram_2>["foo2.mlir":0:0])
+
+// -----
+
+// Ensure recursive types with multiple external references work.
+
+// Common base nodes.
+#di_file = #llvm.di_file<"test.mlir" in "/">
+#di_null_type = #llvm.di_null_type
+#di_compile_unit = #llvm.di_compile_unit<id = distinct[1]<>, sourceLanguage = DW_LANG_C, file = #di_file, isOptimized = false, emissionKind = None>
+
+// Recursive type itself.
+#di_rec_self = #llvm.di_recursive_type<id = distinct[0]<>>
+#di_ptr_inner = #llvm.di_derived_type<tag = DW_TAG_pointer_type, baseType = #di_rec_self, sizeInBits = 64>
+#di_subroutine_inner = #llvm.di_subroutine_type<types = #di_null_type, #di_ptr_inner>
+#di_subprogram_inner = #llvm.di_subprogram<
+  id = distinct[2]<>,
+  compileUnit = #di_compile_unit,
+  scope = #di_rec_self,
+  name = "class_method",
+  file = #di_file,
+  subprogramFlags = Definition,
+  type = #di_subroutine_inner>
+#di_struct = #llvm.di_composite_type<
+  tag = DW_TAG_class_type,
+  name = "class_name",
+  file = #di_file,
+  line = 42,
+  flags = "TypePassByReference|NonTrivial",
+  elements = #di_subprogram_inner>
+#di_rec_struct = #llvm.di_recursive_type<id = distinct[0]<>, baseType = #di_struct>
+
+// Outer types referencing the entire recursive type.
+#di_ptr_outer = #llvm.di_derived_type<tag = DW_TAG_pointer_type, baseType = #di_rec_struct, sizeInBits = 64>
+#di_subroutine_outer = #llvm.di_subroutine_type<types = #di_null_type, #di_ptr_outer>
+#di_subprogram_outer = #llvm.di_subprogram<
+  id = distinct[2]<>,
+  compileUnit = #di_compile_unit,
+  scope = #di_rec_struct,
+  name = "class_method",
+  file = #di_file,
+  subprogramFlags = Definition,
+  type = #di_subroutine_outer>
+
+#loc3 = loc(fused<#di_subprogram_outer>["test.mlir":1:1])
+
+// CHECK: @class_method
+// CHECK: ret void, !dbg ![[LOC:.*]]
+
+// CHECK: ![[CU:.*]] = distinct !DICompileUnit(
+// CHECK: ![[SP:.*]] = distinct !DISubprogram(name: "class_method", scope: ![[STRUCT:.*]], file: !{{.*}}, type: ![[SUBROUTINE:.*]], spFlags: DISPFlagDefinition, unit: ![[CU]])
+// CHECK: ![[STRUCT]] = distinct !DICompositeType(tag: DW_TAG_class_type, name: "class_name", {{.*}}, elements: ![[ELEMS:.*]])
+// CHECK: ![[ELEMS]] = !{![[SP]]}
+// CHECK: ![[SUBROUTINE]] = !DISubroutineType(types: ![[SUBROUTINE_ELEMS:.*]])
+// CHECK: ![[SUBROUTINE_ELEMS]] = !{null, ![[PTR:.*]]}
+// CHECK: ![[PTR]] = !DIDerivedType(tag: DW_TAG_pointer_type, baseType: ![[STRUCT]], size: 64)
+// CHECK: ![[LOC]] = !DILocation(line: 1, column: 1, scope: ![[SP]])
+
+llvm.func @class_method() {
+  llvm.return loc(#loc3)
+} loc(#loc3)

>From 57b8570aaef8087563120b467540e6cd3cfde289 Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Tue, 27 Feb 2024 15:07:14 -0800
Subject: [PATCH 3/9] remove supported failure test

---
 .../Target/LLVMIR/Import/import-failure.ll    | 32 -------------------
 1 file changed, 32 deletions(-)

diff --git a/mlir/test/Target/LLVMIR/Import/import-failure.ll b/mlir/test/Target/LLVMIR/Import/import-failure.ll
index 9a4e939d106516..3f4efab70e1c02 100644
--- a/mlir/test/Target/LLVMIR/Import/import-failure.ll
+++ b/mlir/test/Target/LLVMIR/Import/import-failure.ll
@@ -85,38 +85,6 @@ define void @unsupported_argument(i64 %arg1) {
 
 ; // -----
 
-; Check that debug intrinsics that depend on cyclic metadata are dropped.
-
-declare void @llvm.dbg.value(metadata, metadata, metadata)
-
-; CHECK:      import-failure.ll
-; CHECK-SAME: warning: dropped instruction: call void @llvm.dbg.label(metadata !{{.*}})
-; CHECK:      import-failure.ll
-; CHECK-SAME: warning: dropped intrinsic: call void @llvm.dbg.value(metadata i64 %{{.*}}, metadata !3, metadata !DIExpression())
-define void @cylic_metadata(i64 %arg1) {
-  call void @llvm.dbg.value(metadata i64 %arg1, metadata !10, metadata !DIExpression()), !dbg !14
-  call void @llvm.dbg.label(metadata !13), !dbg !14
-  ret void
-}
-
-!llvm.dbg.cu = !{!1}
-!llvm.module.flags = !{!0}
-!0 = !{i32 2, !"Debug Info Version", i32 3}
-!1 = distinct !DICompileUnit(language: DW_LANG_C, file: !2)
-!2 = !DIFile(filename: "import-failure.ll", directory: "/")
-!3 = !DICompositeType(tag: DW_TAG_array_type, size: 42, baseType: !4)
-!4 = !DIDerivedType(tag: DW_TAG_pointer_type, baseType: !3)
-!5 = distinct !DISubprogram(name: "class_method", scope: !2, file: !2, type: !6, spFlags: DISPFlagDefinition, unit: !1)
-!6 = !DISubroutineType(types: !7)
-!7 = !{!3}
-!10 = !DILocalVariable(scope: !5, name: "arg1", file: !2, line: 1, arg: 1, align: 64);
-!11 = !DILexicalBlock(scope: !5)
-!12 = !DILexicalBlockFile(scope: !11, discriminator: 0)
-!13 = !DILabel(scope: !12, name: "label", file: !2, line: 42)
-!14 = !DILocation(line: 1, column: 2, scope: !5)
-
-; // -----
-
 ; global_dtors with non-null data fields cannot be represented in MLIR.
 ; CHECK:      <unknown>
 ; CHECK-SAME: error: unhandled global variable: @llvm.global_dtors

>From ccd76ece0a14d5340c20843995a3f6988d0f5eec Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Wed, 28 Feb 2024 15:44:40 -0800
Subject: [PATCH 4/9] cleanup

---
 mlir/lib/Target/LLVMIR/DebugImporter.cpp    | 1 -
 mlir/lib/Target/LLVMIR/DebugTranslation.cpp | 5 ++++-
 2 files changed, 4 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/DebugImporter.cpp b/mlir/lib/Target/LLVMIR/DebugImporter.cpp
index 506da06906280a..f25a35acfe526d 100644
--- a/mlir/lib/Target/LLVMIR/DebugImporter.cpp
+++ b/mlir/lib/Target/LLVMIR/DebugImporter.cpp
@@ -326,7 +326,6 @@ DINodeAttr DebugImporter::translate(llvm::DINode *node) {
         attr = DIRecursiveTypeAttr::get(context, id, typeAttr);
 
         // Remove the unbound recursive attr.
-        AttrTypeReplacer replacer;
         unboundRecursiveSelfRefs.back().erase(id);
       }
     }
diff --git a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp
index fec9f9d77318ca..436b31d6f78e0a 100644
--- a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp
@@ -130,6 +130,8 @@ DebugTranslation::translateImpl(DICompositeTypeAttr attr,
     isDistinct = true;
   }
 
+  llvm::TempMDTuple placeholderElements =
+      llvm::MDNode::getTemporary(llvmCtx, std::nullopt);
   llvm::DICompositeType *placeholder =
       getDistinctOrUnique<llvm::DICompositeType>(
           isDistinct, llvmCtx, attr.getTag(), getMDStringOrNull(attr.getName()),
@@ -138,7 +140,8 @@ DebugTranslation::translateImpl(DICompositeTypeAttr attr,
           attr.getAlignInBits(),
           /*OffsetInBits=*/0,
           /*Flags=*/static_cast<llvm::DINode::DIFlags>(attr.getFlags()),
-          /*Elements=*/nullptr, /*RuntimeLang=*/0, /*VTableHolder=*/nullptr);
+          /*Elements=*/placeholderElements.get(), /*RuntimeLang=*/0,
+          /*VTableHolder=*/nullptr);
 
   if (setRec)
     setRec(placeholder);

>From 019f0f26527b075df28932dba1b0b72c46a3b5f9 Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Wed, 28 Feb 2024 16:25:03 -0800
Subject: [PATCH 5/9] use two-step recursive translation hook

---
 mlir/lib/Target/LLVMIR/DebugTranslation.cpp | 64 +++++++++++----------
 mlir/lib/Target/LLVMIR/DebugTranslation.h   | 27 ++++-----
 2 files changed, 49 insertions(+), 42 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp
index 436b31d6f78e0a..73c59282913062 100644
--- a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp
@@ -117,8 +117,7 @@ static DINodeT *getDistinctOrUnique(bool isDistinct, Ts &&...args) {
 }
 
 llvm::DICompositeType *
-DebugTranslation::translateImpl(DICompositeTypeAttr attr,
-                                SetRecursivePlaceholderFn setRec) {
+DebugTranslation::translateImplGetPlaceholder(DICompositeTypeAttr attr) {
   // TODO: Use distinct attributes to model this, once they have landed.
   // Depending on the tag, composite types must be distinct.
   bool isDistinct = false;
@@ -132,26 +131,29 @@ DebugTranslation::translateImpl(DICompositeTypeAttr attr,
 
   llvm::TempMDTuple placeholderElements =
       llvm::MDNode::getTemporary(llvmCtx, std::nullopt);
-  llvm::DICompositeType *placeholder =
-      getDistinctOrUnique<llvm::DICompositeType>(
-          isDistinct, llvmCtx, attr.getTag(), getMDStringOrNull(attr.getName()),
-          translate(attr.getFile()), attr.getLine(), translate(attr.getScope()),
-          translate(attr.getBaseType()), attr.getSizeInBits(),
-          attr.getAlignInBits(),
-          /*OffsetInBits=*/0,
-          /*Flags=*/static_cast<llvm::DINode::DIFlags>(attr.getFlags()),
-          /*Elements=*/placeholderElements.get(), /*RuntimeLang=*/0,
-          /*VTableHolder=*/nullptr);
-
-  if (setRec)
-    setRec(placeholder);
+  return getDistinctOrUnique<llvm::DICompositeType>(
+      isDistinct, llvmCtx, attr.getTag(), getMDStringOrNull(attr.getName()),
+      translate(attr.getFile()), attr.getLine(), translate(attr.getScope()),
+      translate(attr.getBaseType()), attr.getSizeInBits(),
+      attr.getAlignInBits(),
+      /*OffsetInBits=*/0,
+      /*Flags=*/static_cast<llvm::DINode::DIFlags>(attr.getFlags()),
+      /*Elements=*/placeholderElements.get(), /*RuntimeLang=*/0,
+      /*VTableHolder=*/nullptr);
+}
 
+void DebugTranslation::translateImplFillPlaceholder(
+    DICompositeTypeAttr attr, llvm::DICompositeType *placeholder) {
   SmallVector<llvm::Metadata *> elements;
   for (auto member : attr.getElements())
     elements.push_back(translate(member));
-
   placeholder->replaceElements(llvm::MDNode::get(llvmCtx, elements));
+}
 
+llvm::DICompositeType *
+DebugTranslation::translateImpl(DICompositeTypeAttr attr) {
+  llvm::DICompositeType *placeholder = translateImplGetPlaceholder(attr);
+  translateImplFillPlaceholder(attr, placeholder);
   return placeholder;
 }
 
@@ -213,29 +215,33 @@ DebugTranslation::translateImpl(DIGlobalVariableAttr attr) {
 }
 
 llvm::DIType *DebugTranslation::translateImpl(DIRecursiveTypeAttr attr) {
+  DistinctAttr recursiveId = attr.getId();
   if (attr.isRecSelf()) {
-    auto *iter = recursiveTypeMap.find(attr.getId());
+    auto *iter = recursiveTypeMap.find(recursiveId);
     assert(iter != recursiveTypeMap.end() && "unbound DI recursive self type");
     return iter->second;
   }
 
-  size_t recursiveStackSize = recursiveTypeMap.size();
-  auto setRecursivePlaceholderFn = [&](llvm::DIType *node) {
-    auto [iter, inserted] = recursiveTypeMap.try_emplace(attr.getId(), node);
-    assert(inserted && "illegal reuse of recursive id");
-  };
-
-  llvm::DIType *node =
+  llvm::DIType *placeholder =
       TypeSwitch<DITypeAttr, llvm::DIType *>(attr.getBaseType())
-          .Case<DICompositeTypeAttr>([&](auto attr) {
-            return translateImpl(attr, setRecursivePlaceholderFn);
-          });
+          .Case<DICompositeTypeAttr>(
+              [&](auto attr) { return translateImplGetPlaceholder(attr); });
+
+  auto [iter, inserted] =
+      recursiveTypeMap.try_emplace(recursiveId, placeholder);
+  assert(inserted && "illegal reuse of recursive id");
+
+  TypeSwitch<DITypeAttr>(attr.getBaseType())
+      .Case<DICompositeTypeAttr>([&](auto attr) {
+        translateImplFillPlaceholder(attr,
+                                     cast<llvm::DICompositeType>(placeholder));
+      });
 
-  assert((recursiveStackSize + 1 == recursiveTypeMap.size()) &&
+  assert(recursiveTypeMap.back().first == recursiveId &&
          "internal inconsistency: unexpected recursive translation stack");
   recursiveTypeMap.pop_back();
 
-  return node;
+  return placeholder;
 }
 
 llvm::DIScope *DebugTranslation::translateImpl(DIScopeAttr attr) {
diff --git a/mlir/lib/Target/LLVMIR/DebugTranslation.h b/mlir/lib/Target/LLVMIR/DebugTranslation.h
index f4ba687ad564e6..50a6a6f4112264 100644
--- a/mlir/lib/Target/LLVMIR/DebugTranslation.h
+++ b/mlir/lib/Target/LLVMIR/DebugTranslation.h
@@ -68,22 +68,10 @@ class DebugTranslation {
   llvm::DIFile *translateFile(StringRef fileName);
 
   /// Translate the given attribute to the corresponding llvm debug metadata.
-  ///
-  /// For attributes corresponding to DITypes that can be recursive (i.e.
-  /// supports replacing subelements), an additional optional argument with type
-  /// `SetRecursivePlaceholderFn` should be supported.
-  /// The translation impl for recursive support must follow these three steps:
-  /// 1. Produce a placeholder version of the translated node without calling
-  /// `translate` on any subelements of the MLIR attr.
-  /// 2. Call the SetRecursivePlaceholderFn with the placeholder node.
-  /// 3. Translate subelements recursively using `translate` and fill the
-  /// original placeholder.
-  using SetRecursivePlaceholderFn = llvm::function_ref<void(llvm::DIType *)>;
   llvm::DIType *translateImpl(DINullTypeAttr attr);
   llvm::DIBasicType *translateImpl(DIBasicTypeAttr attr);
   llvm::DICompileUnit *translateImpl(DICompileUnitAttr attr);
-  llvm::DICompositeType *translateImpl(DICompositeTypeAttr attr,
-                                       SetRecursivePlaceholderFn setRec = {});
+  llvm::DICompositeType *translateImpl(DICompositeTypeAttr attr);
   llvm::DIDerivedType *translateImpl(DIDerivedTypeAttr attr);
   llvm::DIFile *translateImpl(DIFileAttr attr);
   llvm::DILabel *translateImpl(DILabelAttr attr);
@@ -101,6 +89,19 @@ class DebugTranslation {
   llvm::DISubroutineType *translateImpl(DISubroutineTypeAttr attr);
   llvm::DIType *translateImpl(DITypeAttr attr);
 
+  /// Attributes that support self recursion need to implement two methods and
+  /// hook into the `translateImpl` method of `DIRecursiveTypeAttr`.
+  /// - `<llvm type> translateImplGetPlaceholder(<mlir type>)`:
+  ///   Translate the DI attr without translating any potentially recursive
+  ///   nested DI attrs.
+  /// - `void translateImplFillPlaceholder(<mlir type>, <llvm type>)`:
+  ///   Given the placeholder returned by `translateImplGetPlaceholder`, fill
+  ///   any holes by recursively translating nested DI attrs. This method must
+  ///   mutate the placeholder that is passed in, instead of creating a new one.
+  llvm::DICompositeType *translateImplGetPlaceholder(DICompositeTypeAttr attr);
+  void translateImplFillPlaceholder(DICompositeTypeAttr attr,
+                                    llvm::DICompositeType *placeholder);
+
   /// Constructs a string metadata node from the string attribute. Returns
   /// nullptr if `stringAttr` is null or contains and empty string.
   llvm::MDString *getMDStringOrNull(StringAttr stringAttr);

>From 650161b7502408e910a4f5ef8fe11184743b9925 Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Thu, 29 Feb 2024 22:03:30 -0800
Subject: [PATCH 6/9] Apply suggestions from code review

Co-authored-by: Tobias Gysi <tobias.gysi at nextsilicon.com>
---
 mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td | 2 +-
 mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h     | 2 +-
 mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp         | 3 ++-
 mlir/lib/Target/LLVMIR/DebugImporter.cpp         | 4 ++--
 mlir/lib/Target/LLVMIR/DebugImporter.h           | 6 ++----
 5 files changed, 8 insertions(+), 9 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index a9a471800bda3c..07bf4ba8906e11 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -331,7 +331,7 @@ def LLVM_DIRecursiveTypeAttr : LLVM_Attr<"DIRecursiveType", "di_recursive_type",
 
       #var = di_local_variable<type = #struct_type, ...>
 
-    Note that the a rec-self without an outer rec-decl with the same id is
+    Note that a rec-self without an outer rec-decl with the same id is
     conceptually the same as an "unbound" variable. The context needs to provide
     meaning to the rec-self.
 
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h
index c9c0ba635b8549..c7398d293a71c7 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h
@@ -107,7 +107,7 @@ class DIRecursiveTypeAttrOf : public DITypeAttr {
     return llvm::isa<BaseType>(attr);
   }
 
-  DIRecursiveTypeAttrOf(BaseType baseType) : DITypeAttr(baseType) {}
+  DIRecursiveTypeAttrOf(BaseType baseTypeAttr) : DITypeAttr(baseTypeAttr) {}
 
   BaseType getUnfoldedBaseType() {
     if (auto rec = llvm::dyn_cast<DIRecursiveTypeAttr>(this))
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
index ed4001170bfe55..5275d0ae00573d 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
@@ -189,9 +189,10 @@ void printExpressionArg(AsmPrinter &printer, uint64_t opcode,
 //===----------------------------------------------------------------------===//
 // DIRecursiveTypeAttr
 //===----------------------------------------------------------------------===//
+
 DITypeAttr DIRecursiveTypeAttr::getUnfoldedBaseType() {
   assert(!isRecSelf() && "cannot get baseType from a rec-self type");
-  return llvm::cast<DITypeAttr>(getBaseType().replace(
+  return cast<DITypeAttr>(getBaseType().replace(
       [&](DIRecursiveTypeAttr rec) -> std::optional<DIRecursiveTypeAttr> {
         if (rec.getId() == getId())
           return *this;
diff --git a/mlir/lib/Target/LLVMIR/DebugImporter.cpp b/mlir/lib/Target/LLVMIR/DebugImporter.cpp
index f25a35acfe526d..dd7762df1723ad 100644
--- a/mlir/lib/Target/LLVMIR/DebugImporter.cpp
+++ b/mlir/lib/Target/LLVMIR/DebugImporter.cpp
@@ -178,7 +178,7 @@ DISubprogramAttr DebugImporter::translateImpl(llvm::DISubprogram *node) {
   std::optional<DISubprogramFlags> subprogramFlags =
       symbolizeDISubprogramFlags(node->getSubprogram()->getSPFlags());
   // Return nullptr if the scope or type is invalid.
-  DIScopeAttr scope = cast<DIScopeAttr>(translate(node->getScope()));
+  auto scope = cast<DIScopeAttr>(translate(node->getScope()));
   if (node->getScope() && !scope)
     return nullptr;
   DIRecursiveTypeAttrOf<DISubroutineTypeAttr> type = translate(node->getType());
@@ -325,7 +325,7 @@ DINodeAttr DebugImporter::translate(llvm::DINode *node) {
         DITypeAttr typeAttr = cast<DITypeAttr>(attr);
         attr = DIRecursiveTypeAttr::get(context, id, typeAttr);
 
-        // Remove the unbound recursive attr.
+        // Remove the unbound recursive DistinctAttr ID.
         unboundRecursiveSelfRefs.back().erase(id);
       }
     }
diff --git a/mlir/lib/Target/LLVMIR/DebugImporter.h b/mlir/lib/Target/LLVMIR/DebugImporter.h
index 7e9e28b274b983..a320e6e4448904 100644
--- a/mlir/lib/Target/LLVMIR/DebugImporter.h
+++ b/mlir/lib/Target/LLVMIR/DebugImporter.h
@@ -93,15 +93,13 @@ class DebugImporter {
 
   /// A mapping between LLVM debug metadata and the corresponding attribute.
   DenseMap<llvm::DINode *, DINodeAttr> nodeToAttr;
-  /// A mapping between LLVM debug metadata and the distinct ID attr for DI
-  /// nodes that require distinction.
+  /// A mapping between distinct LLVM debug metadata nodes and the corresponding distinct id attribute.
   DenseMap<llvm::DINode *, DistinctAttr> nodeToDistinctAttr;
 
   /// A stack that stores the metadata type nodes that are being traversed. The
   /// stack is used to detect cyclic dependencies during the metadata
   /// translation. Nodes are pushed with a null value. If it is ever seen twice,
-  /// it is given a DistinctAttr, indicating that it is a recursive node and
-  /// should take on that DistinctAttr as ID.
+  /// it is given a DistinctAttr ID, indicating that it is a recursive node.
   llvm::MapVector<llvm::DIType *, DistinctAttr> typeTranslationStack;
   /// All the unbound recursive self references in the translation stack.
   SmallVector<DenseSet<DistinctAttr>> unboundRecursiveSelfRefs;

>From d632957b9cceb4a34cf03f3f75d19e39cdc4bdb8 Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Thu, 29 Feb 2024 22:29:21 -0800
Subject: [PATCH 7/9] addressed some local comments

---
 .../include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td |  2 +-
 mlir/lib/Target/LLVMIR/DebugImporter.cpp        |  2 +-
 mlir/lib/Target/LLVMIR/DebugTranslation.h       | 17 +++++++++++------
 3 files changed, 13 insertions(+), 8 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index 07bf4ba8906e11..519904f584bb96 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -306,7 +306,7 @@ def LLVM_DIExpressionAttr : LLVM_Attr<"DIExpression", "di_expression"> {
 //===----------------------------------------------------------------------===//
 
 def LLVM_DIRecursiveTypeAttr : LLVM_Attr<"DIRecursiveType", "di_recursive_type",
-                                    /*traits=*/[], "DITypeAttr"> {
+                                          /*traits=*/[], "DITypeAttr"> {
   let description = [{
     This attribute enables recursive DITypes. There are two modes for this
     attribute.
diff --git a/mlir/lib/Target/LLVMIR/DebugImporter.cpp b/mlir/lib/Target/LLVMIR/DebugImporter.cpp
index dd7762df1723ad..e8aa0a640e591d 100644
--- a/mlir/lib/Target/LLVMIR/DebugImporter.cpp
+++ b/mlir/lib/Target/LLVMIR/DebugImporter.cpp
@@ -178,7 +178,7 @@ DISubprogramAttr DebugImporter::translateImpl(llvm::DISubprogram *node) {
   std::optional<DISubprogramFlags> subprogramFlags =
       symbolizeDISubprogramFlags(node->getSubprogram()->getSPFlags());
   // Return nullptr if the scope or type is invalid.
-  auto scope = cast<DIScopeAttr>(translate(node->getScope()));
+  auto scope = translate(node->getScope());
   if (node->getScope() && !scope)
     return nullptr;
   DIRecursiveTypeAttrOf<DISubroutineTypeAttr> type = translate(node->getType());
diff --git a/mlir/lib/Target/LLVMIR/DebugTranslation.h b/mlir/lib/Target/LLVMIR/DebugTranslation.h
index 50a6a6f4112264..24df966d2a614a 100644
--- a/mlir/lib/Target/LLVMIR/DebugTranslation.h
+++ b/mlir/lib/Target/LLVMIR/DebugTranslation.h
@@ -82,7 +82,6 @@ class DebugTranslation {
   llvm::DIGlobalVariable *translateImpl(DIGlobalVariableAttr attr);
   llvm::DIModule *translateImpl(DIModuleAttr attr);
   llvm::DINamespace *translateImpl(DINamespaceAttr attr);
-  llvm::DIType *translateImpl(DIRecursiveTypeAttr attr);
   llvm::DIScope *translateImpl(DIScopeAttr attr);
   llvm::DISubprogram *translateImpl(DISubprogramAttr attr);
   llvm::DISubrange *translateImpl(DISubrangeAttr attr);
@@ -90,7 +89,7 @@ class DebugTranslation {
   llvm::DIType *translateImpl(DITypeAttr attr);
 
   /// Attributes that support self recursion need to implement two methods and
-  /// hook into the `translateImpl` method of `DIRecursiveTypeAttr`.
+  /// hook into the `translateImpl` overload for `DIRecursiveTypeAttr`.
   /// - `<llvm type> translateImplGetPlaceholder(<mlir type>)`:
   ///   Translate the DI attr without translating any potentially recursive
   ///   nested DI attrs.
@@ -98,7 +97,12 @@ class DebugTranslation {
   ///   Given the placeholder returned by `translateImplGetPlaceholder`, fill
   ///   any holes by recursively translating nested DI attrs. This method must
   ///   mutate the placeholder that is passed in, instead of creating a new one.
+  llvm::DIType *translateImpl(DIRecursiveTypeAttr attr);
+
+  /// Get a placeholder DICompositeType without recursing into the elements.
   llvm::DICompositeType *translateImplGetPlaceholder(DICompositeTypeAttr attr);
+  /// Fill out the DICompositeType placeholder by recursively translating the
+  /// elements.
   void translateImplFillPlaceholder(DICompositeTypeAttr attr,
                                     llvm::DICompositeType *placeholder);
 
@@ -116,12 +120,13 @@ class DebugTranslation {
   /// metadata.
   DenseMap<Attribute, llvm::DINode *> attrToNode;
 
-  /// A mapping from DIRecursiveTypeAttr id to the translated DIType.
+  /// A mapping from DistinctAttr ID of DIRecursiveTypeAttr to the translated
+  /// DIType.
   llvm::MapVector<DistinctAttr, llvm::DIType *> recursiveTypeMap;
 
-  /// A mapping between distinct ID attr for DI nodes that require distinction
-  /// and the translate LLVM metadata node. This helps identify attrs that
-  /// should translate into the same LLVM debug node.
+  /// A mapping between DistinctAttr ID and the translated LLVM metadata node.
+  /// This helps identify attrs that should translate into the same LLVM debug
+  /// node.
   DenseMap<DistinctAttr, llvm::DINode *> distinctAttrToNode;
 
   /// A mapping between filename and llvm debug file.

>From 143658c48dd64b794498f73a48e3f5a952021114 Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Thu, 29 Feb 2024 22:39:25 -0800
Subject: [PATCH 8/9] rename id to recId

---
 mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td | 16 ++++++++--------
 mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp         |  4 ++--
 mlir/lib/Target/LLVMIR/DebugTranslation.cpp      |  2 +-
 3 files changed, 11 insertions(+), 11 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index 519904f584bb96..d0a479d04646ed 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -313,25 +313,25 @@ def LLVM_DIRecursiveTypeAttr : LLVM_Attr<"DIRecursiveType", "di_recursive_type",
 
     1. If `baseType` is present:
       - This type is considered a recursive declaration (rec-decl).
-      - The `baseType` is a self-recursive type identified by the `id` field.
+      - The `baseType` is a self-recursive type identified by the `recId` field.
 
     2. If `baseType` is not present:
       - This type is considered a recursive self reference (rec-self).
       - This DIRecursiveType itself is a placeholder type that should be
         conceptually replaced with the closet parent DIRecursiveType with the
-        same `id` field.
+        same `recId` field.
 
     e.g. To represent a linked list struct:
 
-      #rec_self = di_recursive_type<self_id = 0>
+      #rec_self = di_recursive_type<recId = 0>
       #ptr = di_derived_type<baseType: #rec_self, ...>
       #field = di_derived_type<name = "next", baseType: #ptr, ...>
       #struct = di_composite_type<name = "Node", elements: #field, ...>
-      #rec = di_recursive_type<self_id = 0, baseType: #struct>
+      #rec = di_recursive_type<recId = 0, baseType: #struct>
 
       #var = di_local_variable<type = #struct_type, ...>
 
-    Note that a rec-self without an outer rec-decl with the same id is
+    Note that a rec-self without an outer rec-decl with the same recId is
     conceptually the same as an "unbound" variable. The context needs to provide
     meaning to the rec-self.
 
@@ -343,13 +343,13 @@ def LLVM_DIRecursiveTypeAttr : LLVM_Attr<"DIRecursiveType", "di_recursive_type",
   }];
 
   let parameters = (ins
-    "DistinctAttr":$id,
+    "DistinctAttr":$recId,
     OptionalParameter<"DITypeAttr">:$baseType
   );
 
   let builders = [
-    AttrBuilderWithInferredContext<(ins "DistinctAttr":$id), [{
-      return $_get(id.getContext(), id, nullptr);
+    AttrBuilderWithInferredContext<(ins "DistinctAttr":$recId), [{
+      return $_get(recId.getContext(), recId, nullptr);
     }]>
   ];
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
index 5275d0ae00573d..ff240fe4c20411 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
@@ -192,9 +192,9 @@ void printExpressionArg(AsmPrinter &printer, uint64_t opcode,
 
 DITypeAttr DIRecursiveTypeAttr::getUnfoldedBaseType() {
   assert(!isRecSelf() && "cannot get baseType from a rec-self type");
-  return cast<DITypeAttr>(getBaseType().replace(
+  return llvm::cast<DITypeAttr>(getBaseType().replace(
       [&](DIRecursiveTypeAttr rec) -> std::optional<DIRecursiveTypeAttr> {
-        if (rec.getId() == getId())
+        if (rec.getRecId() == getRecId())
           return *this;
         return std::nullopt;
       }));
diff --git a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp
index 73c59282913062..e54974d8e0559a 100644
--- a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp
@@ -215,7 +215,7 @@ DebugTranslation::translateImpl(DIGlobalVariableAttr attr) {
 }
 
 llvm::DIType *DebugTranslation::translateImpl(DIRecursiveTypeAttr attr) {
-  DistinctAttr recursiveId = attr.getId();
+  DistinctAttr recursiveId = attr.getRecId();
   if (attr.isRecSelf()) {
     auto *iter = recursiveTypeMap.find(recursiveId);
     assert(iter != recursiveTypeMap.end() && "unbound DI recursive self type");

>From ac053949dab063bcbf6884e6c31d5cb7e1871b20 Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Tue, 5 Mar 2024 12:06:57 -0800
Subject: [PATCH 9/9] Replace dedicated type with interface instead

---
 .../mlir/Dialect/LLVMIR/CMakeLists.txt        |   2 +
 .../mlir/Dialect/LLVMIR/LLVMAttrDefs.td       | 110 +++++-------------
 mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h  |  33 +-----
 .../mlir/Dialect/LLVMIR/LLVMInterfaces.td     |  62 +++++++++-
 mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp      |  28 +++--
 mlir/lib/Target/LLVMIR/DebugImporter.cpp      |  57 +++++----
 mlir/lib/Target/LLVMIR/DebugImporter.h        |  36 +++---
 mlir/lib/Target/LLVMIR/DebugTranslation.cpp   |  57 +++++----
 mlir/lib/Target/LLVMIR/DebugTranslation.h     |   2 +-
 mlir/test/Target/LLVMIR/Import/debug-info.ll  |  21 ++--
 mlir/test/Target/LLVMIR/llvmir-debug.mlir     |  12 +-
 11 files changed, 213 insertions(+), 207 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
index 862abf00d03450..759de745440c21 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
@@ -29,6 +29,8 @@ add_mlir_doc(LLVMIntrinsicOps LLVMIntrinsicOps Dialects/ -gen-op-doc)
 set(LLVM_TARGET_DEFINITIONS LLVMInterfaces.td)
 mlir_tablegen(LLVMInterfaces.h.inc -gen-op-interface-decls)
 mlir_tablegen(LLVMInterfaces.cpp.inc -gen-op-interface-defs)
+mlir_tablegen(LLVMAttrInterfaces.h.inc -gen-attr-interface-decls)
+mlir_tablegen(LLVMAttrInterfaces.cpp.inc -gen-attr-interface-defs)
 mlir_tablegen(LLVMTypeInterfaces.h.inc -gen-type-interface-decls)
 mlir_tablegen(LLVMTypeInterfaces.cpp.inc -gen-type-interface-defs)
 add_public_tablegen_target(MLIRLLVMInterfacesIncGen)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index d0a479d04646ed..97e88f00921785 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -12,6 +12,7 @@
 include "mlir/Dialect/LLVMIR/LLVMDialect.td"
 include "mlir/IR/AttrTypeBase.td"
 include "mlir/IR/CommonAttrConstraints.td"
+include "mlir/Dialect/LLVMIR/LLVMInterfaces.td"
 
 // All of the attributes will extend this class.
 class LLVM_Attr<string name, string attrMnemonic,
@@ -238,7 +239,7 @@ def LoopAnnotationAttr : LLVM_Attr<"LoopAnnotation", "loop_annotation"> {
 //===----------------------------------------------------------------------===//
 
 class LLVM_DIParameter<string summary, string default, string parseName,
-                       string printName = parseName>
+                       string errorCase, string printName = parseName>
     : AttrOrTypeParameter<"unsigned", "debug info " # summary> {
   let parser = [{ [&]() -> FailureOr<unsigned> {
     SMLoc tagLoc = $_parser.getCurrentLocation();
@@ -246,33 +247,34 @@ class LLVM_DIParameter<string summary, string default, string parseName,
     if ($_parser.parseKeyword(&name))
       return failure();
 
-    if (unsigned tag = llvm::dwarf::get}] # parseName # [{(name))
-      return tag;
-    return $_parser.emitError(tagLoc)
-      << "invalid debug info }] # summary # [{ name: " << name;
+    unsigned tag = llvm::dwarf::get}] # parseName # [{(name);
+    if (tag == }] # errorCase # [{)
+      return $_parser.emitError(tagLoc)
+        << "invalid debug info }] # summary # [{ name: " << name;
+    return tag;
   }() }];
   let printer = "$_printer << llvm::dwarf::" # printName # "String($_self)";
   let defaultValue = default;
 }
 
 def LLVM_DICallingConventionParameter : LLVM_DIParameter<
-  "calling convention", /*default=*/"0", "CallingConvention", "Convention"
+  "calling convention", /*default=*/"0", "CallingConvention", "0", "Convention"
 >;
 
 def LLVM_DIEncodingParameter : LLVM_DIParameter<
-  "encoding", /*default=*/"0", "AttributeEncoding"
+  "encoding", /*default=*/"0", "AttributeEncoding", "0"
 >;
 
 def LLVM_DILanguageParameter : LLVM_DIParameter<
-  "language", /*default=*/"", "Language"
+  "language", /*default=*/"", "Language", "0"
 >;
 
 def LLVM_DITagParameter : LLVM_DIParameter<
-  "tag", /*default=*/"", "Tag"
+  "tag", /*default=*/"", "Tag", /*errorCase=*/"llvm::dwarf::DW_TAG_invalid"
 >;
 
 def LLVM_DIOperationEncodingParameter : LLVM_DIParameter<
-  "operation encoding", /*default=*/"", "OperationEncoding"
+  "operation encoding", /*default=*/"", "OperationEncoding", "0"
 >;
 
 //===----------------------------------------------------------------------===//
@@ -301,70 +303,6 @@ def LLVM_DIExpressionAttr : LLVM_Attr<"DIExpression", "di_expression"> {
   let assemblyFormat = "`<` ( `[` $operations^ `]` ) : (``)? `>`";
 }
 
-//===----------------------------------------------------------------------===//
-// DIRecursiveTypeAttr
-//===----------------------------------------------------------------------===//
-
-def LLVM_DIRecursiveTypeAttr : LLVM_Attr<"DIRecursiveType", "di_recursive_type",
-                                          /*traits=*/[], "DITypeAttr"> {
-  let description = [{
-    This attribute enables recursive DITypes. There are two modes for this
-    attribute.
-
-    1. If `baseType` is present:
-      - This type is considered a recursive declaration (rec-decl).
-      - The `baseType` is a self-recursive type identified by the `recId` field.
-
-    2. If `baseType` is not present:
-      - This type is considered a recursive self reference (rec-self).
-      - This DIRecursiveType itself is a placeholder type that should be
-        conceptually replaced with the closet parent DIRecursiveType with the
-        same `recId` field.
-
-    e.g. To represent a linked list struct:
-
-      #rec_self = di_recursive_type<recId = 0>
-      #ptr = di_derived_type<baseType: #rec_self, ...>
-      #field = di_derived_type<name = "next", baseType: #ptr, ...>
-      #struct = di_composite_type<name = "Node", elements: #field, ...>
-      #rec = di_recursive_type<recId = 0, baseType: #struct>
-
-      #var = di_local_variable<type = #struct_type, ...>
-
-    Note that a rec-self without an outer rec-decl with the same recId is
-    conceptually the same as an "unbound" variable. The context needs to provide
-    meaning to the rec-self.
-
-    This can be avoided by calling the `getUnfoldedBaseType()` method on a
-    rec-decl, which returns the `baseType` with all matching rec-self instances
-    replaced with this rec-decl again. This is useful, for example, for fetching
-    a field out of a recursive struct and maintaining the legality of the field
-    type.
-  }];
-
-  let parameters = (ins
-    "DistinctAttr":$recId,
-    OptionalParameter<"DITypeAttr">:$baseType
-  );
-
-  let builders = [
-    AttrBuilderWithInferredContext<(ins "DistinctAttr":$recId), [{
-      return $_get(recId.getContext(), recId, nullptr);
-    }]>
-  ];
-
-  let extraClassDeclaration = [{
-    /// Whether this node represents a self-reference.
-    bool isRecSelf() { return !getBaseType(); }
-
-    /// Get the `baseType` with all instances of the corresponding rec-self
-    /// replaced with this attribute. This can only be called if `!isRecSelf()`.
-    DITypeAttr getUnfoldedBaseType();
-  }];
-
-  let assemblyFormat = "`<` struct(params) `>`";
-}
-
 //===----------------------------------------------------------------------===//
 // DINullTypeAttr
 //===----------------------------------------------------------------------===//
@@ -421,9 +359,11 @@ def LLVM_DICompileUnitAttr : LLVM_Attr<"DICompileUnit", "di_compile_unit",
 //===----------------------------------------------------------------------===//
 
 def LLVM_DICompositeTypeAttr : LLVM_Attr<"DICompositeType", "di_composite_type",
-                                         /*traits=*/[], "DITypeAttr"> {
+                                         [LLVM_DIRecursiveTypeAttrInterface],
+                                         "DITypeAttr"> {
   let parameters = (ins
     LLVM_DITagParameter:$tag,
+    OptionalParameter<"DistinctAttr">:$recId,
     OptionalParameter<"StringAttr">:$name,
     OptionalParameter<"DIFileAttr">:$file,
     OptionalParameter<"uint32_t">:$line,
@@ -435,6 +375,21 @@ def LLVM_DICompositeTypeAttr : LLVM_Attr<"DICompositeType", "di_composite_type",
     OptionalArrayRefParameter<"DINodeAttr">:$elements
   );
   let assemblyFormat = "`<` struct(params) `>`";
+  let extraClassDeclaration = [{
+    /// Requirements of DIRecursiveTypeAttrInterface.
+    /// @{
+
+    /// Get whether this attr describes a recursive self reference.
+    bool isRecSelf() { return getTag() == 0; }
+
+    /// Get a copy of this type attr but with the recursive ID set to `recId`.
+    DIRecursiveTypeAttrInterface withRecId(DistinctAttr recId);
+
+    /// Build a rec-self instance using the provided recId.
+    static DIRecursiveTypeAttrInterface getRecSelf(DistinctAttr recId);
+
+    /// @}
+  }];
 }
 
 //===----------------------------------------------------------------------===//
@@ -590,15 +545,14 @@ def LLVM_DISubprogramAttr : LLVM_Attr<"DISubprogram", "di_subprogram",
     OptionalParameter<"unsigned">:$line,
     OptionalParameter<"unsigned">:$scopeLine,
     "DISubprogramFlags":$subprogramFlags,
-    OptionalParameter<"DIRecursiveTypeAttrOf<DISubroutineTypeAttr>">:$type
+    OptionalParameter<"DISubroutineTypeAttr">:$type
   );
   let builders = [
     AttrBuilderWithInferredContext<(ins
       "DistinctAttr":$id, "DICompileUnitAttr":$compileUnit,
       "DIScopeAttr":$scope, "StringRef":$name, "StringRef":$linkageName,
       "DIFileAttr":$file, "unsigned":$line, "unsigned":$scopeLine,
-      "DISubprogramFlags":$subprogramFlags,
-      "DIRecursiveTypeAttrOf<DISubroutineTypeAttr>":$type
+      "DISubprogramFlags":$subprogramFlags, "DISubroutineTypeAttr":$type
     ), [{
       MLIRContext *ctx = file.getContext();
       return $_get(ctx, id, compileUnit, scope, StringAttr::get(ctx, name),
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h
index c7398d293a71c7..ae9cca1ced13de 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h
@@ -74,10 +74,6 @@ class TBAANodeAttr : public Attribute {
   }
 };
 
-// Forward declare.
-template <typename BaseType>
-class DIRecursiveTypeAttrOf;
-
 // Inline the LLVM generated Linkage enum and utility.
 // This is only necessary to isolate the "enum generated code" from the
 // attribute definition itself.
@@ -88,34 +84,9 @@ using linkage::Linkage;
 } // namespace LLVM
 } // namespace mlir
 
+#include "mlir/Dialect/LLVMIR/LLVMAttrInterfaces.h.inc"
+
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.h.inc"
 
-namespace mlir {
-namespace LLVM {
-/// This class represents either a concrete attr, or a DIRecursiveTypeAttr
-/// containing such a concrete attr.
-template <typename BaseType>
-class DIRecursiveTypeAttrOf : public DITypeAttr {
-public:
-  static_assert(std::is_base_of_v<DITypeAttr, BaseType>);
-  using DITypeAttr::DITypeAttr;
-  /// Support LLVM type casting.
-  static bool classof(Attribute attr) {
-    if (auto rec = llvm::dyn_cast<DIRecursiveTypeAttr>(attr))
-      return llvm::isa<BaseType>(rec.getBaseType());
-    return llvm::isa<BaseType>(attr);
-  }
-
-  DIRecursiveTypeAttrOf(BaseType baseTypeAttr) : DITypeAttr(baseTypeAttr) {}
-
-  BaseType getUnfoldedBaseType() {
-    if (auto rec = llvm::dyn_cast<DIRecursiveTypeAttr>(this))
-      return llvm::cast<BaseType>(rec.getUnfoldedBaseType());
-    return llvm::cast<BaseType>(this);
-  }
-};
-} // namespace LLVM
-} // namespace mlir
-
 #endif // MLIR_DIALECT_LLVMIR_LLVMATTRS_H_
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
index 3b2a132a881e4e..ee14ee8dee1b2f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
@@ -6,7 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This file defines op and type interfaces for the LLVM dialect in MLIR.
+// This file defines op, type, & attr interfaces for the LLVM dialect in MLIR.
 //
 //===----------------------------------------------------------------------===//
 
@@ -319,4 +319,64 @@ def LLVM_PointerElementTypeInterface
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// LLVM dialect attr interfaces.
+//===----------------------------------------------------------------------===//
+
+def LLVM_DIRecursiveTypeAttrInterface
+  : AttrInterface<"DIRecursiveTypeAttrInterface"> {
+  let description = [{
+    This attribute represents a DITypeAttr that is recursive. Only DITypeAttrs
+    that translate to LLVM DITypes that support mutation should implement this
+    interface.
+
+    There are two modes for conforming attributes:
+
+    1. "rec-decl":
+      - This attr is a recursive declaration identified by a recId.
+
+    2. "rec-self":
+      - This attr is considered a recursive self reference.
+      - This attr itself is a placeholder type that should be conceptually
+        replaced with the closest parent attr of the same type with the same
+        recId.
+
+    e.g. To represent a linked list struct:
+
+      #rec_self = di_composite_type<recId = 0>
+      #ptr = di_derived_type<baseType: #struct_self, ...>
+      #field = di_derived_type<name = "next", baseType: #ptr, ...>
+      #rec = di_composite_type<recId = 0, name = "Node", elements: #field, ...>
+      #var = di_local_variable<type = #rec, ...>
+
+    Note that a rec-self without an outer rec-decl with the same recId is
+    conceptually the same as an "unbound" variable. The context needs to provide
+    meaning to the rec-self.
+
+    This can be avoided by calling the `getUnfoldedBaseType()` method on a
+    rec-decl, which returns the same type with all matching rec-self instances
+    replaced with this rec-decl again. This is useful, for example, for fetching
+    a field out of a recursive struct and maintaining the legality of the field
+    type.
+  }];
+  let cppNamespace = "::mlir::LLVM";
+  let methods = [
+    InterfaceMethod<[{
+      Get whether this attr describes a recursive self reference.
+    }], "bool", "isRecSelf", (ins)>,
+    InterfaceMethod<[{
+      Get the recursive ID used for matching "rec-decl" with "rec-self".
+      If this attr instance is not recursive, return a null attribute.
+    }], "DistinctAttr", "getRecId", (ins)>,
+    InterfaceMethod<[{
+      Get a copy of this type attr but with the recursive ID set to `recId`.
+    }], "DIRecursiveTypeAttrInterface", "withRecId",
+    (ins "DistinctAttr":$recId)>,
+    StaticInterfaceMethod<[{
+      Build a rec-self instance using the provided recId.
+    }], "DIRecursiveTypeAttrInterface", "getRecSelf",
+    (ins "DistinctAttr":$recId)>
+  ];
+}
+
 #endif // LLVMIR_INTERFACES
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
index ff240fe4c20411..2cafac2f750d87 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
@@ -35,6 +35,7 @@ static LogicalResult parseExpressionArg(AsmParser &parser, uint64_t opcode,
 static void printExpressionArg(AsmPrinter &printer, uint64_t opcode,
                                ArrayRef<uint64_t> args);
 
+#include "mlir/Dialect/LLVMIR/LLVMAttrInterfaces.cpp.inc"
 #include "mlir/Dialect/LLVMIR/LLVMOpsEnums.cpp.inc"
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.cpp.inc"
@@ -86,9 +87,8 @@ bool DILocalScopeAttr::classof(Attribute attr) {
 //===----------------------------------------------------------------------===//
 
 bool DITypeAttr::classof(Attribute attr) {
-  return llvm::isa<DIRecursiveTypeAttr, DINullTypeAttr, DIBasicTypeAttr,
-                   DICompositeTypeAttr, DIDerivedTypeAttr,
-                   DISubroutineTypeAttr>(attr);
+  return llvm::isa<DINullTypeAttr, DIBasicTypeAttr, DICompositeTypeAttr,
+                   DIDerivedTypeAttr, DISubroutineTypeAttr>(attr);
 }
 
 //===----------------------------------------------------------------------===//
@@ -187,17 +187,21 @@ void printExpressionArg(AsmPrinter &printer, uint64_t opcode,
 }
 
 //===----------------------------------------------------------------------===//
-// DIRecursiveTypeAttr
+// DICompositeTypeAttr
 //===----------------------------------------------------------------------===//
 
-DITypeAttr DIRecursiveTypeAttr::getUnfoldedBaseType() {
-  assert(!isRecSelf() && "cannot get baseType from a rec-self type");
-  return llvm::cast<DITypeAttr>(getBaseType().replace(
-      [&](DIRecursiveTypeAttr rec) -> std::optional<DIRecursiveTypeAttr> {
-        if (rec.getRecId() == getRecId())
-          return *this;
-        return std::nullopt;
-      }));
+DIRecursiveTypeAttrInterface
+DICompositeTypeAttr::withRecId(DistinctAttr recId) {
+  return DICompositeTypeAttr::get(getContext(), getTag(), recId, getName(),
+                                  getFile(), getLine(), getScope(),
+                                  getBaseType(), getFlags(), getSizeInBits(),
+                                  getAlignInBits(), getElements());
+}
+
+DIRecursiveTypeAttrInterface
+DICompositeTypeAttr::getRecSelf(DistinctAttr recId) {
+  return DICompositeTypeAttr::get(recId.getContext(), 0, recId, {}, {}, 0, {},
+                                  {}, DIFlags(), 0, 0, {});
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/DebugImporter.cpp b/mlir/lib/Target/LLVMIR/DebugImporter.cpp
index e8aa0a640e591d..506ef890e876e0 100644
--- a/mlir/lib/Target/LLVMIR/DebugImporter.cpp
+++ b/mlir/lib/Target/LLVMIR/DebugImporter.cpp
@@ -13,6 +13,7 @@
 #include "mlir/IR/Location.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/ScopeExit.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include "llvm/BinaryFormat/Dwarf.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DebugInfoMetadata.h"
@@ -72,9 +73,10 @@ DICompositeTypeAttr DebugImporter::translateImpl(llvm::DICompositeType *node) {
   if (node->getTag() == llvm::dwarf::DW_TAG_array_type && !baseType)
     return nullptr;
   return DICompositeTypeAttr::get(
-      context, node->getTag(), getStringAttrOrNull(node->getRawName()),
-      translate(node->getFile()), node->getLine(), translate(node->getScope()),
-      baseType, flags.value_or(DIFlags::Zero), node->getSizeInBits(),
+      context, node->getTag(), /*recId=*/{},
+      getStringAttrOrNull(node->getRawName()), translate(node->getFile()),
+      node->getLine(), translate(node->getScope()), baseType,
+      flags.value_or(DIFlags::Zero), node->getSizeInBits(),
       node->getAlignInBits(), elements);
 }
 
@@ -178,10 +180,10 @@ DISubprogramAttr DebugImporter::translateImpl(llvm::DISubprogram *node) {
   std::optional<DISubprogramFlags> subprogramFlags =
       symbolizeDISubprogramFlags(node->getSubprogram()->getSPFlags());
   // Return nullptr if the scope or type is invalid.
-  auto scope = translate(node->getScope());
+  DIScopeAttr scope = translate(node->getScope());
   if (node->getScope() && !scope)
     return nullptr;
-  DIRecursiveTypeAttrOf<DISubroutineTypeAttr> type = translate(node->getType());
+  DISubroutineTypeAttr type = translate(node->getType());
   if (node->getType() && !type)
     return nullptr;
   return DISubprogramAttr::get(context, id, translate(node->getUnit()), scope,
@@ -242,12 +244,13 @@ DINodeAttr DebugImporter::translate(llvm::DINode *node) {
   if (DINodeAttr attr = nodeToAttr.lookup(node))
     return attr;
 
-  // If a cyclic dependency is detected since the same node is being traversed
-  // twice, emit a recursive self type, and mark the duplicate node on the
-  // translationStack so it can emit a recursive decl type.
-  auto *typeNode = dyn_cast<llvm::DIType>(node);
-  if (typeNode) {
-    auto [iter, inserted] = typeTranslationStack.try_emplace(typeNode, nullptr);
+  // If the node type is capable of being recursive, check if it's seen before.
+  auto recSelfCtor = getRecSelfConstructor(node);
+  if (recSelfCtor) {
+    // If a cyclic dependency is detected since the same node is being traversed
+    // twice, emit a recursive self type, and mark the duplicate node on the
+    // translationStack so it can emit a recursive decl type.
+    auto [iter, inserted] = translationStack.try_emplace(node, nullptr);
     if (!inserted) {
       // The original node may have already been assigned a recursive ID from
       // a different self-reference. Use that if possible.
@@ -257,21 +260,16 @@ DINodeAttr DebugImporter::translate(llvm::DINode *node) {
         iter->second = recId;
       }
       unboundRecursiveSelfRefs.back().insert(recId);
-      return DIRecursiveTypeAttr::get(recId);
+
+      return cast<DINodeAttr>(recSelfCtor(recId));
     }
-  } else {
-    bool inserted =
-        nonTypeTranslationStack.insert({node, typeTranslationStack.size()});
-    assert(inserted && "recursion is only supported via DITypes");
   }
 
   unboundRecursiveSelfRefs.emplace_back();
 
   auto guard = llvm::make_scope_exit([&]() {
-    if (typeNode)
-      typeTranslationStack.pop_back();
-    else
-      nonTypeTranslationStack.pop_back();
+    if (recSelfCtor)
+      translationStack.pop_back();
 
     // Copy unboundRecursiveSelfRefs down to the previous level.
     if (unboundRecursiveSelfRefs.size() == 1)
@@ -320,11 +318,9 @@ DINodeAttr DebugImporter::translate(llvm::DINode *node) {
   };
   if (DINodeAttr attr = translateNode(node)) {
     // If this node was marked as recursive, wrap with a recursive type.
-    if (typeNode) {
-      if (DistinctAttr id = typeTranslationStack.lookup(typeNode)) {
-        DITypeAttr typeAttr = cast<DITypeAttr>(attr);
-        attr = DIRecursiveTypeAttr::get(context, id, typeAttr);
-
+    if (auto recType = dyn_cast<DIRecursiveTypeAttrInterface>(attr)) {
+      if (DistinctAttr id = translationStack.lookup(node)) {
+        attr = cast<DINodeAttr>(recType.withRecId(id));
         // Remove the unbound recursive DistinctAttr ID.
         unboundRecursiveSelfRefs.back().erase(id);
       }
@@ -396,3 +392,14 @@ DistinctAttr DebugImporter::getOrCreateDistinctID(llvm::DINode *node) {
     id = DistinctAttr::create(UnitAttr::get(context));
   return id;
 }
+
+llvm::function_ref<DIRecursiveTypeAttrInterface(DistinctAttr)>
+DebugImporter::getRecSelfConstructor(llvm::DINode *node) {
+  using CtorType =
+      llvm::function_ref<DIRecursiveTypeAttrInterface(DistinctAttr)>;
+  return TypeSwitch<llvm::DINode *, CtorType>(node)
+      .Case<llvm::DICompositeType>([](auto *concreteNode) {
+        return CtorType(decltype(translateImpl(concreteNode))::getRecSelf);
+      })
+      .Default(CtorType());
+}
diff --git a/mlir/lib/Target/LLVMIR/DebugImporter.h b/mlir/lib/Target/LLVMIR/DebugImporter.h
index a320e6e4448904..c0238732f777c1 100644
--- a/mlir/lib/Target/LLVMIR/DebugImporter.h
+++ b/mlir/lib/Target/LLVMIR/DebugImporter.h
@@ -52,15 +52,8 @@ class DebugImporter {
   /// Infers the metadata type and translates it to MLIR.
   template <typename DINodeT>
   auto translate(DINodeT *node) {
-    // Infer the result MLIR type from the LLVM metadata type.
-    // If the result is a DIType, it can also be wrapped in a recursive type,
-    // so the result is wrapped into a DIRecursiveTypeAttrOf.
-    // Otherwise, the exact result type is used.
-    constexpr bool isDIType = std::is_base_of_v<llvm::DIType, DINodeT>;
-    using RawMLIRTypeT = decltype(translateImpl(node));
-    using MLIRTypeT =
-        std::conditional_t<isDIType, DIRecursiveTypeAttrOf<RawMLIRTypeT>,
-                           RawMLIRTypeT>;
+    // Infer the MLIR type from the LLVM metadata type.
+    using MLIRTypeT = decltype(translateImpl(node));
     return cast_or_null<MLIRTypeT>(
         translate(static_cast<llvm::DINode *>(node)));
   }
@@ -89,25 +82,28 @@ class DebugImporter {
   /// null attribute otherwise.
   StringAttr getStringAttrOrNull(llvm::MDString *stringNode);
 
+  /// Get the DistinctAttr used to represent `node` if one was already created
+  /// for it, or create a new one if not.
   DistinctAttr getOrCreateDistinctID(llvm::DINode *node);
 
+  /// Get the `getRecSelf` constructor for the translated type of `node` if its
+  /// translated DITypeAttr supports recursion. Otherwise, returns nullptr.
+  llvm::function_ref<DIRecursiveTypeAttrInterface(DistinctAttr)>
+  getRecSelfConstructor(llvm::DINode *node);
+
   /// A mapping between LLVM debug metadata and the corresponding attribute.
   DenseMap<llvm::DINode *, DINodeAttr> nodeToAttr;
-  /// A mapping between distinct LLVM debug metadata nodes and the corresponding distinct id attribute.
+  /// A mapping between distinct LLVM debug metadata nodes and the corresponding
+  /// distinct id attribute.
   DenseMap<llvm::DINode *, DistinctAttr> nodeToDistinctAttr;
 
-  /// A stack that stores the metadata type nodes that are being traversed. The
-  /// stack is used to detect cyclic dependencies during the metadata
-  /// translation. Nodes are pushed with a null value. If it is ever seen twice,
-  /// it is given a DistinctAttr ID, indicating that it is a recursive node.
-  llvm::MapVector<llvm::DIType *, DistinctAttr> typeTranslationStack;
+  /// A stack that stores the metadata nodes that are being traversed. The stack
+  /// is used to detect cyclic dependencies during the metadata translation.
+  /// Nodes are pushed with a null value. If it is ever seen twice, it is given
+  /// a DistinctAttr ID, indicating that it is a recursive node.
+  llvm::MapVector<llvm::DINode *, DistinctAttr> translationStack;
   /// All the unbound recursive self references in the translation stack.
   SmallVector<DenseSet<DistinctAttr>> unboundRecursiveSelfRefs;
-  /// A stack that stores the non-type metadata nodes that are being traversed.
-  /// Each node is associated with the size of the `typeTranslationStack` at the
-  /// time of push. This is used to identify a recursion purely in the non-type
-  /// metadata nodes, which is not supported yet.
-  SetVector<std::pair<llvm::DINode *, unsigned>> nonTypeTranslationStack;
 
   MLIRContext *context;
   ModuleOp mlirModule;
diff --git a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp
index e54974d8e0559a..407c56a2e7ed4f 100644
--- a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp
@@ -214,7 +214,8 @@ DebugTranslation::translateImpl(DIGlobalVariableAttr attr) {
       attr.getIsDefined(), nullptr, nullptr, attr.getAlignInBits(), nullptr);
 }
 
-llvm::DIType *DebugTranslation::translateImpl(DIRecursiveTypeAttr attr) {
+llvm::DIType *
+DebugTranslation::translateRecursive(DIRecursiveTypeAttrInterface attr) {
   DistinctAttr recursiveId = attr.getRecId();
   if (attr.isRecSelf()) {
     auto *iter = recursiveTypeMap.find(recursiveId);
@@ -222,26 +223,26 @@ llvm::DIType *DebugTranslation::translateImpl(DIRecursiveTypeAttr attr) {
     return iter->second;
   }
 
-  llvm::DIType *placeholder =
-      TypeSwitch<DITypeAttr, llvm::DIType *>(attr.getBaseType())
-          .Case<DICompositeTypeAttr>(
-              [&](auto attr) { return translateImplGetPlaceholder(attr); });
-
-  auto [iter, inserted] =
-      recursiveTypeMap.try_emplace(recursiveId, placeholder);
-  assert(inserted && "illegal reuse of recursive id");
+  auto setRecursivePlaceholder = [&](llvm::DIType *placeholder) {
+    auto [iter, inserted] =
+        recursiveTypeMap.try_emplace(recursiveId, placeholder);
+    assert(inserted && "illegal reuse of recursive id");
+  };
 
-  TypeSwitch<DITypeAttr>(attr.getBaseType())
-      .Case<DICompositeTypeAttr>([&](auto attr) {
-        translateImplFillPlaceholder(attr,
-                                     cast<llvm::DICompositeType>(placeholder));
-      });
+  llvm::DIType *result =
+      TypeSwitch<DIRecursiveTypeAttrInterface, llvm::DIType *>(attr)
+          .Case<DICompositeTypeAttr>([&](auto attr) {
+            auto *placeholder = translateImplGetPlaceholder(attr);
+            setRecursivePlaceholder(placeholder);
+            translateImplFillPlaceholder(attr, placeholder);
+            return placeholder;
+          });
 
   assert(recursiveTypeMap.back().first == recursiveId &&
          "internal inconsistency: unexpected recursive translation stack");
   recursiveTypeMap.pop_back();
 
-  return placeholder;
+  return result;
 }
 
 llvm::DIScope *DebugTranslation::translateImpl(DIScopeAttr attr) {
@@ -330,15 +331,23 @@ llvm::DINode *DebugTranslation::translate(DINodeAttr attr) {
   if (llvm::DINode *node = attrToNode.lookup(attr))
     return node;
 
-  llvm::DINode *node =
-      TypeSwitch<DINodeAttr, llvm::DINode *>(attr)
-          .Case<DIBasicTypeAttr, DICompileUnitAttr, DICompositeTypeAttr,
-                DIDerivedTypeAttr, DIFileAttr, DIGlobalVariableAttr,
-                DILabelAttr, DILexicalBlockAttr, DILexicalBlockFileAttr,
-                DILocalVariableAttr, DIModuleAttr, DINamespaceAttr,
-                DINullTypeAttr, DIRecursiveTypeAttr, DISubprogramAttr,
-                DISubrangeAttr, DISubroutineTypeAttr>(
-              [&](auto attr) { return translateImpl(attr); });
+  llvm::DINode *node = nullptr;
+  // Recursive types go through a dedicated handler. All other types are
+  // dispatched directly to their specific handlers.
+  if (auto recTypeAttr = dyn_cast<DIRecursiveTypeAttrInterface>(attr))
+    if (recTypeAttr.getRecId())
+      node = translateRecursive(recTypeAttr);
+
+  if (!node)
+    node = TypeSwitch<DINodeAttr, llvm::DINode *>(attr)
+               .Case<DIBasicTypeAttr, DICompileUnitAttr, DICompositeTypeAttr,
+                     DIDerivedTypeAttr, DIFileAttr, DIGlobalVariableAttr,
+                     DILabelAttr, DILexicalBlockAttr, DILexicalBlockFileAttr,
+                     DILocalVariableAttr, DIModuleAttr, DINamespaceAttr,
+                     DINullTypeAttr, DISubprogramAttr, DISubrangeAttr,
+                     DISubroutineTypeAttr>(
+                   [&](auto attr) { return translateImpl(attr); });
+
   attrToNode.insert({attr, node});
   return node;
 }
diff --git a/mlir/lib/Target/LLVMIR/DebugTranslation.h b/mlir/lib/Target/LLVMIR/DebugTranslation.h
index 24df966d2a614a..a1f9abb29c07bc 100644
--- a/mlir/lib/Target/LLVMIR/DebugTranslation.h
+++ b/mlir/lib/Target/LLVMIR/DebugTranslation.h
@@ -97,7 +97,7 @@ class DebugTranslation {
   ///   Given the placeholder returned by `translateImplGetPlaceholder`, fill
   ///   any holes by recursively translating nested DI attrs. This method must
   ///   mutate the placeholder that is passed in, instead of creating a new one.
-  llvm::DIType *translateImpl(DIRecursiveTypeAttr attr);
+  llvm::DIType *translateRecursive(DIRecursiveTypeAttrInterface attr);
 
   /// Get a placeholder DICompositeType without recursing into the elements.
   llvm::DICompositeType *translateImplGetPlaceholder(DICompositeTypeAttr attr);
diff --git a/mlir/test/Target/LLVMIR/Import/debug-info.ll b/mlir/test/Target/LLVMIR/Import/debug-info.ll
index 032acc8f8c8115..9af40d8c8d3ee6 100644
--- a/mlir/test/Target/LLVMIR/Import/debug-info.ll
+++ b/mlir/test/Target/LLVMIR/Import/debug-info.ll
@@ -297,14 +297,15 @@ define void @class_method() {
 }
 
 ; Verify the cyclic composite type is identified, even though conversion begins from the subprogram type.
-; CHECK: #[[COMP_PTR:.+]] = #llvm.di_derived_type<tag = DW_TAG_pointer_type, baseType = #llvm.di_recursive_type<id = [[COMP_ID:.+]]>, sizeInBits = 64>
+; CHECK: #[[COMP_SELF:.+]] = #llvm.di_composite_type<tag = DW_TAG_null, recId = [[REC_ID:.+]]>
+; CHECK: #[[COMP_PTR:.+]] = #llvm.di_derived_type<tag = DW_TAG_pointer_type, baseType = #[[COMP_SELF]], sizeInBits = 64>
 ; CHECK: #[[SP_TYPE:.+]] = #llvm.di_subroutine_type<types = #{{.*}}, #[[COMP_PTR]]>
-; CHECK: #[[SP_INNER:.+]] = #llvm.di_subprogram<id = [[SP_ID:.+]], compileUnit = #{{.*}}, scope = #llvm.di_recursive_type<id = [[COMP_ID]]>, name = "class_method", file = #{{.*}}, subprogramFlags = Definition, type = #[[SP_TYPE]]>
-; CHECK: #[[COMP:.+]] = #llvm.di_composite_type<tag = DW_TAG_class_type, name = "class_name", file = #{{.*}}, line = 42, flags = "TypePassByReference|NonTrivial", elements = #[[SP_INNER]]>
+; CHECK: #[[SP_INNER:.+]] = #llvm.di_subprogram<id = [[SP_ID:.+]], compileUnit = #{{.*}}, scope = #[[COMP_SELF]], name = "class_method", file = #{{.*}}, subprogramFlags = Definition, type = #[[SP_TYPE]]>
+; CHECK: #[[COMP:.+]] = #llvm.di_composite_type<tag = DW_TAG_class_type, recId = [[REC_ID]], name = "class_name", file = #{{.*}}, line = 42, flags = "TypePassByReference|NonTrivial", elements = #[[SP_INNER]]>
 
-; CHECK: #[[COMP_PTR_OUTER:.+]] = #llvm.di_derived_type<tag = DW_TAG_pointer_type, baseType = #llvm.di_recursive_type<id = [[COMP_ID]], baseType = #[[COMP]]>, sizeInBits = 64>
+; CHECK: #[[COMP_PTR_OUTER:.+]] = #llvm.di_derived_type<tag = DW_TAG_pointer_type, baseType = #[[COMP]], sizeInBits = 64>
 ; CHECK: #[[SP_TYPE_OUTER:.+]] = #llvm.di_subroutine_type<types = #{{.*}}, #[[COMP_PTR_OUTER]]>
-; CHECK: #[[SP_OUTER:.+]] = #llvm.di_subprogram<id = [[SP_ID]], compileUnit = #{{.*}}, scope = #llvm.di_recursive_type<id = [[COMP_ID]], baseType = #[[COMP]]>, name = "class_method", file = #{{.*}}, subprogramFlags = Definition, type = #[[SP_TYPE_OUTER]]>
+; CHECK: #[[SP_OUTER:.+]] = #llvm.di_subprogram<id = [[SP_ID]], compileUnit = #{{.*}}, scope = #[[COMP]], name = "class_method", file = #{{.*}}, subprogramFlags = Definition, type = #[[SP_TYPE_OUTER]]>
 ; CHECK: #[[LOC]] = loc(fused<#[[SP_OUTER]]>
 
 !llvm.dbg.cu = !{!1}
@@ -323,10 +324,12 @@ define void @class_method() {
 ; // -----
 
 ; Verify the cyclic composite type is handled correctly.
-; CHECK: #[[FIELD:.+]] = #llvm.di_derived_type<tag = DW_TAG_member, name = "call_field", baseType = #llvm.di_recursive_type<id = [[COMP_ID:.+]]>>
-; CHECK: #[[COMP:.+]] = #llvm.di_composite_type<tag = DW_TAG_class_type, name = "class_field", file = #{{.*}}, line = 42, flags = "TypePassByReference|NonTrivial", elements = #[[FIELD]]>
-; CHECK: #[[COMP_PTR:.+]] = #llvm.di_derived_type<tag = DW_TAG_pointer_type, baseType = #[[COMP]]>
-; CHECK: #[[VAR0:.+]] = #llvm.di_local_variable<scope = #{{.*}}, name = "class_field", file = #{{.*}}, type = #llvm.di_recursive_type<id = [[COMP_ID]], baseType = #[[COMP_PTR]]>>
+; CHECK: #[[COMP_SELF:.+]] = #llvm.di_composite_type<tag = DW_TAG_null, recId = [[REC_ID:.+]]>
+; CHECK: #[[COMP_PTR_INNER:.+]] = #llvm.di_derived_type<tag = DW_TAG_pointer_type, baseType = #[[COMP_SELF]]>
+; CHECK: #[[FIELD:.+]] = #llvm.di_derived_type<tag = DW_TAG_member, name = "call_field", baseType = #[[COMP_PTR_INNER]]>
+; CHECK: #[[COMP:.+]] = #llvm.di_composite_type<tag = DW_TAG_class_type, recId = [[REC_ID]], name = "class_field", file = #{{.*}}, line = 42, flags = "TypePassByReference|NonTrivial", elements = #[[FIELD]]>
+; CHECK: #[[COMP_PTR_OUTER:.+]] = #llvm.di_derived_type<tag = DW_TAG_pointer_type, baseType = #[[COMP]]>
+; CHECK: #[[VAR0:.+]] = #llvm.di_local_variable<scope = #{{.*}}, name = "class_field", file = #{{.*}}, type = #[[COMP_PTR_OUTER]]>
 
 ; CHECK: @class_field
 ; CHECK-SAME:  %[[ARG0:[a-zA-Z0-9]+]]
diff --git a/mlir/test/Target/LLVMIR/llvmir-debug.mlir b/mlir/test/Target/LLVMIR/llvmir-debug.mlir
index 446ed0afc39e1a..5d70bff52bb2b9 100644
--- a/mlir/test/Target/LLVMIR/llvmir-debug.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-debug.mlir
@@ -353,33 +353,33 @@ llvm.func @func_debug_directives() {
 #di_compile_unit = #llvm.di_compile_unit<id = distinct[1]<>, sourceLanguage = DW_LANG_C, file = #di_file, isOptimized = false, emissionKind = None>
 
 // Recursive type itself.
-#di_rec_self = #llvm.di_recursive_type<id = distinct[0]<>>
-#di_ptr_inner = #llvm.di_derived_type<tag = DW_TAG_pointer_type, baseType = #di_rec_self, sizeInBits = 64>
+#di_struct_self = #llvm.di_composite_type<tag = DW_TAG_null, recId = distinct[0]<>>
+#di_ptr_inner = #llvm.di_derived_type<tag = DW_TAG_pointer_type, baseType = #di_struct_self, sizeInBits = 64>
 #di_subroutine_inner = #llvm.di_subroutine_type<types = #di_null_type, #di_ptr_inner>
 #di_subprogram_inner = #llvm.di_subprogram<
   id = distinct[2]<>,
   compileUnit = #di_compile_unit,
-  scope = #di_rec_self,
+  scope = #di_struct_self,
   name = "class_method",
   file = #di_file,
   subprogramFlags = Definition,
   type = #di_subroutine_inner>
 #di_struct = #llvm.di_composite_type<
   tag = DW_TAG_class_type,
+  recId = distinct[0]<>,
   name = "class_name",
   file = #di_file,
   line = 42,
   flags = "TypePassByReference|NonTrivial",
   elements = #di_subprogram_inner>
-#di_rec_struct = #llvm.di_recursive_type<id = distinct[0]<>, baseType = #di_struct>
 
 // Outer types referencing the entire recursive type.
-#di_ptr_outer = #llvm.di_derived_type<tag = DW_TAG_pointer_type, baseType = #di_rec_struct, sizeInBits = 64>
+#di_ptr_outer = #llvm.di_derived_type<tag = DW_TAG_pointer_type, baseType = #di_struct, sizeInBits = 64>
 #di_subroutine_outer = #llvm.di_subroutine_type<types = #di_null_type, #di_ptr_outer>
 #di_subprogram_outer = #llvm.di_subprogram<
   id = distinct[2]<>,
   compileUnit = #di_compile_unit,
-  scope = #di_rec_struct,
+  scope = #di_struct,
   name = "class_method",
   file = #di_file,
   subprogramFlags = Definition,



More information about the Mlir-commits mailing list