[Mlir-commits] [mlir] [mlir] Add support for recursive elements in DICompositeTypeAttr. (PR #74948)

Justin Wilson llvmlistbot at llvm.org
Thu Mar 14 14:56:54 PDT 2024


https://github.com/waj334 updated https://github.com/llvm/llvm-project/pull/74948

>From fa557fc4265d93eb8c4b7ad943331a8201b1d083 Mon Sep 17 00:00:00 2001
From: "Justin A. Wilson" <justin.wilson at omibyte.io>
Date: Sat, 9 Dec 2023 14:14:51 -0600
Subject: [PATCH] [mlir] Add support for recursive elements in DICompositeAttr.

Implements mutable storage for DICompositeTypeAttr in order to allow for self-references in its elements array. When the "identifier" parameter set non-empty, only this string participates in the hash key, though the storage is implemented such that only the "elements" parameter is mutable. The module translator will now create the respective instance of llvm::DICompositeType without elements and then it will call "llvm::DICompositeType::replaceElements" to set the elements after each element is translated. The only required IR change was that elements are explicitly wrapped in square brackets for the sake of parsing.
---
 .../mlir/Dialect/LLVMIR/LLVMAttrDefs.td       |  76 ++--
 mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h  |   4 +
 mlir/include/mlir/Target/LLVMIR/Dialect/All.h |   2 +-
 mlir/lib/AsmParser/AttributeParser.cpp        |   1 +
 mlir/lib/Dialect/LLVMIR/IR/AttrDetail.h       | 136 +++++++
 mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp      | 344 +++++++++++++++++-
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp    |  16 +-
 mlir/lib/IR/AsmPrinter.cpp                    |   1 +
 mlir/lib/Interfaces/DataLayoutInterfaces.cpp  |  41 ++-
 mlir/lib/Target/LLVMIR/DebugTranslation.cpp   |  24 +-
 mlir/test/Dialect/LLVMIR/debuginfo.mlir       |  11 +-
 11 files changed, 602 insertions(+), 54 deletions(-)
 create mode 100644 mlir/lib/Dialect/LLVMIR/IR/AttrDetail.h

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index 6975b18ab7f81f..74bcc02ff13149 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -351,27 +351,6 @@ def LLVM_DICompileUnitAttr : LLVM_Attr<"DICompileUnit", "di_compile_unit",
   let assemblyFormat = "`<` struct(params) `>`";
 }
 
-//===----------------------------------------------------------------------===//
-// DICompositeTypeAttr
-//===----------------------------------------------------------------------===//
-
-def LLVM_DICompositeTypeAttr : LLVM_Attr<"DICompositeType", "di_composite_type",
-                                         /*traits=*/[], "DITypeAttr"> {
-  let parameters = (ins
-    LLVM_DITagParameter:$tag,
-    OptionalParameter<"StringAttr">:$name,
-    OptionalParameter<"DIFileAttr">:$file,
-    OptionalParameter<"uint32_t">:$line,
-    OptionalParameter<"DIScopeAttr">:$scope,
-    OptionalParameter<"DITypeAttr">:$baseType,
-    OptionalParameter<"DIFlags", "DIFlags::Zero">:$flags,
-    OptionalParameter<"uint64_t">:$sizeInBits,
-    OptionalParameter<"uint64_t">:$alignInBits,
-    OptionalArrayRefParameter<"DINodeAttr">:$elements
-  );
-  let assemblyFormat = "`<` struct(params) `>`";
-}
-
 //===----------------------------------------------------------------------===//
 // DIDerivedTypeAttr
 //===----------------------------------------------------------------------===//
@@ -684,6 +663,61 @@ def LLVM_AliasScopeDomainAttr : LLVM_Attr<"AliasScopeDomain",
   let assemblyFormat = "`<` struct(params) `>`";
 }
 
