[clang] f49d576 - [CUDA] Add wrapper code generation for registering CUDA images

Joseph Huber via cfe-commits cfe-commits at lists.llvm.org
Wed May 11 04:30:42 PDT 2022


Author: Joseph Huber
Date: 2022-05-11T07:30:25-04:00
New Revision: f49d576a882da81292b5730af442fa38899af312

URL: https://github.com/llvm/llvm-project/commit/f49d576a882da81292b5730af442fa38899af312
DIFF: https://github.com/llvm/llvm-project/commit/f49d576a882da81292b5730af442fa38899af312.diff

LOG: [CUDA] Add wrapper code generation for registering CUDA images

This patch adds the necessary code generation to create the wrapper code
that registers all the globals in CUDA. We create the necessary
functions and iterate through the list of
`__start_cuda_offloading_entries` to find which globals must be
registered. This is very similar to the code generation done currently
in Clang for non-rdc builds, but here we are registering a fully linked
fatbinary and finding the globals via the above sections.

With this we should be able to fully support basic RDC / LTO building of CUDA
code.

It's also worth noting that this does not include the necessary PTX to JIT the
image, so to use this support the offloading architecture must match the
system's architecture.

Depends on D123810

Reviewed By: tra

Differential Revision: https://reviews.llvm.org/D123812

Added: 
    

Modified: 
    clang/test/Driver/linker-wrapper-image.c
    clang/test/Driver/linker-wrapper.c
    clang/tools/clang-linker-wrapper/ClangLinkerWrapper.cpp
    clang/tools/clang-linker-wrapper/OffloadWrapper.cpp

Removed: 
    


################################################################################
diff  --git a/clang/test/Driver/linker-wrapper-image.c b/clang/test/Driver/linker-wrapper-image.c
index c10b4c5a45e03..6badd97e2eb16 100644
--- a/clang/test/Driver/linker-wrapper-image.c
+++ b/clang/test/Driver/linker-wrapper-image.c
@@ -27,3 +27,65 @@
 // OPENMP-NEXT:   call void @__tgt_unregister_lib(%__tgt_bin_desc* @.omp_offloading.descriptor)
 // OPENMP-NEXT:   ret void
 // OPENMP-NEXT: }
