[Mlir-commits] [mlir] [mlir][ROCDL] Add metadata attributes for storing ELF object information (PR #95292)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jun 12 12:31:36 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir

Author: Fabian Mora (fabianmcg)

<details>
<summary>Changes</summary>

This patch adds the `#rocdl.kernel` and `#rocdl.object_metadata` attributes. The `#rocdl.kernel` attribute allows storing metadata related to a compiled kernel, for example, the number of scalar registers used by the kernel. It also stores attribute dictionary of the LLVM function used to generate the kernel.

The `#rocdl.object_metadata` stores a table of `#rocdl.kernel`, mapping the name of the kernel to the metadata.

Finally, the function `ROCDL::getAMDHSAKernelsMetadata` was added to collect ELF metadata from a binary. The binary is expected to be complaint with:
https://llvm.org/docs/AMDGPUUsage.html#code-object-v5-metadata

Example:
```mlir
gpu.binary @<!-- -->binary [#gpu.object<#rocdl.target<chip = "gfx900">, properties = {
    "rocdl.object_metadata" = #rocdl.object_metadata<{
      kernel0 = #rocdl.kernel<{sym_name = "kernel0", ...}, metadata = {sgpr_count = 255, ...}>,
    }>
  }, bin = "...">]
```
The motivation behind these attributes is that they provide useful information for things like tunning.


---

Patch is 20.87 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/95292.diff


8 Files Affected:

- (modified) mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h (+8) 
- (modified) mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td (+160) 
- (modified) mlir/include/mlir/Target/LLVM/ROCDL/Utils.h (+6) 
- (modified) mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp (+32) 
- (modified) mlir/lib/Target/LLVM/CMakeLists.txt (+2) 
- (added) mlir/lib/Target/LLVM/ROCDL/Utils.cpp (+225) 
- (modified) mlir/test/Dialect/LLVMIR/rocdl.mlir (+7) 
- (modified) mlir/unittests/Target/LLVM/SerializeROCDLTarget.cpp (+43) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h b/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h
index c2a82ffc1c43c..cf043b2be65bf 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h
@@ -37,4 +37,12 @@
 
 #include "mlir/Dialect/LLVMIR/ROCDLOpsDialect.h.inc"
 
+namespace mlir {
+namespace ROCDL {
+/// Returns the key used for storing the ROCDL metadata dictionary in the
+/// property field dictionary in `#gpu.object`.
+StringRef getROCDLObjectMetadataName();
+} // namespace ROCDL
+} // namespace mlir
+
 #endif /* MLIR_DIALECT_LLVMIR_ROCDLDIALECT_H_ */
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 868208ff74a52..656df3389f4cb 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -722,4 +722,164 @@ def ROCDL_TargettAttr :
     }
   }];
 }