+//===----------------------------------------------------------------------===//
+// DICompositeTypeAttr
+//===----------------------------------------------------------------------===//
+
+def LLVM_DICompositeTypeAttr : LLVM_Attr<"DICompositeType", "di_composite_type",
+                                         /*traits=*/[NativeTypeTrait<"IsMutable">], "DITypeAttr"> {
+  let parameters = (ins
+    OptionalParameter<"unsigned">:$tag,
+    OptionalParameter<"StringAttr">:$name,
+    OptionalParameter<"DIFileAttr">:$file,
+    OptionalParameter<"uint32_t">:$line,
+    OptionalParameter<"DIScopeAttr">:$scope,
+    OptionalParameter<"DITypeAttr">:$baseType,
+    OptionalParameter<"DIFlags", "DIFlags::Zero">:$flags,
+    OptionalParameter<"uint64_t">:$sizeInBits,
+    OptionalParameter<"uint64_t">:$alignInBits,
+    OptionalArrayRefParameter<"DINodeAttr">:$elements,
+    OptionalParameter<"StringAttr">:$identifier
+  );
+  let hasCustomAssemblyFormat = 1;
+  let genStorageClass = 0;
+  let storageClass = "DICompositeTypeAttrStorage";
+  let builders = [
+    AttrBuilder<(ins
+        "unsigned":$tag,
+        "StringAttr":$name,
+        "DIFileAttr":$file,
+        "uint32_t":$line,
+        "DIScopeAttr":$scope,
+        "DITypeAttr":$baseType,
+        "DIFlags":$flags,
+        "uint64_t":$sizeInBits,
+        "uint64_t":$alignInBits,
+        "::llvm::ArrayRef<DINodeAttr>":$elements
+    )>,
+    AttrBuilder<(ins
+        "StringAttr":$identifier,
+        "unsigned":$tag,
+        "StringAttr":$name,
+        "DIFileAttr":$file,
+        "uint32_t":$line,
+        "DIScopeAttr":$scope,
+        "DITypeAttr":$baseType,
+        "DIFlags":$flags,
+        "uint64_t":$sizeInBits,
+        "uint64_t":$alignInBits,
+        CArg<"::llvm::ArrayRef<DINodeAttr>", "{}">:$elements
+    )>
+  ];
+  let extraClassDeclaration = [{
+    static DICompositeTypeAttr getIdentified(MLIRContext *context, StringAttr identifier);
+    void replaceElements(const ArrayRef<DINodeAttr>& elements);
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // AliasScopeAttr
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h
index c370bfa2b733d6..c38bf1c66bba35 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h
@@ -23,6 +23,10 @@
 namespace mlir {
 namespace LLVM {
 
+namespace detail {
+  struct DICompositeTypeAttrStorage;
+} // namespace detail
+
 /// This class represents the base attribute for all debug info attributes.
 class DINodeAttr : public Attribute {
 public:
diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h
index 5dfc15afb75931..8a9932ad36b670 100644
--- a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h
+++ b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h
@@ -63,7 +63,7 @@ registerAllGPUToLLVMIRTranslations(DialectRegistry &registry) {
   registerLLVMDialectTranslation(registry);
   registerNVVMDialectTranslation(registry);
   registerROCDLDialectTranslation(registry);
-  registerSPIRVDialectTranslation(registry);
+  //registerSPIRVDialectTranslation(registry);
 
   // Extension required for translating GPU offloading Ops.
   gpu::registerOffloadingLLVMTranslationInterfaceExternalModels(registry);
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index d085fb6af6bc14..de120bfd7a4f92 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -248,6 +248,7 @@ Attribute Parser::parseAttribute(Type type) {
 OptionalParseResult Parser::parseOptionalAttribute(Attribute &attribute,
                                                    Type type) {
   switch (getToken().getKind()) {
+  case Token::kw_distinct:
   case Token::at_identifier:
   case Token::floatliteral:
   case Token::integer:
diff --git a/mlir/lib/Dialect/LLVMIR/IR/AttrDetail.h b/mlir/lib/Dialect/LLVMIR/IR/AttrDetail.h
new file mode 100644
index 00000000000000..ccecd540ab0469
--- /dev/null
+++ b/mlir/lib/Dialect/LLVMIR/IR/AttrDetail.h
@@ -0,0 +1,136 @@
+//===- AttrDetail.h - Details of MLIR LLVM dialect attributes --------*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains implementation details, such as storage structures, of
+// MLIR LLVM dialect attributes.
+//
+//===----------------------------------------------------------------------===//
+#ifndef DIALECT_LLVMIR_IR_ATTRDETAIL_H
+#define DIALECT_LLVMIR_IR_ATTRDETAIL_H
+
+#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
+#include "mlir/IR/Types.h"
+
+namespace mlir {
+namespace LLVM {
+namespace detail {
+
+//===----------------------------------------------------------------------===//
+// DICompositeTypeAttrStorage
+//===----------------------------------------------------------------------===//
+
+struct DICompositeTypeAttrStorage : public ::mlir::AttributeStorage {
+  using KeyTy = std::tuple<unsigned, StringAttr, DIFileAttr, uint32_t,
+                           DIScopeAttr, DITypeAttr, DIFlags, uint64_t, uint64_t,
+                           ArrayRef<DINodeAttr>, StringAttr>;
+
+  DICompositeTypeAttrStorage(unsigned tag, StringAttr name, DIFileAttr file,
+                             uint32_t line, DIScopeAttr scope,
+                             DITypeAttr baseType, DIFlags flags,
+                             uint64_t sizeInBits, uint64_t alignInBits,
+                             ArrayRef<DINodeAttr> elements,
+                             StringAttr identifier = StringAttr())
+      : tag(tag), name(name), file(file), line(line), scope(scope),
+        baseType(baseType), flags(flags), sizeInBits(sizeInBits),
+        alignInBits(alignInBits), elements(elements), identifier(identifier) {}
+
+  unsigned getTag() const { return tag; }
+  StringAttr getName() const { return name; }
+  DIFileAttr getFile() const { return file; }
+  uint32_t getLine() const { return line; }
+  DIScopeAttr getScope() const { return scope; }
+  DITypeAttr getBaseType() const { return baseType; }
+  DIFlags getFlags() const { return flags; }
+  uint64_t getSizeInBits() const { return sizeInBits; }
+  uint64_t getAlignInBits() const { return alignInBits; }
+  ArrayRef<DINodeAttr> getElements() const { return elements; }
+  StringAttr getIdentifier() const { return identifier; }
+
+  /// Returns true if this attribute is identified.
+  bool isIdentified() const {
+    return !(!identifier);
+  }
+
+  /// Returns the respective key for this attribute.
+  KeyTy getAsKey() const {
+    if (isIdentified())
+      return KeyTy(tag, name, file, line, scope, baseType, flags, sizeInBits,
+                   alignInBits, elements, identifier);
+
+    return KeyTy(tag, name, file, line, scope, baseType, flags, sizeInBits,
+                 alignInBits, elements, StringAttr());
+  }
+
+  /// Compares two keys.
+  bool operator==(const KeyTy &other) const {
+    if (isIdentified())
+      // Just compare against the identifier.
+      return identifier == std::get<10>(other);
+
+    // Otherwise, compare the entire tuple.
+    return other == getAsKey();
+  }
+
+  /// Returns the hash value of the key.
+  static llvm::hash_code hashKey(const KeyTy &key) {
+    const auto &[tag, name, file, line, scope, baseType, flags, sizeInBits,
+                 alignInBits, elements, identifier] = key;
+
+    if (identifier)
+      // Only the identifier participates in the hash id.
+      return hash_value(identifier);
+
+    // Otherwise, everything else is included in the hash.
+    return hash_combine(tag, name, file, line, scope, baseType, flags,
+                              sizeInBits, alignInBits, elements);
+  }
+
+  /// Constructs new storage for an attribute.
+  static DICompositeTypeAttrStorage *
+  construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
+    auto [tag, name, file, line, scope, baseType, flags, sizeInBits,
+           alignInBits, elements, identifier] = key;
+    elements = allocator.copyInto(elements);
+    if (identifier) {
+      return new (allocator.allocate<DICompositeTypeAttrStorage>())
+          DICompositeTypeAttrStorage(tag, name, file, line, scope, baseType,
+                                     flags, sizeInBits, alignInBits, elements,
+                                     identifier);
+    }
+    return new (allocator.allocate<DICompositeTypeAttrStorage>())
+        DICompositeTypeAttrStorage(tag, name, file, line, scope, baseType,
+                                   flags, sizeInBits, alignInBits, elements);
+  }
+
+  LogicalResult mutate(AttributeStorageAllocator &allocator,
+                       const ArrayRef<DINodeAttr>& elements) {
+    // Replace the elements.
+    this->elements = allocator.copyInto(elements);
+    return success();
+  }
+
+private:
+  unsigned tag;
+  StringAttr name;
+  DIFileAttr file;
+  uint32_t line;
+  DIScopeAttr scope;
+  DITypeAttr baseType;
+  DIFlags flags;
+  uint64_t sizeInBits;
+  uint64_t alignInBits;
+  ArrayRef<DINodeAttr> elements;
+  StringAttr identifier;
+};
+
+} // namespace detail
+} // namespace LLVM
+} // namespace mlir
+
+#endif // DIALECT_LLVMIR_IR_ATTRDETAIL_H
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
index 645a45dd96befb..c11ed72fa3557a 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
@@ -10,6 +10,8 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "AttrDetail.h"
+
 #include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/Builders.h"
@@ -47,6 +49,7 @@ void LLVMDialect::registerAttributes() {
   addAttributes<
 #define GET_ATTRDEF_LIST
 #include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.cpp.inc"
+
       >();
 }
 
@@ -124,7 +127,7 @@ bool MemoryEffectsAttr::isReadWrite() {
 }
 
 //===----------------------------------------------------------------------===//
-// DIExpression
+// DIExpressionAttr
 //===----------------------------------------------------------------------===//
 
 DIExpressionAttr DIExpressionAttr::get(MLIRContext *context) {
@@ -248,3 +251,342 @@ TargetFeaturesAttr TargetFeaturesAttr::featuresAt(Operation *op) {
   return parentFunction.getOperation()->getAttrOfType<TargetFeaturesAttr>(
       getAttributeName());
 }
+
+//===----------------------------------------------------------------------===//
+// DICompositeTypeAttr
+//===----------------------------------------------------------------------===//
+
+DICompositeTypeAttr
+DICompositeTypeAttr::get(MLIRContext *context, unsigned tag, StringAttr name,
+                         DIFileAttr file, uint32_t line, DIScopeAttr scope,
+                         DITypeAttr baseType, DIFlags flags,
+                         uint64_t sizeInBits, uint64_t alignInBits,
+                         ::llvm::ArrayRef<DINodeAttr> elements) {
+  return Base::get(context, tag, name, file, line, scope, baseType, flags,
+                   sizeInBits, alignInBits, elements, StringAttr());
+}
+
+DICompositeTypeAttr DICompositeTypeAttr::get(
+    MLIRContext *context, StringAttr identifier, unsigned tag, StringAttr name,
+    DIFileAttr file, uint32_t line, DIScopeAttr scope, DITypeAttr baseType,
+    DIFlags flags, uint64_t sizeInBits, uint64_t alignInBits,
+    ::llvm::ArrayRef<DINodeAttr> elements) {
+  return Base::get(context, tag, name, file, line, scope, baseType, flags,
+                   sizeInBits, alignInBits, elements, identifier);
+}
+
+unsigned DICompositeTypeAttr::getTag() const { return getImpl()->getTag(); }
+
+StringAttr DICompositeTypeAttr::getName() const { return getImpl()->getName(); }
+
+DIFileAttr DICompositeTypeAttr::getFile() const { return getImpl()->getFile(); }
+
+uint32_t DICompositeTypeAttr::getLine() const { return getImpl()->getLine(); }
+
+DIScopeAttr DICompositeTypeAttr::getScope() const {
+  return getImpl()->getScope();
+}
+
+DITypeAttr DICompositeTypeAttr::getBaseType() const {
+  return getImpl()->getBaseType();
+}
+
+DIFlags DICompositeTypeAttr::getFlags() const { return getImpl()->getFlags(); }
+
+uint64_t DICompositeTypeAttr::getSizeInBits() const {
+  return getImpl()->getSizeInBits();
+}
+
+uint64_t DICompositeTypeAttr::getAlignInBits() const {
+  return getImpl()->getAlignInBits();
+}
+
+::llvm::ArrayRef<DINodeAttr> DICompositeTypeAttr::getElements() const {
+  return getImpl()->getElements();
+}
+
+StringAttr DICompositeTypeAttr::getIdentifier() const {
+  return getImpl()->getIdentifier();
+}
+
+Attribute DICompositeTypeAttr::parse(AsmParser &parser, Type type) {
+  FailureOr<AsmParser::CyclicParseReset> cyclicParse;
+  FailureOr<unsigned> tag;
+  FailureOr<StringAttr> name;
+  FailureOr<DIFileAttr> file;
+  FailureOr<uint32_t> line;
+  FailureOr<DIScopeAttr> scope;
+  FailureOr<DITypeAttr> baseType;
+  FailureOr<DIFlags> flags;
+  FailureOr<uint64_t> sizeInBits;
+  FailureOr<uint64_t> alignInBits;
+  SmallVector<DINodeAttr> elements;
+  StringAttr identifier;
+  const Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
+
+  auto paramParser = [&]() -> LogicalResult {
+    StringRef paramKey;
+    if (parser.parseKeyword(&paramKey)) {
+      return parser.emitError(parser.getCurrentLocation(),
+                              "expected parameter name.");
+    }
+
+    if (parser.parseEqual()) {
+      return parser.emitError(parser.getCurrentLocation(),
+                              "expected `=` following parameter name.");
+    }
+
+    if (failed(tag) && paramKey == "tag") {
+      tag = [&]() -> FailureOr<unsigned> {
+        StringRef nameKeyword;
+        if (parser.parseKeyword(&nameKeyword))
+          return failure();
+        if (const unsigned value = llvm::dwarf::getTag(nameKeyword))
+          return value;
+        return parser.emitError(parser.getCurrentLocation())
+               << "invalid debug info debug info tag name: " << nameKeyword;
+      }();
+    } else if (failed(name) && paramKey == "name") {
+      name = FieldParser<StringAttr>::parse(parser);
+      if (failed(name)) {
+        return parser.emitError(parser.getCurrentLocation(),
+                                "failed to parse parameter 'name'");
+      }
+    } else if (failed(file) && paramKey == "file") {
+      file = FieldParser<DIFileAttr>::parse(parser);
+      if (failed(file)) {
+        return parser.emitError(parser.getCurrentLocation(),
+                                "failed to parse parameter 'file'");
+      }
+    } else if (failed(line) && paramKey == "line") {
+      line = FieldParser<uint32_t>::parse(parser);
+      if (failed(line)) {
+        return parser.emitError(parser.getCurrentLocation(),
+                                "failed to parse parameter 'line'");
+      }
+    } else if (failed(scope) && paramKey == "scope") {
+      scope = FieldParser<DIScopeAttr>::parse(parser);
+      if (failed(scope)) {
+        return parser.emitError(parser.getCurrentLocation(),
+                                "failed to parse parameter 'scope'");
+      }
+    } else if (failed(baseType) && paramKey == "baseType") {
+      baseType = FieldParser<DITypeAttr>::parse(parser);
+      if (failed(baseType)) {
+        return parser.emitError(parser.getCurrentLocation(),
+                                "failed to parse parameter 'baseType'");
+      }
+    } else if (failed(flags) && paramKey == "flags") {
+      flags = FieldParser<DIFlags>::parse(parser);
+      if (failed(flags)) {
+        return parser.emitError(parser.getCurrentLocation(),
+                                "failed to parse parameter 'flags'");
+      }
+    } else if (failed(sizeInBits) && paramKey == "sizeInBits") {
+      sizeInBits = FieldParser<uint32_t>::parse(parser);
+      if (failed(sizeInBits)) {
+        return parser.emitError(parser.getCurrentLocation(),
+                                "failed to parse parameter 'sizeInBits'");
+      }
+    } else if (failed(alignInBits) && paramKey == "alignInBits") {
+      alignInBits = FieldParser<uint32_t>::parse(parser);
+      if (failed(alignInBits)) {
+        return parser.emitError(parser.getCurrentLocation(),
+                                "failed to parse parameter 'alignInBits'");
+      }
+    } else {
+      return parser.emitError(parser.getCurrentLocation(),
+                              "unknown parameter '")
+             << paramKey << "'";
+    }
+    return success();
+  };
+
+  // Begin parsing.
+  if (parser.parseLess()) {
+    parser.emitError(parser.getCurrentLocation(), "expected `<`");
+    return {};
+  }
+
+  // First, attempt to parse the identifier attribute.
+  const OptionalParseResult idResult =
+      parser.parseOptionalAttribute(identifier);
+  if (idResult.has_value() && succeeded(*idResult)) {
+    if (succeeded(parser.parseOptionalGreater())) {
+      DICompositeTypeAttr result =
+          getIdentified(parser.getContext(), identifier);
+      // Cyclic parsing should not initiate with only the identifier. Only
+      // nested instances should terminate early.
+      if (succeeded(parser.tryStartCyclicParse(result))) {
+        parser.emitError(parser.getCurrentLocation(),
+                         "Expected identified attribute to contain at least "
+                         "one other parameter");
+        return {};
+      }
+      return result;
+    }
+
+    if (parser.parseComma()) {
+      parser.emitError(parser.getCurrentLocation(), "Expected `,`");
+    }
+  }
+
+  // Parse immutable parameters.
+  if (parser.parseCommaSeparatedList(paramParser)) {
+    return {};
+  }
+
+  if (identifier) {
+    // Create the identified attribute.
+    DICompositeTypeAttr result =
+        get(parser.getContext(), identifier, tag.value_or(0),
+            name.value_or(StringAttr()), file.value_or(DIFileAttr()),
+            line.value_or(0), scope.value_or(DIScopeAttr()),
+            baseType.value_or(DITypeAttr()), flags.value_or(DIFlags::Zero),
+            sizeInBits.value_or(0), alignInBits.value_or(0));
+
+    // Initiate cyclic parsing.
+    if (cyclicParse = parser.tryStartCyclicParse(result); failed(cyclicParse)) {
+      return {};
+    }
+  }
+
+  // Parse the elements now.
+  if (succeeded(parser.parseOptionalLParen())) {
+    if (parser.parseCommaSeparatedList([&]() -> LogicalResult {
+          Attribute attr;
+          if (parser.parseAttribute(attr)) {
+            return parser.emitError(parser.getCurrentLocation(),
+                                    "expected attribute");
+          }
+          elements.push_back(mlir::cast<DINodeAttr>(attr));
+          return success();
+        })) {
+      return {};
+    }
+
+    if (parser.parseRParen()) {
+      parser.emitError(parser.getCurrentLocation(), "expected `)");
+      return {};
+    }
+  }
+
+  // Expect the attribute to terminate.
+  if (parser.parseGreater()) {
+    parser.emitError(parser.getCurrentLocation(), "expected `>`");
+    return {};
+  }
+
+  if (!identifier)
+    return get(loc.getContext(), tag.value_or(0), name.value_or(StringAttr()),
+               file.value_or(DIFileAttr()), line.value_or(0),
+               scope.value_or(DIScopeAttr()), baseType.value_or(DITypeAttr()),
+               flags.value_or(DIFlags::Zero), sizeInBits.value_or(0),
+               alignInBits.value_or(0), elements);
+
+  // Replace the elements if the attribute is identified.
+  DICompositeTypeAttr result = getIdentified(parser.getContext(), identifier);
+  result.replaceElements(elements);
+  return result;
+}
+
+void DICompositeTypeAttr::print(AsmPrinter &printer) const {
+  FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrint;
+  SmallVector<std::function<void()>> valuePrinters;
+  printer << "<";
+  if (getImpl()->isIdentified()) {
+    cyclicPrint = printer.tryStartCyclicPrint(*this);
+    if (failed(cyclicPrint)) {
+      printer << getIdentifier() << ">";
+      return;
+    }
+    valuePrinters.push_back([&]() { printer << getIdentifier(); });
+  }
+
+  if (getTag() > 0) {
+    valuePrinters.push_back(
+        [&]() {
+          printer << "tag = " << llvm::dwarf::TagString(getTag());
+        });
+  }
+
+  if (getName()) {
+    valuePrinters.push_back([&]() {
+
+
+
+      printer << "name = ";
+      printer.printStrippedAttrOrType(getName());
+    });
+  }
+
+  if (getFile()) {
+    valuePrinters.push_back([&]() {
+      printer << "file = ";
+      printer.printStrippedAttrOrType(getFile());
+    });
+  }
+
+  if (getLine() > 0) {
+    valuePrinters.push_back([&]() {
+      printer << "line = ";
+      printer.printStrippedAttrOrType(getLine());
+    });
+  }
+
+  if (getScope()) {
+    valuePrinters.push_back([&]() {
+      printer << "scope = ";
+      printer.printStrippedAttrOrType(getScope());
+    });
+  }
+
+  if (getBaseType()) {
+    valuePrinters.push_back([&]() {
+      printer << "baseType = ";
+      printer.printStrippedAttrOrType(getBaseType());
+    });
+  }
+
+  if (getFlags() != DIFlags::Zero) {
+    valuePrinters.push_back([&]() {
+      printer << "flags = ";
+      printer.printStrippedAttrOrType(getFlags());
+    });
+  }
+
+  if (getSizeInBits() > 0) {
+    valuePrinters.push_back([&]() {
+      printer << "sizeInBits = ";
+      printer.printStrippedAttrOrType(getSizeInBits());
+    });
+  }
+
+  if (getAlignInBits() > 0) {
+    valuePrinters.push_back([&]() {
+      printer << "alignInBits = ";
+      printer.printStrippedAttrOrType(getAlignInBits());
+    });
+  }
+  interleaveComma(valuePrinters, printer,
+                  [&](const std::function<void()> &fn) { fn(); });
+
+  if (!getElements().empty()) {
+    printer << " (";
+    printer.printStrippedAttrOrType(getElements());
+    printer << ")";
+  }
+  printer << ">";
+}
+
+DICompositeTypeAttr DICompositeTypeAttr::getIdentified(MLIRContext *context,
+                                                       StringAttr identifier) {
+  return Base::get(context, 0, StringAttr(), DIFileAttr(), 0, DIScopeAttr(),
+                   DITypeAttr(), DIFlags::Zero, 0, 0, ArrayRef<DINodeAttr>(),
+                   identifier);
+}
+
+void DICompositeTypeAttr::replaceElements(
+    const ArrayRef<DINodeAttr> &elements) {
+  (void)Base::mutate(elements);
+}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 53e1088f620d7e..70e8a6ce208584 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2933,14 +2933,14 @@ struct LLVMOpAsmDialectInterface : public OpAsmDialectInterface {
     return TypeSwitch<Attribute, AliasResult>(attr)
         .Case<AccessGroupAttr, AliasScopeAttr, AliasScopeDomainAttr,
               DIBasicTypeAttr, DICompileUnitAttr, DICompositeTypeAttr,
-              DIDerivedTypeAttr, DIFileAttr, DIGlobalVariableAttr,
-              DIGlobalVariableExpressionAttr, DILabelAttr, DILexicalBlockAttr,
-              DILexicalBlockFileAttr, DILocalVariableAttr, DIModuleAttr,
-              DINamespaceAttr, DINullTypeAttr, DISubprogramAttr,
-              DISubroutineTypeAttr, LoopAnnotationAttr, LoopVectorizeAttr,
-              LoopInterleaveAttr, LoopUnrollAttr, LoopUnrollAndJamAttr,
-              LoopLICMAttr, LoopDistributeAttr, LoopPipelineAttr,
-              LoopPeeledAttr, LoopUnswitchAttr, TBAARootAttr, TBAATagAttr,
+              DIFileAttr, DIGlobalVariableAttr, DIGlobalVariableExpressionAttr,
+              DILabelAttr, DILexicalBlockAttr, DILexicalBlockFileAttr,
+              DILocalVariableAttr, DIModuleAttr, DINamespaceAttr,
+              DINullTypeAttr, DISubprogramAttr, DISubroutineTypeAttr,
+              LoopAnnotationAttr, LoopVectorizeAttr, LoopInterleaveAttr,
+              LoopUnrollAttr, LoopUnrollAndJamAttr, LoopLICMAttr,
+              LoopDistributeAttr, LoopPipelineAttr, LoopPeeledAttr,
+              LoopUnswitchAttr, TBAARootAttr, TBAATagAttr,
               TBAATypeDescriptorAttr>([&](auto attr) {
           os << decltype(attr)::getMnemonic();
           return AliasResult::OverridableAlias;
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 1f7cbf349255d5..220fd1fbdd5588 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -744,6 +744,7 @@ class DummyAliasOperationPrinter : private OpAsmPrinter {
     printAttribute(attr);
   }
   LogicalResult printAlias(Attribute attr) override {
+
     initializer.visit(attr);
     return success();
   }
diff --git a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
index 1178417fd2a6ca..ce7f6371cb150a 100644
--- a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
+++ b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
@@ -35,9 +35,22 @@ using namespace mlir;
 /// Returns the bitwidth of the index type if specified in the param list.
 /// Assumes 64-bit index otherwise.
 static uint64_t getIndexBitwidth(DataLayoutEntryListRef params) {
-  if (params.empty())
+  DataLayoutEntryInterface entry;
+
+  // Look up the bitwidth param in the list.
+  for (DataLayoutEntryInterface param : params) {
+    if (param.getKey().is<Type>() &&
+        mlir::isa<IndexType>(param.getKey().get<Type>()))
+      entry = param;
+  }
+
+  // No corresponding entry was found, so assume the bitwidth is 64-bit.
+  if (!entry)
     return 64;
-  auto attr = cast<IntegerAttr>(params.front().getValue());
+
+  // The expected attribute is a IntegerAttr. Cast to it and retreive the
+  // bitwidth value.
+  auto attr = cast<IntegerAttr>(entry.getValue());
   return attr.getValue().getZExtValue();
 }
 
@@ -86,16 +99,20 @@ mlir::detail::getDefaultTypeSizeInBits(Type type, const DataLayout &dataLayout,
   reportMissingDataLayout(type);
 }
 
+template<typename T>
 static DataLayoutEntryInterface
-findEntryForIntegerType(IntegerType intType,
+findEntryForType(T type,
                         ArrayRef<DataLayoutEntryInterface> params) {
   assert(!params.empty() && "expected non-empty parameter list");
   std::map<unsigned, DataLayoutEntryInterface> sortedParams;
   for (DataLayoutEntryInterface entry : params) {
-    sortedParams.insert(std::make_pair(
-        entry.getKey().get<Type>().getIntOrFloatBitWidth(), entry));
+    // Filter the params by integer type.
+    if (entry.getKey().is<Type>() &&
+        mlir::isa<IntegerType>(entry.getKey().get<Type>()))
+      sortedParams.insert(std::make_pair(
+          entry.getKey().get<Type>().getIntOrFloatBitWidth(), entry));
   }
-  auto iter = sortedParams.lower_bound(intType.getWidth());
+  auto iter = sortedParams.lower_bound(type.getWidth());
   if (iter == sortedParams.end())
     iter = std::prev(iter);
 
@@ -122,17 +139,15 @@ getIntegerTypeABIAlignment(IntegerType intType,
                : kDefaultSmallIntAlignment;
   }
 
-  return extractABIAlignment(findEntryForIntegerType(intType, params));
+  return extractABIAlignment(findEntryForType<IntegerType>(intType, params));
 }
 
 static uint64_t
 getFloatTypeABIAlignment(FloatType fltType, const DataLayout &dataLayout,
                          ArrayRef<DataLayoutEntryInterface> params) {
-  assert(params.size() <= 1 && "at most one data layout entry is expected for "
-                               "the singleton floating-point type");
   if (params.empty())
     return llvm::PowerOf2Ceil(dataLayout.getTypeSize(fltType).getFixedValue());
-  return extractABIAlignment(params[0]);
+  return extractABIAlignment(findEntryForType(fltType, params));
 }
 
 uint64_t mlir::detail::getDefaultABIAlignment(
@@ -175,17 +190,15 @@ getIntegerTypePreferredAlignment(IntegerType intType,
   if (params.empty())
     return llvm::PowerOf2Ceil(dataLayout.getTypeSize(intType).getFixedValue());
 
-  return extractPreferredAlignment(findEntryForIntegerType(intType, params));
+  return extractPreferredAlignment(findEntryForType(intType, params));
 }
 
 static uint64_t
 getFloatTypePreferredAlignment(FloatType fltType, const DataLayout &dataLayout,
                                ArrayRef<DataLayoutEntryInterface> params) {
-  assert(params.size() <= 1 && "at most one data layout entry is expected for "
-                               "the singleton floating-point type");
   if (params.empty())
     return dataLayout.getTypeABIAlignment(fltType);
-  return extractPreferredAlignment(params[0]);
+  return extractPreferredAlignment(findEntryForType(fltType, params));
 }
 
 uint64_t mlir::detail::getDefaultPreferredAlignment(
diff --git a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp
index 16918aab549788..1f73d16b8ead44 100644
--- a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp
@@ -118,10 +118,6 @@ static DINodeT *getDistinctOrUnique(bool isDistinct, Ts &&...args) {
 
 llvm::DICompositeType *
 DebugTranslation::translateImpl(DICompositeTypeAttr attr) {
-  SmallVector<llvm::Metadata *> elements;
-  for (auto member : attr.getElements())
-    elements.push_back(translate(member));
-
   // TODO: Use distinct attributes to model this, once they have landed.
   // Depending on the tag, composite types must be distinct.
   bool isDistinct = false;
@@ -133,15 +129,31 @@ DebugTranslation::translateImpl(DICompositeTypeAttr attr) {
     isDistinct = true;
   }
 
-  return getDistinctOrUnique<llvm::DICompositeType>(
+  // Create the composite type metadata first with an empty set of elements.
+  llvm::DICompositeType *result = getDistinctOrUnique<llvm::DICompositeType>(
       isDistinct, llvmCtx, attr.getTag(), getMDStringOrNull(attr.getName()),
       translate(attr.getFile()), attr.getLine(), translate(attr.getScope()),
       translate(attr.getBaseType()), attr.getSizeInBits(),
       attr.getAlignInBits(),
       /*OffsetInBits=*/0,
       /*Flags=*/static_cast<llvm::DINode::DIFlags>(attr.getFlags()),
-      llvm::MDNode::get(llvmCtx, elements),
+      llvm::MDNode::get(llvmCtx, {}),
       /*RuntimeLang=*/0, /*VTableHolder=*/nullptr);
+
+  // Short-circuit the mapping for this attribute to prevent infinite recursion
+  // if this composite type is encountered while translating the elements.
+  attrToNode[attr] = result;
+
+  // Translate the elements.
+  SmallVector<llvm::Metadata*> elements;
+  for (const DINodeAttr member : attr.getElements())
+    elements.push_back(translate(member));
+
+  // Replace the elements in the resulting metadata.
+  result->replaceElements(llvm::MDTuple::get(llvmCtx, elements));
+
+  // Return the composite type.
+  return result;
 }
 
 llvm::DIDerivedType *DebugTranslation::translateImpl(DIDerivedTypeAttr attr) {
diff --git a/mlir/test/Dialect/LLVMIR/debuginfo.mlir b/mlir/test/Dialect/LLVMIR/debuginfo.mlir
index 53c38b47970310..c9724db7196ba8 100644
--- a/mlir/test/Dialect/LLVMIR/debuginfo.mlir
+++ b/mlir/test/Dialect/LLVMIR/debuginfo.mlir
@@ -36,6 +36,11 @@
   tag = DW_TAG_pointer_type, name = "ptr1"
 >
 
+// CHECK-DAG: #[[COMP3:.*]] = #llvm.di_composite_type<"mystruct", tag = DW_TAG_structure_type, name = "array1", file = #[[FILE]], scope = #[[FILE]] (#llvm.di_composite_type<"mystruct">, #int0, #int1)>
+#comp3 = #llvm.di_composite_type<"mystruct", tag = DW_TAG_structure_type, name = "struct1",
+  file = #file, scope = #file (#llvm.di_composite_type<"mystruct">, #int0, #int1)
+>
+
 // CHECK-DAG: #[[COMP0:.*]] = #llvm.di_composite_type<tag = DW_TAG_array_type, name = "array0", line = 10, sizeInBits = 128, alignInBits = 32>
 #comp0 = #llvm.di_composite_type<
   tag = DW_TAG_array_type, name = "array0",
@@ -45,9 +50,9 @@
 // CHECK-DAG: #[[COMP1:.*]] = #llvm.di_composite_type<tag = DW_TAG_array_type, name = "array1", file = #[[FILE]], scope = #[[FILE]], baseType = #[[INT0]], elements = #llvm.di_subrange<count = 4 : i64>>
 #comp1 = #llvm.di_composite_type<
   tag = DW_TAG_array_type, name = "array1", file = #file,
-  scope = #file, baseType = #int0,
+  scope = #file, baseType = #int0
   // Specify the subrange count.
-  elements = #llvm.di_subrange<count = 4>
+  (#llvm.di_subrange<count = 4>)
 >
 
 // CHECK-DAG: #[[TOPLEVEL:.*]] = #llvm.di_namespace<name = "toplevel", exportSymbols = true>
@@ -74,7 +79,7 @@
 
 // CHECK-DAG: #[[SPTYPE0:.*]] = #llvm.di_subroutine_type<callingConvention = DW_CC_normal, types = #[[NULL]], #[[INT0]], #[[PTR0]], #[[PTR1]], #[[COMP0:.*]], #[[COMP1:.*]], #[[COMP2:.*]]>
 #spType0 = #llvm.di_subroutine_type<
-  callingConvention = DW_CC_normal, types = #null, #int0, #ptr0, #ptr1, #comp0, #comp1, #comp2
+  callingConvention = DW_CC_normal, types = #null, #int0, #ptr0, #ptr1, #comp0, #comp1, #comp2, #comp3
 >
 
 // CHECK-DAG: #[[SPTYPE1:.*]] = #llvm.di_subroutine_type<types = #[[INT1]], #[[INT1]]>



More information about the Mlir-commits mailing list