+
+// RUN: %clang -cc1 %s -triple x86_64-unknown-linux-gnu -emit-obj -o %t.o \
+// RUN:   -fembed-offload-object=%S/Inputs/dummy-elf.o,cuda,nvptx64-nvida-cuda,sm_70
+// RUN: clang-linker-wrapper --print-wrapped-module --dry-run --host-triple x86_64-unknown-linux-gnu \
+// RUN:   -linker-path /usr/bin/ld -- %t.o -o a.out 2>&1 | FileCheck %s --check-prefix=CUDA
+
+// CUDA: @.fatbin_image = internal constant [0 x i8] zeroinitializer, section ".nv_fatbin"
+// CUDA-NEXT: @.fatbin_wrapper = internal constant %fatbin_wrapper { i32 1180844977, i32 1, i8* getelementptr inbounds ([0 x i8], [0 x i8]* @.fatbin_image, i32 0, i32 0), i8* null }, section ".nvFatBinSegment", align 8
+// CUDA-NEXT: @__dummy.cuda_offloading.entry = hidden constant [0 x %__tgt_offload_entry] zeroinitializer, section "cuda_offloading_entries"
+// CUDA-NEXT: @.cuda.binary_handle = internal global i8** null
+// CUDA-NEXT: @__start_cuda_offloading_entries = external hidden constant [0 x %__tgt_offload_entry]
+// CUDA-NEXT: @__stop_cuda_offloading_entries = external hidden constant [0 x %__tgt_offload_entry]
+// CUDA-NEXT: @llvm.global_ctors = appending global [1 x { i32, void ()*, i8* }] [{ i32, void ()*, i8* } { i32 1, void ()* @.cuda.fatbin_reg, i8* null }]
+
+// CUDA: define internal void @.cuda.fatbin_reg() section ".text.startup" {
+// CUDA-NEXT: entry:
+// CUDA-NEXT:   %0 = call i8** @__cudaRegisterFatBinary(i8* bitcast (%fatbin_wrapper* @.fatbin_wrapper to i8*))
+// CUDA-NEXT:   store i8** %0, i8*** @.cuda.binary_handle, align 8
+// CUDA-NEXT:   call void @.cuda.globals_reg(i8** %0)
+// CUDA-NEXT:   call void @__cudaRegisterFatBinaryEnd(i8** %0)
+// CUDA-NEXT:   %1 = call i32 @atexit(void ()* @.cuda.fatbin_unreg)
+// CUDA-NEXT:   ret void
+// CUDA-NEXT: }
+
+// CUDA: define internal void @.cuda.fatbin_unreg() section ".text.startup" {
+// CUDA-NEXT: entry:
+// CUDA-NEXT:   %0 = load i8**, i8*** @.cuda.binary_handle, align 8
+// CUDA-NEXT:   call void @__cudaUnregisterFatBinary(i8** %0)
+// CUDA-NEXT:   ret void
+// CUDA-NEXT: }
+
+// CUDA: define internal void @.cuda.globals_reg(i8** %0) section ".text.startup" {
+// CUDA-NEXT: entry:
+// CUDA-NEXT:   br i1 icmp ne ([0 x %__tgt_offload_entry]* @__start_cuda_offloading_entries, [0 x %__tgt_offload_entry]* @__stop_cuda_offloading_entries), label %while.entry, label %while.end
+
+// CUDA: while.entry:
+// CUDA-NEXT:   %entry1 = phi %__tgt_offload_entry* [ getelementptr inbounds ([0 x %__tgt_offload_entry], [0 x %__tgt_offload_entry]* @__start_cuda_offloading_entries, i64 0, i64 0), %entry ], [ %7, %if.end ]
+// CUDA-NEXT:   %1 = getelementptr inbounds %__tgt_offload_entry, %__tgt_offload_entry* %entry1, i64 0, i32 0
+// CUDA-NEXT:   %addr = load i8*, i8** %1, align 8
+// CUDA-NEXT:   %2 = getelementptr inbounds %__tgt_offload_entry, %__tgt_offload_entry* %entry1, i64 0, i32 1
+// CUDA-NEXT:   %name = load i8*, i8** %2, align 8
+// CUDA-NEXT:   %3 = getelementptr inbounds %__tgt_offload_entry, %__tgt_offload_entry* %entry1, i64 0, i32 2
+// CUDA-NEXT:   %size = load i64, i64* %3, align 4
+// CUDA-NEXT:   %4 = icmp eq i64 %size, 0
+// CUDA-NEXT:   br i1 %4, label %if.then, label %if.else
+
+// CUDA: if.then:
+// CUDA-NEXT:   %5 = call i32 @__cudaRegisterFunction(i8** %0, i8* %addr, i8* %name, i8* %name, i32 -1, i8* null, i8* null, i8* null, i8* null, i32* null)
+// CUDA-NEXT:   br label %if.end
+
+// CUDA: if.else:
+// CUDA-NEXT:   %6 = call i32 @__cudaRegisterVar(i8** %0, i8* %addr, i8* %name, i8* %name, i32 0, i64 %size, i32 0, i32 0)
+// CUDA-NEXT:   br label %if.end
+
+// CUDA: if.end:
+// CUDA-NEXT:   %7 = getelementptr inbounds %__tgt_offload_entry, %__tgt_offload_entry* %entry1, i64 1
+// CUDA-NEXT:   %8 = icmp eq %__tgt_offload_entry* %7, getelementptr inbounds ([0 x %__tgt_offload_entry], [0 x %__tgt_offload_entry]* @__stop_cuda_offloading_entries, i64 0, i64 0)
+// CUDA-NEXT:   br i1 %8, label %while.end, label %while.entry
+
+// CUDA: while.end:
+// CUDA-NEXT:   ret void
+// CUDA-NEXT: }

diff  --git a/clang/test/Driver/linker-wrapper.c b/clang/test/Driver/linker-wrapper.c
index 342822ff46637..3e126f0e7484d 100644
--- a/clang/test/Driver/linker-wrapper.c
+++ b/clang/test/Driver/linker-wrapper.c
@@ -60,3 +60,14 @@
 
 // STATIC-LIBRARY: nvlink{{.*}} -arch sm_70
 // STATIC-LIBRARY-NOT: nvlink{{.*}} -arch sm_50
