[Mlir-commits] [mlir] [mlir] Add support for recursive elements in DICompositeTypeAttr. (PR #74948)

Justin Wilson llvmlistbot at llvm.org
Sat Dec 9 23:12:07 PST 2023


https://github.com/waj334 updated https://github.com/llvm/llvm-project/pull/74948

>From e880caf3ae6b01768bddaed41362f04b7b660393 Mon Sep 17 00:00:00 2001
From: "Justin A. Wilson" <justin.wilson at omibyte.io>
Date: Sat, 9 Dec 2023 14:14:51 -0600
Subject: [PATCH] [mlir] Add support for recursive elements in DICompositeAttr.

Implements mutable storage for DICompositeTypeAttr in order to allow for self-references in its elements array. When the "identifier" parameter set non-empty, only this string participates in the hash key, though the storage is implemented such that only the "elements" parameter is mutable. The module translator will now create the respective instance of llvm::DICompositeType without elements and then it will call "llvm::DICompositeType::replaceElements" to set the elements after each element is translated. The only required IR change was that elements are explicitly wrapped in square brackets for the sake of parsing.
---
 .../mlir/Dialect/LLVMIR/LLVMAttrDefs.td       |  76 ++--
 mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h  |   4 +
 mlir/lib/AsmParser/AttributeParser.cpp        |   1 +
 mlir/lib/Dialect/LLVMIR/IR/AttrDetail.h       | 136 +++++++
 mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp      | 342 +++++++++++++++++-
 mlir/lib/Target/LLVMIR/DebugTranslation.cpp   |  24 +-
 mlir/test/Dialect/LLVMIR/debuginfo.mlir       |  11 +-
 7 files changed, 563 insertions(+), 31 deletions(-)
 create mode 100644 mlir/lib/Dialect/LLVMIR/IR/AttrDetail.h

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index 6975b18ab7f81..69df4c710a910 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -351,27 +351,6 @@ def LLVM_DICompileUnitAttr : LLVM_Attr<"DICompileUnit", "di_compile_unit",
   let assemblyFormat = "`<` struct(params) `>`";
 }
 
-//===----------------------------------------------------------------------===//
-// DICompositeTypeAttr
-//===----------------------------------------------------------------------===//
-
-def LLVM_DICompositeTypeAttr : LLVM_Attr<"DICompositeType", "di_composite_type",
-                                         /*traits=*/[], "DITypeAttr"> {
-  let parameters = (ins
-    LLVM_DITagParameter:$tag,
-    OptionalParameter<"StringAttr">:$name,
-    OptionalParameter<"DIFileAttr">:$file,
-    OptionalParameter<"uint32_t">:$line,
-    OptionalParameter<"DIScopeAttr">:$scope,
-    OptionalParameter<"DITypeAttr">:$baseType,
-    OptionalParameter<"DIFlags", "DIFlags::Zero">:$flags,
-    OptionalParameter<"uint64_t">:$sizeInBits,
-    OptionalParameter<"uint64_t">:$alignInBits,
-    OptionalArrayRefParameter<"DINodeAttr">:$elements
-  );
-  let assemblyFormat = "`<` struct(params) `>`";
-}
-
 //===----------------------------------------------------------------------===//
 // DIDerivedTypeAttr
 //===----------------------------------------------------------------------===//
@@ -684,6 +663,61 @@ def LLVM_AliasScopeDomainAttr : LLVM_Attr<"AliasScopeDomain",
   let assemblyFormat = "`<` struct(params) `>`";
 }
 
