[Mlir-commits] [mlir] [MLIR][LLVM] Support Recursive DITypes (PR #80251)
Billy Zhu
llvmlistbot at llvm.org
Tue Feb 6 11:49:00 PST 2024
https://github.com/zyx-billy updated https://github.com/llvm/llvm-project/pull/80251
>From dc531647d96ce375fd46469c74cbc5e755947bc2 Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Thu, 25 Jan 2024 15:39:29 -0800
Subject: [PATCH 1/2] generic recursive type with importer impl
---
.../mlir/Dialect/LLVMIR/LLVMAttrDefs.td | 69 +++++++++++++-
mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h | 35 ++++++-
mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp | 22 ++++-
mlir/lib/Target/LLVMIR/DebugImporter.cpp | 95 ++++++++++++++-----
mlir/lib/Target/LLVMIR/DebugImporter.h | 34 +++++--
mlir/test/Target/LLVMIR/Import/debug-info.ll | 56 ++++-------
6 files changed, 235 insertions(+), 76 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index a831b076fb864..a9a471800bda3 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -301,6 +301,70 @@ def LLVM_DIExpressionAttr : LLVM_Attr<"DIExpression", "di_expression"> {
let assemblyFormat = "`<` ( `[` $operations^ `]` ) : (``)? `>`";
}
+//===----------------------------------------------------------------------===//
+// DIRecursiveTypeAttr
+//===----------------------------------------------------------------------===//
+
+def LLVM_DIRecursiveTypeAttr : LLVM_Attr<"DIRecursiveType", "di_recursive_type",
+ /*traits=*/[], "DITypeAttr"> {
+ let description = [{
+ This attribute enables recursive DITypes. There are two modes for this
+ attribute.
+
+ 1. If `baseType` is present:
+ - This type is considered a recursive declaration (rec-decl).
+ - The `baseType` is a self-recursive type identified by the `id` field.
+
+ 2. If `baseType` is not present:
+ - This type is considered a recursive self reference (rec-self).
+ - This DIRecursiveType itself is a placeholder type that should be
+ conceptually replaced with the closet parent DIRecursiveType with the
+ same `id` field.
+
+ e.g. To represent a linked list struct:
+
+ #rec_self = di_recursive_type<self_id = 0>
+ #ptr = di_derived_type<baseType: #rec_self, ...>
+ #field = di_derived_type<name = "next", baseType: #ptr, ...>
+ #struct = di_composite_type<name = "Node", elements: #field, ...>
+ #rec = di_recursive_type<self_id = 0, baseType: #struct>
+
+ #var = di_local_variable<type = #struct_type, ...>
+
+ Note that the a rec-self without an outer rec-decl with the same id is
+ conceptually the same as an "unbound" variable. The context needs to provide
+ meaning to the rec-self.
+
+ This can be avoided by calling the `getUnfoldedBaseType()` method on a
+ rec-decl, which returns the `baseType` with all matching rec-self instances
+ replaced with this rec-decl again. This is useful, for example, for fetching
+ a field out of a recursive struct and maintaining the legality of the field
+ type.
+ }];
+
+ let parameters = (ins
+ "DistinctAttr":$id,
+ OptionalParameter<"DITypeAttr">:$baseType
+ );
+
+ let builders = [
+ AttrBuilderWithInferredContext<(ins "DistinctAttr":$id), [{
+ return $_get(id.getContext(), id, nullptr);
+ }]>
+ ];
+
+ let extraClassDeclaration = [{
+ /// Whether this node represents a self-reference.
+ bool isRecSelf() { return !getBaseType(); }
+
+ /// Get the `baseType` with all instances of the corresponding rec-self
+ /// replaced with this attribute. This can only be called if `!isRecSelf()`.
+ DITypeAttr getUnfoldedBaseType();
+ }];
+
+ let assemblyFormat = "`<` struct(params) `>`";
+}
+
//===----------------------------------------------------------------------===//
// DINullTypeAttr
//===----------------------------------------------------------------------===//
@@ -526,14 +590,15 @@ def LLVM_DISubprogramAttr : LLVM_Attr<"DISubprogram", "di_subprogram",
OptionalParameter<"unsigned">:$line,
OptionalParameter<"unsigned">:$scopeLine,
"DISubprogramFlags":$subprogramFlags,
- OptionalParameter<"DISubroutineTypeAttr">:$type
+ OptionalParameter<"DIRecursiveTypeAttrOf<DISubroutineTypeAttr>">:$type
);
let builders = [
AttrBuilderWithInferredContext<(ins
"DistinctAttr":$id, "DICompileUnitAttr":$compileUnit,
"DIScopeAttr":$scope, "StringRef":$name, "StringRef":$linkageName,
"DIFileAttr":$file, "unsigned":$line, "unsigned":$scopeLine,
- "DISubprogramFlags":$subprogramFlags, "DISubroutineTypeAttr":$type
+ "DISubprogramFlags":$subprogramFlags,
+ "DIRecursiveTypeAttrOf<DISubroutineTypeAttr>":$type
), [{
MLIRContext *ctx = file.getContext();
return $_get(ctx, id, compileUnit, scope, StringAttr::get(ctx, name),
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h
index c370bfa2b733d..c9c0ba635b854 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h
@@ -52,9 +52,9 @@ class DILocalScopeAttr : public DIScopeAttr {
};
/// This class represents a LLVM attribute that describes a debug info type.
-class DITypeAttr : public DINodeAttr {
+class DITypeAttr : public DIScopeAttr {
public:
- using DINodeAttr::DINodeAttr;
+ using DIScopeAttr::DIScopeAttr;
/// Support LLVM type casting.
static bool classof(Attribute attr);
@@ -74,6 +74,10 @@ class TBAANodeAttr : public Attribute {
}
};
+// Forward declare.
+template <typename BaseType>
+class DIRecursiveTypeAttrOf;
+
// Inline the LLVM generated Linkage enum and utility.
// This is only necessary to isolate the "enum generated code" from the
// attribute definition itself.
@@ -87,4 +91,31 @@ using linkage::Linkage;
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.h.inc"
+namespace mlir {
+namespace LLVM {
+/// This class represents either a concrete attr, or a DIRecursiveTypeAttr
+/// containing such a concrete attr.
+template <typename BaseType>
+class DIRecursiveTypeAttrOf : public DITypeAttr {
+public:
+ static_assert(std::is_base_of_v<DITypeAttr, BaseType>);
+ using DITypeAttr::DITypeAttr;
+ /// Support LLVM type casting.
+ static bool classof(Attribute attr) {
+ if (auto rec = llvm::dyn_cast<DIRecursiveTypeAttr>(attr))
+ return llvm::isa<BaseType>(rec.getBaseType());
+ return llvm::isa<BaseType>(attr);
+ }
+
+ DIRecursiveTypeAttrOf(BaseType baseType) : DITypeAttr(baseType) {}
+
+ BaseType getUnfoldedBaseType() {
+ if (auto rec = llvm::dyn_cast<DIRecursiveTypeAttr>(this))
+ return llvm::cast<BaseType>(rec.getUnfoldedBaseType());
+ return llvm::cast<BaseType>(this);
+ }
+};
+} // namespace LLVM
+} // namespace mlir
+
#endif // MLIR_DIALECT_LLVMIR_LLVMATTRS_H_
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
index 645a45dd96bef..ed4001170bfe5 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
@@ -68,8 +68,8 @@ bool DINodeAttr::classof(Attribute attr) {
//===----------------------------------------------------------------------===//
bool DIScopeAttr::classof(Attribute attr) {
- return llvm::isa<DICompileUnitAttr, DICompositeTypeAttr, DIFileAttr,
- DILocalScopeAttr, DIModuleAttr, DINamespaceAttr>(attr);
+ return llvm::isa<DICompileUnitAttr, DIFileAttr, DILocalScopeAttr,
+ DIModuleAttr, DINamespaceAttr, DITypeAttr>(attr);
}
//===----------------------------------------------------------------------===//
@@ -86,8 +86,9 @@ bool DILocalScopeAttr::classof(Attribute attr) {
//===----------------------------------------------------------------------===//
bool DITypeAttr::classof(Attribute attr) {
- return llvm::isa<DINullTypeAttr, DIBasicTypeAttr, DICompositeTypeAttr,
- DIDerivedTypeAttr, DISubroutineTypeAttr>(attr);
+ return llvm::isa<DIRecursiveTypeAttr, DINullTypeAttr, DIBasicTypeAttr,
+ DICompositeTypeAttr, DIDerivedTypeAttr,
+ DISubroutineTypeAttr>(attr);
}
//===----------------------------------------------------------------------===//
@@ -185,6 +186,19 @@ void printExpressionArg(AsmPrinter &printer, uint64_t opcode,
});
}
+//===----------------------------------------------------------------------===//
+// DIRecursiveTypeAttr
+//===----------------------------------------------------------------------===//
+DITypeAttr DIRecursiveTypeAttr::getUnfoldedBaseType() {
+ assert(!isRecSelf() && "cannot get baseType from a rec-self type");
+ return llvm::cast<DITypeAttr>(getBaseType().replace(
+ [&](DIRecursiveTypeAttr rec) -> std::optional<DIRecursiveTypeAttr> {
+ if (rec.getId() == getId())
+ return *this;
+ return std::nullopt;
+ }));
+}
+
//===----------------------------------------------------------------------===//
// TargetFeaturesAttr
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/DebugImporter.cpp b/mlir/lib/Target/LLVMIR/DebugImporter.cpp
index 6521295230091..11f96af8126be 100644
--- a/mlir/lib/Target/LLVMIR/DebugImporter.cpp
+++ b/mlir/lib/Target/LLVMIR/DebugImporter.cpp
@@ -51,10 +51,9 @@ DICompileUnitAttr DebugImporter::translateImpl(llvm::DICompileUnit *node) {
std::optional<DIEmissionKind> emissionKind =
symbolizeDIEmissionKind(node->getEmissionKind());
return DICompileUnitAttr::get(
- context, DistinctAttr::create(UnitAttr::get(context)),
- node->getSourceLanguage(), translate(node->getFile()),
- getStringAttrOrNull(node->getRawProducer()), node->isOptimized(),
- emissionKind.value());
+ context, getOrCreateDistinctID(node), node->getSourceLanguage(),
+ translate(node->getFile()), getStringAttrOrNull(node->getRawProducer()),
+ node->isOptimized(), emissionKind.value());
}
DICompositeTypeAttr DebugImporter::translateImpl(llvm::DICompositeType *node) {
@@ -64,11 +63,7 @@ DICompositeTypeAttr DebugImporter::translateImpl(llvm::DICompositeType *node) {
assert(element && "expected a non-null element type");
elements.push_back(translate(element));
}
- // Drop the elements parameter if a cyclic dependency is detected. We
- // currently cannot model these cycles and thus drop the parameter if
- // required. A cyclic dependency is detected if one of the element nodes
- // translates to a nullptr since the node is already on the translation stack.
- // TODO: Support debug metadata with cyclic dependencies.
+ // Drop the elements parameter if any of the elements are invalid.
if (llvm::is_contained(elements, nullptr))
elements.clear();
DITypeAttr baseType = translate(node->getBaseType());
@@ -84,7 +79,7 @@ DICompositeTypeAttr DebugImporter::translateImpl(llvm::DICompositeType *node) {
}
DIDerivedTypeAttr DebugImporter::translateImpl(llvm::DIDerivedType *node) {
- // Return nullptr if the base type is a cyclic dependency.
+ // Return nullptr if the base type invalid.
DITypeAttr baseType = translate(node->getBaseType());
if (node->getBaseType() && !baseType)
return nullptr;
@@ -166,14 +161,14 @@ DISubprogramAttr DebugImporter::translateImpl(llvm::DISubprogram *node) {
// Only definitions require a distinct identifier.
mlir::DistinctAttr id;
if (node->isDistinct())
- id = DistinctAttr::create(UnitAttr::get(context));
+ id = getOrCreateDistinctID(node);
std::optional<DISubprogramFlags> subprogramFlags =
symbolizeDISubprogramFlags(node->getSubprogram()->getSPFlags());
- // Return nullptr if the scope or type is a cyclic dependency.
- DIScopeAttr scope = translate(node->getScope());
+ // Return nullptr if the scope or type is invalid.
+ DIScopeAttr scope = cast<DIScopeAttr>(translate(node->getScope()));
if (node->getScope() && !scope)
return nullptr;
- DISubroutineTypeAttr type = translate(node->getType());
+ DIRecursiveTypeAttrOf<DISubroutineTypeAttr> type = translate(node->getType());
if (node->getType() && !type)
return nullptr;
return DISubprogramAttr::get(context, id, translate(node->getUnit()), scope,
@@ -216,7 +211,7 @@ DebugImporter::translateImpl(llvm::DISubroutineType *node) {
}
types.push_back(translate(type));
}
- // Return nullptr if any of the types is a cyclic dependency.
+ // Return nullptr if any of the types is invalid.
if (llvm::is_contained(types, nullptr))
return nullptr;
return DISubroutineTypeAttr::get(context, node->getCC(), types);
@@ -234,12 +229,47 @@ DINodeAttr DebugImporter::translate(llvm::DINode *node) {
if (DINodeAttr attr = nodeToAttr.lookup(node))
return attr;
- // Return nullptr if a cyclic dependency is detected since the same node is
- // being traversed twice. This check avoids infinite recursion if the debug
- // metadata contains cycles.
- if (!translationStack.insert(node))
- return nullptr;
- auto guard = llvm::make_scope_exit([&]() { translationStack.pop_back(); });
+ // If a cyclic dependency is detected since the same node is being traversed
+ // twice, emit a recursive self type, and mark the duplicate node on the
+ // translationStack so it can emit a recursive decl type.
+ auto *typeNode = dyn_cast<llvm::DIType>(node);
+ if (typeNode) {
+ auto [iter, inserted] = typeTranslationStack.try_emplace(typeNode, nullptr);
+ if (!inserted) {
+ // The original node may have already been assigned a recursive ID from
+ // a different self-reference. Use that if possible.
+ DistinctAttr recId = iter->second;
+ if (!recId) {
+ recId = DistinctAttr::create(UnitAttr::get(context));
+ iter->second = recId;
+ }
+ unboundRecursiveSelfRefs.back().insert(recId);
+ return DIRecursiveTypeAttr::get(recId);
+ }
+ } else {
+ bool inserted =
+ nonTypeTranslationStack.insert({node, typeTranslationStack.size()});
+ assert(inserted && "recursion is only supported via DITypes");
+ }
+
+ unboundRecursiveSelfRefs.emplace_back();
+
+ auto guard = llvm::make_scope_exit([&]() {
+ if (typeNode)
+ typeTranslationStack.pop_back();
+ else
+ nonTypeTranslationStack.pop_back();
+
+ // Copy unboundRecursiveSelfRefs down to the previous level.
+ if (unboundRecursiveSelfRefs.size() == 1)
+ assert(unboundRecursiveSelfRefs.back().empty() &&
+ "internal error: unbound recursive self reference at top level.");
+ else
+ unboundRecursiveSelfRefs[unboundRecursiveSelfRefs.size() - 2].insert(
+ unboundRecursiveSelfRefs.back().begin(),
+ unboundRecursiveSelfRefs.back().end());
+ unboundRecursiveSelfRefs.pop_back();
+ });
// Convert the debug metadata if possible.
auto translateNode = [this](llvm::DINode *node) -> DINodeAttr {
@@ -276,7 +306,21 @@ DINodeAttr DebugImporter::translate(llvm::DINode *node) {
return nullptr;
};
if (DINodeAttr attr = translateNode(node)) {
- nodeToAttr.insert({node, attr});
+ // If this node was marked as recursive, wrap with a recursive type.
+ if (typeNode) {
+ if (DistinctAttr id = typeTranslationStack.lookup(typeNode)) {
+ DITypeAttr typeAttr = cast<DITypeAttr>(attr);
+ attr = DIRecursiveTypeAttr::get(context, id, typeAttr);
+
+ // Remove the unbound recursive attr.
+ AttrTypeReplacer replacer;
+ unboundRecursiveSelfRefs.back().erase(id);
+ }
+ }
+
+ // Only cache fully self-contained nodes.
+ if (unboundRecursiveSelfRefs.back().empty())
+ nodeToAttr.try_emplace(node, attr);
return attr;
}
return nullptr;
@@ -333,3 +377,10 @@ StringAttr DebugImporter::getStringAttrOrNull(llvm::MDString *stringNode) {
return StringAttr();
return StringAttr::get(context, stringNode->getString());
}
+
+DistinctAttr DebugImporter::getOrCreateDistinctID(llvm::DINode *node) {
+ DistinctAttr &id = nodeToDistinctAttr[node];
+ if (!id)
+ id = DistinctAttr::create(UnitAttr::get(context));
+ return id;
+}
diff --git a/mlir/lib/Target/LLVMIR/DebugImporter.h b/mlir/lib/Target/LLVMIR/DebugImporter.h
index 7d4a371284b68..7e9e28b274b98 100644
--- a/mlir/lib/Target/LLVMIR/DebugImporter.h
+++ b/mlir/lib/Target/LLVMIR/DebugImporter.h
@@ -52,8 +52,15 @@ class DebugImporter {
/// Infers the metadata type and translates it to MLIR.
template <typename DINodeT>
auto translate(DINodeT *node) {
- // Infer the MLIR type from the LLVM metadata type.
- using MLIRTypeT = decltype(translateImpl(node));
+ // Infer the result MLIR type from the LLVM metadata type.
+ // If the result is a DIType, it can also be wrapped in a recursive type,
+ // so the result is wrapped into a DIRecursiveTypeAttrOf.
+ // Otherwise, the exact result type is used.
+ constexpr bool isDIType = std::is_base_of_v<llvm::DIType, DINodeT>;
+ using RawMLIRTypeT = decltype(translateImpl(node));
+ using MLIRTypeT =
+ std::conditional_t<isDIType, DIRecursiveTypeAttrOf<RawMLIRTypeT>,
+ RawMLIRTypeT>;
return cast_or_null<MLIRTypeT>(
translate(static_cast<llvm::DINode *>(node)));
}
@@ -82,12 +89,27 @@ class DebugImporter {
/// null attribute otherwise.
StringAttr getStringAttrOrNull(llvm::MDString *stringNode);
+ DistinctAttr getOrCreateDistinctID(llvm::DINode *node);
+
/// A mapping between LLVM debug metadata and the corresponding attribute.
DenseMap<llvm::DINode *, DINodeAttr> nodeToAttr;
-
- /// A stack that stores the metadata nodes that are being traversed. The stack
- /// is used to detect cyclic dependencies during the metadata translation.
- SetVector<llvm::DINode *> translationStack;
+ /// A mapping between LLVM debug metadata and the distinct ID attr for DI
+ /// nodes that require distinction.
+ DenseMap<llvm::DINode *, DistinctAttr> nodeToDistinctAttr;
+
+ /// A stack that stores the metadata type nodes that are being traversed. The
+ /// stack is used to detect cyclic dependencies during the metadata
+ /// translation. Nodes are pushed with a null value. If it is ever seen twice,
+ /// it is given a DistinctAttr, indicating that it is a recursive node and
+ /// should take on that DistinctAttr as ID.
+ llvm::MapVector<llvm::DIType *, DistinctAttr> typeTranslationStack;
+ /// All the unbound recursive self references in the translation stack.
+ SmallVector<DenseSet<DistinctAttr>> unboundRecursiveSelfRefs;
+ /// A stack that stores the non-type metadata nodes that are being traversed.
+ /// Each node is associated with the size of the `typeTranslationStack` at the
+ /// time of push. This is used to identify a recursion purely in the non-type
+ /// metadata nodes, which is not supported yet.
+ SetVector<std::pair<llvm::DINode *, unsigned>> nonTypeTranslationStack;
MLIRContext *context;
ModuleOp mlirModule;
diff --git a/mlir/test/Target/LLVMIR/Import/debug-info.ll b/mlir/test/Target/LLVMIR/Import/debug-info.ll
index 9ef6580bcf240..032acc8f8c811 100644
--- a/mlir/test/Target/LLVMIR/Import/debug-info.ll
+++ b/mlir/test/Target/LLVMIR/Import/debug-info.ll
@@ -296,12 +296,16 @@ define void @class_method() {
ret void, !dbg !9
}
-; Verify the elements parameter is dropped due to the cyclic dependencies.
-; CHECK: #[[COMP:.+]] = #llvm.di_composite_type<tag = DW_TAG_class_type, name = "class_name", file = #{{.*}}, line = 42, flags = "TypePassByReference|NonTrivial">
-; CHECK: #[[COMP_PTR:.+]] = #llvm.di_derived_type<tag = DW_TAG_pointer_type, baseType = #[[COMP]], sizeInBits = 64>
+; Verify the cyclic composite type is identified, even though conversion begins from the subprogram type.
+; CHECK: #[[COMP_PTR:.+]] = #llvm.di_derived_type<tag = DW_TAG_pointer_type, baseType = #llvm.di_recursive_type<id = [[COMP_ID:.+]]>, sizeInBits = 64>
; CHECK: #[[SP_TYPE:.+]] = #llvm.di_subroutine_type<types = #{{.*}}, #[[COMP_PTR]]>
-; CHECK: #[[SP:.+]] = #llvm.di_subprogram<id = distinct[{{.*}}]<>, compileUnit = #{{.*}}, scope = #[[COMP]], name = "class_method", file = #{{.*}}, subprogramFlags = Definition, type = #[[SP_TYPE]]>
-; CHECK: #[[LOC]] = loc(fused<#[[SP]]>
+; CHECK: #[[SP_INNER:.+]] = #llvm.di_subprogram<id = [[SP_ID:.+]], compileUnit = #{{.*}}, scope = #llvm.di_recursive_type<id = [[COMP_ID]]>, name = "class_method", file = #{{.*}}, subprogramFlags = Definition, type = #[[SP_TYPE]]>
+; CHECK: #[[COMP:.+]] = #llvm.di_composite_type<tag = DW_TAG_class_type, name = "class_name", file = #{{.*}}, line = 42, flags = "TypePassByReference|NonTrivial", elements = #[[SP_INNER]]>
+
+; CHECK: #[[COMP_PTR_OUTER:.+]] = #llvm.di_derived_type<tag = DW_TAG_pointer_type, baseType = #llvm.di_recursive_type<id = [[COMP_ID]], baseType = #[[COMP]]>, sizeInBits = 64>
+; CHECK: #[[SP_TYPE_OUTER:.+]] = #llvm.di_subroutine_type<types = #{{.*}}, #[[COMP_PTR_OUTER]]>
+; CHECK: #[[SP_OUTER:.+]] = #llvm.di_subprogram<id = [[SP_ID]], compileUnit = #{{.*}}, scope = #llvm.di_recursive_type<id = [[COMP_ID]], baseType = #[[COMP]]>, name = "class_method", file = #{{.*}}, subprogramFlags = Definition, type = #[[SP_TYPE_OUTER]]>
+; CHECK: #[[LOC]] = loc(fused<#[[SP_OUTER]]>
!llvm.dbg.cu = !{!1}
!llvm.module.flags = !{!0}
@@ -318,15 +322,16 @@ define void @class_method() {
; // -----
-; Verify the elements parameter is dropped due to the cyclic dependencies.
-; CHECK: #[[$COMP:.+]] = #llvm.di_composite_type<tag = DW_TAG_class_type, name = "class_field", file = #{{.*}}, line = 42, flags = "TypePassByReference|NonTrivial">
-; CHECK: #[[$COMP_PTR:.+]] = #llvm.di_derived_type<tag = DW_TAG_pointer_type, baseType = #[[$COMP]]>
-; CHECK: #[[$VAR0:.+]] = #llvm.di_local_variable<scope = #{{.*}}, name = "class_field", file = #{{.*}}, type = #[[$COMP_PTR]]>
+; Verify the cyclic composite type is handled correctly.
+; CHECK: #[[FIELD:.+]] = #llvm.di_derived_type<tag = DW_TAG_member, name = "call_field", baseType = #llvm.di_recursive_type<id = [[COMP_ID:.+]]>>
+; CHECK: #[[COMP:.+]] = #llvm.di_composite_type<tag = DW_TAG_class_type, name = "class_field", file = #{{.*}}, line = 42, flags = "TypePassByReference|NonTrivial", elements = #[[FIELD]]>
+; CHECK: #[[COMP_PTR:.+]] = #llvm.di_derived_type<tag = DW_TAG_pointer_type, baseType = #[[COMP]]>
+; CHECK: #[[VAR0:.+]] = #llvm.di_local_variable<scope = #{{.*}}, name = "class_field", file = #{{.*}}, type = #llvm.di_recursive_type<id = [[COMP_ID]], baseType = #[[COMP_PTR]]>>
-; CHECK-LABEL: @class_field
+; CHECK: @class_field
; CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
define void @class_field(ptr %arg1) {
- ; CHECK: llvm.intr.dbg.value #[[$VAR0]] = %[[ARG0]] : !llvm.ptr
+ ; CHECK: llvm.intr.dbg.value #[[VAR0]] = %[[ARG0]] : !llvm.ptr
call void @llvm.dbg.value(metadata ptr %arg1, metadata !7, metadata !DIExpression()), !dbg !9
ret void
}
@@ -563,35 +568,6 @@ define void @func_in_module(ptr %arg) !dbg !8 {
; // -----
-; Verifies that array types that have an unimportable base type are removed to
-; avoid producing invalid IR.
-; CHECK: #[[DI_LOCAL_VAR:.+]] = #llvm.di_local_variable<
-; CHECK-NOT: type =
-
-; CHECK-LABEL: @array_with_cyclic_base_type
-define i32 @array_with_cyclic_base_type(ptr %0) !dbg !3 {
- call void @llvm.dbg.value(metadata ptr %0, metadata !4, metadata !DIExpression()), !dbg !7
- ret i32 0
-}
-
-; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
-declare void @llvm.dbg.value(metadata, metadata, metadata)
-
-
-!llvm.module.flags = !{!0}
-!llvm.dbg.cu = !{!1}
-
-!0 = !{i32 2, !"Debug Info Version", i32 3}
-!1 = distinct !DICompileUnit(language: DW_LANG_C, file: !2)
-!2 = !DIFile(filename: "debug-info.ll", directory: "/")
-!3 = distinct !DISubprogram(name: "func", scope: !2, file: !2, line: 46, scopeLine: 48, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !1)
-!4 = !DILocalVariable(name: "op", arg: 5, scope: !3, file: !2, line: 47, type: !5)
-!5 = !DICompositeType(tag: DW_TAG_array_type, size: 42, baseType: !6)
-!6 = !DIDerivedType(tag: DW_TAG_pointer_type, baseType: !5)
-!7 = !DILocation(line: 0, scope: !3)
-
-; // -----
-
; Verifies that import compile units respect the distinctness of the input.
; CHECK-LABEL: @distinct_cu_func0
define void @distinct_cu_func0() !dbg !4 {
>From 761a0dd508c236fd36917ebdaf4f0470e8a31d43 Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Mon, 5 Feb 2024 12:22:20 -0800
Subject: [PATCH 2/2] add exporter support for composite type
---
mlir/lib/Target/LLVMIR/DebugTranslation.cpp | 96 ++++++++++++++++-----
mlir/lib/Target/LLVMIR/DebugTranslation.h | 23 ++++-
mlir/test/Target/LLVMIR/llvmir-debug.mlir | 60 +++++++++++++
3 files changed, 156 insertions(+), 23 deletions(-)
diff --git a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp
index 16918aab54978..b7780958ee4af 100644
--- a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp
@@ -117,11 +117,8 @@ static DINodeT *getDistinctOrUnique(bool isDistinct, Ts &&...args) {
}
llvm::DICompositeType *
-DebugTranslation::translateImpl(DICompositeTypeAttr attr) {
- SmallVector<llvm::Metadata *> elements;
- for (auto member : attr.getElements())
- elements.push_back(translate(member));
-
+DebugTranslation::translateImpl(DICompositeTypeAttr attr,
+ SetRecursivePlaceholderFn setRec) {
// TODO: Use distinct attributes to model this, once they have landed.
// Depending on the tag, composite types must be distinct.
bool isDistinct = false;
@@ -133,15 +130,26 @@ DebugTranslation::translateImpl(DICompositeTypeAttr attr) {
isDistinct = true;
}
- return getDistinctOrUnique<llvm::DICompositeType>(
- isDistinct, llvmCtx, attr.getTag(), getMDStringOrNull(attr.getName()),
- translate(attr.getFile()), attr.getLine(), translate(attr.getScope()),
- translate(attr.getBaseType()), attr.getSizeInBits(),
- attr.getAlignInBits(),
- /*OffsetInBits=*/0,
- /*Flags=*/static_cast<llvm::DINode::DIFlags>(attr.getFlags()),
- llvm::MDNode::get(llvmCtx, elements),
- /*RuntimeLang=*/0, /*VTableHolder=*/nullptr);
+ llvm::DICompositeType *placeholder =
+ getDistinctOrUnique<llvm::DICompositeType>(
+ isDistinct, llvmCtx, attr.getTag(), getMDStringOrNull(attr.getName()),
+ translate(attr.getFile()), attr.getLine(), translate(attr.getScope()),
+ translate(attr.getBaseType()), attr.getSizeInBits(),
+ attr.getAlignInBits(),
+ /*OffsetInBits=*/0,
+ /*Flags=*/static_cast<llvm::DINode::DIFlags>(attr.getFlags()),
+ /*Elements=*/nullptr, /*RuntimeLang=*/0, /*VTableHolder=*/nullptr);
+
+ if (setRec)
+ setRec(placeholder);
+
+ SmallVector<llvm::Metadata *> elements;
+ for (auto member : attr.getElements())
+ elements.push_back(translate(member));
+
+ placeholder->replaceElements(llvm::MDNode::get(llvmCtx, elements));
+
+ return placeholder;
}
llvm::DIDerivedType *DebugTranslation::translateImpl(DIDerivedTypeAttr attr) {
@@ -200,22 +208,66 @@ DebugTranslation::translateImpl(DIGlobalVariableAttr attr) {
attr.getIsDefined(), nullptr, nullptr, attr.getAlignInBits(), nullptr);
}
+llvm::DIType *DebugTranslation::translateImpl(DIRecursiveTypeAttr attr) {
+ if (attr.isRecSelf()) {
+ auto *iter = recursiveTypeMap.find(attr.getId());
+ assert(iter != recursiveTypeMap.end() && "unbound DI recursive self type");
+ return iter->second;
+ }
+
+ size_t recursiveStackSize = recursiveTypeMap.size();
+ auto setRecursivePlaceholderFn = [&](llvm::DIType *node) {
+ auto [iter, inserted] = recursiveTypeMap.try_emplace(attr.getId(), node);
+ assert(inserted && "illegal reuse of recursive id");
+ };
+
+ llvm::DIType *node =
+ TypeSwitch<DITypeAttr, llvm::DIType *>(attr.getBaseType())
+ .Case<DICompositeTypeAttr>([&](auto attr) {
+ return translateImpl(attr, setRecursivePlaceholderFn);
+ });
+
+ assert((recursiveStackSize + 1 == recursiveTypeMap.size()) &&
+ "internal inconsistency: unexpected recursive translation stack");
+ recursiveTypeMap.pop_back();
+
+ return node;
+}
+
llvm::DIScope *DebugTranslation::translateImpl(DIScopeAttr attr) {
return cast<llvm::DIScope>(translate(DINodeAttr(attr)));
}
llvm::DISubprogram *DebugTranslation::translateImpl(DISubprogramAttr attr) {
+ if (auto iter = distinctAttrToNode.find(attr.getId());
+ iter != distinctAttrToNode.end())
+ return cast<llvm::DISubprogram>(iter->second);
+
+ llvm::DIScope *scope = translate(attr.getScope());
+ llvm::DIFile *file = translate(attr.getFile());
+ llvm::DIType *type = translate(attr.getType());
+ llvm::DICompileUnit *compileUnit = translate(attr.getCompileUnit());
+
+ // Check again after recursive calls in case this distinct node recurses back
+ // to itself.
+ if (auto iter = distinctAttrToNode.find(attr.getId());
+ iter != distinctAttrToNode.end())
+ return cast<llvm::DISubprogram>(iter->second);
+
bool isDefinition = static_cast<bool>(attr.getSubprogramFlags() &
LLVM::DISubprogramFlags::Definition);
- return getDistinctOrUnique<llvm::DISubprogram>(
- isDefinition, llvmCtx, translate(attr.getScope()),
- getMDStringOrNull(attr.getName()),
- getMDStringOrNull(attr.getLinkageName()), translate(attr.getFile()),
- attr.getLine(), translate(attr.getType()), attr.getScopeLine(),
+ llvm::DISubprogram *node = getDistinctOrUnique<llvm::DISubprogram>(
+ isDefinition, llvmCtx, scope, getMDStringOrNull(attr.getName()),
+ getMDStringOrNull(attr.getLinkageName()), file, attr.getLine(), type,
+ attr.getScopeLine(),
/*ContainingType=*/nullptr, /*VirtualIndex=*/0,
/*ThisAdjustment=*/0, llvm::DINode::FlagZero,
static_cast<llvm::DISubprogram::DISPFlags>(attr.getSubprogramFlags()),
- translate(attr.getCompileUnit()));
+ compileUnit);
+
+ if (attr.getId())
+ distinctAttrToNode.try_emplace(attr.getId(), node);
+ return node;
}
llvm::DIModule *DebugTranslation::translateImpl(DIModuleAttr attr) {
@@ -274,8 +326,8 @@ llvm::DINode *DebugTranslation::translate(DINodeAttr attr) {
DIDerivedTypeAttr, DIFileAttr, DIGlobalVariableAttr,
DILabelAttr, DILexicalBlockAttr, DILexicalBlockFileAttr,
DILocalVariableAttr, DIModuleAttr, DINamespaceAttr,
- DINullTypeAttr, DISubprogramAttr, DISubrangeAttr,
- DISubroutineTypeAttr>(
+ DINullTypeAttr, DIRecursiveTypeAttr, DISubprogramAttr,
+ DISubrangeAttr, DISubroutineTypeAttr>(
[&](auto attr) { return translateImpl(attr); });
attrToNode.insert({attr, node});
return node;
diff --git a/mlir/lib/Target/LLVMIR/DebugTranslation.h b/mlir/lib/Target/LLVMIR/DebugTranslation.h
index 627c684684498..f4ba687ad564e 100644
--- a/mlir/lib/Target/LLVMIR/DebugTranslation.h
+++ b/mlir/lib/Target/LLVMIR/DebugTranslation.h
@@ -68,10 +68,22 @@ class DebugTranslation {
llvm::DIFile *translateFile(StringRef fileName);
/// Translate the given attribute to the corresponding llvm debug metadata.
+ ///
+ /// For attributes corresponding to DITypes that can be recursive (i.e.
+ /// supports replacing subelements), an additional optional argument with type
+ /// `SetRecursivePlaceholderFn` should be supported.
+ /// The translation impl for recursive support must follow these three steps:
+ /// 1. Produce a placeholder version of the translated node without calling
+ /// `translate` on any subelements of the MLIR attr.
+ /// 2. Call the SetRecursivePlaceholderFn with the placeholder node.
+ /// 3. Translate subelements recursively using `translate` and fill the
+ /// original placeholder.
+ using SetRecursivePlaceholderFn = llvm::function_ref<void(llvm::DIType *)>;
llvm::DIType *translateImpl(DINullTypeAttr attr);
llvm::DIBasicType *translateImpl(DIBasicTypeAttr attr);
llvm::DICompileUnit *translateImpl(DICompileUnitAttr attr);
- llvm::DICompositeType *translateImpl(DICompositeTypeAttr attr);
+ llvm::DICompositeType *translateImpl(DICompositeTypeAttr attr,
+ SetRecursivePlaceholderFn setRec = {});
llvm::DIDerivedType *translateImpl(DIDerivedTypeAttr attr);
llvm::DIFile *translateImpl(DIFileAttr attr);
llvm::DILabel *translateImpl(DILabelAttr attr);
@@ -82,6 +94,7 @@ class DebugTranslation {
llvm::DIGlobalVariable *translateImpl(DIGlobalVariableAttr attr);
llvm::DIModule *translateImpl(DIModuleAttr attr);
llvm::DINamespace *translateImpl(DINamespaceAttr attr);
+ llvm::DIType *translateImpl(DIRecursiveTypeAttr attr);
llvm::DIScope *translateImpl(DIScopeAttr attr);
llvm::DISubprogram *translateImpl(DISubprogramAttr attr);
llvm::DISubrange *translateImpl(DISubrangeAttr attr);
@@ -102,6 +115,14 @@ class DebugTranslation {
/// metadata.
DenseMap<Attribute, llvm::DINode *> attrToNode;
+ /// A mapping from DIRecursiveTypeAttr id to the translated DIType.
+ llvm::MapVector<DistinctAttr, llvm::DIType *> recursiveTypeMap;
+
+ /// A mapping between distinct ID attr for DI nodes that require distinction
+ /// and the translate LLVM metadata node. This helps identify attrs that
+ /// should translate into the same LLVM debug node.
+ DenseMap<DistinctAttr, llvm::DINode *> distinctAttrToNode;
+
/// A mapping between filename and llvm debug file.
/// TODO: Change this to DenseMap<Identifier, ...> when we can
/// access the Identifier filename in FileLineColLoc.
diff --git a/mlir/test/Target/LLVMIR/llvmir-debug.mlir b/mlir/test/Target/LLVMIR/llvmir-debug.mlir
index cfd5239515c9c..446ed0afc39e1 100644
--- a/mlir/test/Target/LLVMIR/llvmir-debug.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-debug.mlir
@@ -342,3 +342,63 @@ llvm.func @func_line_tables() {
llvm.func @func_debug_directives() {
llvm.return
} loc(fused<#di_subprogram_2>["foo2.mlir":0:0])
+
+// -----
+
+// Ensure recursive types with multiple external references work.
+
+// Common base nodes.
+#di_file = #llvm.di_file<"test.mlir" in "/">
+#di_null_type = #llvm.di_null_type
+#di_compile_unit = #llvm.di_compile_unit<id = distinct[1]<>, sourceLanguage = DW_LANG_C, file = #di_file, isOptimized = false, emissionKind = None>
+
+// Recursive type itself.
+#di_rec_self = #llvm.di_recursive_type<id = distinct[0]<>>
+#di_ptr_inner = #llvm.di_derived_type<tag = DW_TAG_pointer_type, baseType = #di_rec_self, sizeInBits = 64>
+#di_subroutine_inner = #llvm.di_subroutine_type<types = #di_null_type, #di_ptr_inner>
+#di_subprogram_inner = #llvm.di_subprogram<
+ id = distinct[2]<>,
+ compileUnit = #di_compile_unit,
+ scope = #di_rec_self,
+ name = "class_method",
+ file = #di_file,
+ subprogramFlags = Definition,
+ type = #di_subroutine_inner>
+#di_struct = #llvm.di_composite_type<
+ tag = DW_TAG_class_type,
+ name = "class_name",
+ file = #di_file,
+ line = 42,
+ flags = "TypePassByReference|NonTrivial",
+ elements = #di_subprogram_inner>
+#di_rec_struct = #llvm.di_recursive_type<id = distinct[0]<>, baseType = #di_struct>
+
+// Outer types referencing the entire recursive type.
+#di_ptr_outer = #llvm.di_derived_type<tag = DW_TAG_pointer_type, baseType = #di_rec_struct, sizeInBits = 64>
+#di_subroutine_outer = #llvm.di_subroutine_type<types = #di_null_type, #di_ptr_outer>
+#di_subprogram_outer = #llvm.di_subprogram<
+ id = distinct[2]<>,
+ compileUnit = #di_compile_unit,
+ scope = #di_rec_struct,
+ name = "class_method",
+ file = #di_file,
+ subprogramFlags = Definition,
+ type = #di_subroutine_outer>
+
+#loc3 = loc(fused<#di_subprogram_outer>["test.mlir":1:1])
+
+// CHECK: @class_method
+// CHECK: ret void, !dbg ![[LOC:.*]]
+
+// CHECK: ![[CU:.*]] = distinct !DICompileUnit(
+// CHECK: ![[SP:.*]] = distinct !DISubprogram(name: "class_method", scope: ![[STRUCT:.*]], file: !{{.*}}, type: ![[SUBROUTINE:.*]], spFlags: DISPFlagDefinition, unit: ![[CU]])
+// CHECK: ![[STRUCT]] = distinct !DICompositeType(tag: DW_TAG_class_type, name: "class_name", {{.*}}, elements: ![[ELEMS:.*]])
+// CHECK: ![[ELEMS]] = !{![[SP]]}
+// CHECK: ![[SUBROUTINE]] = !DISubroutineType(types: ![[SUBROUTINE_ELEMS:.*]])
+// CHECK: ![[SUBROUTINE_ELEMS]] = !{null, ![[PTR:.*]]}
+// CHECK: ![[PTR]] = !DIDerivedType(tag: DW_TAG_pointer_type, baseType: ![[STRUCT]], size: 64)
+// CHECK: ![[LOC]] = !DILocation(line: 1, column: 1, scope: ![[SP]])
+
+llvm.func @class_method() {
+ llvm.return loc(#loc3)
+} loc(#loc3)
More information about the Mlir-commits
mailing list