[Mlir-commits] [mlir] [mlir] Add support for recursive elements in DICompositeAttr. (PR #74948)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Dec 9 12:19:29 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir

Author: Justin Wilson (waj334)

<details>
<summary>Changes</summary>

Implements mutable storage for DICompositeTypeAttr in order to allow for self-references in its elements array. When the "identifier" parameter set non-empty, only this string participates in the hash key, though the storage is implemented such that only the "elements" parameter is mutable. The module translator will now create the respective instance of llvm::DICompositeType without elements and then it will call "llvm::DICompositeType::replaceElements" to set the elements after each element is translated. The only required IR change was that elements are explicitly wrapped in square brackets for the sake of parsing.

---

Patch is 28.44 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/74948.diff


6 Files Affected:

- (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td (+55-21) 
- (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h (+4) 
- (added) mlir/lib/Dialect/LLVMIR/IR/AttrDetail.h (+140) 
- (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp (+408-1) 
- (modified) mlir/lib/Target/LLVMIR/DebugTranslation.cpp (+18-6) 
- (modified) mlir/test/Dialect/LLVMIR/debuginfo.mlir (+7-2) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index 5a65293a113c7..0aed5b7840fbe 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -342,27 +342,6 @@ def LLVM_DICompileUnitAttr : LLVM_Attr<"DICompileUnit", "di_compile_unit",
   let assemblyFormat = "`<` struct(params) `>`";
 }
 
-//===----------------------------------------------------------------------===//
-// DICompositeTypeAttr
-//===----------------------------------------------------------------------===//
-
-def LLVM_DICompositeTypeAttr : LLVM_Attr<"DICompositeType", "di_composite_type",
-                                         /*traits=*/[], "DITypeAttr"> {
-  let parameters = (ins
-    LLVM_DITagParameter:$tag,
-    OptionalParameter<"StringAttr">:$name,
-    OptionalParameter<"DIFileAttr">:$file,
-    OptionalParameter<"uint32_t">:$line,
-    OptionalParameter<"DIScopeAttr">:$scope,
-    OptionalParameter<"DITypeAttr">:$baseType,
-    OptionalParameter<"DIFlags", "DIFlags::Zero">:$flags,
-    OptionalParameter<"uint64_t">:$sizeInBits,
-    OptionalParameter<"uint64_t">:$alignInBits,
-    OptionalArrayRefParameter<"DINodeAttr">:$elements
-  );
-  let assemblyFormat = "`<` struct(params) `>`";
-}
-
 //===----------------------------------------------------------------------===//
 // DIDerivedTypeAttr
 //===----------------------------------------------------------------------===//
@@ -675,6 +654,61 @@ def LLVM_AliasScopeDomainAttr : LLVM_Attr<"AliasScopeDomain",
   let assemblyFormat = "`<` struct(params) `>`";
 }
 
+//===----------------------------------------------------------------------===//
+// DICompositeTypeAttr
+//===----------------------------------------------------------------------===//
+
+def LLVM_DICompositeTypeAttr : LLVM_Attr<"DICompositeType", "di_composite_type",
+                                         /*traits=*/[NativeTypeTrait<"IsMutable">], "DITypeAttr"> {
+  let parameters = (ins
+    OptionalParameter<"unsigned">:$tag,
+    OptionalParameter<"StringAttr">:$name,
+    OptionalParameter<"DIFileAttr">:$file,
+    OptionalParameter<"uint32_t">:$line,
+    OptionalParameter<"DIScopeAttr">:$scope,
+    OptionalParameter<"DITypeAttr">:$baseType,
+    OptionalParameter<"DIFlags", "DIFlags::Zero">:$flags,
+    OptionalParameter<"uint64_t">:$sizeInBits,
+    OptionalParameter<"uint64_t">:$alignInBits,
+    OptionalArrayRefParameter<"DINodeAttr">:$elements,
+    OptionalParameter<"StringRef">:$identifier
+  );
+  let hasCustomAssemblyFormat = 1;
+  let genStorageClass = 0;
+  let storageClass = "DICompositeTypeAttrStorage";
+  let builders = [
+    AttrBuilder<(ins
+        "unsigned":$tag,
+        "StringAttr":$name,
+        "DIFileAttr":$file,
+        "uint32_t":$line,
+        "DIScopeAttr":$scope,
+        "DITypeAttr":$baseType,
+        "DIFlags":$flags,
+        "uint64_t":$sizeInBits,
+        "uint64_t":$alignInBits,
+        "::llvm::ArrayRef<DINodeAttr>":$elements
+    )>,
+    AttrBuilder<(ins
+        "StringRef":$identifier,
+        "unsigned":$tag,
+        "StringAttr":$name,
+        "DIFileAttr":$file,
+        "uint32_t":$line,
+        "DIScopeAttr":$scope,
+        "DITypeAttr":$baseType,
+        "DIFlags":$flags,
+        "uint64_t":$sizeInBits,
+        "uint64_t":$alignInBits,
+        CArg<"::llvm::ArrayRef<DINodeAttr>", "{}">:$elements
+    )>
+  ];
+  let extraClassDeclaration = [{
+    static DICompositeTypeAttr getIdentified(MLIRContext *context, StringRef identifier);
+    void replaceElements(const ArrayRef<DINodeAttr>& elements);
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // AliasScopeAttr
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h
index c370bfa2b733d..c38bf1c66bba3 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h
@@ -23,6 +23,10 @@
 namespace mlir {
 namespace LLVM {
 
+namespace detail {
+  struct DICompositeTypeAttrStorage;
+} // namespace detail
+
 /// This class represents the base attribute for all debug info attributes.
 class DINodeAttr : public Attribute {
 public:
diff --git a/mlir/lib/Dialect/LLVMIR/IR/AttrDetail.h b/mlir/lib/Dialect/LLVMIR/IR/AttrDetail.h
new file mode 100644
index 0000000000000..fd02af42a3fc3
--- /dev/null
+++ b/mlir/lib/Dialect/LLVMIR/IR/AttrDetail.h
@@ -0,0 +1,140 @@
+//===- AttrDetail.h - Details of MLIR LLVM dialect attributes --------*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains implementation details, such as storage structures, of
+// MLIR LLVM dialect attributes.
+//
+//===----------------------------------------------------------------------===//
+#ifndef DIALECT_LLVMIR_IR_ATTRDETAIL_H
+#define DIALECT_LLVMIR_IR_ATTRDETAIL_H
+
+#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
+#include "mlir/IR/Types.h"
+
+namespace mlir {
+namespace LLVM {
+namespace detail {
+
+//===----------------------------------------------------------------------===//
+// DICompositeTypeAttrStorage
+//===----------------------------------------------------------------------===//
+
+struct DICompositeTypeAttrStorage : public ::mlir::AttributeStorage {
+  using KeyTy = std::tuple<unsigned, StringAttr, DIFileAttr, uint32_t,
+                           DIScopeAttr, DITypeAttr, DIFlags, uint64_t, uint64_t,
+                           ArrayRef<DINodeAttr>, StringRef>;
+
+  DICompositeTypeAttrStorage(unsigned tag, StringAttr name, DIFileAttr file,
+                             uint32_t line, DIScopeAttr scope,
+                             DITypeAttr baseType, DIFlags flags,
+                             uint64_t sizeInBits, uint64_t alignInBits,
+                             ArrayRef<DINodeAttr> elements,
+                             StringRef identifier = StringRef())
+      : tag(tag), name(name), file(file), line(line), scope(scope),
+        baseType(baseType), flags(flags), sizeInBits(sizeInBits),
+        alignInBits(alignInBits), elements(elements), identifier(identifier) {}
+
+  unsigned getTag() const { return tag; }
+  StringAttr getName() const { return name; }
+  DIFileAttr getFile() const { return file; }
+  uint32_t getLine() const { return line; }
+  DIScopeAttr getScope() const { return scope; }
+  DITypeAttr getBaseType() const { return baseType; }
+  DIFlags getFlags() const { return flags; }
+  uint64_t getSizeInBits() const { return sizeInBits; }
+  uint64_t getAlignInBits() const { return alignInBits; }
+  ArrayRef<DINodeAttr> getElements() const { return elements; }
+  StringRef getIdentifier() const { return identifier; }
+
+  /// Returns true if this attribute is identified.
+  bool isIdentified() const {
+    if (identifier.empty())
+      return false;
+
+    return !identifier.empty();
+  }
+
+  /// Returns the respective key for this attribute.
+  KeyTy getAsKey() const {
+    if (isIdentified())
+      return KeyTy(tag, name, file, line, scope, baseType, flags, sizeInBits,
+                   alignInBits, elements, identifier);
+
+    return KeyTy(tag, name, file, line, scope, baseType, flags, sizeInBits,
+                 alignInBits, elements, StringRef());
+  }
+
+  /// Compares two keys.
+  bool operator==(const KeyTy &other) const {
+    if (isIdentified())
+      // Just compare against the identifier.
+      return identifier == std::get<10>(other);
+
+    // Otherwise, compare the entire tuple.
+    return other == getAsKey();
+  }
+
+  /// Returns the hash value of the key.
+  static llvm::hash_code hashKey(const KeyTy &key) {
+    const auto &[tag, name, file, line, scope, baseType, flags, sizeInBits,
+                 alignInBits, elements, identifier] = key;
+
+    if (!identifier.empty())
+      // The hash only consists of the unique identifier string.
+      return llvm::hash_value(identifier);
+
+    // Otherwise, everything else is included in the hash.
+    return llvm::hash_combine(tag, name, file, line, scope, baseType, flags,
+                              sizeInBits, alignInBits, elements);
+  }
+
+  /// Constructs new storage for an attribute.
+  static DICompositeTypeAttrStorage *
+  construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
+    auto [tag, name, file, line, scope, baseType, flags, sizeInBits,
+           alignInBits, elements, identifier] = key;
+    elements = allocator.copyInto(elements);
+    if (!identifier.empty()) {
+      identifier = allocator.copyInto(identifier);
+      return new (allocator.allocate<DICompositeTypeAttrStorage>())
+          DICompositeTypeAttrStorage(tag, name, file, line, scope, baseType,
+                                     flags, sizeInBits, alignInBits, elements,
+                                     identifier);
+    }
+    return new (allocator.allocate<DICompositeTypeAttrStorage>())
+        DICompositeTypeAttrStorage(tag, name, file, line, scope, baseType,
+                                   flags, sizeInBits, alignInBits, elements);
+  }
+
+  LogicalResult mutate(AttributeStorageAllocator &allocator,
+                       const ArrayRef<DINodeAttr>& elements) {
+    // Replace the elements.
+    this->elements = allocator.copyInto(elements);
+    return success();
+  }
+
+private:
+  unsigned tag;
+  StringAttr name;
+  DIFileAttr file;
+  uint32_t line;
+  DIScopeAttr scope;
+  DITypeAttr baseType;
+  DIFlags flags;
+  uint64_t sizeInBits;
+  uint64_t alignInBits;
+  ArrayRef<DINodeAttr> elements;
+  StringRef identifier;
+};
+
+} // namespace detail
+} // namespace LLVM
+} // namespace mlir
+
+#endif // DIALECT_LLVMIR_IR_ATTRDETAIL_H
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
index e2342670508ce..c0c2c425d3c49 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
@@ -10,6 +10,8 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "AttrDetail.h"
+
 #include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/Builders.h"
@@ -123,7 +125,7 @@ bool MemoryEffectsAttr::isReadWrite() {
 }
 
 //===----------------------------------------------------------------------===//
-// DIExpression
+// DIExpressionAttr
 //===----------------------------------------------------------------------===//
 
 DIExpressionAttr DIExpressionAttr::get(MLIRContext *context) {
@@ -183,3 +185,408 @@ void printExpressionArg(AsmPrinter &printer, uint64_t opcode,
     i++;
   });
 }
+
+//===----------------------------------------------------------------------===//
+// DICompositeTypeAttr
+//===----------------------------------------------------------------------===//
+
+DICompositeTypeAttr
+DICompositeTypeAttr::get(MLIRContext *context, unsigned tag, StringAttr name,
+                         DIFileAttr file, uint32_t line, DIScopeAttr scope,
+                         DITypeAttr baseType, DIFlags flags,
+                         uint64_t sizeInBits, uint64_t alignInBits,
+                         ::llvm::ArrayRef<DINodeAttr> elements) {
+  return Base::get(context, tag, name, file, line, scope, baseType, flags,
+                   sizeInBits, alignInBits, elements, StringRef());
+}
+
+DICompositeTypeAttr DICompositeTypeAttr::get(
+    MLIRContext *context, StringRef identifier, unsigned tag, StringAttr name,
+    DIFileAttr file, uint32_t line, DIScopeAttr scope, DITypeAttr baseType,
+    DIFlags flags, uint64_t sizeInBits, uint64_t alignInBits,
+    ::llvm::ArrayRef<DINodeAttr> elements) {
+  return Base::get(context, tag, name, file, line, scope, baseType, flags,
+                   sizeInBits, alignInBits, elements, identifier);
+}
+
+unsigned DICompositeTypeAttr::getTag() const { return getImpl()->getTag(); }
+
+StringAttr DICompositeTypeAttr::getName() const { return getImpl()->getName(); }
+
+DIFileAttr DICompositeTypeAttr::getFile() const { return getImpl()->getFile(); }
+
+uint32_t DICompositeTypeAttr::getLine() const { return getImpl()->getLine(); }
+
+DIScopeAttr DICompositeTypeAttr::getScope() const {
+  return getImpl()->getScope();
+}
+
+DITypeAttr DICompositeTypeAttr::getBaseType() const {
+  return getImpl()->getBaseType();
+}
+
+DIFlags DICompositeTypeAttr::getFlags() const { return getImpl()->getFlags(); }
+
+uint64_t DICompositeTypeAttr::getSizeInBits() const {
+  return getImpl()->getSizeInBits();
+}
+
+uint64_t DICompositeTypeAttr::getAlignInBits() const {
+  return getImpl()->getAlignInBits();
+}
+
+::llvm::ArrayRef<DINodeAttr> DICompositeTypeAttr::getElements() const {
+  return getImpl()->getElements();
+}
+
+StringRef DICompositeTypeAttr::getIdentifier() const {
+  return getImpl()->getIdentifier();
+}
+
+Attribute DICompositeTypeAttr::parse(AsmParser &parser, Type type) {
+  const Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
+  FailureOr<AsmParser::CyclicParseReset> cyclicParse;
+  FailureOr<unsigned> tag;
+  FailureOr<StringAttr> name;
+  FailureOr<DIFileAttr> file;
+  FailureOr<uint32_t> line;
+  FailureOr<DIScopeAttr> scope;
+  FailureOr<DITypeAttr> baseType;
+  FailureOr<DIFlags> flags;
+  FailureOr<uint64_t> sizeInBits;
+  FailureOr<uint64_t> alignInBits;
+  FailureOr<SmallVector<DINodeAttr>> elements;
+
+  std::string identifier;
+  bool skipFirstKeyword = false;
+
+  // Begin parsing.
+  if (parser.parseLess()) {
+    parser.emitError(parser.getCurrentLocation(), "expected `<`.");
+    return {};
+  }
+
+  auto paramParser = [&](const bool expectElements,
+                         bool &hasElements) -> LogicalResult {
+    StringRef paramKey;
+
+    /// The first key word needs needs to skipped because it would have already
+    /// been parsed when attempting to parse the identifier.
+    if (skipFirstKeyword) {
+      // The parameter key is the identifier string.
+      paramKey = identifier;
+      skipFirstKeyword = false;
+    } else {
+      if (parser.parseKeyword(&paramKey)) {
+        return parser.emitError(parser.getCurrentLocation(),
+                       "expected parameter name.");
+      }
+    }
+
+    if (parser.parseEqual()) {
+      return parser.emitError(parser.getCurrentLocation(),
+                       "expected `=` following parameter name.");
+    }
+
+    if (paramKey == "tag") {
+      tag = [&]() -> FailureOr<unsigned> {
+        StringRef nameKeyword;
+        if (parser.parseKeyword(&nameKeyword))
+          return failure();
+        if (const unsigned value = llvm::dwarf::getTag(nameKeyword))
+          return value;
+        return parser.emitError(parser.getCurrentLocation())
+               << "invalid debug info debug info tag name: " << nameKeyword;
+      }();
+    } else if (failed(name) && paramKey == "name") {
+      name = FieldParser<StringAttr>::parse(parser);
+      if (failed(name)) {
+        return parser.emitError(parser.getCurrentLocation(),
+                         "failed to parse parameter 'name'");
+      }
+    } else if (failed(file) && paramKey == "file") {
+      file = FieldParser<DIFileAttr>::parse(parser);
+      if (failed(file)) {
+        return parser.emitError(parser.getCurrentLocation(),
+                         "failed to parse parameter 'file'");
+      }
+    } else if (failed(line) && paramKey == "line") {
+      line = FieldParser<uint32_t>::parse(parser);
+      if (failed(line)) {
+        return parser.emitError(parser.getCurrentLocation(),
+                         "failed to parse parameter 'line'");
+      }
+    } else if (failed(scope) && paramKey == "scope") {
+      scope = FieldParser<DIScopeAttr>::parse(parser);
+      if (failed(scope)) {
+        return parser.emitError(parser.getCurrentLocation(),
+                         "failed to parse parameter 'scope'");
+      }
+    } else if (failed(baseType) && paramKey == "baseType") {
+      baseType = FieldParser<DITypeAttr>::parse(parser);
+      if (failed(baseType)) {
+        return parser.emitError(parser.getCurrentLocation(),
+                         "failed to parse parameter 'baseType'");
+      }
+    } else if (failed(flags) && paramKey == "flags") {
+      flags = FieldParser<DIFlags>::parse(parser);
+      if (failed(flags)) {
+        return parser.emitError(parser.getCurrentLocation(),
+                         "failed to parse parameter 'flags'");
+      }
+    } else if (failed(sizeInBits) && paramKey == "sizeInBits") {
+      sizeInBits = FieldParser<uint32_t>::parse(parser);
+      if (failed(sizeInBits)) {
+        return parser.emitError(parser.getCurrentLocation(),
+                         "failed to parse parameter 'sizeInBits'");
+      }
+    } else if (failed(alignInBits) && paramKey == "alignInBits") {
+      alignInBits = FieldParser<uint32_t>::parse(parser);
+      if (failed(alignInBits)) {
+        return parser.emitError(parser.getCurrentLocation(),
+                         "failed to parse parameter 'alignInBits'");
+      }
+    } else if (failed(elements) && paramKey == "elements") {
+      if (expectElements) {
+        if (parser.parseLSquare()) {
+          return parser.emitError(parser.getCurrentLocation(), "expected `[`.");
+        }
+
+        elements = FieldParser<SmallVector<DINodeAttr>>::parse(parser);
+        if (failed(elements)) {
+          return parser.emitError(parser.getCurrentLocation(),
+                           "failed to parse parameter 'elements'");
+        }
+
+        if (parser.parseRSquare()) {
+          return parser.emitError(parser.getCurrentLocation(), "expected `]`.");
+        }
+      } else {
+        // Set hasElements to true to signal that parsing should cease until
+        // until after setting up recursive parsing.
+        hasElements = true;
+        return success();
+      }
+    } else {
+      return parser.emitError(parser.getCurrentLocation(), "unknown parameter '")
+          << paramKey << "'";
+    }
+    return success();
+  };
+
+  // This attribute is identified if a keyword followed by a comma or greater
+  // than is parsed.
+  if (succeeded(parser.parseOptionalKeywordOrString(&identifier))) {
+    bool hasElements = false;
+    skipFirstKeyword = true;
+    if (succeeded(parser.parseOptionalGreater()))
+      return getIdentified(parser.getContext(), identifier);
+
+    if (succeeded(parser.parseOptionalComma())) {
+      skipFirstKeyword = false;
+      // auto result = getChecked(
+      //   [&] { return emitError(loc); }, loc.getContext(),
+      //   StringRef(identifier));
+
+      // Parse immutable parameters.
+      do {
+        if (failed(paramParser(false, hasElements))) {
+          return {};
+        }
+
+        if (hasElements) {
+          // Stop parsing if "elements" was encountered. "elements" may be
+          // recursive in this context.
+          break;
+        }
+      } while (succeeded(parser.parseOptionalComma()));
+
+      if (succeeded(parser.parseOptionalGreater())) {
+        DICompositeTypeAttr result =
+            get(parser.getContext(), identifier, tag.value_or(0),
+                name.value_or(StringAttr()), file.value_or(DIFileAttr()),
+                line.value_or(0), scope.value_or(DIScopeAttr()),
+                baseType.value_or(DITypeAttr()), flags.value_or(DIFlags::Zero),
+                sizeInBits.value_or(0), alignInBits.value_or(0));
+
+        if (cyclicParse = parser.tryStartCyclicParse(result);
+            failed(cyclicParse)) {
+          parser.emitError(parser.getCurrentLocation(),
+                           "only identifier allowed in recursive composite "
+                           "type attribute.");
+          return {};
+        }
+        return result;
+      }
+    } else {
+      // Attempt to parse everything at once.
+      do {
+        if (failed(paramParser(true, hasElements))) {
+          return {};
+        }
+      } while (succeeded(parser.parseOptionalComma()));
+
+      // Expect the attribute to terminate.
+      if (parser.parseGreater()) {
+        parser.emitError(parser.getCurrentLocation(), "expected `>`.");
+        return {};
+  ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/74948


More information about the Mlir-commits mailing list