+//===----------------------------------------------------------------------===//
+// DICompositeTypeAttr
+//===----------------------------------------------------------------------===//
+
+def LLVM_DICompositeTypeAttr : LLVM_Attr<"DICompositeType", "di_composite_type",
+                                         /*traits=*/[NativeTypeTrait<"IsMutable">], "DITypeAttr"> {
+  let parameters = (ins
+    OptionalParameter<"unsigned">:$tag,
+    OptionalParameter<"StringAttr">:$name,
+    OptionalParameter<"DIFileAttr">:$file,
+    OptionalParameter<"uint32_t">:$line,
+    OptionalParameter<"DIScopeAttr">:$scope,
+    OptionalParameter<"DITypeAttr">:$baseType,
+    OptionalParameter<"DIFlags", "DIFlags::Zero">:$flags,
+    OptionalParameter<"uint64_t">:$sizeInBits,
+    OptionalParameter<"uint64_t">:$alignInBits,
+    OptionalArrayRefParameter<"DINodeAttr">:$elements,
+    OptionalParameter<"DistinctAttr">:$identifier
+  );
+  let hasCustomAssemblyFormat = 1;
+  let genStorageClass = 0;
+  let storageClass = "DICompositeTypeAttrStorage";
+  let builders = [
+    AttrBuilder<(ins
+        "unsigned":$tag,
+        "StringAttr":$name,
+        "DIFileAttr":$file,
+        "uint32_t":$line,
+        "DIScopeAttr":$scope,
+        "DITypeAttr":$baseType,
+        "DIFlags":$flags,
+        "uint64_t":$sizeInBits,
+        "uint64_t":$alignInBits,
+        "::llvm::ArrayRef<DINodeAttr>":$elements
+    )>,
+    AttrBuilder<(ins
+        "DistinctAttr":$identifier,
+        "unsigned":$tag,
+        "StringAttr":$name,
+        "DIFileAttr":$file,
+        "uint32_t":$line,
+        "DIScopeAttr":$scope,
+        "DITypeAttr":$baseType,
+        "DIFlags":$flags,
+        "uint64_t":$sizeInBits,
+        "uint64_t":$alignInBits,
+        CArg<"::llvm::ArrayRef<DINodeAttr>", "{}">:$elements
+    )>
+  ];
+  let extraClassDeclaration = [{
+    static DICompositeTypeAttr getIdentified(MLIRContext *context, DistinctAttr identifier);
+    void replaceElements(const ArrayRef<DINodeAttr>& elements);
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // AliasScopeAttr
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h
index c370bfa2b733d..c38bf1c66bba3 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h
@@ -23,6 +23,10 @@
 namespace mlir {
 namespace LLVM {
 
+namespace detail {
+  struct DICompositeTypeAttrStorage;
+} // namespace detail
+
 /// This class represents the base attribute for all debug info attributes.
 class DINodeAttr : public Attribute {
 public:
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index d085fb6af6bc1..de120bfd7a4f9 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -248,6 +248,7 @@ Attribute Parser::parseAttribute(Type type) {
 OptionalParseResult Parser::parseOptionalAttribute(Attribute &attribute,
                                                    Type type) {
   switch (getToken().getKind()) {
+  case Token::kw_distinct:
   case Token::at_identifier:
   case Token::floatliteral:
   case Token::integer:
diff --git a/mlir/lib/Dialect/LLVMIR/IR/AttrDetail.h b/mlir/lib/Dialect/LLVMIR/IR/AttrDetail.h
new file mode 100644
index 0000000000000..60478174f7ca0
--- /dev/null
+++ b/mlir/lib/Dialect/LLVMIR/IR/AttrDetail.h
@@ -0,0 +1,136 @@
+//===- AttrDetail.h - Details of MLIR LLVM dialect attributes --------*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains implementation details, such as storage structures, of
+// MLIR LLVM dialect attributes.
+//
+//===----------------------------------------------------------------------===//
+#ifndef DIALECT_LLVMIR_IR_ATTRDETAIL_H
+#define DIALECT_LLVMIR_IR_ATTRDETAIL_H
+
+#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
+#include "mlir/IR/Types.h"
+
+namespace mlir {
+namespace LLVM {
+namespace detail {
+
+//===----------------------------------------------------------------------===//
+// DICompositeTypeAttrStorage
+//===----------------------------------------------------------------------===//
+
+struct DICompositeTypeAttrStorage : public ::mlir::AttributeStorage {
+  using KeyTy = std::tuple<unsigned, StringAttr, DIFileAttr, uint32_t,
+                           DIScopeAttr, DITypeAttr, DIFlags, uint64_t, uint64_t,
+                           ArrayRef<DINodeAttr>, DistinctAttr>;
+
+  DICompositeTypeAttrStorage(unsigned tag, StringAttr name, DIFileAttr file,
+                             uint32_t line, DIScopeAttr scope,
+                             DITypeAttr baseType, DIFlags flags,
+                             uint64_t sizeInBits, uint64_t alignInBits,
+                             ArrayRef<DINodeAttr> elements,
+                             DistinctAttr identifier = DistinctAttr())
+      : tag(tag), name(name), file(file), line(line), scope(scope),
+        baseType(baseType), flags(flags), sizeInBits(sizeInBits),
+        alignInBits(alignInBits), elements(elements), identifier(identifier) {}
+
+  unsigned getTag() const { return tag; }
+  StringAttr getName() const { return name; }
+  DIFileAttr getFile() const { return file; }
+  uint32_t getLine() const { return line; }
+  DIScopeAttr getScope() const { return scope; }
+  DITypeAttr getBaseType() const { return baseType; }
+  DIFlags getFlags() const { return flags; }
+  uint64_t getSizeInBits() const { return sizeInBits; }
+  uint64_t getAlignInBits() const { return alignInBits; }
+  ArrayRef<DINodeAttr> getElements() const { return elements; }
+  DistinctAttr getIdentifier() const { return identifier; }
+
+  /// Returns true if this attribute is identified.
+  bool isIdentified() const {
+    return !(!identifier);
+  }
+
+  /// Returns the respective key for this attribute.
+  KeyTy getAsKey() const {
+    if (isIdentified())
+      return KeyTy(tag, name, file, line, scope, baseType, flags, sizeInBits,
+                   alignInBits, elements, identifier);
+
+    return KeyTy(tag, name, file, line, scope, baseType, flags, sizeInBits,
+                 alignInBits, elements, DistinctAttr());
+  }
+
+  /// Compares two keys.
+  bool operator==(const KeyTy &other) const {
+    if (isIdentified())
+      // Just compare against the identifier.
+      return identifier == std::get<10>(other);
+
+    // Otherwise, compare the entire tuple.
+    return other == getAsKey();
+  }
+
+  /// Returns the hash value of the key.
+  static llvm::hash_code hashKey(const KeyTy &key) {
+    const auto &[tag, name, file, line, scope, baseType, flags, sizeInBits,
+                 alignInBits, elements, identifier] = key;
+
+    if (identifier)
+      // Only the identifier participates in the hash id.
+      return hash_value(identifier);
+
+    // Otherwise, everything else is included in the hash.
+    return hash_combine(tag, name, file, line, scope, baseType, flags,
+                              sizeInBits, alignInBits, elements);
+  }
+
+  /// Constructs new storage for an attribute.
+  static DICompositeTypeAttrStorage *
+  construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
+    auto [tag, name, file, line, scope, baseType, flags, sizeInBits,
+           alignInBits, elements, identifier] = key;
+    elements = allocator.copyInto(elements);
+    if (identifier) {
+      return new (allocator.allocate<DICompositeTypeAttrStorage>())
+          DICompositeTypeAttrStorage(tag, name, file, line, scope, baseType,
+                                     flags, sizeInBits, alignInBits, elements,
+                                     identifier);
+    }
+    return new (allocator.allocate<DICompositeTypeAttrStorage>())
+        DICompositeTypeAttrStorage(tag, name, file, line, scope, baseType,
+                                   flags, sizeInBits, alignInBits, elements);
+  }
+
+  LogicalResult mutate(AttributeStorageAllocator &allocator,
+                       const ArrayRef<DINodeAttr>& elements) {
+    // Replace the elements.
+    this->elements = allocator.copyInto(elements);
+    return success();
+  }
+
+private:
+  unsigned tag;
+  StringAttr name;
+  DIFileAttr file;
+  uint32_t line;
+  DIScopeAttr scope;
+  DITypeAttr baseType;
+  DIFlags flags;
+  uint64_t sizeInBits;
+  uint64_t alignInBits;
+  ArrayRef<DINodeAttr> elements;
+  DistinctAttr identifier;
+};
+
+} // namespace detail
+} // namespace LLVM
+} // namespace mlir
+
+#endif // DIALECT_LLVMIR_IR_ATTRDETAIL_H
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
index 645a45dd96bef..cf6dd803b2160 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
@@ -10,6 +10,8 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "AttrDetail.h"
+
 #include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/Builders.h"
@@ -47,6 +49,7 @@ void LLVMDialect::registerAttributes() {
   addAttributes<
 #define GET_ATTRDEF_LIST
 #include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.cpp.inc"
+
       >();
 }
 
@@ -124,7 +127,7 @@ bool MemoryEffectsAttr::isReadWrite() {
 }
 
 //===----------------------------------------------------------------------===//
-// DIExpression
+// DIExpressionAttr
 //===----------------------------------------------------------------------===//
 
 DIExpressionAttr DIExpressionAttr::get(MLIRContext *context) {
@@ -248,3 +251,340 @@ TargetFeaturesAttr TargetFeaturesAttr::featuresAt(Operation *op) {
   return parentFunction.getOperation()->getAttrOfType<TargetFeaturesAttr>(
       getAttributeName());
 }
+
+//===----------------------------------------------------------------------===//
+// DICompositeTypeAttr
+//===----------------------------------------------------------------------===//
+
+DICompositeTypeAttr
+DICompositeTypeAttr::get(MLIRContext *context, unsigned tag, StringAttr name,
+                         DIFileAttr file, uint32_t line, DIScopeAttr scope,
+                         DITypeAttr baseType, DIFlags flags,
+                         uint64_t sizeInBits, uint64_t alignInBits,
+                         ::llvm::ArrayRef<DINodeAttr> elements) {
+  return Base::get(context, tag, name, file, line, scope, baseType, flags,
+                   sizeInBits, alignInBits, elements, DistinctAttr());
+}
+
+DICompositeTypeAttr DICompositeTypeAttr::get(
+    MLIRContext *context, DistinctAttr identifier, unsigned tag,
+    StringAttr name, DIFileAttr file, uint32_t line, DIScopeAttr scope,
+    DITypeAttr baseType, DIFlags flags, uint64_t sizeInBits,
+    uint64_t alignInBits, ::llvm::ArrayRef<DINodeAttr> elements) {
+  return Base::get(context, tag, name, file, line, scope, baseType, flags,
+                   sizeInBits, alignInBits, elements, identifier);
+}
+
+unsigned DICompositeTypeAttr::getTag() const { return getImpl()->getTag(); }
+
+StringAttr DICompositeTypeAttr::getName() const { return getImpl()->getName(); }
+
+DIFileAttr DICompositeTypeAttr::getFile() const { return getImpl()->getFile(); }
+
+uint32_t DICompositeTypeAttr::getLine() const { return getImpl()->getLine(); }
+
+DIScopeAttr DICompositeTypeAttr::getScope() const {
+  return getImpl()->getScope();
+}
+
+DITypeAttr DICompositeTypeAttr::getBaseType() const {
+  return getImpl()->getBaseType();
+}
+
+DIFlags DICompositeTypeAttr::getFlags() const { return getImpl()->getFlags(); }
+
+uint64_t DICompositeTypeAttr::getSizeInBits() const {
+  return getImpl()->getSizeInBits();
+}
+
+uint64_t DICompositeTypeAttr::getAlignInBits() const {
+  return getImpl()->getAlignInBits();
+}
+
+::llvm::ArrayRef<DINodeAttr> DICompositeTypeAttr::getElements() const {
+  return getImpl()->getElements();
+}
+
+DistinctAttr DICompositeTypeAttr::getIdentifier() const {
+  return getImpl()->getIdentifier();
+}
+
+Attribute DICompositeTypeAttr::parse(AsmParser &parser, Type type) {
+  FailureOr<AsmParser::CyclicParseReset> cyclicParse;
+  FailureOr<unsigned> tag;
+  FailureOr<StringAttr> name;
+  FailureOr<DIFileAttr> file;
+  FailureOr<uint32_t> line;
+  FailureOr<DIScopeAttr> scope;
+  FailureOr<DITypeAttr> baseType;
+  FailureOr<DIFlags> flags;
+  FailureOr<uint64_t> sizeInBits;
+  FailureOr<uint64_t> alignInBits;
+  SmallVector<DINodeAttr> elements;
+  DistinctAttr identifier;
+  const Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
+
+  auto paramParser = [&]() -> LogicalResult {
+    StringRef paramKey;
+    if (parser.parseKeyword(&paramKey)) {
+      return parser.emitError(parser.getCurrentLocation(),
+                              "expected parameter name.");
+    }
+
+    if (parser.parseEqual()) {
+      return parser.emitError(parser.getCurrentLocation(),
+                              "expected `=` following parameter name.");
+    }
+
+    if (paramKey == "tag") {
+      tag = [&]() -> FailureOr<unsigned> {
+        StringRef nameKeyword;
+        if (parser.parseKeyword(&nameKeyword))
+          return failure();
+        if (const unsigned value = llvm::dwarf::getTag(nameKeyword))
+          return value;
+        return parser.emitError(parser.getCurrentLocation())
+               << "invalid debug info debug info tag name: " << nameKeyword;
+      }();
+    } else if (failed(name) && paramKey == "name") {
+      name = FieldParser<StringAttr>::parse(parser);
+      if (failed(name)) {
+        return parser.emitError(parser.getCurrentLocation(),
+                                "failed to parse parameter 'name'");
+      }
+    } else if (failed(file) && paramKey == "file") {
+      file = FieldParser<DIFileAttr>::parse(parser);
+      if (failed(file)) {
+        return parser.emitError(parser.getCurrentLocation(),
+                                "failed to parse parameter 'file'");
+      }
+    } else if (failed(line) && paramKey == "line") {
+      line = FieldParser<uint32_t>::parse(parser);
+      if (failed(line)) {
+        return parser.emitError(parser.getCurrentLocation(),
+                                "failed to parse parameter 'line'");
+      }
+    } else if (failed(scope) && paramKey == "scope") {
+      scope = FieldParser<DIScopeAttr>::parse(parser);
+      if (failed(scope)) {
+        return parser.emitError(parser.getCurrentLocation(),
+                                "failed to parse parameter 'scope'");
+      }
+    } else if (failed(baseType) && paramKey == "baseType") {
+      baseType = FieldParser<DITypeAttr>::parse(parser);
+      if (failed(baseType)) {
+        return parser.emitError(parser.getCurrentLocation(),
+                                "failed to parse parameter 'baseType'");
+      }
+    } else if (failed(flags) && paramKey == "flags") {
+      flags = FieldParser<DIFlags>::parse(parser);
+      if (failed(flags)) {
+        return parser.emitError(parser.getCurrentLocation(),
+                                "failed to parse parameter 'flags'");
+      }
+    } else if (failed(sizeInBits) && paramKey == "sizeInBits") {
+      sizeInBits = FieldParser<uint32_t>::parse(parser);
+      if (failed(sizeInBits)) {
+        return parser.emitError(parser.getCurrentLocation(),
+                                "failed to parse parameter 'sizeInBits'");
+      }
+    } else if (failed(alignInBits) && paramKey == "alignInBits") {
+      alignInBits = FieldParser<uint32_t>::parse(parser);
+      if (failed(alignInBits)) {
+        return parser.emitError(parser.getCurrentLocation(),
+                                "failed to parse parameter 'alignInBits'");
+      }
+    } else {
+      return parser.emitError(parser.getCurrentLocation(),
+                              "unknown parameter '")
+             << paramKey << "'";
+    }
+    return success();
+  };
+
+  // Begin parsing.
+  if (parser.parseLess()) {
+    parser.emitError(parser.getCurrentLocation(), "expected `<`");
+    return {};
+  }
+
+  // First, attempt to parse the identifier attribute.
+  const OptionalParseResult idResult =
+      parser.parseOptionalAttribute(identifier);
+  if (idResult.has_value() && succeeded(*idResult)) {
+    if (succeeded(parser.parseOptionalGreater())) {
+      DICompositeTypeAttr result =
+          getIdentified(parser.getContext(), identifier);
+      // Cyclic parsing should not initiate with only the identifier. Only
+      // nested instances should terminate early.
+      if (succeeded(parser.tryStartCyclicParse(result))) {
+        parser.emitError(parser.getCurrentLocation(),
+                         "Expected identified attribute to contain at least "
+                         "one other parameter");
+        return {};
+      }
+      return result;
+    }
+
+    if (parser.parseComma()) {
+      parser.emitError(parser.getCurrentLocation(), "Expected `,`");
+    }
+  }
+
+  // Parse immutable parameters.
+  if (parser.parseCommaSeparatedList(paramParser)) {
+    return {};
+  }
+
+  if (identifier) {
+    // Create the identified attribute.
+    DICompositeTypeAttr result =
+        get(parser.getContext(), identifier, tag.value_or(0),
+            name.value_or(StringAttr()), file.value_or(DIFileAttr()),
+            line.value_or(0), scope.value_or(DIScopeAttr()),
+            baseType.value_or(DITypeAttr()), flags.value_or(DIFlags::Zero),
+            sizeInBits.value_or(0), alignInBits.value_or(0));
+
+    // Initiate cyclic parsing.
+    if (cyclicParse = parser.tryStartCyclicParse(result); failed(cyclicParse)) {
+      return {};
+    }
+  }
+
+  // Parse the elements now.
+  if (succeeded(parser.parseOptionalLParen())) {
+    if (parser.parseCommaSeparatedList([&]() -> LogicalResult {
+          Attribute attr;
+          if (parser.parseAttribute(attr)) {
+            return parser.emitError(parser.getCurrentLocation(),
+                                    "expected attribute");
+          }
+          elements.push_back(mlir::cast<DINodeAttr>(attr));
+          return success();
+        })) {
+      return {};
+    }
+
+    if (parser.parseRParen()) {
+      parser.emitError(parser.getCurrentLocation(), "expected `)");
+      return {};
+    }
+  }
+
+  // Expect the attribute to terminate.
+  if (parser.parseGreater()) {
+    parser.emitError(parser.getCurrentLocation(), "expected `>`");
+    return {};
+  }
+
+  if (!identifier)
+    return get(loc.getContext(), tag.value_or(0), name.value_or(StringAttr()),
+               file.value_or(DIFileAttr()), line.value_or(0),
+               scope.value_or(DIScopeAttr()), baseType.value_or(DITypeAttr()),
+               flags.value_or(DIFlags::Zero), sizeInBits.value_or(0),
+               alignInBits.value_or(0), elements);
+
+  // Replace the elements if the attribute is identified.
+  DICompositeTypeAttr result = getIdentified(parser.getContext(), identifier);
+  result.replaceElements(elements);
+  return result;
+}
+
+void DICompositeTypeAttr::print(AsmPrinter &printer) const {
+  FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrint;
+  SmallVector<std::function<void()>> valuePrinters;
+
+  printer << "<";
+  if (getImpl()->isIdentified()) {
+    cyclicPrint = printer.tryStartCyclicPrint(*this);
+    if (failed(cyclicPrint)) {
+      printer << getIdentifier() << ">";
+      return;
+    }
+    valuePrinters.push_back([&]() { printer << getIdentifier(); });
+  }
+
+  if (getTag() > 0) {
+    valuePrinters.push_back(
+        [&]() { printer << "tag = " << llvm::dwarf::TagString(getTag()); });
+  }
+
+  if (getName()) {
+    valuePrinters.push_back([&]() {
+      printer << "name = ";
+      printer.printStrippedAttrOrType(getName());
+    });
+  }
+
+  if (getFile()) {
+    valuePrinters.push_back([&]() {
+      printer << "file = ";
+      printer.printStrippedAttrOrType(getFile());
+    });
+  }
+
+  if (getLine() > 0) {
+    valuePrinters.push_back([&]() {
+      printer << "line = ";
+      printer.printStrippedAttrOrType(getLine());
+    });
+  }
+
+  if (getScope()) {
+    valuePrinters.push_back([&]() {
+      printer << "scope = ";
+      printer.printStrippedAttrOrType(getScope());
+    });
+  }
+
+  if (getBaseType()) {
+    valuePrinters.push_back([&]() {
+      printer << "baseType = ";
+      printer.printStrippedAttrOrType(getBaseType());
+    });
+  }
+
+  if (getFlags() != DIFlags::Zero) {
+    valuePrinters.push_back([&]() {
+      printer << "flags = ";
+      printer.printStrippedAttrOrType(getFlags());
+    });
+  }
+
+  if (getSizeInBits() > 0) {
+    valuePrinters.push_back([&]() {
+      printer << "sizeInBits = ";
+      printer.printStrippedAttrOrType(getSizeInBits());
+    });
+  }
+
+  if (getAlignInBits() > 0) {
+    valuePrinters.push_back([&]() {
+      printer << "alignInBits = ";
+      printer.printStrippedAttrOrType(getAlignInBits());
+    });
+  }
+  interleaveComma(valuePrinters, printer,
+                  [&](const std::function<void()> &fn) { fn(); });
+
+  if (!getElements().empty()) {
+    printer << " (";
+    printer.printStrippedAttrOrType(getElements());
+    printer << ")";
+  }
+
+  printer << ">";
+}
+
+DICompositeTypeAttr
+DICompositeTypeAttr::getIdentified(MLIRContext *context,
+                                   DistinctAttr identifier) {
+  return Base::get(context, 0, StringAttr(), DIFileAttr(), 0, DIScopeAttr(),
+                   DITypeAttr(), DIFlags::Zero, 0, 0, ArrayRef<DINodeAttr>(),
+                   identifier);
+}
+
+void DICompositeTypeAttr::replaceElements(
+    const ArrayRef<DINodeAttr> &elements) {
+  (void)Base::mutate(elements);
+}
diff --git a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp
index 16918aab54978..1f73d16b8ead4 100644
--- a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp
@@ -118,10 +118,6 @@ static DINodeT *getDistinctOrUnique(bool isDistinct, Ts &&...args) {
 
 llvm::DICompositeType *
 DebugTranslation::translateImpl(DICompositeTypeAttr attr) {
-  SmallVector<llvm::Metadata *> elements;
-  for (auto member : attr.getElements())
-    elements.push_back(translate(member));
-
   // TODO: Use distinct attributes to model this, once they have landed.
   // Depending on the tag, composite types must be distinct.
   bool isDistinct = false;
@@ -133,15 +129,31 @@ DebugTranslation::translateImpl(DICompositeTypeAttr attr) {
     isDistinct = true;
   }
 
-  return getDistinctOrUnique<llvm::DICompositeType>(
+  // Create the composite type metadata first with an empty set of elements.
+  llvm::DICompositeType *result = getDistinctOrUnique<llvm::DICompositeType>(
       isDistinct, llvmCtx, attr.getTag(), getMDStringOrNull(attr.getName()),
       translate(attr.getFile()), attr.getLine(), translate(attr.getScope()),
       translate(attr.getBaseType()), attr.getSizeInBits(),
       attr.getAlignInBits(),
       /*OffsetInBits=*/0,
       /*Flags=*/static_cast<llvm::DINode::DIFlags>(attr.getFlags()),
-      llvm::MDNode::get(llvmCtx, elements),
+      llvm::MDNode::get(llvmCtx, {}),
       /*RuntimeLang=*/0, /*VTableHolder=*/nullptr);
+
+  // Short-circuit the mapping for this attribute to prevent infinite recursion
+  // if this composite type is encountered while translating the elements.
+  attrToNode[attr] = result;
+
+  // Translate the elements.
+  SmallVector<llvm::Metadata*> elements;
+  for (const DINodeAttr member : attr.getElements())
+    elements.push_back(translate(member));
+
+  // Replace the elements in the resulting metadata.
+  result->replaceElements(llvm::MDTuple::get(llvmCtx, elements));
+
+  // Return the composite type.
+  return result;
 }
 
 llvm::DIDerivedType *DebugTranslation::translateImpl(DIDerivedTypeAttr attr) {
diff --git a/mlir/test/Dialect/LLVMIR/debuginfo.mlir b/mlir/test/Dialect/LLVMIR/debuginfo.mlir
index 53c38b4797031..857f9cb0f575a 100644
--- a/mlir/test/Dialect/LLVMIR/debuginfo.mlir
+++ b/mlir/test/Dialect/LLVMIR/debuginfo.mlir
@@ -36,6 +36,11 @@
   tag = DW_TAG_pointer_type, name = "ptr1"
 >
 
+// CHECK-DAG: #[[COMP3:.*]] = #llvm.di_composite_type<mystruct, tag = DW_TAG_array_type, name = "array1", file = #[[FILE]], scope = #[[FILE]], elements = #llvm.di_subrange<count = 4 : i64>>
+#comp3 = #llvm.di_composite_type<distinct[0]<>, tag = DW_TAG_struct_type, name = "struct1",
+  file = #file, scope = #file (#llvm.di_composite_type<distinct[0]<>>, #int0, #int1)
+>
+
 // CHECK-DAG: #[[COMP0:.*]] = #llvm.di_composite_type<tag = DW_TAG_array_type, name = "array0", line = 10, sizeInBits = 128, alignInBits = 32>
 #comp0 = #llvm.di_composite_type<
   tag = DW_TAG_array_type, name = "array0",
@@ -45,9 +50,9 @@
 // CHECK-DAG: #[[COMP1:.*]] = #llvm.di_composite_type<tag = DW_TAG_array_type, name = "array1", file = #[[FILE]], scope = #[[FILE]], baseType = #[[INT0]], elements = #llvm.di_subrange<count = 4 : i64>>
 #comp1 = #llvm.di_composite_type<
   tag = DW_TAG_array_type, name = "array1", file = #file,
-  scope = #file, baseType = #int0,
+  scope = #file, baseType = #int0
   // Specify the subrange count.
-  elements = #llvm.di_subrange<count = 4>
+  (#llvm.di_subrange<count = 4>)
 >
 
 // CHECK-DAG: #[[TOPLEVEL:.*]] = #llvm.di_namespace<name = "toplevel", exportSymbols = true>
@@ -74,7 +79,7 @@
 
 // CHECK-DAG: #[[SPTYPE0:.*]] = #llvm.di_subroutine_type<callingConvention = DW_CC_normal, types = #[[NULL]], #[[INT0]], #[[PTR0]], #[[PTR1]], #[[COMP0:.*]], #[[COMP1:.*]], #[[COMP2:.*]]>
 #spType0 = #llvm.di_subroutine_type<
-  callingConvention = DW_CC_normal, types = #null, #int0, #ptr0, #ptr1, #comp0, #comp1, #comp2
+  callingConvention = DW_CC_normal, types = #null, #int0, #ptr0, #ptr1, #comp0, #comp1, #comp2, #comp3
 >
 
 // CHECK-DAG: #[[SPTYPE1:.*]] = #llvm.di_subroutine_type<types = #[[INT1]], #[[INT1]]>



More information about the Mlir-commits mailing list