+
+// RUN: %clang -cc1 %s -triple x86_64-unknown-linux-gnu -emit-obj -o %t.o \
+// RUN:   -fembed-offload-object=%S/Inputs/dummy-elf.o,cuda,nvptx64-nvida-cuda,sm_70 \
+// RUN:   -fembed-offload-object=%S/Inputs/dummy-elf.o,openmp,nvptx64-nvida-cuda,sm_70 \
+// RUN:   -fembed-offload-object=%S/Inputs/dummy-elf.o,cuda,nvptx64-nvida-cuda,sm_52
+// RUN: clang-linker-wrapper --dry-run --host-triple x86_64-unknown-linux-gnu -linker-path \
+// RUN:   /usr/bin/ld -- %t.o -o a.out 2>&1 | FileCheck %s --check-prefix=CUDA
+
+// CUDA: nvlink{{.*}}-m64 -o {{.*}}.out -arch sm_70 {{.*}}.o {{.*}}.o
+// CUDA: nvlink{{.*}}-m64 -o {{.*}}.out -arch sm_52 {{.*}}.o
+// CUDA: fatbinary{{.*}}-64 --create {{.*}}.fatbin --image=profile=sm_70,file={{.*}}.out  --image=profile=sm_52,file={{.*}}.out

diff  --git a/clang/tools/clang-linker-wrapper/ClangLinkerWrapper.cpp b/clang/tools/clang-linker-wrapper/ClangLinkerWrapper.cpp
index 68e2f3dd7def5..670e1af943046 100644
--- a/clang/tools/clang-linker-wrapper/ClangLinkerWrapper.cpp
+++ b/clang/tools/clang-linker-wrapper/ClangLinkerWrapper.cpp
@@ -1513,7 +1513,16 @@ int main(int argc, const char **argv) {
   auto FileOrErr = wrapDeviceImages(LinkedImages);
   if (!FileOrErr)
     return reportError(FileOrErr.takeError());
-  LinkerArgs.append(*FileOrErr);
+
+  // We need to insert the new files next to the old ones to make sure they're
+  // linked with the same libraries / arguments.
+  if (!FileOrErr->empty()) {
+    auto FirstInput = std::next(llvm::find_if(LinkerArgs, [](StringRef Str) {
+      return sys::fs::exists(Str) && !sys::fs::is_directory(Str) &&
+             Str != ExecutableName;
+    }));
+    LinkerArgs.insert(FirstInput, FileOrErr->begin(), FileOrErr->end());
+  }
 
   // Run the host linking job.
   if (Error Err = runLinker(LinkerUserPath, LinkerArgs))

diff  --git a/clang/tools/clang-linker-wrapper/OffloadWrapper.cpp b/clang/tools/clang-linker-wrapper/OffloadWrapper.cpp
index f24a39a59d069..cc030580b3d12 100644
--- a/clang/tools/clang-linker-wrapper/OffloadWrapper.cpp
+++ b/clang/tools/clang-linker-wrapper/OffloadWrapper.cpp
@@ -20,6 +20,8 @@
 using namespace llvm;
 
 namespace {
+/// Magic number that begins the section containing the CUDA fatbinary.
+constexpr unsigned CudaFatMagic = 0x466243b1;
 
 IntegerType *getSizeTTy(Module &M) {
   LLVMContext &C = M.getContext();
@@ -255,6 +257,278 @@ void createUnregisterFunction(Module &M, GlobalVariable *BinDesc) {
   appendToGlobalDtors(M, Func, /*Priority*/ 1);
 }
 
+// struct fatbin_wrapper {
+//  int32_t magic;
+//  int32_t version;
+//  void *image;
+//  void *reserved;
+//};
+StructType *getFatbinWrapperTy(Module &M) {
+  LLVMContext &C = M.getContext();
+  StructType *FatbinTy = StructType::getTypeByName(C, "fatbin_wrapper");
+  if (!FatbinTy)
+    FatbinTy = StructType::create("fatbin_wrapper", Type::getInt32Ty(C),
+                                  Type::getInt32Ty(C), Type::getInt8PtrTy(C),
+                                  Type::getInt8PtrTy(C));
+  return FatbinTy;
+}
+
+/// Embed the image \p Image into the module \p M so it can be found by the
+/// runtime.
+GlobalVariable *createFatbinDesc(Module &M, ArrayRef<char> Image) {
+  LLVMContext &C = M.getContext();
+  llvm::Type *Int8PtrTy = Type::getInt8PtrTy(C);
+  llvm::Triple Triple = llvm::Triple(M.getTargetTriple());
+
+  // Create the global string containing the fatbinary.
+  StringRef FatbinConstantSection =
+      Triple.isMacOSX() ? "__NV_CUDA,__nv_fatbin" : ".nv_fatbin";
+  auto *Data = ConstantDataArray::get(C, Image);
+  auto *Fatbin = new GlobalVariable(M, Data->getType(), /*isConstant*/ true,
+                                    GlobalVariable::InternalLinkage, Data,
+                                    ".fatbin_image");
+  Fatbin->setSection(FatbinConstantSection);
+
+  // Create the fatbinary wrapper
+  StringRef FatbinWrapperSection =
+      Triple.isMacOSX() ? "__NV_CUDA,__fatbin" : ".nvFatBinSegment";
+  Constant *FatbinWrapper[] = {
+      ConstantInt::get(Type::getInt32Ty(C), CudaFatMagic),
+      ConstantInt::get(Type::getInt32Ty(C), 1),
+      ConstantExpr::getPointerBitCastOrAddrSpaceCast(Fatbin, Int8PtrTy),
+      ConstantPointerNull::get(Type::getInt8PtrTy(C))};
+
+  Constant *FatbinInitializer =
+      ConstantStruct::get(getFatbinWrapperTy(M), FatbinWrapper);
+
+  auto *FatbinDesc =
+      new GlobalVariable(M, getFatbinWrapperTy(M),
+                         /*isConstant*/ true, GlobalValue::InternalLinkage,
+                         FatbinInitializer, ".fatbin_wrapper");
+  FatbinDesc->setSection(FatbinWrapperSection);
+  FatbinDesc->setAlignment(Align(8));
+
+  // We create a dummy entry to ensure the linker will define the begin / end
+  // symbols. The CUDA runtime should ignore the null address if we attempt to
+  // register it.
+  auto *DummyInit =
+      ConstantAggregateZero::get(ArrayType::get(getEntryTy(M), 0u));
+  auto *DummyEntry = new GlobalVariable(
+      M, DummyInit->getType(), true, GlobalVariable::ExternalLinkage, DummyInit,
+      "__dummy.cuda_offloading.entry");
+  DummyEntry->setSection("cuda_offloading_entries");
+  DummyEntry->setVisibility(GlobalValue::HiddenVisibility);
+
+  return FatbinDesc;
+}
+
+/// Create the register globals function. We will iterate all of the offloading
+/// entries stored at the begin / end symbols and register them according to
+/// their type. This creates the following function in IR:
+///
+/// extern struct __tgt_offload_entry __start_cuda_offloading_entries;
+/// extern struct __tgt_offload_entry __stop_cuda_offloading_entries;
+///
+/// extern void __cudaRegisterFunction(void **, void *, void *, void *, int,
+///                                    void *, void *, void *, void *, int *);
+/// extern void __cudaRegisterVar(void **, void *, void *, void *, int32_t,
+///                               int64_t, int32_t, int32_t);
+///
+/// void __cudaRegisterTest(void **fatbinHandle) {
+///   for (struct __tgt_offload_entry *entry = &__start_cuda_offloading_entries;
+///        entry != &__stop_cuda_offloading_entries; ++entry) {
+///     if (!entry->size)
+///       __cudaRegisterFunction(fatbinHandle, entry->addr, entry->name,
+///                              entry->name, -1, 0, 0, 0, 0, 0);
+///     else
+///       __cudaRegisterVar(fatbinHandle, entry->addr, entry->name, entry->name,
+///                         0, entry->size, 0, 0);
+///   }
+/// }
+///
+/// TODO: This only registers functions are variables. Additional support is
+///       required for texture / surface / managed variables.
+Function *createRegisterGlobalsFunction(Module &M) {
+  LLVMContext &C = M.getContext();
+  // Get the __cudaRegisterFunction function declaration.
+  auto *RegFuncTy = FunctionType::get(
+      Type::getInt32Ty(C),
+      {Type::getInt8PtrTy(C)->getPointerTo(), Type::getInt8PtrTy(C),
+       Type::getInt8PtrTy(C), Type::getInt8PtrTy(C), Type::getInt32Ty(C),
+       Type::getInt8PtrTy(C), Type::getInt8PtrTy(C), Type::getInt8PtrTy(C),
+       Type::getInt8PtrTy(C), Type::getInt32PtrTy(C)},
+      /*isVarArg*/ false);
+  FunctionCallee RegFunc =
+      M.getOrInsertFunction("__cudaRegisterFunction", RegFuncTy);
+
+  // Get the __cudaRegisterVar function declaration.
+  auto *RegVarTy = FunctionType::get(
+      Type::getInt32Ty(C),
+      {Type::getInt8PtrTy(C)->getPointerTo(), Type::getInt8PtrTy(C),
+       Type::getInt8PtrTy(C), Type::getInt8PtrTy(C), Type::getInt32Ty(C),
+       getSizeTTy(M), Type::getInt32Ty(C), Type::getInt32Ty(C)},
+      /*isVarArg*/ false);
+  FunctionCallee RegVar = M.getOrInsertFunction("__cudaRegisterVar", RegVarTy);
+
+  // Create the references to the start / stop symbols defined by the linker.
+  auto *EntriesB = new GlobalVariable(
+      M, ArrayType::get(getEntryTy(M), 0), /*isConstant*/ true,
+      GlobalValue::ExternalLinkage,
+      /*Initializer*/ nullptr, "__start_cuda_offloading_entries");
+  EntriesB->setVisibility(GlobalValue::HiddenVisibility);
+  auto *EntriesE = new GlobalVariable(
+      M, ArrayType::get(getEntryTy(M), 0), /*isConstant*/ true,
+      GlobalValue::ExternalLinkage,
+      /*Initializer*/ nullptr, "__stop_cuda_offloading_entries");
+  EntriesE->setVisibility(GlobalValue::HiddenVisibility);
+
+  auto *RegGlobalsTy = FunctionType::get(Type::getVoidTy(C),
+                                         Type::getInt8PtrTy(C)->getPointerTo(),
+                                         /*isVarArg*/ false);
+  auto *RegGlobalsFn = Function::Create(
+      RegGlobalsTy, GlobalValue::InternalLinkage, ".cuda.globals_reg", &M);
+  RegGlobalsFn->setSection(".text.startup");
+
+  // Create the loop to register all the entries.
+  IRBuilder<> Builder(BasicBlock::Create(C, "entry", RegGlobalsFn));
+  auto *EntryBB = BasicBlock::Create(C, "while.entry", RegGlobalsFn);
+  auto *IfThenBB = BasicBlock::Create(C, "if.then", RegGlobalsFn);
+  auto *IfElseBB = BasicBlock::Create(C, "if.else", RegGlobalsFn);
+  auto *IfEndBB = BasicBlock::Create(C, "if.end", RegGlobalsFn);
+  auto *ExitBB = BasicBlock::Create(C, "while.end", RegGlobalsFn);
+
+  auto *EntryCmp = Builder.CreateICmpNE(EntriesB, EntriesE);
+  Builder.CreateCondBr(EntryCmp, EntryBB, ExitBB);
+  Builder.SetInsertPoint(EntryBB);
+  auto *Entry = Builder.CreatePHI(getEntryPtrTy(M), 2, "entry");
+  auto *AddrPtr =
+      Builder.CreateInBoundsGEP(getEntryTy(M), Entry,
+                                {ConstantInt::get(getSizeTTy(M), 0),
+                                 ConstantInt::get(Type::getInt32Ty(C), 0)});
+  auto *Addr = Builder.CreateLoad(Type::getInt8PtrTy(C), AddrPtr, "addr");
+  auto *NamePtr =
+      Builder.CreateInBoundsGEP(getEntryTy(M), Entry,
+                                {ConstantInt::get(getSizeTTy(M), 0),
+                                 ConstantInt::get(Type::getInt32Ty(C), 1)});
+  auto *Name = Builder.CreateLoad(Type::getInt8PtrTy(C), NamePtr, "name");
+  auto *SizePtr =
+      Builder.CreateInBoundsGEP(getEntryTy(M), Entry,
+                                {ConstantInt::get(getSizeTTy(M), 0),
+                                 ConstantInt::get(Type::getInt32Ty(C), 2)});
+  auto *Size = Builder.CreateLoad(getSizeTTy(M), SizePtr, "size");
+  auto *FnCond =
+      Builder.CreateICmpEQ(Size, ConstantInt::getNullValue(getSizeTTy(M)));
+  Builder.CreateCondBr(FnCond, IfThenBB, IfElseBB);
+  Builder.SetInsertPoint(IfThenBB);
+  Builder.CreateCall(RegFunc,
+                     {RegGlobalsFn->arg_begin(), Addr, Name, Name,
+                      ConstantInt::get(Type::getInt32Ty(C), -1),
+                      ConstantPointerNull::get(Type::getInt8PtrTy(C)),
+                      ConstantPointerNull::get(Type::getInt8PtrTy(C)),
+                      ConstantPointerNull::get(Type::getInt8PtrTy(C)),
+                      ConstantPointerNull::get(Type::getInt8PtrTy(C)),
+                      ConstantPointerNull::get(Type::getInt32PtrTy(C))});
+  Builder.CreateBr(IfEndBB);
+  Builder.SetInsertPoint(IfElseBB);
+  Builder.CreateCall(RegVar, {RegGlobalsFn->arg_begin(), Addr, Name, Name,
+                              ConstantInt::get(Type::getInt32Ty(C), 0), Size,
+                              ConstantInt::get(Type::getInt32Ty(C), 0),
+                              ConstantInt::get(Type::getInt32Ty(C), 0)});
+  Builder.CreateBr(IfEndBB);
+  Builder.SetInsertPoint(IfEndBB);
+  auto *NewEntry = Builder.CreateInBoundsGEP(
+      getEntryTy(M), Entry, ConstantInt::get(getSizeTTy(M), 1));
+  auto *Cmp = Builder.CreateICmpEQ(
+      NewEntry,
+      ConstantExpr::getInBoundsGetElementPtr(
+          ArrayType::get(getEntryTy(M), 0), EntriesE,
+          ArrayRef<Constant *>({ConstantInt::get(getSizeTTy(M), 0),
+                                ConstantInt::get(getSizeTTy(M), 0)})));
+  Entry->addIncoming(
+      ConstantExpr::getInBoundsGetElementPtr(
+          ArrayType::get(getEntryTy(M), 0), EntriesB,
+          ArrayRef<Constant *>({ConstantInt::get(getSizeTTy(M), 0),
+                                ConstantInt::get(getSizeTTy(M), 0)})),
+      &RegGlobalsFn->getEntryBlock());
+  Entry->addIncoming(NewEntry, IfEndBB);
+  Builder.CreateCondBr(Cmp, ExitBB, EntryBB);
+  Builder.SetInsertPoint(ExitBB);
+  Builder.CreateRetVoid();
+
+  return RegGlobalsFn;
+}
+
+// Create the constructor and destructor to register the fatbinary with the CUDA
+// runtime.
+void createRegisterFatbinFunction(Module &M, GlobalVariable *FatbinDesc) {
+  LLVMContext &C = M.getContext();
+  auto *CtorFuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false);
+  auto *CtorFunc = Function::Create(CtorFuncTy, GlobalValue::InternalLinkage,
+                                    ".cuda.fatbin_reg", &M);
+  CtorFunc->setSection(".text.startup");
+
+  auto *DtorFuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false);
+  auto *DtorFunc = Function::Create(DtorFuncTy, GlobalValue::InternalLinkage,
+                                    ".cuda.fatbin_unreg", &M);
+  DtorFunc->setSection(".text.startup");
+
+  // Get the __cudaRegisterFatBinary function declaration.
+  auto *RegFatTy = FunctionType::get(Type::getInt8PtrTy(C)->getPointerTo(),
+                                     Type::getInt8PtrTy(C),
+                                     /*isVarArg*/ false);
+  FunctionCallee RegFatbin =
+      M.getOrInsertFunction("__cudaRegisterFatBinary", RegFatTy);
+  // Get the __cudaRegisterFatBinaryEnd function declaration.
+  auto *RegFatEndTy = FunctionType::get(Type::getVoidTy(C),
+                                        Type::getInt8PtrTy(C)->getPointerTo(),
+                                        /*isVarArg*/ false);
+  FunctionCallee RegFatbinEnd =
+      M.getOrInsertFunction("__cudaRegisterFatBinaryEnd", RegFatEndTy);
+  // Get the __cudaUnregisterFatBinary function declaration.
+  auto *UnregFatTy = FunctionType::get(Type::getVoidTy(C),
+                                       Type::getInt8PtrTy(C)->getPointerTo(),
+                                       /*isVarArg*/ false);
+  FunctionCallee UnregFatbin =
+      M.getOrInsertFunction("__cudaUnregisterFatBinary", UnregFatTy);
+
+  auto *AtExitTy =
+      FunctionType::get(Type::getInt32Ty(C), DtorFuncTy->getPointerTo(),
+                        /*isVarArg*/ false);
+  FunctionCallee AtExit = M.getOrInsertFunction("atexit", AtExitTy);
+
+  auto *BinaryHandleGlobal = new llvm::GlobalVariable(
+      M, Type::getInt8PtrTy(C)->getPointerTo(), false,
+      llvm::GlobalValue::InternalLinkage,
+      llvm::ConstantPointerNull::get(Type::getInt8PtrTy(C)->getPointerTo()),
+      ".cuda.binary_handle");
+
+  // Create the constructor to register this image with the runtime.
+  IRBuilder<> CtorBuilder(BasicBlock::Create(C, "entry", CtorFunc));
+  CallInst *Handle = CtorBuilder.CreateCall(
+      RegFatbin, ConstantExpr::getPointerBitCastOrAddrSpaceCast(
+                     FatbinDesc, Type::getInt8PtrTy(C)));
+  CtorBuilder.CreateAlignedStore(
+      Handle, BinaryHandleGlobal,
+      Align(M.getDataLayout().getPointerTypeSize(Type::getInt8PtrTy(C))));
+  CtorBuilder.CreateCall(createRegisterGlobalsFunction(M), Handle);
+  CtorBuilder.CreateCall(RegFatbinEnd, Handle);
+  CtorBuilder.CreateCall(AtExit, DtorFunc);
+  CtorBuilder.CreateRetVoid();
+
+  // Create the destructor to unregister the image with the runtime. We cannot
+  // use a standard global destructor after CUDA 9.2 so this must be called by
+  // `atexit()` intead.
+  IRBuilder<> DtorBuilder(BasicBlock::Create(C, "entry", DtorFunc));
+  LoadInst *BinaryHandle = DtorBuilder.CreateAlignedLoad(
+      Type::getInt8PtrTy(C)->getPointerTo(), BinaryHandleGlobal,
+      Align(M.getDataLayout().getPointerTypeSize(Type::getInt8PtrTy(C))));
+  DtorBuilder.CreateCall(UnregFatbin, BinaryHandle);
+  DtorBuilder.CreateRetVoid();
+
+  // Add this function to constructors.
+  appendToGlobalCtors(M, CtorFunc, /*Priority*/ 1);
+}
+
 } // namespace
 
 Error wrapOpenMPBinaries(Module &M, ArrayRef<ArrayRef<char>> Images) {
@@ -267,7 +541,12 @@ Error wrapOpenMPBinaries(Module &M, ArrayRef<ArrayRef<char>> Images) {
   return Error::success();
 }
 
-llvm::Error wrapCudaBinary(llvm::Module &M, llvm::ArrayRef<char> Images) {
-  // TODO: Implement this.
+Error wrapCudaBinary(Module &M, ArrayRef<char> Image) {
+  GlobalVariable *Desc = createFatbinDesc(M, Image);
+  if (!Desc)
+    return createStringError(inconvertibleErrorCode(),
+                             "No fatinbary section created.");
+
+  createRegisterFatbinFunction(M, Desc);
   return Error::success();
 }


        


More information about the cfe-commits mailing list