[Mlir-commits] [mlir] [mlir] Add support for recursive elements in DICompositeAttr. (PR #74948)

Justin Wilson llvmlistbot at llvm.org
Sat Dec 9 12:19:02 PST 2023


https://github.com/waj334 created https://github.com/llvm/llvm-project/pull/74948

Implements mutable storage for DICompositeTypeAttr in order to allow for self-references in its elements array. When the "identifier" parameter set non-empty, only this string participates in the hash key, though the storage is implemented such that only the "elements" parameter is mutable. The module translator will now create the respective instance of llvm::DICompositeType without elements and then it will call "llvm::DICompositeType::replaceElements" to set the elements after each element is translated. The only required IR change was that elements are explicitly wrapped in square brackets for the sake of parsing.

>From ecaf00712fb551df575c061c807b288f0282268c Mon Sep 17 00:00:00 2001
From: "Justin A. Wilson" <justin.wilson at omibyte.io>
Date: Sat, 9 Dec 2023 14:14:51 -0600
Subject: [PATCH] [mlir] Add support for recursive elements in DICompositeAttr.

Implements mutable storage for DICompositeTypeAttr in order to allow for self-references in its elements array. When the "identifier" parameter set non-empty, only this string participates in the hash key, though the storage is implemented such that only the "elements" parameter is mutable. The module translator will now create the respective instance of llvm::DICompositeType without elements and then it will call "llvm::DICompositeType::replaceElements" to set the elements after each element is translated. The only required IR change was that elements are explicitly wrapped in square brackets for the sake of parsing.
---
 .../mlir/Dialect/LLVMIR/LLVMAttrDefs.td       |  76 +++-
 mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h  |   4 +
 mlir/lib/Dialect/LLVMIR/IR/AttrDetail.h       | 140 ++++++
 mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp      | 409 +++++++++++++++++-
 mlir/lib/Target/LLVMIR/DebugTranslation.cpp   |  24 +-
 mlir/test/Dialect/LLVMIR/debuginfo.mlir       |   9 +-
 6 files changed, 632 insertions(+), 30 deletions(-)
 create mode 100644 mlir/lib/Dialect/LLVMIR/IR/AttrDetail.h

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index 5a65293a113c7f..0aed5b7840fbe9 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -342,27 +342,6 @@ def LLVM_DICompileUnitAttr : LLVM_Attr<"DICompileUnit", "di_compile_unit",
   let assemblyFormat = "`<` struct(params) `>`";
 }
 
-//===----------------------------------------------------------------------===//
-// DICompositeTypeAttr
-//===----------------------------------------------------------------------===//
-
-def LLVM_DICompositeTypeAttr : LLVM_Attr<"DICompositeType", "di_composite_type",
-                                         /*traits=*/[], "DITypeAttr"> {
-  let parameters = (ins
-    LLVM_DITagParameter:$tag,
-    OptionalParameter<"StringAttr">:$name,
-    OptionalParameter<"DIFileAttr">:$file,
-    OptionalParameter<"uint32_t">:$line,
-    OptionalParameter<"DIScopeAttr">:$scope,
-    OptionalParameter<"DITypeAttr">:$baseType,
-    OptionalParameter<"DIFlags", "DIFlags::Zero">:$flags,
-    OptionalParameter<"uint64_t">:$sizeInBits,
-    OptionalParameter<"uint64_t">:$alignInBits,
-    OptionalArrayRefParameter<"DINodeAttr">:$elements
-  );
-  let assemblyFormat = "`<` struct(params) `>`";
-}
-
 //===----------------------------------------------------------------------===//
 // DIDerivedTypeAttr
 //===----------------------------------------------------------------------===//
@@ -675,6 +654,61 @@ def LLVM_AliasScopeDomainAttr : LLVM_Attr<"AliasScopeDomain",
   let assemblyFormat = "`<` struct(params) `>`";
 }
 