+
+//===----------------------------------------------------------------------===//
+// ROCDL kernel attribute
+//===----------------------------------------------------------------------===//
+
+def ROCDL_KernelAttr :
+    ROCDL_Attr<"ROCDLKernel", "kernel"> {
+  let description = [{
+    ROCDL attribute for storing metadata related to a compiled kernel. It
+    contains the attribute dictionary of the LLVM function used to generate the
+    kernel, as well as an optional dictionary for additional metadata, like ELF
+    related metadata.
+    For details on the ELF metadata see:
+    https://llvm.org/docs/AMDGPUUsage.html#code-object-v5-metadata
+
+    Examples:
+    ```mlir
+      #rocdl.kernel<{sym_name = "test_fusion__part_0", ...},
+                    metadata = {sgpr_count = 255, ...}>
+    ```
+  }];
+  let parameters = (ins
+    "DictionaryAttr":$func_attrs,
+    OptionalParameter<"DictionaryAttr", "metadata dictionary">:$metadata
+  );
+  let assemblyFormat = [{
+    `<` $func_attrs (`,` `metadata` `=` $metadata^ )? `>`
+  }];
+  let builders = [
+    AttrBuilderWithInferredContext<(ins "DictionaryAttr":$funcAttrs,
+                                         CArg<"DictionaryAttr",
+                                              "nullptr">:$metadata), [{
+      assert(funcAttrs && "invalid function attributes dictionary");
+      return $_get(funcAttrs.getContext(), funcAttrs, metadata);
+    }]>
+  ];
+  let extraClassDeclaration = [{
+    /// Returns the function attribute corresponding to key or nullptr if missing.
+    Attribute getAttr(StringRef key) const {
+      return getFuncAttrs().get(key);
+    }
+    template <typename ConcreteAttr>
+    ConcreteAttr getAttr(StringRef key) const {
+      return llvm::dyn_cast_or_null<ConcreteAttr>(getAttr(key));
+    }
+    Attribute getAttr(StringAttr key) const;
+    template <typename ConcreteAttr>
+    ConcreteAttr getAttr(StringAttr key) const {
+      return llvm::dyn_cast_or_null<ConcreteAttr>(getAttr(key));
+    }
+
+    /// Returns the name of the kernel.
+    StringAttr getName() const {
+      return getAttr<StringAttr>("sym_name");
+    }
+
+    /// Returns the metadta attribute corresponding to key or nullptr if missing.
+    Attribute getMDAttr(StringRef key) const {
+      if (DictionaryAttr attrs = getMetadata())
+        return attrs.get(key);
+      return nullptr;
+    }
+    template <typename ConcreteAttr>
+    ConcreteAttr getMDAttr(StringRef key) const {
+      return llvm::dyn_cast_or_null<ConcreteAttr>(getMDAttr(key));
+    }
+    Attribute getMDAttr(StringAttr key) const;
+    template <typename ConcreteAttr>
+    ConcreteAttr getMDAttr(StringAttr key) const {
+      return llvm::dyn_cast_or_null<ConcreteAttr>(getMDAttr(key));
+    }
+
+    /// Returns the number of required scalar registers, or nullptr if the field
+    /// is missing.
+    IntegerAttr getSGPR() const {
+      return getMDAttr<IntegerAttr>("sgpr_count");
+    }
+
+    /// Returns the number of required scalar registers, or nullptr if the field
+    /// is missing.
+    IntegerAttr getVGPR() const {
+      return getMDAttr<IntegerAttr>("vgpr_count");
+    }
+
+    /// Returns the number of required scalar registers, or nullptr if the field
+    /// is missing.
+    IntegerAttr getAGPR() const {
+      return getMDAttr<IntegerAttr>("agpr_count");
+    }
+
+    /// Returns the number of spilled SGPR, or nullptr if the field is missing.
+    IntegerAttr getSGPRSpill() const {
+      return getMDAttr<IntegerAttr>("sgpr_spill_count");
+    }
+
+    /// Returns the number of spilled VGPR, or nullptr if the field is missing.
+    IntegerAttr getVGPRSpill() const {
+      return getMDAttr<IntegerAttr>("vgpr_spill_count");
+    }
+
+    /// Helper function for appending metadata to a kernel attribute.
+    ROCDLKernelAttr appendMetadata(ArrayRef<NamedAttribute> attrs) const;
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// ROCDL object metadata
+//===----------------------------------------------------------------------===//
+
+def ROCDL_ObjectMDAttr :
+    ROCDL_Attr<"ROCDLObjectMD", "object_metadata"> {
+  let description = [{
+    ROCDL attribute representing a table of kernels metadata. All the attributes
+    in the dictionary must be of type `#rocdl.kernel`.
+
+    Examples:
+    ```mlir
+      #rocdl.object_metadata<{kernel0 = #rocdl.kernel<...>}>
+    ```
+  }];
+  let parameters = (ins
+    "DictionaryAttr":$kernel_table
+  );
+  let assemblyFormat = [{
+    `<` $kernel_table `>`
+  }];
+  let builders = [
+    AttrBuilderWithInferredContext<(ins "DictionaryAttr":$kernel_table), [{
+      assert(kernel_table && "invalid kernel table");
+      return $_get(kernel_table.getContext(), kernel_table);
+    }]>
+  ];
+  let skipDefaultBuilders = 1;
+  let genVerifyDecl = 1;
+  let extraClassDeclaration = [{
+    /// Helper iterator class for traversing the kernel table.
+    struct KernelIterator
+        : llvm::mapped_iterator_base<KernelIterator,
+                                    llvm::ArrayRef<NamedAttribute>::iterator,
+                                    std::pair<StringAttr, ROCDLKernelAttr>> {
+      using llvm::mapped_iterator_base<
+          KernelIterator, llvm::ArrayRef<NamedAttribute>::iterator,
+          std::pair<StringAttr, ROCDLKernelAttr>>::mapped_iterator_base;
+      /// Map the iterator to the kernel name and a KernelAttribute.
+      std::pair<StringAttr, ROCDLKernelAttr> mapElement(NamedAttribute attr) const {
+        return {attr.getName(), llvm::cast<ROCDLKernelAttr>(attr.getValue())};
+      }
+    };
+    auto begin() const {
+      return KernelIterator(getKernelTable().begin());
+    }
+    auto end() const {
+      return KernelIterator(getKernelTable().end());
+    }
+    size_t size() const {
+      return getKernelTable().size();
+    }
+  }];
+}
+
 #endif // ROCDLIR_OPS
diff --git a/mlir/include/mlir/Target/LLVM/ROCDL/Utils.h b/mlir/include/mlir/Target/LLVM/ROCDL/Utils.h
index 374fa65bd02e3..733d3919a2276 100644
--- a/mlir/include/mlir/Target/LLVM/ROCDL/Utils.h
+++ b/mlir/include/mlir/Target/LLVM/ROCDL/Utils.h
@@ -85,6 +85,12 @@ class SerializeGPUModuleBase : public LLVM::ModuleToObject {
   /// List of LLVM bitcode files to link to.
   SmallVector<std::string> fileList;
 };
+
+/// Returns a dictionary containing kernel metadata for each of the kernels in
+/// `gpuModule`. If `elfData` is valid, then the `amdhsa.kernels` ELF metadata
+/// will be added to the dictionary.
+ROCDLObjectMDAttr getAMDHSAKernelsMetadata(Operation *gpuModule,
+                                           ArrayRef<char> elfData = {});
 } // namespace ROCDL
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
index 0c9c61fad1363..fde5bd1b26cf6 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
@@ -217,6 +217,7 @@ LogicalResult ROCDLDialect::verifyOperationAttribute(Operation *op,
 //===----------------------------------------------------------------------===//
 // ROCDL target attribute.
 //===----------------------------------------------------------------------===//
+
 LogicalResult
 ROCDLTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError,
                         int optLevel, StringRef triple, StringRef chip,
@@ -247,6 +248,37 @@ ROCDLTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// ROCDL object metadata
+//===----------------------------------------------------------------------===//
+
+StringRef mlir::ROCDL::getROCDLObjectMetadataName() {
+  return "rocdl.object_metadata";
+}
+
+ROCDLKernelAttr
+ROCDLKernelAttr::appendMetadata(ArrayRef<NamedAttribute> attrs) const {
+  if (attrs.empty())
+    return *this;
+  NamedAttrList attrList(attrs);
+  attrList.append(getMetadata());
+  return ROCDLKernelAttr::get(getFuncAttrs(),
+                              attrList.getDictionary(getContext()));
+}
+
+LogicalResult
+ROCDLObjectMDAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+                          DictionaryAttr dict) {
+  if (!dict)
+    return emitError() << "table cannot be null";
+  if (llvm::any_of(dict, [](NamedAttribute attr) {
+        return !llvm::isa<ROCDLKernelAttr>(attr.getValue());
+      }))
+    return emitError()
+           << "all the dictionary values must be `#rocdl.kernel` attributes";
+  return success();
+}
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/LLVMIR/ROCDLOps.cpp.inc"
 
diff --git a/mlir/lib/Target/LLVM/CMakeLists.txt b/mlir/lib/Target/LLVM/CMakeLists.txt
index 5a3fa160850b4..08da4966e1499 100644
--- a/mlir/lib/Target/LLVM/CMakeLists.txt
+++ b/mlir/lib/Target/LLVM/CMakeLists.txt
@@ -108,9 +108,11 @@ endif()
 
 add_mlir_dialect_library(MLIRROCDLTarget
   ROCDL/Target.cpp
+  ROCDL/Utils.cpp
 
   LINK_COMPONENTS
   MCParser
+  Object
   ${AMDGPU_LIBS}
 
   LINK_LIBS PUBLIC
diff --git a/mlir/lib/Target/LLVM/ROCDL/Utils.cpp b/mlir/lib/Target/LLVM/ROCDL/Utils.cpp
new file mode 100644
index 0000000000000..fa3b49d94e414
--- /dev/null
+++ b/mlir/lib/Target/LLVM/ROCDL/Utils.cpp
@@ -0,0 +1,225 @@
+//===- Utils.cpp - MLIR ROCDL target utils ----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This files defines ROCDL target related utility classes and functions.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Target/LLVM/ROCDL/Utils.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+
+#include "llvm/ADT/StringMap.h"
+#include "llvm/BinaryFormat/MsgPackDocument.h"
+#include "llvm/Object/ELFObjectFile.h"
+#include "llvm/Object/ObjectFile.h"
+#include "llvm/Support/AMDGPUMetadata.h"
+
+using namespace mlir;
+using namespace mlir::ROCDL;
+
+/// Search the ELF object and return an object containing the `amdhsa.kernels`
+/// metadata note. Function adapted from:
+/// llvm-project/llvm/tools/llvm-readobj/ELFDumper.cpp Also see
+/// `amdhsa.kernels`:
+/// https://llvm.org/docs/AMDGPUUsage.html#code-object-v3-metadata
+template <typename ELFT>
+static std::optional<llvm::msgpack::Document>
+getAMDHSANote(llvm::object::ELFObjectFile<ELFT> &elfObj) {
+  using namespace llvm;
+  using namespace llvm::object;
+  using namespace llvm::ELF;
+  const ELFFile<ELFT> &elf = elfObj.getELFFile();
+  auto secOrErr = elf.sections();
+  if (!secOrErr)
+    return std::nullopt;
+  ArrayRef<typename ELFT::Shdr> sections = *secOrErr;
+  for (auto section : sections) {
+    if (section.sh_type != ELF::SHT_NOTE)
+      continue;
+    size_t align = std::max(static_cast<unsigned>(section.sh_addralign), 4u);
+    Error err = Error::success();
+    for (const typename ELFT::Note note : elf.notes(section, err)) {
+      StringRef name = note.getName();
+      if (name != "AMDGPU")
+        continue;
+      uint32_t type = note.getType();
+      if (type != ELF::NT_AMDGPU_METADATA)
+        continue;
+      ArrayRef<uint8_t> desc = note.getDesc(align);
+      StringRef msgPackString =
+          StringRef(reinterpret_cast<const char *>(desc.data()), desc.size());
+      msgpack::Document msgPackDoc;
+      if (!msgPackDoc.readFromBlob(msgPackString, /*Multi=*/false))
+        return std::nullopt;
+      if (msgPackDoc.getRoot().isScalar())
+        return std::nullopt;
+      return std::optional<llvm::msgpack::Document>(std::move(msgPackDoc));
+    }
+  }
+  return std::nullopt;
+}
+
+/// Return the `amdhsa.kernels` metadata in the ELF object or std::nullopt on
+/// failure. This is a helper function that casts a generic `ObjectFile` to the
+/// appropiate `ELFObjectFile`.
+static std::optional<llvm::msgpack::Document>
+getAMDHSANote(ArrayRef<char> elfData) {
+  using namespace llvm;
+  using namespace llvm::object;
+  if (elfData.empty())
+    return std::nullopt;
+  MemoryBufferRef buffer(StringRef(elfData.data(), elfData.size()), "buffer");
+  Expected<std::unique_ptr<ObjectFile>> objOrErr =
+      ObjectFile::createELFObjectFile(buffer);
+  if (!objOrErr || !objOrErr.get()) {
+    // Drop the error.
+    llvm::consumeError(objOrErr.takeError());
+    return std::nullopt;
+  }
+  ObjectFile &elf = *(objOrErr.get());
+  std::optional<llvm::msgpack::Document> metadata;
+  if (auto *obj = dyn_cast<ELF32LEObjectFile>(&elf))
+    metadata = getAMDHSANote(*obj);
+  else if (auto *obj = dyn_cast<ELF32BEObjectFile>(&elf))
+    metadata = getAMDHSANote(*obj);
+  else if (auto *obj = dyn_cast<ELF64LEObjectFile>(&elf))
+    metadata = getAMDHSANote(*obj);
+  else if (auto *obj = dyn_cast<ELF64BEObjectFile>(&elf))
+    metadata = getAMDHSANote(*obj);
+  return metadata;
+}
+
+/// Utility functions for converting `llvm::msgpack::DocNode` nodes.
+static Attribute convertNode(Builder &builder, llvm::msgpack::DocNode &node);
+static Attribute convertNode(Builder &builder,
+                             llvm::msgpack::MapDocNode &node) {
+  NamedAttrList attrs;
+  for (auto kv : node) {
+    if (!kv.first.isString())
+      continue;
+    if (Attribute attr = convertNode(builder, kv.second)) {
+      auto key = kv.first.getString();
+      key.consume_front(".");
+      key.consume_back(".");
+      attrs.append(key, attr);
+    }
+  }
+  if (attrs.empty())
+    return nullptr;
+  return builder.getDictionaryAttr(attrs);
+}
+
+static Attribute convertNode(Builder &builder,
+                             llvm::msgpack::ArrayDocNode &node) {
+  using NodeKind = llvm::msgpack::Type;
+  // Use `DenseIntAttr` if we know all the attrs are ints.
+  if (llvm::all_of(node, [](llvm::msgpack::DocNode &n) {
+        auto kind = n.getKind();
+        return kind == NodeKind::Int || kind == NodeKind::UInt;
+      })) {
+    SmallVector<int64_t> values;
+    for (llvm::msgpack::DocNode &n : node) {
+      auto kind = n.getKind();
+      if (kind == NodeKind::Int)
+        values.push_back(n.getInt());
+      else if (kind == NodeKind::UInt)
+        values.push_back(n.getUInt());
+    }
+    return builder.getDenseI64ArrayAttr(values);
+  }
+  // Convert the array.
+  SmallVector<Attribute> attrs;
+  for (llvm::msgpack::DocNode &n : node) {
+    if (Attribute attr = convertNode(builder, n))
+      attrs.push_back(attr);
+  }
+  if (attrs.empty())
+    return nullptr;
+  return builder.getArrayAttr(attrs);
+}
+
+static Attribute convertNode(Builder &builder, llvm::msgpack::DocNode &node) {
+  using namespace llvm::msgpack;
+  using NodeKind = llvm::msgpack::Type;
+  switch (node.getKind()) {
+  case NodeKind::Int:
+    return builder.getI64IntegerAttr(node.getInt());
+  case NodeKind::UInt:
+    return builder.getI64IntegerAttr(node.getUInt());
+  case NodeKind::Boolean:
+    return builder.getI64IntegerAttr(node.getBool());
+  case NodeKind::String:
+    return builder.getStringAttr(node.getString());
+  case NodeKind::Array:
+    return convertNode(builder, node.getArray());
+  case NodeKind::Map:
+    return convertNode(builder, node.getMap());
+  default:
+    return nullptr;
+  }
+}
+
+/// The following function should succeed for Code object V3 and above.
+static llvm::StringMap<DictionaryAttr> getELFMetadata(Builder &builder,
+                                                      ArrayRef<char> elfData) {
+  std::optional<llvm::msgpack::Document> metadata = getAMDHSANote(elfData);
+  if (!metadata)
+    return {};
+  llvm::StringMap<DictionaryAttr> kernelMD;
+  llvm::msgpack::DocNode &root = (metadata)->getRoot();
+  // Fail if `root` is not a map -it should be for AMD Obj Ver 3.
+  if (!root.isMap())
+    return kernelMD;
+  auto &kernels = root.getMap()["amdhsa.kernels"];
+  // Fail if `amdhsa.kernels` is not an array.
+  if (!kernels.isArray())
+    return kernelMD;
+  // Convert each of the kernels.
+  for (auto &kernel : kernels.getArray()) {
+    if (!kernel.isMap())
+      continue;
+    auto &kernelMap = kernel.getMap();
+    auto &name = kernelMap[".name"];
+    if (!name.isString())
+      continue;
+    NamedAttrList attrList;
+    // Convert the kernel properties.
+    for (auto kv : kernelMap) {
+      if (!kv.first.isString())
+        continue;
+      StringRef key = kv.first.getString();
+      key.consume_front(".");
+      key.consume_back(".");
+      if (key == "name")
+        continue;
+      if (Attribute attr = convertNode(builder, kv.second))
+        attrList.append(key, attr);
+    }
+    if (!attrList.empty())
+      kernelMD[name.getString()] = builder.getDictionaryAttr(attrList);
+  }
+  return kernelMD;
+}
+
+ROCDLObjectMDAttr
+mlir::ROCDL::getAMDHSAKernelsMetadata(Operation *gpuModule,
+                                      ArrayRef<char> elfData) {
+  auto module = cast<gpu::GPUModuleOp>(gpuModule);
+  Builder builder(module.getContext());
+  NamedAttrList moduleAttrs;
+  llvm::StringMap<DictionaryAttr> mdMap = getELFMetadata(builder, elfData);
+  for (auto funcOp : module.getBody()->getOps<LLVM::LLVMFuncOp>()) {
+    if (!funcOp->getDiscardableAttr("rocdl.kernel"))
+      continue;
+    moduleAttrs.append(funcOp.getName(),
+                       ROCDLKernelAttr::get(funcOp->getAttrDictionary(),
+                                            mdMap.lookup(funcOp.getName())));
+  }
+  return ROCDLObjectMDAttr::get(moduleAttrs.getDictionary(module.getContext()));
+}
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index f5dd5721c45e6..e7c3206b4d6db 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -365,3 +365,10 @@ gpu.module @module_1 [#rocdl.target<O = 1, chip = "gfx900", abi = "500", link =
 
 gpu.module @module_2 [#rocdl.target<chip = "gfx900">, #rocdl.target<chip = "gfx90a">] {
 }
+
+gpu.binary @binary [#gpu.object<#rocdl.target<chip = "gfx900">, properties = {
+    "rocdl.object_metadata" = #rocdl.object_metadata<{
+      kernel0 = #rocdl.kernel<{sym_name = "kernel0"}, metadata = {sgpr_count = 255}>,
+      kernel1 = #rocdl.kernel<{sym_name = "kernel1"}>
+    }>
+  }, bin = "BLOB">]
diff --git a/mlir/unittests/Target/LLVM/SerializeROCDLTarget.cpp b/mlir/unittests/Target/LLVM/SerializeROCDLTarget.cpp
index 33291bc4bcaed..fa9b752daaca4 100644
--- a/mlir/unittests/Target/LLVM/SerializeROCDLTarget.cpp
+++ b/mlir/unittests/Target/LLVM/SerializeROCDLTarget.cpp
@@ -158,3 +158,46 @@ TEST_F(MLIRTargetLLVMROCDL, SKIP_WITHOUT_AMDGPU(SerializeROCDLToBinary)) {
     ASSERT_FALSE(object->empty());
   }
 }
+
+// Test ROCDL serialization to Binary.
+TEST_F(MLIRTargetLLVMROCDL, SKIP_WITHOUT_AMDGPU(GetELFMetadata)) {
+  if (!hasROCMTools())
+    GTEST_SKIP() << "ROCm installation not found, skipping test.";
+
+  MLIRContext context(registry);
+
+  OwningOpRef<ModuleOp> module =
+      parseSourceString<ModuleOp>(moduleStr, &context);
+  ASSERT_TRUE(!!module);
+
+  // Create a ROCDL target.
+  ROCDL::ROCDLTargetAttr target = ROCDL::ROCDLTargetAttr::get(&context);
+
+  // Serialize the module.
+  auto serializer = dyn_cast<gpu::TargetAttrInterface>(target);
+  ASSERT_TRUE(!!serializer);
+  gpu::TargetOptions options("", {}, "", gpu::CompilationTarget::Binary);
+  for (auto gpuModule : (*module).getBody()->getOps<gpu::GPUModuleOp>()) ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/95292


More information about the Mlir-commits mailing list