[clang] [llvm] [mlir] [mlir][gpu] Add the `OffloadEmbeddingAttr` offloading translation attr (PR #78117)

Fabian Mora via cfe-commits cfe-commits at lists.llvm.org
Sun Jan 14 18:17:05 PST 2024


https://github.com/fabianmcg created https://github.com/llvm/llvm-project/pull/78117

This patch adds the offloading translation attribute. This attribute uses LLVM
offloading infrastructure to embed GPU binaries in the IR. At the program start,
the LLVM offloading mechanism registers kernels and variables with the runtime
library: CUDA RT, HIP RT, or LibOMPTarget.

The offloading mechanism relies on the runtime library to dispatch the correct
kernel based on the registered symbols.
    
This patch is 3/4 on introducing the `OffloadEmbeddingAttr` GPU translation
attribute.
    
Note: Ignore the base commits; those are being reviewed in PRs #78057, #78098,
and #78073.


>From 61c8809698b66cf3b4686e9908fb11773ecf0eb6 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Sat, 13 Jan 2024 23:45:57 +0000
Subject: [PATCH 1/4] [mlir][interfaces] Add the `TargetInfo` attribute
 interface

This patch adds the TargetInfo attribute interface to the set of DLTI
interfaces. Target information attributes provide essential information on the
compilation target. This information includes the target triple identifier, the
target chip identifier, and a string representation of the target features.

This patch also adds this new interface to the NVVM and ROCDL GPU target
attributes.
---
 .../include/mlir/Dialect/LLVMIR/NVVMDialect.h |  1 +
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td   |  5 ++-
 .../mlir/Dialect/LLVMIR/ROCDLDialect.h        |  1 +
 mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td  |  6 ++--
 .../mlir/Interfaces/DataLayoutInterfaces.td   | 33 +++++++++++++++++++
 mlir/lib/Dialect/LLVMIR/CMakeLists.txt        |  2 ++
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp    |  8 +++++
 mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp   |  8 +++++
 8 files changed, 61 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
index 08019e77ae6af8..1a55d08be9edc2 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/DataLayoutInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "llvm/IR/IntrinsicsNVPTX.h"
 
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index c5f68a2ebe3952..0bbbde6270cd69 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -17,6 +17,7 @@ include "mlir/IR/EnumAttr.td"
 include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"
 include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/DataLayoutInterfaces.td"
 include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td"
 
 def LLVM_PointerGlobal : LLVM_PointerInAddressSpace<1>;
@@ -1894,7 +1895,9 @@ def NVVM_WgmmaMmaAsyncOp : NVVM_Op<"wgmma.mma_async",
 // NVVM target attribute.
 //===----------------------------------------------------------------------===//
 
-def NVVM_TargettAttr : NVVM_Attr<"NVVMTarget", "target"> {
+def NVVM_TargettAttr : NVVM_Attr<"NVVMTarget", "target", [
+    DeclareAttrInterfaceMethods<TargetInfoAttrInterface>
+  ]> {
   let description = [{
     GPU target attribute for controlling compilation of NVIDIA targets. All
     parameters decay into default values if not present.
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h b/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h
index c2a82ffc1c43cf..fa1131a463e1ab 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h
@@ -26,6 +26,7 @@
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/DataLayoutInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
 ///// Ops /////
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 48b830ae34f292..a492709c299544 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -15,6 +15,7 @@
 
 include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"
 include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+include "mlir/Interfaces/DataLayoutInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 
 //===----------------------------------------------------------------------===//
@@ -608,8 +609,9 @@ def ROCDL_CvtSrFp8F32Op :
 // ROCDL target attribute.
 //===----------------------------------------------------------------------===//
 
-def ROCDL_TargettAttr :
-    ROCDL_Attr<"ROCDLTarget", "target"> {
+def ROCDL_TargettAttr : ROCDL_Attr<"ROCDLTarget", "target", [
+    DeclareAttrInterfaceMethods<TargetInfoAttrInterface>
+  ]> {
   let description = [{
     ROCDL target attribute for controlling compilation of AMDGPU targets. All
     parameters decay into default values if not present.
diff --git a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
index a8def967fffcfa..eac9521aadc11e 100644
--- a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
+++ b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
@@ -188,6 +188,39 @@ def DataLayoutSpecInterface : AttrInterface<"DataLayoutSpecInterface"> {
   }];
 }
 
+def TargetInfoAttrInterface : AttrInterface<"TargetInfoAttrInterface"> {
+  let cppNamespace = "::mlir";
+
+  let description = [{
+    Attribute interface describing target information.
+
+    Target information attributes provide essential information on the
+    compilation target. This information includes the target triple identifier,
+    the target chip identifier, and a string representation of the target features.
+  }];
+
+  let methods = [
+    InterfaceMethod<
+      /*description=*/"Returns the target triple identifier.",
+      /*retTy=*/"::mlir::StringRef",
+      /*methodName=*/"getTargetTriple",
+      /*args=*/(ins)
+    >,
+    InterfaceMethod<
+      /*description=*/"Returns the target chip identifier.",
+      /*retTy=*/"::mlir::StringRef",
+      /*methodName=*/"getTargetChip",
+      /*args=*/(ins)
+    >,
+    InterfaceMethod<
+      /*description=*/"Returns the target features as a string.",
+      /*retTy=*/"std::string",
+      /*methodName=*/"getTargetFeatures",
+      /*args=*/(ins)
+    >
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // Operation interface
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
index b00259677697a5..00b78e30ee8b09 100644
--- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
@@ -61,6 +61,7 @@ add_mlir_dialect_library(MLIRNVVMDialect
   LINK_LIBS PUBLIC
   MLIRIR
   MLIRLLVMDialect
+  MLIRDataLayoutInterfaces
   MLIRSideEffectInterfaces
   )
 
@@ -83,5 +84,6 @@ add_mlir_dialect_library(MLIRROCDLDialect
   LINK_LIBS PUBLIC
   MLIRIR
   MLIRLLVMDialect
+  MLIRDataLayoutInterfaces
   MLIRSideEffectInterfaces
   )
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index aa49c4dc31fbc0..b73504ac4969af 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1106,6 +1106,14 @@ NVVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError,
   return success();
 }
 
+StringRef NVVMTargetAttr::getTargetTriple() const { return getTriple(); }
+
+StringRef NVVMTargetAttr::getTargetChip() const { return getChip(); }
+
+std::string NVVMTargetAttr::getTargetFeatures() const {
+  return getFeatures().str();
+}
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
index 26e46b31ddc018..8b10c48718a3f8 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
@@ -295,6 +295,14 @@ ROCDLTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError,
   return success();
 }
 
+StringRef ROCDLTargetAttr::getTargetTriple() const { return getTriple(); }
+
+StringRef ROCDLTargetAttr::getTargetChip() const { return getChip(); }
+
+std::string ROCDLTargetAttr::getTargetFeatures() const {
+  return getFeatures().str();
+}
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/LLVMIR/ROCDLOps.cpp.inc"
 

>From 436ec9b04bb238238d4a935a8f965a13e70c6846 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Sun, 14 Jan 2024 01:29:19 +0000
Subject: [PATCH 2/4] [mlir][Target][LLVM] Add offload utility class

This patch adds the `OffloadHandler` utility class for creating LLVM offload
entries.
LLVM offload entries hold information on offload symbols; for example, for a
GPU kernel, this includes its host address to identify the kernel and the kernel
identifier in the binary. Arrays of offload entries can be used to register
functions within the CUDA/HIP runtime. Libomptarget also uses these entries to
register OMP target offload kernels and variables.

This patch is 1/4 on introducing the `OffloadEmbeddingAttr` GPU translation
attribute.
---
 mlir/include/mlir/Target/LLVM/Offload.h   |  61 ++++++++++++
 mlir/lib/Target/LLVM/CMakeLists.txt       |   2 +
 mlir/lib/Target/LLVM/Offload.cpp          | 111 ++++++++++++++++++++++
 mlir/unittests/Target/LLVM/CMakeLists.txt |   1 +
 mlir/unittests/Target/LLVM/Offload.cpp    |  49 ++++++++++
 5 files changed, 224 insertions(+)
 create mode 100644 mlir/include/mlir/Target/LLVM/Offload.h
 create mode 100644 mlir/lib/Target/LLVM/Offload.cpp
 create mode 100644 mlir/unittests/Target/LLVM/Offload.cpp

diff --git a/mlir/include/mlir/Target/LLVM/Offload.h b/mlir/include/mlir/Target/LLVM/Offload.h
new file mode 100644
index 00000000000000..7b705667d477d2
--- /dev/null
+++ b/mlir/include/mlir/Target/LLVM/Offload.h
@@ -0,0 +1,61 @@
+//===- Offload.h - LLVM Target Offload --------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares LLVM target offload utility classes.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TARGET_LLVM_OFFLOAD_H
+#define MLIR_TARGET_LLVM_OFFLOAD_H
+
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/StringRef.h"
+
+namespace llvm {
+class Constant;
+class GlobalVariable;
+class Module;
+} // namespace llvm
+
+namespace mlir {
+namespace LLVM {
+/// `OffloadHandler` is a utility class for creating LLVM offload entries. LLVM
+/// offload entries hold information on offload symbols; for example, for a GPU
+/// kernel, this includes its host address to identify the kernel and the kernel
+/// identifier in the binary. Arrays of offload entries can be used to register
+/// functions within the CUDA/HIP runtime. Libomptarget also uses these entries
+/// to register OMP target offload kernels and variables.
+class OffloadHandler {
+public:
+  using OffloadEntryArray =
+      std::pair<llvm::GlobalVariable *, llvm::GlobalVariable *>;
+  OffloadHandler(llvm::Module &module) : module(module) {}
+
+  /// Returns the begin symbol name used in the entry array.
+  static std::string getBeginSymbol(StringRef suffix);
+
+  /// Returns the end symbol name used in the entry array.
+  static std::string getEndSymbol(StringRef suffix);
+
+  /// Returns the entry array if it exists or a pair of null pointers.
+  OffloadEntryArray getEntryArray(StringRef suffix);
+
+  /// Emits an empty array of offloading entries.
+  OffloadEntryArray emitEmptyEntryArray(StringRef suffix);
+
+  /// Inserts an offloading entry into an existing entry array. This method
+  /// returns failure if the entry array hasn't been declared.
+  LogicalResult insertOffloadEntry(StringRef suffix, llvm::Constant *entry);
+
+protected:
+  llvm::Module &module;
+};
+} // namespace LLVM
+} // namespace mlir
+
+#endif // MLIR_TARGET_LLVM_OFFLOAD_H
diff --git a/mlir/lib/Target/LLVM/CMakeLists.txt b/mlir/lib/Target/LLVM/CMakeLists.txt
index cc2c3a00a02eaf..241a6c64dd868f 100644
--- a/mlir/lib/Target/LLVM/CMakeLists.txt
+++ b/mlir/lib/Target/LLVM/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_library(MLIRTargetLLVM
   ModuleToObject.cpp
+  Offload.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Target/LLVM
@@ -16,6 +17,7 @@ add_mlir_library(MLIRTargetLLVM
   Passes
   Support
   Target
+  FrontendOffloading
   LINK_LIBS PUBLIC
   MLIRExecutionEngineUtils
   MLIRTargetLLVMIRExport
diff --git a/mlir/lib/Target/LLVM/Offload.cpp b/mlir/lib/Target/LLVM/Offload.cpp
new file mode 100644
index 00000000000000..81ba12403bfb99
--- /dev/null
+++ b/mlir/lib/Target/LLVM/Offload.cpp
@@ -0,0 +1,111 @@
+//===- Offload.cpp - LLVM Target Offload ------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines LLVM target offload utility classes.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Target/LLVM/Offload.h"
+#include "llvm/Frontend/Offloading/Utility.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/Module.h"
+
+using namespace mlir;
+using namespace mlir::LLVM;
+
+std::string OffloadHandler::getBeginSymbol(StringRef suffix) {
+  return ("__begin_offload_" + suffix).str();
+}
+
+std::string OffloadHandler::getEndSymbol(StringRef suffix) {
+  return ("__end_offload_" + suffix).str();
+}
+
+namespace {
+/// Returns the type of the entry array.
+llvm::ArrayType *getEntryArrayType(llvm::Module &module, size_t numElems) {
+  return llvm::ArrayType::get(llvm::offloading::getEntryTy(module), numElems);
+}
+
+/// Creates the initializer of the entry array.
+llvm::Constant *getEntryArrayBegin(llvm::Module &module,
+                                   ArrayRef<llvm::Constant *> entries) {
+  // If there are no entries return a constant zero initializer.
+  llvm::ArrayType *arrayTy = getEntryArrayType(module, entries.size());
+  return entries.empty() ? llvm::ConstantAggregateZero::get(arrayTy)
+                         : llvm::ConstantArray::get(arrayTy, entries);
+}
+
+/// Computes the end position of the entry array.
+llvm::Constant *getEntryArrayEnd(llvm::Module &module,
+                                 llvm::GlobalVariable *begin, size_t numElems) {
+  llvm::Type *intTy = module.getDataLayout().getIntPtrType(module.getContext());
+  return llvm::ConstantExpr::getGetElementPtr(
+      llvm::offloading::getEntryTy(module), begin,
+      ArrayRef<llvm::Constant *>({llvm::ConstantInt::get(intTy, numElems)}),
+      true);
+}
+} // namespace
+
+OffloadHandler::OffloadEntryArray
+OffloadHandler::getEntryArray(StringRef suffix) {
+  llvm::GlobalVariable *beginGV =
+      module.getGlobalVariable(getBeginSymbol(suffix), true);
+  llvm::GlobalVariable *endGV =
+      module.getGlobalVariable(getEndSymbol(suffix), true);
+  return {beginGV, endGV};
+}
+
+OffloadHandler::OffloadEntryArray
+OffloadHandler::emitEmptyEntryArray(StringRef suffix) {
+  llvm::ArrayType *arrayTy = getEntryArrayType(module, 0);
+  auto *beginGV = new llvm::GlobalVariable(
+      module, arrayTy, /*isConstant=*/true, llvm::GlobalValue::InternalLinkage,
+      getEntryArrayBegin(module, {}), getBeginSymbol(suffix));
+  auto *endGV = new llvm::GlobalVariable(
+      module, llvm::PointerType::get(module.getContext(), 0),
+      /*isConstant=*/true, llvm::GlobalValue::InternalLinkage,
+      getEntryArrayEnd(module, beginGV, 0), getEndSymbol(suffix));
+  return {beginGV, endGV};
+}
+
+LogicalResult OffloadHandler::insertOffloadEntry(StringRef suffix,
+                                                 llvm::Constant *entry) {
+  // Get the begin and end symbols to the entry array.
+  std::string beginSymId = getBeginSymbol(suffix);
+  llvm::GlobalVariable *beginGV = module.getGlobalVariable(beginSymId, true);
+  llvm::GlobalVariable *endGV =
+      module.getGlobalVariable(getEndSymbol(suffix), true);
+  // Fail if the symbols are missing.
+  if (!beginGV || !endGV)
+    return failure();
+  // Create the entry initializer.
+  assert(beginGV->getInitializer() && "entry array initializer is missing.");
+  // Add existing entries into the new entry array.
+  SmallVector<llvm::Constant *> entries;
+  if (auto beginInit = dyn_cast_or_null<llvm::ConstantAggregate>(
+          beginGV->getInitializer())) {
+    for (unsigned i = 0; i < beginInit->getNumOperands(); ++i)
+      entries.push_back(beginInit->getOperand(i));
+  }
+  // Add the new entry.
+  entries.push_back(entry);
+  // Create a global holding the new updated set of entries.
+  auto *arrayTy = llvm::ArrayType::get(llvm::offloading::getEntryTy(module),
+                                       entries.size());
+  auto *entryArr = new llvm::GlobalVariable(
+      module, arrayTy, /*isConstant=*/true, llvm::GlobalValue::InternalLinkage,
+      getEntryArrayBegin(module, entries), beginSymId, endGV);
+  // Replace the old entry array variable withe new one.
+  beginGV->replaceAllUsesWith(entryArr);
+  beginGV->eraseFromParent();
+  entryArr->setName(beginSymId);
+  // Update the end symbol.
+  endGV->setInitializer(getEntryArrayEnd(module, entryArr, entries.size()));
+  return success();
+}
diff --git a/mlir/unittests/Target/LLVM/CMakeLists.txt b/mlir/unittests/Target/LLVM/CMakeLists.txt
index 6d612548a94c0f..d04f38ddddfacf 100644
--- a/mlir/unittests/Target/LLVM/CMakeLists.txt
+++ b/mlir/unittests/Target/LLVM/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_unittest(MLIRTargetLLVMTests
+  Offload.cpp
   SerializeNVVMTarget.cpp
   SerializeROCDLTarget.cpp
   SerializeToLLVMBitcode.cpp
diff --git a/mlir/unittests/Target/LLVM/Offload.cpp b/mlir/unittests/Target/LLVM/Offload.cpp
new file mode 100644
index 00000000000000..375edc2e9614d3
--- /dev/null
+++ b/mlir/unittests/Target/LLVM/Offload.cpp
@@ -0,0 +1,49 @@
+//===- Offload.cpp ----------------------------------------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Target/LLVM/Offload.h"
+#include "llvm/Frontend/Offloading/Utility.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/Module.h"
+
+#include "gmock/gmock.h"
+
+using namespace llvm;
+
+TEST(MLIRTarget, OffloadAPI) {
+  using OffloadEntryArray = mlir::LLVM::OffloadHandler::OffloadEntryArray;
+  LLVMContext llvmContext;
+  Module llvmModule("offload", llvmContext);
+  mlir::LLVM::OffloadHandler handler(llvmModule);
+  StringRef suffix = ".mlir";
+  // Check there's no entry array with `.mlir` suffix.
+  OffloadEntryArray entryArray = handler.getEntryArray(suffix);
+  EXPECT_EQ(entryArray, OffloadEntryArray());
+  // Emit the entry array.
+  handler.emitEmptyEntryArray(suffix);
+  // Check there's an entry array with `.mlir` suffix.
+  entryArray = handler.getEntryArray(suffix);
+  ASSERT_NE(entryArray.first, nullptr);
+  ASSERT_NE(entryArray.second, nullptr);
+  // Check the array contains no entries.
+  auto *zeroInitializer = dyn_cast_or_null<ConstantAggregateZero>(
+      entryArray.first->getInitializer());
+  ASSERT_NE(zeroInitializer, nullptr);
+  // Insert an empty entries.
+  auto emptyEntry =
+      ConstantAggregateZero::get(offloading::getEntryTy(llvmModule));
+  ASSERT_TRUE(succeeded(handler.insertOffloadEntry(suffix, emptyEntry)));
+  // Check there's an entry in the entry array with `.mlir` suffix.
+  entryArray = handler.getEntryArray(suffix);
+  ASSERT_NE(entryArray.first, nullptr);
+  Constant *arrayInitializer = entryArray.first->getInitializer();
+  ASSERT_NE(arrayInitializer, nullptr);
+  auto *arrayTy = dyn_cast_or_null<ArrayType>(arrayInitializer->getType());
+  ASSERT_NE(arrayTy, nullptr);
+  EXPECT_EQ(arrayTy->getNumElements(), 1u);
+}

>From 64a38946f2d72b4e1859354e8f53a1297622f2b6 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Sun, 14 Jan 2024 22:52:48 +0000
Subject: [PATCH 3/4] Base commit, PR #78057

---
 .../tools/clang-linker-wrapper/CMakeLists.txt |   1 -
 .../ClangLinkerWrapper.cpp                    |   8 +-
 .../clang-linker-wrapper/OffloadWrapper.h     |  28 ----
 .../llvm/Frontend/Offloading/OffloadWrapper.h |  55 ++++++++
 .../llvm/Frontend/Offloading/Utility.h        |   6 +
 llvm/lib/Frontend/Offloading/CMakeLists.txt   |   2 +
 .../Frontend/Offloading}/OffloadWrapper.cpp   | 124 +++++++++++-------
 llvm/lib/Frontend/Offloading/Utility.cpp      |  21 ++-
 8 files changed, 161 insertions(+), 84 deletions(-)
 delete mode 100644 clang/tools/clang-linker-wrapper/OffloadWrapper.h
 create mode 100644 llvm/include/llvm/Frontend/Offloading/OffloadWrapper.h
 rename {clang/tools/clang-linker-wrapper => llvm/lib/Frontend/Offloading}/OffloadWrapper.cpp (84%)

diff --git a/clang/tools/clang-linker-wrapper/CMakeLists.txt b/clang/tools/clang-linker-wrapper/CMakeLists.txt
index 744026a37b22c0..5556869affaa62 100644
--- a/clang/tools/clang-linker-wrapper/CMakeLists.txt
+++ b/clang/tools/clang-linker-wrapper/CMakeLists.txt
@@ -28,7 +28,6 @@ endif()
 
 add_clang_tool(clang-linker-wrapper
   ClangLinkerWrapper.cpp
-  OffloadWrapper.cpp
 
   DEPENDS
   ${tablegen_deps}
diff --git a/clang/tools/clang-linker-wrapper/ClangLinkerWrapper.cpp b/clang/tools/clang-linker-wrapper/ClangLinkerWrapper.cpp
index 122ba1998eb83f..c30d66821dae4e 100644
--- a/clang/tools/clang-linker-wrapper/ClangLinkerWrapper.cpp
+++ b/clang/tools/clang-linker-wrapper/ClangLinkerWrapper.cpp
@@ -14,11 +14,11 @@
 //
 //===---------------------------------------------------------------------===//
 
-#include "OffloadWrapper.h"
 #include "clang/Basic/Version.h"
 #include "llvm/BinaryFormat/Magic.h"
 #include "llvm/Bitcode/BitcodeWriter.h"
 #include "llvm/CodeGen/CommandFlags.h"
+#include "llvm/Frontend/Offloading/OffloadWrapper.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DiagnosticPrinter.h"
 #include "llvm/IR/Module.h"
@@ -906,15 +906,15 @@ wrapDeviceImages(ArrayRef<std::unique_ptr<MemoryBuffer>> Buffers,
 
   switch (Kind) {
   case OFK_OpenMP:
-    if (Error Err = wrapOpenMPBinaries(M, BuffersToWrap))
+    if (Error Err = offloading::wrapOpenMPBinaries(M, BuffersToWrap))
       return std::move(Err);
     break;
   case OFK_Cuda:
-    if (Error Err = wrapCudaBinary(M, BuffersToWrap.front()))
+    if (Error Err = offloading::wrapCudaBinary(M, BuffersToWrap.front()))
       return std::move(Err);
     break;
   case OFK_HIP:
-    if (Error Err = wrapHIPBinary(M, BuffersToWrap.front()))
+    if (Error Err = offloading::wrapHIPBinary(M, BuffersToWrap.front()))
       return std::move(Err);
     break;
   default:
diff --git a/clang/tools/clang-linker-wrapper/OffloadWrapper.h b/clang/tools/clang-linker-wrapper/OffloadWrapper.h
deleted file mode 100644
index 679333975b2120..00000000000000
--- a/clang/tools/clang-linker-wrapper/OffloadWrapper.h
+++ /dev/null
@@ -1,28 +0,0 @@
-//===- OffloadWrapper.h --r-------------------------------------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef LLVM_CLANG_TOOLS_CLANG_LINKER_WRAPPER_OFFLOAD_WRAPPER_H
-#define LLVM_CLANG_TOOLS_CLANG_LINKER_WRAPPER_OFFLOAD_WRAPPER_H
-
-#include "llvm/ADT/ArrayRef.h"
-#include "llvm/IR/Module.h"
-
-/// Wraps the input device images into the module \p M as global symbols and
-/// registers the images with the OpenMP Offloading runtime libomptarget.
-llvm::Error wrapOpenMPBinaries(llvm::Module &M,
-                               llvm::ArrayRef<llvm::ArrayRef<char>> Images);
-
-/// Wraps the input fatbinary image into the module \p M as global symbols and
-/// registers the images with the CUDA runtime.
-llvm::Error wrapCudaBinary(llvm::Module &M, llvm::ArrayRef<char> Images);
-
-/// Wraps the input bundled image into the module \p M as global symbols and
-/// registers the images with the HIP runtime.
-llvm::Error wrapHIPBinary(llvm::Module &M, llvm::ArrayRef<char> Images);
-
-#endif
diff --git a/llvm/include/llvm/Frontend/Offloading/OffloadWrapper.h b/llvm/include/llvm/Frontend/Offloading/OffloadWrapper.h
new file mode 100644
index 00000000000000..f6ab1f475cdb90
--- /dev/null
+++ b/llvm/include/llvm/Frontend/Offloading/OffloadWrapper.h
@@ -0,0 +1,55 @@
+//===- OffloadWrapper.h --r-------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_FRONTEND_OFFLOADING_OFFLOADWRAPPER_H
+#define LLVM_FRONTEND_OFFLOADING_OFFLOADWRAPPER_H
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/IR/Module.h"
+
+namespace llvm {
+namespace offloading {
+using EntryArrayTy = std::pair<GlobalVariable *, GlobalVariable *>;
+/// Wraps the input device images into the module \p M as global symbols and
+/// registers the images with the OpenMP Offloading runtime libomptarget.
+/// \param EntryArray Optional pair pointing to the `__start` and `__stop`
+/// symbols holding the `__tgt_offload_entry` array.
+/// \param Suffix An optional suffix appended to the emitted symbols.
+llvm::Error
+wrapOpenMPBinaries(llvm::Module &M, llvm::ArrayRef<llvm::ArrayRef<char>> Images,
+                   std::optional<EntryArrayTy> EntryArray = std::nullopt,
+                   llvm::StringRef Suffix = "");
+
+/// Wraps the input fatbinary image into the module \p M as global symbols and
+/// registers the images with the CUDA runtime.
+/// \param EntryArray Optional pair pointing to the `__start` and `__stop`
+/// symbols holding the `__tgt_offload_entry` array.
+/// \param Suffix An optional suffix appended to the emitted symbols.
+/// \param EmitSurfacesAndTextures Whether to emit surface and textures
+/// registration code. It defaults to false.
+llvm::Error
+wrapCudaBinary(llvm::Module &M, llvm::ArrayRef<char> Images,
+               std::optional<EntryArrayTy> EntryArray = std::nullopt,
+               llvm::StringRef Suffix = "",
+               bool EmitSurfacesAndTextures = true);
+
+/// Wraps the input bundled image into the module \p M as global symbols and
+/// registers the images with the HIP runtime.
+/// \param EntryArray Optional pair pointing to the `__start` and `__stop`
+/// symbols holding the `__tgt_offload_entry` array.
+/// \param Suffix An optional suffix appended to the emitted symbols.
+/// \param EmitSurfacesAndTextures Whether to emit surface and textures
+/// registration code. It defaults to false.
+llvm::Error wrapHIPBinary(llvm::Module &M, llvm::ArrayRef<char> Images,
+                          std::optional<EntryArrayTy> EntryArray = std::nullopt,
+                          llvm::StringRef Suffix = "",
+                          bool EmitSurfacesAndTextures = true);
+} // namespace offloading
+} // namespace llvm
+
+#endif // LLVM_FRONTEND_OFFLOADING_OFFLOADWRAPPER_H
diff --git a/llvm/include/llvm/Frontend/Offloading/Utility.h b/llvm/include/llvm/Frontend/Offloading/Utility.h
index 520c192996a066..f54dd7ba7ab45f 100644
--- a/llvm/include/llvm/Frontend/Offloading/Utility.h
+++ b/llvm/include/llvm/Frontend/Offloading/Utility.h
@@ -61,6 +61,12 @@ StructType *getEntryTy(Module &M);
 void emitOffloadingEntry(Module &M, Constant *Addr, StringRef Name,
                          uint64_t Size, int32_t Flags, int32_t Data,
                          StringRef SectionName);
+/// Create a constant struct initializer used to register this global at
+/// runtime.
+/// \return the constant struct and the global variable holding the symbol name.
+std::pair<Constant *, GlobalVariable *>
+getOffloadingEntryInitializer(Module &M, Constant *Addr, StringRef Name,
+                              uint64_t Size, int32_t Flags, int32_t Data);
 
 /// Creates a pair of globals used to iterate the array of offloading entries by
 /// accessing the section variables provided by the linker.
diff --git a/llvm/lib/Frontend/Offloading/CMakeLists.txt b/llvm/lib/Frontend/Offloading/CMakeLists.txt
index 2d0117c9e10059..16e0dcfa0e90d6 100644
--- a/llvm/lib/Frontend/Offloading/CMakeLists.txt
+++ b/llvm/lib/Frontend/Offloading/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_llvm_component_library(LLVMFrontendOffloading
   Utility.cpp
+  OffloadWrapper.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${LLVM_MAIN_INCLUDE_DIR}/llvm/Frontend
@@ -9,6 +10,7 @@ add_llvm_component_library(LLVMFrontendOffloading
 
   LINK_COMPONENTS
   Core
+  BinaryFormat
   Support
   TransformUtils
   TargetParser
diff --git a/clang/tools/clang-linker-wrapper/OffloadWrapper.cpp b/llvm/lib/Frontend/Offloading/OffloadWrapper.cpp
similarity index 84%
rename from clang/tools/clang-linker-wrapper/OffloadWrapper.cpp
rename to llvm/lib/Frontend/Offloading/OffloadWrapper.cpp
index 161374ae555233..2cc5e110510abd 100644
--- a/clang/tools/clang-linker-wrapper/OffloadWrapper.cpp
+++ b/llvm/lib/Frontend/Offloading/OffloadWrapper.cpp
@@ -6,7 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "OffloadWrapper.h"
+#include "llvm/Frontend/Offloading/OffloadWrapper.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/BinaryFormat/Magic.h"
 #include "llvm/Frontend/Offloading/Utility.h"
@@ -21,6 +21,7 @@
 #include "llvm/Transforms/Utils/ModuleUtils.h"
 
 using namespace llvm;
+using namespace llvm::offloading;
 
 namespace {
 /// Magic number that begins the section containing the CUDA fatbinary.
@@ -110,10 +111,10 @@ PointerType *getBinDescPtrTy(Module &M) {
 /// };
 ///
 /// Global variable that represents BinDesc is returned.
-GlobalVariable *createBinDesc(Module &M, ArrayRef<ArrayRef<char>> Bufs) {
+GlobalVariable *createBinDesc(Module &M, ArrayRef<ArrayRef<char>> Bufs,
+                              EntryArrayTy EntryArray, StringRef Suffix) {
   LLVMContext &C = M.getContext();
-  auto [EntriesB, EntriesE] =
-      offloading::getOffloadEntryArray(M, "omp_offloading_entries");
+  auto [EntriesB, EntriesE] = EntryArray;
 
   auto *Zero = ConstantInt::get(getSizeTTy(M), 0u);
   Constant *ZeroZero[] = {Zero, Zero};
@@ -126,7 +127,7 @@ GlobalVariable *createBinDesc(Module &M, ArrayRef<ArrayRef<char>> Bufs) {
     auto *Data = ConstantDataArray::get(C, Buf);
     auto *Image = new GlobalVariable(M, Data->getType(), /*isConstant=*/true,
                                      GlobalVariable::InternalLinkage, Data,
-                                     ".omp_offloading.device_image");
+                                     ".omp_offloading.device_image" + Suffix);
     Image->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
     Image->setSection(".llvm.offloading");
     Image->setAlignment(Align(object::OffloadBinary::getAlignment()));
@@ -166,7 +167,7 @@ GlobalVariable *createBinDesc(Module &M, ArrayRef<ArrayRef<char>> Bufs) {
   auto *Images =
       new GlobalVariable(M, ImagesData->getType(), /*isConstant*/ true,
                          GlobalValue::InternalLinkage, ImagesData,
-                         ".omp_offloading.device_images");
+                         ".omp_offloading.device_images" + Suffix);
   Images->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
 
   auto *ImagesB =
@@ -180,14 +181,15 @@ GlobalVariable *createBinDesc(Module &M, ArrayRef<ArrayRef<char>> Bufs) {
 
   return new GlobalVariable(M, DescInit->getType(), /*isConstant*/ true,
                             GlobalValue::InternalLinkage, DescInit,
-                            ".omp_offloading.descriptor");
+                            ".omp_offloading.descriptor" + Suffix);
 }
 
-void createRegisterFunction(Module &M, GlobalVariable *BinDesc) {
+void createRegisterFunction(Module &M, GlobalVariable *BinDesc,
+                            StringRef Suffix) {
   LLVMContext &C = M.getContext();
   auto *FuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false);
   auto *Func = Function::Create(FuncTy, GlobalValue::InternalLinkage,
-                                ".omp_offloading.descriptor_reg", &M);
+                                ".omp_offloading.descriptor_reg" + Suffix, &M);
   Func->setSection(".text.startup");
 
   // Get __tgt_register_lib function declaration.
@@ -210,11 +212,13 @@ void createRegisterFunction(Module &M, GlobalVariable *BinDesc) {
   appendToGlobalCtors(M, Func, /*Priority*/ 1);
 }
 
-void createUnregisterFunction(Module &M, GlobalVariable *BinDesc) {
+void createUnregisterFunction(Module &M, GlobalVariable *BinDesc,
+                              StringRef Suffix) {
   LLVMContext &C = M.getContext();
   auto *FuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false);
-  auto *Func = Function::Create(FuncTy, GlobalValue::InternalLinkage,
-                                ".omp_offloading.descriptor_unreg", &M);
+  auto *Func =
+      Function::Create(FuncTy, GlobalValue::InternalLinkage,
+                       ".omp_offloading.descriptor_unreg" + Suffix, &M);
   Func->setSection(".text.startup");
 
   // Get __tgt_unregister_lib function declaration.
@@ -251,7 +255,8 @@ StructType *getFatbinWrapperTy(Module &M) {
 
 /// Embed the image \p Image into the module \p M so it can be found by the
 /// runtime.
-GlobalVariable *createFatbinDesc(Module &M, ArrayRef<char> Image, bool IsHIP) {
+GlobalVariable *createFatbinDesc(Module &M, ArrayRef<char> Image, bool IsHIP,
+                                 StringRef Suffix) {
   LLVMContext &C = M.getContext();
   llvm::Type *Int8PtrTy = PointerType::getUnqual(C);
   llvm::Triple Triple = llvm::Triple(M.getTargetTriple());
@@ -263,7 +268,7 @@ GlobalVariable *createFatbinDesc(Module &M, ArrayRef<char> Image, bool IsHIP) {
   auto *Data = ConstantDataArray::get(C, Image);
   auto *Fatbin = new GlobalVariable(M, Data->getType(), /*isConstant*/ true,
                                     GlobalVariable::InternalLinkage, Data,
-                                    ".fatbin_image");
+                                    ".fatbin_image" + Suffix);
   Fatbin->setSection(FatbinConstantSection);
 
   // Create the fatbinary wrapper
@@ -282,7 +287,7 @@ GlobalVariable *createFatbinDesc(Module &M, ArrayRef<char> Image, bool IsHIP) {
   auto *FatbinDesc =
       new GlobalVariable(M, getFatbinWrapperTy(M),
                          /*isConstant*/ true, GlobalValue::InternalLinkage,
-                         FatbinInitializer, ".fatbin_wrapper");
+                         FatbinInitializer, ".fatbin_wrapper" + Suffix);
   FatbinDesc->setSection(FatbinWrapperSection);
   FatbinDesc->setAlignment(Align(8));
 
@@ -312,10 +317,12 @@ GlobalVariable *createFatbinDesc(Module &M, ArrayRef<char> Image, bool IsHIP) {
 ///                         0, entry->size, 0, 0);
 ///   }
 /// }
-Function *createRegisterGlobalsFunction(Module &M, bool IsHIP) {
+Function *createRegisterGlobalsFunction(Module &M, bool IsHIP,
+                                        EntryArrayTy EntryArray,
+                                        StringRef Suffix,
+                                        bool EmitSurfacesAndTextures) {
   LLVMContext &C = M.getContext();
-  auto [EntriesB, EntriesE] = offloading::getOffloadEntryArray(
-      M, IsHIP ? "hip_offloading_entries" : "cuda_offloading_entries");
+  auto [EntriesB, EntriesE] = EntryArray;
 
   // Get the __cudaRegisterFunction function declaration.
   PointerType *Int8PtrTy = PointerType::get(C, 0);
@@ -339,7 +346,7 @@ Function *createRegisterGlobalsFunction(Module &M, bool IsHIP) {
       IsHIP ? "__hipRegisterVar" : "__cudaRegisterVar", RegVarTy);
 
   // Get the __cudaRegisterSurface function declaration.
-  auto *RegSurfaceTy =
+  FunctionType *RegSurfaceTy =
       FunctionType::get(Type::getVoidTy(C),
                         {Int8PtrPtrTy, Int8PtrTy, Int8PtrTy, Int8PtrTy,
                          Type::getInt32Ty(C), Type::getInt32Ty(C)},
@@ -348,7 +355,7 @@ Function *createRegisterGlobalsFunction(Module &M, bool IsHIP) {
       IsHIP ? "__hipRegisterSurface" : "__cudaRegisterSurface", RegSurfaceTy);
 
   // Get the __cudaRegisterTexture function declaration.
-  auto *RegTextureTy = FunctionType::get(
+  FunctionType *RegTextureTy = FunctionType::get(
       Type::getVoidTy(C),
       {Int8PtrPtrTy, Int8PtrTy, Int8PtrTy, Int8PtrTy, Type::getInt32Ty(C),
        Type::getInt32Ty(C), Type::getInt32Ty(C)},
@@ -454,19 +461,20 @@ Function *createRegisterGlobalsFunction(Module &M, bool IsHIP) {
   Builder.CreateBr(IfEndBB);
   Switch->addCase(Builder.getInt32(llvm::offloading::OffloadGlobalManagedEntry),
                   SwManagedBB);
-
   // Create surface variable registration code.
   Builder.SetInsertPoint(SwSurfaceBB);
-  Builder.CreateCall(
-      RegSurface, {RegGlobalsFn->arg_begin(), Addr, Name, Name, Data, Extern});
+  if (EmitSurfacesAndTextures)
+    Builder.CreateCall(RegSurface, {RegGlobalsFn->arg_begin(), Addr, Name, Name,
+                                    Data, Extern});
   Builder.CreateBr(IfEndBB);
   Switch->addCase(Builder.getInt32(llvm::offloading::OffloadGlobalSurfaceEntry),
                   SwSurfaceBB);
 
   // Create texture variable registration code.
   Builder.SetInsertPoint(SwTextureBB);
-  Builder.CreateCall(RegTexture, {RegGlobalsFn->arg_begin(), Addr, Name, Name,
-                                  Data, Normalized, Extern});
+  if (EmitSurfacesAndTextures)
+    Builder.CreateCall(RegTexture, {RegGlobalsFn->arg_begin(), Addr, Name, Name,
+                                    Data, Normalized, Extern});
   Builder.CreateBr(IfEndBB);
   Switch->addCase(Builder.getInt32(llvm::offloading::OffloadGlobalTextureEntry),
                   SwTextureBB);
@@ -497,18 +505,21 @@ Function *createRegisterGlobalsFunction(Module &M, bool IsHIP) {
 // Create the constructor and destructor to register the fatbinary with the CUDA
 // runtime.
 void createRegisterFatbinFunction(Module &M, GlobalVariable *FatbinDesc,
-                                  bool IsHIP) {
+                                  bool IsHIP,
+                                  std::optional<EntryArrayTy> EntryArrayOpt,
+                                  StringRef Suffix,
+                                  bool EmitSurfacesAndTextures) {
   LLVMContext &C = M.getContext();
   auto *CtorFuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false);
-  auto *CtorFunc =
-      Function::Create(CtorFuncTy, GlobalValue::InternalLinkage,
-                       IsHIP ? ".hip.fatbin_reg" : ".cuda.fatbin_reg", &M);
+  auto *CtorFunc = Function::Create(
+      CtorFuncTy, GlobalValue::InternalLinkage,
+      (IsHIP ? ".hip.fatbin_reg" : ".cuda.fatbin_reg") + Suffix, &M);
   CtorFunc->setSection(".text.startup");
 
   auto *DtorFuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false);
-  auto *DtorFunc =
-      Function::Create(DtorFuncTy, GlobalValue::InternalLinkage,
-                       IsHIP ? ".hip.fatbin_unreg" : ".cuda.fatbin_unreg", &M);
+  auto *DtorFunc = Function::Create(
+      DtorFuncTy, GlobalValue::InternalLinkage,
+      (IsHIP ? ".hip.fatbin_unreg" : ".cuda.fatbin_unreg") + Suffix, &M);
   DtorFunc->setSection(".text.startup");
 
   auto *PtrTy = PointerType::getUnqual(C);
@@ -536,7 +547,7 @@ void createRegisterFatbinFunction(Module &M, GlobalVariable *FatbinDesc,
   auto *BinaryHandleGlobal = new llvm::GlobalVariable(
       M, PtrTy, false, llvm::GlobalValue::InternalLinkage,
       llvm::ConstantPointerNull::get(PtrTy),
-      IsHIP ? ".hip.binary_handle" : ".cuda.binary_handle");
+      (IsHIP ? ".hip.binary_handle" : ".cuda.binary_handle") + Suffix);
 
   // Create the constructor to register this image with the runtime.
   IRBuilder<> CtorBuilder(BasicBlock::Create(C, "entry", CtorFunc));
@@ -546,7 +557,16 @@ void createRegisterFatbinFunction(Module &M, GlobalVariable *FatbinDesc,
   CtorBuilder.CreateAlignedStore(
       Handle, BinaryHandleGlobal,
       Align(M.getDataLayout().getPointerTypeSize(PtrTy)));
-  CtorBuilder.CreateCall(createRegisterGlobalsFunction(M, IsHIP), Handle);
+  EntryArrayTy EntryArray =
+      (EntryArrayOpt ? *EntryArrayOpt
+                     : (IsHIP ? offloading::getOffloadEntryArray(
+                                    M, "hip_offloading_entries")
+                              : offloading::getOffloadEntryArray(
+                                    M, "cuda_offloading_entries")));
+  CtorBuilder.CreateCall(createRegisterGlobalsFunction(M, IsHIP, EntryArray,
+                                                       Suffix,
+                                                       EmitSurfacesAndTextures),
+                         Handle);
   if (!IsHIP)
     CtorBuilder.CreateCall(RegFatbinEnd, Handle);
   CtorBuilder.CreateCall(AtExit, DtorFunc);
@@ -565,35 +585,49 @@ void createRegisterFatbinFunction(Module &M, GlobalVariable *FatbinDesc,
   // Add this function to constructors.
   appendToGlobalCtors(M, CtorFunc, /*Priority*/ 1);
 }
-
 } // namespace
 
-Error wrapOpenMPBinaries(Module &M, ArrayRef<ArrayRef<char>> Images) {
-  GlobalVariable *Desc = createBinDesc(M, Images);
+Error offloading::wrapOpenMPBinaries(Module &M, ArrayRef<ArrayRef<char>> Images,
+                                     std::optional<EntryArrayTy> EntryArray,
+                                     llvm::StringRef Suffix) {
+  GlobalVariable *Desc = createBinDesc(
+      M, Images,
+      EntryArray
+          ? *EntryArray
+          : offloading::getOffloadEntryArray(M, "omp_offloading_entries"),
+      Suffix);
   if (!Desc)
     return createStringError(inconvertibleErrorCode(),
                              "No binary descriptors created.");
-  createRegisterFunction(M, Desc);
-  createUnregisterFunction(M, Desc);
+  createRegisterFunction(M, Desc, Suffix);
+  createUnregisterFunction(M, Desc, Suffix);
   return Error::success();
 }
 
-Error wrapCudaBinary(Module &M, ArrayRef<char> Image) {
-  GlobalVariable *Desc = createFatbinDesc(M, Image, /* IsHIP */ false);
+Error offloading::wrapCudaBinary(Module &M, ArrayRef<char> Image,
+                                 std::optional<EntryArrayTy> EntryArray,
+                                 llvm::StringRef Suffix,
+                                 bool EmitSurfacesAndTextures) {
+  GlobalVariable *Desc = createFatbinDesc(M, Image, /*IsHip=*/false, Suffix);
   if (!Desc)
     return createStringError(inconvertibleErrorCode(),
                              "No fatinbary section created.");
 
-  createRegisterFatbinFunction(M, Desc, /* IsHIP */ false);
+  createRegisterFatbinFunction(M, Desc, /*IsHip=*/false, EntryArray, Suffix,
+                               EmitSurfacesAndTextures);
   return Error::success();
 }
 
-Error wrapHIPBinary(Module &M, ArrayRef<char> Image) {
-  GlobalVariable *Desc = createFatbinDesc(M, Image, /* IsHIP */ true);
+Error offloading::wrapHIPBinary(Module &M, ArrayRef<char> Image,
+                                std::optional<EntryArrayTy> EntryArray,
+                                llvm::StringRef Suffix,
+                                bool EmitSurfacesAndTextures) {
+  GlobalVariable *Desc = createFatbinDesc(M, Image, /*IsHip=*/true, Suffix);
   if (!Desc)
     return createStringError(inconvertibleErrorCode(),
                              "No fatinbary section created.");
 
-  createRegisterFatbinFunction(M, Desc, /* IsHIP */ true);
+  createRegisterFatbinFunction(M, Desc, /*IsHip=*/true, EntryArray, Suffix,
+                               EmitSurfacesAndTextures);
   return Error::success();
 }
diff --git a/llvm/lib/Frontend/Offloading/Utility.cpp b/llvm/lib/Frontend/Offloading/Utility.cpp
index 25f609517ebeb7..531919bccb94e3 100644
--- a/llvm/lib/Frontend/Offloading/Utility.cpp
+++ b/llvm/lib/Frontend/Offloading/Utility.cpp
@@ -1,4 +1,4 @@
-//===- Utility.cpp ------ Collection of geneirc offloading utilities ------===//
+//===- Utility.cpp ------ Collection of generic offloading utilities ------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -28,11 +28,10 @@ StructType *offloading::getEntryTy(Module &M) {
 }
 
 // TODO: Rework this interface to be more generic.
-void offloading::emitOffloadingEntry(Module &M, Constant *Addr, StringRef Name,
-                                     uint64_t Size, int32_t Flags, int32_t Data,
-                                     StringRef SectionName) {
-  llvm::Triple Triple(M.getTargetTriple());
-
+std::pair<Constant *, GlobalVariable *>
+offloading::getOffloadingEntryInitializer(Module &M, Constant *Addr,
+                                          StringRef Name, uint64_t Size,
+                                          int32_t Flags, int32_t Data) {
   Type *Int8PtrTy = PointerType::getUnqual(M.getContext());
   Type *Int32Ty = Type::getInt32Ty(M.getContext());
   Type *SizeTy = M.getDataLayout().getIntPtrType(M.getContext());
@@ -54,6 +53,16 @@ void offloading::emitOffloadingEntry(Module &M, Constant *Addr, StringRef Name,
       ConstantInt::get(Int32Ty, Data),
   };
   Constant *EntryInitializer = ConstantStruct::get(getEntryTy(M), EntryData);
+  return {EntryInitializer, Str};
+}
+
+void offloading::emitOffloadingEntry(Module &M, Constant *Addr, StringRef Name,
+                                     uint64_t Size, int32_t Flags, int32_t Data,
+                                     StringRef SectionName) {
+  llvm::Triple Triple(M.getTargetTriple());
+
+  auto [EntryInitializer, NameGV] =
+      getOffloadingEntryInitializer(M, Addr, Name, Size, Flags, Data);
 
   auto *Entry = new GlobalVariable(
       M, getEntryTy(M),

>From b4ea6de32e0b9a26924b5383922b28568d15d719 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Sun, 14 Jan 2024 23:27:27 +0000
Subject: [PATCH 4/4] [mlir][gpu] Add the `OffloadEmbeddingAttr` offloading
 translation attr

This patch adds the offloading translation attribute. This attribute uses LLVM
offloading infrastructure to embed GPU binaries in the IR. At the program start,
the LLVM offloading mechanism registers kernels and variables with the runtime
library: CUDA RT, HIP RT, or LibOMPTarget.
The offloading mechanism relies on the runtime library to dispatch the correct
kernel based on the registered symbols.

This patch is 3/4 on introducing the `OffloadEmbeddingAttr` GPU translation
attribute.

Note: Ignore the base commits; those are being reviewed in PRs #78057, #78098,
and #78073.
---
 .../mlir/Dialect/GPU/IR/CompilationAttrs.td   |  35 ++
 .../Target/LLVMIR/Dialect/GPU/CMakeLists.txt  |   5 +-
 ...ttr.cpp => OffloadingTranslationAttrs.cpp} | 434 +++++++++++++++---
 mlir/test/Target/LLVMIR/gpu.mlir              |  83 ++++
 4 files changed, 498 insertions(+), 59 deletions(-)
 rename mlir/lib/Target/LLVMIR/Dialect/GPU/{SelectObjectAttr.cpp => OffloadingTranslationAttrs.cpp} (54%)

diff --git a/mlir/include/mlir/Dialect/GPU/IR/CompilationAttrs.td b/mlir/include/mlir/Dialect/GPU/IR/CompilationAttrs.td
index 6659f4a2c58e82..812b72681343b9 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/CompilationAttrs.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/CompilationAttrs.td
@@ -107,4 +107,39 @@ def GPU_SelectObjectAttr : GPU_Attr<"SelectObject", "select_object", [
   let genVerifyDecl = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// GPU LLVM offload attribute.
+//===----------------------------------------------------------------------===//
+def GPU_OffloadOpenMP : I32EnumAttrCase<"OpenMP", 1, "omp">;
+def GPU_OffloadCUDA : I32EnumAttrCase<"CUDA", 2, "cuda">;
+def GPU_OffloadHIP : I32EnumAttrCase<"HIP", 3, "hip">;
+def GPU_OffloadKindEnum : GPU_I32Enum<
+  "OffloadKind", "GPU offload kind", [
+    GPU_OffloadOpenMP,
+    GPU_OffloadCUDA,
+    GPU_OffloadHIP
+  ]>;
+
+def GPU_OffloadEmbeddingAttr : GPU_Attr<"OffloadEmbedding", "offload_embedding", [
+    OffloadingTranslationAttrTrait
+  ]> {
+  let description = [{
+    This GPU offloading handler uses LLVM offloading infrastructure to embed GPU
+    binaries in the IR. At program start, the LLVM offloading mechanism registers
+    kernels and variables with the runtime library: CUDA RT, HIP RT or
+    LibOMPTarget.
+    The offloading mechanism relies on the runtime library to dispatch the
+    correct kernel based on the registered symbols.
+    This offload mechanism requires to specify which runtime is being called,
+    this is done by the `kind` parameter.
+    Example:
+    ```mlir
+    gpu.binary @binary <#gpu.offload_embedding<omp>> [...]
+    gpu.binary @binary <#gpu.offload_embedding<cuda>> [...]
+    ```
+  }];
+  let parameters = (ins "gpu::OffloadKind":$kind);
+  let assemblyFormat = [{ `<` $kind `>` }];
+}
+
 #endif // GPU_COMPILATION_ATTRS
diff --git a/mlir/lib/Target/LLVMIR/Dialect/GPU/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/GPU/CMakeLists.txt
index 11816ff5c2c1f1..b95b1e95a039ba 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/GPU/CMakeLists.txt
+++ b/mlir/lib/Target/LLVMIR/Dialect/GPU/CMakeLists.txt
@@ -1,14 +1,17 @@
 add_mlir_translation_library(MLIRGPUToLLVMIRTranslation
   GPUToLLVMIRTranslation.cpp
-  SelectObjectAttr.cpp
+  OffloadingTranslationAttrs.cpp
 
   LINK_COMPONENTS
   Core
+  FrontendOffloading
+  Object
 
   LINK_LIBS PUBLIC
   MLIRIR
   MLIRGPUDialect
   MLIRLLVMDialect
   MLIRSupport
+  MLIRTargetLLVM
   MLIRTargetLLVMIRExport
   )
diff --git a/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp b/mlir/lib/Target/LLVMIR/Dialect/GPU/OffloadingTranslationAttrs.cpp
similarity index 54%
rename from mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp
rename to mlir/lib/Target/LLVMIR/Dialect/GPU/OffloadingTranslationAttrs.cpp
index 0eb33287d608bd..4448b72615e21d 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/GPU/OffloadingTranslationAttrs.cpp
@@ -25,6 +25,9 @@
 
 using namespace mlir;
 
+//===----------------------------------------------------------------------===//
+// SelectObjectAttr
+//===----------------------------------------------------------------------===//
 namespace {
 // Implementation of the `OffloadingLLVMTranslationAttrInterface` model.
 class SelectObjectAttrImpl
@@ -54,13 +57,6 @@ std::string getBinaryIdentifier(StringRef binaryName) {
 }
 } // namespace
 
-void mlir::gpu::registerOffloadingLLVMTranslationInterfaceExternalModels(
-    DialectRegistry &registry) {
-  registry.addExtension(+[](MLIRContext *ctx, gpu::GPUDialect *dialect) {
-    SelectObjectAttr::attachInterface<SelectObjectAttrImpl>(*ctx);
-  });
-}
-
 gpu::ObjectAttr
 SelectObjectAttrImpl::getSelectedObject(gpu::BinaryOp op) const {
   ArrayRef<Attribute> objects = op.getObjectsAttr().getValue();
@@ -136,6 +132,9 @@ class LaunchKernel {
   // Get the kernel launch callee.
   FunctionCallee getKernelLaunchFn();
 
+  // Get the kernel RT launch callee.
+  FunctionCallee getKernelRTLaunchFn();
+
   // Get the kernel launch callee.
   FunctionCallee getClusterKernelLaunchFn();
 
@@ -166,9 +165,15 @@ class LaunchKernel {
   // Create the void* kernel array for passing the arguments.
   Value *createKernelArgArray(mlir::gpu::LaunchFuncOp op);
 
+  // Returns a pair containing the function pointer to the kernel and the
+  // pointer to the kernel module.
+  mlir::FailureOr<std::pair<Value *, Value *>>
+  getKernelInfo(mlir::gpu::LaunchFuncOp op, mlir::gpu::ObjectAttr object);
+
   // Create the full kernel launch.
   mlir::LogicalResult createKernelLaunch(mlir::gpu::LaunchFuncOp op,
-                                         mlir::gpu::ObjectAttr object);
+                                         mlir::gpu::ObjectAttr object,
+                                         Value *kernelPtr = nullptr);
 
 private:
   Module &module;
@@ -244,6 +249,16 @@ llvm::FunctionCallee llvm::LaunchKernel::getClusterKernelLaunchFn() {
           false));
 }
 
+llvm::FunctionCallee llvm::LaunchKernel::getKernelRTLaunchFn() {
+  return module.getOrInsertFunction(
+      "mgpuLaunchKernelRT",
+      FunctionType::get(voidTy,
+                        ArrayRef<Type *>({ptrTy, intPtrTy, intPtrTy, intPtrTy,
+                                          intPtrTy, intPtrTy, intPtrTy, i32Ty,
+                                          ptrTy, ptrTy, ptrTy, i64Ty}),
+                        false));
+}
+
 llvm::FunctionCallee llvm::LaunchKernel::getModuleFunctionFn() {
   return module.getOrInsertFunction(
       "mgpuModuleGetFunction",
@@ -334,46 +349,14 @@ llvm::LaunchKernel::createKernelArgArray(mlir::gpu::LaunchFuncOp op) {
   return argArray;
 }
 
-// Emits LLVM IR to launch a kernel function:
+// Loads the kernel module pointer
 // %0 = call %binarygetter
 // %1 = call %moduleLoad(%0)
 // %2 = <see generateKernelNameConstant>
 // %3 = call %moduleGetFunction(%1, %2)
-// %4 = call %streamCreate()
-// %5 = <see generateParamsArray>
-// call %launchKernel(%3, <launchOp operands 0..5>, 0, %4, %5, nullptr)
-// call %streamSynchronize(%4)
-// call %streamDestroy(%4)
-// call %moduleUnload(%1)
-mlir::LogicalResult
-llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op,
-                                       mlir::gpu::ObjectAttr object) {
-  auto llvmValue = [&](mlir::Value value) -> Value * {
-    Value *v = moduleTranslation.lookupValue(value);
-    assert(v && "Value has not been translated.");
-    return v;
-  };
-
-  // Get grid dimensions.
-  mlir::gpu::KernelDim3 grid = op.getGridSizeOperandValues();
-  Value *gx = llvmValue(grid.x), *gy = llvmValue(grid.y),
-        *gz = llvmValue(grid.z);
-
-  // Get block dimensions.
-  mlir::gpu::KernelDim3 block = op.getBlockSizeOperandValues();
-  Value *bx = llvmValue(block.x), *by = llvmValue(block.y),
-        *bz = llvmValue(block.z);
-
-  // Get dynamic shared memory size.
-  Value *dynamicMemorySize = nullptr;
-  if (mlir::Value dynSz = op.getDynamicSharedMemorySize())
-    dynamicMemorySize = llvmValue(dynSz);
-  else
-    dynamicMemorySize = ConstantInt::get(i32Ty, 0);
-
-  // Create the argument array.
-  Value *argArray = createKernelArgArray(op);
-
+mlir::FailureOr<std::pair<llvm::Value *, llvm::Value *>>
+llvm::LaunchKernel::getKernelInfo(mlir::gpu::LaunchFuncOp op,
+                                  mlir::gpu::ObjectAttr object) {
   // Default JIT optimization level.
   llvm::Constant *optV = llvm::ConstantInt::get(i32Ty, 0);
   // Check if there's an optimization level embedded in the object.
@@ -385,7 +368,6 @@ llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op,
       return op.emitError("the optimization level must be an integer");
     optV = llvm::ConstantInt::get(i32Ty, optLevel.getValue());
   }
-
   // Load the kernel module.
   StringRef moduleName = op.getKernelModuleName().getValue();
   std::string binaryIdentifier = getBinaryIdentifier(moduleName);
@@ -417,6 +399,56 @@ llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op,
       getModuleFunctionFn(),
       {moduleObject,
        getOrCreateFunctionName(moduleName, op.getKernelName().getValue())});
+  return std::pair<Value *, Value *>(moduleFunction, moduleObject);
+}
+
+// Emits LLVM IR to launch a kernel function:
+// %4 = call %streamCreate()
+// %5 = <see generateParamsArray>
+// call %launchKernel(%3, <launchOp operands 0..5>, 0, %4, %5, nullptr)
+// call %streamSynchronize(%4)
+// call %streamDestroy(%4)
+// call %moduleUnload(%1)
+mlir::LogicalResult
+llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op,
+                                       mlir::gpu::ObjectAttr object,
+                                       Value *kernelPtr) {
+  auto llvmValue = [&](mlir::Value value) -> Value * {
+    Value *v = moduleTranslation.lookupValue(value);
+    assert(v && "Value has not been translated.");
+    return v;
+  };
+
+  // Get grid dimensions.
+  mlir::gpu::KernelDim3 grid = op.getGridSizeOperandValues();
+  Value *gx = llvmValue(grid.x), *gy = llvmValue(grid.y),
+        *gz = llvmValue(grid.z);
+
+  // Get block dimensions.
+  mlir::gpu::KernelDim3 block = op.getBlockSizeOperandValues();
+  Value *bx = llvmValue(block.x), *by = llvmValue(block.y),
+        *bz = llvmValue(block.z);
+
+  // Get dynamic shared memory size.
+  Value *dynamicMemorySize = nullptr;
+  if (mlir::Value dynSz = op.getDynamicSharedMemorySize())
+    dynamicMemorySize = llvmValue(dynSz);
+  else
+    dynamicMemorySize = ConstantInt::get(i32Ty, 0);
+
+  // Create the argument array.
+  Value *argArray = createKernelArgArray(op);
+
+  Value *moduleObject = nullptr, *moduleFunction = nullptr;
+
+  if (!kernelPtr) {
+    mlir::FailureOr<std::pair<Value *, Value *>> kernelInfo =
+        getKernelInfo(op, object);
+    if (failed(kernelInfo))
+      return failure();
+    moduleFunction = kernelInfo->first;
+    moduleObject = kernelInfo->second;
+  }
 
   // Get the stream to use for execution. If there's no async object then create
   // a stream to make a synchronous kernel launch.
@@ -436,19 +468,27 @@ llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op,
   Value *nullPtr = ConstantPointerNull::get(ptrTy);
 
   // Launch kernel with clusters if cluster size is specified.
-  if (op.hasClusterSize()) {
-    mlir::gpu::KernelDim3 cluster = op.getClusterSizeOperandValues();
-    Value *cx = llvmValue(cluster.x), *cy = llvmValue(cluster.y),
-          *cz = llvmValue(cluster.z);
-    builder.CreateCall(
-        getClusterKernelLaunchFn(),
-        ArrayRef<Value *>({moduleFunction, cx, cy, cz, gx, gy, gz, bx, by, bz,
-                           dynamicMemorySize, stream, argArray, nullPtr}));
+  if (moduleFunction) {
+    if (op.hasClusterSize()) {
+      mlir::gpu::KernelDim3 cluster = op.getClusterSizeOperandValues();
+      Value *cx = llvmValue(cluster.x), *cy = llvmValue(cluster.y),
+            *cz = llvmValue(cluster.z);
+      builder.CreateCall(
+          getClusterKernelLaunchFn(),
+          ArrayRef<Value *>({moduleFunction, cx, cy, cz, gx, gy, gz, bx, by, bz,
+                             dynamicMemorySize, stream, argArray, nullPtr}));
+    } else {
+      builder.CreateCall(getKernelLaunchFn(),
+                         ArrayRef<Value *>({moduleFunction, gx, gy, gz, bx, by,
+                                            bz, dynamicMemorySize, stream,
+                                            argArray, nullPtr, paramsCount}));
+    }
   } else {
-    builder.CreateCall(getKernelLaunchFn(),
-                       ArrayRef<Value *>({moduleFunction, gx, gy, gz, bx, by,
-                                          bz, dynamicMemorySize, stream,
-                                          argArray, nullPtr, paramsCount}));
+    assert(kernelPtr && "invalid kernel pointer");
+    builder.CreateCall(
+        getKernelRTLaunchFn(),
+        ArrayRef<Value *>({kernelPtr, gx, gy, gz, bx, by, bz, dynamicMemorySize,
+                           stream, argArray, nullPtr, paramsCount}));
   }
 
   // Sync & destroy the stream, for synchronous launches.
@@ -458,7 +498,285 @@ llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op,
   }
 
   // Unload the kernel module.
-  builder.CreateCall(getModuleUnloadFn(), {moduleObject});
+  if (moduleObject)
+    builder.CreateCall(getModuleUnloadFn(), {moduleObject});
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// OffloadEmbeddingAttr
+//===----------------------------------------------------------------------===//
+#include "mlir/Target/LLVM/Offload.h"
+#include "llvm/Frontend/Offloading/OffloadWrapper.h"
+#include "llvm/Frontend/Offloading/Utility.h"
+#include "llvm/Object/OffloadBinary.h"
+
+namespace {
+// Implementation of the `OffloadingLLVMTranslationAttrInterface` model.
+class OffloadEmbeddingAttrImpl
+    : public gpu::OffloadingLLVMTranslationAttrInterface::FallbackModel<
+          OffloadEmbeddingAttrImpl> {
+public:
+  // Translates a `gpu.binary`, embedding the binary into a host LLVM module as
+  // global binary string.
+  LogicalResult embedBinary(Attribute attribute, Operation *operation,
+                            llvm::IRBuilderBase &builder,
+                            LLVM::ModuleTranslation &moduleTranslation) const;
+
+  // Translates a `gpu.launch_func` to a sequence of LLVM instructions resulting
+  // in a kernel launch call.
+  LogicalResult launchKernel(Attribute attribute,
+                             Operation *launchFuncOperation,
+                             Operation *binaryOperation,
+                             llvm::IRBuilderBase &builder,
+                             LLVM::ModuleTranslation &moduleTranslation) const;
+};
+} // namespace
+
+namespace {
+llvm::object::ImageKind getImageKind(gpu::CompilationTarget format) {
+  switch (format) {
+  case gpu::CompilationTarget::Offload:
+    return llvm::object::IMG_Bitcode;
+  case gpu::CompilationTarget::Assembly:
+    return llvm::object::IMG_PTX;
+  case gpu::CompilationTarget::Binary:
+    return llvm::object::IMG_Object;
+  case gpu::CompilationTarget::Fatbin:
+    return llvm::object::IMG_Fatbinary;
+  }
+}
+
+llvm::object::OffloadKind getOffloadKind(gpu::OffloadKind offloadKind) {
+  switch (offloadKind) {
+  case gpu::OffloadKind::OpenMP:
+    return llvm::object::OFK_OpenMP;
+  case gpu::OffloadKind::CUDA:
+    return llvm::object::OFK_Cuda;
+  case gpu::OffloadKind::HIP:
+    return llvm::object::OFK_HIP;
+  }
+}
+
+using OffloadEntryArray = LLVM::OffloadHandler::OffloadEntryArray;
+
+/// Utility class for embedding binaries and launching kernels using the
+/// offloading attribute.
+class OffloadManager : public LLVM::OffloadHandler {
+public:
+  OffloadManager(gpu::BinaryOp binaryOp, llvm::Module &module,
+                 gpu::OffloadKind offloadKind)
+      : LLVM::OffloadHandler(module), binaryOp(binaryOp),
+        offloadKind(offloadKind) {}
+
+  /// Embed a GPU binary into a module.
+  LogicalResult embedBinary();
+
+  /// Generates the kernel launch call.
+  LogicalResult launchKernel(gpu::LaunchFuncOp launchFunc,
+                             llvm::IRBuilderBase &builder,
+                             LLVM::ModuleTranslation &moduleTranslation);
+
+protected:
+  /// Returns the name to be used for the offloading symbols.
+  StringRef getSymbolSuffix();
+
+  /// Emits the offloading entry for `launchFunc`.
+  LogicalResult emitOffloadingEntry(gpu::LaunchFuncOp launchFunc,
+                                    llvm::Constant *registeredSym);
+
+  /// Bundle OpenMP images together.
+  SmallVector<std::unique_ptr<llvm::MemoryBuffer>>
+  bundleOpenMP(ArrayRef<Attribute> objects);
+
+  /// Bundle gpu-objects together. TODO: support more than a single object.
+  FailureOr<SmallVector<std::unique_ptr<llvm::MemoryBuffer>>>
+  bundleGPU(ArrayRef<Attribute> objects);
+
+  /// Bundle objects depending on the `gpu::OffloadKind`.
+  FailureOr<SmallVector<std::unique_ptr<llvm::MemoryBuffer>>>
+  bundleImages(ArrayRef<Attribute> objects);
+
+  /// Emit registration code and embed the images.
+  LogicalResult wrapImages(llvm::Module &module, ArrayRef<ArrayRef<char>> imgs);
+
+  /// Convert a `ObjectAttr` to a OffloadingImage.
+  llvm::object::OffloadBinary::OffloadingImage
+  getOffloadingImage(gpu::ObjectAttr obj);
 
+  gpu::BinaryOp binaryOp;
+  gpu::OffloadKind offloadKind;
+};
+} // namespace
+
+llvm::object::OffloadBinary::OffloadingImage
+OffloadManager::getOffloadingImage(gpu::ObjectAttr obj) {
+  // Create the binary used by Libomptarget
+  auto targetAttr = cast<TargetInfoAttrInterface>(obj.getTarget());
+  llvm::object::OffloadBinary::OffloadingImage imageBinary{};
+  imageBinary.TheImageKind = getImageKind(obj.getFormat());
+  imageBinary.TheOffloadKind = getOffloadKind(offloadKind);
+  imageBinary.StringData["triple"] = targetAttr.getTargetTriple();
+  imageBinary.StringData["arch"] = targetAttr.getTargetChip();
+  imageBinary.Image =
+      llvm::MemoryBuffer::getMemBufferCopy(obj.getObject().getValue(), "");
+  return imageBinary;
+}
+
+SmallVector<std::unique_ptr<llvm::MemoryBuffer>>
+OffloadManager::bundleOpenMP(ArrayRef<Attribute> objects) {
+  // Bundle all the available objects in the binary.
+  SmallVector<std::unique_ptr<llvm::MemoryBuffer>> buffers;
+  for (Attribute attr : objects)
+    buffers.emplace_back(
+        llvm::MemoryBuffer::getMemBufferCopy(llvm::object::OffloadBinary::write(
+            getOffloadingImage(cast<gpu::ObjectAttr>(attr)))));
+  return buffers;
+}
+
+FailureOr<SmallVector<std::unique_ptr<llvm::MemoryBuffer>>>
+OffloadManager::bundleGPU(ArrayRef<Attribute> objects) {
+  if (objects.size() > 1)
+    return binaryOp.emitError("multiple objects are not yet supported");
+  SmallVector<std::unique_ptr<llvm::MemoryBuffer>> buffers;
+  assert(objects.size() == 1 && "there should be a single object");
+  auto object = cast<gpu::ObjectAttr>(objects[0]);
+  if (gpu::CompilationTarget frmt = object.getFormat();
+      frmt != gpu::CompilationTarget::Binary &&
+      frmt != gpu::CompilationTarget::Fatbin)
+    return binaryOp.emitError(
+        "the only supported objects are binaries and fat-binaries.");
+  buffers.emplace_back(
+      llvm::MemoryBuffer::getMemBuffer(object.getObject().getValue()));
+  return buffers;
+}
+
+FailureOr<SmallVector<std::unique_ptr<llvm::MemoryBuffer>>>
+OffloadManager::bundleImages(ArrayRef<Attribute> objects) {
+  switch (offloadKind) {
+  case gpu::OffloadKind::OpenMP:
+    return bundleOpenMP(objects);
+  case gpu::OffloadKind::CUDA:
+  case gpu::OffloadKind::HIP:
+    return bundleGPU(objects);
+  }
+}
+
+StringRef OffloadManager::getSymbolSuffix() { return binaryOp.getName(); }
+
+LogicalResult
+OffloadManager::emitOffloadingEntry(gpu::LaunchFuncOp launchFunc,
+                                    llvm::Constant *registeredSym) {
+  // Create the entry initializer.
+  std::pair<llvm::Constant *, llvm::GlobalVariable *> entry =
+      llvm::offloading::getOffloadingEntryInitializer(
+          module, registeredSym, launchFunc.getKernelName().getValue(), 0, 0,
+          0);
+  if (failed(insertOffloadEntry(getSymbolSuffix(), entry.first)))
+    return binaryOp.emitError("entry array symbols not found");
+  return success();
+}
+
+LogicalResult OffloadManager::wrapImages(llvm::Module &module,
+                                         ArrayRef<ArrayRef<char>> imgs) {
+  // This suffix is appended to all the symbols emitted by the `wrap*` methods.
+  std::string suffix = "." + getSymbolSuffix().str();
+  // Emit an empty entry array.
+  OffloadEntryArray entryArray = emitEmptyEntryArray(getSymbolSuffix());
+  switch (offloadKind) {
+  case gpu::OffloadKind::OpenMP:
+    if (auto error = llvm::offloading::wrapOpenMPBinaries(module, imgs,
+                                                          entryArray, suffix))
+      return binaryOp.emitError("failed wrapping the OpenMP binaries");
+    return success();
+  case gpu::OffloadKind::CUDA:
+    if (auto error = llvm::offloading::wrapCudaBinary(
+            module, imgs.front(), entryArray, suffix, false))
+      return binaryOp.emitError("failed wrapping the CUDA binaries");
+    return success();
+  case gpu::OffloadKind::HIP:
+    if (auto error = llvm::offloading::wrapHIPBinary(module, imgs.front(),
+                                                     entryArray, suffix, false))
+      return binaryOp.emitError("failed wrapping the HIP binaries");
+    return success();
+  }
+}
+
+LogicalResult OffloadManager::embedBinary() {
+  // Call all the methods in order, bundleImages -> wrapImages.
+  auto bundledImgs = bundleImages(binaryOp.getObjectsAttr().getValue());
+  if (failed(bundledImgs))
+    return failure();
+  SmallVector<ArrayRef<char>> imgs;
+  for (auto &img : bundledImgs.value())
+    imgs.push_back(ArrayRef<char>(img->getBufferStart(), img->getBufferSize()));
+  if (failed(wrapImages(module, imgs)))
+    return failure();
+  return success();
+}
+
+LogicalResult
+OffloadManager::launchKernel(gpu::LaunchFuncOp launchFunc,
+                             llvm::IRBuilderBase &builder,
+                             LLVM::ModuleTranslation &moduleTranslation) {
+  // OpenMMP kernels launches are handled by the `omp.target` op.
+  if (offloadKind == gpu::OffloadKind::OpenMP)
+    return binaryOp.emitError(
+        "it's invalid to call OpenMP kernels using gpu.launch_func");
+  llvm::Module *hostModule = moduleTranslation.getLLVMModule();
+  // Create or get the symbol to be registered.
+  std::string symbolId =
+      (binaryOp.getName() + "_K" + launchFunc.getKernelName().getValue()).str();
+  llvm::Constant *registeredSym = nullptr;
+  if (!(registeredSym = hostModule->getGlobalVariable(symbolId))) {
+    // Create the symbol used to register the kernel with the runtime.
+    registeredSym = new llvm::GlobalVariable(
+        *hostModule, builder.getInt8Ty(), /*isConstant=*/true,
+        llvm::GlobalValue::InternalLinkage, builder.getInt8(0), symbolId);
+    // Emit the offload entry.
+    if (failed(emitOffloadingEntry(launchFunc, registeredSym)))
+      return failure();
+  }
+  return llvm::LaunchKernel(*moduleTranslation.getLLVMModule(), builder,
+                            moduleTranslation)
+      .createKernelLaunch(launchFunc, nullptr, registeredSym);
+}
+
+LogicalResult OffloadEmbeddingAttrImpl::embedBinary(
+    Attribute attribute, Operation *operation, llvm::IRBuilderBase &builder,
+    LLVM::ModuleTranslation &moduleTranslation) const {
+  if (failed(OffloadManager(
+                 mlir::cast<gpu::BinaryOp>(operation),
+                 *moduleTranslation.getLLVMModule(),
+                 mlir::cast<gpu::OffloadEmbeddingAttr>(attribute).getKind())
+                 .embedBinary()))
+    return failure();
+  return success();
+}
+
+LogicalResult OffloadEmbeddingAttrImpl::launchKernel(
+    Attribute attribute, Operation *launchFuncOperation,
+    Operation *binaryOperation, llvm::IRBuilderBase &builder,
+    LLVM::ModuleTranslation &moduleTranslation) const {
+  if (failed(
+          OffloadManager(
+              mlir::cast<gpu::BinaryOp>(binaryOperation),
+              *moduleTranslation.getLLVMModule(),
+              mlir::cast<gpu::OffloadEmbeddingAttr>(attribute).getKind())
+              .launchKernel(mlir::cast<gpu::LaunchFuncOp>(launchFuncOperation),
+                            builder, moduleTranslation)))
+    return failure();
   return success();
 }
+
+//===----------------------------------------------------------------------===//
+// Interface registration
+//===----------------------------------------------------------------------===//
+void mlir::gpu::registerOffloadingLLVMTranslationInterfaceExternalModels(
+    DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, gpu::GPUDialect *dialect) {
+    SelectObjectAttr::attachInterface<SelectObjectAttrImpl>(*ctx);
+    OffloadEmbeddingAttr::attachInterface<OffloadEmbeddingAttrImpl>(*ctx);
+  });
+}
diff --git a/mlir/test/Target/LLVMIR/gpu.mlir b/mlir/test/Target/LLVMIR/gpu.mlir
index 88672bd231df8f..74dfa53558d71f 100644
--- a/mlir/test/Target/LLVMIR/gpu.mlir
+++ b/mlir/test/Target/LLVMIR/gpu.mlir
@@ -101,3 +101,86 @@ module attributes {gpu.container_module} {
     llvm.return
   }
 }
+
+// -----
+
+// Test the `offload_embedding<cuda>` attribute.
+module attributes {gpu.container_module} {
+  // CHECK: @__begin_offload_kernel_module = internal constant [1 x %{{.*}}] [%{{.*}} { ptr @[[KERNEL_SYMBOL:.*]], ptr @[[ENTRY_NAME:.*]], i64 0, i32 0, i32 0 }]
+  // CHECK: @__end_offload_kernel_module = internal constant ptr getelementptr inbounds (%{{.*}}, ptr @__begin_offload_kernel_module, i64 1)
+  // CHECK: @[[FATBIN:.*]] = internal constant [4 x i8] c"BLOB", section ".nv_fatbin"
+  // CHECK: @[[FATBIN_HANDLE:.*]] = internal constant %{{.*}} { i32 1180844977, i32 1, ptr @[[FATBIN]]
+  // CHECK: @llvm.global_ctors = appending global [1 x { i32, ptr, ptr }] [{ i32, ptr, ptr } { i32 1, ptr @[[REGISTRATION_CTOR:.*]], ptr null }]
+  // CHECK: @[[KERNEL_SYMBOL]] = internal constant i8 0
+  // CHECK-NEXT: @[[ENTRY_NAME]] = internal unnamed_addr constant [7 x i8] c"kernel\00"
+  gpu.binary @kernel_module <#gpu.offload_embedding<cuda>> [#gpu.object<#nvvm.target, bin = "BLOB">]
+  llvm.func @foo() {
+    // CHECK: [[ARGS:%.*]] = alloca %{{.*}}, align 8
+    // CHECK-NEXT: [[ARGS_ARRAY:%.*]] = alloca ptr, i64 2, align 8
+    // CHECK-NEXT: [[ARG0:%.*]] = getelementptr inbounds [[ARGS_TY]], ptr [[ARGS]], i32 0, i32 0
+    // CHECK-NEXT: store i32 32, ptr [[ARG0]], align 4
+    // CHECK-NEXT: %{{.*}} = getelementptr ptr, ptr [[ARGS_ARRAY]], i32 0
+    // CHECK-NEXT: store ptr [[ARG0]], ptr %{{.*}}, align 8
+    // CHECK-NEXT: [[ARG1:%.*]] = getelementptr inbounds [[ARGS_TY]], ptr [[ARGS]], i32 0, i32 1
+    // CHECK-NEXT: store i32 32, ptr [[ARG1]], align 4
+    // CHECK-NEXT: %{{.*}} = getelementptr ptr, ptr [[ARGS_ARRAY]], i32 1
+    // CHECK-NEXT: store ptr [[ARG1]], ptr %{{.*}}, align 8
+    // CHECK-NEXT: [[STREAM:%.*]] = call ptr @mgpuStreamCreate()
+    // CHECK-NEXT: call void @mgpuLaunchKernelRT(ptr @[[KERNEL_SYMBOL]], i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i32 256, ptr [[STREAM]], ptr [[ARGS_ARRAY]], ptr null, i64 2)
+    // CHECK-NEXT: call void @mgpuStreamSynchronize(ptr [[STREAM]])
+    // CHECK-NEXT: call void @mgpuStreamDestroy(ptr [[STREAM]])
+    %0 = llvm.mlir.constant(8 : index) : i64
+    %1 = llvm.mlir.constant(32 : i32) : i32
+    %2 = llvm.mlir.constant(256 : i32) : i32
+    gpu.launch_func @kernel_module::@kernel blocks in (%0, %0, %0) threads in (%0, %0, %0) : i64 dynamic_shared_memory_size %2 args(%1 : i32, %1 : i32)
+    llvm.return
+  }
+  // CHECK: define internal void @[[REGISTRATION_CTOR]]
+  // CHECK: %{{.*}} = call ptr @__cudaRegisterFatBinary(ptr @[[FATBIN_HANDLE]])
+}
+
+// -----
+
+// Test the `offload_embedding<hip>` attribute.
+module attributes {gpu.container_module} {
+  // CHECK: @__begin_offload_kernel_module = internal constant [2 x %{{.*}}] [
+  // CHECK: %{{.*}} { ptr @[[KERNEL_1_SYMBOL:.*]], ptr @[[ENTRY_NAME_1:.*]], i64 0, i32 0, i32 0 },
+  // CHECK: %{{.*}} { ptr @[[KERNEL_2_SYMBOL:.*]], ptr @[[ENTRY_NAME_2:.*]], i64 0, i32 0, i32 0 }]
+  // CHECK: @__end_offload_kernel_module = internal constant ptr getelementptr inbounds (%{{.*}}, ptr @__begin_offload_kernel_module, i64 2)
+  // CHECK: @[[FATBIN:.*]] = internal constant [4 x i8] c"BLOB", section ".hip_fatbin"
+  // CHECK: @[[FATBIN_HANDLE:.*]] = internal constant %{{.*}} { i32 1212764230, i32 1, ptr @[[FATBIN]]
+  // CHECK: @llvm.global_ctors = appending global [1 x { i32, ptr, ptr }] [{ i32, ptr, ptr } { i32 1, ptr @[[REGISTRATION_CTOR:.*]], ptr null }]
+  // CHECK: @[[KERNEL_1_SYMBOL]] = internal constant i8 0
+  // CHECK-NEXT: @[[ENTRY_NAME_1]] = internal unnamed_addr constant [9 x i8] c"kernel_1\00"
+  // CHECK: @[[KERNEL_2_SYMBOL]] = internal constant i8 0
+  // CHECK-NEXT: @[[ENTRY_NAME_2]] = internal unnamed_addr constant [9 x i8] c"kernel_2\00"
+  gpu.binary @kernel_module <#gpu.offload_embedding<hip>> [#gpu.object<#rocdl.target, bin = "BLOB">]
+  llvm.func @foo() {
+    %0 = llvm.mlir.constant(8 : index) : i64
+    %1 = llvm.mlir.constant(32 : i32) : i32
+    %2 = llvm.mlir.constant(256 : i32) : i32
+    gpu.launch_func @kernel_module::@kernel_1 blocks in (%0, %0, %0) threads in (%0, %0, %0) : i64 dynamic_shared_memory_size %2 args(%1 : i32, %1 : i32)
+    gpu.launch_func @kernel_module::@kernel_2 blocks in (%0, %0, %0) threads in (%0, %0, %0) : i64 dynamic_shared_memory_size %2 args(%1 : i32, %1 : i32)
+    llvm.return
+  }
+  // CHECK: define internal void @[[REGISTRATION_CTOR]]
+  // CHECK: %{{.*}} = call ptr @__hipRegisterFatBinary(ptr @[[FATBIN_HANDLE]])
+}
+
+// -----
+
+// Test the `offload_embedding<omp>` attribute.
+module attributes {gpu.container_module} {
+  // CHECK: @__begin_offload_kernel_module = internal constant [0 x %{{.*}}] zeroinitializer
+  // CHECK: @__end_offload_kernel_module = internal constant ptr @__begin_offload_kernel_module
+  // CHECK: @[[BINARY:.*]] = internal unnamed_addr constant [{{.*}} x i8] c"{{.*}}", section ".llvm.offloading", align 8
+  // CHECK: @[[BINARIES:.*]] = internal unnamed_addr constant [1 x %{{.*}}] [%{{.*}} { ptr getelementptr inbounds ([{{.*}} x i8], ptr @[[BINARY]], i64 0, i64 {{.*}}), ptr getelementptr inbounds ([{{.*}} x i8], ptr @[[BINARY]], i64 0, i64 {{.*}}), ptr @__begin_offload_kernel_module, ptr @__end_offload_kernel_module }]
+  // CHECK: @[[DESCRIPTOR:.*]] = internal constant %{{.*}} { i32 1, ptr @[[BINARIES]], ptr @__begin_offload_kernel_module, ptr @__end_offload_kernel_module }
+  // CHECK: @llvm.global_ctors = appending global [1 x { i32, ptr, ptr }] [{ i32, ptr, ptr } { i32 1, ptr @[[REGISTRATION_CTOR:.*]], ptr null }]
+  // CHECK: @llvm.global_dtors = appending global [1 x { i32, ptr, ptr }] [{ i32, ptr, ptr } { i32 1, ptr @[[REGISTRATION_DTOR:.*]], ptr null }]
+  gpu.binary @kernel_module <#gpu.offload_embedding<omp>> [#gpu.object<#rocdl.target, bin = "BLOB">]
+  // CHECK: define internal void @[[REGISTRATION_CTOR]]
+  // CHECK: call {{.*}} @__tgt_register_lib(ptr @[[DESCRIPTOR]])
+  // CHECK: define internal void @[[REGISTRATION_DTOR]]
+  // CHECK: call {{.*}} @__tgt_unregister_lib(ptr @[[DESCRIPTOR]])
+}



More information about the cfe-commits mailing list