+//===----------------------------------------------------------------------===//
+// DICompositeTypeAttr
+//===----------------------------------------------------------------------===//
+
+def LLVM_DICompositeTypeAttr : LLVM_Attr<"DICompositeType", "di_composite_type",
+                                         /*traits=*/[NativeTypeTrait<"IsMutable">], "DITypeAttr"> {
+  let parameters = (ins
+    OptionalParameter<"unsigned">:$tag,
+    OptionalParameter<"StringAttr">:$name,
+    OptionalParameter<"DIFileAttr">:$file,
+    OptionalParameter<"uint32_t">:$line,
+    OptionalParameter<"DIScopeAttr">:$scope,
+    OptionalParameter<"DITypeAttr">:$baseType,
+    OptionalParameter<"DIFlags", "DIFlags::Zero">:$flags,
+    OptionalParameter<"uint64_t">:$sizeInBits,
+    OptionalParameter<"uint64_t">:$alignInBits,
+    OptionalArrayRefParameter<"DINodeAttr">:$elements,
+    OptionalParameter<"StringRef">:$identifier
+  );
+  let hasCustomAssemblyFormat = 1;
+  let genStorageClass = 0;
+  let storageClass = "DICompositeTypeAttrStorage";
+  let builders = [
+    AttrBuilder<(ins
+        "unsigned":$tag,
+        "StringAttr":$name,
+        "DIFileAttr":$file,
+        "uint32_t":$line,
+        "DIScopeAttr":$scope,
+        "DITypeAttr":$baseType,
+        "DIFlags":$flags,
+        "uint64_t":$sizeInBits,
+        "uint64_t":$alignInBits,
+        "::llvm::ArrayRef<DINodeAttr>":$elements
+    )>,
+    AttrBuilder<(ins
+        "StringRef":$identifier,
+        "unsigned":$tag,
+        "StringAttr":$name,
+        "DIFileAttr":$file,
+        "uint32_t":$line,
+        "DIScopeAttr":$scope,
+        "DITypeAttr":$baseType,
+        "DIFlags":$flags,
+        "uint64_t":$sizeInBits,
+        "uint64_t":$alignInBits,
+        CArg<"::llvm::ArrayRef<DINodeAttr>", "{}">:$elements
+    )>
+  ];
+  let extraClassDeclaration = [{
+    static DICompositeTypeAttr getIdentified(MLIRContext *context, StringRef identifier);
+    void replaceElements(const ArrayRef<DINodeAttr>& elements);
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // AliasScopeAttr
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h
index c370bfa2b733d6..c38bf1c66bba35 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h
@@ -23,6 +23,10 @@
 namespace mlir {
 namespace LLVM {
 
+namespace detail {
+  struct DICompositeTypeAttrStorage;
+} // namespace detail
+
 /// This class represents the base attribute for all debug info attributes.
 class DINodeAttr : public Attribute {
 public:
diff --git a/mlir/lib/Dialect/LLVMIR/IR/AttrDetail.h b/mlir/lib/Dialect/LLVMIR/IR/AttrDetail.h
new file mode 100644
index 00000000000000..fd02af42a3fc37
--- /dev/null
+++ b/mlir/lib/Dialect/LLVMIR/IR/AttrDetail.h
@@ -0,0 +1,140 @@
+//===- AttrDetail.h - Details of MLIR LLVM dialect attributes --------*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains implementation details, such as storage structures, of
+// MLIR LLVM dialect attributes.
+//
+//===----------------------------------------------------------------------===//
+#ifndef DIALECT_LLVMIR_IR_ATTRDETAIL_H
+#define DIALECT_LLVMIR_IR_ATTRDETAIL_H
+
+#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
+#include "mlir/IR/Types.h"
+
+namespace mlir {
+namespace LLVM {
+namespace detail {
+
+//===----------------------------------------------------------------------===//
+// DICompositeTypeAttrStorage
+//===----------------------------------------------------------------------===//
+
+struct DICompositeTypeAttrStorage : public ::mlir::AttributeStorage {
+  using KeyTy = std::tuple<unsigned, StringAttr, DIFileAttr, uint32_t,
+                           DIScopeAttr, DITypeAttr, DIFlags, uint64_t, uint64_t,
+                           ArrayRef<DINodeAttr>, StringRef>;
+
+  DICompositeTypeAttrStorage(unsigned tag, StringAttr name, DIFileAttr file,
+                             uint32_t line, DIScopeAttr scope,
+                             DITypeAttr baseType, DIFlags flags,
+                             uint64_t sizeInBits, uint64_t alignInBits,
+                             ArrayRef<DINodeAttr> elements,
+                             StringRef identifier = StringRef())
+      : tag(tag), name(name), file(file), line(line), scope(scope),
+        baseType(baseType), flags(flags), sizeInBits(sizeInBits),
+        alignInBits(alignInBits), elements(elements), identifier(identifier) {}
+
+  unsigned getTag() const { return tag; }
+  StringAttr getName() const { return name; }
+  DIFileAttr getFile() const { return file; }
+  uint32_t getLine() const { return line; }
+  DIScopeAttr getScope() const { return scope; }
+  DITypeAttr getBaseType() const { return baseType; }
+  DIFlags getFlags() const { return flags; }
+  uint64_t getSizeInBits() const { return sizeInBits; }
+  uint64_t getAlignInBits() const { return alignInBits; }
+  ArrayRef<DINodeAttr> getElements() const { return elements; }
+  StringRef getIdentifier() const { return identifier; }
+
+  /// Returns true if this attribute is identified.
+  bool isIdentified() const {
+    if (identifier.empty())
+      return false;
+
+    return !identifier.empty();
+  }
+
+  /// Returns the respective key for this attribute.
+  KeyTy getAsKey() const {
+    if (isIdentified())
+      return KeyTy(tag, name, file, line, scope, baseType, flags, sizeInBits,
+                   alignInBits, elements, identifier);
+
+    return KeyTy(tag, name, file, line, scope, baseType, flags, sizeInBits,
+                 alignInBits, elements, StringRef());
+  }
+
+  /// Compares two keys.
+  bool operator==(const KeyTy &other) const {
+    if (isIdentified())
+      // Just compare against the identifier.
+      return identifier == std::get<10>(other);
+
+    // Otherwise, compare the entire tuple.
+    return other == getAsKey();
+  }
+
+  /// Returns the hash value of the key.
+  static llvm::hash_code hashKey(const KeyTy &key) {
+    const auto &[tag, name, file, line, scope, baseType, flags, sizeInBits,
+                 alignInBits, elements, identifier] = key;
+
+    if (!identifier.empty())
+      // The hash only consists of the unique identifier string.
+      return llvm::hash_value(identifier);
+
+    // Otherwise, everything else is included in the hash.
+    return llvm::hash_combine(tag, name, file, line, scope, baseType, flags,
+                              sizeInBits, alignInBits, elements);
+  }
+
+  /// Constructs new storage for an attribute.
+  static DICompositeTypeAttrStorage *
+  construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
+    auto [tag, name, file, line, scope, baseType, flags, sizeInBits,
+           alignInBits, elements, identifier] = key;
+    elements = allocator.copyInto(elements);
+    if (!identifier.empty()) {
+      identifier = allocator.copyInto(identifier);
+      return new (allocator.allocate<DICompositeTypeAttrStorage>())
+          DICompositeTypeAttrStorage(tag, name, file, line, scope, baseType,
+                                     flags, sizeInBits, alignInBits, elements,
+                                     identifier);
+    }
+    return new (allocator.allocate<DICompositeTypeAttrStorage>())
+        DICompositeTypeAttrStorage(tag, name, file, line, scope, baseType,
+                                   flags, sizeInBits, alignInBits, elements);
+  }
+
+  LogicalResult mutate(AttributeStorageAllocator &allocator,
+                       const ArrayRef<DINodeAttr>& elements) {
+    // Replace the elements.
+    this->elements = allocator.copyInto(elements);
+    return success();
+  }
+
+private:
+  unsigned tag;
+  StringAttr name;
+  DIFileAttr file;
+  uint32_t line;
+  DIScopeAttr scope;
+  DITypeAttr baseType;
+  DIFlags flags;
+  uint64_t sizeInBits;
+  uint64_t alignInBits;
+  ArrayRef<DINodeAttr> elements;
+  StringRef identifier;
+};
+
+} // namespace detail
+} // namespace LLVM
+} // namespace mlir
+
+#endif // DIALECT_LLVMIR_IR_ATTRDETAIL_H
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
index e2342670508ce4..c0c2c425d3c492 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
@@ -10,6 +10,8 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "AttrDetail.h"
+
 #include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/Builders.h"
@@ -123,7 +125,7 @@ bool MemoryEffectsAttr::isReadWrite() {
 }
 
 //===----------------------------------------------------------------------===//
-// DIExpression
+// DIExpressionAttr
 //===----------------------------------------------------------------------===//
 
 DIExpressionAttr DIExpressionAttr::get(MLIRContext *context) {
@@ -183,3 +185,408 @@ void printExpressionArg(AsmPrinter &printer, uint64_t opcode,
     i++;
   });
 }
+
+//===----------------------------------------------------------------------===//
+// DICompositeTypeAttr
+//===----------------------------------------------------------------------===//
+
+DICompositeTypeAttr
+DICompositeTypeAttr::get(MLIRContext *context, unsigned tag, StringAttr name,
+                         DIFileAttr file, uint32_t line, DIScopeAttr scope,
+                         DITypeAttr baseType, DIFlags flags,
+                         uint64_t sizeInBits, uint64_t alignInBits,
+                         ::llvm::ArrayRef<DINodeAttr> elements) {
+  return Base::get(context, tag, name, file, line, scope, baseType, flags,
+                   sizeInBits, alignInBits, elements, StringRef());
+}
+
+DICompositeTypeAttr DICompositeTypeAttr::get(
+    MLIRContext *context, StringRef identifier, unsigned tag, StringAttr name,
+    DIFileAttr file, uint32_t line, DIScopeAttr scope, DITypeAttr baseType,
+    DIFlags flags, uint64_t sizeInBits, uint64_t alignInBits,
+    ::llvm::ArrayRef<DINodeAttr> elements) {
+  return Base::get(context, tag, name, file, line, scope, baseType, flags,
+                   sizeInBits, alignInBits, elements, identifier);
+}
+
+unsigned DICompositeTypeAttr::getTag() const { return getImpl()->getTag(); }
+
+StringAttr DICompositeTypeAttr::getName() const { return getImpl()->getName(); }
+
+DIFileAttr DICompositeTypeAttr::getFile() const { return getImpl()->getFile(); }
+
+uint32_t DICompositeTypeAttr::getLine() const { return getImpl()->getLine(); }
+
+DIScopeAttr DICompositeTypeAttr::getScope() const {
+  return getImpl()->getScope();
+}
+
+DITypeAttr DICompositeTypeAttr::getBaseType() const {
+  return getImpl()->getBaseType();
+}
+
+DIFlags DICompositeTypeAttr::getFlags() const { return getImpl()->getFlags(); }
+
+uint64_t DICompositeTypeAttr::getSizeInBits() const {
+  return getImpl()->getSizeInBits();
+}
+
+uint64_t DICompositeTypeAttr::getAlignInBits() const {
+  return getImpl()->getAlignInBits();
+}
+
+::llvm::ArrayRef<DINodeAttr> DICompositeTypeAttr::getElements() const {
+  return getImpl()->getElements();
+}
+
+StringRef DICompositeTypeAttr::getIdentifier() const {
+  return getImpl()->getIdentifier();
+}
+
+Attribute DICompositeTypeAttr::parse(AsmParser &parser, Type type) {
+  const Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
+  FailureOr<AsmParser::CyclicParseReset> cyclicParse;
+  FailureOr<unsigned> tag;
+  FailureOr<StringAttr> name;
+  FailureOr<DIFileAttr> file;
+  FailureOr<uint32_t> line;
+  FailureOr<DIScopeAttr> scope;
+  FailureOr<DITypeAttr> baseType;
+  FailureOr<DIFlags> flags;
+  FailureOr<uint64_t> sizeInBits;
+  FailureOr<uint64_t> alignInBits;
+  FailureOr<SmallVector<DINodeAttr>> elements;
+
+  std::string identifier;
+  bool skipFirstKeyword = false;
+
+  // Begin parsing.
+  if (parser.parseLess()) {
+    parser.emitError(parser.getCurrentLocation(), "expected `<`.");
+    return {};
+  }
+
+  auto paramParser = [&](const bool expectElements,
+                         bool &hasElements) -> LogicalResult {
+    StringRef paramKey;
+
+    /// The first key word needs needs to skipped because it would have already
+    /// been parsed when attempting to parse the identifier.
+    if (skipFirstKeyword) {
+      // The parameter key is the identifier string.
+      paramKey = identifier;
+      skipFirstKeyword = false;
+    } else {
+      if (parser.parseKeyword(&paramKey)) {
+        return parser.emitError(parser.getCurrentLocation(),
+                       "expected parameter name.");
+      }
+    }
+
+    if (parser.parseEqual()) {
+      return parser.emitError(parser.getCurrentLocation(),
+                       "expected `=` following parameter name.");
+    }
+
+    if (paramKey == "tag") {
+      tag = [&]() -> FailureOr<unsigned> {
+        StringRef nameKeyword;
+        if (parser.parseKeyword(&nameKeyword))
+          return failure();
+        if (const unsigned value = llvm::dwarf::getTag(nameKeyword))
+          return value;
+        return parser.emitError(parser.getCurrentLocation())
+               << "invalid debug info debug info tag name: " << nameKeyword;
+      }();
+    } else if (failed(name) && paramKey == "name") {
+      name = FieldParser<StringAttr>::parse(parser);
+      if (failed(name)) {
+        return parser.emitError(parser.getCurrentLocation(),
+                         "failed to parse parameter 'name'");
+      }
+    } else if (failed(file) && paramKey == "file") {
+      file = FieldParser<DIFileAttr>::parse(parser);
+      if (failed(file)) {
+        return parser.emitError(parser.getCurrentLocation(),
+                         "failed to parse parameter 'file'");
+      }
+    } else if (failed(line) && paramKey == "line") {
+      line = FieldParser<uint32_t>::parse(parser);
+      if (failed(line)) {
+        return parser.emitError(parser.getCurrentLocation(),
+                         "failed to parse parameter 'line'");
+      }
+    } else if (failed(scope) && paramKey == "scope") {
+      scope = FieldParser<DIScopeAttr>::parse(parser);
+      if (failed(scope)) {
+        return parser.emitError(parser.getCurrentLocation(),
+                         "failed to parse parameter 'scope'");
+      }
+    } else if (failed(baseType) && paramKey == "baseType") {
+      baseType = FieldParser<DITypeAttr>::parse(parser);
+      if (failed(baseType)) {
+        return parser.emitError(parser.getCurrentLocation(),
+                         "failed to parse parameter 'baseType'");
+      }
+    } else if (failed(flags) && paramKey == "flags") {
+      flags = FieldParser<DIFlags>::parse(parser);
+      if (failed(flags)) {
+        return parser.emitError(parser.getCurrentLocation(),
+                         "failed to parse parameter 'flags'");
+      }
+    } else if (failed(sizeInBits) && paramKey == "sizeInBits") {
+      sizeInBits = FieldParser<uint32_t>::parse(parser);
+      if (failed(sizeInBits)) {
+        return parser.emitError(parser.getCurrentLocation(),
+                         "failed to parse parameter 'sizeInBits'");
+      }
+    } else if (failed(alignInBits) && paramKey == "alignInBits") {
+      alignInBits = FieldParser<uint32_t>::parse(parser);
+      if (failed(alignInBits)) {
+        return parser.emitError(parser.getCurrentLocation(),
+                         "failed to parse parameter 'alignInBits'");
+      }
+    } else if (failed(elements) && paramKey == "elements") {
+      if (expectElements) {
+        if (parser.parseLSquare()) {
+          return parser.emitError(parser.getCurrentLocation(), "expected `[`.");
+        }
+
+        elements = FieldParser<SmallVector<DINodeAttr>>::parse(parser);
+        if (failed(elements)) {
+          return parser.emitError(parser.getCurrentLocation(),
+                           "failed to parse parameter 'elements'");
+        }
+
+        if (parser.parseRSquare()) {
+          return parser.emitError(parser.getCurrentLocation(), "expected `]`.");
+        }
+      } else {
+        // Set hasElements to true to signal that parsing should cease until
+        // until after setting up recursive parsing.
+        hasElements = true;
+        return success();
+      }
+    } else {
+      return parser.emitError(parser.getCurrentLocation(), "unknown parameter '")
+          << paramKey << "'";
+    }
+    return success();
+  };
+
+  // This attribute is identified if a keyword followed by a comma or greater
+  // than is parsed.
+  if (succeeded(parser.parseOptionalKeywordOrString(&identifier))) {
+    bool hasElements = false;
+    skipFirstKeyword = true;
+    if (succeeded(parser.parseOptionalGreater()))
+      return getIdentified(parser.getContext(), identifier);
+
+    if (succeeded(parser.parseOptionalComma())) {
+      skipFirstKeyword = false;
+      // auto result = getChecked(
+      //   [&] { return emitError(loc); }, loc.getContext(),
+      //   StringRef(identifier));
+
+      // Parse immutable parameters.
+      do {
+        if (failed(paramParser(false, hasElements))) {
+          return {};
+        }
+
+        if (hasElements) {
+          // Stop parsing if "elements" was encountered. "elements" may be
+          // recursive in this context.
+          break;
+        }
+      } while (succeeded(parser.parseOptionalComma()));
+
+      if (succeeded(parser.parseOptionalGreater())) {
+        DICompositeTypeAttr result =
+            get(parser.getContext(), identifier, tag.value_or(0),
+                name.value_or(StringAttr()), file.value_or(DIFileAttr()),
+                line.value_or(0), scope.value_or(DIScopeAttr()),
+                baseType.value_or(DITypeAttr()), flags.value_or(DIFlags::Zero),
+                sizeInBits.value_or(0), alignInBits.value_or(0));
+
+        if (cyclicParse = parser.tryStartCyclicParse(result);
+            failed(cyclicParse)) {
+          parser.emitError(parser.getCurrentLocation(),
+                           "only identifier allowed in recursive composite "
+                           "type attribute.");
+          return {};
+        }
+        return result;
+      }
+    } else {
+      // Attempt to parse everything at once.
+      do {
+        if (failed(paramParser(true, hasElements))) {
+          return {};
+        }
+      } while (succeeded(parser.parseOptionalComma()));
+
+      // Expect the attribute to terminate.
+      if (parser.parseGreater()) {
+        parser.emitError(parser.getCurrentLocation(), "expected `>`.");
+        return {};
+      }
+
+      // return getChecked([loc] { return emitError(loc); }, loc.getContext(),
+      // tag.value_or(0),
+      //     name.value_or(StringAttr()), file.value_or(DIFileAttr()),
+      //     line.value_or(0), scope.value_or(DIScopeAttr()),
+      //     baseType.value_or(DITypeAttr()), flags.value_or(DIFlags::Zero),
+      //     sizeInBits.value_or(0), alignInBits.value_or(0),
+      //     elements.value_or(SmallVector<DINodeAttr>()));
+      return get(loc.getContext(), tag.value_or(0), name.value_or(StringAttr()),
+                 file.value_or(DIFileAttr()), line.value_or(0),
+                 scope.value_or(DIScopeAttr()), baseType.value_or(DITypeAttr()),
+                 flags.value_or(DIFlags::Zero), sizeInBits.value_or(0),
+                 alignInBits.value_or(0),
+                 elements.value_or(SmallVector<DINodeAttr>()));
+    }
+
+    DICompositeTypeAttr result =
+        get(parser.getContext(), identifier, tag.value_or(0),
+            name.value_or(StringAttr()), file.value_or(DIFileAttr()),
+            line.value_or(0), scope.value_or(DIScopeAttr()),
+            baseType.value_or(DITypeAttr()), flags.value_or(DIFlags::Zero),
+            sizeInBits.value_or(0), alignInBits.value_or(0),
+            elements.value_or(SmallVector<DINodeAttr>()));
+
+    cyclicParse = parser.tryStartCyclicParse(result);
+
+    if (hasElements)
+      // Parse the elements now.
+      if (parser.parseLSquare()) {
+        parser.emitError(parser.getCurrentLocation(), "expected `[`.");
+        return {};
+      }
+
+    elements = FieldParser<SmallVector<DINodeAttr>>::parse(parser);
+
+    if (parser.parseRSquare()) {
+      parser.emitError(parser.getCurrentLocation(), "expected `]`.");
+      return {};
+    }
+
+    // Expect the attribute to terminate.
+    if (parser.parseGreater()) {
+      parser.emitError(parser.getCurrentLocation(), "expected `>`.");
+      return {};
+    }
+
+    // Replace the elements.
+    result.replaceElements(elements.value_or(SmallVector<DINodeAttr>()));
+
+    // return getChecked(
+    //     loc, [loc] { return emitError(loc); }, loc.getContext(), identifier);
+
+    return result;
+  }
+  parser.emitError(parser.getCurrentLocation(),
+                   "expected identifier or parameter.");
+  return {};
+}
+
+void DICompositeTypeAttr::print(AsmPrinter &printer) const {
+  FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrint;
+  SmallVector<std::function<void()>> valuePrinters;
+
+  printer << "<";
+  if (getImpl()->isIdentified()) {
+    cyclicPrint = printer.tryStartCyclicPrint(*this);
+    if (failed(cyclicPrint)) {
+      printer << getIdentifier() << ">";
+      return;
+    }
+    valuePrinters.push_back([&]() { printer << getIdentifier(); });
+  }
+
+  if (getTag() > 0) {
+    valuePrinters.push_back(
+        [&]() { printer << "tag = " << llvm::dwarf::TagString(getTag()); });
+  }
+
+  if (getName()) {
+    valuePrinters.push_back([&]() {
+      printer << "name = ";
+      printer.printStrippedAttrOrType(getName());
+    });
+  }
+
+  if (getFile()) {
+    valuePrinters.push_back([&]() {
+      printer << "file = ";
+      printer.printStrippedAttrOrType(getFile());
+    });
+  }
+
+  if (getLine() > 0) {
+    valuePrinters.push_back([&]() {
+      printer << "line = ";
+      printer.printStrippedAttrOrType(getLine());
+    });
+  }
+
+  if (getScope()) {
+    valuePrinters.push_back([&]() {
+      printer << "scope = ";
+      printer.printStrippedAttrOrType(getScope());
+    });
+  }
+
+  if (getBaseType()) {
+    valuePrinters.push_back([&]() {
+      printer << "baseType = ";
+      printer.printStrippedAttrOrType(getBaseType());
+    });
+  }
+
+  if (getFlags() != DIFlags::Zero) {
+    valuePrinters.push_back([&]() {
+      printer << "flags = ";
+      printer.printStrippedAttrOrType(getFlags());
+    });
+  }
+
+  if (getSizeInBits() > 0) {
+    valuePrinters.push_back([&]() {
+      printer << "sizeInBits = ";
+      printer.printStrippedAttrOrType(getSizeInBits());
+    });
+  }
+
+  if (getAlignInBits() > 0) {
+    valuePrinters.push_back([&]() {
+      printer << "alignInBits = ";
+      printer.printStrippedAttrOrType(getAlignInBits());
+    });
+  }
+
+  if (!getElements().empty()) {
+    valuePrinters.push_back([&]() {
+      printer << "elements = [";
+      printer.printStrippedAttrOrType(getElements());
+      printer << "]";
+    });
+  }
+
+  interleaveComma(valuePrinters, printer,
+                  [&](const std::function<void()> &fn) { fn(); });
+
+  printer << ">";
+}
+
+DICompositeTypeAttr DICompositeTypeAttr::getIdentified(MLIRContext *context,
+                                                       StringRef identifier) {
+  return Base::get(context, 0, StringAttr(), DIFileAttr(), 0, DIScopeAttr(),
+                   DITypeAttr(), DIFlags::Zero, 0, 0, ArrayRef<DINodeAttr>(),
+                   identifier);
+}
+
+void DICompositeTypeAttr::replaceElements(
+    const ArrayRef<DINodeAttr> &elements) {
+  (void)Base::mutate(elements);
+}
diff --git a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp
index 6d845e27ffa247..e2a8d0fa48d05b 100644
--- a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp
@@ -116,10 +116,6 @@ static DINodeT *getDistinctOrUnique(bool isDistinct, Ts &&...args) {
 
 llvm::DICompositeType *
 DebugTranslation::translateImpl(DICompositeTypeAttr attr) {
-  SmallVector<llvm::Metadata *> elements;
-  for (auto member : attr.getElements())
-    elements.push_back(translate(member));
-
   // TODO: Use distinct attributes to model this, once they have landed.
   // Depending on the tag, composite types must be distinct.
   bool isDistinct = false;
@@ -131,15 +127,31 @@ DebugTranslation::translateImpl(DICompositeTypeAttr attr) {
     isDistinct = true;
   }
 
-  return getDistinctOrUnique<llvm::DICompositeType>(
+  // Create the composite type metadata first with an empty set of elements.
+  llvm::DICompositeType *result = getDistinctOrUnique<llvm::DICompositeType>(
       isDistinct, llvmCtx, attr.getTag(), getMDStringOrNull(attr.getName()),
       translate(attr.getFile()), attr.getLine(), translate(attr.getScope()),
       translate(attr.getBaseType()), attr.getSizeInBits(),
       attr.getAlignInBits(),
       /*OffsetInBits=*/0,
       /*Flags=*/static_cast<llvm::DINode::DIFlags>(attr.getFlags()),
-      llvm::MDNode::get(llvmCtx, elements),
+      llvm::MDNode::get(llvmCtx, {}),
       /*RuntimeLang=*/0, /*VTableHolder=*/nullptr);
+
+  // Short-circuit the mapping for this attribute to prevent infinite recursion
+  // if this composite type is encountered while translating the elements.
+  attrToNode[attr] = result;
+
+  // Translate the elements.
+  SmallVector<llvm::Metadata*> elements;
+  for (const DINodeAttr member : attr.getElements())
+    elements.push_back(translate(member));
+
+  // Replace the elements in the resulting metadata.
+  result->replaceElements(llvm::MDTuple::get(llvmCtx, elements));
+
+  // Return the composite type.
+  return result;
 }
 
 llvm::DIDerivedType *DebugTranslation::translateImpl(DIDerivedTypeAttr attr) {
diff --git a/mlir/test/Dialect/LLVMIR/debuginfo.mlir b/mlir/test/Dialect/LLVMIR/debuginfo.mlir
index 53c38b47970310..4fc00815fc5e61 100644
--- a/mlir/test/Dialect/LLVMIR/debuginfo.mlir
+++ b/mlir/test/Dialect/LLVMIR/debuginfo.mlir
@@ -47,7 +47,12 @@
   tag = DW_TAG_array_type, name = "array1", file = #file,
   scope = #file, baseType = #int0,
   // Specify the subrange count.
-  elements = #llvm.di_subrange<count = 4>
+  elements = [#llvm.di_subrange<count = 4>]
+>
+
+// CHECK-DAG: #[[COMP3:.*]] = #llvm.di_composite_type<mystruct, tag = DW_TAG_array_type, name = "array1", file = #[[FILE]], scope = #[[FILE]], elements = #llvm.di_subrange<count = 4 : i64>>
+#comp3 = #llvm.di_composite_type<mystruct, tag = DW_TAG_struct_type, name = "struct1",
+  file = #file, scope = #file, elements = [#llvm.di_composite_type<mystruct>, #int0, #int1]
 >
 
 // CHECK-DAG: #[[TOPLEVEL:.*]] = #llvm.di_namespace<name = "toplevel", exportSymbols = true>
@@ -74,7 +79,7 @@
 
 // CHECK-DAG: #[[SPTYPE0:.*]] = #llvm.di_subroutine_type<callingConvention = DW_CC_normal, types = #[[NULL]], #[[INT0]], #[[PTR0]], #[[PTR1]], #[[COMP0:.*]], #[[COMP1:.*]], #[[COMP2:.*]]>
 #spType0 = #llvm.di_subroutine_type<
-  callingConvention = DW_CC_normal, types = #null, #int0, #ptr0, #ptr1, #comp0, #comp1, #comp2
+  callingConvention = DW_CC_normal, types = #null, #int0, #ptr0, #ptr1, #comp0, #comp1, #comp2, #comp3
 >
 
 // CHECK-DAG: #[[SPTYPE1:.*]] = #llvm.di_subroutine_type<types = #[[INT1]], #[[INT1]]>



More information about the Mlir-commits mailing list