[Mlir-commits] [mlir] [mlir][gpu] Add metadata attributes for storing kernel metadata in GPU objects (PR #95292)

Fabian Mora llvmlistbot at llvm.org
Tue Jul 30 10:18:38 PDT 2024


https://github.com/fabianmcg updated https://github.com/llvm/llvm-project/pull/95292

>From 26050b5740d0af525e1c04a568e0c1665ca11e05 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Wed, 12 Jun 2024 20:49:09 +0000
Subject: [PATCH 1/8] move to gpu

---
 mlir/include/mlir-c/Dialect/GPU.h             |   3 +-
 .../mlir/Dialect/GPU/IR/CompilationAttrs.td   | 147 +++++++++++-
 mlir/include/mlir/Target/LLVM/ROCDL/Utils.h   |   7 +
 mlir/lib/CAPI/Dialect/GPU.cpp                 |  12 +-
 mlir/lib/Dialect/GPU/IR/GPUDialect.cpp        |  37 ++-
 mlir/lib/Target/LLVM/CMakeLists.txt           |   2 +
 mlir/lib/Target/LLVM/NVVM/Target.cpp          |   2 +-
 mlir/lib/Target/LLVM/ROCDL/Target.cpp         |   6 +-
 mlir/lib/Target/LLVM/ROCDL/Utils.cpp          | 226 ++++++++++++++++++
 mlir/lib/Target/SPIRV/Target.cpp              |   2 +-
 mlir/test/Dialect/GPU/ops.mlir                |   6 +
 .../Target/LLVM/SerializeROCDLTarget.cpp      |  45 ++++
 12 files changed, 483 insertions(+), 12 deletions(-)
 create mode 100644 mlir/lib/Target/LLVM/ROCDL/Utils.cpp

diff --git a/mlir/include/mlir-c/Dialect/GPU.h b/mlir/include/mlir-c/Dialect/GPU.h
index c42ff61f9592c..0c2a603ed9b89 100644
--- a/mlir/include/mlir-c/Dialect/GPU.h
+++ b/mlir/include/mlir-c/Dialect/GPU.h
@@ -35,7 +35,8 @@ MLIR_CAPI_EXPORTED bool mlirAttributeIsAGPUObjectAttr(MlirAttribute attr);
 
 MLIR_CAPI_EXPORTED MlirAttribute
 mlirGPUObjectAttrGet(MlirContext mlirCtx, MlirAttribute target, uint32_t format,
-                     MlirStringRef objectStrRef, MlirAttribute mlirObjectProps);
+                     MlirStringRef objectStrRef, MlirAttribute mlirObjectProps,
+                     MlirAttribute mlirKernelsAttr);
 
 MLIR_CAPI_EXPORTED MlirAttribute
 mlirGPUObjectAttrGetTarget(MlirAttribute mlirObjectAttr);
diff --git a/mlir/include/mlir/Dialect/GPU/IR/CompilationAttrs.td b/mlir/include/mlir/Dialect/GPU/IR/CompilationAttrs.td
index 6659f4a2c58e8..f4037b55c85b4 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/CompilationAttrs.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/CompilationAttrs.td
@@ -16,6 +16,136 @@
 include "mlir/Dialect/GPU/IR/GPUBase.td"
 include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"
 
+//===----------------------------------------------------------------------===//
+// GPU kernel attribute
+//===----------------------------------------------------------------------===//
+
+def GPU_KernelAttr : GPU_Attr<"Kernel", "kernel"> {
+  let description = [{
+    GPU attribute for storing metadata related to a compiled kernel. It
+    contains the attribute dictionary of the LLVM function used to generate the
+    kernel, as well as an optional dictionary for additional metadata, like
+    occupancy information.
+
+    Examples:
+    ```mlir
+      #gpu.kernel<{sym_name = "test_fusion__part_0", ...},
+                   metadata = {reg_count = 255, ...}>
+    ```
+  }];
+  let parameters = (ins
+    "DictionaryAttr":$func_attrs,
+    OptionalParameter<"DictionaryAttr", "metadata dictionary">:$metadata
+  );
+  let assemblyFormat = [{
+    `<` $func_attrs (`,` `metadata` `=` $metadata^ )? `>`
+  }];
+  let builders = [
+    AttrBuilderWithInferredContext<(ins "DictionaryAttr":$funcAttrs,
+                                         CArg<"DictionaryAttr",
+                                              "nullptr">:$metadata), [{
+      assert(funcAttrs && "invalid function attributes dictionary");
+      return $_get(funcAttrs.getContext(), funcAttrs, metadata);
+    }]>,
+    AttrBuilderWithInferredContext<(ins "Operation*":$kernel,
+                                         CArg<"DictionaryAttr",
+                                              "nullptr">:$metadata)>
+  ];
+  let extraClassDeclaration = [{
+    /// Returns the function attribute corresponding to key or nullptr if missing.
+    Attribute getAttr(StringRef key) const {
+      return getFuncAttrs().get(key);
+    }
+    template <typename ConcreteAttr>
+    ConcreteAttr getAttr(StringRef key) const {
+      return llvm::dyn_cast_or_null<ConcreteAttr>(getAttr(key));
+    }
+    Attribute getAttr(StringAttr key) const;
+    template <typename ConcreteAttr>
+    ConcreteAttr getAttr(StringAttr key) const {
+      return llvm::dyn_cast_or_null<ConcreteAttr>(getAttr(key));
+    }
+
+    /// Returns the name of the kernel.
+    StringAttr getName() const {
+      return getAttr<StringAttr>("sym_name");
+    }
+
+    /// Returns the metadta attribute corresponding to key or nullptr if missing.
+    Attribute getMDAttr(StringRef key) const {
+      if (DictionaryAttr attrs = getMetadata())
+        return attrs.get(key);
+      return nullptr;
+    }
+    template <typename ConcreteAttr>
+    ConcreteAttr getMDAttr(StringRef key) const {
+      return llvm::dyn_cast_or_null<ConcreteAttr>(getMDAttr(key));
+    }
+    Attribute getMDAttr(StringAttr key) const;
+    template <typename ConcreteAttr>
+    ConcreteAttr getMDAttr(StringAttr key) const {
+      return llvm::dyn_cast_or_null<ConcreteAttr>(getMDAttr(key));
+    }
+
+    /// Helper function for appending metadata to a kernel attribute.
+    KernelAttr appendMetadata(ArrayRef<NamedAttribute> attrs) const;
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// GPU kernel table attribute
+//===----------------------------------------------------------------------===//
+
+def GPU_KernelTableAttr : GPU_Attr<"KernelTable", "kernel_table"> {
+  let description = [{
+    GPU attribute representing a table of kernels metadata. All the attributes
+    in the dictionary must be of type `#gpu.kernel`.
+
+    Examples:
+    ```mlir
+      #gpu.kernel_table<{kernel0 = #gpu.kernel<...>}>
+    ```
+  }];
+  let parameters = (ins
+    "DictionaryAttr":$kernel_table
+  );
+  let assemblyFormat = [{
+    `<` $kernel_table `>`
+  }];
+  let builders = [
+    AttrBuilderWithInferredContext<(ins "DictionaryAttr":$kernel_table), [{
+      assert(kernel_table && "invalid kernel table");
+      return $_get(kernel_table.getContext(), kernel_table);
+    }]>
+  ];
+  let skipDefaultBuilders = 1;
+  let genVerifyDecl = 1;
+  let extraClassDeclaration = [{
+    /// Helper iterator class for traversing the kernel table.
+    struct KernelIterator
+        : llvm::mapped_iterator_base<KernelIterator,
+                                    llvm::ArrayRef<NamedAttribute>::iterator,
+                                    std::pair<StringAttr, KernelAttr>> {
+      using llvm::mapped_iterator_base<
+          KernelIterator, llvm::ArrayRef<NamedAttribute>::iterator,
+          std::pair<StringAttr, KernelAttr>>::mapped_iterator_base;
+      /// Map the iterator to the kernel name and a KernelAttribute.
+      std::pair<StringAttr, KernelAttr> mapElement(NamedAttribute attr) const {
+        return {attr.getName(), llvm::cast<KernelAttr>(attr.getValue())};
+      }
+    };
+    auto begin() const {
+      return KernelIterator(getKernelTable().begin());
+    }
+    auto end() const {
+      return KernelIterator(getKernelTable().end());
+    }
+    size_t size() const {
+      return getKernelTable().size();
+    }
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // GPU object attribute.
 //===----------------------------------------------------------------------===//
@@ -63,16 +193,29 @@ def GPU_ObjectAttr : GPU_Attr<"Object", "object"> {
       #gpu.object<#nvvm.target, properties = {O = 3 : i32}, assembly = "..."> // An assembly object with additional properties.
       #gpu.object<#rocdl.target, bin = "..."> // A binary object.
       #gpu.object<#nvvm.target, "..."> // A fatbin object.
+      #gpu.object<#nvvm.target, kernels = #gpu.kernel_table<...>, "..."> // An object with a kernel table.
     ```
   }];
   let parameters = (ins
     "Attribute":$target,
     DefaultValuedParameter<"CompilationTarget", "CompilationTarget::Fatbin">:$format,
     "StringAttr":$object,
-    OptionalParameter<"DictionaryAttr">:$properties
+    OptionalParameter<"DictionaryAttr">:$properties,
+    OptionalParameter<"KernelTableAttr">:$kernels
   );
+  let builders = [
+    AttrBuilderWithInferredContext<(ins "Attribute":$target,
+                                        "CompilationTarget":$format,
+                                        "StringAttr":$object,
+                                        CArg<"DictionaryAttr", "nullptr">:$properties,
+                                        CArg<"KernelTableAttr", "nullptr">:$kernels), [{
+      assert(target && "invalid target");
+      return $_get(target.getContext(), target, format, object, properties, kernels);
+    }]>
+  ];
   let assemblyFormat = [{ `<`
-      $target `,`  (`properties` `=` $properties ^ `,`)?
+      $target `,`  (`properties` `=` $properties^ `,`)?
+      (`kernels` `=` $kernels^ `,`)?
       custom<Object>($format, $object)
     `>`
   }];
diff --git a/mlir/include/mlir/Target/LLVM/ROCDL/Utils.h b/mlir/include/mlir/Target/LLVM/ROCDL/Utils.h
index 3c637a01b0e3b..904d60a2a75db 100644
--- a/mlir/include/mlir/Target/LLVM/ROCDL/Utils.h
+++ b/mlir/include/mlir/Target/LLVM/ROCDL/Utils.h
@@ -14,6 +14,7 @@
 #define MLIR_TARGET_LLVM_ROCDL_UTILS_H
 
 #include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Target/LLVM/ModuleToObject.h"
@@ -107,6 +108,12 @@ class SerializeGPUModuleBase : public LLVM::ModuleToObject {
   /// AMD GCN libraries to use when linking, the default is using none.
   AMDGCNLibraries deviceLibs = AMDGCNLibraries::None;
 };
+
+/// Returns a `#gpu.kernel_table` containing kernel metadata for each of the
+/// kernels in `gpuModule`. If `elfData` is valid, then the `amdhsa.kernels` ELF
+/// metadata will be added to the `#gpu.kernel_table`.
+gpu::KernelTableAttr getAMDHSAKernelsMetadata(Operation *gpuModule,
+                                              ArrayRef<char> elfData = {});
 } // namespace ROCDL
 } // namespace mlir
 
diff --git a/mlir/lib/CAPI/Dialect/GPU.cpp b/mlir/lib/CAPI/Dialect/GPU.cpp
index 0acebb2300429..ffe60658cb2ce 100644
--- a/mlir/lib/CAPI/Dialect/GPU.cpp
+++ b/mlir/lib/CAPI/Dialect/GPU.cpp
@@ -37,15 +37,19 @@ bool mlirAttributeIsAGPUObjectAttr(MlirAttribute attr) {
 
 MlirAttribute mlirGPUObjectAttrGet(MlirContext mlirCtx, MlirAttribute target,
                                    uint32_t format, MlirStringRef objectStrRef,
-                                   MlirAttribute mlirObjectProps) {
+                                   MlirAttribute mlirObjectProps,
+                                   MlirAttribute mlirKernelsAttr) {
   MLIRContext *ctx = unwrap(mlirCtx);
   llvm::StringRef object = unwrap(objectStrRef);
   DictionaryAttr objectProps;
   if (mlirObjectProps.ptr != nullptr)
     objectProps = llvm::cast<DictionaryAttr>(unwrap(mlirObjectProps));
-  return wrap(gpu::ObjectAttr::get(ctx, unwrap(target),
-                                   static_cast<gpu::CompilationTarget>(format),
-                                   StringAttr::get(ctx, object), objectProps));
+  gpu::KernelTableAttr kernels;
+  if (mlirKernelsAttr.ptr != nullptr)
+    kernels = llvm::cast<gpu::KernelTableAttr>(unwrap(mlirKernelsAttr));
+  return wrap(gpu::ObjectAttr::get(
+      ctx, unwrap(target), static_cast<gpu::CompilationTarget>(format),
+      StringAttr::get(ctx, object), objectProps, kernels));
 }
 
 MlirAttribute mlirGPUObjectAttrGetTarget(MlirAttribute mlirObjectAttr) {
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 7bc2668310ddb..7873a0f89ef94 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -2147,7 +2147,8 @@ void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
 
 LogicalResult ObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
                                  Attribute target, CompilationTarget format,
-                                 StringAttr object, DictionaryAttr properties) {
+                                 StringAttr object, DictionaryAttr properties,
+                                 KernelTableAttr kernels) {
   if (!target)
     return emitError() << "the target attribute cannot be null";
   if (target.hasPromiseOrImplementsInterface<TargetAttrInterface>())
@@ -2233,6 +2234,40 @@ LogicalResult gpu::DynamicSharedMemoryOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// GPU KernelAttr
+//===----------------------------------------------------------------------===//
+
+KernelAttr KernelAttr::get(Operation *kernelOp, DictionaryAttr metadata) {
+  assert(kernelOp && "invalid kernel");
+  return get(kernelOp->getAttrDictionary(), metadata);
+}
+
+KernelAttr KernelAttr::appendMetadata(ArrayRef<NamedAttribute> attrs) const {
+  if (attrs.empty())
+    return *this;
+  NamedAttrList attrList(attrs);
+  attrList.append(getMetadata());
+  return KernelAttr::get(getFuncAttrs(), attrList.getDictionary(getContext()));
+}
+
+//===----------------------------------------------------------------------===//
+// GPU KernelTableAttr
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+KernelTableAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+                        DictionaryAttr dict) {
+  if (!dict)
+    return emitError() << "table cannot be null";
+  if (llvm::any_of(dict, [](NamedAttribute attr) {
+        return !llvm::isa<KernelAttr>(attr.getValue());
+      }))
+    return emitError()
+           << "all the dictionary values must be `#gpu.kernel` attributes";
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // GPU target options
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVM/CMakeLists.txt b/mlir/lib/Target/LLVM/CMakeLists.txt
index 93dc5ff9d35b7..4999ce00ba5bc 100644
--- a/mlir/lib/Target/LLVM/CMakeLists.txt
+++ b/mlir/lib/Target/LLVM/CMakeLists.txt
@@ -110,11 +110,13 @@ endif()
 
 add_mlir_dialect_library(MLIRROCDLTarget
   ROCDL/Target.cpp
+  ROCDL/Utils.cpp
 
   OBJECT
 
   LINK_COMPONENTS
   MCParser
+  Object
   ${AMDGPU_LIBS}
 
   LINK_LIBS PUBLIC
diff --git a/mlir/lib/Target/LLVM/NVVM/Target.cpp b/mlir/lib/Target/LLVM/NVVM/Target.cpp
index e608d26e8d2ec..a5099dc033765 100644
--- a/mlir/lib/Target/LLVM/NVVM/Target.cpp
+++ b/mlir/lib/Target/LLVM/NVVM/Target.cpp
@@ -604,5 +604,5 @@ NVVMTargetAttrImpl::createObject(Attribute attribute,
   return builder.getAttr<gpu::ObjectAttr>(
       attribute, format,
       builder.getStringAttr(StringRef(object.data(), object.size())),
-      objectProps);
+      objectProps, nullptr);
 }
diff --git a/mlir/lib/Target/LLVM/ROCDL/Target.cpp b/mlir/lib/Target/LLVM/ROCDL/Target.cpp
index 4d23f987eb05e..42053e3ec18d3 100644
--- a/mlir/lib/Target/LLVM/ROCDL/Target.cpp
+++ b/mlir/lib/Target/LLVM/ROCDL/Target.cpp
@@ -512,7 +512,9 @@ ROCDLTargetAttrImpl::createObject(Attribute attribute,
   DictionaryAttr properties{};
   Builder builder(attribute.getContext());
   return builder.getAttr<gpu::ObjectAttr>(
-      attribute, format,
+      attribute,
+      format > gpu::CompilationTarget::Binary ? gpu::CompilationTarget::Binary
+                                              : format,
       builder.getStringAttr(StringRef(object.data(), object.size())),
-      properties);
+      properties, nullptr);
 }
diff --git a/mlir/lib/Target/LLVM/ROCDL/Utils.cpp b/mlir/lib/Target/LLVM/ROCDL/Utils.cpp
new file mode 100644
index 0000000000000..5029293bae2a6
--- /dev/null
+++ b/mlir/lib/Target/LLVM/ROCDL/Utils.cpp
@@ -0,0 +1,226 @@
+//===- Utils.cpp - MLIR ROCDL target utils ----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This files defines ROCDL target related utility classes and functions.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Target/LLVM/ROCDL/Utils.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+
+#include "llvm/ADT/StringMap.h"
+#include "llvm/BinaryFormat/MsgPackDocument.h"
+#include "llvm/Object/ELFObjectFile.h"
+#include "llvm/Object/ObjectFile.h"
+#include "llvm/Support/AMDGPUMetadata.h"
+
+using namespace mlir;
+using namespace mlir::ROCDL;
+
+/// Search the ELF object and return an object containing the `amdhsa.kernels`
+/// metadata note. Function adapted from:
+/// llvm-project/llvm/tools/llvm-readobj/ELFDumper.cpp Also see
+/// `amdhsa.kernels`:
+/// https://llvm.org/docs/AMDGPUUsage.html#code-object-v3-metadata
+template <typename ELFT>
+static std::optional<llvm::msgpack::Document>
+getAMDHSANote(llvm::object::ELFObjectFile<ELFT> &elfObj) {
+  using namespace llvm;
+  using namespace llvm::object;
+  using namespace llvm::ELF;
+  const ELFFile<ELFT> &elf = elfObj.getELFFile();
+  auto secOrErr = elf.sections();
+  if (!secOrErr)
+    return std::nullopt;
+  ArrayRef<typename ELFT::Shdr> sections = *secOrErr;
+  for (auto section : sections) {
+    if (section.sh_type != ELF::SHT_NOTE)
+      continue;
+    size_t align = std::max(static_cast<unsigned>(section.sh_addralign), 4u);
+    Error err = Error::success();
+    for (const typename ELFT::Note note : elf.notes(section, err)) {
+      StringRef name = note.getName();
+      if (name != "AMDGPU")
+        continue;
+      uint32_t type = note.getType();
+      if (type != ELF::NT_AMDGPU_METADATA)
+        continue;
+      ArrayRef<uint8_t> desc = note.getDesc(align);
+      StringRef msgPackString =
+          StringRef(reinterpret_cast<const char *>(desc.data()), desc.size());
+      msgpack::Document msgPackDoc;
+      if (!msgPackDoc.readFromBlob(msgPackString, /*Multi=*/false))
+        return std::nullopt;
+      if (msgPackDoc.getRoot().isScalar())
+        return std::nullopt;
+      return std::optional<llvm::msgpack::Document>(std::move(msgPackDoc));
+    }
+  }
+  return std::nullopt;
+}
+
+/// Return the `amdhsa.kernels` metadata in the ELF object or std::nullopt on
+/// failure. This is a helper function that casts a generic `ObjectFile` to the
+/// appropiate `ELFObjectFile`.
+static std::optional<llvm::msgpack::Document>
+getAMDHSANote(ArrayRef<char> elfData) {
+  using namespace llvm;
+  using namespace llvm::object;
+  if (elfData.empty())
+    return std::nullopt;
+  MemoryBufferRef buffer(StringRef(elfData.data(), elfData.size()), "buffer");
+  Expected<std::unique_ptr<ObjectFile>> objOrErr =
+      ObjectFile::createELFObjectFile(buffer);
+  if (!objOrErr || !objOrErr.get()) {
+    // Drop the error.
+    llvm::consumeError(objOrErr.takeError());
+    return std::nullopt;
+  }
+  ObjectFile &elf = *(objOrErr.get());
+  std::optional<llvm::msgpack::Document> metadata;
+  if (auto *obj = dyn_cast<ELF32LEObjectFile>(&elf))
+    metadata = getAMDHSANote(*obj);
+  else if (auto *obj = dyn_cast<ELF32BEObjectFile>(&elf))
+    metadata = getAMDHSANote(*obj);
+  else if (auto *obj = dyn_cast<ELF64LEObjectFile>(&elf))
+    metadata = getAMDHSANote(*obj);
+  else if (auto *obj = dyn_cast<ELF64BEObjectFile>(&elf))
+    metadata = getAMDHSANote(*obj);
+  return metadata;
+}
+
+/// Utility functions for converting `llvm::msgpack::DocNode` nodes.
+static Attribute convertNode(Builder &builder, llvm::msgpack::DocNode &node);
+static Attribute convertNode(Builder &builder,
+                             llvm::msgpack::MapDocNode &node) {
+  NamedAttrList attrs;
+  for (auto kv : node) {
+    if (!kv.first.isString())
+      continue;
+    if (Attribute attr = convertNode(builder, kv.second)) {
+      auto key = kv.first.getString();
+      key.consume_front(".");
+      key.consume_back(".");
+      attrs.append(key, attr);
+    }
+  }
+  if (attrs.empty())
+    return nullptr;
+  return builder.getDictionaryAttr(attrs);
+}
+
+static Attribute convertNode(Builder &builder,
+                             llvm::msgpack::ArrayDocNode &node) {
+  using NodeKind = llvm::msgpack::Type;
+  // Use `DenseIntAttr` if we know all the attrs are ints.
+  if (llvm::all_of(node, [](llvm::msgpack::DocNode &n) {
+        auto kind = n.getKind();
+        return kind == NodeKind::Int || kind == NodeKind::UInt;
+      })) {
+    SmallVector<int64_t> values;
+    for (llvm::msgpack::DocNode &n : node) {
+      auto kind = n.getKind();
+      if (kind == NodeKind::Int)
+        values.push_back(n.getInt());
+      else if (kind == NodeKind::UInt)
+        values.push_back(n.getUInt());
+    }
+    return builder.getDenseI64ArrayAttr(values);
+  }
+  // Convert the array.
+  SmallVector<Attribute> attrs;
+  for (llvm::msgpack::DocNode &n : node) {
+    if (Attribute attr = convertNode(builder, n))
+      attrs.push_back(attr);
+  }
+  if (attrs.empty())
+    return nullptr;
+  return builder.getArrayAttr(attrs);
+}
+
+static Attribute convertNode(Builder &builder, llvm::msgpack::DocNode &node) {
+  using namespace llvm::msgpack;
+  using NodeKind = llvm::msgpack::Type;
+  switch (node.getKind()) {
+  case NodeKind::Int:
+    return builder.getI64IntegerAttr(node.getInt());
+  case NodeKind::UInt:
+    return builder.getI64IntegerAttr(node.getUInt());
+  case NodeKind::Boolean:
+    return builder.getI64IntegerAttr(node.getBool());
+  case NodeKind::String:
+    return builder.getStringAttr(node.getString());
+  case NodeKind::Array:
+    return convertNode(builder, node.getArray());
+  case NodeKind::Map:
+    return convertNode(builder, node.getMap());
+  default:
+    return nullptr;
+  }
+}
+
+/// The following function should succeed for Code object V3 and above.
+static llvm::StringMap<DictionaryAttr> getELFMetadata(Builder &builder,
+                                                      ArrayRef<char> elfData) {
+  std::optional<llvm::msgpack::Document> metadata = getAMDHSANote(elfData);
+  if (!metadata)
+    return {};
+  llvm::StringMap<DictionaryAttr> kernelMD;
+  llvm::msgpack::DocNode &root = (metadata)->getRoot();
+  // Fail if `root` is not a map -it should be for AMD Obj Ver 3.
+  if (!root.isMap())
+    return kernelMD;
+  auto &kernels = root.getMap()["amdhsa.kernels"];
+  // Fail if `amdhsa.kernels` is not an array.
+  if (!kernels.isArray())
+    return kernelMD;
+  // Convert each of the kernels.
+  for (auto &kernel : kernels.getArray()) {
+    if (!kernel.isMap())
+      continue;
+    auto &kernelMap = kernel.getMap();
+    auto &name = kernelMap[".name"];
+    if (!name.isString())
+      continue;
+    NamedAttrList attrList;
+    // Convert the kernel properties.
+    for (auto kv : kernelMap) {
+      if (!kv.first.isString())
+        continue;
+      StringRef key = kv.first.getString();
+      key.consume_front(".");
+      key.consume_back(".");
+      if (key == "name")
+        continue;
+      if (Attribute attr = convertNode(builder, kv.second))
+        attrList.append(key, attr);
+    }
+    if (!attrList.empty())
+      kernelMD[name.getString()] = builder.getDictionaryAttr(attrList);
+  }
+  return kernelMD;
+}
+
+gpu::KernelTableAttr
+mlir::ROCDL::getAMDHSAKernelsMetadata(Operation *gpuModule,
+                                      ArrayRef<char> elfData) {
+  auto module = cast<gpu::GPUModuleOp>(gpuModule);
+  Builder builder(module.getContext());
+  NamedAttrList moduleAttrs;
+  llvm::StringMap<DictionaryAttr> mdMap = getELFMetadata(builder, elfData);
+  for (auto funcOp : module.getBody()->getOps<LLVM::LLVMFuncOp>()) {
+    if (!funcOp->getDiscardableAttr("rocdl.kernel"))
+      continue;
+    moduleAttrs.append(
+        funcOp.getName(),
+        gpu::KernelAttr::get(funcOp, mdMap.lookup(funcOp.getName())));
+  }
+  return gpu::KernelTableAttr::get(
+      moduleAttrs.getDictionary(module.getContext()));
+}
diff --git a/mlir/lib/Target/SPIRV/Target.cpp b/mlir/lib/Target/SPIRV/Target.cpp
index 4c416abe71cac..e7651b5f3c767 100644
--- a/mlir/lib/Target/SPIRV/Target.cpp
+++ b/mlir/lib/Target/SPIRV/Target.cpp
@@ -98,5 +98,5 @@ SPIRVTargetAttrImpl::createObject(Attribute attribute,
   return builder.getAttr<gpu::ObjectAttr>(
       attribute, format,
       builder.getStringAttr(StringRef(object.data(), object.size())),
-      objectProps);
+      objectProps, nullptr);
 }
diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir
index ba7897f4e80cb..692ef2a5e3bef 100644
--- a/mlir/test/Dialect/GPU/ops.mlir
+++ b/mlir/test/Dialect/GPU/ops.mlir
@@ -441,3 +441,9 @@ gpu.module @module_with_two_target [#nvvm.target, #rocdl.target<chip = "gfx90a">
 
 gpu.module @module_with_offload_handler <#gpu.select_object<0>> [#nvvm.target] {
 }
+
+
+gpu.binary @binary [#gpu.object<#rocdl.target<chip = "gfx900">, kernels = #gpu.kernel_table<{
+    kernel0 = #gpu.kernel<{sym_name = "kernel0"}, metadata = {sgpr_count = 255}>,
+    kernel1 = #gpu.kernel<{sym_name = "kernel1"}>
+  }> , bin = "BLOB">]
diff --git a/mlir/unittests/Target/LLVM/SerializeROCDLTarget.cpp b/mlir/unittests/Target/LLVM/SerializeROCDLTarget.cpp
index 33291bc4bcaed..3d5c84efb6f4f 100644
--- a/mlir/unittests/Target/LLVM/SerializeROCDLTarget.cpp
+++ b/mlir/unittests/Target/LLVM/SerializeROCDLTarget.cpp
@@ -158,3 +158,48 @@ TEST_F(MLIRTargetLLVMROCDL, SKIP_WITHOUT_AMDGPU(SerializeROCDLToBinary)) {
     ASSERT_FALSE(object->empty());
   }
 }
+
+// Test ROCDL metadata.
+TEST_F(MLIRTargetLLVMROCDL, SKIP_WITHOUT_AMDGPU(GetELFMetadata)) {
+  if (!hasROCMTools())
+    GTEST_SKIP() << "ROCm installation not found, skipping test.";
+
+  MLIRContext context(registry);
+
+  OwningOpRef<ModuleOp> module =
+      parseSourceString<ModuleOp>(moduleStr, &context);
+  ASSERT_TRUE(!!module);
+
+  // Create a ROCDL target.
+  ROCDL::ROCDLTargetAttr target = ROCDL::ROCDLTargetAttr::get(&context);
+
+  // Serialize the module.
+  auto serializer = dyn_cast<gpu::TargetAttrInterface>(target);
+  ASSERT_TRUE(!!serializer);
+  gpu::TargetOptions options("", {}, "", gpu::CompilationTarget::Binary);
+  for (auto gpuModule : (*module).getBody()->getOps<gpu::GPUModuleOp>()) {
+    std::optional<SmallVector<char, 0>> object =
+        serializer.serializeToObject(gpuModule, options);
+    // Check that the serializer was successful.
+    ASSERT_TRUE(object != std::nullopt);
+    ASSERT_FALSE(object->empty());
+    if (!object)
+      continue;
+    // Get the metadata.
+    gpu::KernelTableAttr metadata =
+        ROCDL::getAMDHSAKernelsMetadata(gpuModule, *object);
+    ASSERT_TRUE(metadata != nullptr);
+    // There should be only a single kernel.
+    ASSERT_TRUE(metadata.size() == 1);
+    // Test the `ROCDLObjectMDAttr` iterators.
+    for (auto [name, kernel] : metadata) {
+      ASSERT_TRUE(name.getValue() == "rocdl_kernel");
+      // Check that the ELF metadata is present.
+      ASSERT_TRUE(kernel.getMetadata() != nullptr);
+      // Verify that `sgpr_count` is present and it is an integer attribute.
+      ASSERT_TRUE(kernel.getMDAttr<IntegerAttr>("sgpr_count") != nullptr);
+      // Verify that `vgpr_count` is present and it is an integer attribute.
+      ASSERT_TRUE(kernel.getMDAttr<IntegerAttr>("vgpr_count") != nullptr);
+    }
+  }
+}

>From 73b6cfc3d696ff05994b7c650bb13a3371507407 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Wed, 12 Jun 2024 21:11:57 +0000
Subject: [PATCH 2/8] fix python bindings

---
 mlir/include/mlir-c/Dialect/GPU.h       |  6 ++++++
 mlir/lib/Bindings/Python/DialectGPU.cpp | 21 ++++++++++++++++-----
 mlir/lib/CAPI/Dialect/GPU.cpp           | 12 ++++++++++++
 3 files changed, 34 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir-c/Dialect/GPU.h b/mlir/include/mlir-c/Dialect/GPU.h
index 0c2a603ed9b89..d6d8eee639de9 100644
--- a/mlir/include/mlir-c/Dialect/GPU.h
+++ b/mlir/include/mlir-c/Dialect/GPU.h
@@ -53,6 +53,12 @@ mlirGPUObjectAttrHasProperties(MlirAttribute mlirObjectAttr);
 MLIR_CAPI_EXPORTED MlirAttribute
 mlirGPUObjectAttrGetProperties(MlirAttribute mlirObjectAttr);
 
+MLIR_CAPI_EXPORTED bool
+mlirGPUObjectAttrHasKernels(MlirAttribute mlirObjectAttr);
+
+MLIR_CAPI_EXPORTED MlirAttribute
+mlirGPUObjectAttrGetKernels(MlirAttribute mlirObjectAttr);
+
 #ifdef __cplusplus
 }
 #endif
diff --git a/mlir/lib/Bindings/Python/DialectGPU.cpp b/mlir/lib/Bindings/Python/DialectGPU.cpp
index a9e339b50dabc..d4df5d5b58130 100644
--- a/mlir/lib/Bindings/Python/DialectGPU.cpp
+++ b/mlir/lib/Bindings/Python/DialectGPU.cpp
@@ -48,17 +48,21 @@ PYBIND11_MODULE(_mlirDialectsGPU, m) {
       .def_classmethod(
           "get",
           [](py::object cls, MlirAttribute target, uint32_t format,
-             py::bytes object, std::optional<MlirAttribute> mlirObjectProps) {
+             py::bytes object, std::optional<MlirAttribute> mlirObjectProps,
+             std::optional<MlirAttribute> mlirKernelsAttr) {
             py::buffer_info info(py::buffer(object).request());
             MlirStringRef objectStrRef =
                 mlirStringRefCreate(static_cast<char *>(info.ptr), info.size);
             return cls(mlirGPUObjectAttrGet(
                 mlirAttributeGetContext(target), target, format, objectStrRef,
                 mlirObjectProps.has_value() ? *mlirObjectProps
+                                            : MlirAttribute{nullptr},
+                mlirKernelsAttr.has_value() ? *mlirKernelsAttr
                                             : MlirAttribute{nullptr}));
           },
           "cls"_a, "target"_a, "format"_a, "object"_a,
-          "properties"_a = py::none(), "Gets a gpu.object from parameters.")
+          "properties"_a = py::none(), "kernels"_a = py::none(),
+          "Gets a gpu.object from parameters.")
       .def_property_readonly(
           "target",
           [](MlirAttribute self) { return mlirGPUObjectAttrGetTarget(self); })
@@ -71,9 +75,16 @@ PYBIND11_MODULE(_mlirDialectsGPU, m) {
             MlirStringRef stringRef = mlirGPUObjectAttrGetObject(self);
             return py::bytes(stringRef.data, stringRef.length);
           })
-      .def_property_readonly("properties", [](MlirAttribute self) {
-        if (mlirGPUObjectAttrHasProperties(self))
-          return py::cast(mlirGPUObjectAttrGetProperties(self));
+      .def_property_readonly("properties",
+                             [](MlirAttribute self) {
+                               if (mlirGPUObjectAttrHasProperties(self))
+                                 return py::cast(
+                                     mlirGPUObjectAttrGetProperties(self));
+                               return py::none().cast<py::object>();
+                             })
+      .def_property_readonly("kernels", [](MlirAttribute self) {
+        if (mlirGPUObjectAttrHasKernels(self))
+          return py::cast(mlirGPUObjectAttrGetKernels(self));
         return py::none().cast<py::object>();
       });
 }
diff --git a/mlir/lib/CAPI/Dialect/GPU.cpp b/mlir/lib/CAPI/Dialect/GPU.cpp
index ffe60658cb2ce..61b3c23f21bed 100644
--- a/mlir/lib/CAPI/Dialect/GPU.cpp
+++ b/mlir/lib/CAPI/Dialect/GPU.cpp
@@ -82,3 +82,15 @@ MlirAttribute mlirGPUObjectAttrGetProperties(MlirAttribute mlirObjectAttr) {
       llvm::cast<gpu::ObjectAttr>(unwrap(mlirObjectAttr));
   return wrap(objectAttr.getProperties());
 }
+
+bool mlirGPUObjectAttrHasKernels(MlirAttribute mlirObjectAttr) {
+  gpu::ObjectAttr objectAttr =
+      llvm::cast<gpu::ObjectAttr>(unwrap(mlirObjectAttr));
+  return objectAttr.getKernels() != nullptr;
+}
+
+MlirAttribute mlirGPUObjectAttrGetKernels(MlirAttribute mlirObjectAttr) {
+  gpu::ObjectAttr objectAttr =
+      llvm::cast<gpu::ObjectAttr>(unwrap(mlirObjectAttr));
+  return wrap(objectAttr.getKernels());
+}

>From f8ef1b8a78c7e7f403016381ff3b653ab3abedb5 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Thu, 13 Jun 2024 15:03:06 +0000
Subject: [PATCH 3/8] addressed reviewer comments

---
 .../mlir/Dialect/GPU/IR/CompilationAttrs.td   | 96 ++++++++++++-------
 mlir/lib/Dialect/GPU/IR/GPUDialect.cpp        | 54 ++++++++---
 mlir/test/Dialect/GPU/ops.mlir                |  6 +-
 .../Target/LLVM/SerializeROCDLTarget.cpp      |  4 +-
 4 files changed, 108 insertions(+), 52 deletions(-)

diff --git a/mlir/include/mlir/Dialect/GPU/IR/CompilationAttrs.td b/mlir/include/mlir/Dialect/GPU/IR/CompilationAttrs.td
index f4037b55c85b4..0088b729ab8a7 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/CompilationAttrs.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/CompilationAttrs.td
@@ -22,72 +22,83 @@ include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"
 
 def GPU_KernelAttr : GPU_Attr<"Kernel", "kernel"> {
   let description = [{
-    GPU attribute for storing metadata related to a compiled kernel. It
-    contains the attribute dictionary of the LLVM function used to generate the
-    kernel, as well as an optional dictionary for additional metadata, like
-    occupancy information.
+    GPU attribute for storing metadata related to a compiled kernel. The
+    attribute contains the name and function type of the kernel.
+
+    The attribute also contains optional parameters for storing the arguments
+    attributes as well as a dictionary for additional metadata, like occupancy
+    information or other function attributes.
+
+    Note: The `arg_attrs` parameter is expected to follow all the constraints
+    imposed by the `mlir::FunctionOpInterface` interface.
 
     Examples:
     ```mlir
-      #gpu.kernel<{sym_name = "test_fusion__part_0", ...},
-                   metadata = {reg_count = 255, ...}>
+      #gpu.kernel<@kernel1, (i32) -> (), arg_attrs = [...],  metadata = {reg_count = 255, ...}>
+      #gpu.kernel<@kernel2, (i32, f64) -> ()>
     ```
   }];
   let parameters = (ins
-    "DictionaryAttr":$func_attrs,
+    "StringAttr":$name,
+    "Type":$function_type,
+    OptionalParameter<"ArrayAttr", "arguments attributes">:$arg_attrs,
     OptionalParameter<"DictionaryAttr", "metadata dictionary">:$metadata
   );
   let assemblyFormat = [{
-    `<` $func_attrs (`,` `metadata` `=` $metadata^ )? `>`
+    `<` $name `,` $function_type (`,` struct($arg_attrs, $metadata)^)? `>`
   }];
   let builders = [
-    AttrBuilderWithInferredContext<(ins "DictionaryAttr":$funcAttrs,
-                                         CArg<"DictionaryAttr",
-                                              "nullptr">:$metadata), [{
-      assert(funcAttrs && "invalid function attributes dictionary");
-      return $_get(funcAttrs.getContext(), funcAttrs, metadata);
+    AttrBuilderWithInferredContext<(ins "StringAttr":$name,
+                                        "Type":$functionType,
+                                        CArg<"ArrayAttr", "nullptr">:$argAttrs,
+                                        CArg<"DictionaryAttr",
+                                             "nullptr">:$metadata), [{
+      assert(name && "invalid name");
+      return $_get(name.getContext(), name, functionType, argAttrs, metadata);
     }]>,
-    AttrBuilderWithInferredContext<(ins "Operation*":$kernel,
+    AttrBuilderWithInferredContext<(ins "FunctionOpInterface":$kernel,
                                          CArg<"DictionaryAttr",
                                               "nullptr">:$metadata)>
   ];
+  let genVerifyDecl = 1;
   let extraClassDeclaration = [{
-    /// Returns the function attribute corresponding to key or nullptr if missing.
+    /// Returns the metadata attribute corresponding to `key` or `nullptr`
+    /// if missing.
     Attribute getAttr(StringRef key) const {
-      return getFuncAttrs().get(key);
+      auto attrs = getMetadata();
+      return attrs ? attrs.get(key) : nullptr;
     }
     template <typename ConcreteAttr>
     ConcreteAttr getAttr(StringRef key) const {
       return llvm::dyn_cast_or_null<ConcreteAttr>(getAttr(key));
     }
-    Attribute getAttr(StringAttr key) const;
+    Attribute getAttr(StringAttr key) const {
+      auto attrs = getMetadata();
+      return attrs ? attrs.get(key) : nullptr;
+    }
     template <typename ConcreteAttr>
     ConcreteAttr getAttr(StringAttr key) const {
       return llvm::dyn_cast_or_null<ConcreteAttr>(getAttr(key));
     }
 
-    /// Returns the name of the kernel.
-    StringAttr getName() const {
-      return getAttr<StringAttr>("sym_name");
+    /// Returns the attribute dictionary at position `index`.
+    DictionaryAttr getArgAttrDict(unsigned index) {
+      auto argArray = getArgAttrs();
+      return argArray ? llvm::cast<DictionaryAttr>(argArray[index]) : nullptr;
     }
 
-    /// Returns the metadta attribute corresponding to key or nullptr if missing.
-    Attribute getMDAttr(StringRef key) const {
-      if (DictionaryAttr attrs = getMetadata())
-        return attrs.get(key);
-      return nullptr;
+    /// Return the specified attribute, if present, for the argument at 'index',
+    /// null otherwise.
+    Attribute getArgAttr(unsigned index, StringAttr name) {
+      auto argDict = getArgAttrDict(index);
+      return argDict ? argDict.get(name) : nullptr;
     }
-    template <typename ConcreteAttr>
-    ConcreteAttr getMDAttr(StringRef key) const {
-      return llvm::dyn_cast_or_null<ConcreteAttr>(getMDAttr(key));
-    }
-    Attribute getMDAttr(StringAttr key) const;
-    template <typename ConcreteAttr>
-    ConcreteAttr getMDAttr(StringAttr key) const {
-      return llvm::dyn_cast_or_null<ConcreteAttr>(getMDAttr(key));
+    Attribute getArgAttr(unsigned index, StringRef name) {
+      auto argDict = getArgAttrDict(index);
+      return argDict ? argDict.get(name) : nullptr;
     }
 
-    /// Helper function for appending metadata to a kernel attribute.
+    /// Returns a new KernelAttr that contains `attrs` in the metadata dictionary.
     KernelAttr appendMetadata(ArrayRef<NamedAttribute> attrs) const;
   }];
 }
@@ -143,6 +154,14 @@ def GPU_KernelTableAttr : GPU_Attr<"KernelTable", "kernel_table"> {
     size_t size() const {
       return getKernelTable().size();
     }
+
+    /// Returns the kernel with name `key` or `nullptr` if not present.
+    KernelAttr lookup(StringRef key) const {
+      return getKernelTable().getAs<KernelAttr>(key);
+    }
+    KernelAttr lookup(StringAttr key) const {
+      return getKernelTable().getAs<KernelAttr>(key);
+    }
   }];
 }
 
@@ -166,8 +185,9 @@ def GPU_CompilationTargetEnum : GPU_I32Enum<
 def GPU_ObjectAttr : GPU_Attr<"Object", "object"> {
   let description = [{
     A GPU object attribute glues together a GPU target, the object kind, a
-    binary string with the object, and the object properties, encapsulating how
-    the object was generated and its properties with the object itself.
+    binary string with the object, the object properties, and kernel metadata,
+    encapsulating how the object was generated and its properties with the
+    object itself.
 
     There are four object formats:
     1. `Offload`: represents generic objects not described by the other three
@@ -185,6 +205,10 @@ def GPU_ObjectAttr : GPU_Attr<"Object", "object"> {
 
     Object properties are specified through the `properties` dictionary
     attribute and can be used to define additional information.
+
+    Kernel metadata is specified through the `kernels` parameter, and can be
+    used to specify additional information on a kernel by kernel basis.
+
     The target attribute must implement or promise the `TargetAttrInterface`
     interface.
 
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 7873a0f89ef94..5396a9c7ac9f5 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -2238,17 +2238,45 @@ LogicalResult gpu::DynamicSharedMemoryOp::verify() {
 // GPU KernelAttr
 //===----------------------------------------------------------------------===//
 
-KernelAttr KernelAttr::get(Operation *kernelOp, DictionaryAttr metadata) {
-  assert(kernelOp && "invalid kernel");
-  return get(kernelOp->getAttrDictionary(), metadata);
+KernelAttr KernelAttr::get(FunctionOpInterface kernel,
+                           DictionaryAttr metadata) {
+  assert(kernel && "invalid kernel");
+  return get(kernel.getNameAttr(), kernel.getFunctionType(),
+             kernel.getAllArgAttrs(), metadata);
+}
+
+KernelAttr KernelAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
+                                  FunctionOpInterface kernel,
+                                  DictionaryAttr metadata) {
+  assert(kernel && "invalid kernel");
+  return getChecked(emitError, kernel.getNameAttr(), kernel.getFunctionType(),
+                    kernel.getAllArgAttrs(), metadata);
 }
 
 KernelAttr KernelAttr::appendMetadata(ArrayRef<NamedAttribute> attrs) const {
   if (attrs.empty())
     return *this;
-  NamedAttrList attrList(attrs);
-  attrList.append(getMetadata());
-  return KernelAttr::get(getFuncAttrs(), attrList.getDictionary(getContext()));
+  NamedAttrList attrList;
+  if (auto dict = getMetadata())
+    attrList.append(dict);
+  attrList.append(attrs);
+  return KernelAttr::get(getName(), getFunctionType(), getArgAttrs(),
+                         attrList.getDictionary(getContext()));
+}
+
+LogicalResult KernelAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+                                 StringAttr name, Type functionType,
+                                 ArrayAttr argAttrs, DictionaryAttr metadata) {
+  if (name.empty())
+    return emitError() << "the kernel name can't be empty";
+  if (argAttrs) {
+    if (llvm::any_of(argAttrs, [](Attribute attr) {
+          return !llvm::isa<DictionaryAttr>(attr);
+        }))
+      return emitError()
+             << "all attributes in the array must be a dictionary attribute";
+  }
+  return success();
 }
 
 //===----------------------------------------------------------------------===//
@@ -2260,11 +2288,15 @@ KernelTableAttr::verify(function_ref<InFlightDiagnostic()> emitError,
                         DictionaryAttr dict) {
   if (!dict)
     return emitError() << "table cannot be null";
-  if (llvm::any_of(dict, [](NamedAttribute attr) {
-        return !llvm::isa<KernelAttr>(attr.getValue());
-      }))
-    return emitError()
-           << "all the dictionary values must be `#gpu.kernel` attributes";
+  for (NamedAttribute attr : dict) {
+    auto kernel = llvm::dyn_cast<KernelAttr>(attr.getValue());
+    if (!kernel)
+      return emitError()
+             << "all the dictionary values must be `#gpu.kernel` attributes";
+    if (kernel.getName() != attr.getName())
+      return emitError() << "expected kernel to be named `" << attr.getName()
+                         << "` but got `" << kernel.getName() << "`";
+  }
   return success();
 }
 
diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir
index 692ef2a5e3bef..d902cf3f41e08 100644
--- a/mlir/test/Dialect/GPU/ops.mlir
+++ b/mlir/test/Dialect/GPU/ops.mlir
@@ -444,6 +444,6 @@ gpu.module @module_with_offload_handler <#gpu.select_object<0>> [#nvvm.target] {
 
 
 gpu.binary @binary [#gpu.object<#rocdl.target<chip = "gfx900">, kernels = #gpu.kernel_table<{
-    kernel0 = #gpu.kernel<{sym_name = "kernel0"}, metadata = {sgpr_count = 255}>,
-    kernel1 = #gpu.kernel<{sym_name = "kernel1"}>
-  }> , bin = "BLOB">]
+    kernel0 = #gpu.kernel<"kernel0", (i32, f32) -> (), metadata = {sgpr_count = 255}>,
+    kernel1 = #gpu.kernel<"kernel1", (i32) -> (), arg_attrs = [{llvm.read_only}]>
+  }>, bin = "BLOB">]
diff --git a/mlir/unittests/Target/LLVM/SerializeROCDLTarget.cpp b/mlir/unittests/Target/LLVM/SerializeROCDLTarget.cpp
index 3d5c84efb6f4f..a75cb05468250 100644
--- a/mlir/unittests/Target/LLVM/SerializeROCDLTarget.cpp
+++ b/mlir/unittests/Target/LLVM/SerializeROCDLTarget.cpp
@@ -197,9 +197,9 @@ TEST_F(MLIRTargetLLVMROCDL, SKIP_WITHOUT_AMDGPU(GetELFMetadata)) {
       // Check that the ELF metadata is present.
       ASSERT_TRUE(kernel.getMetadata() != nullptr);
       // Verify that `sgpr_count` is present and it is an integer attribute.
-      ASSERT_TRUE(kernel.getMDAttr<IntegerAttr>("sgpr_count") != nullptr);
+      ASSERT_TRUE(kernel.getAttr<IntegerAttr>("sgpr_count") != nullptr);
       // Verify that `vgpr_count` is present and it is an integer attribute.
-      ASSERT_TRUE(kernel.getMDAttr<IntegerAttr>("vgpr_count") != nullptr);
+      ASSERT_TRUE(kernel.getAttr<IntegerAttr>("vgpr_count") != nullptr);
     }
   }
 }

>From 4ac6a0d35d2e9d5bbc1bf905c3d6924c24e003f7 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Mon, 17 Jun 2024 16:41:29 +0000
Subject: [PATCH 4/8] expose function retrieving the ELF metadata dict

---
 mlir/include/mlir/Target/LLVM/ROCDL/Utils.h   | 12 ++++++--
 mlir/lib/Target/LLVM/ROCDL/Utils.cpp          | 28 +++++++++++--------
 .../Target/LLVM/SerializeROCDLTarget.cpp      |  2 +-
 3 files changed, 27 insertions(+), 15 deletions(-)

diff --git a/mlir/include/mlir/Target/LLVM/ROCDL/Utils.h b/mlir/include/mlir/Target/LLVM/ROCDL/Utils.h
index 904d60a2a75db..3d2174c144815 100644
--- a/mlir/include/mlir/Target/LLVM/ROCDL/Utils.h
+++ b/mlir/include/mlir/Target/LLVM/ROCDL/Utils.h
@@ -109,11 +109,19 @@ class SerializeGPUModuleBase : public LLVM::ModuleToObject {
   AMDGCNLibraries deviceLibs = AMDGCNLibraries::None;
 };
 
+/// Returns a map containing the `amdhsa.kernels` ELF metadata for each of the
+/// kernels in the binary, or `std::nullopt` if the metadata couldn't be
+/// retrieved. The map associates the name of the kernel with the list of named
+/// attributes found in `amdhsa.kernels`. For more information on the ELF
+/// metadata see: https://llvm.org/docs/AMDGPUUsage.html#amdhsa
+std::optional<DenseMap<StringAttr, NamedAttrList>>
+getAMDHSAKernelsELFMetadata(Builder &builder, ArrayRef<char> elfData);
+
 /// Returns a `#gpu.kernel_table` containing kernel metadata for each of the
 /// kernels in `gpuModule`. If `elfData` is valid, then the `amdhsa.kernels` ELF
 /// metadata will be added to the `#gpu.kernel_table`.
-gpu::KernelTableAttr getAMDHSAKernelsMetadata(Operation *gpuModule,
-                                              ArrayRef<char> elfData = {});
+gpu::KernelTableAttr getKernelMetadata(Operation *gpuModule,
+                                       ArrayRef<char> elfData = {});
 } // namespace ROCDL
 } // namespace mlir
 
diff --git a/mlir/lib/Target/LLVM/ROCDL/Utils.cpp b/mlir/lib/Target/LLVM/ROCDL/Utils.cpp
index 5029293bae2a6..c07ea8eebaf2e 100644
--- a/mlir/lib/Target/LLVM/ROCDL/Utils.cpp
+++ b/mlir/lib/Target/LLVM/ROCDL/Utils.cpp
@@ -166,20 +166,21 @@ static Attribute convertNode(Builder &builder, llvm::msgpack::DocNode &node) {
 }
 
 /// The following function should succeed for Code object V3 and above.
-static llvm::StringMap<DictionaryAttr> getELFMetadata(Builder &builder,
-                                                      ArrayRef<char> elfData) {
+std::optional<DenseMap<StringAttr, NamedAttrList>>
+mlir::ROCDL::getAMDHSAKernelsELFMetadata(Builder &builder,
+                                         ArrayRef<char> elfData) {
   std::optional<llvm::msgpack::Document> metadata = getAMDHSANote(elfData);
   if (!metadata)
-    return {};
-  llvm::StringMap<DictionaryAttr> kernelMD;
+    return std::nullopt;
+  DenseMap<StringAttr, NamedAttrList> kernelMD;
   llvm::msgpack::DocNode &root = (metadata)->getRoot();
   // Fail if `root` is not a map -it should be for AMD Obj Ver 3.
   if (!root.isMap())
-    return kernelMD;
+    return std::nullopt;
   auto &kernels = root.getMap()["amdhsa.kernels"];
   // Fail if `amdhsa.kernels` is not an array.
   if (!kernels.isArray())
-    return kernelMD;
+    return std::nullopt;
   // Convert each of the kernels.
   for (auto &kernel : kernels.getArray()) {
     if (!kernel.isMap())
@@ -202,24 +203,27 @@ static llvm::StringMap<DictionaryAttr> getELFMetadata(Builder &builder,
         attrList.append(key, attr);
     }
     if (!attrList.empty())
-      kernelMD[name.getString()] = builder.getDictionaryAttr(attrList);
+      kernelMD[builder.getStringAttr(name.getString())] = std::move(attrList);
   }
   return kernelMD;
 }
 
-gpu::KernelTableAttr
-mlir::ROCDL::getAMDHSAKernelsMetadata(Operation *gpuModule,
-                                      ArrayRef<char> elfData) {
+gpu::KernelTableAttr mlir::ROCDL::getKernelMetadata(Operation *gpuModule,
+                                                    ArrayRef<char> elfData) {
   auto module = cast<gpu::GPUModuleOp>(gpuModule);
   Builder builder(module.getContext());
   NamedAttrList moduleAttrs;
-  llvm::StringMap<DictionaryAttr> mdMap = getELFMetadata(builder, elfData);
+  std::optional<DenseMap<StringAttr, NamedAttrList>> mdMapOrNull =
+      getAMDHSAKernelsELFMetadata(builder, elfData);
   for (auto funcOp : module.getBody()->getOps<LLVM::LLVMFuncOp>()) {
     if (!funcOp->getDiscardableAttr("rocdl.kernel"))
       continue;
     moduleAttrs.append(
         funcOp.getName(),
-        gpu::KernelAttr::get(funcOp, mdMap.lookup(funcOp.getName())));
+        gpu::KernelAttr::get(
+            funcOp, mdMapOrNull ? builder.getDictionaryAttr(
+                                      mdMapOrNull->lookup(funcOp.getNameAttr()))
+                                : nullptr));
   }
   return gpu::KernelTableAttr::get(
       moduleAttrs.getDictionary(module.getContext()));
diff --git a/mlir/unittests/Target/LLVM/SerializeROCDLTarget.cpp b/mlir/unittests/Target/LLVM/SerializeROCDLTarget.cpp
index a75cb05468250..c5ad5a045d2f7 100644
--- a/mlir/unittests/Target/LLVM/SerializeROCDLTarget.cpp
+++ b/mlir/unittests/Target/LLVM/SerializeROCDLTarget.cpp
@@ -187,7 +187,7 @@ TEST_F(MLIRTargetLLVMROCDL, SKIP_WITHOUT_AMDGPU(GetELFMetadata)) {
       continue;
     // Get the metadata.
     gpu::KernelTableAttr metadata =
-        ROCDL::getAMDHSAKernelsMetadata(gpuModule, *object);
+        ROCDL::getKernelMetadata(gpuModule, *object);
     ASSERT_TRUE(metadata != nullptr);
     // There should be only a single kernel.
     ASSERT_TRUE(metadata.size() == 1);

>From e629c5fc1e4a1127e00ef70217095eee5def488f Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Tue, 30 Jul 2024 15:30:31 +0000
Subject: [PATCH 5/8] address reviewer comments

---
 mlir/include/mlir-c/Dialect/GPU.h             |   8 +-
 .../mlir/Dialect/GPU/IR/CompilationAttrs.td   |  59 +++++-----
 mlir/lib/CAPI/Dialect/GPU.cpp                 |  19 +++-
 mlir/lib/Dialect/GPU/IR/GPUDialect.cpp        |  63 ++++++++---
 mlir/lib/Target/LLVM/ROCDL/Target.cpp         |  11 +-
 mlir/lib/Target/LLVM/ROCDL/Utils.cpp          | 101 +++++++++---------
 mlir/test/Dialect/GPU/invalid.mlir            |  12 +++
 mlir/test/Dialect/GPU/ops.mlir                |  27 ++++-
 .../Target/LLVM/SerializeROCDLTarget.cpp      |  29 ++++-
 9 files changed, 212 insertions(+), 117 deletions(-)

diff --git a/mlir/include/mlir-c/Dialect/GPU.h b/mlir/include/mlir-c/Dialect/GPU.h
index d6d8eee639de9..321c1122c3370 100644
--- a/mlir/include/mlir-c/Dialect/GPU.h
+++ b/mlir/include/mlir-c/Dialect/GPU.h
@@ -35,8 +35,12 @@ MLIR_CAPI_EXPORTED bool mlirAttributeIsAGPUObjectAttr(MlirAttribute attr);
 
 MLIR_CAPI_EXPORTED MlirAttribute
 mlirGPUObjectAttrGet(MlirContext mlirCtx, MlirAttribute target, uint32_t format,
-                     MlirStringRef objectStrRef, MlirAttribute mlirObjectProps,
-                     MlirAttribute mlirKernelsAttr);
+                     MlirStringRef objectStrRef, MlirAttribute mlirObjectProps);
+
+MLIR_CAPI_EXPORTED MlirAttribute mlirGPUObjectAttrGetWithKernels(
+    MlirContext mlirCtx, MlirAttribute target, uint32_t format,
+    MlirStringRef objectStrRef, MlirAttribute mlirObjectProps,
+    MlirAttribute mlirKernelsAttr);
 
 MLIR_CAPI_EXPORTED MlirAttribute
 mlirGPUObjectAttrGetTarget(MlirAttribute mlirObjectAttr);
diff --git a/mlir/include/mlir/Dialect/GPU/IR/CompilationAttrs.td b/mlir/include/mlir/Dialect/GPU/IR/CompilationAttrs.td
index 0088b729ab8a7..78740a1cb5e1d 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/CompilationAttrs.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/CompilationAttrs.td
@@ -62,10 +62,15 @@ def GPU_KernelAttr : GPU_Attr<"Kernel", "kernel"> {
   ];
   let genVerifyDecl = 1;
   let extraClassDeclaration = [{
+    /// Compare two kernels based on the name.
+    bool operator<(const KernelAttr& other) const {
+      return getName().getValue() < other.getName().getValue();
+    }
+
     /// Returns the metadata attribute corresponding to `key` or `nullptr`
     /// if missing.
     Attribute getAttr(StringRef key) const {
-      auto attrs = getMetadata();
+      DictionaryAttr attrs = getMetadata();
       return attrs ? attrs.get(key) : nullptr;
     }
     template <typename ConcreteAttr>
@@ -73,7 +78,7 @@ def GPU_KernelAttr : GPU_Attr<"Kernel", "kernel"> {
       return llvm::dyn_cast_or_null<ConcreteAttr>(getAttr(key));
     }
     Attribute getAttr(StringAttr key) const {
-      auto attrs = getMetadata();
+      DictionaryAttr attrs = getMetadata();
       return attrs ? attrs.get(key) : nullptr;
     }
     template <typename ConcreteAttr>
@@ -83,18 +88,18 @@ def GPU_KernelAttr : GPU_Attr<"Kernel", "kernel"> {
 
     /// Returns the attribute dictionary at position `index`.
     DictionaryAttr getArgAttrDict(unsigned index) {
-      auto argArray = getArgAttrs();
+      ArrayAttr argArray = getArgAttrs();
       return argArray ? llvm::cast<DictionaryAttr>(argArray[index]) : nullptr;
     }
 
     /// Return the specified attribute, if present, for the argument at 'index',
     /// null otherwise.
     Attribute getArgAttr(unsigned index, StringAttr name) {
-      auto argDict = getArgAttrDict(index);
+      DictionaryAttr argDict = getArgAttrDict(index);
       return argDict ? argDict.get(name) : nullptr;
     }
     Attribute getArgAttr(unsigned index, StringRef name) {
-      auto argDict = getArgAttrDict(index);
+      DictionaryAttr argDict = getArgAttrDict(index);
       return argDict ? argDict.get(name) : nullptr;
     }
 
@@ -114,54 +119,38 @@ def GPU_KernelTableAttr : GPU_Attr<"KernelTable", "kernel_table"> {
 
     Examples:
     ```mlir
-      #gpu.kernel_table<{kernel0 = #gpu.kernel<...>}>
+      #gpu.kernel_table<[#gpu.kernel<kernel0, ...>]>
     ```
   }];
   let parameters = (ins
-    "DictionaryAttr":$kernel_table
+    OptionalArrayRefParameter<"KernelAttr", "array of kernels">:$kernel_table
   );
   let assemblyFormat = [{
-    `<` $kernel_table `>`
+    `<` (`[` qualified($kernel_table)^ `]`)? `>`
   }];
   let builders = [
-    AttrBuilderWithInferredContext<(ins "DictionaryAttr":$kernel_table), [{
-      assert(kernel_table && "invalid kernel table");
-      return $_get(kernel_table.getContext(), kernel_table);
-    }]>
+    AttrBuilder<(ins "ArrayRef<KernelAttr>":$kernels,
+                     CArg<"bool", "false">:$isSorted)>
   ];
   let skipDefaultBuilders = 1;
   let genVerifyDecl = 1;
   let extraClassDeclaration = [{
-    /// Helper iterator class for traversing the kernel table.
-    struct KernelIterator
-        : llvm::mapped_iterator_base<KernelIterator,
-                                    llvm::ArrayRef<NamedAttribute>::iterator,
-                                    std::pair<StringAttr, KernelAttr>> {
-      using llvm::mapped_iterator_base<
-          KernelIterator, llvm::ArrayRef<NamedAttribute>::iterator,
-          std::pair<StringAttr, KernelAttr>>::mapped_iterator_base;
-      /// Map the iterator to the kernel name and a KernelAttribute.
-      std::pair<StringAttr, KernelAttr> mapElement(NamedAttribute attr) const {
-        return {attr.getName(), llvm::cast<KernelAttr>(attr.getValue())};
-      }
-    };
-    auto begin() const {
-      return KernelIterator(getKernelTable().begin());
+    llvm::ArrayRef<KernelAttr>::iterator begin() const {
+      return getKernelTable().begin();
     }
-    auto end() const {
-      return KernelIterator(getKernelTable().end());
+    llvm::ArrayRef<KernelAttr>::iterator end() const {
+      return getKernelTable().end();
     }
     size_t size() const {
       return getKernelTable().size();
     }
+    bool empty() const {
+      return getKernelTable().empty();
+    }
 
     /// Returns the kernel with name `key` or `nullptr` if not present.
-    KernelAttr lookup(StringRef key) const {
-      return getKernelTable().getAs<KernelAttr>(key);
-    }
-    KernelAttr lookup(StringAttr key) const {
-      return getKernelTable().getAs<KernelAttr>(key);
-    }
+    KernelAttr lookup(StringRef key) const;
+    KernelAttr lookup(StringAttr key) const;
   }];
 }
 
diff --git a/mlir/lib/CAPI/Dialect/GPU.cpp b/mlir/lib/CAPI/Dialect/GPU.cpp
index 61b3c23f21bed..e4796ed1499ea 100644
--- a/mlir/lib/CAPI/Dialect/GPU.cpp
+++ b/mlir/lib/CAPI/Dialect/GPU.cpp
@@ -37,8 +37,23 @@ bool mlirAttributeIsAGPUObjectAttr(MlirAttribute attr) {
 
 MlirAttribute mlirGPUObjectAttrGet(MlirContext mlirCtx, MlirAttribute target,
                                    uint32_t format, MlirStringRef objectStrRef,
-                                   MlirAttribute mlirObjectProps,
-                                   MlirAttribute mlirKernelsAttr) {
+                                   MlirAttribute mlirObjectProps) {
+  MLIRContext *ctx = unwrap(mlirCtx);
+  llvm::StringRef object = unwrap(objectStrRef);
+  DictionaryAttr objectProps;
+  if (mlirObjectProps.ptr != nullptr)
+    objectProps = llvm::cast<DictionaryAttr>(unwrap(mlirObjectProps));
+  return wrap(gpu::ObjectAttr::get(
+      ctx, unwrap(target), static_cast<gpu::CompilationTarget>(format),
+      StringAttr::get(ctx, object), objectProps, nullptr));
+}
+
+MlirAttribute mlirGPUObjectAttrGetWithKernels(MlirContext mlirCtx,
+                                              MlirAttribute target,
+                                              uint32_t format,
+                                              MlirStringRef objectStrRef,
+                                              MlirAttribute mlirObjectProps,
+                                              MlirAttribute mlirKernelsAttr) {
   MLIRContext *ctx = unwrap(mlirCtx);
   llvm::StringRef object = unwrap(objectStrRef);
   DictionaryAttr objectProps;
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 5396a9c7ac9f5..c78f82fef39f9 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -2257,7 +2257,7 @@ KernelAttr KernelAttr::appendMetadata(ArrayRef<NamedAttribute> attrs) const {
   if (attrs.empty())
     return *this;
   NamedAttrList attrList;
-  if (auto dict = getMetadata())
+  if (DictionaryAttr dict = getMetadata())
     attrList.append(dict);
   attrList.append(attrs);
   return KernelAttr::get(getName(), getFunctionType(), getArgAttrs(),
@@ -2283,23 +2283,62 @@ LogicalResult KernelAttr::verify(function_ref<InFlightDiagnostic()> emitError,
 // GPU KernelTableAttr
 //===----------------------------------------------------------------------===//
 
+KernelTableAttr KernelTableAttr::get(MLIRContext *context,
+                                     ArrayRef<KernelAttr> kernels,
+                                     bool isSorted) {
+  // Note that `is_sorted` is always only invoked once even with assertions ON.
+  assert((!isSorted || llvm::is_sorted(kernels)) &&
+         "expected a sorted kernel array");
+  // Immediately return the attribute if the array is sorted.
+  if (isSorted || llvm::is_sorted(kernels))
+    return Base::get(context, kernels);
+  // Sort the array.
+  SmallVector<KernelAttr> kernelsTmp(kernels);
+  llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
+  return Base::get(context, kernelsTmp);
+}
+
+KernelTableAttr
+KernelTableAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
+                            MLIRContext *context, ArrayRef<KernelAttr> kernels,
+                            bool isSorted) {
+  // Note that `is_sorted` is always only invoked once even with assertions ON.
+  assert((!isSorted || llvm::is_sorted(kernels)) &&
+         "expected a sorted kernel array");
+  // Immediately return the attribute if the array is sorted.
+  if (isSorted || llvm::is_sorted(kernels))
+    return Base::getChecked(emitError, context, kernels);
+  // Sort the array.
+  SmallVector<KernelAttr> kernelsTmp(kernels);
+  llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
+  return Base::getChecked(emitError, context, kernelsTmp);
+}
+
 LogicalResult
 KernelTableAttr::verify(function_ref<InFlightDiagnostic()> emitError,
-                        DictionaryAttr dict) {
-  if (!dict)
-    return emitError() << "table cannot be null";
-  for (NamedAttribute attr : dict) {
-    auto kernel = llvm::dyn_cast<KernelAttr>(attr.getValue());
-    if (!kernel)
-      return emitError()
-             << "all the dictionary values must be `#gpu.kernel` attributes";
-    if (kernel.getName() != attr.getName())
-      return emitError() << "expected kernel to be named `" << attr.getName()
-                         << "` but got `" << kernel.getName() << "`";
+                        ArrayRef<KernelAttr> kernels) {
+  if (kernels.size() < 2)
+    return success();
+  // Check that the kernels are uniquely named.
+  if (std::adjacent_find(kernels.begin(), kernels.end(),
+                         [](KernelAttr l, KernelAttr r) {
+                           return l.getName() == r.getName();
+                         }) != kernels.end()) {
+    return emitError() << "expected all kernels to be uniquely named";
   }
   return success();
 }
 
+KernelAttr KernelTableAttr::lookup(StringRef key) const {
+  auto it = impl::findAttrSorted(begin(), end(), key);
+  return it.second ? *it.first : KernelAttr();
+}
+
+KernelAttr KernelTableAttr::lookup(StringAttr key) const {
+  auto it = impl::findAttrSorted(begin(), end(), key);
+  return it.second ? *it.first : KernelAttr();
+}
+
 //===----------------------------------------------------------------------===//
 // GPU target options
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVM/ROCDL/Target.cpp b/mlir/lib/Target/LLVM/ROCDL/Target.cpp
index 42053e3ec18d3..1ca35f09e255f 100644
--- a/mlir/lib/Target/LLVM/ROCDL/Target.cpp
+++ b/mlir/lib/Target/LLVM/ROCDL/Target.cpp
@@ -508,13 +508,10 @@ ROCDLTargetAttrImpl::createObject(Attribute attribute,
   // supported.
   if (format > gpu::CompilationTarget::Binary)
     format = gpu::CompilationTarget::Binary;
-
   DictionaryAttr properties{};
   Builder builder(attribute.getContext());
-  return builder.getAttr<gpu::ObjectAttr>(
-      attribute,
-      format > gpu::CompilationTarget::Binary ? gpu::CompilationTarget::Binary
-                                              : format,
-      builder.getStringAttr(StringRef(object.data(), object.size())),
-      properties, nullptr);
+  StringAttr objectStr =
+      builder.getStringAttr(StringRef(object.data(), object.size()));
+  return builder.getAttr<gpu::ObjectAttr>(attribute, format, objectStr,
+                                          properties, nullptr);
 }
diff --git a/mlir/lib/Target/LLVM/ROCDL/Utils.cpp b/mlir/lib/Target/LLVM/ROCDL/Utils.cpp
index c07ea8eebaf2e..4829950a1818c 100644
--- a/mlir/lib/Target/LLVM/ROCDL/Utils.cpp
+++ b/mlir/lib/Target/LLVM/ROCDL/Utils.cpp
@@ -29,17 +29,19 @@ using namespace mlir::ROCDL;
 /// `amdhsa.kernels`:
 /// https://llvm.org/docs/AMDGPUUsage.html#code-object-v3-metadata
 template <typename ELFT>
-static std::optional<llvm::msgpack::Document>
+static std::unique_ptr<llvm::msgpack::Document>
 getAMDHSANote(llvm::object::ELFObjectFile<ELFT> &elfObj) {
   using namespace llvm;
   using namespace llvm::object;
   using namespace llvm::ELF;
   const ELFFile<ELFT> &elf = elfObj.getELFFile();
-  auto secOrErr = elf.sections();
-  if (!secOrErr)
-    return std::nullopt;
+  Expected<typename ELFT::ShdrRange> secOrErr = elf.sections();
+  if (!secOrErr) {
+    consumeError(secOrErr.takeError());
+    return nullptr;
+  }
   ArrayRef<typename ELFT::Shdr> sections = *secOrErr;
-  for (auto section : sections) {
+  for (const typename ELFT::Shdr &section : sections) {
     if (section.sh_type != ELF::SHT_NOTE)
       continue;
     size_t align = std::max(static_cast<unsigned>(section.sh_addralign), 4u);
@@ -54,45 +56,45 @@ getAMDHSANote(llvm::object::ELFObjectFile<ELFT> &elfObj) {
       ArrayRef<uint8_t> desc = note.getDesc(align);
       StringRef msgPackString =
           StringRef(reinterpret_cast<const char *>(desc.data()), desc.size());
-      msgpack::Document msgPackDoc;
-      if (!msgPackDoc.readFromBlob(msgPackString, /*Multi=*/false))
-        return std::nullopt;
-      if (msgPackDoc.getRoot().isScalar())
-        return std::nullopt;
-      return std::optional<llvm::msgpack::Document>(std::move(msgPackDoc));
+      std::unique_ptr<llvm::msgpack::Document> msgPackDoc(
+          new llvm::msgpack::Document());
+      if (!msgPackDoc->readFromBlob(msgPackString, /*Multi=*/false))
+        return nullptr;
+      if (msgPackDoc->getRoot().isScalar())
+        return nullptr;
+      return msgPackDoc;
     }
   }
-  return std::nullopt;
+  return nullptr;
 }
 
-/// Return the `amdhsa.kernels` metadata in the ELF object or std::nullopt on
+/// Return the `amdhsa.kernels` metadata in the ELF object or nullptr on
 /// failure. This is a helper function that casts a generic `ObjectFile` to the
 /// appropiate `ELFObjectFile`.
-static std::optional<llvm::msgpack::Document>
+static std::unique_ptr<llvm::msgpack::Document>
 getAMDHSANote(ArrayRef<char> elfData) {
   using namespace llvm;
   using namespace llvm::object;
   if (elfData.empty())
-    return std::nullopt;
+    return nullptr;
   MemoryBufferRef buffer(StringRef(elfData.data(), elfData.size()), "buffer");
   Expected<std::unique_ptr<ObjectFile>> objOrErr =
       ObjectFile::createELFObjectFile(buffer);
   if (!objOrErr || !objOrErr.get()) {
     // Drop the error.
     llvm::consumeError(objOrErr.takeError());
-    return std::nullopt;
+    return nullptr;
   }
   ObjectFile &elf = *(objOrErr.get());
-  std::optional<llvm::msgpack::Document> metadata;
   if (auto *obj = dyn_cast<ELF32LEObjectFile>(&elf))
-    metadata = getAMDHSANote(*obj);
+    return getAMDHSANote(*obj);
   else if (auto *obj = dyn_cast<ELF32BEObjectFile>(&elf))
-    metadata = getAMDHSANote(*obj);
+    return getAMDHSANote(*obj);
   else if (auto *obj = dyn_cast<ELF64LEObjectFile>(&elf))
-    metadata = getAMDHSANote(*obj);
+    return getAMDHSANote(*obj);
   else if (auto *obj = dyn_cast<ELF64BEObjectFile>(&elf))
-    metadata = getAMDHSANote(*obj);
-  return metadata;
+    return getAMDHSANote(*obj);
+  return nullptr;
 }
 
 /// Utility functions for converting `llvm::msgpack::DocNode` nodes.
@@ -100,11 +102,11 @@ static Attribute convertNode(Builder &builder, llvm::msgpack::DocNode &node);
 static Attribute convertNode(Builder &builder,
                              llvm::msgpack::MapDocNode &node) {
   NamedAttrList attrs;
-  for (auto kv : node) {
-    if (!kv.first.isString())
+  for (auto &[keyNode, valueNode] : node) {
+    if (!keyNode.isString())
       continue;
-    if (Attribute attr = convertNode(builder, kv.second)) {
-      auto key = kv.first.getString();
+    StringRef key = keyNode.getString();
+    if (Attribute attr = convertNode(builder, valueNode)) {
       key.consume_front(".");
       key.consume_back(".");
       attrs.append(key, attr);
@@ -125,7 +127,7 @@ static Attribute convertNode(Builder &builder,
       })) {
     SmallVector<int64_t> values;
     for (llvm::msgpack::DocNode &n : node) {
-      auto kind = n.getKind();
+      llvm::msgpack::Type kind = n.getKind();
       if (kind == NodeKind::Int)
         values.push_back(n.getInt());
       else if (kind == NodeKind::UInt)
@@ -169,41 +171,43 @@ static Attribute convertNode(Builder &builder, llvm::msgpack::DocNode &node) {
 std::optional<DenseMap<StringAttr, NamedAttrList>>
 mlir::ROCDL::getAMDHSAKernelsELFMetadata(Builder &builder,
                                          ArrayRef<char> elfData) {
-  std::optional<llvm::msgpack::Document> metadata = getAMDHSANote(elfData);
+  using namespace llvm::msgpack;
+  std::unique_ptr<llvm::msgpack::Document> metadata = getAMDHSANote(elfData);
   if (!metadata)
     return std::nullopt;
   DenseMap<StringAttr, NamedAttrList> kernelMD;
-  llvm::msgpack::DocNode &root = (metadata)->getRoot();
-  // Fail if `root` is not a map -it should be for AMD Obj Ver 3.
-  if (!root.isMap())
+  DocNode &rootNode = (metadata)->getRoot();
+  // Fail if `rootNode` is not a map -it should be for AMD Obj Ver 3.
+  if (!rootNode.isMap())
     return std::nullopt;
-  auto &kernels = root.getMap()["amdhsa.kernels"];
+  DocNode &kernels = rootNode.getMap()["amdhsa.kernels"];
   // Fail if `amdhsa.kernels` is not an array.
   if (!kernels.isArray())
     return std::nullopt;
   // Convert each of the kernels.
-  for (auto &kernel : kernels.getArray()) {
+  for (DocNode &kernel : kernels.getArray()) {
     if (!kernel.isMap())
       continue;
-    auto &kernelMap = kernel.getMap();
-    auto &name = kernelMap[".name"];
-    if (!name.isString())
+    MapDocNode &kernelMap = kernel.getMap();
+    DocNode &nameNode = kernelMap[".name"];
+    if (!nameNode.isString())
       continue;
+    StringRef name = nameNode.getString();
     NamedAttrList attrList;
     // Convert the kernel properties.
-    for (auto kv : kernelMap) {
-      if (!kv.first.isString())
+    for (auto &[keyNode, valueNode] : kernelMap) {
+      if (!keyNode.isString())
         continue;
-      StringRef key = kv.first.getString();
+      StringRef key = keyNode.getString();
       key.consume_front(".");
       key.consume_back(".");
       if (key == "name")
         continue;
-      if (Attribute attr = convertNode(builder, kv.second))
+      if (Attribute attr = convertNode(builder, valueNode))
         attrList.append(key, attr);
     }
     if (!attrList.empty())
-      kernelMD[builder.getStringAttr(name.getString())] = std::move(attrList);
+      kernelMD[builder.getStringAttr(name)] = std::move(attrList);
   }
   return kernelMD;
 }
@@ -212,19 +216,16 @@ gpu::KernelTableAttr mlir::ROCDL::getKernelMetadata(Operation *gpuModule,
                                                     ArrayRef<char> elfData) {
   auto module = cast<gpu::GPUModuleOp>(gpuModule);
   Builder builder(module.getContext());
-  NamedAttrList moduleAttrs;
+  SmallVector<gpu::KernelAttr> kernels;
   std::optional<DenseMap<StringAttr, NamedAttrList>> mdMapOrNull =
       getAMDHSAKernelsELFMetadata(builder, elfData);
   for (auto funcOp : module.getBody()->getOps<LLVM::LLVMFuncOp>()) {
     if (!funcOp->getDiscardableAttr("rocdl.kernel"))
       continue;
-    moduleAttrs.append(
-        funcOp.getName(),
-        gpu::KernelAttr::get(
-            funcOp, mdMapOrNull ? builder.getDictionaryAttr(
-                                      mdMapOrNull->lookup(funcOp.getNameAttr()))
-                                : nullptr));
+    kernels.push_back(gpu::KernelAttr::get(
+        funcOp, mdMapOrNull ? builder.getDictionaryAttr(
+                                  mdMapOrNull->lookup(funcOp.getNameAttr()))
+                            : nullptr));
   }
-  return gpu::KernelTableAttr::get(
-      moduleAttrs.getDictionary(module.getContext()));
+  return gpu::KernelTableAttr::get(gpuModule->getContext(), kernels);
 }
diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index e9d8f329be8ed..1c65a9ccf7336 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -832,3 +832,15 @@ module attributes {gpu.container_module} {
   gpu.module @kernel <> {
   }
 }
+
+// -----
+
+gpu.binary @binary [#gpu.object<#rocdl.target<chip = "gfx900">,
+  // expected-error at +1{{expected all kernels to be uniquely named}}
+    kernels = #gpu.kernel_table<[
+      #gpu.kernel<"kernel", (i32) -> ()>,
+      #gpu.kernel<"kernel", (i32, f32) -> (), metadata = {sgpr_count = 255}>
+  // expected-error at below{{failed to parse GPU_ObjectAttr parameter 'kernels' which is to be a `KernelTableAttr`}}
+    ]>,
+    bin = "BLOB">
+  ]
diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir
index d902cf3f41e08..f80dca7552623 100644
--- a/mlir/test/Dialect/GPU/ops.mlir
+++ b/mlir/test/Dialect/GPU/ops.mlir
@@ -442,8 +442,25 @@ gpu.module @module_with_two_target [#nvvm.target, #rocdl.target<chip = "gfx90a">
 gpu.module @module_with_offload_handler <#gpu.select_object<0>> [#nvvm.target] {
 }
 
-
-gpu.binary @binary [#gpu.object<#rocdl.target<chip = "gfx900">, kernels = #gpu.kernel_table<{
-    kernel0 = #gpu.kernel<"kernel0", (i32, f32) -> (), metadata = {sgpr_count = 255}>,
-    kernel1 = #gpu.kernel<"kernel1", (i32) -> (), arg_attrs = [{llvm.read_only}]>
-  }>, bin = "BLOB">]
+// Test kernel attributes
+gpu.binary @kernel_attrs_1 [
+    #gpu.object<#rocdl.target<chip = "gfx900">,
+      kernels = #gpu.kernel_table<[
+        #gpu.kernel<"kernel0", (i32, f32) -> (), metadata = {sgpr_count = 255}>,
+        #gpu.kernel<"kernel1", (i32) -> (), arg_attrs = [{llvm.read_only}]>
+      ]>,
+      bin = "BLOB">
+  ]
+
+// Verify the kernels are sorted
+// CHECK-LABEL: gpu.binary @kernel_attrs_2
+gpu.binary @kernel_attrs_2 [
+    // CHECK: [#gpu.kernel<"a_kernel", () -> ()>, #gpu.kernel<"m_kernel", () -> ()>, #gpu.kernel<"z_kernel", () -> ()>]
+    #gpu.object<#rocdl.target<chip = "gfx900">,
+      kernels = #gpu.kernel_table<[
+        #gpu.kernel<"z_kernel", () -> ()>,
+        #gpu.kernel<"m_kernel", () -> ()>,
+        #gpu.kernel<"a_kernel", () -> ()>
+      ]>,
+      bin = "BLOB">
+  ]
diff --git a/mlir/unittests/Target/LLVM/SerializeROCDLTarget.cpp b/mlir/unittests/Target/LLVM/SerializeROCDLTarget.cpp
index c5ad5a045d2f7..100c483c515de 100644
--- a/mlir/unittests/Target/LLVM/SerializeROCDLTarget.cpp
+++ b/mlir/unittests/Target/LLVM/SerializeROCDLTarget.cpp
@@ -166,6 +166,23 @@ TEST_F(MLIRTargetLLVMROCDL, SKIP_WITHOUT_AMDGPU(GetELFMetadata)) {
 
   MLIRContext context(registry);
 
+  // MLIR module used for the tests.
+  const std::string moduleStr = R"mlir(
+    gpu.module @rocdl_test {
+    llvm.func @rocdl_kernel_1(%arg0: f32) attributes {gpu.kernel, rocdl.kernel} {
+      llvm.return
+    }
+    llvm.func @rocdl_kernel_0(%arg0: f32) attributes {gpu.kernel, rocdl.kernel} {
+      llvm.return
+    }
+    llvm.func @rocdl_kernel_2(%arg0: f32) attributes {gpu.kernel, rocdl.kernel} {
+      llvm.return
+    }
+    llvm.func @a_kernel(%arg0: f32) attributes {gpu.kernel, rocdl.kernel} {
+      llvm.return
+    }
+  })mlir";
+
   OwningOpRef<ModuleOp> module =
       parseSourceString<ModuleOp>(moduleStr, &context);
   ASSERT_TRUE(!!module);
@@ -189,11 +206,15 @@ TEST_F(MLIRTargetLLVMROCDL, SKIP_WITHOUT_AMDGPU(GetELFMetadata)) {
     gpu::KernelTableAttr metadata =
         ROCDL::getKernelMetadata(gpuModule, *object);
     ASSERT_TRUE(metadata != nullptr);
-    // There should be only a single kernel.
-    ASSERT_TRUE(metadata.size() == 1);
+    // There should be 4 kernels.
+    ASSERT_TRUE(metadata.size() == 4);
+    // Check that the lookup method returns finds the kernel.
+    ASSERT_TRUE(metadata.lookup("a_kernel") != nullptr);
+    ASSERT_TRUE(metadata.lookup("rocdl_kernel_0") != nullptr);
+    // Check that the kernel doesn't exist.
+    ASSERT_TRUE(metadata.lookup("not_existent_kernel") == nullptr);
     // Test the `ROCDLObjectMDAttr` iterators.
-    for (auto [name, kernel] : metadata) {
-      ASSERT_TRUE(name.getValue() == "rocdl_kernel");
+    for (gpu::KernelAttr kernel : metadata) {
       // Check that the ELF metadata is present.
       ASSERT_TRUE(kernel.getMetadata() != nullptr);
       // Verify that `sgpr_count` is present and it is an integer attribute.

>From 0d20f287c97cf10a402684ff1504d82fe7043d39 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Tue, 30 Jul 2024 15:45:03 +0000
Subject: [PATCH 6/8] fix docs

---
 .../mlir/Dialect/GPU/IR/CompilationAttrs.td     | 17 ++++++++++++++---
 1 file changed, 14 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/GPU/IR/CompilationAttrs.td b/mlir/include/mlir/Dialect/GPU/IR/CompilationAttrs.td
index 78740a1cb5e1d..0cc51620d0b6a 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/CompilationAttrs.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/CompilationAttrs.td
@@ -114,12 +114,23 @@ def GPU_KernelAttr : GPU_Attr<"Kernel", "kernel"> {
 
 def GPU_KernelTableAttr : GPU_Attr<"KernelTable", "kernel_table"> {
   let description = [{
-    GPU attribute representing a table of kernels metadata. All the attributes
-    in the dictionary must be of type `#gpu.kernel`.
+    GPU attribute representing a list of `#gpu.kernel` attributes. This
+    attribute supports searching kernels by name. All kernels in the table must
+    have an unique name.
 
     Examples:
     ```mlir
-      #gpu.kernel_table<[#gpu.kernel<kernel0, ...>]>
+      // Empty table
+      #gpu.kernel_table<>
+
+      // Table with a single kernel
+      #gpu.kernel_table<[#gpu.kernel<kernel0, () -> () >]>
+
+      // Table with multiple kernels.
+      #gpu.kernel_table<[
+        #gpu.kernel<"kernel0", (i32, f32) -> (), metadata = {sgpr_count = 255}>,
+        #gpu.kernel<"kernel1", (i32) -> ()>
+      ]>
     ```
   }];
   let parameters = (ins

>From a78c9dfd51d3941a99dec4a4366e01a111f4f20f Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Tue, 30 Jul 2024 15:49:18 +0000
Subject: [PATCH 7/8] expand types

---
 mlir/lib/Dialect/GPU/IR/GPUDialect.cpp |  6 ++++--
 mlir/lib/Target/LLVM/ROCDL/Utils.cpp   | 23 +++++++++++------------
 2 files changed, 15 insertions(+), 14 deletions(-)

diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index c78f82fef39f9..5ac085f7d8c8b 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -2330,12 +2330,14 @@ KernelTableAttr::verify(function_ref<InFlightDiagnostic()> emitError,
 }
 
 KernelAttr KernelTableAttr::lookup(StringRef key) const {
-  auto it = impl::findAttrSorted(begin(), end(), key);
+  std::pair<ArrayRef<KernelAttr>::iterator, bool> it =
+      impl::findAttrSorted(begin(), end(), key);
   return it.second ? *it.first : KernelAttr();
 }
 
 KernelAttr KernelTableAttr::lookup(StringAttr key) const {
-  auto it = impl::findAttrSorted(begin(), end(), key);
+  std::pair<ArrayRef<KernelAttr>::iterator, bool> it =
+      impl::findAttrSorted(begin(), end(), key);
   return it.second ? *it.first : KernelAttr();
 }
 
diff --git a/mlir/lib/Target/LLVM/ROCDL/Utils.cpp b/mlir/lib/Target/LLVM/ROCDL/Utils.cpp
index 4829950a1818c..7e703924db0ed 100644
--- a/mlir/lib/Target/LLVM/ROCDL/Utils.cpp
+++ b/mlir/lib/Target/LLVM/ROCDL/Utils.cpp
@@ -119,18 +119,18 @@ static Attribute convertNode(Builder &builder,
 
 static Attribute convertNode(Builder &builder,
                              llvm::msgpack::ArrayDocNode &node) {
-  using NodeKind = llvm::msgpack::Type;
   // Use `DenseIntAttr` if we know all the attrs are ints.
   if (llvm::all_of(node, [](llvm::msgpack::DocNode &n) {
-        auto kind = n.getKind();
-        return kind == NodeKind::Int || kind == NodeKind::UInt;
+        llvm::msgpack::Type kind = n.getKind();
+        return kind == llvm::msgpack::Type::Int ||
+               kind == llvm::msgpack::Type::UInt;
       })) {
     SmallVector<int64_t> values;
     for (llvm::msgpack::DocNode &n : node) {
       llvm::msgpack::Type kind = n.getKind();
-      if (kind == NodeKind::Int)
+      if (kind == llvm::msgpack::Type::Int)
         values.push_back(n.getInt());
-      else if (kind == NodeKind::UInt)
+      else if (kind == llvm::msgpack::Type::UInt)
         values.push_back(n.getUInt());
     }
     return builder.getDenseI64ArrayAttr(values);
@@ -148,19 +148,18 @@ static Attribute convertNode(Builder &builder,
 
 static Attribute convertNode(Builder &builder, llvm::msgpack::DocNode &node) {
   using namespace llvm::msgpack;
-  using NodeKind = llvm::msgpack::Type;
   switch (node.getKind()) {
-  case NodeKind::Int:
+  case llvm::msgpack::Type::Int:
     return builder.getI64IntegerAttr(node.getInt());
-  case NodeKind::UInt:
+  case llvm::msgpack::Type::UInt:
     return builder.getI64IntegerAttr(node.getUInt());
-  case NodeKind::Boolean:
+  case llvm::msgpack::Type::Boolean:
     return builder.getI64IntegerAttr(node.getBool());
-  case NodeKind::String:
+  case llvm::msgpack::Type::String:
     return builder.getStringAttr(node.getString());
-  case NodeKind::Array:
+  case llvm::msgpack::Type::Array:
     return convertNode(builder, node.getArray());
-  case NodeKind::Map:
+  case llvm::msgpack::Type::Map:
     return convertNode(builder, node.getMap());
   default:
     return nullptr;

>From f56506feef92562bd0616beef206975de49e48f0 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Tue, 30 Jul 2024 17:18:10 +0000
Subject: [PATCH 8/8] fix python bindings

---
 mlir/lib/Bindings/Python/DialectGPU.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Bindings/Python/DialectGPU.cpp b/mlir/lib/Bindings/Python/DialectGPU.cpp
index d4df5d5b58130..560a54bcd1591 100644
--- a/mlir/lib/Bindings/Python/DialectGPU.cpp
+++ b/mlir/lib/Bindings/Python/DialectGPU.cpp
@@ -53,7 +53,7 @@ PYBIND11_MODULE(_mlirDialectsGPU, m) {
             py::buffer_info info(py::buffer(object).request());
             MlirStringRef objectStrRef =
                 mlirStringRefCreate(static_cast<char *>(info.ptr), info.size);
-            return cls(mlirGPUObjectAttrGet(
+            return cls(mlirGPUObjectAttrGetWithKernels(
                 mlirAttributeGetContext(target), target, format, objectStrRef,
                 mlirObjectProps.has_value() ? *mlirObjectProps
                                             : MlirAttribute{nullptr},



More information about the Mlir-commits mailing list